Skip to content

Commit

Permalink
[zerobubble]Support ZeroBubble Pipeline (#6034)
Browse files Browse the repository at this point in the history
* [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble;

* [feat] add dw test;

* [fix] fix weight not close;

* [update] update text;

* [feat] add test run_fwd_bwd automatic scheduling;

* [feat] split communication and calculation; fix pop empty send_bwd_buffer error;

* [feat] add test for p & p grad;

* [feat] add comments for ZBV func;

* [fix] rm useless assign and comments;

* [fix] fix ci test; add pytest;

* [feat] add run_fwd_bwd_with_microbatch  (replace input) & test; add p&p.grad assert close test & all pass;

* [feat] add apply v_schedule graph; p & p.grad assert err exist;

* [fix] update

* [feat] fix ci; add assert;

* [feat] fix poc format

* [feat] fix func name & ci; add comments;

* [fix] fix poc test; add comments in poc;

* [feat] add optim backward_b_by_grad

* [feat] fix optimizer bwd b & w; support return accum loss & output

* [feat] add fwd_bwd_step, run_fwd_only;

* [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict;

* [fix] fix communication_map;

* [feat] update test; rm comments;

* [fix] rm zbv in hybridplugin

* [fix] fix optim bwd;

* [fix] fix optim bwd;

* [fix] rm output.data after send fwd;

* [fix] fix bwd step if condition; remove useless comments and format info;

* [fix] fix detach output & release output;

* [fix] rm requir_grad for output;

* [fix] fix requir grad position and detach position and input&output local buffer append position;

* [feat] add memory assertation;

* [fix] fix mem check;

* [fix] mem assertation'

* [fix] fix mem assertation

* [fix] fix mem; use a new model shape; only assert mem less and equal than theo;

* [fix] fix model zoo import;

* [fix] fix redundant detach & clone; add buffer assertation in the end;

* [fix] add output_obj_grad assert None at bwd b step; replace input_obj.require_grad_ with treemap;

* [fix] update optim state dict assert (include param group & state); fix mem assert after add optim;

* [fix] add testcase with microbatch 4;
  • Loading branch information
duanjunwen authored Sep 10, 2024
1 parent 7cf9df0 commit 11ae684
Show file tree
Hide file tree
Showing 7 changed files with 2,131 additions and 4 deletions.
2 changes: 1 addition & 1 deletion colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1103,7 +1103,7 @@ def __init__(
self.stage_manager = PipelineStageManager(
self.pg_mesh,
pipeline_axis=self.pp_axis,
enable_interleave=pp_style == "interleaved",
enable_interleave=(pp_style == "interleaved"),
num_model_chunks=num_model_chunks,
num_layers_per_stage=num_layers_per_stage,
)
Expand Down
21 changes: 19 additions & 2 deletions colossalai/interface/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,25 @@ def backward(self, loss: Tensor, *args, **kwargs):
"""
loss.backward(*args, **kwargs)

def backward_by_grad(self, tensor: Tensor, grad: Tensor):
torch.autograd.backward(tensor, grad)
def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False):
"""
Performs a backward pass for dx or dw,
for dx, we only calculate dx = w*dy here
for dw, we only calculate dw = x*dy here
Args:
tensor (Tensor): y or loss of current chunk;
grad_tensors (Tensor): dy of current chunk;
input_obj (Tensor): for dx, input_obj is x of current chunk;
for dw, input_obj is w of current chunk;
retain_graph (bool): default to be True, we retain graph in backward_b
"""
torch.autograd.backward(
tensors=tensor,
grad_tensors=grad,
inputs=inputs,
retain_graph=retain_graph,
)

def state_dict(self):
"""
Expand Down
3 changes: 2 additions & 1 deletion colossalai/pipeline/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from .p2p import PipelineP2PCommunication
from .schedule import InterleavedSchedule, OneForwardOneBackwardSchedule, PipelineSchedule
from .schedule import InterleavedSchedule, OneForwardOneBackwardSchedule, PipelineSchedule, ZeroBubbleVPipeScheduler
from .stage_manager import PipelineStageManager

__all__ = [
"PipelineSchedule",
"OneForwardOneBackwardSchedule",
"InterleavedSchedule",
"ZeroBubbleVPipeScheduler",
"PipelineP2PCommunication",
"PipelineStageManager",
]
2 changes: 2 additions & 0 deletions colossalai/pipeline/schedule/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from .base import PipelineSchedule
from .interleaved_pp import InterleavedSchedule
from .one_f_one_b import OneForwardOneBackwardSchedule
from .zero_bubble_pp import ZeroBubbleVPipeScheduler

__all__ = [
"PipelineSchedule",
"OneForwardOneBackwardSchedule",
"InterleavedSchedule",
"ZeroBubbleVPipeScheduler",
]
Loading

0 comments on commit 11ae684

Please sign in to comment.