Skip to content

Commit

Permalink
[gemini] refactor gemini optimizer and gemini ddp (#4398)
Browse files Browse the repository at this point in the history
* [gemini] update optimizer interface

* [gemini] renaming gemini optimizer

* [gemini] refactor gemini ddp class

* [example] update gemini related example

* [example] update gemini related example

* [plugin] fix gemini plugin args

* [test] update gemini ckpt tests

* [gemini] fix checkpoint io

* [example] fix opt example requirements

* [example] fix opt example

* [example] fix opt example

* [example] fix opt example
  • Loading branch information
ver217 authored Aug 14, 2023
1 parent e3d732c commit cf98b30
Show file tree
Hide file tree
Showing 22 changed files with 193 additions and 284 deletions.
63 changes: 11 additions & 52 deletions colossalai/booster/plugin/gemini_plugin.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import gc
import logging
import os
import warnings
from pathlib import Path
from typing import Callable, Iterator, List, Optional, Tuple, Union
from typing import Callable, Iterator, List, Optional, Tuple

import torch
import torch.nn as nn
from torch import Tensor
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
Expand All @@ -16,16 +14,14 @@
from colossalai.checkpoint_io.utils import (
get_model_base_filenames,
get_optimizer_base_filenames,
get_shard_filename,
load_shard_state_dict,
save_state_dict,
save_state_dict_shards,
)
from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.utils import get_current_device
from colossalai.zero import GeminiDDP, zero_model_wrapper, zero_optim_wrapper
from colossalai.zero.gemini import ZeroOptimizer
from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.memory_tracer import MemStats

from .dp_plugin_base import DPPluginBase
Expand Down Expand Up @@ -132,11 +128,7 @@ def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_
As there is communication when getting state dict, this must be called on all processes.
"""

# If optimizer is wrapped, unwrap it.
if isinstance(optimizer, OptimizerWrapper):
optimizer = optimizer.unwrap()

assert isinstance(optimizer, ZeroOptimizer)
assert isinstance(optimizer, GeminiOptimizer)

if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
Expand Down Expand Up @@ -183,11 +175,7 @@ def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint_index_file: Pa
if not os.path.isfile(checkpoint_index_file):
logging.error(f"Provided path ({checkpoint_index_file}) should be a file")

# If optimizer is wrapped, unwrap it.
if isinstance(optimizer, OptimizerWrapper):
optimizer = optimizer.unwrap()

assert isinstance(optimizer, ZeroOptimizer)
assert isinstance(optimizer, GeminiOptimizer)

# Read checkpoint index file.
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
Expand Down Expand Up @@ -220,36 +208,6 @@ def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
super().save_lr_scheduler(lr_scheduler, checkpoint)


class GeminiOptimizer(OptimizerWrapper):

def __init__(self,
module: GeminiDDP,
optimizer: Optimizer,
zero_optim_config: dict,
optim_kwargs: dict,
verbose: bool = False) -> None:
optimizer = zero_optim_wrapper(module,
optimizer,
optim_config=zero_optim_config,
**optim_kwargs,
verbose=verbose)
super().__init__(optimizer)

def backward(self, loss: Tensor, *args, **kwargs):
self.optim.backward(loss)

def clip_grad_by_norm(self,
max_norm: Union[float, int],
norm_type: Union[float, int] = 2,
error_if_nonfinite: bool = False,
*args,
**kwargs) -> Tensor:
warnings.warn(f'Gemini controls grad clipping by itself, so you should not use clip_grad_by_norm')

def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None:
raise NotImplementedError('Gemini does not support clip_grad_by_value')


class GeminiPlugin(DPPluginBase):
"""
Plugin for Gemini.
Expand Down Expand Up @@ -299,7 +257,7 @@ class GeminiPlugin(DPPluginBase):

def __init__(
self,
device: Optional[torch.device] = None,
chunk_init_device: Optional[torch.device] = None,
placement_policy: str = "cpu",
precision: str = "fp16",
pin_memory: bool = False,
Expand All @@ -324,7 +282,7 @@ def __init__(
super().__init__()
assert precision in SUPPORTED_PRECISION, f'precision {precision} is not supported'
self.gemini_config = dict(
device=(device or get_current_device()),
chunk_init_device=(chunk_init_device or get_current_device()),
placement_policy=placement_policy,
pin_memory=pin_memory,
force_outputs_fp32=force_outputs_fp32,
Expand Down Expand Up @@ -383,13 +341,14 @@ def configure(

# wrap the model with Gemini
model = GeminiDDP(model, **self.gemini_config, verbose=self.verbose)
# TODO(ver217): remove this line
model._colo_zero_stage = 3

if optimizer is not None and \
not isinstance(optimizer, OptimizerWrapper):
optimizer = GeminiOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs,
self.verbose)
optimizer = GeminiOptimizer(optimizer,
model.unwrap(),
**self.zero_optim_config,
**self.optim_kwargs,
verbose=self.verbose)

return model, optimizer, criterion, dataloader, lr_scheduler

Expand Down
5 changes: 2 additions & 3 deletions colossalai/zero/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@
ColoInitContext,
GeminiAdamOptimizer,
GeminiDDP,
ZeroDDP,
ZeroOptimizer,
GeminiOptimizer,
get_static_torch_model,
post_process_colo_init_ctx,
)
from .low_level import LowLevelZeroOptimizer
from .wrapper import zero_model_wrapper, zero_optim_wrapper

__all__ = [
'ZeroDDP', 'GeminiDDP', 'ZeroOptimizer', 'GeminiAdamOptimizer', 'zero_model_wrapper', 'zero_optim_wrapper',
'GeminiDDP', 'GeminiOptimizer', 'GeminiAdamOptimizer', 'zero_model_wrapper', 'zero_optim_wrapper',
'LowLevelZeroOptimizer', 'ColoInitContext', 'post_process_colo_init_ctx', 'get_static_torch_model'
]
8 changes: 4 additions & 4 deletions colossalai/zero/gemini/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from .chunk import ChunkManager, TensorInfo, TensorState, search_chunk_configuration
from .colo_init_context import ColoInitContext, post_process_colo_init_ctx
from .gemini_ddp import GeminiDDP, ZeroDDP
from .gemini_ddp import GeminiDDP
from .gemini_mgr import GeminiManager
from .gemini_optimizer import GeminiAdamOptimizer, ZeroOptimizer
from .gemini_optimizer import GeminiAdamOptimizer, GeminiOptimizer
from .utils import get_static_torch_model

__all__ = [
'GeminiManager', 'TensorInfo', 'TensorState', 'ChunkManager', 'search_chunk_configuration', 'ZeroDDP', 'GeminiDDP',
'get_static_torch_model', 'GeminiAdamOptimizer', 'ZeroOptimizer', 'ColoInitContext', 'post_process_colo_init_ctx'
'GeminiManager', 'TensorInfo', 'TensorState', 'ChunkManager', 'search_chunk_configuration', 'GeminiDDP',
'get_static_torch_model', 'GeminiAdamOptimizer', 'GeminiOptimizer', 'ColoInitContext', 'post_process_colo_init_ctx'
]
120 changes: 37 additions & 83 deletions colossalai/zero/gemini/gemini_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,13 @@
_EXTRA_STATE_KEY_SUFFIX = '_extra_state'

__all__ = [
'ZeroDDP',
'GeminiDDP',
]


class ZeroDDP(ModelWrapper):
class GeminiDDP(ModelWrapper):
"""ZeRO DDP.
Warning: Nested ZeroDDP is not supported now.
Warning: Nested GeminiDDP is not supported now.
It is designed to be used with ChunkManager and GeminiManager.
For more details, see the API reference of ``ChunkManager`` and ``GeminiManager``.
Expand All @@ -55,20 +54,42 @@ class ZeroDDP(ModelWrapper):
mixed_precision (torch.dtype): If set to torch.float16, the model will be trained in fp16. Otherwise, the model will be trained in bf16. Defaults to torch.float16.
"""

def __init__(self,
module: torch.nn.Module,
gemini_manager: GeminiManager,
pin_memory: bool = False,
force_outputs_fp32: bool = False,
strict_ddp_mode: bool = False,
scatter_after_inference: bool = True,
mixed_precision: torch.dtype = torch.float16,
process_group: Optional[ProcessGroup] = None) -> None:
def __init__(
self,
module: torch.nn.Module,
chunk_config_dict: Optional[dict] = None,
chunk_init_device: torch.device = torch.device('cpu'),
placement_policy: str = "cpu",
search_range_m: int = 32, # chunk search options
hidden_dim: Optional[int] = None, # chunk search options
min_chunk_size_m: float = 32, # chunk search options
pin_memory: bool = False,
force_outputs_fp32: bool = False,
strict_ddp_mode: bool = False,
scatter_after_inference: bool = True,
mixed_precision: torch.dtype = torch.float16,
process_group: Optional[ProcessGroup] = None,
memstats: Optional[MemStats] = None, # genimi memory stats
verbose: bool = False) -> None:
assert mixed_precision in (torch.float16, torch.bfloat16)
self.gemini_manager = gemini_manager
self.chunk_manager: ChunkManager = gemini_manager.chunk_manager
if chunk_config_dict is not None:
self.chunk_manager = ChunkManager(chunk_config_dict, chunk_init_device)
else:
# some ugly hotfix for the compatibility with Lightning
if search_range_m is None:
search_range_m = 32
self.chunk_manager = init_chunk_manager(model=module,
init_device=chunk_init_device,
hidden_dim=hidden_dim,
search_range_m=search_range_m,
min_chunk_size_m=min_chunk_size_m,
strict_ddp_flag=strict_ddp_mode,
process_group=process_group,
verbose=verbose)
self.gemini_manager = GeminiManager(placement_policy, self.chunk_manager, memstats)

self.force_outputs_fp32 = force_outputs_fp32
self.param_op_hook = GeminiZeROHook(gemini_manager)
self.param_op_hook = GeminiZeROHook(self.gemini_manager)
self.fp32_params: List[torch.Tensor] = list()
self.fp16_params: List[ColoParameter] = list()
self.overflow_counter = 0
Expand Down Expand Up @@ -257,7 +278,7 @@ def _post_backward(self):
error_params.append(self.param2name[param])
error_str = "\n\t".join(error_params)
raise RuntimeError("ZERO DDP error: the synchronization of gradients doesn't exit properly.",
"The most possible reason is that the model is not compatible with ZeroDDP.\n",
"The most possible reason is that the model is not compatible with GeminiDDP.\n",
f"{error_str}")
self._setup_grads_ptr()
self._logger.debug(
Expand Down Expand Up @@ -772,70 +793,3 @@ def append(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[OrderedDict]
self.current_block[name] = tensor
self.current_block_size += tensor_size
return ret_block, ret_block_size


class GeminiDDP(ZeroDDP):

def __init__(self,
module: torch.nn.Module,
device: torch.device,
placement_policy: str = "cpu",
pin_memory: bool = False,
force_outputs_fp32: bool = False,
strict_ddp_mode: bool = False,
scatter_after_inference: bool = True,
search_range_m: int = 32,
hidden_dim: Optional[int] = None,
min_chunk_size_m: float = 32,
memstats: Optional[MemStats] = None,
mixed_precision: torch.dtype = torch.float16,
process_group: Optional[ProcessGroup] = None,
verbose: bool = False) -> None:
"""
A torch.Module wrapper using ZeRO-DP and Gemini.
ZeRO is for parallel. Gemini is for memory management.
WARNING: The class will modify the module inline!
Example:
model is initialized under the context of ColoInitContext
>>> model = GeminiDDP(model, torch.cuda.current_device(), "cuda")
>>> logits = model(x)
>>> loss = criterion(logits, labels)
>>> model.backward(loss)
Args:
module (torch.nn.Module): the model to be wrapped.
device (torch.device): device to place the model.
placement_policy (str, optional): "cpu", "cuda", "auto". Defaults to "cpu".
pin_memory (bool, optional): use pin memory on CPU. Defaults to False.
force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False.
search_range_m (int, optional): chunk size searching range divided by 2^20. Defaults to 32.
hidden_dim (int, optional): the hidden dimension of DNN.
Users can provide this argument to speed up searching.
If users do not know this argument before training, it is ok. We will use a default value 1024.
min_chunk_size_m (float, optional): the minimum chunk size divided by 2^20.
If the aggregate size of parameters is still smaller than the minimum chunk size,
all parameters will be compacted into one small chunk.
memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer.
"""
# some ugly hotfix for the compatibility with Lightning
if search_range_m is None:
search_range_m = 32

chunk_manager = init_chunk_manager(model=module,
init_device=device,
hidden_dim=hidden_dim,
search_range_m=search_range_m,
min_chunk_size_m=min_chunk_size_m,
strict_ddp_flag=strict_ddp_mode,
process_group=process_group,
verbose=verbose)
gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats)
super().__init__(module,
gemini_manager,
pin_memory,
force_outputs_fp32,
strict_ddp_mode,
scatter_after_inference,
mixed_precision=mixed_precision,
process_group=process_group)
Loading

0 comments on commit cf98b30

Please sign in to comment.