Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[moe] support low level zero optim #4429

Merged
merged 5 commits into from
Aug 14, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions colossalai/zero/low_level/low_level_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,16 +131,23 @@ def __init__(
self._grad_store = GradientStore(self.dp_pg, partition_grad=partition_grad)
self._bucket_store = BucketStore(self.dp_pg)

# moe param should not be stored in working_groups
# because they have different parallel strategy
# so we need to store them separately in param_groups
# instead of working_groups
moe_params = list()

# iterate over the param group in the optimizer
# partition these param groups for data parallel training
# and add buffers to parameter store for future access
for group_id, param_group in enumerate(self.optim.param_groups):
group_params = list()
for param in param_group['params']:
# skip moe param
if hasattr(param, "moe_info"):
continue
if param.requires_grad:
# skip moe param
if hasattr(param, "moe_info"):
moe_params.append(param)
continue
group_params.append(param)

# add the working params to working_param_groups for bookkeeping
Expand All @@ -155,6 +162,15 @@ def __init__(
# managed by this data parallel rank
param_group['params'] = master_param_current_rank

# if there are moe params, store in addtional group in optim
if len(moe_params) > 0:
param_group = dict()
for key, value in self.optim.param_groups[0].items():
if key != 'params':
param_group[key] = value
param_group['params'] = moe_params
self.optim.param_groups.append(param_group)

# intialize communication stream for
# communication-compuation overlapping
if self._overlap_communication:
Expand Down Expand Up @@ -418,6 +434,11 @@ def step(self, closure=None):
# update the parameters
self.optim.step()

# release the moe grad
if len(self.param_groups) > len(self._working_param_groups):
for param in self.param_groups[-1]['params']:
param.grad = None

# release the grad
ver217 marked this conversation as resolved.
Show resolved Hide resolved
grad_partition_groups = []
for group_id in range(self.num_param_groups):
Expand Down
173 changes: 61 additions & 112 deletions tests/test_moe/test_moe_zero_optim.py
Original file line number Diff line number Diff line change
@@ -1,144 +1,93 @@
import pytest
import torch
import torch.distributed as dist

import colossalai
from colossalai.amp import convert_to_apex_amp
from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
from colossalai.context import MOE_CONTEXT
from colossalai.engine.gradient_handler import MoeGradientHandler
from colossalai.nn import MoeLoss
from colossalai.nn.optimizer import CPUAdam
from colossalai.testing import assert_equal_in_group, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
from colossalai.zero.legacy.init_ctx import ZeroInitContext
from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy
from colossalai.zero.legacy.sharded_model import ShardedModelV2
from colossalai.zero.legacy.sharded_model.utils import col_model_deepcopy
from colossalai.zero.legacy.sharded_optim import ShardedOptimizerV2
from colossalai.zero.low_level._utils import has_inf_or_nan
from tests.components_to_test.registry import non_distributed_component_funcs
from colossalai.testing import rerun_if_address_is_in_use, spawn
from tests.test_moe.moe_utils import MoeModel


def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> bool:
if loose:
return torch.allclose(tensor_a, tensor_b, atol=1e-2, rtol=1e-3)
return torch.allclose(tensor_a, tensor_b)
def split_ddp_grad(grad, world_size):
with torch.no_grad():
grad = grad.clone().detach().flatten()
padding_size = (world_size - grad.numel() % world_size) % world_size
if padding_size > 0:
grad = torch.nn.functional.pad(grad, [0, padding_size])
splited_grad = grad.split(grad.numel() // world_size)
return splited_grad


def check_sharded_model_params(model, zero_model, loose=False, reuse_fp16_shard=False):
rank = dist.get_rank()
for (name, p), (zero_name, zero_p) in zip(model.named_parameters(), zero_model.named_parameters()):
if zero_p.colo_attr.param_is_sharded:
zero_p = zero_p.colo_attr.data_payload.to(p.device).float()
chunks = torch.flatten(p).chunk(dist.get_world_size())
if rank >= len(chunks):
continue
p = chunks[rank].float()
if zero_p.size(0) > p.size(0):
zero_p = zero_p[:p.size(0)]
else:
zero_p = zero_p.colo_attr.data_payload.to(p.device)

assert p.dtype == zero_p.dtype, "Parameter `{}`:\n{} vs {}".format(name, p.dtype, zero_p.dtype)
assert allclose(p, zero_p, loose=loose), f'{p} vs {zero_p}'


def _run_step(model, optimizer, data, label, criterion, grad_handler):
def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False):
model.train()
optimizer.zero_grad()

if criterion:
y = model(data)
loss = criterion(y, label)
else:
loss = model(data, label)
with torch.cuda.amp.autocast(enabled=enable_autocast):
if criterion:
y = model(data)
loss = criterion(y, label)
else:
loss = model(data, label)
loss = loss.float()

loss = loss.float()
if isinstance(model, ShardedModelV2):
if isinstance(model, LowLevelZeroModel):
optimizer.backward(loss)
else:
loss.backward()
return y

if grad_handler is not None:
grad_handler.handle_gradient()

optimizer.step()


@parameterize("cpu_offload", [True])
@parameterize("use_cpuadam", [True]) # We do not use Hybrid Adam right now, since it has a little bug
@parameterize("reuse_fp16_shard", [True, False])
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
def _run_test_sharded_optim_v2(cpu_offload,
shard_strategy_class,
use_cpuadam,
reuse_fp16_shard,
gpu_margin_mem_ratio=0.0):
shard_strategy = shard_strategy_class()
if use_cpuadam and cpu_offload is False:
return
MOE_CONTEXT.reset_loss()
get_components_func = non_distributed_component_funcs.get_callable('hanging_param_model')
_, train_dataloader, _, optimizer_class, _ = get_components_func()
def run_zero_optim_test(local_rank, world_size, stage=1):
criterion = MoeLoss(aux_weight=0.01, loss_fn=torch.nn.CrossEntropyLoss)

with ZeroInitContext(target_device=torch.device('cpu') if cpu_offload else get_current_device(),
shard_strategy=shard_strategy,
shard_param=True):
zero_model = MoeModel(checkpoint=True)

zero_model = ShardedModelV2(zero_model,
shard_strategy,
tensor_placement_policy='cpu' if cpu_offload else 'cuda',
reuse_fp16_shard=reuse_fp16_shard)

# check whether parameters are identical in ddp
for name, p in zero_model.named_parameters():
if not p.colo_attr.param_is_sharded and p.colo_attr.is_replicated:
assert_equal_in_group(p.colo_attr.data_payload.to(get_current_device()))

model = MoeModel(checkpoint=True).half()
col_model_deepcopy(zero_model, model)
model = model.cuda().float()

if use_cpuadam:
optimizer_class = CPUAdam
optim = optimizer_class(model.parameters(), lr=1e-3)
sharded_optim = optimizer_class(zero_model.parameters(), lr=1e-3)
sharded_optim = ShardedOptimizerV2(zero_model,
sharded_optim,
initial_scale=2**5,
gpu_margin_mem_ratio=gpu_margin_mem_ratio)

amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False)
apex_model, apex_optimizer = convert_to_apex_amp(model, optim, amp_config)
apex_grad_handler = MoeGradientHandler(model)

for i, (data, label) in enumerate(train_dataloader):
if i > 5:
break
data, label = data.cuda(), label.cuda()
_run_step(apex_model, apex_optimizer, data, label, criterion, apex_grad_handler)
_run_step(zero_model, sharded_optim, data, label, criterion, None)
check_sharded_model_params(model, zero_model, loose=True, reuse_fp16_shard=use_cpuadam)
for param in model.parameters():
assert not has_inf_or_nan(param)


def _run_dist(rank, world_size, port):
zero_model = MoeModel(checkpoint=True)
zero_optimizer = torch.optim.Adam(zero_model.parameters())
plugin = LowLevelZeroPlugin(stage=stage, precision="fp32")
booster = Booster(plugin=plugin)
zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer)

torch_model = MoeModel(checkpoint=True)
for zero_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()):
torch_param.data.copy_(zero_param.data)
torch_optimizer = torch.optim.Adam(torch_model.parameters())
torch_model = torch_model.cuda()
grad_handler = MoeGradientHandler(torch_model)

for _ in range(2):
data = torch.randn(16, 4).cuda() / (local_rank + 1)
label = torch.randint(0, 4, (16,)).cuda()
run_fwd_bwd(torch_model, data, label, criterion, None)
run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer)
grad_handler.handle_gradient()

torch_optimizer.step()
zero_optimizer.step()

for (torch_name, torch_param), (zero_name, zero_param) in zip(torch_model.named_parameters(),
zero_model.named_parameters()):
assert torch.allclose(
torch_param.data,
zero_param.data), f"{torch_name}\ntorch_param {torch_param.data}\nzero_param {zero_param.data}"

torch_optimizer.zero_grad()
zero_optimizer.zero_grad()


def run_dist(rank, world_size, port):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
MOE_CONTEXT.setup(seed=42)
_run_test_sharded_optim_v2()
run_zero_optim_test(rank, world_size, stage=1)
run_zero_optim_test(rank, world_size, stage=2)


# use_cpuadam = True can be used with cpu_offload = False
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [2])
@rerun_if_address_is_in_use()
def test_moe_zero_optim(world_size):
spawn(_run_dist, world_size)
spawn(run_dist, world_size)


if __name__ == '__main__':
test_moe_zero_optim(world_size=4)
test_moe_zero_optim(world_size=2)
Loading