Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[shardformer] support pp+tp+zero1 tests #4531

Merged
merged 17 commits into from
Aug 30, 2023
15 changes: 13 additions & 2 deletions colossalai/zero/low_level/low_level_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,12 +333,23 @@ def backward(self, loss, retain_graph=False):
self.zero_grad()

def backward_by_grad(self, tensor, grad):
# in lower stage which grad is transfered by higher stage
# we need to pass the optim state down.
assert not(self._partition_grads and not self.require_grad_sync), \
"ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible"

if self.mixed_precision_mixin is not None:
grad = self.mixed_precision_mixin.pre_backward_by_grad(tensor, grad)
torch.autograd.backward(tensor, grad)

if not self.require_grad_sync:
return
self._reduce_grad(self._partition_grads)

# clear reduced grads
if self._overlap_communication:
torch.cuda.synchronize()

self.zero_grad()

def zero_grad(self, set_to_none=True):
"""
Set parameter gradients to zero. If set_to_none = True, gradient
Expand Down
9 changes: 9 additions & 0 deletions tests/test_shardformer/test_model/test_shard_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,15 @@ def run_bert_test(test_config):
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp32',
},
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp16',
'zero_stage': 1,
'initial_scale': 1,
},
])
Expand Down
10 changes: 10 additions & 0 deletions tests/test_shardformer/test_model/test_shard_bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,16 @@ def run_bloom_test(test_config):
'precision': 'fp32',
'initial_scale': 1,
},
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp16',
'zero_stage': 1,
'initial_scale': 1,
},
])
def run_bloom_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom')
Expand Down
10 changes: 10 additions & 0 deletions tests/test_shardformer/test_model/test_shard_chatglm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,16 @@ def run_chatglm_test(test_config):
'precision': 'fp32',
'initial_scale': 1,
},
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp16',
'zero_stage': 1,
'initial_scale': 1,
},
])
def run_chatglm_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_chatglm')
Expand Down
10 changes: 10 additions & 0 deletions tests/test_shardformer/test_model/test_shard_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,16 @@ def run_gpt2_test(test_config):
'precision': 'fp32',
'initial_scale': 1,
},
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp16',
'zero_stage': 1,
'initial_scale': 1,
},
])
@clear_cache_before_run()
def run_gpt2_3d_test(test_config):
Expand Down
10 changes: 10 additions & 0 deletions tests/test_shardformer/test_model/test_shard_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,16 @@ def run_llama_test(test_config):
'precision': 'fp32',
'initial_scale': 1,
},
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp16',
'zero_stage': 1,
'initial_scale': 1,
},
])
def run_llama_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_llama')
Expand Down
10 changes: 10 additions & 0 deletions tests/test_shardformer/test_model/test_shard_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,16 @@ def run_opt_test(test_config):
'precision': 'fp32',
'initial_scale': 1,
},
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp16',
'zero_stage': 1,
'initial_scale': 1,
},
])
def run_opt_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_opt')
Expand Down
10 changes: 10 additions & 0 deletions tests/test_shardformer/test_model/test_shard_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,16 @@ def run_t5_test(test_config):
'precision': 'fp32',
'initial_scale': 1,
},
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp16',
'zero_stage': 1,
'initial_scale': 1,
},
])
def run_t5_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_t5')
Expand Down
9 changes: 9 additions & 0 deletions tests/test_shardformer/test_model/test_shard_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,15 @@ def run_vit_test(test_config):
'precision': 'fp32',
'initial_scale': 1,
},
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 2,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp32',
'initial_scale': 1,
},
])
def run_vit_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_vit')
Expand Down
11 changes: 10 additions & 1 deletion tests/test_shardformer/test_model/test_shard_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,

# check weights
if test_config['precision'] == 'fp32':
atol, rtol = 5e-4, 5e-4
atol, rtol = 1e-3, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if stage_manager is None or stage_manager.is_first_stage():
Expand Down Expand Up @@ -195,6 +195,15 @@ def run_whisper_test(test_config):
'precision': 'fp32',
'initial_scale': 1,
},
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 2,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp32',
'initial_scale': 1,
},
])
def run_whisper_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_whisper')
Expand Down