Skip to content

Commit

Permalink
add test coverage for lion and lion8b checkpoint interop
Browse files Browse the repository at this point in the history
  • Loading branch information
dblalock committed Oct 17, 2023
1 parent aecadc9 commit 2ec5d18
Showing 1 changed file with 40 additions and 15 deletions.
55 changes: 40 additions & 15 deletions tests/test_lion8b.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
LocalOptimStateDictConfig = MagicMock()
ShardedOptimStateDictConfig = MagicMock()

from llmfoundry.optim import DecoupledLionW
from llmfoundry.optim import DecoupledLionW_8bit as Lion8bit

warnings.filterwarnings('ignore')
Expand Down Expand Up @@ -406,8 +407,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # type:ignore
@pytest.mark.parametrize('use_errors', [False, True])
@pytest.mark.parametrize('state_sharding',
[_FULL_STATE, _SHARDED_STATE, _LOCAL_STATE])
@pytest.mark.parametrize('save_as_lion8b, load_as_lion8b', [(False, True),
(True, False),
(True, True)])
def test_fsdp_save_load(dtype: torch.dtype, use_errors: bool,
state_sharding: fsdp.StateDictType):
state_sharding: fsdp.StateDictType,
save_as_lion8b: bool, load_as_lion8b: bool):
device = 'cuda'
if torch.cuda.device_count() < 2:
pytest.skip(f'This test requires 2+ GPUs.')
Expand All @@ -419,6 +424,10 @@ def test_fsdp_save_load(dtype: torch.dtype, use_errors: bool,
dist.init_process_group(backend='nccl')
assert dist.get_world_size() >= 2, 'Misconfigured test run!'

# nb: this is the line that causes:
# `Warning: Deallocating Tensor that still has live PyObject references.`
# suggesting this warning isn't an issue with our test code. It's also
# going to stdout (probably from cpp) so we can't suppress it with warnings
mod = FSDP(_DummyModule(device=device, dtype=dtype))

# actual forward pass instead of setting p.grad to avoid FSDP issues
Expand All @@ -429,7 +438,10 @@ def test_fsdp_save_load(dtype: torch.dtype, use_errors: bool,
p.grad = torch.rand_like(p)

# create optimizer and have it step so that state gets populated
opt = Lion8bit(mod.parameters(), error_correction=use_errors)
if save_as_lion8b:
opt = Lion8bit(mod.parameters(), error_correction=use_errors)
else:
opt = DecoupledLionW(mod.parameters())
opt.step()
opt.zero_grad()

Expand All @@ -449,13 +461,22 @@ def _set_state_dict_type(model: nn.Module):
FSDP.set_state_dict_type(model, state_sharding, state_dict_cfg,
optim_cfg)

def _local_shard(t: torch.Tensor) -> torch.Tensor:
try: # can't operate on ShardedTensors directly
return t.local_tensor() # type: ignore
except AttributeError:
return t

# load FSDP state dict
_set_state_dict_type(mod)
opt_state_dict = FSDP.optim_state_dict(mod, opt)

# make a new model and optimizer
mod_new = FSDP(_DummyModule(device=device, dtype=dtype))
opt_new = Lion8bit(mod_new.parameters(), error_correction=use_errors)
if load_as_lion8b:
opt_new = Lion8bit(mod_new.parameters(), error_correction=use_errors)
else:
opt_new = DecoupledLionW(mod_new.parameters())
_set_state_dict_type(mod_new)

# load state dict into the new optimizer
Expand All @@ -480,22 +501,26 @@ def _set_state_dict_type(model: nn.Module):
mom_new = d_new['exp_avg']

assert mom_orig.shape == mom_new.shape
assert mom_orig.dtype == mom_new.dtype
if use_errors and (dtype != torch.float32):
errs_orig = d_orig['errors']
errs_new = d_new['errors']
assert errs_orig.shape == errs_new.shape
assert errs_orig.dtype == errs_new.dtype

if state_sharding != _FULL_STATE:
continue # more detailed checks lean on FSDP impl details
both_lion8b = save_as_lion8b and load_as_lion8b
check_errors = both_lion8b and use_errors and (dtype != torch.float32)
if both_lion8b:
assert mom_orig.dtype == mom_new.dtype
if check_errors:
errs_orig = d_orig['errors']
errs_new = d_new['errors']
assert errs_orig.shape == errs_new.shape
assert errs_orig.dtype == errs_new.dtype

# momentums may not be bit-for-bit identical because Optimizer upcasts
# to f32 and we convert back to bf16, possibly with different rounding
torch.testing.assert_close(mom_orig, mom_new)
torch.testing.assert_close(_local_shard(mom_orig).float(),
_local_shard(mom_new).float(),
atol=1e-4,
rtol=1. / 128)
# errors not bit-for-bit identical because scales get upcast too
if use_errors and (dtype != torch.float32):
torch.testing.assert_close(d_orig['errors'], d_new['errors'])
if check_errors:
torch.testing.assert_close(_local_shard(d_orig['errors']),
_local_shard(d_new['errors']))


@pytest.mark.gpu
Expand Down

0 comments on commit 2ec5d18

Please sign in to comment.