diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index aa17fa269ccf..0fb992f1da52 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -1,13 +1,11 @@ import gc import logging import os -import warnings from pathlib import Path -from typing import Callable, Iterator, List, Optional, Tuple, Union +from typing import Callable, Iterator, List, Optional, Tuple import torch import torch.nn as nn -from torch import Tensor from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils.data import DataLoader @@ -16,7 +14,6 @@ from colossalai.checkpoint_io.utils import ( get_model_base_filenames, get_optimizer_base_filenames, - get_shard_filename, load_shard_state_dict, save_state_dict, save_state_dict_shards, @@ -24,8 +21,7 @@ from colossalai.cluster import DistCoordinator from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.utils import get_current_device -from colossalai.zero import GeminiDDP, zero_model_wrapper, zero_optim_wrapper -from colossalai.zero.gemini import ZeroOptimizer +from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.memory_tracer import MemStats from .dp_plugin_base import DPPluginBase @@ -132,11 +128,7 @@ def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_ As there is communication when getting state dict, this must be called on all processes. """ - # If optimizer is wrapped, unwrap it. - if isinstance(optimizer, OptimizerWrapper): - optimizer = optimizer.unwrap() - - assert isinstance(optimizer, ZeroOptimizer) + assert isinstance(optimizer, GeminiOptimizer) if os.path.isfile(checkpoint): logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") @@ -183,11 +175,7 @@ def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint_index_file: Pa if not os.path.isfile(checkpoint_index_file): logging.error(f"Provided path ({checkpoint_index_file}) should be a file") - # If optimizer is wrapped, unwrap it. - if isinstance(optimizer, OptimizerWrapper): - optimizer = optimizer.unwrap() - - assert isinstance(optimizer, ZeroOptimizer) + assert isinstance(optimizer, GeminiOptimizer) # Read checkpoint index file. ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file) @@ -220,36 +208,6 @@ def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): super().save_lr_scheduler(lr_scheduler, checkpoint) -class GeminiOptimizer(OptimizerWrapper): - - def __init__(self, - module: GeminiDDP, - optimizer: Optimizer, - zero_optim_config: dict, - optim_kwargs: dict, - verbose: bool = False) -> None: - optimizer = zero_optim_wrapper(module, - optimizer, - optim_config=zero_optim_config, - **optim_kwargs, - verbose=verbose) - super().__init__(optimizer) - - def backward(self, loss: Tensor, *args, **kwargs): - self.optim.backward(loss) - - def clip_grad_by_norm(self, - max_norm: Union[float, int], - norm_type: Union[float, int] = 2, - error_if_nonfinite: bool = False, - *args, - **kwargs) -> Tensor: - warnings.warn(f'Gemini controls grad clipping by itself, so you should not use clip_grad_by_norm') - - def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None: - raise NotImplementedError('Gemini does not support clip_grad_by_value') - - class GeminiPlugin(DPPluginBase): """ Plugin for Gemini. @@ -299,7 +257,7 @@ class GeminiPlugin(DPPluginBase): def __init__( self, - device: Optional[torch.device] = None, + chunk_init_device: Optional[torch.device] = None, placement_policy: str = "cpu", precision: str = "fp16", pin_memory: bool = False, @@ -324,7 +282,7 @@ def __init__( super().__init__() assert precision in SUPPORTED_PRECISION, f'precision {precision} is not supported' self.gemini_config = dict( - device=(device or get_current_device()), + chunk_init_device=(chunk_init_device or get_current_device()), placement_policy=placement_policy, pin_memory=pin_memory, force_outputs_fp32=force_outputs_fp32, @@ -383,13 +341,14 @@ def configure( # wrap the model with Gemini model = GeminiDDP(model, **self.gemini_config, verbose=self.verbose) - # TODO(ver217): remove this line - model._colo_zero_stage = 3 if optimizer is not None and \ not isinstance(optimizer, OptimizerWrapper): - optimizer = GeminiOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs, - self.verbose) + optimizer = GeminiOptimizer(optimizer, + model.unwrap(), + **self.zero_optim_config, + **self.optim_kwargs, + verbose=self.verbose) return model, optimizer, criterion, dataloader, lr_scheduler diff --git a/colossalai/zero/__init__.py b/colossalai/zero/__init__.py index 3465079e4fbb..4991241b8df1 100644 --- a/colossalai/zero/__init__.py +++ b/colossalai/zero/__init__.py @@ -2,8 +2,7 @@ ColoInitContext, GeminiAdamOptimizer, GeminiDDP, - ZeroDDP, - ZeroOptimizer, + GeminiOptimizer, get_static_torch_model, post_process_colo_init_ctx, ) @@ -11,6 +10,6 @@ from .wrapper import zero_model_wrapper, zero_optim_wrapper __all__ = [ - 'ZeroDDP', 'GeminiDDP', 'ZeroOptimizer', 'GeminiAdamOptimizer', 'zero_model_wrapper', 'zero_optim_wrapper', + 'GeminiDDP', 'GeminiOptimizer', 'GeminiAdamOptimizer', 'zero_model_wrapper', 'zero_optim_wrapper', 'LowLevelZeroOptimizer', 'ColoInitContext', 'post_process_colo_init_ctx', 'get_static_torch_model' ] diff --git a/colossalai/zero/gemini/__init__.py b/colossalai/zero/gemini/__init__.py index 60f85ca2f540..7ac6a9be4140 100644 --- a/colossalai/zero/gemini/__init__.py +++ b/colossalai/zero/gemini/__init__.py @@ -1,11 +1,11 @@ from .chunk import ChunkManager, TensorInfo, TensorState, search_chunk_configuration from .colo_init_context import ColoInitContext, post_process_colo_init_ctx -from .gemini_ddp import GeminiDDP, ZeroDDP +from .gemini_ddp import GeminiDDP from .gemini_mgr import GeminiManager -from .gemini_optimizer import GeminiAdamOptimizer, ZeroOptimizer +from .gemini_optimizer import GeminiAdamOptimizer, GeminiOptimizer from .utils import get_static_torch_model __all__ = [ - 'GeminiManager', 'TensorInfo', 'TensorState', 'ChunkManager', 'search_chunk_configuration', 'ZeroDDP', 'GeminiDDP', - 'get_static_torch_model', 'GeminiAdamOptimizer', 'ZeroOptimizer', 'ColoInitContext', 'post_process_colo_init_ctx' + 'GeminiManager', 'TensorInfo', 'TensorState', 'ChunkManager', 'search_chunk_configuration', 'GeminiDDP', + 'get_static_torch_model', 'GeminiAdamOptimizer', 'GeminiOptimizer', 'ColoInitContext', 'post_process_colo_init_ctx' ] diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index d0a2896a8dd2..c8f66a52ff23 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -31,14 +31,13 @@ _EXTRA_STATE_KEY_SUFFIX = '_extra_state' __all__ = [ - 'ZeroDDP', 'GeminiDDP', ] -class ZeroDDP(ModelWrapper): +class GeminiDDP(ModelWrapper): """ZeRO DDP. - Warning: Nested ZeroDDP is not supported now. + Warning: Nested GeminiDDP is not supported now. It is designed to be used with ChunkManager and GeminiManager. For more details, see the API reference of ``ChunkManager`` and ``GeminiManager``. @@ -55,20 +54,42 @@ class ZeroDDP(ModelWrapper): mixed_precision (torch.dtype): If set to torch.float16, the model will be trained in fp16. Otherwise, the model will be trained in bf16. Defaults to torch.float16. """ - def __init__(self, - module: torch.nn.Module, - gemini_manager: GeminiManager, - pin_memory: bool = False, - force_outputs_fp32: bool = False, - strict_ddp_mode: bool = False, - scatter_after_inference: bool = True, - mixed_precision: torch.dtype = torch.float16, - process_group: Optional[ProcessGroup] = None) -> None: + def __init__( + self, + module: torch.nn.Module, + chunk_config_dict: Optional[dict] = None, + chunk_init_device: torch.device = torch.device('cpu'), + placement_policy: str = "cpu", + search_range_m: int = 32, # chunk search options + hidden_dim: Optional[int] = None, # chunk search options + min_chunk_size_m: float = 32, # chunk search options + pin_memory: bool = False, + force_outputs_fp32: bool = False, + strict_ddp_mode: bool = False, + scatter_after_inference: bool = True, + mixed_precision: torch.dtype = torch.float16, + process_group: Optional[ProcessGroup] = None, + memstats: Optional[MemStats] = None, # genimi memory stats + verbose: bool = False) -> None: assert mixed_precision in (torch.float16, torch.bfloat16) - self.gemini_manager = gemini_manager - self.chunk_manager: ChunkManager = gemini_manager.chunk_manager + if chunk_config_dict is not None: + self.chunk_manager = ChunkManager(chunk_config_dict, chunk_init_device) + else: + # some ugly hotfix for the compatibility with Lightning + if search_range_m is None: + search_range_m = 32 + self.chunk_manager = init_chunk_manager(model=module, + init_device=chunk_init_device, + hidden_dim=hidden_dim, + search_range_m=search_range_m, + min_chunk_size_m=min_chunk_size_m, + strict_ddp_flag=strict_ddp_mode, + process_group=process_group, + verbose=verbose) + self.gemini_manager = GeminiManager(placement_policy, self.chunk_manager, memstats) + self.force_outputs_fp32 = force_outputs_fp32 - self.param_op_hook = GeminiZeROHook(gemini_manager) + self.param_op_hook = GeminiZeROHook(self.gemini_manager) self.fp32_params: List[torch.Tensor] = list() self.fp16_params: List[ColoParameter] = list() self.overflow_counter = 0 @@ -257,7 +278,7 @@ def _post_backward(self): error_params.append(self.param2name[param]) error_str = "\n\t".join(error_params) raise RuntimeError("ZERO DDP error: the synchronization of gradients doesn't exit properly.", - "The most possible reason is that the model is not compatible with ZeroDDP.\n", + "The most possible reason is that the model is not compatible with GeminiDDP.\n", f"{error_str}") self._setup_grads_ptr() self._logger.debug( @@ -772,70 +793,3 @@ def append(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[OrderedDict] self.current_block[name] = tensor self.current_block_size += tensor_size return ret_block, ret_block_size - - -class GeminiDDP(ZeroDDP): - - def __init__(self, - module: torch.nn.Module, - device: torch.device, - placement_policy: str = "cpu", - pin_memory: bool = False, - force_outputs_fp32: bool = False, - strict_ddp_mode: bool = False, - scatter_after_inference: bool = True, - search_range_m: int = 32, - hidden_dim: Optional[int] = None, - min_chunk_size_m: float = 32, - memstats: Optional[MemStats] = None, - mixed_precision: torch.dtype = torch.float16, - process_group: Optional[ProcessGroup] = None, - verbose: bool = False) -> None: - """ - A torch.Module wrapper using ZeRO-DP and Gemini. - ZeRO is for parallel. Gemini is for memory management. - WARNING: The class will modify the module inline! - - Example: - model is initialized under the context of ColoInitContext - >>> model = GeminiDDP(model, torch.cuda.current_device(), "cuda") - >>> logits = model(x) - >>> loss = criterion(logits, labels) - >>> model.backward(loss) - - Args: - module (torch.nn.Module): the model to be wrapped. - device (torch.device): device to place the model. - placement_policy (str, optional): "cpu", "cuda", "auto". Defaults to "cpu". - pin_memory (bool, optional): use pin memory on CPU. Defaults to False. - force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False. - search_range_m (int, optional): chunk size searching range divided by 2^20. Defaults to 32. - hidden_dim (int, optional): the hidden dimension of DNN. - Users can provide this argument to speed up searching. - If users do not know this argument before training, it is ok. We will use a default value 1024. - min_chunk_size_m (float, optional): the minimum chunk size divided by 2^20. - If the aggregate size of parameters is still smaller than the minimum chunk size, - all parameters will be compacted into one small chunk. - memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer. - """ - # some ugly hotfix for the compatibility with Lightning - if search_range_m is None: - search_range_m = 32 - - chunk_manager = init_chunk_manager(model=module, - init_device=device, - hidden_dim=hidden_dim, - search_range_m=search_range_m, - min_chunk_size_m=min_chunk_size_m, - strict_ddp_flag=strict_ddp_mode, - process_group=process_group, - verbose=verbose) - gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats) - super().__init__(module, - gemini_manager, - pin_memory, - force_outputs_fp32, - strict_ddp_mode, - scatter_after_inference, - mixed_precision=mixed_precision, - process_group=process_group) diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py index 7d0db6b1fa23..31cb5f671a00 100644 --- a/colossalai/zero/gemini/gemini_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -3,7 +3,7 @@ import gc import math import warnings -from typing import Any, Dict, Iterator, OrderedDict, Set, Tuple +from typing import Any, Dict, Iterator, OrderedDict, Set, Tuple, Union import torch import torch.distributed as dist @@ -12,15 +12,16 @@ from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin from colossalai.checkpoint_io.utils import calculate_tensor_size +from colossalai.interface import OptimizerWrapper from colossalai.logging import get_dist_logger -from colossalai.nn.optimizer import ColossalaiOptimizer, CPUAdam, FusedAdam, HybridAdam +from colossalai.nn.optimizer import CPUAdam, FusedAdam, HybridAdam from colossalai.tensor.d_tensor import is_distributed_tensor from colossalai.utils import disposable, get_current_device, is_ddp_ignored from .chunk import Chunk, ChunkManager -from .gemini_ddp import ZeroDDP +from .gemini_ddp import GeminiDDP -__all__ = ['ZeroOptimizer', 'GeminiAdamOptimizer'] +__all__ = ['GeminiOptimizer', 'GeminiAdamOptimizer'] _AVAIL_OPTIM_LIST = {FusedAdam, CPUAdam, HybridAdam} @@ -28,7 +29,7 @@ class GeminiFP16MixedPrecisionMixin(FP16MixedPrecisionMixin): def __init__(self, - module: ZeroDDP, + module: GeminiDDP, initial_scale: float = 2**16, min_scale: float = 1, growth_factor: float = 2, @@ -47,11 +48,11 @@ def pre_zero_grad(self) -> None: self.module.overflow_counter = 0 -class ZeroOptimizer(ColossalaiOptimizer): - """A wrapper for optimizer. ``ZeroDDP`` and ``ZeroOptimizer`` implement Zero Redundancy Optimizer (ZeRO state-3). +class GeminiOptimizer(OptimizerWrapper): + """A wrapper for optimizer. ``GeminiDDP`` and ``GeminiOptimizer`` implement Zero Redundancy Optimizer (ZeRO state-3). Note: - You must use ``ZeroDDP`` with ``ZeroOptimizer``. + You must use ``GeminiDDP`` with ``GeminiOptimizer``. Note: Make sure you set ``placement_policy`` of ``GeminiManager`` to `"auto"`, @@ -59,7 +60,7 @@ class ZeroOptimizer(ColossalaiOptimizer): Args: optim (Optimizer): An Optimizer instance. - module (ZeroDDP): A ``ZeroDDP`` instance. + module (GeminiDDP): A ``GeminiDDP`` instance. gpu_margin_mem_ratio (float, optional): The ratio of GPU remaining memory (after the first forward-backward) which will be used when using hybrid CPU optimizer. This argument is meaningless when `placement_policy` of `GeminiManager` is not "auto". @@ -71,15 +72,15 @@ class ZeroOptimizer(ColossalaiOptimizer): growth_interval (float, optional): Growth_interval used by DynamicGradScaler. Defaults to 1000. hysteresis (float, optional): Hysteresis used by DynamicGradScaler. Defaults to 2. max_scale (int, optional): Max_scale used by DynamicGradScaler. Defaults to 2**32. - clipping_norm (float, optional): The norm value used to clip gradient. Defaults to 0.0. + max_norm (float, optional): The norm value used to clip gradient. Defaults to 0.0. norm_type (float, optional): The type of norm used for gradient clipping. Currently, only L2-norm (norm_type=2.0) - is supported in ZeroOptimizer. Defaults to 2.0. + is supported in GeminiOptimizer. Defaults to 2.0. verbose (bool, optional): Whether to print verbose information, including grad overflow info. Defaults to False. """ def __init__(self, optim: Optimizer, - module: ZeroDDP, + module: GeminiDDP, gpu_margin_mem_ratio: float = 0.0, initial_scale: float = 2**32, min_scale: float = 1, @@ -88,12 +89,12 @@ def __init__(self, growth_interval: int = 1000, hysteresis: int = 2, max_scale: float = 2**32, - clipping_norm: float = 0.0, + max_norm: float = 0.0, norm_type: float = 2.0, verbose: bool = False, **defaults: Any): super().__init__(optim) - assert isinstance(module, ZeroDDP) + assert isinstance(module, GeminiDDP) assert type(optim) in _AVAIL_OPTIM_LIST, "You should use an optimizer in the available list:\n" \ f"{_AVAIL_OPTIM_LIST}" self.module = module @@ -102,8 +103,8 @@ def __init__(self, self.param_to_range: Dict[Parameter, Tuple[int, int]] = dict() self.param_to_chunk32: Dict[Parameter, Chunk] = dict() self.chunk16_set: Set[Chunk] = set() - self.clipping_flag = clipping_norm > 0.0 - self.max_norm = clipping_norm + self.clipping_flag = max_norm > 0.0 + self.max_norm = max_norm self.verbose = verbose self.param_groups_backup = list() @@ -112,7 +113,7 @@ def __init__(self, self.id_to_fake_params: Dict[int, Parameter] = dict() if self.clipping_flag: - assert norm_type == 2.0, "ZeroOptimizer only supports L2 norm now" + assert norm_type == 2.0, "GeminiOptimizer only supports L2 norm now" ddp_param_list = [] for name, param in module.named_parameters(): @@ -741,8 +742,19 @@ def state_shard(self, yield current_block, current_block_size + def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None: + raise NotImplementedError('Gemini does not support clip_grad_by_value') -class GeminiAdamOptimizer(ZeroOptimizer): + def clip_grad_by_norm(self, + max_norm: Union[float, int], + norm_type: Union[float, int] = 2, + error_if_nonfinite: bool = False, + *args, + **kwargs) -> torch.Tensor: + warnings.warn(f'Gemini controls grad clipping by itself, so you should not use clip_grad_by_norm') + + +class GeminiAdamOptimizer(GeminiOptimizer): def __init__(self, model: torch.nn.Module, **defaults: Any) -> None: optimizer = HybridAdam(model.parameters(), **defaults) diff --git a/colossalai/zero/gemini/memory_tracer/memory_stats.py b/colossalai/zero/gemini/memory_tracer/memory_stats.py index 41d7e5754e96..02de6ecb97a9 100644 --- a/colossalai/zero/gemini/memory_tracer/memory_stats.py +++ b/colossalai/zero/gemini/memory_tracer/memory_stats.py @@ -9,7 +9,7 @@ class MemStats(object): def __init__(self) -> None: """ - Store the non model data statistics used for Gemini and ZeroOptimizer. + Store the non model data statistics used for Gemini and GeminiOptimizer. """ # (preop_step, List[param]) self._step_param_dict = dict() diff --git a/colossalai/zero/gemini/utils.py b/colossalai/zero/gemini/utils.py index 6f4a253b504b..0d92d32e5603 100644 --- a/colossalai/zero/gemini/utils.py +++ b/colossalai/zero/gemini/utils.py @@ -64,13 +64,13 @@ def get_static_torch_model(zero_ddp_model, device=torch.device("cpu"), dtype=torch.float32, only_rank_0=True) -> torch.nn.Module: - """Get a static torch.nn.Module model from the given ZeroDDP module. - You should notice that the original ZeroDDP model is not modified. + """Get a static torch.nn.Module model from the given GeminiDDP module. + You should notice that the original GeminiDDP model is not modified. Thus, you can use the original model in further training. But you should not use the returned torch model to train, this can cause unexpected errors. Args: - zero_ddp_model (ZeroDDP): a zero ddp model + zero_ddp_model (GeminiDDP): a zero ddp model device (torch.device): the device of the final torch model dtype (torch.dtype): the dtype of the final torch model only_rank_0 (bool): if True, only rank0 has the converted torch model @@ -78,8 +78,8 @@ def get_static_torch_model(zero_ddp_model, Returns: torch.nn.Module: a static torch model used for saving checkpoints or numeric checks """ - from colossalai.zero.gemini.gemini_ddp import ZeroDDP - assert isinstance(zero_ddp_model, ZeroDDP) + from colossalai.zero.gemini.gemini_ddp import GeminiDDP + assert isinstance(zero_ddp_model, GeminiDDP) state_dict = zero_ddp_model.state_dict(only_rank_0=only_rank_0) colo_model = zero_ddp_model.module diff --git a/colossalai/zero/wrapper.py b/colossalai/zero/wrapper.py index 3e48f49fa305..90325fe0a704 100644 --- a/colossalai/zero/wrapper.py +++ b/colossalai/zero/wrapper.py @@ -109,6 +109,6 @@ def zero_optim_wrapper(model: nn.Module, config_dict['clip_grad_norm'] = max_norm return LowLevelZeroOptimizer(optimizer, **config_dict, verbose=verbose) else: - from colossalai.zero.gemini.gemini_optimizer import ZeroOptimizer + from colossalai.zero.gemini.gemini_optimizer import GeminiOptimizer config_dict['clipping_norm'] = max_norm - return ZeroOptimizer(optimizer, model, **config_dict, verbose=verbose) + return GeminiOptimizer(optimizer, model, **config_dict, verbose=verbose) diff --git a/examples/community/roberta/pretraining/run_pretraining.py b/examples/community/roberta/pretraining/run_pretraining.py index 9fae4bef227a..53fa9f489c10 100644 --- a/examples/community/roberta/pretraining/run_pretraining.py +++ b/examples/community/roberta/pretraining/run_pretraining.py @@ -22,7 +22,7 @@ from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec from colossalai.utils import get_current_device from colossalai.utils.model.colo_init_context import ColoInitContext -from colossalai.zero import ZeroOptimizer +from colossalai.zero import GeminiOptimizer def main(): @@ -46,7 +46,7 @@ def main(): args.local_rank = -1 args.log_interval = 1 else: - colossalai.launch_from_torch(config={}) #args.colossal_config + colossalai.launch_from_torch(config={}) # args.colossal_config args.local_rank = int(os.environ["LOCAL_RANK"]) logger.info( f'launch_from_torch, world size: {torch.distributed.get_world_size()} | ' + @@ -123,7 +123,8 @@ def main(): get_tflops_func = partial(get_tflops, numel, args.train_micro_batch_size_per_gpu, args.max_seq_length) # 144003367 is is the length of the entire dataset - steps_per_epoch = 144003367 // world_size // args.train_micro_batch_size_per_gpu // args.gradient_accumulation_steps // args.refresh_bucket_size #len(dataloader) + # len(dataloader) + steps_per_epoch = 144003367 // world_size // args.train_micro_batch_size_per_gpu // args.gradient_accumulation_steps // args.refresh_bucket_size total_steps = steps_per_epoch * args.epoch lr_scheduler = get_lr_scheduler(optimizer, total_steps=total_steps, last_epoch=-1) diff --git a/examples/language/palm/train.py b/examples/language/palm/train.py index a0600db1bc5b..01862a02608b 100644 --- a/examples/language/palm/train.py +++ b/examples/language/palm/train.py @@ -1,5 +1,4 @@ import gzip -import random from functools import partial from time import time @@ -8,20 +7,17 @@ import torch.nn as nn import torch.optim as optim import tqdm -from packaging import version - -from colossalai.nn import HybridAdam from palm_pytorch import PaLM from palm_pytorch.autoregressive_wrapper import AutoregressiveWrapper from torch.utils.data import DataLoader, Dataset import colossalai -from colossalai.logging import disable_existing_loggers, get_dist_logger -from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec -from colossalai.utils import MultiTimer, get_current_device -from colossalai.zero import ColoInitContext, GeminiAdamOptimizer, ZeroDDP from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.nn import HybridAdam +from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec +from colossalai.zero import ColoInitContext # constants @@ -111,8 +107,6 @@ def get_model_size(model: nn.Module): return total_numel - - # Parameter Sharding Strategies for Tensor Parallelism def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup): spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) @@ -212,9 +206,9 @@ def __len__(self): if args.plugin.startswith('torch_ddp'): plugin = TorchDDPPlugin() elif args.plugin == 'gemini': - plugin = GeminiPlugin(placement_policy=args.placement, strict_ddp_mode=True, initial_scale=2 ** 5) + plugin = GeminiPlugin(placement_policy=args.placement, strict_ddp_mode=True, initial_scale=2**5) elif args.plugin == 'low_level_zero': - plugin = LowLevelZeroPlugin(initial_scale=2 ** 5) + plugin = LowLevelZeroPlugin(initial_scale=2**5) logger.info(f"plugin: {plugin}") booster = Booster(plugin=plugin, **booster_kwargs) diff --git a/examples/tutorial/opt/opt/requirements.txt b/examples/tutorial/opt/opt/requirements.txt index d0ed2c717aee..ae290080d13a 100644 --- a/examples/tutorial/opt/opt/requirements.txt +++ b/examples/tutorial/opt/opt/requirements.txt @@ -3,5 +3,5 @@ torch >= 1.8.1 datasets >= 1.8.0 sentencepiece != 0.1.92 protobuf -accelerate == 0.13.2 +accelerate transformers diff --git a/examples/tutorial/opt/opt/run_clm.py b/examples/tutorial/opt/opt/run_clm.py index fdc86adab665..9f2aa7e645f3 100755 --- a/examples/tutorial/opt/opt/run_clm.py +++ b/examples/tutorial/opt/opt/run_clm.py @@ -30,7 +30,7 @@ import datasets import torch import torch.distributed as dist -import transformers +import transformers.utils.logging as logging from accelerate.utils import set_seed from context import barrier_context from datasets import load_dataset @@ -57,7 +57,7 @@ from colossalai.nn.optimizer import HybridAdam from colossalai.tensor import ProcessGroup from colossalai.utils import get_current_device, get_dataloader -from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer +from colossalai.zero import GeminiOptimizer require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") @@ -292,10 +292,10 @@ def main(): if is_main_process: datasets.utils.logging.set_verbosity_warning() - transformers.utils.logging.set_verbosity_info() + logging.set_verbosity_info() else: datasets.utils.logging.set_verbosity_error() - transformers.utils.logging.set_verbosity_error() + logging.set_verbosity_error() if args.mem_cap > 0: colo_memory_cap(args.mem_cap) @@ -391,16 +391,28 @@ def main(): else: init_dev = get_current_device() + cai_version = colossalai.__version__ + logger.info(f'using Colossal-AI version {cai_version}') # build model + if version.parse(cai_version) >= version.parse("0.3.1"): + from contextlib import nullcontext + + from colossalai.lazy import LazyInitContext + ctx = LazyInitContext( + default_device=init_dev + ) if args.model_name_or_path is None or args.model_name_or_path == 'facebook/opt-13b' else nullcontext() + else: + from colossalai.zero import ColoInitContext + ctx = ColoInitContext(device=init_dev) if args.model_name_or_path is None or args.model_name_or_path == 'facebook/opt-13b': # currently, there has a bug in pretrained opt-13b # we can not import it until huggingface fix it logger.info("Train a new model from scratch", ranks=[0]) - with ColoInitContext(device=init_dev): + with ctx: model = OPTForCausalLM(config) else: logger.info("Finetune a pre-trained model", ranks=[0]) - with ColoInitContext(device=init_dev): + with ctx: model = OPTForCausalLM.from_pretrained(args.model_name_or_path, from_tf=bool(".ckpt" in args.model_name_or_path), config=config, @@ -410,9 +422,13 @@ def main(): model.gradient_checkpointing_enable() PLACEMENT_POLICY = 'auto' - cai_version = colossalai.__version__ - logger.info(f'using Colossal-AI version {cai_version}') - if version.parse(cai_version) > version.parse("0.1.10"): + if version.parse(cai_version) >= version.parse("0.3.1"): + from colossalai.zero import GeminiDDP + model = GeminiDDP(model, + chunk_init_device=get_current_device(), + placement_policy=PLACEMENT_POLICY, + pin_memory=True) + elif version.parse(cai_version) > version.parse("0.1.10"): try: from colossalai.nn.parallel import GeminiDDP except ImportError: @@ -536,7 +552,6 @@ def group_texts(examples): ] optimizer = HybridAdam(optimizer_grouped_parameters, lr=args.learning_rate) - optimizer = ZeroOptimizer(optimizer, model, initial_scale=2**14) # Scheduler and math around the number of training steps. overrode_max_train_steps = False @@ -551,6 +566,7 @@ def group_texts(examples): num_warmup_steps=args.num_warmup_steps, num_training_steps=args.max_train_steps, ) + optimizer = GeminiOptimizer(optimizer, model, initial_scale=2**14) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) diff --git a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py index 7b664419b405..43fdcb21df2e 100644 --- a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py @@ -46,7 +46,7 @@ def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: b dist.barrier() new_bert_model = BertForSequenceClassification.from_pretrained(pretrained_path) - check_state_dict_equal(bert_model.unwrap().state_dict(only_rank_0=False, dtype=torch.float32), + check_state_dict_equal(bert_model.state_dict(only_rank_0=False, dtype=torch.float32), new_bert_model.state_dict(), False) @@ -87,12 +87,11 @@ def exam_state_dict(placement_policy, shard: bool, model_name: str, size_per_sha dist.barrier() booster.load_model(new_model, model_ckpt_path) - check_state_dict_equal(model.unwrap().state_dict(only_rank_0=False), - new_model.unwrap().state_dict(only_rank_0=False), False) + check_state_dict_equal(model.state_dict(only_rank_0=False), new_model.state_dict(only_rank_0=False), False) booster.load_optimizer(new_optimizer, optimizer_ckpt_path) - check_state_dict_equal(optimizer.unwrap().state_dict(only_rank_0=False), - new_optimizer.unwrap().state_dict(only_rank_0=False), False) + check_state_dict_equal(optimizer.state_dict(only_rank_0=False), new_optimizer.state_dict(only_rank_0=False), + False) # Check the new model/optimizer can successfully run. data = data_gen_fn() diff --git a/tests/test_checkpoint_io/test_gemini_torch_compability.py b/tests/test_checkpoint_io/test_gemini_torch_compability.py index 464fccb39103..4569ea12d82d 100644 --- a/tests/test_checkpoint_io/test_gemini_torch_compability.py +++ b/tests/test_checkpoint_io/test_gemini_torch_compability.py @@ -60,12 +60,11 @@ def exam_torch_load_from_gemini(shard: bool, model_name: str): new_booster.load_model(new_model, model_ckpt_path, strict=True) # Add prefix to get aligned with pytorch parameter names. - check_state_dict_equal( - model.unwrap().state_dict(only_rank_0=False, prefix='module.module.', dtype=torch.float32), - new_model.state_dict(), False) + check_state_dict_equal(model.state_dict(only_rank_0=False, prefix='module.module.', dtype=torch.float32), + new_model.state_dict(), False) new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path) - check_state_dict_equal(optimizer.unwrap().state_dict(only_rank_0=False), new_optimizer.state_dict(), False) + check_state_dict_equal(optimizer.state_dict(only_rank_0=False), new_optimizer.state_dict(), False) # Check the new model/optimizer can successfully run. data = data_gen_fn() @@ -124,13 +123,12 @@ def exam_gemini_load_from_torch(shard: bool, model_name: str): new_booster.load_model(new_model, model_ckpt_path, strict=True) # Add prefix to get aligned with pytorch parameter names. - check_state_dict_equal( - new_model.unwrap().state_dict(only_rank_0=False, prefix='module.module.', dtype=torch.float32), - model.state_dict(), False) + check_state_dict_equal(new_model.state_dict(only_rank_0=False, prefix='module.module.', dtype=torch.float32), + model.state_dict(), False) new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path) old_state_dict = optimizer.state_dict() - new_state_dict = new_optimizer.unwrap().state_dict(only_rank_0=False) + new_state_dict = new_optimizer.state_dict(only_rank_0=False) # Comparison of param_groups needs special care here, # since not all hyperparameters in Adam are used by HybridAdam @@ -138,7 +136,7 @@ def exam_gemini_load_from_torch(shard: bool, model_name: str): for old_group, new_group in zip(old_state_dict['param_groups'], new_state_dict['param_groups']): for k in hyperparameters_to_examine: assert k in old_group and k in new_group, \ - f"Old group's keys: {list(old_group.keys())}, New group's keys: {list(new_group.keys())}" + f"Old group's keys: {list(old_group.keys())}, New group's keys: {list(new_group.keys())}" assert old_group[k] == new_group[k] check_state_dict_equal(old_state_dict['state'], new_state_dict['state'], False) diff --git a/tests/test_zero/test_gemini/test_fwd_bwd.py b/tests/test_zero/test_gemini/test_fwd_bwd.py index d84a6e0fecbc..57d12c55b9b6 100644 --- a/tests/test_zero/test_gemini/test_fwd_bwd.py +++ b/tests/test_zero/test_gemini/test_fwd_bwd.py @@ -9,15 +9,14 @@ from colossalai.nn.optimizer import HybridAdam from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils.cuda import get_current_device -from colossalai.zero import ZeroDDP, ZeroOptimizer -from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration -from colossalai.zero.gemini.gemini_mgr import GeminiManager +from colossalai.zero import GeminiDDP, GeminiOptimizer +from colossalai.zero.gemini.chunk import search_chunk_configuration from tests.components_to_test import run_fwd, run_fwd_bwd from tests.components_to_test.registry import non_distributed_component_funcs from tests.test_tensor.common_utils import set_seed -def check_grad(model: ZeroDDP, torch_model: torch.nn.Module): +def check_grad(model: GeminiDDP, torch_model: torch.nn.Module): chunk_manager = model.chunk_manager param_list = [p for p in model.parameters()] chunk_list = chunk_manager.get_chunks(param_list) @@ -54,11 +53,9 @@ def exam_gpt_fwd_bwd( config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) config_dict[world_size]['chunk_size'] = 5000 config_dict[world_size]['keep_gathered'] = keep_gather - chunk_manager = ChunkManager(config_dict) - gemini_manager = GeminiManager(placement_policy, chunk_manager) - model = ZeroDDP(model, gemini_manager, pin_memory=True) + model = GeminiDDP(model, config_dict, placement_policy=placement_policy, pin_memory=True) optimizer = HybridAdam(model.parameters(), lr=1e-3) - zero_optim = ZeroOptimizer(optimizer, model, initial_scale=1) + zero_optim = GeminiOptimizer(optimizer, model, initial_scale=1) rank = dist.get_rank() amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1) @@ -114,9 +111,7 @@ def exam_gpt_inference( config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) config_dict[world_size]['chunk_size'] = 5000 config_dict[world_size]['keep_gathered'] = keep_gather - chunk_manager = ChunkManager(config_dict) - gemini_manager = GeminiManager(placement_policy, chunk_manager) - model = ZeroDDP(model, gemini_manager, pin_memory=True, scatter_after_inference=scatter_after_inference) + model = GeminiDDP(model, config_dict, pin_memory=True, scatter_after_inference=scatter_after_inference) rank = dist.get_rank() amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1) diff --git a/tests/test_zero/test_gemini/test_gemini_use_rmt.py b/tests/test_zero/test_gemini/test_gemini_use_rmt.py index b10be4753d20..a80a2f62de22 100644 --- a/tests/test_zero/test_gemini/test_gemini_use_rmt.py +++ b/tests/test_zero/test_gemini/test_gemini_use_rmt.py @@ -4,9 +4,8 @@ import colossalai from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from colossalai.zero import ZeroDDP -from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration -from colossalai.zero.gemini.gemini_mgr import GeminiManager +from colossalai.zero import GeminiDDP +from colossalai.zero.gemini.chunk import search_chunk_configuration from colossalai.zero.gemini.memory_tracer.runtime_mem_tracer import RuntimeMemTracer from tests.components_to_test import run_fwd_bwd from tests.components_to_test.registry import non_distributed_component_funcs @@ -58,9 +57,11 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_ config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) config_dict[world_size]['chunk_size'] = 5000 config_dict[world_size]['keep_gathered'] = keep_gather - chunk_manager = ChunkManager(config_dict) - gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats) - model = ZeroDDP(model, gemini_manager, pin_memory=True) + model = GeminiDDP(model, + chunk_config_dict=config_dict, + placement_policy=placement_policy, + pin_memory=True, + memstats=memstats) set_seed(dist.get_rank()) for i, (input_ids, label) in enumerate(train_dataloader): @@ -74,7 +75,7 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_ set_seed(42) loss = run_fwd_bwd(model, input_ids, label, criterion, model) - gemini_non_model_data = gemini_manager._mem_stats_collector._memstats.non_model_data_list('cuda') + gemini_non_model_data = model.gemini_manager._mem_stats_collector._memstats.non_model_data_list('cuda') # print('gemini non model data:', gemini_non_model_data) diff --git a/tests/test_zero/test_gemini/test_grad_clip.py b/tests/test_zero/test_gemini/test_grad_clip.py index 621cafabf447..69058256ae47 100644 --- a/tests/test_zero/test_gemini/test_grad_clip.py +++ b/tests/test_zero/test_gemini/test_grad_clip.py @@ -9,15 +9,14 @@ from colossalai.nn.optimizer import HybridAdam from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils.cuda import get_current_device -from colossalai.zero import ZeroDDP, ZeroOptimizer -from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration -from colossalai.zero.gemini.gemini_mgr import GeminiManager +from colossalai.zero import GeminiDDP, GeminiOptimizer +from colossalai.zero.gemini.chunk import search_chunk_configuration from tests.components_to_test import run_fwd_bwd from tests.components_to_test.registry import non_distributed_component_funcs from tests.test_tensor.common_utils import set_seed -def check_param(model: ZeroDDP, torch_model: torch.nn.Module): +def check_param(model: GeminiDDP, torch_model: torch.nn.Module): zero_dict = model.state_dict(only_rank_0=False) torch_dict = torch_model.state_dict() @@ -43,7 +42,6 @@ def exam_grad_clipping(placement_policy, model_name: str): torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) torch_model = DDP(torch_model, device_ids=[dist.get_rank()]) - init_dev = get_current_device() model = model_builder() for torch_p, p in zip(torch_model.parameters(), model.parameters()): @@ -57,12 +55,15 @@ def exam_grad_clipping(placement_policy, model_name: str): init_device = torch.device('cpu') else: init_device = None - chunk_manager = ChunkManager(config_dict, init_device=init_device) - gemini_manager = GeminiManager(placement_policy, chunk_manager) - model = ZeroDDP(model, gemini_manager, pin_memory=True) + + model = GeminiDDP(model, + chunk_config_dict=config_dict, + chunk_init_device=init_device, + placement_policy=placement_policy, + pin_memory=True) optimizer = HybridAdam(model.parameters(), lr=1e-3) - zero_optim = ZeroOptimizer(optimizer, model, initial_scale=32, clipping_norm=1.0) + zero_optim = GeminiOptimizer(optimizer, model, initial_scale=32, clipping_norm=1.0) model.train() torch_model.train() diff --git a/tests/test_zero/test_gemini/test_inference.py b/tests/test_zero/test_gemini/test_inference.py index 585f93b8b34f..74f51601cb23 100644 --- a/tests/test_zero/test_gemini/test_inference.py +++ b/tests/test_zero/test_gemini/test_inference.py @@ -11,15 +11,14 @@ from colossalai.nn.optimizer import HybridAdam from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils.cuda import get_current_device -from colossalai.zero import ZeroDDP, ZeroOptimizer, zero_model_wrapper -from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration -from colossalai.zero.gemini.gemini_mgr import GeminiManager +from colossalai.zero import GeminiDDP, GeminiOptimizer +from colossalai.zero.gemini.chunk import search_chunk_configuration from tests.components_to_test import run_fwd_bwd from tests.components_to_test.registry import non_distributed_component_funcs from tests.test_tensor.common_utils import set_seed -def check_param(model: ZeroDDP, torch_model: torch.nn.Module): +def check_param(model: GeminiDDP, torch_model: torch.nn.Module): zero_dict = model.state_dict(only_rank_0=False) torch_dict = torch_model.state_dict() @@ -41,19 +40,12 @@ def multi_chunk_init(model: torch.nn.Module, placement_policy: str): init_device = torch.device('cpu') else: init_device = None - chunk_manager = ChunkManager(config_dict, init_device=init_device) - gemini_manager = GeminiManager(placement_policy, chunk_manager) - model = ZeroDDP(model, gemini_manager, pin_memory=True) + model = GeminiDDP(model, config_dict, init_device, placement_policy=placement_policy, pin_memory=True) return model def single_chunk_init(model: torch.nn.Module, placement_policy: str): - gemini_config = dict( - device=get_current_device(), - placement_policy=placement_policy, - pin_memory=True, - ) - model = zero_model_wrapper(model=model, zero_stage=3, gemini_config=gemini_config) + model = GeminiDDP(model, chunk_init_device=get_current_device(), placement_policy=placement_policy, pin_memory=True) return model @@ -79,7 +71,7 @@ def exam_inference(placement_policy: str, model_name: str, model_init_func: Call model = model_init_func(model, placement_policy) optimizer = HybridAdam(model.parameters(), lr=1e-3) - zero_optim = ZeroOptimizer(optimizer, model, initial_scale=128) + zero_optim = GeminiOptimizer(optimizer, model, initial_scale=128) model.eval() torch_model.eval() diff --git a/tests/test_zero/test_gemini/test_optim.py b/tests/test_zero/test_gemini/test_optim.py index df118a764a2d..5eb3e4e4ea66 100644 --- a/tests/test_zero/test_gemini/test_optim.py +++ b/tests/test_zero/test_gemini/test_optim.py @@ -9,9 +9,8 @@ from colossalai.nn.optimizer import HybridAdam from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils.cuda import get_current_device -from colossalai.zero import ZeroDDP, ZeroOptimizer -from colossalai.zero.gemini.chunk import ChunkManager, init_chunk_manager, search_chunk_configuration -from colossalai.zero.gemini.gemini_mgr import GeminiManager +from colossalai.zero import GeminiDDP, GeminiOptimizer +from colossalai.zero.gemini.chunk import search_chunk_configuration from tests.components_to_test import run_fwd_bwd from tests.components_to_test.registry import non_distributed_component_funcs from tests.test_tensor.common_utils import set_seed @@ -29,7 +28,7 @@ ] -def check_param(model: ZeroDDP, torch_model: torch.nn.Module, dtype: torch.dtype): +def check_param(model: GeminiDDP, torch_model: torch.nn.Module, dtype: torch.dtype): zero_dict = model.state_dict(only_rank_0=False, dtype=dtype) torch_dict = torch_model.state_dict() @@ -78,12 +77,10 @@ def exam_model_step(placement_policy, model_name: str, mixed_precision: torch.dt init_device = torch.device('cpu') else: init_device = None - chunk_manager = ChunkManager(config_dict, init_device=init_device) - gemini_manager = GeminiManager(placement_policy, chunk_manager) - model = ZeroDDP(model, gemini_manager, pin_memory=True, mixed_precision=mixed_precision) + model = GeminiDDP(model, config_dict, init_device, placement_policy, mixed_precision=mixed_precision) optimizer = HybridAdam(model.parameters(), lr=1e-3) - zero_optim = ZeroOptimizer(optimizer, model, initial_scale=128) + zero_optim = GeminiOptimizer(optimizer, model, initial_scale=128) model.eval() torch_model.eval() @@ -126,11 +123,13 @@ def exam_tiny_example(placement_policy, model_name: str, mixed_precision: torch. for torch_p, p in zip(torch_model.parameters(), model.parameters()): p.data.copy_(torch_p.data) - chunk_manager = init_chunk_manager(model=model, init_device=get_current_device(), search_range_m=1) - gemini_manager = GeminiManager(placement_policy, chunk_manager) - model = ZeroDDP(model, gemini_manager, pin_memory=True, mixed_precision=mixed_precision) + model = GeminiDDP(model, + chunk_init_device=get_current_device(), + search_range_m=1, + pin_memory=True, + mixed_precision=mixed_precision) optimizer = HybridAdam(model.parameters(), lr=1e-3) - zero_optim = ZeroOptimizer(optimizer, model, initial_scale=2) + zero_optim = GeminiOptimizer(optimizer, model, initial_scale=2) model.eval() torch_model.eval() diff --git a/tests/test_zero/test_gemini/test_zeroddp_state_dict.py b/tests/test_zero/test_gemini/test_zeroddp_state_dict.py index fb30b0d84fcf..12c195efd6ed 100644 --- a/tests/test_zero/test_gemini/test_zeroddp_state_dict.py +++ b/tests/test_zero/test_gemini/test_zeroddp_state_dict.py @@ -4,9 +4,8 @@ import colossalai from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from colossalai.zero import ZeroDDP -from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration -from colossalai.zero.gemini.gemini_mgr import GeminiManager +from colossalai.zero import GeminiDDP +from colossalai.zero.gemini.chunk import search_chunk_configuration from tests.components_to_test.registry import non_distributed_component_funcs from tests.test_tensor.common_utils import set_seed @@ -14,7 +13,7 @@ def ignore_the_first_parameter(model: torch.nn.Module): for name, param in model.named_parameters(): print(f"parameter `{name}` is set ignored") - ZeroDDP.set_params_to_ignore([param]) + GeminiDDP.set_params_to_ignore([param]) return @@ -36,9 +35,7 @@ def exam_state_dict(placement_policy, keep_gathered, model_name: str): config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) config_dict[world_size]['chunk_size'] = 5000 config_dict[world_size]['keep_gathered'] = keep_gathered - chunk_manager = ChunkManager(config_dict) - gemini_manager = GeminiManager(placement_policy, chunk_manager) - model = ZeroDDP(model, gemini_manager, pin_memory=True) + model = GeminiDDP(model, config_dict, placement_policy=placement_policy, pin_memory=True) model.train() zero_dict = model.state_dict(only_rank_0=False) @@ -72,9 +69,7 @@ def exam_load_state_dict(placement_policy, keep_gathered, model_name: str): init_device = torch.device('cpu') else: init_device = None - chunk_manager = ChunkManager(config_dict, init_device=init_device) - gemini_manager = GeminiManager(placement_policy, chunk_manager) - model = ZeroDDP(model, gemini_manager, pin_memory=True) + model = GeminiDDP(model, config_dict, init_device, placement_policy, pin_memory=True) torch_dict = torch_model.state_dict() model.load_state_dict(torch_dict, strict=False) diff --git a/tests/test_zero/test_gemini/test_zeroddp_state_dict_shard.py b/tests/test_zero/test_gemini/test_zeroddp_state_dict_shard.py index 0ea876e10849..c8ac8a8502c0 100644 --- a/tests/test_zero/test_gemini/test_zeroddp_state_dict_shard.py +++ b/tests/test_zero/test_gemini/test_zeroddp_state_dict_shard.py @@ -3,9 +3,8 @@ import colossalai from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from colossalai.zero import ZeroDDP -from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration -from colossalai.zero.gemini.gemini_mgr import GeminiManager +from colossalai.zero import GeminiDDP +from colossalai.zero.gemini.chunk import search_chunk_configuration from tests.components_to_test.registry import non_distributed_component_funcs @@ -20,9 +19,7 @@ def exam_state_dict(placement_policy, model_name: str): model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2 config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) - chunk_manager = ChunkManager(config_dict) - gemini_manager = GeminiManager(placement_policy, chunk_manager) - model = ZeroDDP(model, gemini_manager) + model = GeminiDDP(model, config_dict, placement_policy=placement_policy) model.train() zero_dict = model.state_dict(only_rank_0=False) diff --git a/tests/test_zero/test_gemini/test_zerooptim_state_dict.py b/tests/test_zero/test_gemini/test_zerooptim_state_dict.py index 2908538f94de..80e8821c1bf7 100644 --- a/tests/test_zero/test_gemini/test_zerooptim_state_dict.py +++ b/tests/test_zero/test_gemini/test_zerooptim_state_dict.py @@ -5,9 +5,8 @@ import colossalai from colossalai.nn.optimizer import HybridAdam from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from colossalai.zero import ZeroDDP, ZeroOptimizer -from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration -from colossalai.zero.gemini.gemini_mgr import GeminiManager +from colossalai.zero import GeminiDDP, GeminiOptimizer +from colossalai.zero.gemini.chunk import search_chunk_configuration from tests.components_to_test.registry import non_distributed_component_funcs from tests.test_tensor.common_utils import set_seed @@ -33,12 +32,10 @@ def exam_zero_optim_state_dict(placement_policy, keep_gathered): init_device = torch.device('cpu') else: init_device = None - chunk_manager = ChunkManager(config_dict, init_device=init_device) - gemini_manager = GeminiManager(placement_policy, chunk_manager) - model = ZeroDDP(model, gemini_manager, pin_memory=True) + model = GeminiDDP(model, config_dict, init_device, placement_policy, pin_memory=True) optimizer = HybridAdam(model.parameters()) - optim = ZeroOptimizer(optimizer, model, initial_scale=32) # initialize the link between chunk16 and chunk32 + optim = GeminiOptimizer(optimizer, model, initial_scale=32) # initialize the link between chunk16 and chunk32 set_seed(dist.get_rank() * 3 + 128) model.train()