Skip to content

Commit

Permalink
Merge branch 'main' into feature/shardformer
Browse files Browse the repository at this point in the history
  • Loading branch information
ver217 authored Sep 5, 2023
2 parents bd18678 + ac178ca commit fae6c92
Show file tree
Hide file tree
Showing 113 changed files with 627 additions and 631 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
class Registry:
# TODO: refactor the registry classes used in colossalai.registry, colossalai.fx and here
# TODO: refactor the registry classes used in colossalai.legacy.registry, colossalai.fx and here

def __init__(self, name):
self.name = name
Expand Down
129 changes: 83 additions & 46 deletions colossalai/booster/plugin/low_level_zero_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import warnings
from functools import partial
from pathlib import Path
from types import MethodType
from typing import Callable, Iterator, List, Optional, Tuple, Union

import torch
Expand All @@ -25,9 +26,9 @@
sharded_optimizer_loading_epilogue,
unwrap_optimizer,
)
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.utils import get_current_device
from colossalai.zero import LowLevelZeroOptimizer, zero_model_wrapper, zero_optim_wrapper
from colossalai.zero import LowLevelZeroOptimizer

from .dp_plugin_base import DPPluginBase
from .torch_ddp_plugin import TorchDDPCheckpointIO
Expand All @@ -44,6 +45,34 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
SUPPORTED_PRECISION = ['fp16', 'bf16', 'fp32']


class LowLevelZeroModel(ModelWrapper, AMPModelMixin):

def __init__(self, module: nn.Module, precision: str) -> None:
super().__init__(module)
self.dtype = None
if precision == 'fp16':
self.dtype = torch.float16
elif precision == 'bf16':
self.dtype = torch.bfloat16
if self.dtype is not None:
module = module.to(self.dtype)
module = module.to(get_current_device())
self.module = module
self.convert_fn = None
if self.dtype is not None:
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)

def forward(self, *args, **kwargs):
if self.convert_fn is not None:
args = tree_map(self.convert_fn, args)
kwargs = tree_map(self.convert_fn, kwargs)
return super().forward(*args, **kwargs)

def unwrap(self):
# TODO(ver217): this is a workaround for loading model
return self


class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):

def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False):
Expand Down Expand Up @@ -165,30 +194,36 @@ def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: s

sharded_optimizer_loading_epilogue(optimizer)


class LowLevelZeroModel(ModelWrapper):

def __init__(self, module: nn.Module, stage: int, precision: str) -> None:
super().__init__(module)
self.dtype = None
if precision == 'fp16':
self.dtype = torch.float16
elif precision == 'bf16':
self.dtype = torch.bfloat16
module = zero_model_wrapper(module, zero_stage=stage)
if self.dtype is not None:
module = module.to(self.dtype)
module = module.to(get_current_device())
self.module = module
self.convert_fn = None
if self.dtype is not None:
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)

def forward(self, *args, **kwargs):
if self.convert_fn is not None:
args = tree_map(self.convert_fn, args)
kwargs = tree_map(self.convert_fn, kwargs)
return super().forward(*args, **kwargs)
def save_unsharded_model(self, model: LowLevelZeroModel, checkpoint: str, gather_dtensor: bool,
use_safetensors: bool):
assert isinstance(model, LowLevelZeroModel)
super().save_unsharded_model(model.module, checkpoint, gather_dtensor, use_safetensors)

def save_sharded_model(self,
model: nn.Module,
checkpoint_path: str,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
max_shard_size: int = 1024,
use_safetensors: bool = False):
assert isinstance(model, LowLevelZeroModel)
super().save_sharded_model(model.module, checkpoint_path, gather_dtensor, prefix, max_shard_size,
use_safetensors)

def load_unsharded_model(self, model: LowLevelZeroModel, checkpoint: str, strict: bool = True):
assert isinstance(model, LowLevelZeroModel)
super().load_unsharded_model(model.module, checkpoint, strict)
model.update_master_params()

def load_sharded_model(self,
model: LowLevelZeroModel,
checkpoint_index_file: Path,
strict: bool = False,
use_safetensors: bool = False,
load_sub_module: bool = True):
assert isinstance(model, LowLevelZeroModel)
super().load_sharded_model(model.module, checkpoint_index_file, strict, use_safetensors, load_sub_module)
model.update_master_params()


class LowLevelZeroPlugin(DPPluginBase):
Expand Down Expand Up @@ -248,22 +283,24 @@ def __init__(
super().__init__()
assert stage in (1, 2), f'LowLevelZeroPlugin only supports stage 1/2 training'
assert precision in SUPPORTED_PRECISION, f'LowLevelZeroPlugin only supports {SUPPORTED_PRECISION} training'

assert norm_type == 2.0, f'LowLevelZeroPlugin only supports norm_type=2.0 now'
self.stage = stage
self.precision = precision
self.zero_optim_config = dict(reduce_bucket_size=reduce_bucket_size_in_m * 1024 * 1024,
communication_dtype=communication_dtype,
overlap_communication=overlap_communication,
cpu_offload=cpu_offload)
self.optim_kwargs = dict(initial_scale=initial_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
min_scale=min_scale,
max_scale=max_scale,
max_norm=max_norm,
norm_type=norm_type)
self.zero_optim_kwargs = dict(
initial_scale=initial_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
min_scale=min_scale,
max_scale=max_scale,
clip_grad_norm=max_norm,
reduce_bucket_size=reduce_bucket_size_in_m * 1024 * 1024,
communication_dtype=communication_dtype,
overlap_communication=overlap_communication,
cpu_offload=cpu_offload,
partition_grad=(stage == 2),
)
self.verbose = verbose

# set class name with stage, for better error message
Expand Down Expand Up @@ -294,15 +331,15 @@ def configure(
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:

if not isinstance(model, ModelWrapper):
model = LowLevelZeroModel(model, self.stage, self.precision)
model = LowLevelZeroModel(model, self.precision)

if optimizer is not None and \
not isinstance(optimizer, OptimizerWrapper):
optimizer = zero_optim_wrapper(model.unwrap(),
optimizer,
optim_config=self.zero_optim_config,
**self.optim_kwargs,
verbose=self.verbose)
optimizer: LowLevelZeroOptimizer = LowLevelZeroOptimizer(optimizer,
**self.zero_optim_kwargs,
verbose=self.verbose)
# inject update_master_params
model.update_master_params = MethodType(optimizer.update_master_params, model)

return model, optimizer, criterion, dataloader, lr_scheduler

Expand Down
2 changes: 1 addition & 1 deletion colossalai/context/parallel_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from colossalai.context.config import Config
from colossalai.context.singleton_meta import SingletonMeta
from colossalai.global_variables import tensor_parallel_env as env
from colossalai.legacy.registry import DIST_GROUP_INITIALIZER
from colossalai.logging import get_dist_logger
from colossalai.registry import DIST_GROUP_INITIALIZER

from .parallel_mode import ParallelMode
from .random import add_seed, get_seeds, set_mode
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
# -*- encoding: utf-8 -*-

import torch.distributed as dist

from colossalai.global_variables import tensor_parallel_env as env
from colossalai.registry import DIST_GROUP_INITIALIZER
from colossalai.legacy.registry import DIST_GROUP_INITIALIZER

from ..parallel_mode import ParallelMode
from .process_group_initializer import ProcessGroupInitializer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch.distributed as dist

from colossalai.global_variables import tensor_parallel_env as env
from colossalai.registry import DIST_GROUP_INITIALIZER
from colossalai.legacy.registry import DIST_GROUP_INITIALIZER

from ..parallel_mode import ParallelMode
from .process_group_initializer import ProcessGroupInitializer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
import math

import torch.distributed as dist

from colossalai.context import Config
from colossalai.global_variables import tensor_parallel_env as env
from colossalai.registry import DIST_GROUP_INITIALIZER
from colossalai.legacy.registry import DIST_GROUP_INITIALIZER

from ..parallel_mode import ParallelMode
from .process_group_initializer import ProcessGroupInitializer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch.distributed as dist

from colossalai.global_variables import tensor_parallel_env as env
from colossalai.registry import DIST_GROUP_INITIALIZER
from colossalai.legacy.registry import DIST_GROUP_INITIALIZER

from ..parallel_mode import ParallelMode
from .process_group_initializer import ProcessGroupInitializer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from torch import distributed as dist

from colossalai.registry import DIST_GROUP_INITIALIZER
from colossalai.legacy.registry import DIST_GROUP_INITIALIZER

from ..parallel_mode import ParallelMode
from .process_group_initializer import ProcessGroupInitializer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
# -*- encoding: utf-8 -*-

import torch.distributed as dist
from colossalai.registry import DIST_GROUP_INITIALIZER
from .process_group_initializer import ProcessGroupInitializer

from colossalai.legacy.registry import DIST_GROUP_INITIALIZER

from ..parallel_mode import ParallelMode
from .process_group_initializer import ProcessGroupInitializer


@DIST_GROUP_INITIALIZER.register_module
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from torch import distributed as dist

from colossalai.registry import DIST_GROUP_INITIALIZER
from colossalai.legacy.registry import DIST_GROUP_INITIALIZER

from ..parallel_mode import ParallelMode
from .process_group_initializer import ProcessGroupInitializer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# -*- encoding: utf-8 -*-
import torch.distributed as dist

from colossalai.registry import DIST_GROUP_INITIALIZER
from colossalai.legacy.registry import DIST_GROUP_INITIALIZER

from ..parallel_mode import ParallelMode
from .initializer_tensor import Initializer_Tensor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@

import torch.distributed as dist

from colossalai.registry import DIST_GROUP_INITIALIZER
from .process_group_initializer import ProcessGroupInitializer
from colossalai.legacy.registry import DIST_GROUP_INITIALIZER

from ..parallel_mode import ParallelMode
from .process_group_initializer import ProcessGroupInitializer


@DIST_GROUP_INITIALIZER.register_module
Expand Down
8 changes: 4 additions & 4 deletions colossalai/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@

from colossalai.amp import AMP_TYPE, convert_to_amp
from colossalai.amp.naive_amp import NaiveAMPModel
from colossalai.builder.builder import build_gradient_handler
from colossalai.context import Config, ConfigException, ParallelMode
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.core import global_context as gpc
from colossalai.engine import Engine
from colossalai.engine.gradient_accumulation import accumulate_gradient
from colossalai.engine.schedule import (
from colossalai.legacy.builder.builder import build_gradient_handler
from colossalai.legacy.engine import Engine
from colossalai.legacy.engine.gradient_accumulation import accumulate_gradient
from colossalai.legacy.engine.schedule import (
InterleavedPipelineSchedule,
NonPipelineSchedule,
PipelineSchedule,
Expand Down
4 changes: 2 additions & 2 deletions colossalai/interface/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .model import ModelWrapper
from .model import AMPModelMixin, ModelWrapper
from .optimizer import OptimizerWrapper

__all__ = ['OptimizerWrapper', 'ModelWrapper']
__all__ = ['OptimizerWrapper', 'ModelWrapper', 'AMPModelMixin']
11 changes: 11 additions & 0 deletions colossalai/interface/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,14 @@ def unwrap(self):

def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)


class AMPModelMixin:
"""This mixin class defines the interface for AMP training.
"""

def update_master_params(self):
"""
Update the master parameters for AMP training.
"""
pass
Empty file added colossalai/legacy/__init__.py
Empty file.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import inspect

from colossalai.registry import *
from colossalai.legacy.registry import *


def build_from_config(module, config: dict):
Expand Down Expand Up @@ -71,7 +71,7 @@ def build_gradient_handler(config, model, optimizer):
optimizer (:class:`torch.optim.Optimizer`): An optimizer object containing parameters for the gradient handler
Returns:
An object of :class:`colossalai.engine.BaseGradientHandler`
An object of :class:`colossalai.legacy.engine.BaseGradientHandler`
"""
config_ = config.copy()
config_['model'] = model
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,17 @@
from torch.nn import Module
from torch.nn.modules.loss import _Loss

from colossalai.engine.gradient_handler import BaseGradientHandler
from colossalai.engine.schedule import BaseSchedule, InterleavedPipelineSchedule, NonPipelineSchedule, PipelineSchedule
from colossalai.legacy.engine.gradient_handler import BaseGradientHandler
from colossalai.legacy.engine.schedule import (
BaseSchedule,
InterleavedPipelineSchedule,
NonPipelineSchedule,
PipelineSchedule,
)
from colossalai.logging import get_dist_logger
from colossalai.zero.legacy.gemini import BaseOpHook, register_ophooks_recursively
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.zero.legacy.gemini import BaseOpHook, register_ophooks_recursively


class Engine:
"""Basic engine class for training and evaluation. It runs a specific process method
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler

from colossalai.engine import BaseGradientHandler
from colossalai.legacy.engine import BaseGradientHandler

from ._gradient_accumulation import (
GradAccumDataloader,
Expand Down Expand Up @@ -33,7 +33,7 @@ def accumulate_gradient(model: nn.Module,
dataloader (:class:`torch.utils.data.DataLoader` or iterable objects):
your dataloader object, would be called like iter(dataloader)
accumulate_size (int): the number of steps to accumulate gradients
gradient_handlers (List[:class:`colossalai.engine.BaseGradientHandler`]):
gradient_handlers (List[:class:`colossalai.legacy.engine.BaseGradientHandler`]):
list of gradient handler objects. Default is None.
lr_scheduler (`torch.optim.lr_scheduler` or `colossalai.nn.lr_scheduler`):
your ``lr_scheduler`` object for gradient accumulation. Defaults to None.
Expand Down
Loading

0 comments on commit fae6c92

Please sign in to comment.