Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[moe] merge branch feature/pipeline #4432

Closed
wants to merge 83 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
83 commits
Select commit Hold shift + click to select a range
1015f04
[cluster] add process group mesh (#4039)
ver217 Jun 20, 2023
b10821a
[pipeline] add stage manager (#4093)
ver217 Jun 27, 2023
bd6b0a3
[pipeline] implement p2p communication (#4100)
ver217 Jun 28, 2023
faeac9d
[pipeline] refactor 1f1b schedule (#4115)
ver217 Jun 29, 2023
9852743
[pipeline]add pipeline policy and bert forward (#4130)
CjhHa1 Jul 4, 2023
3be0c35
[cluster] add process group mesh (#4039)
ver217 Jun 20, 2023
18c7539
[pipeline] add stage manager (#4093)
ver217 Jun 27, 2023
5a467e9
[pipeline] implement p2p communication (#4100)
ver217 Jun 28, 2023
9526f44
[pipeline] refactor 1f1b schedule (#4115)
ver217 Jun 29, 2023
836a3a2
[pipeline]add pipeline policy and bert forward (#4130)
CjhHa1 Jul 4, 2023
ef1f972
Merge pull request #4166 from ver217/sync/main
FrankLeeeee Jul 4, 2023
386d34e
[pipeline] build bloom model and policy , revise the base class of po…
CjhHa1 Jul 5, 2023
15a4e82
[pipeline] update shardformer policy
ver217 Jul 5, 2023
8b6679d
[pipeline] update shardformer docstring
ver217 Jul 5, 2023
9143556
[test] update shardformer tests
ver217 Jul 5, 2023
0cbe423
[test] add shard util tests
ver217 Jul 5, 2023
1a87dd7
[shardformer] rename policy file name
ver217 Jul 5, 2023
d4b96ab
[shardformer] fix type hint
ver217 Jul 5, 2023
12e6d5d
Merge pull request #4176 from ver217/feature/pipeline-policy
CjhHa1 Jul 5, 2023
15b34e0
[pipeline] add bert_for_pretraining bert_lmhead forward and policy (#…
CjhHa1 Jul 6, 2023
ec217de
Feature/vit support (#4182)
klhhhhh Jul 7, 2023
c6f9c2c
[pipeline] move bert related pipeline components to shardformer (#4187)
CjhHa1 Jul 7, 2023
0192011
[shardformer] support lazy init (#4202)
ver217 Jul 10, 2023
b30d1b9
[pipeline] Bert pipeline for shardformer and its tests (#4197)
CjhHa1 Jul 10, 2023
a2619c3
[pipeline] Llama pipeline (#4205)
CjhHa1 Jul 11, 2023
981764c
[pipeline] Llama causal lm and llama for sequence classification pipe…
CjhHa1 Jul 11, 2023
3595eba
[pipeline] add bloom model pipeline (#4210)
CjhHa1 Jul 13, 2023
236f294
[pipeline] Add Pipeline Forward for GPT2Model Shardformer (#4224)
Fridge003 Jul 13, 2023
ad2687c
[shardformer] fix base policy (#4229)
ver217 Jul 14, 2023
ddecf73
[shardformer] support SAM (#4231)
FoolPlayer Jul 14, 2023
afcf4a0
[shardformer] support whisper (#4212)
FoolPlayer Jul 17, 2023
383d2e3
[pipeline] add pipeline forward for variants of gpt2 (#4238)
Fridge003 Jul 17, 2023
7b8756f
[pipeline] All bert models (#4233)
CjhHa1 Jul 17, 2023
a895458
[pipeline] finish bloom models pipeline and tests (#4223)
CjhHa1 Jul 17, 2023
843158b
[bugs] hot fix some testing bugs for new models (#4268)
CjhHa1 Jul 18, 2023
3918898
[pipeline] support shardformer for GPT2ForQuestionAnswering & complet…
Fridge003 Jul 19, 2023
7b5a155
[shardformer] support inplace sharding (#4251)
ver217 Jul 20, 2023
cc120c6
[pipeline] refactor gpt2 pipeline forwards (#4287)
Fridge003 Jul 20, 2023
7b583c0
[pipeline] OPT model pipeline (#4258)
CjhHa1 Jul 20, 2023
3b92e4a
[hotfix] fix opt pipeline (#4293)
CjhHa1 Jul 20, 2023
77cc087
Feature/chatglm (#4240)
klhhhhh Jul 20, 2023
6c2acf0
[shardformer] added tests
klhhhhh Jul 4, 2023
7668b24
[shardformer] vit test finish and support
klhhhhh Jul 6, 2023
b135b75
import chatglm
klhhhhh Jul 7, 2023
e3cd5cb
[shardformer] add test kit in model zoo for chatglm
klhhhhh Jul 7, 2023
30574a7
[sharformer] add first version of policy of chatglm
klhhhhh Jul 10, 2023
28677d4
[shardformer] polish chatglm code
klhhhhh Jul 12, 2023
28319c2
[shardformer] polish code
klhhhhh Jul 13, 2023
3f19de9
[shardformer] support chatglm without layernorm
klhhhhh Jul 14, 2023
2a4bbcf
[shardformer] delete some file
klhhhhh Jul 17, 2023
32448e3
[shardformer] ChatGLM support layernorm sharding
klhhhhh Jul 17, 2023
eb1c71a
[shardformer] register without auto policy
klhhhhh Jul 18, 2023
127e385
[shardformer] pre-commit check files
klhhhhh Jul 19, 2023
9d5b141
[shardformer] support ChatGLMForConditionalGeneration & add fusedlaye…
klhhhhh Jul 20, 2023
d7e584c
[pipeline] reformat for unified design (#4283)
CjhHa1 Jul 21, 2023
9605805
[pipeline] add pipeline support for T5Stack/T5EncoderModel (#4300)
Fridge003 Jul 21, 2023
805f342
Merge pull request #4297 from klhhhhh/feature/support_ChatGLMForCondi…
klhhhhh Jul 21, 2023
f48a8bb
[shardformer] support Blip2 (#4243)
FoolPlayer Jul 25, 2023
965bf20
[pipeline] test pure pipeline process using llama (#4218)
CjhHa1 Jul 25, 2023
28e6980
[pipeline] add pipeline support for all T5 models (#4310)
Fridge003 Jul 25, 2023
2e93d9b
[shardformer] support pipeline base vit model (#4284)
FoolPlayer Jul 25, 2023
78dd508
[plugin] add 3d parallel plugin (#4295)
ver217 Jul 25, 2023
8ad05d1
[hotfix] fix gemini and zero test (#4333)
ver217 Jul 27, 2023
d547377
[pipeline] fix return_dict/fix pure_pipeline_test (#4331)
Fridge003 Jul 27, 2023
b941e65
[pipeline] add unit test for 1f1b (#4303)
Gy-Lu Jul 31, 2023
7d5b144
[pipeline] refactor test pipeline and remove useless utils in pipelin…
CjhHa1 Aug 1, 2023
01ef6c5
Merge remote-tracking branch 'upstream/feature/pipeline' into feature…
FoolPlayer Aug 1, 2023
992cbb7
[pipeline] support fp32 for HybridPlugin/merge shardformer test and p…
Fridge003 Aug 1, 2023
260df9e
update some module with new api version
FoolPlayer Aug 1, 2023
5403578
[test] skip some not compatible models
FoolPlayer Aug 2, 2023
b849657
Merge pull request #4358 from hpcaitech/feature/shardformer-models
FrankLeeeee Aug 2, 2023
3bfdd53
[test] Hotfix/fix some model test and refactor check util api (#4369)
FoolPlayer Aug 3, 2023
21c6bb0
[shardformer] add util functions for shardformer tests/fix sync_share…
Fridge003 Aug 3, 2023
c5f4844
[pipeline] add chatglm (#4363)
CjhHa1 Aug 4, 2023
7c84f51
[Shardformer] Merge flash attention branch to pipeline branch (#4362)
flybird11111 Aug 7, 2023
2e77e57
[pipeline] rewrite t5 tests & support multi-tensor transmitting in pi…
Fridge003 Aug 8, 2023
c14920a
[shardformer] update shardformer to use flash attention 2 (#4392)
flybird11111 Aug 9, 2023
ed2c229
[shardformer] test all optimizations (#4399)
flybird11111 Aug 10, 2023
9916a19
[pipeline] rewrite bert tests and fix some bugs (#4409)
CjhHa1 Aug 11, 2023
fcbf80f
[shardformer]fix, test gpt2 for AMP+TP (#4403)
flybird11111 Aug 11, 2023
1e518ae
[shardformer] rewrite tests for opt/bloom/llama/vit/chatglm (#4395)
Fridge003 Aug 11, 2023
d4a3a10
[shardformer] update tests for all optimization (#4413)
flybird11111 Aug 11, 2023
2e962f6
Merge remote-tracking branch 'remotes/ColossalAI/feature/pipeline' in…
oahzxl Aug 14, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 149 additions & 0 deletions colossalai/amp/naive_amp/mixed_precision_optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
from typing import Dict, List

import torch
from torch import Tensor
from torch.nn import Parameter
from torch.optim import Optimizer

from colossalai.interface import OptimizerWrapper

from .mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin


class NaiveFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):

def __init__(self,
working_params: List[Parameter],
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32) -> None:
super().__init__(initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis,
max_scale)
self.params = working_params

def check_local_overflow(self) -> bool:
for p in self.params:
if p.grad is not None and not torch.isfinite(p.grad).all():
return True
return False


class MixedPrecisionOptimizer(OptimizerWrapper):

def __init__(self,
optim: Optimizer,
precision: str = 'fp16',
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32,
max_norm: float = 0.0):
super().__init__(optim)
if precision == 'fp16':
working_params = []
for group in self.optim.param_groups:
for p in group['params']:
working_params.append(p)
self.mixed_precision = NaiveFP16MixedPrecisionMixin(working_params,
initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
max_scale=max_scale)
elif precision == 'bf16':
self.mixed_precision = BF16MixedPrecisionMixin()
else:
raise ValueError(f'Unsupported precision: {precision}')
if max_norm > 0.0:
raise NotImplementedError('max_norm is not supported yet.')
self.max_norm = max_norm
self.working_to_master_map: Dict[Parameter, Tensor] = {}
self.master_to_working_map: Dict[Tensor, Parameter] = {}

# create master weights
for group in self.optim.param_groups:
master_params = []
for p in group['params']:
if p.requires_grad:
master_p = p
if p.dtype != torch.float:
master_p = p.detach().float()
self.working_to_master_map[p] = master_p
self.master_to_working_map[master_p] = p
master_params.append(master_p)
group['params'] = master_params

def backward(self, loss: Tensor, *args, **kwargs):
loss = self.mixed_precision.pre_backward(loss)
loss.backward(*args, **kwargs)

def backward_by_grad(self, tensor: Tensor, grad: Tensor):
grad = self.mixed_precision.pre_backward_by_grad(tensor, grad)
tensor.backward(grad)

def zero_grad(self, *args, **kwargs):
for p in self.working_to_master_map.keys():
p.grad = None
self.mixed_precision.pre_zero_grad()
return super().zero_grad(*args, **kwargs)

def _unscale_and_clip_grads(self, total_norm: float) -> None:
div_scale = 1.0
if self.mixed_precision is not None:
div_scale = self.mixed_precision.get_grad_div_scale()

if self.max_norm > 0.:
# norm is in fact norm*scale
clip = ((total_norm / div_scale) + 1e-6) / self.max_norm
if clip > 1:
div_scale = clip * div_scale

for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
p.grad.data.mul_(1. / div_scale)

def _compute_grad_norm(self) -> float:
if self.max_norm <= 0.:
return 0.
grads = [p.grad for group in self.param_groups for p in group['params'] if p.grad is not None]
if len(grads) == 0:
return 0.
device = grads[0].device
# TODO(ver217): support tp
total_norm = torch.norm(torch.stack([torch.norm(g.detach(), 2).to(device) for g in grads]), 2)
return total_norm.item()

def step(self, *args, **kwargs):
if self.mixed_precision.should_skip_step():
self.zero_grad()
return
# prepare grads
for group in self.optim.param_groups:
for p in group['params']:
working_param = self.master_to_working_map[p]
if p is working_param:
continue
if working_param.grad is not None:
p.grad = working_param.grad.data.float()
working_param.grad = None
total_norm = self._compute_grad_norm()
self._unscale_and_clip_grads(total_norm)
self.optim.step(*args, **kwargs)
# update working params
for group in self.optim.param_groups:
for p in group['params']:
working_param = self.master_to_working_map[p]
if p is working_param:
continue
working_param.data.copy_(p.data)
14 changes: 8 additions & 6 deletions colossalai/booster/booster.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import warnings
from contextlib import contextmanager
from typing import Callable, Iterator, List, Optional, Tuple, Union
from typing import Any, Callable, Iterator, List, Optional, Union

import torch
import torch.nn as nn
Expand All @@ -14,6 +14,7 @@
from .accelerator import Accelerator
from .mixed_precision import MixedPrecision, mixed_precision_factory
from .plugin import Plugin
from .plugin.pp_plugin_base import PipelinePluginBase

__all__ = ['Booster']

Expand Down Expand Up @@ -144,14 +145,15 @@ def backward(self, loss: torch.Tensor, optimizer: Optimizer) -> None:
def execute_pipeline(self,
data_iter: Iterator,
model: nn.Module,
criterion: Callable[[torch.Tensor], torch.Tensor],
criterion: Callable[[Any, Any], torch.Tensor],
optimizer: Optimizer,
return_loss: bool = True,
return_outputs: bool = False) -> Tuple[Optional[torch.Tensor], ...]:
# TODO: implement this method
return_outputs: bool = False) -> dict:
# run pipeline forward backward pass
# return loss or outputs if needed
pass
assert isinstance(self.plugin,
PipelinePluginBase), f'The plugin {self.plugin.__class__.__name__} does not support pipeline.'
return self.plugin.execute_pipeline(data_iter, model, criterion, optimizer, return_loss, return_outputs)

def no_sync(self, model: nn.Module = None, optimizer: OptimizerWrapper = None) -> contextmanager:
"""Context manager to disable gradient synchronization across DP process groups.
Expand All @@ -166,7 +168,7 @@ def no_sync(self, model: nn.Module = None, optimizer: OptimizerWrapper = None) -
"""
assert self.plugin is not None, f'no_sync is only enabled when a plugin is provided and the plugin supports no_sync.'
assert self.plugin.support_no_sync(), f'The plugin {self.plugin.__class__.__name__} does not support no_sync.'
return self.plugin.no_sync(model, optimizer)
return self.plugin.no_sync(model)

def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True):
"""Load model from checkpoint.
Expand Down
3 changes: 2 additions & 1 deletion colossalai/booster/plugin/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from .gemini_plugin import GeminiPlugin
from .hybrid_parallel_plugin import HybridParallelPlugin
from .low_level_zero_plugin import LowLevelZeroPlugin
from .plugin_base import Plugin
from .torch_ddp_plugin import TorchDDPPlugin

__all__ = ['Plugin', 'TorchDDPPlugin', 'GeminiPlugin', 'LowLevelZeroPlugin']
__all__ = ['Plugin', 'TorchDDPPlugin', 'GeminiPlugin', 'LowLevelZeroPlugin', 'HybridParallelPlugin']

import torch
from packaging import version
Expand Down
Loading
Loading