diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index 36973b240896..56405ed47e00 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -29,6 +29,7 @@ from colossalai.nn.optimizer import cast_to_distributed from colossalai.pipeline.schedule.interleaved_pp import InterleavedSchedule from colossalai.pipeline.schedule.one_f_one_b import OneForwardOneBackwardSchedule +from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.policies.base_policy import Policy from colossalai.shardformer.shard.grad_ckpt_config import GradientCheckpointConfig @@ -207,6 +208,7 @@ def __init__( custom_policy: Policy = None, pp_style: str = "1f1b", num_model_chunks: int = 1, + scheduler_nodes: List = None, num_layers_per_stage: Optional[List[int]] = None, gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None, enable_metadata_cache: bool = True, @@ -282,8 +284,10 @@ def __init__( self.custom_policy = custom_policy assert zero_stage in (0, 1, 2) if self.pp_size > 1: - assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style" - assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b" + assert pp_style in ["1f1b", "interleaved", "zbv"], "Unsupported pipeline parallelism style" + assert ( + pp_style == "interleaved" or pp_style == "zbv" + ) or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b" assert ( num_microbatches is not None or microbatch_size is not None ), "num_microbatches or microbatch_size must be specified when using pipeline parallelism" @@ -293,7 +297,7 @@ def __init__( self.stage_manager = PipelineStageManager( self.pg_mesh, pipeline_axis=self.pp_axis, - enable_interleave=pp_style == "interleaved", + enable_interleave=(pp_style == "interleaved" or pp_style == "zbv"), num_model_chunks=num_model_chunks, num_layers_per_stage=num_layers_per_stage, ) @@ -315,6 +319,14 @@ def __init__( microbatch_size=microbatch_size, enable_metadata_cache=enable_metadata_cache, ) + elif pp_style == "zbv": + self.schedule = ZeroBubbleVPipeScheduler( + schedule=scheduler_nodes, + stage_manager=self.stage_manager, + num_model_chunks=num_model_chunks, + num_microbatch=num_microbatches, + overlap_p2p=overlap_p2p, + ) else: raise NotImplementedError() diff --git a/colossalai/pipeline/schedule/_utils.py b/colossalai/pipeline/schedule/_utils.py index 271b3238f5c4..b641eb3645cd 100644 --- a/colossalai/pipeline/schedule/_utils.py +++ b/colossalai/pipeline/schedule/_utils.py @@ -131,6 +131,16 @@ def retain_grad(x: Any) -> None: x.retain_grad() +def require_grad(x: Any) -> None: + """Call require_grad on a tensor. + + Args: + x (Any): Object to be called. + """ + if isinstance(x, torch.Tensor) and not x.requires_grad: + x.requires_grad_() + + def detach(x: Any) -> Any: """Call detach() on a tensor. @@ -145,6 +155,34 @@ def detach(x: Any) -> Any: return x +def clone(x: Any) -> Any: + """Call clone() on a tensor. + + Args: + x (Any): Object to be called. + + Returns: + Any: The cloned object. + """ + if isinstance(x, torch.Tensor): + return x.clone() + return x + + +def release_tensor_data(x: Any) -> Any: + """Call untyped_storage().resize_(0) on a tensor. Use to release tensor.data and keep grad_fn. + + Args: + x (Any): Object to be called. + + Returns: + Any: The deallocate .data object. + """ + if isinstance(x, torch.Tensor): + return x.data.untyped_storage().resize_(0) + return x + + def merge_batch(data: List[Any], batch_size_dim=0) -> Any: """Merge micro batches into a batch. diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index c1c4f13c68c2..bbad921b2ab5 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -4,7 +4,7 @@ import torch import torch.cuda from torch.nn import Module, ModuleList -from torch.utils._pytree import tree_map +from torch.utils._pytree import tree_flatten, tree_map from colossalai.accelerator import get_accelerator from colossalai.interface import OptimizerWrapper @@ -12,7 +12,18 @@ from colossalai.pipeline.schedule.v_schedule import ScheduledNode from colossalai.pipeline.stage_manager import PipelineStageManager -from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, retain_grad, to_device +from ._utils import ( + clone, + detach, + get_batch_size, + get_micro_batch, + merge_batch, + model_forward, + release_tensor_data, + require_grad, + retain_grad, + to_device, +) from .base import PipelineSchedule AUTO_SCHEDULE_COMMUNICATION_TYPES = {"RECV_FORWARD", "RECV_BACKWARD", "SEND_FORWARD", "SEND_BACKWARD"} @@ -24,21 +35,6 @@ def _wait_p2p(wait_handles: List[torch.cuda.Event]) -> None: req.wait() -def deallocate_output_tensor(out, deallocate_pipeline_outputs=False): - """Pseudo-deallocate (i.e., set to scalar) the output tensor's '.data' field. - - This method should be called right after the output tensor has been - sent to the next pipeline stage. At this point, the output tensor is - only useful for its '.grad_fn' field, and not its '.data'. - """ - if (out is None) or (not deallocate_pipeline_outputs): - return - assert isinstance(out, torch.Tensor), "expected Tensor, found %s." % type(out).__name__ - assert out._base is None, "counter-productive to free a view of another tensor." - # out.data = torch.empty((1,), device=out.device, dtype=out.dtype,) - out.data.untyped_storage().resize_(0) - - class ZeroBubbleVPipeScheduler(PipelineSchedule): def __init__( self, @@ -409,6 +405,7 @@ def forward_step( self, model_chunk: Union[ModuleList, Module], model_chunk_id: int, + micro_batch: Optional[dict], input_obj: Optional[dict], criterion: Callable, accum_loss: Optional[torch.Tensor] = None, @@ -427,18 +424,18 @@ def forward_step( Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor). """ # Load input ids, attention mask and labels - # micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id) - - # for the first stage, input_obj is None + # for the first stage, input_obj is None; So,we use micro_batch as input_obj # for other stages, input_obj is the output of the previous/next stage containing hidden_states etc. # Only attention_mask from micro_batch is used - with self.stage_manager.switch_model_chunk_id(model_chunk_id): - # fwd calculate - output_obj = model_chunk[model_chunk_id](input_obj) + # fwd calculate + internal_inputs = {} if input_obj is None else input_obj + # internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id] + output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, internal_inputs) + # last layer in model if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): - loss = criterion(output_obj) / self.num_microbatch + loss = criterion(output_obj, micro_batch) / self.num_microbatch if accum_loss is not None: accum_loss.add_(loss.detach()) if outputs is not None: @@ -452,6 +449,7 @@ def backward_b_step( model_chunk: Union[ModuleList, Module], model_chunk_id: int, optimizer: OptimizerWrapper, + micro_batch: Optional[dict], input_obj: Optional[dict], output_obj: Union[dict, torch.Tensor], output_obj_grad: Optional[dict], @@ -462,7 +460,7 @@ def backward_b_step( model_chunk (ModuleList or Module): Model Chunk to be run; model_chunk_id (int): The current model chunk idx; optimizer (OptimizerWrapper): Optimizer to update the model - input_obj (Optional[dict]): x. + input_obj (Optional[Tuple(dict)]): x. (microbatch, input_obj) output_obj (Union[dict, torch.Tensor]): y. output_obj_grad (dict): dy. @@ -471,20 +469,52 @@ def backward_b_step( """ # calculate bwd b step ; only dx = w*dy; - # Retain the grad on the input_obj. - tree_map(retain_grad, input_obj) + # Retain the grad on the input_obj. No need retain_grad microbatch + if input_obj is not None: + tree_map(retain_grad, input_obj) - if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): - # loss backward; output_obj is loss; so output_obj_grad should be None + # x, y, dy list for backward_by_grad; Type: list[tensor]; + input_obj_ = [] + output_obj_ = [] + output_obj_grad_ = [] + + # For chunk 0 stage 0, use micro_batch as input_obj_ + if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True): + input_obj_, _ = tree_flatten(micro_batch) + output_obj_, _ = tree_flatten(output_obj) # y + output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy + + # For loss backward; output_obj is loss; output_obj_grad should be None + elif model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): assert output_obj_grad is None + input_obj_, _ = tree_flatten(input_obj) + output_obj_.append(output_obj) # LOSS + output_obj_grad_.append(output_obj_grad) # None + + # For other chunk stage, use input_obj as input_obj_; + else: + input_obj_, _ = tree_flatten(input_obj) + output_obj_, _ = tree_flatten(output_obj) # y + output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy optimizer.backward_by_grad( - tensor=output_obj, - grad=output_obj_grad, - inputs=input_obj, + tensor=output_obj_, + grad=output_obj_grad_, + inputs=input_obj_, retain_graph=True, ) - return input_obj.grad + + # Format output_obj_grad + input_obj_grad = {} + if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True): + for k, v in micro_batch.items(): + if isinstance(v, torch.Tensor) and v.grad is not None: + input_obj_grad[k] = v.grad + else: + for k, v in input_obj.items(): + if isinstance(v, torch.Tensor) and v.grad is not None: + input_obj_grad[k] = v.grad + return input_obj_grad def backward_w_step( self, @@ -508,12 +538,21 @@ def backward_w_step( """ # calculate bwd w step ; only dw = x*dy; + # y, dy list for w backward_by_grad; Type: list[tensor]; + output_obj_ = [] + output_obj_grad_ = [] + if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): - # loss backward; output_obj is loss - output_obj_grad = None + # loss backward; output_obj is loss; + output_obj_.append(output_obj) # LOSS + output_obj_grad_.append(None) # None + else: + output_obj_, _ = tree_flatten(output_obj) # y + output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy + optimizer.backward_by_grad( - tensor=output_obj, - grad=output_obj_grad, + tensor=output_obj_, + grad=output_obj_grad_, inputs=list(model_chunk[model_chunk_id].parameters()), retain_graph=False, ) @@ -543,9 +582,9 @@ def schedule_f( micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id) # Step1: recv fwd if model_chunk_id == 0: - # is first stage; get input from func param + # is first stage; get input from microbatch if self.stage_manager.is_first_stage(ignore_chunk=True): - input_obj = micro_batch + input_obj = None else: input_obj = self.recv_forward_buffer[model_chunk_id].pop(0) else: @@ -557,55 +596,75 @@ def schedule_f( input_obj = self.recv_forward_buffer[model_chunk_id].pop(0) # Here, let input_obj.requires_grad_() - tree_map(torch.Tensor.requires_grad_, input_obj) + # if input_obj is not None: + if not isinstance(input_obj, torch.Tensor): + tree_map(require_grad, input_obj) + + # Also requires_grad_ for micro_batch in stage 0 chunk 0 fwd, + # tree_map(torch.Tensor.requires_grad_, micro_batch) # Step2: fwd step output_obj = self.forward_step( model_chunk=model_chunk, model_chunk_id=model_chunk_id, + micro_batch=micro_batch, input_obj=input_obj, criterion=criterion, accum_loss=accum_loss, outputs=outputs, ) + + # Step3: + # 3-1:detach output; detach output for send fwd; if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): # We should not detach bwd LOSS pass else: - detached_output_obj = output_obj.clone().detach() + # detach output + detached_output_obj = tree_map(detach, output_obj) + # 3-2 clone detached_output_obj + detached_output_obj = tree_map(clone, detached_output_obj) + + # 3-3 release cloned output.data; release_tensor_data output for bwd b & w; (do not detach output) + if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): + # We should not release_tensor_data bwd LOSS + pass + else: + # release_tensor_data output + tree_map(release_tensor_data, output_obj) + + # add input and output object for backward b + self.input_tensors[model_chunk_id].append((micro_batch, input_obj)) + + # for bwd b&w, we only need the graph(grad_fn) of output_obj + # Do not release_tensor_data loss, release_tensor_data other output_obj; + if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): + self.output_tensors[model_chunk_id].append(output_obj) + self.output_tensors_dw[model_chunk_id].append(output_obj) + else: + self.output_tensors[model_chunk_id].append(output_obj) + self.output_tensors_dw[model_chunk_id].append(output_obj) - # Step3: send fwd # add output to send_fwd_buffer - if model_chunk_id == 0: + if model_chunk_id == 0: # chunk 0 # is last stage; send to local_send_forward_buffer if self.stage_manager.is_last_stage(ignore_chunk=True): self.local_send_forward_buffer.append(detached_output_obj) else: self.send_forward_buffer[model_chunk_id].append(detached_output_obj) - else: - # is first stage; end of fwd; append LOSS to local_send_backward_buffer + else: # chunk 1 + # is first stage; end of fwd; do nothing if self.stage_manager.is_first_stage(ignore_chunk=True): pass else: self.send_forward_buffer[model_chunk_id].append(detached_output_obj) - # add input and output object for backward b - self.input_tensors[model_chunk_id].append(input_obj) - # detached output; for bwd b&w, we only need the graph(grad_fn) of output_obj - deallocate_output_tensor(output_obj, deallocate_pipeline_outputs=True) - self.output_tensors[model_chunk_id].append(output_obj) - # add output object for backward w - self.output_tensors_dw[model_chunk_id].append(output_obj) - def schedule_b( self, scheduled_node, model_chunk: Union[ModuleList, Module], model_chunk_id: int, optimizer: OptimizerWrapper, - # input_obj: Optional[dict], - # output_obj: Union[dict, torch.Tensor], - # output_obj_grad: Optional[dict], ): """A complete backward b schedule; Include recv bwd --> cal bwd step --> send bwd; @@ -616,25 +675,24 @@ def schedule_b( Returns: Nothing. """ - # Step1: recv bwd if model_chunk_id == 0: # chunk0 is last stage; recv output_grad from local_send_backward_buffer if self.stage_manager.is_last_stage(ignore_chunk=True): output_tensor_grad = self.local_send_backward_buffer.pop(0) - # chunk 0 not last stage; recv output_grad from recv_backward_buffer + # chunk0 not last stage; recv output_grad from recv_backward_buffer else: output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0) else: # chunk1, is first stage; recv LOSS from local send bwd buffer if self.stage_manager.is_first_stage(ignore_chunk=True): output_tensor_grad = None - # chunk1, not first stage; recv output_grad from recv_backward_buffer + # chunk1, not first stage; recv output_grad from recv_backward_buffer else: output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0) # get input and output object from buffer; - input_obj = self.input_tensors[model_chunk_id].pop(0) + micro_batch, input_obj = self.input_tensors[model_chunk_id].pop(0) output_obj = self.output_tensors[model_chunk_id].pop(0) # save output_tensor_grad for dw @@ -645,12 +703,12 @@ def schedule_b( # we save output_tensor_grad here self.output_tensors_grad_dw[model_chunk_id].append(output_tensor_grad) - # _wait_p2p(recv_bwd_handles) # Step2: bwd step input_object_grad = self.backward_b_step( model_chunk=model_chunk, model_chunk_id=model_chunk_id, optimizer=optimizer, + micro_batch=micro_batch, input_obj=input_obj, output_obj=output_obj, output_obj_grad=output_tensor_grad, @@ -777,8 +835,7 @@ def run_forward_backward( # communication communication_func = self.communication_map[scheduled_node.type] communication_func(scheduled_node.chunk) - - if scheduled_node.type == "F": + elif scheduled_node.type == "F": self.schedule_f( scheduled_node=scheduled_node, model_chunk=model_chunk, diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index 825c192d8fd5..14bc3475dac2 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -1,4 +1,6 @@ from copy import deepcopy +from functools import partial +from types import MethodType from typing import Tuple import pytest @@ -14,6 +16,7 @@ from colossalai.pipeline.schedule.v_schedule import PipelineGraph, ScheduledNode from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.tensor.d_tensor.api import clear_layout_converter from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo @@ -23,10 +26,32 @@ def __init__(self, in_dim, out_dim, num_layers): super().__init__() self.layers = nn.ModuleList([nn.Linear(in_dim, out_dim, bias=None) for _ in range(num_layers)]) - def forward(self, x): + def forward( + self, + hidden_states, + ): for layer in self.layers: - x = layer(x) - return x + hidden_states = layer(hidden_states) + return hidden_states + + +def pp_linear_fwd( + forward, + data: torch.Tensor = None, + hidden_states: torch.Tensor = None, + stage_mgr: PipelineStageManager = None, + model_chunk_id: int = None, +): + with stage_mgr.switch_model_chunk_id(model_chunk_id): + # fwd end + if stage_mgr.is_first_stage() and model_chunk_id == 1: + return forward(hidden_states) + # fwd start + elif stage_mgr.is_first_stage() and model_chunk_id == 0: + return {"hidden_states": forward(data)} + # fwd middle + else: + return {"hidden_states": forward(hidden_states)} def get_model_numel(model: torch.nn.Module) -> Tuple[int, int]: @@ -561,19 +586,24 @@ def run_fwd_bwd_vschedule_with_optim(test_config): # init loss func def criterion(x, *args, **kwargs): + x = x["hidden_states"] + return (x * x).mean() + + def criterion_base(x, *args, **kwargs): return (x * x).mean() # init model and input batch_size = test_config["batch_size"] num_layers = 8 assert num_layers % num_model_chunk == 0, f"Model with {num_layers} layer can not dist on {num_model_chunk} chunk" - in_dim = out_dim = 4096 + in_dim = out_dim = 1024 before_init_memory = torch.cuda.memory_allocated() / 1024**3 print(f"Before init Model: {before_init_memory :.3f} GB on device {stage_manager.get_rank()};") model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank) - data_iter = [torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)] - - input_base = [t.clone() for t in data_iter] + # data_iter = [torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)] + data_iter = {"data": torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)} + # input_base = [t.clone() for t in data_iter] + input_base = {k: v.clone() for k, v in data_iter.items()} model_base = deepcopy(model) if rank == 0: @@ -581,24 +611,44 @@ def criterion(x, *args, **kwargs): local_chunk = torch.nn.ModuleList().to(rank) for idx, sub_model in enumerate(model.layers): if idx == 0 or idx == 7: + sub_model._forward = sub_model.forward + sub_model.forward = MethodType( + partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(local_chunk)), + sub_model._forward, + ) local_chunk.append(sub_model) elif rank == 1: # layer 1 & 6 to chunk 1 on rank1 local_chunk = torch.nn.ModuleList().to(rank) for idx, sub_model in enumerate(model.layers): if idx == 1 or idx == 6: + sub_model._forward = sub_model.forward + sub_model.forward = MethodType( + partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(local_chunk)), + sub_model._forward, + ) local_chunk.append(sub_model) elif rank == 2: # layer 2 & 5 to chunk 2 on rank2 local_chunk = torch.nn.ModuleList().to(rank) for idx, sub_model in enumerate(model.layers): if idx == 2 or idx == 5: + sub_model._forward = sub_model.forward + sub_model.forward = MethodType( + partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(local_chunk)), + sub_model._forward, + ) local_chunk.append(sub_model) else: # layer 3 & 4 to chunk 3 on rank3 local_chunk = torch.nn.ModuleList().to(rank) for idx, sub_model in enumerate(model.layers): if idx == 3 or idx == 4: + sub_model._forward = sub_model.forward + sub_model.forward = MethodType( + partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(local_chunk)), + sub_model._forward, + ) local_chunk.append(sub_model) # init optimizer @@ -611,7 +661,7 @@ def criterion(x, *args, **kwargs): torch.cuda.synchronize() result = scheduler.forward_backward_step( model_chunk=local_chunk, - data_iter=iter(data_iter), + data_iter=iter([data_iter]), criterion=criterion, optimizer=optimizer_pp, return_loss=True, @@ -624,26 +674,28 @@ def criterion(x, *args, **kwargs): # assert memory if rank != 0: - # w.grad hid_dim * hid_dim * 4(fp32) * 2 (2 layer in each stage) / 1024**3 - # output hid_dim * hid_dim * 4(fp32) / 1024**3 - # optim state hid_dim * hid_dim * 4(fp32) * 2 (2 layer in each stage) / 1024**3 - print(f"rank {rank}: {(after_pp_step_memory - after_init_memory)} <= {(in_dim * in_dim * 4 * 5 / 1024**3)}") - assert (after_pp_step_memory - after_init_memory) <= (in_dim * in_dim * 4 * 5 / 1024**3) + # w.grad: hid_dim * hid_dim * microbatch * 4(fp32) * 2 (2 layer in each stage) / 1024**3 + # output: hid_dim * hid_dim * microbatch * 4(fp32) / 1024**3 + # optim: state hid_dim * hid_dim * 4(fp32) * 2 (2 layer in each stage) / 1024**3 + print( + f" num_microbatch {num_microbatch} rank {rank}: {(after_pp_step_memory - after_init_memory)} <= {(in_dim * in_dim * 4 * 5 * batch_size / 1024**3)}" + ) + assert (after_pp_step_memory - after_init_memory) <= (in_dim * in_dim * 4 * 5 * batch_size / 1024**3) else: # rank0 will also hold output; print( - f"rank {rank}: {round((after_pp_step_memory - after_init_memory), 5)} <= {round((in_dim * in_dim * 4 * 5 / 1024**3 + batch_size * in_dim * in_dim * 4 / 1024**3), 5)}" + f" num_microbatch {num_microbatch} rank {rank}: {round((after_pp_step_memory - after_init_memory), 5)} <= {round((in_dim * in_dim * 4 * 5 * batch_size / 1024**3 + batch_size * in_dim * in_dim * 4 / 1024**3), 5)}" ) assert round((after_pp_step_memory - after_init_memory), 5) <= round( - (in_dim * in_dim * 4 * 5 / 1024**3 + batch_size * in_dim * in_dim * 4 / 1024**3), 5 + (in_dim * in_dim * 4 * 5 * batch_size / 1024**3 + batch_size * in_dim * in_dim * 4 / 1024**3), 5 ) ########################## # Fwd bwd for base ########################## # fwd & bwd - output_base = model_base(input_base[0]) - loss_base = criterion(output_base) + output_base = model_base(input_base["data"]) + loss_base = criterion_base(output_base) loss_base.backward() optimizer_base.step() @@ -653,7 +705,7 @@ def criterion(x, *args, **kwargs): # only chunk 1 stage 0 hold loss and output if rank == 0: assert_close(result["loss"], loss_base) - assert_close(result["outputs"], output_base) + assert_close(result["outputs"]["hidden_states"], output_base) # print(f"pp result {result}; base result loss:{loss_base} output_base:{output_base} ") ########################## @@ -724,28 +776,108 @@ def run_with_hybridplugin(test_config): "test_config", [ { - "batch_size": 8, + "pp_style": "zbv", "tp_size": 1, + "ep_size": 1, "pp_size": 4, "num_microbatches": 4, "zero_stage": 1, "precision": "bf16", - "num_model_chunk": 2, + "num_model_chunks": 2, }, ], ) def run_with_moehybridplugin(test_config): - model_zoo.get_sub_registry("transformers_bert") - test_config["use_lazy_init"] = False + sub_model_zoo = model_zoo.get_sub_registry("transformers_bert") + # test_config["use_lazy_init"] = False test_config["initial_scale"] = 2**16 model_list = [ "transformers_bert", ] + clear_layout_converter() + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + if name in model_list: + # base param + model = model_fn() + data = data_gen_fn() + print(f"data {data}") + criterion = loss_fn + optimizer = torch.optim.SGD(model.parameters(), momentum=0.1, lr=1e-5) + + output = model(**data) + loss = criterion(output) + loss.backward() + optimizer.step() + print(f"output {output}") + + # # pp param + # model_pp = deepcopy(model) + # data_pp = deepcopy(data) + # optimizer_pp = OptimizerWrapper(torch.optim.SGD(model_pp.parameters(), momentum=0.1, lr=1e-5)) + + # # init pipeline graph + # h, a, s = model.config.hidden_size, model.config.num_attention_heads, 1024 + # mem_f = 34 * h + 5 * a * s + # mem_w = -32 * h + # mem_b = -mem_w - mem_f + # graph = PipelineGraph( + # n_stage=test_config["pp_size"], + # n_micro=test_config["num_microbatches"], + # f_cost=1, + # b_cost=1, + # w_cost=1, + # c_cost=1, + # f_mem=mem_f, + # b_mem=mem_b, + # w_mem=mem_w, + # # max_mem=mem_f * (p * 2 + m_offset), + # ) + + # zbv_schedule = graph.get_v_schedule() + + # test_config["scheduler_nodes"] = zbv_schedule + # plugin = MoeHybridParallelPlugin( + # **test_config + # ) + # model_pp, optimizer_pp, criterion, data_pp, _ = plugin.configure( + # model = model_pp, + # optimizer = optimizer_pp, + # criterion = criterion, + # dataloader = data_pp, + # ) + + # output_pp = plugin.execute_pipeline( + # data_iter=iter(data), + # model=model, + # criterion=criterion, + # optimizer=optimizer, + # return_loss = True, + # return_outputs = True, + # ) # TODO:6) support booster & Hybrid base 4) + # TODO:7) support booster & MoEHybrid base 4) +@parameterize( + "test_config", + [ + { + "pp_style": "zbv", + "tp_size": 1, + "ep_size": 1, + "pp_size": 4, + "num_microbatches": 4, + "zero_stage": 1, + "precision": "bf16", + "num_model_chunks": 2, + }, + ], +) +def run_with_booster_moehybridplugin(test_config): + pass def run_dist(rank, world_size, port): @@ -754,6 +886,7 @@ def run_dist(rank, world_size, port): # run_fwd_bwd_iter_input() run_fwd_bwd_vschedule_with_optim() # run_with_moehybridplugin() + # run_with_booster_moehybridplugin() @pytest.mark.dist