Skip to content

Commit

Permalink
Merge pull request #6069 from duanjunwen/dev/zero_bubble
Browse files Browse the repository at this point in the history
[HotFix] Fix stage_index in zerobubble test;
  • Loading branch information
duanjunwen authored Sep 27, 2024
2 parents 8501202 + 1342a98 commit b804fdc
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 133 deletions.
32 changes: 18 additions & 14 deletions colossalai/pipeline/schedule/zero_bubble_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,8 +430,8 @@ def forward_step(
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
# fwd calculate
internal_inputs = {} if input_obj is None else input_obj
# internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id]
output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, internal_inputs)
internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id]
output_obj = model_forward(model_chunk, micro_batch, internal_inputs)

# last layer in model
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
Expand All @@ -449,7 +449,7 @@ def backward_b_step(
model_chunk: Union[ModuleList, Module],
model_chunk_id: int,
optimizer: OptimizerWrapper,
micro_batch: Optional[dict],
# micro_batch: Optional[dict],
input_obj: Optional[dict],
output_obj: Union[dict, torch.Tensor],
output_obj_grad: Optional[dict],
Expand Down Expand Up @@ -478,11 +478,9 @@ def backward_b_step(
output_obj_ = []
output_obj_grad_ = []

# For chunk 0 stage 0, use micro_batch as input_obj_
# For chunk 0 stage 0, use micro_batch as input_obj_; and we don't have to cal microbatch dx.
if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
input_obj_, _ = tree_flatten(micro_batch)
output_obj_, _ = tree_flatten(output_obj) # y
output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy
return None

# For loss backward; output_obj is loss; output_obj_grad should be None
elif model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
Expand All @@ -497,6 +495,11 @@ def backward_b_step(
output_obj_, _ = tree_flatten(output_obj) # y
output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy

# filter item which is not torch.Tensor
input_obj_ = [v for v in input_obj_ if isinstance(v, torch.Tensor) or v is None]
output_obj_ = [v for v in output_obj_ if isinstance(v, torch.Tensor) or v is None]
output_obj_grad_ = [v for v in output_obj_grad_ if isinstance(v, torch.Tensor) or v is None]

optimizer.backward_by_grad(
tensor=output_obj_,
grad=output_obj_grad_,
Expand All @@ -507,9 +510,7 @@ def backward_b_step(
# Format output_obj_grad
input_obj_grad = {}
if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
for k, v in micro_batch.items():
if isinstance(v, torch.Tensor) and v.grad is not None:
input_obj_grad[k] = v.grad
pass
else:
for k, v in input_obj.items():
if isinstance(v, torch.Tensor) and v.grad is not None:
Expand Down Expand Up @@ -550,10 +551,14 @@ def backward_w_step(
output_obj_, _ = tree_flatten(output_obj) # y
output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy

# filter item which is not torch.Tensor
output_obj_ = [v for v in output_obj_ if isinstance(v, torch.Tensor) or v is None]
output_obj_grad_ = [v for v in output_obj_grad_ if isinstance(v, torch.Tensor) or v is None]

optimizer.backward_by_grad(
tensor=output_obj_,
grad=output_obj_grad_,
inputs=list(model_chunk[model_chunk_id].parameters()),
inputs=list(model_chunk.parameters()),
retain_graph=False,
)

Expand Down Expand Up @@ -634,7 +639,7 @@ def schedule_f(
tree_map(release_tensor_data, output_obj)

# add input and output object for backward b
self.input_tensors[model_chunk_id].append((micro_batch, input_obj))
self.input_tensors[model_chunk_id].append(input_obj)

# for bwd b&w, we only need the graph(grad_fn) of output_obj
# Do not release_tensor_data loss, release_tensor_data other output_obj;
Expand Down Expand Up @@ -692,7 +697,7 @@ def schedule_b(
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)

# get input and output object from buffer;
micro_batch, input_obj = self.input_tensors[model_chunk_id].pop(0)
input_obj = self.input_tensors[model_chunk_id].pop(0)
output_obj = self.output_tensors[model_chunk_id].pop(0)

# save output_tensor_grad for dw
Expand All @@ -708,7 +713,6 @@ def schedule_b(
model_chunk=model_chunk,
model_chunk_id=model_chunk_id,
optimizer=optimizer,
micro_batch=micro_batch,
input_obj=input_obj,
output_obj=output_obj,
output_obj_grad=output_tensor_grad,
Expand Down
12 changes: 12 additions & 0 deletions colossalai/pipeline/stage_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __init__(
pg_mesh: ProcessGroupMesh,
pipeline_axis: int,
enable_interleave: bool = False,
use_zbv: bool = False,
num_model_chunks: int = 1,
num_layers_per_stage: Optional[List[int]] = None,
) -> None:
Expand All @@ -49,6 +50,7 @@ def __init__(
next_coord = coord[: self.pipeline_axis] + (coord[self.pipeline_axis] + 1,) + coord[self.pipeline_axis + 1 :]
self.next_rank = self.pg_mesh.ravel(next_coord, self.pg_mesh.shape, mode="wrap")
self.is_interleave = enable_interleave
self.use_zbv = use_zbv
# for interleaved pipeline parallel, each device is responsible for multiple chunk of layers
self.num_model_chunks: int = num_model_chunks
# for shardformer, hold stage indices of model
Expand Down Expand Up @@ -85,6 +87,16 @@ def get_stage_index(
num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0)

stage_indices = []
if self.use_zbv:
stage_indices.append([num_layers_per_stage_accumulated[stage], num_layers_per_stage_accumulated[stage + 1]])
stage_indices.append(
[
num_layers_per_stage_accumulated[2 * num_stages - stage - 1],
num_layers_per_stage_accumulated[2 * num_stages - stage],
]
)
return stage_indices

for model_chunk in range(num_model_chunks):
start_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages]
end_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages + 1]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def __init__(self):
self.is_interleave = False
self.num_layers_per_stage = None
self.num_model_chunks = 1
self.use_zbv = False

@property
def num_stages(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def __init__(self):
self.is_interleave = False
self.num_layers_per_stage = None
self.num_model_chunks = 1
self.use_zbv = False

@property
def num_stages(self):
Expand Down
Loading

0 comments on commit b804fdc

Please sign in to comment.