Skip to content

Commit

Permalink
[fix] fix fp8 overlap code
Browse files Browse the repository at this point in the history
  • Loading branch information
duanjunwen committed Nov 19, 2024
1 parent cb9e5cc commit 8a0bad9
Show file tree
Hide file tree
Showing 3 changed files with 1 addition and 41 deletions.
10 changes: 0 additions & 10 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,15 +116,10 @@ def __init__(

super().__init__(module)
self.op_hooks = []
if use_fp8:
self.op_hooks.append(FP8Hook())
self.op_hooks = []
if use_fp8:
self.op_hooks.append(FP8Hook())
if overlap_allgather:
self.op_hooks.append(ZeroOpHook())
if use_fp8 or overlap_allgather:
self.op_hooks.append(ZeroOpHook())
if use_fp8 or overlap_allgather:
for p in module.parameters():
if p.requires_grad and type(p) is not ColoParameter:
Expand Down Expand Up @@ -237,9 +232,6 @@ def _force_wait_all_gather(self):
def _hook_context(self):
return ColoParamOpHookManager.use_hooks(*self.op_hooks) if len(self.op_hooks) > 0 else nullcontext()

def _hook_context(self):
return ColoParamOpHookManager.use_hooks(*self.op_hooks) if len(self.op_hooks) > 0 else nullcontext()


def get_param_info(optim: Optimizer):
# Get a backup of necessary information of parameters for future use, which includes:
Expand Down Expand Up @@ -995,8 +987,6 @@ class HybridParallelPlugin(PipelinePluginBase):
make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64.
fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False.
use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False.
fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False.
use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False.
overlap_p2p (bool, optional): Whether to overlap the p2p communication in pipeline parallelism
inner_ring_size (int, optional): The inner ring size of 2D Ring Attention when sp mode is "ring_attn".
It's advisable to not tune this (especially in single-node settings) and let it be heuristically set based on topology by default.
Expand Down
5 changes: 1 addition & 4 deletions examples/language/llama/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,12 +129,9 @@ def empty_init():
{
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(
num_ckpt_layers_per_stage=[19, 19, 19, 13],
# num_ckpt_layers_per_stage=[48, 48, 48, 48],
),
"num_layers_per_stage": [19, 20, 20, 21],
# "num_layers_per_stage": [48, 48, 48, 48],
# "pp_style": "interleaved",
"pp_style": "1f1b",
"pp_style": "interleaved",
}
if args.custom_ckpt
else {}
Expand Down
27 changes: 0 additions & 27 deletions tests/test_shardformer/test_model/test_shard_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,33 +277,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"precision": "fp16",
"initial_scale": 1,
},
# # TODO: assert layer error
# {
# "tp_size": 2,
# "pp_size": 2,
# "pp_style": "zbv",
# "num_model_chunks": 2,
# "num_microbatches": 4,
# "enable_all_optimization": False,
# "precision": "fp16",
# "zero_stage": 0,
# "initial_scale": 1,
# "enable_gradient_checkpointing": True,
# "parallel_output": False,
# },
# {
# "tp_size": 2,
# "pp_size": 2,
# "pp_style": "zbv",
# "num_model_chunks": 2,
# "num_microbatches": 4,
# "enable_all_optimization": False,
# "precision": "fp16",
# "zero_stage": 1,
# "initial_scale": 1,
# "enable_gradient_checkpointing": True,
# "parallel_output": False,
# },
],
)
def run_llama_test(test_config):
Expand Down

0 comments on commit 8a0bad9

Please sign in to comment.