Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Zerobubble] merge main. #6142

Merged
merged 132 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
132 commits
Select commit Hold shift + click to select a range
ee9baed
[feat] add zerobubble pp (just a frame now); add POC test for dx_dw; …
duanjunwen Aug 22, 2024
c18ef06
[feat] add dw test;
duanjunwen Aug 23, 2024
203033e
[fix] fix weight not close;
duanjunwen Aug 23, 2024
107230d
[update] update text;
duanjunwen Aug 26, 2024
fd5526b
Merge branch 'main' into dev/zero_bubble
duanjunwen Aug 26, 2024
1d75045
[feat] add test run_fwd_bwd automatic scheduling;
duanjunwen Aug 26, 2024
5e09c8b
[feat] split communication and calculation; fix pop empty send_bwd_bu…
duanjunwen Aug 27, 2024
f1c1a87
[feat] add test for p & p grad;
duanjunwen Aug 27, 2024
1b4bb2b
[feat] add comments for ZBV func;
duanjunwen Aug 27, 2024
283c9ff
[fix] rm useless assign and comments;
duanjunwen Aug 27, 2024
9e0bd1a
[fix] fix ci test; add pytest;
duanjunwen Aug 27, 2024
8b37323
[feat] add run_fwd_bwd_with_microbatch (replace input) & test; add p…
duanjunwen Aug 27, 2024
fe20916
[feat] add apply v_schedule graph; p & p.grad assert err exist;
duanjunwen Aug 27, 2024
29383b2
[fix] update
duanjunwen Aug 28, 2024
d6e3d7d
[feat] fix ci; add assert;
duanjunwen Aug 28, 2024
b5f7b4d
[feat] fix poc format
duanjunwen Aug 28, 2024
582ba0d
[feat] fix func name & ci; add comments;
duanjunwen Aug 28, 2024
b1419ef
[fix] fix poc test; add comments in poc;
duanjunwen Aug 28, 2024
4c4b01b
[feat] add optim backward_b_by_grad
duanjunwen Aug 29, 2024
48ba22d
[feat] fix optimizer bwd b & w; support return accum loss & output
duanjunwen Aug 29, 2024
6af81d8
[feat] add fwd_bwd_step, run_fwd_only;
duanjunwen Aug 30, 2024
8eb6eac
[fix] fix optim bwd; add license for v_schedule; remove redundant att…
duanjunwen Aug 30, 2024
a7b767b
[fix] fix communication_map;
duanjunwen Aug 30, 2024
6d18d38
[feat] update test; rm comments;
duanjunwen Sep 2, 2024
77fe442
[fix] rm zbv in hybridplugin
duanjunwen Sep 2, 2024
591a13b
[fix] fix optim bwd;
duanjunwen Sep 2, 2024
a48afc4
[fix] fix optim bwd;
duanjunwen Sep 3, 2024
ab643c9
[fix] rm output.data after send fwd;
duanjunwen Sep 3, 2024
4c1f81c
[fix] fix bwd step if condition; remove useless comments and format i…
duanjunwen Sep 3, 2024
b4103f1
[fix] fix detach output & release output;
duanjunwen Sep 3, 2024
20503cd
[fix] rm requir_grad for output;
duanjunwen Sep 3, 2024
e6e1a97
[fix] fix requir grad position and detach position and input&output l…
duanjunwen Sep 4, 2024
2f09c37
[feat] add memory assertation;
duanjunwen Sep 4, 2024
4a35834
[fix] fix mem check;
duanjunwen Sep 4, 2024
400e5e5
[fix] mem assertation'
duanjunwen Sep 9, 2024
35a7b63
[fix] fix mem assertation
duanjunwen Sep 9, 2024
a5ec3d4
[fix] fix mem; use a new model shape; only assert mem less and equal …
duanjunwen Sep 9, 2024
fed8b15
[fix] fix model zoo import;
duanjunwen Sep 9, 2024
7568b34
[fix] fix redundant detach & clone; add buffer assertation in the end;
duanjunwen Sep 9, 2024
ce58d8e
[fix] add output_obj_grad assert None at bwd b step; replace input_ob…
duanjunwen Sep 9, 2024
8366a78
[fix] update optim state dict assert (include param group & state); f…
duanjunwen Sep 9, 2024
6c2a120
[fix] add testcase with microbatch 4;
duanjunwen Sep 9, 2024
11ae684
[zerobubble]Support ZeroBubble Pipeline (#6034)
duanjunwen Sep 10, 2024
9bc3b6e
[feat] moehybrid support zerobubble;
duanjunwen Sep 12, 2024
3dbad10
[fix] fix zerobubble pp for shardformer type input;
duanjunwen Sep 18, 2024
af2c2f8
[feat] add more test;
duanjunwen Sep 18, 2024
1f5c725
Merge remote-tracking branch 'upstream/feature/zerobubble' into dev/z…
duanjunwen Sep 19, 2024
6ee9584
[fix] fix require_grad & deallocate call;
duanjunwen Sep 19, 2024
349272c
[fix] updatw bwd b&w input; dict --> list[torch.Tensor]
duanjunwen Sep 19, 2024
a115106
[fix] fix bwd w input;
duanjunwen Sep 19, 2024
4753bf7
[fix] fix mem assert;
duanjunwen Sep 19, 2024
2678377
[fix] fix input_tensors buffer append input_obj(dict) --> Tuple (mic…
duanjunwen Sep 20, 2024
c6d6ee3
[fix] use tree_flatten replace dict traverse;
duanjunwen Sep 20, 2024
b6616f5
[fix] rm comments;
duanjunwen Sep 20, 2024
1739df4
[fix] fix fwd branch, fwd pass both micro_batch & internal_inputs'
duanjunwen Sep 20, 2024
da3220f
[fix] fix pipeline util func deallocate --> release_tensor_data; fix …
duanjunwen Sep 20, 2024
c114d14
[fix] fix detach clone release order;
duanjunwen Sep 23, 2024
a875212
[fix] fix ci --> oom in 4096 hidden dim;
duanjunwen Sep 23, 2024
6c1e155
[fix] fix dumb clone;
duanjunwen Sep 23, 2024
7e6f793
[fix] fix detach_output_obj clone;
duanjunwen Sep 24, 2024
8501202
Merge pull request #6065 from duanjunwen/dev/zero_bubble
duanjunwen Sep 24, 2024
fc8b016
[fix] fix stage_indices;
duanjunwen Sep 25, 2024
83163fa
[fix] fix traverse; traverse dict --> traverse tensor List;
duanjunwen Sep 25, 2024
a92e167
[fix] fix zerobubble; support shardformer model type;
duanjunwen Sep 26, 2024
45f17fc
[fix] rm comments;
duanjunwen Sep 26, 2024
c5503b0
[fix] fix test_pipeline_utils ci;
duanjunwen Sep 26, 2024
bb0390c
[fix] remove duplicate arg; rm comments;
duanjunwen Sep 26, 2024
64ceea7
[fix] remove chunk 0 stage 0 bwd b; u don't have to cal micrbatch's dx;
duanjunwen Sep 26, 2024
1342a98
[fix] rm print & comments;
duanjunwen Sep 26, 2024
b804fdc
Merge pull request #6069 from duanjunwen/dev/zero_bubble
duanjunwen Sep 27, 2024
af6aa9e
[plugin] hybrid support zero bubble pipeline (#6060)
flybird11111 Sep 27, 2024
d634795
[feat] zerobubble support moehybridplugin;
duanjunwen Sep 29, 2024
5c8bbf6
[feat] update optimizer bwd; ä¸
duanjunwen Sep 29, 2024
6975c50
[fix] fix build ci;
duanjunwen Sep 30, 2024
295dd2d
[zerobubble] rebase main (#6075)
flybird11111 Oct 8, 2024
f4d023c
Merge branch 'feature/zerobubble' of github.com:hpcaitech/ColossalAI …
duanjunwen Oct 8, 2024
292a504
[fix] fix mixtral policy;
duanjunwen Oct 8, 2024
cc500b3
[fix] fix mixtral policy;
duanjunwen Oct 8, 2024
531773f
Merge pull request #6077 from duanjunwen/dev/zero_bubble
duanjunwen Oct 9, 2024
3f5bec8
[feat] support zbv in mixtral benchmark;
duanjunwen Oct 9, 2024
9ee80fc
[fix] MixtralForCausalLMPolicy get_held_layer support zbv;
duanjunwen Oct 10, 2024
72b507a
[feat] update MixtralPipelineForwards --> mixtral_model_forward; supp…
duanjunwen Oct 10, 2024
e234dfa
[feat] support MixtralPipelineForwards--> mixtral_for_causal_lm_forwa…
duanjunwen Oct 10, 2024
dac0e07
[zero bubble] support zero (#6080)
flybird11111 Oct 11, 2024
0ca16d5
[fix] fix llama, mixtral benchmark zbv loss none bug; update mixtral …
duanjunwen Oct 11, 2024
cfade4c
[feat] Linear1D_COL/ROW support zbv WeightGradStore;
duanjunwen Oct 14, 2024
a11b4b5
[feat] support use_zbv in llama, mixtral modeling; only replace Linea…
duanjunwen Oct 14, 2024
abd4551
[fix] fix test case; moe error in second iter
duanjunwen Oct 14, 2024
160e9a4
[feat]EPMixtralSparseMoeBlock (op in MOE) support zbv;
duanjunwen Oct 14, 2024
9912cc8
[fix] fix bwd b; now bwd w only for Layer replaced by Linear1D_Col/Ro…
duanjunwen Oct 15, 2024
52dcc73
Merge branch 'feature/zerobubble' of github.com:hpcaitech/ColossalAI …
duanjunwen Oct 15, 2024
90939b7
[fix] debug zbv llama test;
duanjunwen Oct 15, 2024
e76308c
[fix] rm use_zbv flag in Shardconfig; rm debug info;
duanjunwen Oct 16, 2024
705b18e
[fix] add & fix llama test
duanjunwen Oct 16, 2024
2eca112
[feat] support meta cache, meta_grad_send, meta_tensor_send; fix runt…
duanjunwen Oct 24, 2024
d0ec221
[fix\ fix fail case test_shard_llama
duanjunwen Oct 25, 2024
cc0dfdd
[fix] fix test_shard_llama
duanjunwen Oct 25, 2024
03fa79a
[fix] fix llama modeling policy;
duanjunwen Oct 25, 2024
6377aa0
[fix] fix test_shard_llama ci;
duanjunwen Oct 28, 2024
5aee426
[fix] fix test zerobubble
duanjunwen Oct 28, 2024
fafe049
[fix] fix handle name; rm useless comments;
duanjunwen Oct 29, 2024
fa3ccda
[fix] fix send recv signature;
duanjunwen Oct 29, 2024
982e4ee
[fix] fix comment in llama & benchmark
duanjunwen Oct 29, 2024
d2e05a9
[feat] support no tensor parallel Linear in shardformer; Add test for…
duanjunwen Oct 30, 2024
5f09243
[fix] fix linear (no tp) ops func name;
duanjunwen Oct 31, 2024
aed20fb
[feat] support zbv in mixtral benchmark; (#6083)
duanjunwen Oct 31, 2024
1d328ff
Merge branch 'main' into dev/zero_bubble
duanjunwen Nov 1, 2024
c82c75a
Merge branch 'feature/zerobubble' of github.com:hpcaitech/ColossalAI …
duanjunwen Nov 1, 2024
3b5c314
[fix] fix fp8 args in HybridParallel
duanjunwen Nov 1, 2024
5b5fbcf
[fix] fix hybridparall use_fp8 config
duanjunwen Nov 1, 2024
0218e67
[fix] fix use_fp8 flag
duanjunwen Nov 1, 2024
8e40087
[fix] fix model zoo init
duanjunwen Nov 1, 2024
4fc92aa
[feat] support no_tp Linear for sharderformer.llama
duanjunwen Nov 5, 2024
0d6d40c
[fix] fix zbv llama pp4
duanjunwen Nov 6, 2024
12919de
[fix] fix send_tensor_metadata & send_grad_metadata;
duanjunwen Nov 11, 2024
337debc
[feat] fix testcase;
duanjunwen Nov 11, 2024
80b04d7
[feat] support mixtral policy with zbv tp_Linear & non_tp_Linear
duanjunwen Nov 12, 2024
b6d5e61
[feat] update mixtral policy & bert policy for zerobubble
duanjunwen Nov 14, 2024
1bc4dba
[fix] fix p2p error in zbv
duanjunwen Nov 14, 2024
014afbd
[fix] fix attn
duanjunwen Nov 14, 2024
5c2ebbf
[fix] fix mixtral modeling & policy; update wait handles; doing bench…
duanjunwen Nov 15, 2024
cf86c1b
[fix] fix zbv wait_handle
duanjunwen Nov 15, 2024
0fb500c
[fix] rm debug info; update llama policy; update wait handle
duanjunwen Nov 15, 2024
2980da5
[fix] fix test_lora
duanjunwen Nov 15, 2024
f48a85e
[fix] fix test_lora in llama policy
duanjunwen Nov 15, 2024
9a21f87
[fix] fix wait handle in run_fwd_bwd
duanjunwen Nov 18, 2024
dafda0f
[fix] remove debug info;
duanjunwen Nov 18, 2024
41fdd21
[fix] rm unused comments
duanjunwen Nov 18, 2024
cb9e5cc
Merge branch 'main' into dev/zero_bubble
duanjunwen Nov 18, 2024
8a0bad9
[fix] fix fp8 overlap code
duanjunwen Nov 19, 2024
9aa4c67
[fix] fix yml file & v_schedule comments
duanjunwen Nov 19, 2024
e4488b1
[fix] rm fwd only meta cache comments;
duanjunwen Nov 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion colossalai/amp/naive_amp/mixed_precision_mixin/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def zero_grad(self):
dtype: torch.dtype

@abstractmethod
def pre_backward(self, loss: Tensor) -> Tensor:
def pre_backward(self, loss: Tensor, *args, **kwargs) -> Tensor:
"""Called before backward.

Args:
Expand Down
13 changes: 9 additions & 4 deletions colossalai/amp/naive_amp/mixed_precision_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,18 @@ def __init__(
group["params"] = master_params
self._current_grad_norm: Optional[float] = None

def backward(self, loss: Tensor, *args, **kwargs):
def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs):
loss = self.mixed_precision.pre_backward(loss)
loss.backward(*args, **kwargs)
loss.backward(inputs=inputs, retain_graph=retain_graph, **kwargs)

def backward_by_grad(self, tensor: Tensor, grad: Tensor):
def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False):
grad = self.mixed_precision.pre_backward_by_grad(tensor, grad)
tensor.backward(grad)
torch.autograd.backward(
tensors=tensor,
grad_tensors=grad,
inputs=inputs,
retain_graph=retain_graph,
)

def zero_grad(self, *args, **kwargs):
for p in self.working_to_master_map.keys():
Expand Down
4 changes: 2 additions & 2 deletions colossalai/booster/mixed_precision/fp16_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ def __init__(
growth_interval=growth_interval,
)

def backward(self, loss: Tensor, *args, **kwargs) -> None:
def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs) -> None:
scaled_loss = self.scale_loss(loss)
scaled_loss.backward(*args, **kwargs)
scaled_loss.backward(inputs=inputs, retain_graph=retain_graph, **kwargs)

def step(self, *args, **kwargs) -> Optional[float]:
out = self.scaler.step(self.optim, *args, **kwargs)
Expand Down
64 changes: 43 additions & 21 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from colossalai.interface.optimizer import DistributedOptim
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule, ZeroBubbleVPipeScheduler
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.quantization.fp8_hook import FP8Hook
Expand Down Expand Up @@ -296,7 +296,7 @@ def __init__(
self._current_grad_norm: Optional[float] = None
super().__init__(optim)

def backward(self, loss: Tensor, *args, **kwargs):
def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs):
r"""
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.

Expand All @@ -315,7 +315,7 @@ def backward(self, loss: Tensor, *args, **kwargs):

# Call the superclass backward method to compute gradients.
with self.model._hook_context():
super().backward(loss, *args, **kwargs)
super().backward(loss, inputs=inputs, retain_graph=retain_graph, **kwargs)

if self.model.require_grad_sync:
# If gradient synchronization is required, sync sequence parallelism gradients.
Expand All @@ -324,7 +324,7 @@ def backward(self, loss: Tensor, *args, **kwargs):
# If gradient synchronization is is not required, return.
return

def backward_by_grad(self, tensor: Tensor, grad: Tensor):
def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False):
"""
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.

Expand All @@ -341,7 +341,7 @@ def backward_by_grad(self, tensor: Tensor, grad: Tensor):
"""

# Call the superclass backward method to compute gradients.
super().backward_by_grad(tensor, grad)
super().backward_by_grad(tensor, grad, inputs=inputs, retain_graph=retain_graph)

if self.model.require_grad_sync:
# If gradient synchronization is required, sync sequence parallelism gradients.
Expand Down Expand Up @@ -525,7 +525,7 @@ def __init__(
max_norm=max_norm,
)

def backward(self, loss: Tensor, *args, **kwargs):
def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs):
r"""
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.

Expand All @@ -543,7 +543,7 @@ def backward(self, loss: Tensor, *args, **kwargs):
"""
# Call the superclass backward method to compute gradients.
with self.model._hook_context():
super().backward(loss, *args, **kwargs)
super().backward(loss, inputs=inputs, retain_graph=retain_graph, **kwargs)

if self.model.require_grad_sync:
# If gradient synchronization is required, sync sequence parallelism gradients.
Expand All @@ -552,7 +552,7 @@ def backward(self, loss: Tensor, *args, **kwargs):
# If gradient synchronization is is not required, return.
return

def backward_by_grad(self, tensor: Tensor, grad: Tensor):
def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False):
"""
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.

Expand All @@ -568,7 +568,7 @@ def backward_by_grad(self, tensor: Tensor, grad: Tensor):
None
"""
# Call the superclass backward method to compute gradients.
super().backward_by_grad(tensor, grad)
super().backward_by_grad(tensor, grad, inputs=inputs, retain_graph=retain_graph)

if self.model.require_grad_sync:
# If gradient synchronization is required, sync sequence parallelism gradients.
Expand Down Expand Up @@ -785,7 +785,7 @@ def _get_grads_to_sync(all_working_grads) -> Union[List[Tensor], None]:
else:
return

def backward(self, loss, retain_graph=False):
def backward(self, loss, inputs=None, retain_graph=False):
"""
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.

Expand All @@ -801,7 +801,7 @@ def backward(self, loss, retain_graph=False):
None
"""
# Call the superclass backward method to compute gradients.
super().backward(loss, retain_graph)
super().backward(loss, inputs=inputs, retain_graph=retain_graph)

if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
# If gradient synchronization is required, sync sequence parallelism gradients.
Expand All @@ -810,7 +810,7 @@ def backward(self, loss, retain_graph=False):
# If gradient synchronization is is not required, return.
return

def backward_by_grad(self, tensor, grad):
def backward_by_grad(self, tensor, grad, inputs: Tensor = None, retain_graph: bool = False):
"""
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.

Expand All @@ -826,7 +826,7 @@ def backward_by_grad(self, tensor, grad):
None
"""
# Call the superclass backward_by_grad method to compute gradients.
super().backward_by_grad(tensor, grad)
super().backward_by_grad(tensor, grad, inputs=inputs, retain_graph=retain_graph)

if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
# If gradient synchronization is required, sync sequence parallelism gradients.
Expand Down Expand Up @@ -1030,6 +1030,7 @@ def __init__(
custom_policy: Policy = None,
pp_style: str = "1f1b",
num_model_chunks: int = 1,
scheduler_nodes: List = None,
num_layers_per_stage: Optional[List[int]] = None,
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None,
enable_metadata_cache: bool = True,
Expand All @@ -1048,6 +1049,9 @@ def __init__(
dist.get_world_size() % (tp_size * pp_size) == 0
), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"

assert (
not pp_style == "zbv" or scheduler_nodes is not None
), f"scheduler_nodes must not be None when using zero bubble pipeline."
if enable_sequence_parallelism:
self.sequence_parallelism_mode = (
sequence_parallelism_mode if sequence_parallelism_mode is not None else "all_to_all"
Expand Down Expand Up @@ -1109,29 +1113,39 @@ def __init__(
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size)

self.stage_manager = None
self.schedule = None
self.scheduler = None
self.custom_policy = custom_policy
assert zero_stage in (0, 1, 2)
if self.pp_size > 1:
assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style"
assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b"
assert pp_style in ["1f1b", "interleaved", "zbv"], "Unsupported pipeline parallelism style"
assert (
pp_style in ["interleaved", "zbv"] or num_model_chunks == 1
), "num_model_chunks must be 1 when using 1f1b"
assert (
pp_style in ["1f1b", "interleaved"] or num_model_chunks == 2
), "num_model_chunks must be 2 when using zero bubble pipeline"
assert (
num_microbatches is not None or microbatch_size is not None
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
assert (
self.zero_stage <= 1
), "To avoid prohibitive gradient synchronization costs, zero stage must be 0 or 1 when using pipeline parallelism"
if pp_style == "zbv":
self.logger.warning(
"""the enable_gradient_checkpointing function must set the use_reentrant to False, such as model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant':False})"""
)
self.stage_manager = PipelineStageManager(
self.pg_mesh,
pipeline_axis=self.pp_axis,
enable_interleave=pp_style == "interleaved",
enable_interleave=(pp_style == "interleaved" or pp_style == "zbv"),
use_zbv=(pp_style == "zbv"),
num_model_chunks=num_model_chunks,
num_layers_per_stage=num_layers_per_stage,
)

if pp_style == "interleaved":
assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved"
self.schedule = InterleavedSchedule(
self.scheduler = InterleavedSchedule(
stage_manager=self.stage_manager,
num_model_chunks=num_model_chunks,
num_microbatch=num_microbatches,
Expand All @@ -1141,13 +1155,21 @@ def __init__(
fp8_communication=fp8_communication,
)
elif pp_style == "1f1b":
self.schedule = OneForwardOneBackwardSchedule(
self.scheduler = OneForwardOneBackwardSchedule(
stage_manager=self.stage_manager,
num_microbatches=num_microbatches,
microbatch_size=microbatch_size,
enable_metadata_cache=enable_metadata_cache,
fp8_communication=fp8_communication,
)
elif pp_style == "zbv":
self.scheduler = ZeroBubbleVPipeScheduler(
stage_manager=self.stage_manager,
schedule=scheduler_nodes,
num_model_chunks=num_model_chunks,
num_microbatch=num_microbatches,
microbatch_size=microbatch_size,
)
else:
raise NotImplementedError()
if sequence_parallelism_mode == "ring_attn":
Expand Down Expand Up @@ -1263,7 +1285,6 @@ def configure(

# Replace with distributed implementation if exists
optimizer = cast_to_distributed(optimizer)

if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and self.dp_size > 0:
self.logger.warning(
"Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.",
Expand All @@ -1278,6 +1299,7 @@ def configure(
self.dp_size == 1 and self.pp_size == 1
)
# sync gradients across DP * SP ranks
# sync gradients across DP * SP ranks
# Apply Hybrid ZeRO across DP * SP ranks
if self.enable_sequence_parallelism and not is_share_sp_tp(self.sequence_parallelism_mode):
dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
Expand Down Expand Up @@ -1380,7 +1402,7 @@ def execute_pipeline(
ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()

with ctx, model._hook_context():
outputs = self.schedule.forward_backward_step(
outputs = self.scheduler.forward_backward_step(
model, data_iter, criterion, optimizer, return_loss, return_outputs
)

Expand Down
29 changes: 23 additions & 6 deletions colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from colossalai.nn.optimizer import cast_to_distributed
from colossalai.pipeline.schedule.interleaved_pp import InterleavedSchedule
from colossalai.pipeline.schedule.one_f_one_b import OneForwardOneBackwardSchedule
from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.policies.base_policy import Policy
from colossalai.shardformer.shard.grad_ckpt_config import GradientCheckpointConfig
Expand Down Expand Up @@ -212,6 +213,7 @@ def __init__(
custom_policy: Policy = None,
pp_style: str = "1f1b",
num_model_chunks: int = 1,
scheduler_nodes: List = None,
num_layers_per_stage: Optional[List[int]] = None,
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None,
enable_metadata_cache: bool = True,
Expand Down Expand Up @@ -285,12 +287,17 @@ def __init__(
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.moe_dp_size, self.ep_size, self.tp_size, self.sp_size)

self.stage_manager = None
self.schedule = None
self.scheduler = None
self.custom_policy = custom_policy
assert zero_stage in (0, 1, 2)
if self.pp_size > 1:
assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style"
assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b"
assert pp_style in ["1f1b", "interleaved", "zbv"], "Unsupported pipeline parallelism style"
assert (
pp_style in ["interleaved", "zbv"] or num_model_chunks == 1
), "num_model_chunks must be 1 when using 1f1b"
assert (
pp_style in ["1f1b", "interleaved"] or num_model_chunks == 2
), "num_model_chunks must be 2 when using zero bubble pipeline"
assert (
num_microbatches is not None or microbatch_size is not None
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
Expand All @@ -300,14 +307,15 @@ def __init__(
self.stage_manager = PipelineStageManager(
self.pg_mesh,
pipeline_axis=self.pp_axis,
enable_interleave=pp_style == "interleaved",
enable_interleave=(pp_style == "interleaved" or pp_style == "zbv"),
num_model_chunks=num_model_chunks,
num_layers_per_stage=num_layers_per_stage,
use_zbv=(pp_style == "zbv"),
)

if pp_style == "interleaved":
assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved"
self.schedule = InterleavedSchedule(
self.scheduler = InterleavedSchedule(
stage_manager=self.stage_manager,
num_model_chunks=num_model_chunks,
num_microbatch=num_microbatches,
Expand All @@ -316,12 +324,21 @@ def __init__(
overlap_p2p=overlap_p2p,
)
elif pp_style == "1f1b":
self.schedule = OneForwardOneBackwardSchedule(
self.scheduler = OneForwardOneBackwardSchedule(
stage_manager=self.stage_manager,
num_microbatches=num_microbatches,
microbatch_size=microbatch_size,
enable_metadata_cache=enable_metadata_cache,
)
elif pp_style == "zbv":
assert num_model_chunks > 1, "number of model chunks must be > 1 when using ZerbubbleV"
self.scheduler = ZeroBubbleVPipeScheduler(
schedule=scheduler_nodes,
stage_manager=self.stage_manager,
num_model_chunks=num_model_chunks,
num_microbatch=num_microbatches,
overlap_p2p=overlap_p2p,
)
else:
raise NotImplementedError()

Expand Down
25 changes: 21 additions & 4 deletions colossalai/interface/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,31 @@ def zero_grad(self, *args, **kwargs):
"""
self.optim.zero_grad(*args, **kwargs)

def backward(self, loss: Tensor, *args, **kwargs):
def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs):
"""
Performs a backward pass on the loss.
"""
loss.backward(*args, **kwargs)
loss.backward(inputs=inputs, retain_graph=retain_graph, **kwargs)

def backward_by_grad(self, tensor: Tensor, grad: Tensor):
torch.autograd.backward(tensor, grad)
def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False):
"""
Performs a backward pass for dx or dw,
for dx, we only calculate dx = w*dy here
for dw, we only calculate dw = x*dy here

Args:
tensor (Tensor): y or loss of current chunk;
grad_tensors (Tensor): dy of current chunk;
input_obj (Tensor): for dx, input_obj is x of current chunk;
for dw, input_obj is w of current chunk;
retain_graph (bool): default to be True, we retain graph in backward_b
"""
torch.autograd.backward(
tensors=tensor,
grad_tensors=grad,
inputs=inputs,
retain_graph=retain_graph,
)

def state_dict(self):
"""
Expand Down
3 changes: 2 additions & 1 deletion colossalai/pipeline/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from .p2p import PipelineP2PCommunication
from .schedule import InterleavedSchedule, OneForwardOneBackwardSchedule, PipelineSchedule
from .schedule import InterleavedSchedule, OneForwardOneBackwardSchedule, PipelineSchedule, ZeroBubbleVPipeScheduler
from .stage_manager import PipelineStageManager

__all__ = [
"PipelineSchedule",
"OneForwardOneBackwardSchedule",
"InterleavedSchedule",
"ZeroBubbleVPipeScheduler",
"PipelineP2PCommunication",
"PipelineStageManager",
]
Loading
Loading