From 660eed912495eb0f9473ba53dd191e4b44e1d31f Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Thu, 7 Sep 2023 10:42:59 +0800 Subject: [PATCH] [pipeline] set optimizer to optional in execute_pipeline (#4630) * set optimizer to optional in execute_pipeline * arrange device and mixed precision in booster init * fix execute_pipeline in booster.py --- colossalai/booster/booster.py | 15 ++++++++++----- .../booster/plugin/hybrid_parallel_plugin.py | 6 +++--- colossalai/booster/plugin/pp_plugin_base.py | 4 ++-- colossalai/pipeline/schedule/base.py | 6 +++--- colossalai/pipeline/schedule/interleaved_pp.py | 6 ++++-- colossalai/pipeline/schedule/one_f_one_b.py | 6 ++++-- examples/language/bert/finetune.py | 10 ++-------- .../test_schedule/test_interleaved.py | 2 +- .../test_pipeline/test_schedule/test_oneF_oneB.py | 2 +- 9 files changed, 30 insertions(+), 27 deletions(-) diff --git a/colossalai/booster/booster.py b/colossalai/booster/booster.py index adb8f62a5084..7acf164def69 100644 --- a/colossalai/booster/booster.py +++ b/colossalai/booster/booster.py @@ -49,7 +49,9 @@ class Booster: ``` Args: - device (str or torch.device): The device to run the training. Default: 'cuda'. + device (str or torch.device): The device to run the training. Default: None. + If plugin is not used or plugin doesn't control the device, + this argument will be set as training device ('cuda' will be used if argument is None). mixed_precision (str or MixedPrecision): The mixed precision to run the training. Default: None. If the argument is a string, it can be 'fp16', 'fp16_apex', 'bf16', or 'fp8'. 'fp16' would use PyTorch AMP while `fp16_apex` would use Nvidia Apex. @@ -57,7 +59,7 @@ class Booster: """ def __init__(self, - device: str = 'cuda', + device: Optional[str] = None, mixed_precision: Union[MixedPrecision, str] = None, plugin: Optional[Plugin] = None) -> None: if plugin is not None: @@ -68,13 +70,16 @@ def __init__(self, # set accelerator if self.plugin and self.plugin.control_device(): self.accelerator = None - warnings.warn('The plugin will control the accelerator, so the device argument will be ignored.') + if device is not None: + warnings.warn('The plugin will control the accelerator, so the device argument will be ignored.') else: + device = device or 'cuda' self.accelerator = Accelerator(device) # set precision if self.plugin and self.plugin.control_precision(): - warnings.warn('The plugin will control the precision, so the mixed_precision argument will be ignored.') + if mixed_precision is not None: + warnings.warn('The plugin will control the precision, so the mixed_precision argument will be ignored.') self.mixed_precision = None elif mixed_precision is None: self.mixed_precision = None @@ -146,7 +151,7 @@ def execute_pipeline(self, data_iter: Iterator, model: nn.Module, criterion: Callable[[Any, Any], torch.Tensor], - optimizer: Optimizer, + optimizer: Optional[Optimizer] = None, return_loss: bool = True, return_outputs: bool = False) -> dict: # run pipeline forward backward pass diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index d33e3485c39c..125a9ccca1b5 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -443,15 +443,15 @@ def execute_pipeline(self, data_iter: Iterator, model: HybridParallelModule, criterion: Callable[[Any, Any], torch.Tensor], - optimizer: Union[HybridParallelNaiveOptimizer, HybridParallelAMPOptimizer, - HybridParallelZeroOptimizer], + optimizer: Optional[Union[HybridParallelNaiveOptimizer, HybridParallelAMPOptimizer, + HybridParallelZeroOptimizer]] = None, return_loss: bool = True, return_outputs: bool = False) -> dict: assert self.enable_pipeline_parallelism, 'pipeline parallelism is not enabled' # return loss or outputs if needed ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync() with ctx: - outputs = self.schedule.forward_backward_step(model, optimizer, data_iter, criterion, return_loss, + outputs = self.schedule.forward_backward_step(model, data_iter, criterion, optimizer, return_loss, return_outputs) model.sync_shared_params() if isinstance(optimizer, HybridParallelZeroOptimizer): diff --git a/colossalai/booster/plugin/pp_plugin_base.py b/colossalai/booster/plugin/pp_plugin_base.py index 67ade9330f5b..f52844db082f 100644 --- a/colossalai/booster/plugin/pp_plugin_base.py +++ b/colossalai/booster/plugin/pp_plugin_base.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import Any, Callable, Iterator +from typing import Any, Callable, Iterator, Optional import torch @@ -15,7 +15,7 @@ def execute_pipeline(self, data_iter: Iterator, model: ModelWrapper, criterion: Callable[[Any, Any], torch.Tensor], - optimizer: OptimizerWrapper, + optimizer: Optional[OptimizerWrapper] = None, return_loss: bool = True, return_outputs: bool = False) -> dict: pass diff --git a/colossalai/pipeline/schedule/base.py b/colossalai/pipeline/schedule/base.py index 9cd9beded65a..b0fa6e6ad2b8 100644 --- a/colossalai/pipeline/schedule/base.py +++ b/colossalai/pipeline/schedule/base.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Iterable +from typing import Any, Callable, Iterable, Optional from torch import Tensor from torch.nn import Module @@ -14,18 +14,18 @@ def __init__(self, stage_manager: PipelineStageManager) -> None: def forward_backward_step(self, model: Module, - optimizer: OptimizerWrapper, data_iter: Iterable, criterion: Callable[[Any, Any], Tensor], + optimizer: Optional[OptimizerWrapper] = None, return_loss: bool = False, return_outputs: bool = False) -> dict: """Forward and backward step for pipeline training. Args: model (Module): Model to be trained. - optimizer (OptimizerWrapper): Optimizer to be used. data_iter (Iterable): Data iterator. criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor. + optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None. return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss. return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs. diff --git a/colossalai/pipeline/schedule/interleaved_pp.py b/colossalai/pipeline/schedule/interleaved_pp.py index 35a33491b03c..6fdb09be5f32 100644 --- a/colossalai/pipeline/schedule/interleaved_pp.py +++ b/colossalai/pipeline/schedule/interleaved_pp.py @@ -237,18 +237,18 @@ def backward_step(self, optimizer: OptimizerWrapper, input_obj: Optional[dict], def forward_backward_step(self, model_chunk: Module, - optimizer: OptimizerWrapper, data_iter: Iterable, criterion: Callable[..., Any], + optimizer: Optional[OptimizerWrapper] = None, return_loss: bool = False, return_outputs: bool = False) -> dict: """Runs interleaved 1F1B schedule, with communication between pipeline stages. Args: model_chunk (List[Module]): Model Chunk to be trained. - optimizer (OptimizerWrapper): Optimizer to be used. data_iter (Iterable): Data iterator. criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor. + optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None. return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss. return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs. @@ -256,6 +256,8 @@ def forward_backward_step(self, dict: A dict with keys: 'loss' and 'outputs'. """ forward_only = not torch.is_grad_enabled() + if optimizer is None: + assert forward_only, "Optimizer should be passed when doing backward." self.load_batch(data_iter) num_model_chunks = len(model_chunk) diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index 5db1c7f30d7f..fbd0f9f0d4c0 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -210,18 +210,18 @@ def backward_step(self, optimizer: OptimizerWrapper, input_obj: Optional[dict], def forward_backward_step(self, model: Module, - optimizer: OptimizerWrapper, data_iter: Iterable, criterion: Callable[..., Any], + optimizer: Optional[OptimizerWrapper] = None, return_loss: bool = False, return_outputs: bool = False) -> dict: """Runs non-interleaved 1F1B schedule, with communication between pipeline stages. Args: model (Module): Model to be trained. - optimizer (OptimizerWrapper): Optimizer to be used. data_iter (Iterable): Data iterator. criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor. + optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None. return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss. return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs. @@ -229,6 +229,8 @@ def forward_backward_step(self, dict: A dict with keys: 'loss' and 'outputs'. """ forward_only = not torch.is_grad_enabled() + if optimizer is None: + assert forward_only, "Optimizer should be passed when doing backward." self.load_batch(data_iter) diff --git a/examples/language/bert/finetune.py b/examples/language/bert/finetune.py index c4d541c978a8..8864776967ce 100644 --- a/examples/language/bert/finetune.py +++ b/examples/language/bert/finetune.py @@ -46,7 +46,6 @@ def move_to_cuda(batch): @torch.no_grad() def evaluate_model( model: nn.Module, - optimizer, criterion, test_dataloader: Union[DataLoader, List[DataLoader]], num_labels: int, @@ -71,12 +70,7 @@ def evaluate_subset(dataloader: DataLoader): current_rank = dist.get_rank() #TODO pass dataloader to execute_pipeline directly batch = iter([batch]) - outputs = booster.execute_pipeline(batch, - model, - criterion, - optimizer, - return_loss=True, - return_outputs=True) + outputs = booster.execute_pipeline(batch, model, criterion, return_loss=True, return_outputs=True) if booster.plugin.stage_manager.is_last_stage(): val_loss = outputs["loss"] @@ -304,7 +298,7 @@ def _criterion(outputs, inputs): for epoch in range(NUM_EPOCHS): train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator) - results = evaluate_model(model, optimizer, _criterion, test_dataloader, data_builder.num_labels, args.task, + results = evaluate_model(model, _criterion, test_dataloader, data_builder.num_labels, args.task, data_builder.eval_splits, booster, coordinator) if coordinator.is_master(): diff --git a/tests/test_pipeline/test_schedule/test_interleaved.py b/tests/test_pipeline/test_schedule/test_interleaved.py index 2ac31c8ca0d1..a995d17e5da6 100644 --- a/tests/test_pipeline/test_schedule/test_interleaved.py +++ b/tests/test_pipeline/test_schedule/test_interleaved.py @@ -110,9 +110,9 @@ def examine_pp(num_micro_batches): torch_loss.backward() pp_ret = schedule.forward_backward_step(sharded_model, - pp_optimizer, iter(input_list), criterion, + pp_optimizer, return_loss=True, return_outputs=True) diff --git a/tests/test_pipeline/test_schedule/test_oneF_oneB.py b/tests/test_pipeline/test_schedule/test_oneF_oneB.py index d31eafd70e1a..41b535573c39 100644 --- a/tests/test_pipeline/test_schedule/test_oneF_oneB.py +++ b/tests/test_pipeline/test_schedule/test_oneF_oneB.py @@ -90,9 +90,9 @@ def examine_pp(): torch_loss.backward() pp_ret = schedule.forward_backward_step(sharded_model, - pp_optimizer, iter(input_list), criterion, + pp_optimizer, return_loss=True, return_outputs=True)