diff --git a/tests/test_lion8b.py b/tests/test_lion8b.py index ddb70e882b..0c7010ce9f 100644 --- a/tests/test_lion8b.py +++ b/tests/test_lion8b.py @@ -24,6 +24,7 @@ LocalOptimStateDictConfig = MagicMock() ShardedOptimStateDictConfig = MagicMock() +from llmfoundry.optim import DecoupledLionW from llmfoundry.optim import DecoupledLionW_8bit as Lion8bit warnings.filterwarnings('ignore') @@ -406,8 +407,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # type:ignore @pytest.mark.parametrize('use_errors', [False, True]) @pytest.mark.parametrize('state_sharding', [_FULL_STATE, _SHARDED_STATE, _LOCAL_STATE]) +@pytest.mark.parametrize('save_as_lion8b, load_as_lion8b', [(False, True), + (True, False), + (True, True)]) def test_fsdp_save_load(dtype: torch.dtype, use_errors: bool, - state_sharding: fsdp.StateDictType): + state_sharding: fsdp.StateDictType, + save_as_lion8b: bool, load_as_lion8b: bool): device = 'cuda' if torch.cuda.device_count() < 2: pytest.skip(f'This test requires 2+ GPUs.') @@ -419,6 +424,10 @@ def test_fsdp_save_load(dtype: torch.dtype, use_errors: bool, dist.init_process_group(backend='nccl') assert dist.get_world_size() >= 2, 'Misconfigured test run!' + # nb: this is the line that causes: + # `Warning: Deallocating Tensor that still has live PyObject references.` + # suggesting this warning isn't an issue with our test code. It's also + # going to stdout (probably from cpp) so we can't suppress it with warnings mod = FSDP(_DummyModule(device=device, dtype=dtype)) # actual forward pass instead of setting p.grad to avoid FSDP issues @@ -429,7 +438,10 @@ def test_fsdp_save_load(dtype: torch.dtype, use_errors: bool, p.grad = torch.rand_like(p) # create optimizer and have it step so that state gets populated - opt = Lion8bit(mod.parameters(), error_correction=use_errors) + if save_as_lion8b: + opt = Lion8bit(mod.parameters(), error_correction=use_errors) + else: + opt = DecoupledLionW(mod.parameters()) opt.step() opt.zero_grad() @@ -449,13 +461,22 @@ def _set_state_dict_type(model: nn.Module): FSDP.set_state_dict_type(model, state_sharding, state_dict_cfg, optim_cfg) + def _local_shard(t: torch.Tensor) -> torch.Tensor: + try: # can't operate on ShardedTensors directly + return t.local_tensor() # type: ignore + except AttributeError: + return t + # load FSDP state dict _set_state_dict_type(mod) opt_state_dict = FSDP.optim_state_dict(mod, opt) # make a new model and optimizer mod_new = FSDP(_DummyModule(device=device, dtype=dtype)) - opt_new = Lion8bit(mod_new.parameters(), error_correction=use_errors) + if load_as_lion8b: + opt_new = Lion8bit(mod_new.parameters(), error_correction=use_errors) + else: + opt_new = DecoupledLionW(mod_new.parameters()) _set_state_dict_type(mod_new) # load state dict into the new optimizer @@ -480,22 +501,26 @@ def _set_state_dict_type(model: nn.Module): mom_new = d_new['exp_avg'] assert mom_orig.shape == mom_new.shape - assert mom_orig.dtype == mom_new.dtype - if use_errors and (dtype != torch.float32): - errs_orig = d_orig['errors'] - errs_new = d_new['errors'] - assert errs_orig.shape == errs_new.shape - assert errs_orig.dtype == errs_new.dtype - - if state_sharding != _FULL_STATE: - continue # more detailed checks lean on FSDP impl details + both_lion8b = save_as_lion8b and load_as_lion8b + check_errors = both_lion8b and use_errors and (dtype != torch.float32) + if both_lion8b: + assert mom_orig.dtype == mom_new.dtype + if check_errors: + errs_orig = d_orig['errors'] + errs_new = d_new['errors'] + assert errs_orig.shape == errs_new.shape + assert errs_orig.dtype == errs_new.dtype # momentums may not be bit-for-bit identical because Optimizer upcasts # to f32 and we convert back to bf16, possibly with different rounding - torch.testing.assert_close(mom_orig, mom_new) + torch.testing.assert_close(_local_shard(mom_orig).float(), + _local_shard(mom_new).float(), + atol=1e-4, + rtol=1. / 128) # errors not bit-for-bit identical because scales get upcast too - if use_errors and (dtype != torch.float32): - torch.testing.assert_close(d_orig['errors'], d_new['errors']) + if check_errors: + torch.testing.assert_close(_local_shard(d_orig['errors']), + _local_shard(d_new['errors'])) @pytest.mark.gpu