Skip to content

Commit

Permalink
[shardformer] hybridparallelplugin support gradients accumulation. (#…
Browse files Browse the repository at this point in the history
…5246)

* support gradients acc

fix

fix

fix

fix

fix

fix

fix

fix

fix

fix

fix

fix

fix

* fix

fix

* fix

fix

fix
  • Loading branch information
flybird11111 authored Jan 17, 2024
1 parent 2a0558d commit 46e0916
Show file tree
Hide file tree
Showing 2 changed files with 174 additions and 8 deletions.
20 changes: 12 additions & 8 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,6 @@ def sync_sp_grads(self, grads: Optional[List[torch.Tensor]] = None):
Returns:
None
"""

if self.tp_group.size() > 1 and self.shard_config.enable_sequence_parallelism:
if grads is not None:
# Synchronize provided gradient tensors across the tensor parallelism group.
Expand Down Expand Up @@ -487,7 +486,6 @@ def backward(self, loss: Tensor, *args, **kwargs):
Returns:
None
"""

# Call the superclass backward method to compute gradients.
super().backward(loss, *args, **kwargs)

Expand All @@ -513,7 +511,6 @@ def backward_by_grad(self, tensor: Tensor, grad: Tensor):
Returns:
None
"""

# Call the superclass backward method to compute gradients.
super().backward_by_grad(tensor, grad)

Expand Down Expand Up @@ -674,7 +671,6 @@ def sync_dp_grads(self):
Returns:
None
"""

# Call the superclass `_sync_grad` method to synchronize gradients.
super()._sync_grad()

Expand Down Expand Up @@ -1081,7 +1077,7 @@ def control_precision(self) -> bool:
return True

def support_no_sync(self) -> bool:
return False
return True

def control_checkpoint_io(self) -> bool:
return True
Expand Down Expand Up @@ -1175,9 +1171,14 @@ def execute_pipeline(
model, data_iter, criterion, optimizer, return_loss, return_outputs
)

# run with gradients accumulation
if model.require_grad_sync == False or (
isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False
):
return outputs

# Synchronize the grads of shared parameters of the model.
model.sync_shared_params()

# Synchronize sequence parallelism gradients of the model.
model.sync_sp_grads()

Expand Down Expand Up @@ -1241,5 +1242,8 @@ def seed_worker(worker_id):
def get_checkpoint_io(self) -> CheckpointIO:
return HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)

def no_sync(self, model: Module) -> Iterator[None]:
raise NotImplementedError
def no_sync(self, model: Module, optimizer: OptimizerWrapper) -> Iterator[None]:
assert (
self.zero_stage != 2
), "ZERO2 is not compatible with no_sync function, please run gradient accumulation with gradient synchronization allowed."
return optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
162 changes: 162 additions & 0 deletions tests/test_booster/test_plugin/test_3d_plugin.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import copy
from contextlib import nullcontext
from typing import Optional

import torch
import torch.distributed as dist
from torch.testing import assert_close
from torch.utils.data import Dataset

import colossalai
from colossalai.booster import Booster
Expand All @@ -11,9 +14,33 @@
from colossalai.lazy.lazy_init import LazyInitContext
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device, set_seed
from tests.kit.model_zoo import model_zoo


class RandomDataset(Dataset):
def __init__(self, num_samples: int = 100, max_length: int = 512, vocab_size: int = 32000):
self.num_samples = num_samples
self.max_length = max_length
set_seed(42)
self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device())
self.attention_mask = torch.ones_like(self.input_ids)

def __len__(self):
return self.num_samples

def __getitem__(self, idx):
return {
"input_ids": self.input_ids[idx],
"attention_mask": self.attention_mask[idx],
"labels": self.input_ids[idx],
}


def move_to_cuda(batch):
return {k: v.cuda() for k, v in batch.items()}


@clear_cache_before_run()
def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]:
try:
Expand Down Expand Up @@ -85,10 +112,145 @@ def check_3d_plugin(init_method: str = "none", early_stop: bool = True):
assert len(failed_info) == 0, "\n".join([f"{k}: {v}" for k, v in failed_info.items()])


@parameterize(
"test_args",
[
{
"batch_size": 8,
"num_steps": 4,
"tp": 2,
"pp": 2,
"pp_style": "1f1b",
"num_model_chunks": 1,
"num_microbatches": 4,
"zero": 0,
"precision": "fp16",
"initial_scale": 1,
"max_length": 512,
"gradient_accumulation_step": 2,
},
{
"batch_size": 8,
"num_steps": 4,
"tp": 1,
"pp": 2,
"pp_style": "1f1b",
"num_model_chunks": 1,
"num_microbatches": 4,
"zero": 1,
"precision": "fp16",
"initial_scale": 1,
"max_length": 512,
"gradient_accumulation_step": 2,
},
{
"batch_size": 1,
"num_steps": 4,
"tp": 2,
"pp": 1,
"pp_style": "1f1b",
"num_model_chunks": 1,
"num_microbatches": 1,
"zero": 2,
"precision": "fp16",
"initial_scale": 1,
"max_length": 512,
"gradient_accumulation_step": 2,
},
{
"batch_size": 1,
"num_steps": 4,
"tp": 2,
"pp": 1,
"pp_style": "1f1b",
"num_model_chunks": 1,
"num_microbatches": 1,
"zero": 0,
"precision": "fp16",
"initial_scale": 1,
"max_length": 512,
"gradient_accumulation_step": 2,
},
],
)
def run_grad_acc_test(test_args):
model_fn, *_ = next(iter(model_zoo.get_sub_registry("transformers_gpt_lm").values()))
model = model_fn()
optimizer = HybridAdam(model.parameters())
origin_model = copy.deepcopy(model).cuda()
origin_optimizer = HybridAdam(origin_model.parameters())

plugin = HybridParallelPlugin(
tp_size=test_args["tp"],
pp_size=test_args["pp"],
pp_style=test_args["pp_style"],
zero_stage=test_args["zero"],
num_model_chunks=test_args["num_model_chunks"],
enable_fused_normalization=True,
num_microbatches=test_args["num_microbatches"],
precision=test_args["precision"],
)
booster = Booster(plugin=plugin)

dataset = RandomDataset(
num_samples=test_args["batch_size"] * test_args["num_steps"] * plugin.dp_size,
max_length=test_args["max_length"],
vocab_size=model.config.vocab_size,
)
dataloader = plugin.prepare_dataloader(dataset, batch_size=test_args["batch_size"], shuffle=True, drop_last=True)

model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader)

grad_accu_step = test_args["gradient_accumulation_step"]
for step, batch in enumerate(dataloader):
batch = move_to_cuda(batch)
# train origin model
origin_output = origin_model(**batch)
origin_loss = origin_output[0] / grad_accu_step
origin_loss.backward()

if (step + 1) % grad_accu_step != 0 and test_args["zero"] != 2:
ctx = booster.no_sync(model, optimizer)
else:
ctx = nullcontext()

with ctx:
if plugin.stage_manager is not None:
batch = iter([batch])
booster.execute_pipeline(
batch,
model,
criterion=lambda outputs, inputs: outputs[0] / grad_accu_step,
optimizer=optimizer,
return_loss=False,
)
else:
outputs = model(**batch)
loss = outputs[0] / grad_accu_step
booster.backward(loss, optimizer)

if (step + 1) % grad_accu_step == 0:
# update origin model weight
origin_optimizer.step()
origin_optimizer.zero_grad()

# update sharded model
optimizer.step()
optimizer.zero_grad()

# tricky code here, shard the origin model inorder to check the parameters in the same stage.
origin_model, origin_optimizer, _, dataloader, _ = booster.boost(
origin_model, origin_optimizer, dataloader=dataloader
)
for p1, p2 in zip(model.unwrap().parameters(), origin_model.unwrap().parameters()):
assert_close(p1.to(p2.dtype), p2, atol=1e-2, rtol=1e-2)


def run_dist(rank, world_size, port, early_stop: bool = True):
# init dist env
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost")
check_3d_plugin(early_stop=early_stop)
run_grad_acc_test()


@rerun_if_address_is_in_use()
Expand Down

0 comments on commit 46e0916

Please sign in to comment.