Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes to lion8b test for torch 2.1 #649

Merged
merged 11 commits into from
Oct 7, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion llmfoundry/optim/lion8b.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, Callable, Dict, Iterable, Optional, Tuple

import torch
from packaging import version


class DecoupledLionW_8bit(torch.optim.Optimizer):
Expand Down Expand Up @@ -53,7 +54,7 @@ class DecoupledLionW_8bit(torch.optim.Optimizer):
by retaining information across optimizer steps.

Raises:
NotImplemenetedError - If any of `quantize`, `compress_state_dict`,
NotImplementedError - If any of `quantize`, `compress_state_dict`,
or `error_correction` are `True` and either a) there is no CUDA
device, or b) step() is executed on a non-CUDA parameter.
"""
Expand All @@ -67,6 +68,12 @@ def __init__(self,
compress_state_dict: bool = False,
error_correction: bool = False,
_fused: bool = True): # XXX this flag is mostly for testing...
if version.parse(torch.__version__) >= version.parse(
'2.1.0') and error_correction:
raise RuntimeError(
'DecoupledLionW_8bit with error correction requires PyTorch <2.1.0'
)

if lr < 0.0:
raise ValueError('Invalid learning rate: {}'.format(lr))
if not 0.0 <= betas[0] <= 1.0:
Expand Down
170 changes: 89 additions & 81 deletions tests/test_lion8b.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import contextlib
import os
import time
import warnings
Expand Down Expand Up @@ -414,88 +415,95 @@ def test_fsdp_save_load(dtype: torch.dtype, use_errors: bool,
if version.parse(torch.__version__) < version.parse('2.0.1'):
pytest.skip(f'This test requires torch 2.0.1 or greater.')

torch.cuda.set_device(f'cuda:{os.environ["RANK"]}') # needed for fsdp
if not dist.is_initialized():
dist.init_process_group()
assert dist.get_world_size() >= 2, 'Misconfigured test run!'

mod = FSDP(_DummyModule(device=device, dtype=dtype))

# actual forward pass instead of setting p.grad to avoid FSDP issues
X = torch.rand(size=(5, 4), device=device, dtype=dtype)
Y = mod(X)
Y.sum().backward()
for p in mod.parameters():
p.grad = torch.rand_like(p)

# create optimizer and have it step so that state gets populated
opt = Lion8bit(mod.parameters(), error_correction=use_errors)
opt.step()
opt.zero_grad()

def _set_state_dict_type(model: nn.Module):
# for mapping between state dict types and optim state dict types, see:
# https://github.com/pytorch/pytorch/blob/a815e719e85899d4229616617e7827d4de191c2d/torch/distributed/fsdp/fully_sharded_data_parallel.py#L664 # noqa
state_dict_cfg = {
_FULL_STATE: fsdp.FullStateDictConfig(rank0_only=False),
_SHARDED_STATE: fsdp.ShardedStateDictConfig(),
_LOCAL_STATE: fsdp.LocalStateDictConfig(),
}[state_sharding]
optim_cfg = {
_FULL_STATE: FullOptimStateDictConfig(rank0_only=False),
_SHARDED_STATE: ShardedOptimStateDictConfig(),
_LOCAL_STATE: LocalOptimStateDictConfig(),
}[state_sharding]
FSDP.set_state_dict_type(model, state_sharding, state_dict_cfg,
optim_cfg)

# load FSDP state dict
_set_state_dict_type(mod)
opt_state_dict = FSDP.optim_state_dict(mod, opt)

# make a new model and optimizer
mod_new = FSDP(_DummyModule(device=device, dtype=dtype))
opt_new = Lion8bit(mod_new.parameters(), error_correction=use_errors)
_set_state_dict_type(mod_new)

# load state dict into the new optimizer
opt_state_dict_slice = FSDP.optim_state_dict_to_load(
optim_state_dict=opt_state_dict, model=mod_new, optim=opt_new)
opt_new.load_state_dict(opt_state_dict_slice)

new_opt_state_dict = FSDP.optim_state_dict(mod_new, opt_new)

orig_state = opt_state_dict['state']
orig_param_groups = opt_state_dict['param_groups']
new_state = new_opt_state_dict['state']
new_param_groups = new_opt_state_dict['param_groups']

all_keys = set(orig_state.keys()) | set(new_state.keys())
assert orig_param_groups == new_param_groups # works since strs, not ptrs
for k in all_keys: # keys are param paths in module as strings
d_orig = orig_state[k]
d_new = new_state[k]
assert list(d_orig.keys()) == list(d_new.keys())
mom_orig = d_orig['exp_avg']
mom_new = d_new['exp_avg']
error_context = contextlib.nullcontext()
if use_errors and version.parse(
torch.__version__) >= version.parse('2.1.0'):
error_context = pytest.raises(
RuntimeError, match='DecoupledLionW_8bit with error correction')

with error_context:
torch.cuda.set_device(f'cuda:{os.environ["RANK"]}') # needed for fsdp
if not dist.is_initialized():
dist.init_process_group(backend='nccl')
assert dist.get_world_size() >= 2, 'Misconfigured test run!'

mod = FSDP(_DummyModule(device=device, dtype=dtype))

# actual forward pass instead of setting p.grad to avoid FSDP issues
X = torch.rand(size=(5, 4), device=device, dtype=dtype)
Y = mod(X)
Y.sum().backward()
for p in mod.parameters():
p.grad = torch.rand_like(p)

# create optimizer and have it step so that state gets populated
opt = Lion8bit(mod.parameters(), error_correction=use_errors)
opt.step()
opt.zero_grad()

assert mom_orig.shape == mom_new.shape
assert mom_orig.dtype == mom_new.dtype
if use_errors:
errs_orig = d_orig['errors']
errs_new = d_new['errors']
assert errs_orig.shape == errs_new.shape
assert errs_orig.dtype == errs_new.dtype

if state_sharding != _FULL_STATE:
continue # more detailed checks lean on FSDP impl details

# momentums may not be bit-for-bit identical because Optimizer upcasts
# to f32 and we convert back to bf16, possibly with different rounding
torch.testing.assert_close(mom_orig, mom_new)
# errors not bit-for-bit identical because scales get upcast too
if use_errors and (dtype != torch.float32):
torch.testing.assert_close(d_orig['errors'], d_new['errors'])
def _set_state_dict_type(model: nn.Module):
# for mapping between state dict types and optim state dict types, see:
# https://github.com/pytorch/pytorch/blob/a815e719e85899d4229616617e7827d4de191c2d/torch/distributed/fsdp/fully_sharded_data_parallel.py#L664 # noqa
state_dict_cfg = {
_FULL_STATE: fsdp.FullStateDictConfig(rank0_only=False),
_SHARDED_STATE: fsdp.ShardedStateDictConfig(),
_LOCAL_STATE: fsdp.LocalStateDictConfig(),
}[state_sharding]
optim_cfg = {
_FULL_STATE: FullOptimStateDictConfig(rank0_only=False),
_SHARDED_STATE: ShardedOptimStateDictConfig(),
_LOCAL_STATE: LocalOptimStateDictConfig(),
}[state_sharding]
FSDP.set_state_dict_type(model, state_sharding, state_dict_cfg,
optim_cfg)

# load FSDP state dict
_set_state_dict_type(mod)
opt_state_dict = FSDP.optim_state_dict(mod, opt)

# make a new model and optimizer
mod_new = FSDP(_DummyModule(device=device, dtype=dtype))
opt_new = Lion8bit(mod_new.parameters(), error_correction=use_errors)
_set_state_dict_type(mod_new)

# load state dict into the new optimizer
opt_state_dict_slice = FSDP.optim_state_dict_to_load(
optim_state_dict=opt_state_dict, model=mod_new, optim=opt_new)
opt_new.load_state_dict(opt_state_dict_slice)

new_opt_state_dict = FSDP.optim_state_dict(mod_new, opt_new)

orig_state = opt_state_dict['state']
orig_param_groups = opt_state_dict['param_groups']
new_state = new_opt_state_dict['state']
new_param_groups = new_opt_state_dict['param_groups']

all_keys = set(orig_state.keys()) | set(new_state.keys())
assert orig_param_groups == new_param_groups # works since strs, not ptrs
for k in all_keys: # keys are param paths in module as strings
d_orig = orig_state[k]
d_new = new_state[k]
assert list(d_orig.keys()) == list(d_new.keys())
mom_orig = d_orig['exp_avg']
mom_new = d_new['exp_avg']

assert mom_orig.shape == mom_new.shape
assert mom_orig.dtype == mom_new.dtype
if use_errors and (dtype != torch.float32):
errs_orig = d_orig['errors']
errs_new = d_new['errors']
assert errs_orig.shape == errs_new.shape
assert errs_orig.dtype == errs_new.dtype

if state_sharding != _FULL_STATE:
continue # more detailed checks lean on FSDP impl details

# momentums may not be bit-for-bit identical because Optimizer upcasts
# to f32 and we convert back to bf16, possibly with different rounding
torch.testing.assert_close(mom_orig, mom_new)
# errors not bit-for-bit identical because scales get upcast too
if use_errors and (dtype != torch.float32):
torch.testing.assert_close(d_orig['errors'], d_new['errors'])


@pytest.mark.gpu
Expand Down
Loading