Skip to content

Commit

Permalink
precommit
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Oct 6, 2023
1 parent f10be3e commit 70e92ac
Showing 1 changed file with 31 additions and 28 deletions.
59 changes: 31 additions & 28 deletions tests/test_lion8b.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ def test_modifies_weights_and_momentums(N: int, D: int, dtype: torch.dtype,
W_orig = W.detach().clone()

opt = Lion8bit([W],
lr=1.0,
_fused=fused,
betas=(.75, .75),
weight_decay=.2,
error_correction=use_errors)
lr=1.0,
_fused=fused,
betas=(.75, .75),
weight_decay=.2,
error_correction=use_errors)

Y = X @ W
loss = Y.sum()
Expand Down Expand Up @@ -99,7 +99,7 @@ def test_changes_with_zero_grads(N: int, D: int, device: str,
fused: bool, use_errors: bool) -> None:
if (device == 'cpu') and (fused or use_errors):
return

error_context = contextlib.nullcontext()
if use_errors and version.parse(
torch.__version__) >= version.parse('2.1.0'):
Expand All @@ -112,11 +112,11 @@ def test_changes_with_zero_grads(N: int, D: int, device: str,
W_orig = W.detach().clone()

opt = Lion8bit([W],
_fused=fused,
betas=(.5, .5),
quantize=(device != 'cpu'),
weight_decay=weight_decay,
error_correction=use_errors)
_fused=fused,
betas=(.5, .5),
quantize=(device != 'cpu'),
weight_decay=weight_decay,
error_correction=use_errors)

zeros_grad = torch.zeros_like(W)
for _ in range(5):
Expand Down Expand Up @@ -147,7 +147,7 @@ def test_descends(N: int, D: int, device: str, dtype: torch.dtype, fused: bool,
use_errors: bool) -> None:
if (device == 'cpu') and (fused or use_errors):
return

error_context = contextlib.nullcontext()
if use_errors and version.parse(
torch.__version__) >= version.parse('2.1.0'):
Expand All @@ -161,11 +161,11 @@ def test_descends(N: int, D: int, device: str, dtype: torch.dtype, fused: bool,

# we use tiny beta1 so we move almost entirely in the gradient direction
opt = Lion8bit([W],
lr=1e-2,
betas=(.5, .5),
quantize=(device != 'cpu'),
_fused=fused,
error_correction=use_errors)
lr=1e-2,
betas=(.5, .5),
quantize=(device != 'cpu'),
_fused=fused,
error_correction=use_errors)

prev_loss = np.inf
prev_momentum = None
Expand Down Expand Up @@ -361,14 +361,17 @@ def test_state_dict_save_load(device: str, quantized_state: bool,
torch.manual_seed(123)
params = []
for shape in _MANY_PARAM_SHAPES:
p = torch.rand(shape, device=device, dtype=dtype, requires_grad=True)
p = torch.rand(shape,
device=device,
dtype=dtype,
requires_grad=True)
p.grad = torch.rand_like(p)
params.append(p)

# create optimizer and have it step so that state gets populated
opt = Lion8bit(params,
compress_state_dict=quantized_state,
error_correction=use_errors)
compress_state_dict=quantized_state,
error_correction=use_errors)
if device == 'cpu':
with pytest.raises(NotImplementedError):
opt.step()
Expand All @@ -380,8 +383,8 @@ def test_state_dict_save_load(device: str, quantized_state: bool,
# copy state dict into a new instance
state_dict = opt.state_dict()
opt_new = Lion8bit(params,
compress_state_dict=quantized_state,
error_correction=use_errors)
compress_state_dict=quantized_state,
error_correction=use_errors)
opt_new.load_state_dict(state_dict)

for p in params:
Expand All @@ -398,16 +401,16 @@ def test_state_dict_save_load(device: str, quantized_state: bool,
assert torch.all(mom_orig.quantized == mom_new.quantized)
if dtype == torch.bfloat16:
torch.testing.assert_close(mom_orig.scales,
mom_new.scales,
atol=1e-3,
rtol=1e-2)
mom_new.scales,
atol=1e-3,
rtol=1e-2)
else:
assert torch.all(mom_orig.scales == mom_new.scales)

torch.testing.assert_close(mom_orig.materialize(),
mom_new.materialize(),
atol=1. / (2 * 127),
rtol=np.inf)
mom_new.materialize(),
atol=1. / (2 * 127),
rtol=np.inf)
if use_errors and (dtype != torch.float32):
torch.testing.assert_close(d_orig['errors'], d_new['errors'])

Expand Down

0 comments on commit 70e92ac

Please sign in to comment.