From b24f09585c05f2caa582f0a6ab5b974189d2d8d5 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Fri, 6 Oct 2023 21:03:37 +0000 Subject: [PATCH 01/11] set backend to nccl --- tests/test_lion8b.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_lion8b.py b/tests/test_lion8b.py index 7d517269fc..d1648157c0 100644 --- a/tests/test_lion8b.py +++ b/tests/test_lion8b.py @@ -416,7 +416,7 @@ def test_fsdp_save_load(dtype: torch.dtype, use_errors: bool, torch.cuda.set_device(f'cuda:{os.environ["RANK"]}') # needed for fsdp if not dist.is_initialized(): - dist.init_process_group() + dist.init_process_group(backend='nccl') assert dist.get_world_size() >= 2, 'Misconfigured test run!' mod = FSDP(_DummyModule(device=device, dtype=dtype)) From 10aba2000273968831c7925733182318c381c33f Mon Sep 17 00:00:00 2001 From: Daniel King Date: Fri, 6 Oct 2023 21:19:03 +0000 Subject: [PATCH 02/11] skip testing errors when not present --- tests/test_lion8b.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_lion8b.py b/tests/test_lion8b.py index d1648157c0..088a6ab909 100644 --- a/tests/test_lion8b.py +++ b/tests/test_lion8b.py @@ -481,7 +481,7 @@ def _set_state_dict_type(model: nn.Module): assert mom_orig.shape == mom_new.shape assert mom_orig.dtype == mom_new.dtype - if use_errors: + if use_errors and (dtype != torch.float32): errs_orig = d_orig['errors'] errs_new = d_new['errors'] assert errs_orig.shape == errs_new.shape From 9bc5f7fc22a0cb9b1e3e3e4390a8c79bea8a57ee Mon Sep 17 00:00:00 2001 From: Daniel King Date: Fri, 6 Oct 2023 15:36:18 -0700 Subject: [PATCH 03/11] add explicit error for 2.1 and use errors --- llmfoundry/optim/lion8b.py | 7 +- tests/test_lion8b.py | 168 +++++++++++++++++++------------------ 2 files changed, 93 insertions(+), 82 deletions(-) diff --git a/llmfoundry/optim/lion8b.py b/llmfoundry/optim/lion8b.py index 806dbdbd14..7e2a03a59b 100644 --- a/llmfoundry/optim/lion8b.py +++ b/llmfoundry/optim/lion8b.py @@ -4,6 +4,7 @@ from typing import Any, Callable, Dict, Iterable, Optional, Tuple import torch +from packaging import version class DecoupledLionW_8bit(torch.optim.Optimizer): @@ -53,7 +54,7 @@ class DecoupledLionW_8bit(torch.optim.Optimizer): by retaining information across optimizer steps. Raises: - NotImplemenetedError - If any of `quantize`, `compress_state_dict`, + NotImplementedError - If any of `quantize`, `compress_state_dict`, or `error_correction` are `True` and either a) there is no CUDA device, or b) step() is executed on a non-CUDA parameter. """ @@ -67,6 +68,10 @@ def __init__(self, compress_state_dict: bool = False, error_correction: bool = False, _fused: bool = True): # XXX this flag is mostly for testing... + if version.parse(torch.__version__) >= version.parse('2.1.0') and error_correction: + raise RuntimeError( + 'DecoupledLionW_8bit with error correction requires PyTorch <2.1.0') + if lr < 0.0: raise ValueError('Invalid learning rate: {}'.format(lr)) if not 0.0 <= betas[0] <= 1.0: diff --git a/tests/test_lion8b.py b/tests/test_lion8b.py index 088a6ab909..cfc5970663 100644 --- a/tests/test_lion8b.py +++ b/tests/test_lion8b.py @@ -4,6 +4,7 @@ import os import time import warnings +import contextlib import numpy as np import packaging.version as version @@ -414,88 +415,93 @@ def test_fsdp_save_load(dtype: torch.dtype, use_errors: bool, if version.parse(torch.__version__) < version.parse('2.0.1'): pytest.skip(f'This test requires torch 2.0.1 or greater.') - torch.cuda.set_device(f'cuda:{os.environ["RANK"]}') # needed for fsdp - if not dist.is_initialized(): - dist.init_process_group(backend='nccl') - assert dist.get_world_size() >= 2, 'Misconfigured test run!' - - mod = FSDP(_DummyModule(device=device, dtype=dtype)) - - # actual forward pass instead of setting p.grad to avoid FSDP issues - X = torch.rand(size=(5, 4), device=device, dtype=dtype) - Y = mod(X) - Y.sum().backward() - for p in mod.parameters(): - p.grad = torch.rand_like(p) - - # create optimizer and have it step so that state gets populated - opt = Lion8bit(mod.parameters(), error_correction=use_errors) - opt.step() - opt.zero_grad() - - def _set_state_dict_type(model: nn.Module): - # for mapping between state dict types and optim state dict types, see: - # https://github.com/pytorch/pytorch/blob/a815e719e85899d4229616617e7827d4de191c2d/torch/distributed/fsdp/fully_sharded_data_parallel.py#L664 # noqa - state_dict_cfg = { - _FULL_STATE: fsdp.FullStateDictConfig(rank0_only=False), - _SHARDED_STATE: fsdp.ShardedStateDictConfig(), - _LOCAL_STATE: fsdp.LocalStateDictConfig(), - }[state_sharding] - optim_cfg = { - _FULL_STATE: FullOptimStateDictConfig(rank0_only=False), - _SHARDED_STATE: ShardedOptimStateDictConfig(), - _LOCAL_STATE: LocalOptimStateDictConfig(), - }[state_sharding] - FSDP.set_state_dict_type(model, state_sharding, state_dict_cfg, - optim_cfg) - - # load FSDP state dict - _set_state_dict_type(mod) - opt_state_dict = FSDP.optim_state_dict(mod, opt) - - # make a new model and optimizer - mod_new = FSDP(_DummyModule(device=device, dtype=dtype)) - opt_new = Lion8bit(mod_new.parameters(), error_correction=use_errors) - _set_state_dict_type(mod_new) - - # load state dict into the new optimizer - opt_state_dict_slice = FSDP.optim_state_dict_to_load( - optim_state_dict=opt_state_dict, model=mod_new, optim=opt_new) - opt_new.load_state_dict(opt_state_dict_slice) - - new_opt_state_dict = FSDP.optim_state_dict(mod_new, opt_new) - - orig_state = opt_state_dict['state'] - orig_param_groups = opt_state_dict['param_groups'] - new_state = new_opt_state_dict['state'] - new_param_groups = new_opt_state_dict['param_groups'] - - all_keys = set(orig_state.keys()) | set(new_state.keys()) - assert orig_param_groups == new_param_groups # works since strs, not ptrs - for k in all_keys: # keys are param paths in module as strings - d_orig = orig_state[k] - d_new = new_state[k] - assert list(d_orig.keys()) == list(d_new.keys()) - mom_orig = d_orig['exp_avg'] - mom_new = d_new['exp_avg'] + error_context = contextlib.null_context() + if use_errors and version.parse(torch.__version__) >= version.parse('2.1.0'): + error_context = pytest.raises(RuntimeError, match='DecoupledLionW_8bit with error correction') + + with error_context: + torch.cuda.set_device(f'cuda:{os.environ["RANK"]}') # needed for fsdp + if not dist.is_initialized(): + dist.init_process_group(backend='nccl') + assert dist.get_world_size() >= 2, 'Misconfigured test run!' + + mod = FSDP(_DummyModule(device=device, dtype=dtype)) + + # actual forward pass instead of setting p.grad to avoid FSDP issues + X = torch.rand(size=(5, 4), device=device, dtype=dtype) + Y = mod(X) + Y.sum().backward() + for p in mod.parameters(): + p.grad = torch.rand_like(p) + + # create optimizer and have it step so that state gets populated + opt = Lion8bit(mod.parameters(), error_correction=use_errors) + opt.step() + opt.zero_grad() - assert mom_orig.shape == mom_new.shape - assert mom_orig.dtype == mom_new.dtype - if use_errors and (dtype != torch.float32): - errs_orig = d_orig['errors'] - errs_new = d_new['errors'] - assert errs_orig.shape == errs_new.shape - assert errs_orig.dtype == errs_new.dtype - - if state_sharding != _FULL_STATE: - continue # more detailed checks lean on FSDP impl details - - # momentums may not be bit-for-bit identical because Optimizer upcasts - # to f32 and we convert back to bf16, possibly with different rounding - torch.testing.assert_close(mom_orig, mom_new) - # errors not bit-for-bit identical because scales get upcast too - if use_errors and (dtype != torch.float32): - torch.testing.assert_close(d_orig['errors'], d_new['errors']) + def _set_state_dict_type(model: nn.Module): + # for mapping between state dict types and optim state dict types, see: + # https://github.com/pytorch/pytorch/blob/a815e719e85899d4229616617e7827d4de191c2d/torch/distributed/fsdp/fully_sharded_data_parallel.py#L664 # noqa + state_dict_cfg = { + _FULL_STATE: fsdp.FullStateDictConfig(rank0_only=False), + _SHARDED_STATE: fsdp.ShardedStateDictConfig(), + _LOCAL_STATE: fsdp.LocalStateDictConfig(), + }[state_sharding] + optim_cfg = { + _FULL_STATE: FullOptimStateDictConfig(rank0_only=False), + _SHARDED_STATE: ShardedOptimStateDictConfig(), + _LOCAL_STATE: LocalOptimStateDictConfig(), + }[state_sharding] + FSDP.set_state_dict_type(model, state_sharding, state_dict_cfg, + optim_cfg) + + # load FSDP state dict + _set_state_dict_type(mod) + opt_state_dict = FSDP.optim_state_dict(mod, opt) + + # make a new model and optimizer + mod_new = FSDP(_DummyModule(device=device, dtype=dtype)) + opt_new = Lion8bit(mod_new.parameters(), error_correction=use_errors) + _set_state_dict_type(mod_new) + + # load state dict into the new optimizer + opt_state_dict_slice = FSDP.optim_state_dict_to_load( + optim_state_dict=opt_state_dict, model=mod_new, optim=opt_new) + opt_new.load_state_dict(opt_state_dict_slice) + + new_opt_state_dict = FSDP.optim_state_dict(mod_new, opt_new) + + orig_state = opt_state_dict['state'] + orig_param_groups = opt_state_dict['param_groups'] + new_state = new_opt_state_dict['state'] + new_param_groups = new_opt_state_dict['param_groups'] + + all_keys = set(orig_state.keys()) | set(new_state.keys()) + assert orig_param_groups == new_param_groups # works since strs, not ptrs + for k in all_keys: # keys are param paths in module as strings + d_orig = orig_state[k] + d_new = new_state[k] + assert list(d_orig.keys()) == list(d_new.keys()) + mom_orig = d_orig['exp_avg'] + mom_new = d_new['exp_avg'] + + assert mom_orig.shape == mom_new.shape + assert mom_orig.dtype == mom_new.dtype + if use_errors and (dtype != torch.float32): + errs_orig = d_orig['errors'] + errs_new = d_new['errors'] + assert errs_orig.shape == errs_new.shape + assert errs_orig.dtype == errs_new.dtype + + if state_sharding != _FULL_STATE: + continue # more detailed checks lean on FSDP impl details + + # momentums may not be bit-for-bit identical because Optimizer upcasts + # to f32 and we convert back to bf16, possibly with different rounding + torch.testing.assert_close(mom_orig, mom_new) + # errors not bit-for-bit identical because scales get upcast too + if use_errors and (dtype != torch.float32): + torch.testing.assert_close(d_orig['errors'], d_new['errors']) @pytest.mark.gpu From 9d211f6930e1326ab04fc3f3bb75920582bdb665 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Fri, 6 Oct 2023 15:37:02 -0700 Subject: [PATCH 04/11] precommit --- llmfoundry/optim/lion8b.py | 6 ++++-- tests/test_lion8b.py | 14 ++++++++------ 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/llmfoundry/optim/lion8b.py b/llmfoundry/optim/lion8b.py index 7e2a03a59b..f09d4a86c5 100644 --- a/llmfoundry/optim/lion8b.py +++ b/llmfoundry/optim/lion8b.py @@ -68,9 +68,11 @@ def __init__(self, compress_state_dict: bool = False, error_correction: bool = False, _fused: bool = True): # XXX this flag is mostly for testing... - if version.parse(torch.__version__) >= version.parse('2.1.0') and error_correction: + if version.parse(torch.__version__) >= version.parse( + '2.1.0') and error_correction: raise RuntimeError( - 'DecoupledLionW_8bit with error correction requires PyTorch <2.1.0') + 'DecoupledLionW_8bit with error correction requires PyTorch <2.1.0' + ) if lr < 0.0: raise ValueError('Invalid learning rate: {}'.format(lr)) diff --git a/tests/test_lion8b.py b/tests/test_lion8b.py index cfc5970663..7e4d97390c 100644 --- a/tests/test_lion8b.py +++ b/tests/test_lion8b.py @@ -1,10 +1,10 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import contextlib import os import time import warnings -import contextlib import numpy as np import packaging.version as version @@ -415,10 +415,12 @@ def test_fsdp_save_load(dtype: torch.dtype, use_errors: bool, if version.parse(torch.__version__) < version.parse('2.0.1'): pytest.skip(f'This test requires torch 2.0.1 or greater.') - error_context = contextlib.null_context() - if use_errors and version.parse(torch.__version__) >= version.parse('2.1.0'): - error_context = pytest.raises(RuntimeError, match='DecoupledLionW_8bit with error correction') - + error_context = contextlib.nullcontext() + if use_errors and version.parse( + torch.__version__) >= version.parse('2.1.0'): + error_context = pytest.raises( + RuntimeError, match='DecoupledLionW_8bit with error correction') + with error_context: torch.cuda.set_device(f'cuda:{os.environ["RANK"]}') # needed for fsdp if not dist.is_initialized(): @@ -453,7 +455,7 @@ def _set_state_dict_type(model: nn.Module): _LOCAL_STATE: LocalOptimStateDictConfig(), }[state_sharding] FSDP.set_state_dict_type(model, state_sharding, state_dict_cfg, - optim_cfg) + optim_cfg) # load FSDP state dict _set_state_dict_type(mod) From f10be3e75a5cc2acf79c0242ae9eadb02b859eb0 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Fri, 6 Oct 2023 15:59:07 -0700 Subject: [PATCH 05/11] add the error contexts --- tests/test_lion8b.py | 317 +++++++++++++++++++++++-------------------- 1 file changed, 173 insertions(+), 144 deletions(-) diff --git a/tests/test_lion8b.py b/tests/test_lion8b.py index 7e4d97390c..af8c7b00d7 100644 --- a/tests/test_lion8b.py +++ b/tests/test_lion8b.py @@ -42,40 +42,47 @@ (True, True)]) def test_modifies_weights_and_momentums(N: int, D: int, dtype: torch.dtype, fused: bool, use_errors: bool) -> None: - device = 'cuda' - torch.manual_seed(123) - X = torch.randn((N, D), device=device, requires_grad=False, dtype=dtype) - W = torch.randn((D, D), device=device, requires_grad=True, dtype=dtype) - W_orig = W.detach().clone() - - opt = Lion8bit([W], - lr=1.0, - _fused=fused, - betas=(.75, .75), - weight_decay=.2, - error_correction=use_errors) - - Y = X @ W - loss = Y.sum() - loss.backward() - torch.testing.assert_close(W_orig, W) # no weight modification yet - opt.step() - opt.zero_grad() - - with pytest.raises(AssertionError): # opt step modified the weights - torch.testing.assert_close(W_orig, W) - - # Every momentum should be nonzero with infinite precision, but - # might be zero after quantization. We turn the _MaybeQuantizedTensor - # instance into a regular torch Tensor to simplify this check. - param_state = opt.state[W] # type:ignore using tensor as key - momentum = param_state['exp_avg'].materialize() - assert momentum.shape == (D, D) - momentum = momentum.ravel() - if momentum.numel() == 1: - assert momentum.item() != 0 - else: - assert torch.std(momentum).item() > 0 + error_context = contextlib.nullcontext() + if use_errors and version.parse( + torch.__version__) >= version.parse('2.1.0'): + error_context = pytest.raises( + RuntimeError, match='DecoupledLionW_8bit with error correction') + + with error_context: + device = 'cuda' + torch.manual_seed(123) + X = torch.randn((N, D), device=device, requires_grad=False, dtype=dtype) + W = torch.randn((D, D), device=device, requires_grad=True, dtype=dtype) + W_orig = W.detach().clone() + + opt = Lion8bit([W], + lr=1.0, + _fused=fused, + betas=(.75, .75), + weight_decay=.2, + error_correction=use_errors) + + Y = X @ W + loss = Y.sum() + loss.backward() + torch.testing.assert_close(W_orig, W) # no weight modification yet + opt.step() + opt.zero_grad() + + with pytest.raises(AssertionError): # opt step modified the weights + torch.testing.assert_close(W_orig, W) + + # Every momentum should be nonzero with infinite precision, but + # might be zero after quantization. We turn the _MaybeQuantizedTensor + # instance into a regular torch Tensor to simplify this check. + param_state = opt.state[W] # type:ignore using tensor as key + momentum = param_state['exp_avg'].materialize() + assert momentum.shape == (D, D) + momentum = momentum.ravel() + if momentum.numel() == 1: + assert momentum.item() != 0 + else: + assert torch.std(momentum).item() > 0 @pytest.mark.gpu @@ -92,33 +99,40 @@ def test_changes_with_zero_grads(N: int, D: int, device: str, fused: bool, use_errors: bool) -> None: if (device == 'cpu') and (fused or use_errors): return + + error_context = contextlib.nullcontext() + if use_errors and version.parse( + torch.__version__) >= version.parse('2.1.0'): + error_context = pytest.raises( + RuntimeError, match='DecoupledLionW_8bit with error correction') - torch.manual_seed(123) - W = torch.rand((D, D), device=device, requires_grad=True) - W_orig = W.detach().clone() - - opt = Lion8bit([W], - _fused=fused, - betas=(.5, .5), - quantize=(device != 'cpu'), - weight_decay=weight_decay, - error_correction=use_errors) - - zeros_grad = torch.zeros_like(W) - for _ in range(5): - W.grad = zeros_grad - opt.step() - opt.zero_grad() + with error_context: + torch.manual_seed(123) + W = torch.rand((D, D), device=device, requires_grad=True) + W_orig = W.detach().clone() + + opt = Lion8bit([W], + _fused=fused, + betas=(.5, .5), + quantize=(device != 'cpu'), + weight_decay=weight_decay, + error_correction=use_errors) + + zeros_grad = torch.zeros_like(W) + for _ in range(5): + W.grad = zeros_grad + opt.step() + opt.zero_grad() - mom = opt.state[W]['exp_avg'] # type:ignore using tensor as key - assert torch.all(mom.materialize() == 0) - if mom.is_quantized(): - assert torch.all(mom.quantized == 0) + mom = opt.state[W]['exp_avg'] # type:ignore using tensor as key + assert torch.all(mom.materialize() == 0) + if mom.is_quantized(): + assert torch.all(mom.quantized == 0) - if weight_decay: - assert torch.all(W_orig.abs() > W.abs()) - else: - torch.testing.assert_close(W_orig, W) # no weight modification + if weight_decay: + assert torch.all(W_orig.abs() > W.abs()) + else: + torch.testing.assert_close(W_orig, W) # no weight modification @pytest.mark.gpu @@ -133,43 +147,51 @@ def test_descends(N: int, D: int, device: str, dtype: torch.dtype, fused: bool, use_errors: bool) -> None: if (device == 'cpu') and (fused or use_errors): return - torch.manual_seed(123) - X = torch.randn((N, D), device=device, requires_grad=False, dtype=dtype) - W = torch.randn((D, D), device=device, requires_grad=True, dtype=dtype) - - # we use tiny beta1 so we move almost entirely in the gradient direction - opt = Lion8bit([W], - lr=1e-2, - betas=(.5, .5), - quantize=(device != 'cpu'), - _fused=fused, - error_correction=use_errors) - - prev_loss = np.inf - prev_momentum = None - num_iters = 10 if device == 'cuda' else 2 # keep test fast - for _ in range(num_iters): - Y = X @ W - loss = (Y * Y).mean() - loss.backward() - opt.step() - opt.zero_grad() + + error_context = contextlib.nullcontext() + if use_errors and version.parse( + torch.__version__) >= version.parse('2.1.0'): + error_context = pytest.raises( + RuntimeError, match='DecoupledLionW_8bit with error correction') + + with error_context: + torch.manual_seed(123) + X = torch.randn((N, D), device=device, requires_grad=False, dtype=dtype) + W = torch.randn((D, D), device=device, requires_grad=True, dtype=dtype) + + # we use tiny beta1 so we move almost entirely in the gradient direction + opt = Lion8bit([W], + lr=1e-2, + betas=(.5, .5), + quantize=(device != 'cpu'), + _fused=fused, + error_correction=use_errors) + + prev_loss = np.inf + prev_momentum = None + num_iters = 10 if device == 'cuda' else 2 # keep test fast + for _ in range(num_iters): + Y = X @ W + loss = (Y * Y).mean() + loss.backward() + opt.step() + opt.zero_grad() - loss_val = loss.item() - assert loss_val < prev_loss - prev_loss = loss_val + loss_val = loss.item() + assert loss_val < prev_loss + prev_loss = loss_val - # since we're getting the same batch every time and have a small - # learning rate, our gradients should point in the same direction - # at each step. Consequently, our momentum should grow each step. - state_for_param = opt.state[W] # type:ignore using tensor as key - momentum = state_for_param['exp_avg'].materialize() - assert momentum is not None and momentum.shape == W.shape - if prev_momentum is not None: - momentum_abs_changes = (momentum - prev_momentum).abs() - assert torch.all(momentum_abs_changes >= 0) - assert momentum_abs_changes.max() > 0 - prev_momentum = momentum.clone() # {gpu, f32 on cpu} write in place + # since we're getting the same batch every time and have a small + # learning rate, our gradients should point in the same direction + # at each step. Consequently, our momentum should grow each step. + state_for_param = opt.state[W] # type:ignore using tensor as key + momentum = state_for_param['exp_avg'].materialize() + assert momentum is not None and momentum.shape == W.shape + if prev_momentum is not None: + momentum_abs_changes = (momentum - prev_momentum).abs() + assert torch.all(momentum_abs_changes >= 0) + assert momentum_abs_changes.max() > 0 + prev_momentum = momentum.clone() # {gpu, f32 on cpu} write in place def _nmse(vals_true: torch.Tensor, @@ -329,58 +351,65 @@ def test_lion8b_fused_unfused_unquantized_same(w_init: str, grad_strategy: str, @pytest.mark.parametrize('use_errors', [False, True]) def test_state_dict_save_load(device: str, quantized_state: bool, dtype: torch.dtype, use_errors: bool): - torch.manual_seed(123) - params = [] - for shape in _MANY_PARAM_SHAPES: - p = torch.rand(shape, device=device, dtype=dtype, requires_grad=True) - p.grad = torch.rand_like(p) - params.append(p) - - # create optimizer and have it step so that state gets populated - opt = Lion8bit(params, - compress_state_dict=quantized_state, - error_correction=use_errors) - if device == 'cpu': - with pytest.raises(NotImplementedError): + error_context = contextlib.nullcontext() + if use_errors and version.parse( + torch.__version__) >= version.parse('2.1.0'): + error_context = pytest.raises( + RuntimeError, match='DecoupledLionW_8bit with error correction') + + with error_context: + torch.manual_seed(123) + params = [] + for shape in _MANY_PARAM_SHAPES: + p = torch.rand(shape, device=device, dtype=dtype, requires_grad=True) + p.grad = torch.rand_like(p) + params.append(p) + + # create optimizer and have it step so that state gets populated + opt = Lion8bit(params, + compress_state_dict=quantized_state, + error_correction=use_errors) + if device == 'cpu': + with pytest.raises(NotImplementedError): + opt.step() + return + else: opt.step() - return - else: - opt.step() - opt.zero_grad() - - # copy state dict into a new instance - state_dict = opt.state_dict() - opt_new = Lion8bit(params, - compress_state_dict=quantized_state, - error_correction=use_errors) - opt_new.load_state_dict(state_dict) - - for p in params: - d_orig = opt.state[p] - d_new = opt_new.state[p] - assert list(d_orig.keys()) == list(d_new.keys()) - mom_orig = d_orig['exp_avg'] - mom_new = d_new['exp_avg'] - if quantized_state: - # Optimizer load_state_dict insists on converting scales to - # dtype of param, which is lossy for bf16 params. - # Ideally we'd require == for everything but it's less complexity - # to just relax the bf16 test - assert torch.all(mom_orig.quantized == mom_new.quantized) - if dtype == torch.bfloat16: - torch.testing.assert_close(mom_orig.scales, - mom_new.scales, - atol=1e-3, - rtol=1e-2) - else: - assert torch.all(mom_orig.scales == mom_new.scales) - - torch.testing.assert_close(mom_orig.materialize(), - mom_new.materialize(), - atol=1. / (2 * 127), - rtol=np.inf) - if use_errors and (dtype != torch.float32): - torch.testing.assert_close(d_orig['errors'], d_new['errors']) + opt.zero_grad() + + # copy state dict into a new instance + state_dict = opt.state_dict() + opt_new = Lion8bit(params, + compress_state_dict=quantized_state, + error_correction=use_errors) + opt_new.load_state_dict(state_dict) + + for p in params: + d_orig = opt.state[p] + d_new = opt_new.state[p] + assert list(d_orig.keys()) == list(d_new.keys()) + mom_orig = d_orig['exp_avg'] + mom_new = d_new['exp_avg'] + if quantized_state: + # Optimizer load_state_dict insists on converting scales to + # dtype of param, which is lossy for bf16 params. + # Ideally we'd require == for everything but it's less complexity + # to just relax the bf16 test + assert torch.all(mom_orig.quantized == mom_new.quantized) + if dtype == torch.bfloat16: + torch.testing.assert_close(mom_orig.scales, + mom_new.scales, + atol=1e-3, + rtol=1e-2) + else: + assert torch.all(mom_orig.scales == mom_new.scales) + + torch.testing.assert_close(mom_orig.materialize(), + mom_new.materialize(), + atol=1. / (2 * 127), + rtol=np.inf) + if use_errors and (dtype != torch.float32): + torch.testing.assert_close(d_orig['errors'], d_new['errors']) class _DummyModule(nn.Module): From 70e92ac663ce224b747ff0eccacbe730c49d1d66 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Fri, 6 Oct 2023 15:59:44 -0700 Subject: [PATCH 06/11] precommit --- tests/test_lion8b.py | 59 +++++++++++++++++++++++--------------------- 1 file changed, 31 insertions(+), 28 deletions(-) diff --git a/tests/test_lion8b.py b/tests/test_lion8b.py index af8c7b00d7..255f1bb208 100644 --- a/tests/test_lion8b.py +++ b/tests/test_lion8b.py @@ -56,11 +56,11 @@ def test_modifies_weights_and_momentums(N: int, D: int, dtype: torch.dtype, W_orig = W.detach().clone() opt = Lion8bit([W], - lr=1.0, - _fused=fused, - betas=(.75, .75), - weight_decay=.2, - error_correction=use_errors) + lr=1.0, + _fused=fused, + betas=(.75, .75), + weight_decay=.2, + error_correction=use_errors) Y = X @ W loss = Y.sum() @@ -99,7 +99,7 @@ def test_changes_with_zero_grads(N: int, D: int, device: str, fused: bool, use_errors: bool) -> None: if (device == 'cpu') and (fused or use_errors): return - + error_context = contextlib.nullcontext() if use_errors and version.parse( torch.__version__) >= version.parse('2.1.0'): @@ -112,11 +112,11 @@ def test_changes_with_zero_grads(N: int, D: int, device: str, W_orig = W.detach().clone() opt = Lion8bit([W], - _fused=fused, - betas=(.5, .5), - quantize=(device != 'cpu'), - weight_decay=weight_decay, - error_correction=use_errors) + _fused=fused, + betas=(.5, .5), + quantize=(device != 'cpu'), + weight_decay=weight_decay, + error_correction=use_errors) zeros_grad = torch.zeros_like(W) for _ in range(5): @@ -147,7 +147,7 @@ def test_descends(N: int, D: int, device: str, dtype: torch.dtype, fused: bool, use_errors: bool) -> None: if (device == 'cpu') and (fused or use_errors): return - + error_context = contextlib.nullcontext() if use_errors and version.parse( torch.__version__) >= version.parse('2.1.0'): @@ -161,11 +161,11 @@ def test_descends(N: int, D: int, device: str, dtype: torch.dtype, fused: bool, # we use tiny beta1 so we move almost entirely in the gradient direction opt = Lion8bit([W], - lr=1e-2, - betas=(.5, .5), - quantize=(device != 'cpu'), - _fused=fused, - error_correction=use_errors) + lr=1e-2, + betas=(.5, .5), + quantize=(device != 'cpu'), + _fused=fused, + error_correction=use_errors) prev_loss = np.inf prev_momentum = None @@ -361,14 +361,17 @@ def test_state_dict_save_load(device: str, quantized_state: bool, torch.manual_seed(123) params = [] for shape in _MANY_PARAM_SHAPES: - p = torch.rand(shape, device=device, dtype=dtype, requires_grad=True) + p = torch.rand(shape, + device=device, + dtype=dtype, + requires_grad=True) p.grad = torch.rand_like(p) params.append(p) # create optimizer and have it step so that state gets populated opt = Lion8bit(params, - compress_state_dict=quantized_state, - error_correction=use_errors) + compress_state_dict=quantized_state, + error_correction=use_errors) if device == 'cpu': with pytest.raises(NotImplementedError): opt.step() @@ -380,8 +383,8 @@ def test_state_dict_save_load(device: str, quantized_state: bool, # copy state dict into a new instance state_dict = opt.state_dict() opt_new = Lion8bit(params, - compress_state_dict=quantized_state, - error_correction=use_errors) + compress_state_dict=quantized_state, + error_correction=use_errors) opt_new.load_state_dict(state_dict) for p in params: @@ -398,16 +401,16 @@ def test_state_dict_save_load(device: str, quantized_state: bool, assert torch.all(mom_orig.quantized == mom_new.quantized) if dtype == torch.bfloat16: torch.testing.assert_close(mom_orig.scales, - mom_new.scales, - atol=1e-3, - rtol=1e-2) + mom_new.scales, + atol=1e-3, + rtol=1e-2) else: assert torch.all(mom_orig.scales == mom_new.scales) torch.testing.assert_close(mom_orig.materialize(), - mom_new.materialize(), - atol=1. / (2 * 127), - rtol=np.inf) + mom_new.materialize(), + atol=1. / (2 * 127), + rtol=np.inf) if use_errors and (dtype != torch.float32): torch.testing.assert_close(d_orig['errors'], d_new['errors']) From 42499ab2c66873cd272b3df290327c979e51b13a Mon Sep 17 00:00:00 2001 From: Daniel King Date: Fri, 6 Oct 2023 16:27:46 -0700 Subject: [PATCH 07/11] more skips --- tests/test_lion8b.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/tests/test_lion8b.py b/tests/test_lion8b.py index 255f1bb208..4ce5de7138 100644 --- a/tests/test_lion8b.py +++ b/tests/test_lion8b.py @@ -255,11 +255,16 @@ def test_lion8b_fused_unfused_unquantized_same(w_init: str, grad_strategy: str, opt_uq = Lion8bit([W_uq], quantize=False, **kwargs) opt_uf = Lion8bit([W_uf], _fused=False, **kwargs) opt_fq = Lion8bit([W_fq], _fused=True, **kwargs) - opt_fqe = Lion8bit([W_fqe], _fused=True, error_correction=True, **kwargs) opt_sgd = torch.optim.SGD([W_sgd], lr=lr) - W_list = [W_true, W_uq, W_uf, W_fq, W_fqe, W_sgd] - opt_list = [opt_true, opt_uq, opt_uf, opt_fq, opt_fqe, opt_sgd] + W_list = [W_true, W_uq, W_uf, W_fq, W_sgd] + opt_list = [opt_true, opt_uq, opt_uf, opt_fq, opt_sgd] + + # error correction not supported on torch 2.1 + if version.parse(torch.__version__) < version.parse('2.1.0'): + opt_fqe = Lion8bit([W_fqe], _fused=True, error_correction=True, **kwargs) + W_list.append(W_fqe) + opt_list.append(opt_fqe) if grad_strategy == 'zero': grads = torch.zeros_like(W0) @@ -555,7 +560,13 @@ def _time_kernels(N: int, D: int, min_elems_traversed: int): times = {} kwargs = {'weight_decay': .01} + combos = [(True, False), (True, True), (False, False), ('NA', False)] + # use_errors not currently supported on torch 2.1 + if version.parse( + torch.__version__) >= version.parse('2.1.0'): + combos = [(True, False), (False, False), ('NA', False)] + for fused, use_errors in combos: if fused == 'NA': opt = Lion8bit( From 9cd5ec465bdd5c7a106a7d060e4bc3d73c50701b Mon Sep 17 00:00:00 2001 From: Daniel King Date: Fri, 6 Oct 2023 16:33:22 -0700 Subject: [PATCH 08/11] precommit --- tests/test_lion8b.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/test_lion8b.py b/tests/test_lion8b.py index 4ce5de7138..312480f61a 100644 --- a/tests/test_lion8b.py +++ b/tests/test_lion8b.py @@ -259,10 +259,13 @@ def test_lion8b_fused_unfused_unquantized_same(w_init: str, grad_strategy: str, W_list = [W_true, W_uq, W_uf, W_fq, W_sgd] opt_list = [opt_true, opt_uq, opt_uf, opt_fq, opt_sgd] - + # error correction not supported on torch 2.1 if version.parse(torch.__version__) < version.parse('2.1.0'): - opt_fqe = Lion8bit([W_fqe], _fused=True, error_correction=True, **kwargs) + opt_fqe = Lion8bit([W_fqe], + _fused=True, + error_correction=True, + **kwargs) W_list.append(W_fqe) opt_list.append(opt_fqe) @@ -563,8 +566,7 @@ def _time_kernels(N: int, D: int, min_elems_traversed: int): combos = [(True, False), (True, True), (False, False), ('NA', False)] # use_errors not currently supported on torch 2.1 - if version.parse( - torch.__version__) >= version.parse('2.1.0'): + if version.parse(torch.__version__) >= version.parse('2.1.0'): combos = [(True, False), (False, False), ('NA', False)] for fused, use_errors in combos: From 82ae8246a8e87e0e915ba4353d7db3c7ab6dd5a4 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Fri, 6 Oct 2023 23:51:16 +0000 Subject: [PATCH 09/11] more test skipping --- tests/test_lion8b.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/tests/test_lion8b.py b/tests/test_lion8b.py index 312480f61a..018ce5f7cd 100644 --- a/tests/test_lion8b.py +++ b/tests/test_lion8b.py @@ -331,13 +331,15 @@ def test_lion8b_fused_unfused_unquantized_same(w_init: str, grad_strategy: str, # at all; latter is "ground truth" assert cossim(diffs_true, diffs_fq, dim=-1) > min_cossim assert _nmse(diffs_true, diffs_fq) < max_nmse + + # error correction not supported on torch 2.1 + if version.parse(torch.__version__) < version.parse('2.1.0'): + # fused impl with errors should also be close to "true" updates; + assert cossim(diffs_true, diffs_fqe, dim=-1) > min_cossim + assert _nmse(diffs_true, diffs_fqe) < max_nmse - # fused impl with errors should also be close to "true" updates; - assert cossim(diffs_true, diffs_fqe, dim=-1) > min_cossim - assert _nmse(diffs_true, diffs_fqe) < max_nmse - - # error correction should reduce error, or at least do no worse - assert _nmse(diffs_true, diffs_fqe) <= _nmse(diffs_true, diffs_fq) + # error correction should reduce error, or at least do no worse + assert _nmse(diffs_true, diffs_fqe) <= _nmse(diffs_true, diffs_fq) # if sgd weights aren't different than LION weights, we haven't # changed them enough to meaningfully test the LION logic @@ -601,12 +603,10 @@ def _time_kernels(N: int, D: int, min_elems_traversed: int): try: assert times[True] < times[False] + atol assert times[True] < times['NA'] + atol - assert times['ecc'] < times['NA'] + atol - print('') - print('time fused (ms): ', times[True] * 1e3) - print('time fused+ecc (ms): ', times['ecc'] * 1e3) - print('time unfused (ms): ', times[False] * 1e3) - print('time unquantized (ms): ', times['NA'] * 1e3) + + # error correction not supported on torch 2.1 + if version.parse(torch.__version__) < version.parse('2.1.0'): + assert times['ecc'] < times['NA'] + atol break except AssertionError as e: if it >= 2: # allow 3 retries to avoid flakiness From 42244ad54171635001bab9e72bd73543eb285b03 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Fri, 6 Oct 2023 16:51:50 -0700 Subject: [PATCH 10/11] precommit --- tests/test_lion8b.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_lion8b.py b/tests/test_lion8b.py index 018ce5f7cd..35368be593 100644 --- a/tests/test_lion8b.py +++ b/tests/test_lion8b.py @@ -331,7 +331,7 @@ def test_lion8b_fused_unfused_unquantized_same(w_init: str, grad_strategy: str, # at all; latter is "ground truth" assert cossim(diffs_true, diffs_fq, dim=-1) > min_cossim assert _nmse(diffs_true, diffs_fq) < max_nmse - + # error correction not supported on torch 2.1 if version.parse(torch.__version__) < version.parse('2.1.0'): # fused impl with errors should also be close to "true" updates; From 3c43bad8d1b2bbe8b3cc78e3dbff1160d56e9560 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Fri, 6 Oct 2023 17:49:06 -0700 Subject: [PATCH 11/11] switch to cu118 image --- .github/workflows/pr-gpu.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pr-gpu.yaml b/.github/workflows/pr-gpu.yaml index f0650f6179..769b345e39 100644 --- a/.github/workflows/pr-gpu.yaml +++ b/.github/workflows/pr-gpu.yaml @@ -24,7 +24,7 @@ jobs: markers: 'gpu' pytest_command: 'coverage run -m pytest' - name: 'gpu-2.0.1' - container: mosaicml/pytorch:2.0.1_cu117-python3.10-ubuntu20.04 + container: mosaicml/pytorch:2.0.1_cu118-python3.10-ubuntu20.04 markers: 'gpu' pytest_command: 'coverage run -m pytest' - name: 'gpu-2.1.0'