diff --git a/tests/test_lion8b.py b/tests/test_lion8b.py index af8c7b00d7..255f1bb208 100644 --- a/tests/test_lion8b.py +++ b/tests/test_lion8b.py @@ -56,11 +56,11 @@ def test_modifies_weights_and_momentums(N: int, D: int, dtype: torch.dtype, W_orig = W.detach().clone() opt = Lion8bit([W], - lr=1.0, - _fused=fused, - betas=(.75, .75), - weight_decay=.2, - error_correction=use_errors) + lr=1.0, + _fused=fused, + betas=(.75, .75), + weight_decay=.2, + error_correction=use_errors) Y = X @ W loss = Y.sum() @@ -99,7 +99,7 @@ def test_changes_with_zero_grads(N: int, D: int, device: str, fused: bool, use_errors: bool) -> None: if (device == 'cpu') and (fused or use_errors): return - + error_context = contextlib.nullcontext() if use_errors and version.parse( torch.__version__) >= version.parse('2.1.0'): @@ -112,11 +112,11 @@ def test_changes_with_zero_grads(N: int, D: int, device: str, W_orig = W.detach().clone() opt = Lion8bit([W], - _fused=fused, - betas=(.5, .5), - quantize=(device != 'cpu'), - weight_decay=weight_decay, - error_correction=use_errors) + _fused=fused, + betas=(.5, .5), + quantize=(device != 'cpu'), + weight_decay=weight_decay, + error_correction=use_errors) zeros_grad = torch.zeros_like(W) for _ in range(5): @@ -147,7 +147,7 @@ def test_descends(N: int, D: int, device: str, dtype: torch.dtype, fused: bool, use_errors: bool) -> None: if (device == 'cpu') and (fused or use_errors): return - + error_context = contextlib.nullcontext() if use_errors and version.parse( torch.__version__) >= version.parse('2.1.0'): @@ -161,11 +161,11 @@ def test_descends(N: int, D: int, device: str, dtype: torch.dtype, fused: bool, # we use tiny beta1 so we move almost entirely in the gradient direction opt = Lion8bit([W], - lr=1e-2, - betas=(.5, .5), - quantize=(device != 'cpu'), - _fused=fused, - error_correction=use_errors) + lr=1e-2, + betas=(.5, .5), + quantize=(device != 'cpu'), + _fused=fused, + error_correction=use_errors) prev_loss = np.inf prev_momentum = None @@ -361,14 +361,17 @@ def test_state_dict_save_load(device: str, quantized_state: bool, torch.manual_seed(123) params = [] for shape in _MANY_PARAM_SHAPES: - p = torch.rand(shape, device=device, dtype=dtype, requires_grad=True) + p = torch.rand(shape, + device=device, + dtype=dtype, + requires_grad=True) p.grad = torch.rand_like(p) params.append(p) # create optimizer and have it step so that state gets populated opt = Lion8bit(params, - compress_state_dict=quantized_state, - error_correction=use_errors) + compress_state_dict=quantized_state, + error_correction=use_errors) if device == 'cpu': with pytest.raises(NotImplementedError): opt.step() @@ -380,8 +383,8 @@ def test_state_dict_save_load(device: str, quantized_state: bool, # copy state dict into a new instance state_dict = opt.state_dict() opt_new = Lion8bit(params, - compress_state_dict=quantized_state, - error_correction=use_errors) + compress_state_dict=quantized_state, + error_correction=use_errors) opt_new.load_state_dict(state_dict) for p in params: @@ -398,16 +401,16 @@ def test_state_dict_save_load(device: str, quantized_state: bool, assert torch.all(mom_orig.quantized == mom_new.quantized) if dtype == torch.bfloat16: torch.testing.assert_close(mom_orig.scales, - mom_new.scales, - atol=1e-3, - rtol=1e-2) + mom_new.scales, + atol=1e-3, + rtol=1e-2) else: assert torch.all(mom_orig.scales == mom_new.scales) torch.testing.assert_close(mom_orig.materialize(), - mom_new.materialize(), - atol=1. / (2 * 127), - rtol=np.inf) + mom_new.materialize(), + atol=1. / (2 * 127), + rtol=np.inf) if use_errors and (dtype != torch.float32): torch.testing.assert_close(d_orig['errors'], d_new['errors'])