Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[shardformer] support pipeline and hybrid parallelism #4441

Merged
merged 80 commits into from
Aug 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
80 commits
Select commit Hold shift + click to select a range
1c2f032
[cluster] add process group mesh (#4039)
ver217 Jun 20, 2023
d71b8e4
[pipeline] add stage manager (#4093)
ver217 Jun 27, 2023
a636ea4
[pipeline] implement p2p communication (#4100)
ver217 Jun 28, 2023
aa690f9
[pipeline] refactor 1f1b schedule (#4115)
ver217 Jun 29, 2023
3c2fbe2
[pipeline]add pipeline policy and bert forward (#4130)
CjhHa1 Jul 4, 2023
b902be2
[pipeline] add stage manager (#4093)
ver217 Jun 27, 2023
35213bc
[pipeline]add pipeline policy and bert forward (#4130)
CjhHa1 Jul 4, 2023
c80f8cd
[pipeline] build bloom model and policy , revise the base class of po…
CjhHa1 Jul 5, 2023
223de46
[pipeline] update shardformer policy
ver217 Jul 5, 2023
318b5e8
[pipeline] update shardformer docstring
ver217 Jul 5, 2023
71470a2
[test] update shardformer tests
ver217 Jul 5, 2023
fb4b18d
[test] add shard util tests
ver217 Jul 5, 2023
d6101a8
[shardformer] rename policy file name
ver217 Jul 5, 2023
951bde4
[shardformer] fix type hint
ver217 Jul 5, 2023
c9511ad
[pipeline] add bert_for_pretraining bert_lmhead forward and policy (#…
CjhHa1 Jul 6, 2023
6f47636
[pipeline] move bert related pipeline components to shardformer (#4187)
CjhHa1 Jul 7, 2023
b66870a
[shardformer] support lazy init (#4202)
ver217 Jul 10, 2023
612a5d1
[pipeline] Bert pipeline for shardformer and its tests (#4197)
CjhHa1 Jul 10, 2023
fa4515f
[pipeline] Llama pipeline (#4205)
CjhHa1 Jul 11, 2023
8161405
[pipeline] Llama causal lm and llama for sequence classification pipe…
CjhHa1 Jul 11, 2023
701257a
[pipeline] add bloom model pipeline (#4210)
CjhHa1 Jul 13, 2023
6c9b161
[pipeline] Add Pipeline Forward for GPT2Model Shardformer (#4224)
Fridge003 Jul 13, 2023
b7673a1
[shardformer] fix base policy (#4229)
ver217 Jul 14, 2023
59cbbe0
[pipeline] add pipeline forward for variants of gpt2 (#4238)
Fridge003 Jul 17, 2023
f269575
[pipeline] All bert models (#4233)
CjhHa1 Jul 17, 2023
db33d4a
[pipeline] finish bloom models pipeline and tests (#4223)
CjhHa1 Jul 17, 2023
1fcafe9
[bugs] hot fix some testing bugs for new models (#4268)
CjhHa1 Jul 18, 2023
e902544
[pipeline] support shardformer for GPT2ForQuestionAnswering & complet…
Fridge003 Jul 19, 2023
1d3a5c4
[shardformer] support inplace sharding (#4251)
ver217 Jul 20, 2023
bd6cf03
[pipeline] refactor gpt2 pipeline forwards (#4287)
Fridge003 Jul 20, 2023
6221a24
[pipeline] OPT model pipeline (#4258)
CjhHa1 Jul 20, 2023
29efdc3
[hotfix] fix opt pipeline (#4293)
CjhHa1 Jul 20, 2023
7d32333
[pipeline] reformat for unified design (#4283)
CjhHa1 Jul 21, 2023
63c2f30
[pipeline] add pipeline support for T5Stack/T5EncoderModel (#4300)
Fridge003 Jul 21, 2023
2fe393f
[pipeline] test pure pipeline process using llama (#4218)
CjhHa1 Jul 25, 2023
f172208
[pipeline] add pipeline support for all T5 models (#4310)
Fridge003 Jul 25, 2023
668ab10
[shardformer] support pipeline base vit model (#4284)
FoolPlayer Jul 25, 2023
9e4018d
[plugin] add 3d parallel plugin (#4295)
ver217 Jul 25, 2023
1cf9e01
[hotfix] fix gemini and zero test (#4333)
ver217 Jul 27, 2023
f24f3c2
[pipeline] fix return_dict/fix pure_pipeline_test (#4331)
Fridge003 Jul 27, 2023
064cea0
[pipeline] add unit test for 1f1b (#4303)
Gy-Lu Jul 31, 2023
daef153
[pipeline] refactor test pipeline and remove useless utils in pipelin…
CjhHa1 Aug 1, 2023
1920ce1
[pipeline] support fp32 for HybridPlugin/merge shardformer test and p…
Fridge003 Aug 1, 2023
a02ddd2
Feature/vit support (#4182)
klhhhhh Jul 7, 2023
3896e92
[shardformer] support SAM (#4231)
FoolPlayer Jul 14, 2023
98e382c
[shardformer] support whisper (#4212)
FoolPlayer Jul 17, 2023
7dea2e9
Feature/chatglm (#4240)
klhhhhh Jul 20, 2023
e48ce72
[shardformer] added tests
klhhhhh Jul 4, 2023
fd29900
[shardformer] vit test finish and support
klhhhhh Jul 6, 2023
65d663e
import chatglm
klhhhhh Jul 7, 2023
b0a15c3
[shardformer] add test kit in model zoo for chatglm
klhhhhh Jul 7, 2023
9576cd7
[sharformer] add first version of policy of chatglm
klhhhhh Jul 10, 2023
f1ec91a
[shardformer] polish chatglm code
klhhhhh Jul 12, 2023
a7c54bd
[shardformer] polish code
klhhhhh Jul 13, 2023
955b1d8
[shardformer] support chatglm without layernorm
klhhhhh Jul 14, 2023
e9ebe69
[shardformer] delete some file
klhhhhh Jul 17, 2023
965eb53
[shardformer] ChatGLM support layernorm sharding
klhhhhh Jul 17, 2023
fb7b53c
[shardformer] register without auto policy
klhhhhh Jul 18, 2023
2f94308
[shardformer] pre-commit check files
klhhhhh Jul 19, 2023
2bb51ea
[shardformer] support ChatGLMForConditionalGeneration & add fusedlaye…
klhhhhh Jul 20, 2023
1e7b804
[shardformer] support Blip2 (#4243)
FoolPlayer Jul 25, 2023
ee54290
update some module with new api version
FoolPlayer Aug 1, 2023
383ce5b
[test] skip some not compatible models
FoolPlayer Aug 2, 2023
4b1a574
[test] Hotfix/fix some model test and refactor check util api (#4369)
FoolPlayer Aug 3, 2023
5834885
[shardformer] add util functions for shardformer tests/fix sync_share…
Fridge003 Aug 3, 2023
0e3a8da
[pipeline] add chatglm (#4363)
CjhHa1 Aug 4, 2023
4711174
[Shardformer] Merge flash attention branch to pipeline branch (#4362)
flybird11111 Aug 7, 2023
eff7572
[pipeline] rewrite t5 tests & support multi-tensor transmitting in pi…
Fridge003 Aug 8, 2023
a3631fd
[shardformer] update shardformer to use flash attention 2 (#4392)
flybird11111 Aug 9, 2023
1fa9a26
[shardformer] test all optimizations (#4399)
flybird11111 Aug 10, 2023
4a7b0f1
[pipeline] rewrite bert tests and fix some bugs (#4409)
CjhHa1 Aug 11, 2023
58f9fce
[shardformer]fix, test gpt2 for AMP+TP (#4403)
flybird11111 Aug 11, 2023
ebe81c3
[shardformer] rewrite tests for opt/bloom/llama/vit/chatglm (#4395)
Fridge003 Aug 11, 2023
d79f3cf
[shardformer] update tests for all optimization (#4413)
flybird11111 Aug 11, 2023
afbee8e
[shardformer]update t5 tests for using all optimizations. (#4407)
flybird11111 Aug 14, 2023
6968eb8
[shardformer] update bloom/llama/vit/chatglm tests (#4420)
flybird11111 Aug 14, 2023
e7dc4a3
[misc] resolve code factor issues (#4433)
ver217 Aug 14, 2023
f7b82f5
[misc] update requirements
ver217 Aug 15, 2023
48970f2
[shardformer] fix embedding
ver217 Aug 15, 2023
025c543
[shardformer] fix import
ver217 Aug 15, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 149 additions & 0 deletions colossalai/amp/naive_amp/mixed_precision_optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
from typing import Dict, List

import torch
from torch import Tensor
from torch.nn import Parameter
from torch.optim import Optimizer

from colossalai.interface import OptimizerWrapper

from .mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin


class NaiveFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):

def __init__(self,
working_params: List[Parameter],
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32) -> None:
super().__init__(initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis,
max_scale)
self.params = working_params

def check_local_overflow(self) -> bool:
for p in self.params:
if p.grad is not None and not torch.isfinite(p.grad).all():
return True
return False


class MixedPrecisionOptimizer(OptimizerWrapper):

def __init__(self,
optim: Optimizer,
precision: str = 'fp16',
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32,
max_norm: float = 0.0):
super().__init__(optim)
if precision == 'fp16':
working_params = []
for group in self.optim.param_groups:
for p in group['params']:
working_params.append(p)
self.mixed_precision = NaiveFP16MixedPrecisionMixin(working_params,
initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
max_scale=max_scale)
elif precision == 'bf16':
self.mixed_precision = BF16MixedPrecisionMixin()
else:
raise ValueError(f'Unsupported precision: {precision}')
if max_norm > 0.0:
raise NotImplementedError('max_norm is not supported yet.')
self.max_norm = max_norm
self.working_to_master_map: Dict[Parameter, Tensor] = {}
self.master_to_working_map: Dict[Tensor, Parameter] = {}

# create master weights
for group in self.optim.param_groups:
master_params = []
for p in group['params']:
if p.requires_grad:
master_p = p
if p.dtype != torch.float:
master_p = p.detach().float()
self.working_to_master_map[p] = master_p
self.master_to_working_map[master_p] = p
master_params.append(master_p)
group['params'] = master_params

def backward(self, loss: Tensor, *args, **kwargs):
loss = self.mixed_precision.pre_backward(loss)
loss.backward(*args, **kwargs)

def backward_by_grad(self, tensor: Tensor, grad: Tensor):
grad = self.mixed_precision.pre_backward_by_grad(tensor, grad)
tensor.backward(grad)

def zero_grad(self, *args, **kwargs):
for p in self.working_to_master_map.keys():
p.grad = None
self.mixed_precision.pre_zero_grad()
return super().zero_grad(*args, **kwargs)

def _unscale_and_clip_grads(self, total_norm: float) -> None:
div_scale = 1.0
if self.mixed_precision is not None:
div_scale = self.mixed_precision.get_grad_div_scale()

if self.max_norm > 0.:
# norm is in fact norm*scale
clip = ((total_norm / div_scale) + 1e-6) / self.max_norm
if clip > 1:
div_scale = clip * div_scale

for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
p.grad.data.mul_(1. / div_scale)

def _compute_grad_norm(self) -> float:
if self.max_norm <= 0.:
return 0.
grads = [p.grad for group in self.param_groups for p in group['params'] if p.grad is not None]
if len(grads) == 0:
return 0.
device = grads[0].device
# TODO(ver217): support tp
total_norm = torch.norm(torch.stack([torch.norm(g.detach(), 2).to(device) for g in grads]), 2)
return total_norm.item()

def step(self, *args, **kwargs):
if self.mixed_precision.should_skip_step():
self.zero_grad()
return
# prepare grads
for group in self.optim.param_groups:
for p in group['params']:
working_param = self.master_to_working_map[p]
if p is working_param:
continue
if working_param.grad is not None:
p.grad = working_param.grad.data.float()
working_param.grad = None
total_norm = self._compute_grad_norm()
self._unscale_and_clip_grads(total_norm)
self.optim.step(*args, **kwargs)
# update working params
for group in self.optim.param_groups:
for p in group['params']:
working_param = self.master_to_working_map[p]
if p is working_param:
continue
working_param.data.copy_(p.data)
14 changes: 8 additions & 6 deletions colossalai/booster/booster.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import warnings
from contextlib import contextmanager
from typing import Callable, Iterator, List, Optional, Tuple, Union
from typing import Any, Callable, Iterator, List, Optional, Union

import torch
import torch.nn as nn
Expand All @@ -14,6 +14,7 @@
from .accelerator import Accelerator
from .mixed_precision import MixedPrecision, mixed_precision_factory
from .plugin import Plugin
from .plugin.pp_plugin_base import PipelinePluginBase

__all__ = ['Booster']

Expand Down Expand Up @@ -138,20 +139,21 @@ def backward(self, loss: torch.Tensor, optimizer: Optimizer) -> None:
loss (torch.Tensor): The loss to be backpropagated.
optimizer (Optimizer): The optimizer to be updated.
"""
# TODO: implement this method with plugin
# TODO(frank lee): implement this method with plugin
optimizer.backward(loss)

def execute_pipeline(self,
data_iter: Iterator,
model: nn.Module,
criterion: Callable[[torch.Tensor], torch.Tensor],
criterion: Callable[[Any, Any], torch.Tensor],
optimizer: Optimizer,
return_loss: bool = True,
return_outputs: bool = False) -> Tuple[Optional[torch.Tensor], ...]:
# TODO: implement this method
return_outputs: bool = False) -> dict:
# run pipeline forward backward pass
# return loss or outputs if needed
pass
assert isinstance(self.plugin,
PipelinePluginBase), f'The plugin {self.plugin.__class__.__name__} does not support pipeline.'
return self.plugin.execute_pipeline(data_iter, model, criterion, optimizer, return_loss, return_outputs)

def no_sync(self, model: nn.Module = None, optimizer: OptimizerWrapper = None) -> contextmanager:
"""Context manager to disable gradient synchronization across DP process groups.
Expand Down
3 changes: 2 additions & 1 deletion colossalai/booster/plugin/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from .gemini_plugin import GeminiPlugin
from .hybrid_parallel_plugin import HybridParallelPlugin
from .low_level_zero_plugin import LowLevelZeroPlugin
from .plugin_base import Plugin
from .torch_ddp_plugin import TorchDDPPlugin

__all__ = ['Plugin', 'TorchDDPPlugin', 'GeminiPlugin', 'LowLevelZeroPlugin']
__all__ = ['Plugin', 'TorchDDPPlugin', 'GeminiPlugin', 'LowLevelZeroPlugin', 'HybridParallelPlugin']

import torch
from packaging import version
Expand Down
Loading
Loading