From cc3cbe9f6f291af172252f097952bfe247200195 Mon Sep 17 00:00:00 2001
From: Frank Lee
Date: Tue, 4 Jul 2023 18:11:46 +0800
Subject: [PATCH 01/64] [workflow] show test duration (#4159)
---
.github/workflows/build_on_pr.yml | 2 +-
.github/workflows/build_on_schedule.yml | 4 ++--
2 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/.github/workflows/build_on_pr.yml b/.github/workflows/build_on_pr.yml
index 5f4e4feaa230..380c8e9f882c 100644
--- a/.github/workflows/build_on_pr.yml
+++ b/.github/workflows/build_on_pr.yml
@@ -208,7 +208,7 @@ jobs:
- name: Execute Unit Testing
run: |
- CURL_CA_BUNDLE="" PYTHONPATH=$PWD pytest --testmon --testmon-cov=. tests/
+ CURL_CA_BUNDLE="" PYTHONPATH=$PWD pytest --testmon --testmon-cov=. --durations=10 tests/
env:
DATA: /data/scratch/cifar-10
NCCL_SHM_DISABLE: 1
diff --git a/.github/workflows/build_on_schedule.yml b/.github/workflows/build_on_schedule.yml
index 0589cd617b80..03b47e6cb5b6 100644
--- a/.github/workflows/build_on_schedule.yml
+++ b/.github/workflows/build_on_schedule.yml
@@ -3,7 +3,7 @@ name: Build on Schedule
on:
schedule:
# run at 00:00 of every Sunday
- - cron: '0 0 * * *'
+ - cron: "0 0 * * *"
workflow_dispatch:
jobs:
@@ -60,7 +60,7 @@ jobs:
- name: Unit Testing
if: steps.check-avai.outputs.avai == 'true'
run: |
- PYTHONPATH=$PWD pytest tests
+ PYTHONPATH=$PWD pytest --durations=0 tests
env:
DATA: /data/scratch/cifar-10
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
From 190a6ea9c2d1c318779c68786e342daced2f8ac8 Mon Sep 17 00:00:00 2001
From: Frank Lee
Date: Tue, 4 Jul 2023 18:21:11 +0800
Subject: [PATCH 02/64] [dtensor] fixed readme file name and removed deprecated
file (#4162)
---
.../tensor/d_tensor/{RAEDME.md => README.md} | 0
colossalai/tensor/d_tensor/d_tensor.py | 142 ------------------
2 files changed, 142 deletions(-)
rename colossalai/tensor/d_tensor/{RAEDME.md => README.md} (100%)
delete mode 100644 colossalai/tensor/d_tensor/d_tensor.py
diff --git a/colossalai/tensor/d_tensor/RAEDME.md b/colossalai/tensor/d_tensor/README.md
similarity index 100%
rename from colossalai/tensor/d_tensor/RAEDME.md
rename to colossalai/tensor/d_tensor/README.md
diff --git a/colossalai/tensor/d_tensor/d_tensor.py b/colossalai/tensor/d_tensor/d_tensor.py
deleted file mode 100644
index c1fe9d50a048..000000000000
--- a/colossalai/tensor/d_tensor/d_tensor.py
+++ /dev/null
@@ -1,142 +0,0 @@
-from typing import Optional
-
-import torch
-from torch.utils._pytree import tree_map
-
-from .layout import Layout
-from .layout_converter import LayoutConverter, to_global
-from .sharding_spec import ShardingSpec
-
-layout_converter = LayoutConverter()
-
-
-class DTensor(torch.Tensor):
-
- def __init__(self, local_tensor: torch.Tensor, dist_layout: Layout):
- self.local_tensor = local_tensor
- self.data_type = local_tensor.dtype
- self.entire_shape = local_tensor.shape
- self.dist_layout = dist_layout
- self._apply_layout()
-
- @staticmethod
- def __new__(cls, local_tensor, layout):
- return torch.Tensor._make_subclass(cls, local_tensor, local_tensor.requires_grad)
-
- def __repr__(self):
- return f"DTensor({self.to_global()}, {self.dist_layout})"
-
- def __str__(self):
- return self.__repr__()
-
- def layout_convert(self, target_layout):
- '''
- Convert the layout of the tensor from source_spec to target_spec.
- '''
- self.local_tensor = layout_converter.apply(self.local_tensor, self.dist_layout, target_layout)
- self.dist_layout = target_layout
-
- def _apply_layout(self):
- '''
- Apply the layout to the local tensor during initializing process.
- '''
- source_spec = construct_default_sharding_spec(self.local_tensor)
- source_layout = Layout(device_mesh=self.dist_layout.device_mesh,
- device_type=self.dist_layout.device_type,
- sharding_spec=source_spec,
- entire_shape=self.entire_shape)
- self.local_tensor = layout_converter.apply(self.local_tensor, source_layout, self.dist_layout)
-
- @classmethod
- def __torch_function__(cls, func, types, args=(), kwargs=None):
- if kwargs is None:
- kwargs = {}
-
- def filter_arg(arg):
- if isinstance(arg, DTensor):
- return arg.local_tensor
- else:
- return arg
-
- args = tree_map(filter_arg, args)
- kwargs = tree_map(filter_arg, kwargs)
- # if we want to convert the result into DTensor, we need to infer the layout of result from the layout of input tensors
- # and op type.
-
- return func(*args, **kwargs)
-
- @property
- def device_mesh(self):
- '''
- Return the device mesh of the tensor.
- '''
- return self.dist_layout.device_mesh
-
- @property
- def sharding_spec(self):
- '''
- Return the sharding specification of the tensor.
- '''
- return self.dist_layout.sharding_spec
-
- def to(self, *args, **kwargs):
- '''
- Move the tensor to a new device or convert the tensor to a new dtype.
- '''
- self.local_tensor = self.local_tensor.to(*args, **kwargs)
- self.data_type = self.local_tensor.dtype
- self.dist_layout.device_type = self.local_tensor.device
- # TODO: update the device mesh process groups or we should just cache
- # both the cpu process groups and the cuda process groups?
- return self
-
- def to_local(self):
- '''
- Return the local tensor in this rank.
- '''
- return self.local_tensor
-
- def to_global(self):
- '''
- Recover the global tensor from the distributed tensor.
-
- Note: This function will all_gather the local tensor to the global tensor and it
- will not change the layout of the DTensor. This function is mainly used for debugging or
- check the correctness of the distributed tensor.
- '''
- return to_global(self.local_tensor, self.dist_layout)
-
-
-def distribute_tensor(local_tensor: torch.Tensor, dist_layout: Layout) -> DTensor:
- '''
- Distribute the local tensor to the distributed tensor according to the dist_layout specified.
-
- Args:
- local_tensor: tensor to be distributed.
- dist_layout: the layout specification of the distributed tensor.
-
- Returns:
- A 'DTensor' object.
- '''
- return DTensor(local_tensor, dist_layout)
-
-
-def distribute_module(module: torch.nn.Module, partition_fn: Optional[callable] = None) -> torch.nn.Module:
- '''
- This function converts all the parameters in the module to DTensor(DParam).
-
- Note: This function is subject to future change as the DParam has not been implemented yet.
- '''
- for name, param in module.named_parameters():
- if param is not None and not isinstance(param, DTensor):
- # TODO: we could convert the parameter to DParam here,
- # the type of the parameter could be an optional argument.
- setattr(module, name, torch.nn.Parameter(partition_fn(name, param.data)))
- return module
-
-
-def construct_default_sharding_spec(tensor: torch.Tensor,) -> ShardingSpec:
- '''
- Construct the default sharding specification for the tensor.
- '''
- return ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={})
From fee32a3b785327d63ff9cdaea4451b4cfe071d2a Mon Sep 17 00:00:00 2001
From: Frank Lee
Date: Fri, 7 Jul 2023 15:31:51 +0800
Subject: [PATCH 03/64] [docker] added ssh and rdma support for docker (#4192)
---
docker/Dockerfile | 12 ++++++++++++
1 file changed, 12 insertions(+)
diff --git a/docker/Dockerfile b/docker/Dockerfile
index 2c7bafd9604c..97399c939376 100644
--- a/docker/Dockerfile
+++ b/docker/Dockerfile
@@ -5,6 +5,18 @@ LABEL org.opencontainers.image.source = "https://github.com/hpcaitech/ColossalAI
LABEL org.opencontainers.image.licenses = "Apache License 2.0"
LABEL org.opencontainers.image.base.name = "docker.io/library/hpcaitech/cuda-conda:11.3"
+# enable passwordless ssh
+RUN mkdir ~/.ssh && \
+ printf "Host * \n ForwardAgent yes\nHost *\n StrictHostKeyChecking no" > ~/.ssh/config && \
+ ssh-keygen -t rsa -N "" -f ~/.ssh/id_rsa && \
+ cat ~/.ssh/id_rsa.pub >> ~/.ssh/authorized_keys
+
+# enable RDMA support
+RUN apt-get update && \
+ apt-get install -y infiniband-diags perftest ibverbs-providers libibumad3 libibverbs1 libnl-3-200 libnl-route-3-200 librdmacm1 && \
+ apt-get clean && \
+ rm -rf /var/lib/apt/lists/*
+
# install torch
RUN conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch
From 58913441a1bd5df3848a4766e2f75a8ae0942121 Mon Sep 17 00:00:00 2001
From: Baizhou Zhang
Date: Fri, 7 Jul 2023 16:33:06 +0800
Subject: [PATCH 04/64] Next commit [checkpointio] Unsharded Optimizer
Checkpoint for Gemini Plugin (#4141)
* [checkpointio] unsharded optimizer checkpoint for Gemini plugin
* [checkpointio] unsharded optimizer checkpoint for Gemini using all_gather
---
colossalai/booster/plugin/gemini_plugin.py | 75 ++--
.../checkpoint_io/checkpoint_io_base.py | 2 +
.../checkpoint_io/general_checkpoint_io.py | 14 +-
colossalai/checkpoint_io/utils.py | 24 +-
colossalai/interface/optimizer.py | 6 +
colossalai/testing/comparison.py | 64 +++-
colossalai/zero/gemini/gemini_optimizer.py | 340 +++++++++++++++++-
.../test_gemini_checkpoint_io.py | 69 ++--
.../test_gemini_torch_compability.py | 171 +++++++++
9 files changed, 683 insertions(+), 82 deletions(-)
create mode 100644 tests/test_checkpoint_io/test_gemini_torch_compability.py
diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py
index 1173589fcd49..6191f271c318 100644
--- a/colossalai/booster/plugin/gemini_plugin.py
+++ b/colossalai/booster/plugin/gemini_plugin.py
@@ -33,44 +33,40 @@ def __init__(self) -> None:
super().__init__()
self.coordinator = DistCoordinator()
- def load_unsharded_model(self, model: GeminiDDP, checkpoint: str, strict: bool = True):
- """
- Load model from checkpoint with automatic unwrapping.
- """
- # the model should be unwrapped in self.load_model via ModelWrapper.unwrap
- return super().load_unsharded_model(model, checkpoint, strict=strict)
-
def save_unsharded_model(self, model: GeminiDDP, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
"""
- Save model to checkpoint but only on master process.
+ Save sharded model to checkpoint but only on master process.
+ The model should be unwrapped in self.load_model via ModelWrapper.unwrap.
+ As there is communication when getting state dict, this must be called on all processes.
"""
- # the model should be unwrapped in self.load_model via ModelWrapper.unwrap
- # as there is communication when get state dict, this must be called on all processes
state_dict = model.state_dict(only_rank_0=True)
if self.coordinator.is_master():
save_state_dict(state_dict, checkpoint, use_safetensors)
- def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
+ def load_unsharded_model(self, model: GeminiDDP, checkpoint: str, strict: bool = True):
"""
- Save optimizer to checkpoint but only on master process.
+ Load model from checkpoint with automatic unwrapping.
+ The model should be unwrapped in self.load_model via ModelWrapper.unwrap.
"""
- # TODO(ver217): optimizer state dict is sharded
- warnings.warn('GeminiPlugin does not support save full optimizer checkpoint now. Save it on every process.')
- checkpoint = f'{checkpoint}.rank{self.coordinator.rank}'
- super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor)
-
- def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
- warnings.warn(
- 'GeminiPlugin can only load optimizer checkpoint saved by itself with the same number of processes.')
- checkpoint = f'{checkpoint}.rank{self.coordinator.rank}'
- super().load_optimizer(optimizer, checkpoint)
-
- def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
+ super().load_unsharded_model(model, checkpoint, strict=strict)
+
+ def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
"""
- Save model to checkpoint but only on master process.
+ Save unsharded optimizer state dict to checkpoint.
+ After calling optimizer.state_dict(), the complete optimizer states will be collected on master rank.
+ As there is communication when getting state dict, this must be called on all processes.
+ The saving process will only be executed by master rank.
"""
+ state_dict = optimizer.state_dict()
if self.coordinator.is_master():
- super().save_lr_scheduler(lr_scheduler, checkpoint)
+ save_state_dict(state_dict, checkpoint, use_safetensors=False)
+
+ def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str):
+ """
+ Loading unsharded optimizer from checkpoint file.
+ For each process, only loading optimizer states of parameters it controls.
+ """
+ super().load_unsharded_optimizer(optimizer, checkpoint)
def save_sharded_model(self,
model: GeminiDDP,
@@ -82,6 +78,12 @@ def save_sharded_model(self,
"""
Save sharded model
"""
+ if os.path.isfile(checkpoint_path):
+ logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file")
+ return
+
+ Path(checkpoint_path).mkdir(parents=True, exist_ok=True)
+
state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True, dtype=torch.float32)
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
total_size = 0
@@ -117,6 +119,23 @@ def load_sharded_model(self,
"""
return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False)
+ def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str,
+ size_per_shard: int):
+ """
+ Save sharded optimizer state dict to checkpoint folder.
+ As there is communication when getting state dict, this must be called on all processes.
+ """
+ Path(checkpoint).mkdir(parents=True, exist_ok=True)
+ super().save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard)
+
+ def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint_index_file: Path, prefix: str):
+ """
+ Loading sharded optimizer from checkpoint folder, with index file given.
+ For each process, only loading optimizer states of parameters it controls.
+ """
+ # TODO(Baizhou): To be implemented.
+ pass
+
class GeminiModel(ModelWrapper):
@@ -193,7 +212,7 @@ class GeminiPlugin(DPPluginBase):
which will be used when using hybrid CPU optimizer.
This argument is meaningless when `placement_policy` of `GeminiManager` is not "auto".
Defaults to 0.0.
- initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**32.
+ initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**16.
min_scale (float, optional): Min scale used by DynamicGradScaler. Defaults to 1.
growth_factor (float, optional): growth_factor used by DynamicGradScaler. Defaults to 2.
backoff_factor (float, optional): backoff_factor used by DynamicGradScaler. Defaults to 0.5.
@@ -219,7 +238,7 @@ def __init__(
min_chunk_size_m: float = 32,
memstats: Optional[MemStats] = None,
gpu_margin_mem_ratio: float = 0.0,
- initial_scale: float = 2**32,
+ initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
diff --git a/colossalai/checkpoint_io/checkpoint_io_base.py b/colossalai/checkpoint_io/checkpoint_io_base.py
index 8ff9d87c288e..baff24e1cb25 100644
--- a/colossalai/checkpoint_io/checkpoint_io_base.py
+++ b/colossalai/checkpoint_io/checkpoint_io_base.py
@@ -152,6 +152,7 @@ def load_optimizer(self, optimizer: Optimizer, checkpoint: str, prefix: str = No
names to compose the keys in state_dict. Defaults to None.
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
"""
+
index_file_exists, index_file_path = has_index_file(checkpoint)
if Path(checkpoint).is_dir() and not index_file_exists:
@@ -186,6 +187,7 @@ def save_optimizer(self,
prefix (str): prefix for the optimizer checkpoint when shard = True. Default: None.
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard is set to True.
"""
+
if shard:
self.save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard)
else:
diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py
index 26cafcada2c5..e1d9066948dd 100644
--- a/colossalai/checkpoint_io/general_checkpoint_io.py
+++ b/colossalai/checkpoint_io/general_checkpoint_io.py
@@ -28,6 +28,7 @@
shard_model_checkpoint,
shard_optimizer_checkpoint,
sharded_optimizer_loading_epilogue,
+ unwrap_optimizer,
)
__all__ = ['GeneralCheckpointIO']
@@ -59,7 +60,7 @@ def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, pre
# If optimizer is wrapped, unwrap it.
if isinstance(optimizer, OptimizerWrapper):
- optimizer = optimizer.optim
+ optimizer = unwrap_optimizer(optimizer)
# Read checkpoint index file.
ckpt_index_file = CheckpointIndexFile.from_file(index_file_path)
@@ -96,6 +97,11 @@ def save_sharded_optimizer(
- A group file (pytorch_optim_group.bin) recording information of param_groups
- Multiple files (pytorch_optim-000XX.bin) that store state tensors of optimizer in a sharding way
"""
+
+ # If optimizer is wrapped, unwrap it.
+ if isinstance(optimizer, OptimizerWrapper):
+ optimizer = unwrap_optimizer(optimizer)
+
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return
@@ -121,9 +127,8 @@ def save_sharded_optimizer(
shard, current_size = shard_pair
shard_file = get_shard_filename(states_name, idx)
total_size = total_size + current_size
- for param_id in shard.keys():
- index_file.append_weight_map(str(param_id), shard_file)
-
+ for key in shard.keys():
+ index_file.append_weight_map(key, shard_file)
checkpoint_file_path = os.path.join(checkpoint, shard_file)
save_state_dict(shard, checkpoint_file_path, use_safetensors=False)
@@ -177,7 +182,6 @@ def save_sharded_model(self,
total_size = total_size + shard_pair[1]
for key in shard.keys():
index_file.append_weight_map(key, shard_file)
-
checkpoint_file_path = os.path.join(checkpoint_path, shard_file)
save_state_dict(shard, checkpoint_file_path, use_safetensors)
diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py
index 485577b9650c..19e28c3f7068 100644
--- a/colossalai/checkpoint_io/utils.py
+++ b/colossalai/checkpoint_io/utils.py
@@ -10,6 +10,8 @@
import torch.nn as nn
from torch.optim import Optimizer
+from colossalai.interface import OptimizerWrapper
+from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.tensor.d_tensor import is_distributed_tensor
SAFE_WEIGHTS_NAME = "model.safetensors"
@@ -88,6 +90,19 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool:
# ======================================
# Helper functions for saving shard file
# ======================================
+def unwrap_optimizer(optimizer: OptimizerWrapper):
+ '''
+ Unwrap a wrapped optimizer.
+ This method should be used before saving/loading it to/from sharded checkpoints.
+ '''
+
+ # TODO(Baizhou): ColossalaiOptimizer will be replaced with OptimizerWrapper in the future
+ unwrapped_optim = optimizer.optim
+ if isinstance(unwrapped_optim, ColossalaiOptimizer):
+ unwrapped_optim = unwrapped_optim.optim
+ return unwrapped_optim
+
+
def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
"""
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
@@ -103,7 +118,7 @@ def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024)
weight_size = calculate_tensor_size(weight)
# If this weight is going to tip up over the maximal size, we split.
- if current_block_size + weight_size > max_shard_size:
+ if current_block_size + weight_size > max_shard_size and current_block_size > 0:
ret_block = current_block
ret_block_size = current_block_size
current_block = {}
@@ -140,9 +155,10 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) ->
isDTensor = False
for state_tensor in state.values():
- # When state_tensor is None (e.g., a SGD optimizer with momentum set to 0),
+ # When state_tensor is not of Tensor class,
+ # e.g., a SGD optimizer with momentum set to 0 can have None as state
# The calculation of tensor size should be skipped to avoid error.
- if state_tensor is None:
+ if not isinstance(state_tensor, torch.Tensor):
continue
# If the states are stored as DTensors, mark isDTensor as true.
@@ -152,7 +168,7 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) ->
if not isDTensor:
- if current_block_size + state_size > max_shard_size:
+ if current_block_size + state_size > max_shard_size and current_block_size > 0:
ret_block = current_block
ret_block_size = current_block_size
current_block = {}
diff --git a/colossalai/interface/optimizer.py b/colossalai/interface/optimizer.py
index dd9acab17584..0eaf2e1ef8ba 100644
--- a/colossalai/interface/optimizer.py
+++ b/colossalai/interface/optimizer.py
@@ -119,3 +119,9 @@ def unscale_grad(self):
"""
raise NotImplementedError(
"The method unscale_grad is only available for optimizers with mixed precision training")
+
+ def unwrap(self):
+ """
+ Unwrap the optimizer for checkpoint saving/loading.
+ """
+ return self.optim
diff --git a/colossalai/testing/comparison.py b/colossalai/testing/comparison.py
index 5cbfb936b144..8d9ec8ab5f35 100644
--- a/colossalai/testing/comparison.py
+++ b/colossalai/testing/comparison.py
@@ -5,6 +5,7 @@
from torch import Tensor
from torch.distributed import ProcessGroup
from torch.testing import assert_close
+from torch.utils._pytree import tree_flatten
def assert_equal(a: Tensor, b: Tensor):
@@ -16,7 +17,12 @@ def assert_not_equal(a: Tensor, b: Tensor):
def assert_close_loose(a: Tensor, b: Tensor, rtol: float = 1e-3, atol: float = 1e-3):
- assert_close(a, b, rtol=rtol, atol=atol)
+ assert_close(a,
+ b,
+ rtol=rtol,
+ atol=atol,
+ msg=f"Tensor not close, shape: {a.shape} vs {b.shape}, \
+ dtype: {a.dtype} vs {b.dtype}")
def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None):
@@ -33,25 +39,51 @@ def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None):
def check_state_dict_equal(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True):
- for k, v in d1.items():
- if isinstance(v, dict):
- check_state_dict_equal(v, d2[k])
- elif isinstance(v, list):
- for i in range(len(v)):
- if isinstance(v[i], torch.Tensor):
+ assert len(list(d1.keys())) == len(list(d2.keys())), \
+ f"Number of keys unequal: {len(list(d1.keys()))} vs {len(list(d2.keys()))}"
+ for k, v1 in d1.items():
+ assert k in d2
+ v2 = d2[k]
+ if isinstance(v1, dict):
+ assert isinstance(v2, dict)
+ check_state_dict_equal(v1, v2, ignore_device)
+ elif isinstance(v1, list):
+ assert isinstance(v2, list)
+ for v1_i, v2_i in zip(v1, v2):
+ if isinstance(v1_i, torch.Tensor):
+ assert isinstance(v2_i, torch.Tensor)
if not ignore_device:
- v[i] = v[i].to("cpu")
- d2[k][i] = d2[k][i].to("cpu")
- assert torch.equal(v[i], d2[k][i])
+ v1_i = v1_i.to("cpu")
+ v2_i = v2_i.to("cpu")
+ assert_close_loose(v1_i, v2_i)
+ elif isinstance(v1_i, dict):
+ assert isinstance(v2_i, dict)
+ check_state_dict_equal(v1_i, v2_i, ignore_device)
else:
- assert v[i] == d2[k][i]
- elif isinstance(v, torch.Tensor):
+ assert v1_i == v2_i, f"{v1_i} not equals to {v2_i}"
+ elif isinstance(v1, torch.Tensor):
+ assert isinstance(v2, torch.Tensor)
if not ignore_device:
- v = v.to("cpu")
- d2[k] = d2[k].to("cpu")
- assert torch.equal(v, d2[k])
+ v1 = v1.to("cpu")
+ v2 = v2.to("cpu")
+ assert_close_loose(v1, v2)
else:
- assert v == d2[k]
+ assert v1 == v2, f"{v1} not equals to {v2}"
+
+
+def check_state_dict_equal_pytree(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True):
+ flat_d1, _ = tree_flatten(d1)
+ flat_d2, _ = tree_flatten(d2)
+ assert len(flat_d1) == len(flat_d2)
+ for v1, v2 in zip(flat_d1, flat_d2):
+ if isinstance(v1, torch.Tensor):
+ assert isinstance(v2, torch.Tensor)
+ if not ignore_device:
+ v1 = v1.to("cpu")
+ v2 = v2.to("cpu")
+ assert_close_loose(v1, v2)
+ else:
+ assert v1 == v2, f"{v1} not equals to {v2}"
def assert_hf_output_close(out1: Any,
diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py
index 267deb1e8699..99aff6f1c527 100644
--- a/colossalai/zero/gemini/gemini_optimizer.py
+++ b/colossalai/zero/gemini/gemini_optimizer.py
@@ -1,4 +1,6 @@
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
+import copy
+import gc
import math
import warnings
from typing import Any, Dict, Set, Tuple
@@ -101,6 +103,11 @@ def __init__(self,
self.clipping_flag = clipping_norm > 0.0
self.max_norm = clipping_norm
self.verbose = verbose
+ self.param_groups_backup = list()
+
+ # Mapping from integer id to real/fake param tensor, used for checkpointing.
+ self.id_to_real_params: Dict[int, Parameter] = dict()
+ self.id_to_fake_params: Dict[int, Parameter] = dict()
if self.clipping_flag:
assert norm_type == 2.0, "ZeroOptimizer only supports L2 norm now"
@@ -301,25 +308,352 @@ def get_range_pair(local_chunk: Chunk, local_param: Parameter):
end = min(local_chunk.shard_size, param_info.end - local_chunk.shard_begin)
return begin, end
+ param_id = -1
for group in self.optim.param_groups:
fake_params_list = list()
-
+ group_backup = {k: v for k, v in group.items() if k != 'params'}
+ group_ids = []
for param in group['params']:
+
+ # Record the mapping of id to current param.
+ param_id += 1
+ self.id_to_real_params[param_id] = param
+ group_ids.append(param_id)
+
+ # If current param is controlled by current process, add it to fake_param.
if is_ddp_ignored(param):
continue
chunk16 = self.chunk_manager.get_chunk(param)
range_pair = get_range_pair(chunk16, param)
if range_pair[0] >= range_pair[1]:
continue
-
grad_device = self.module.grads_device[param]
fake_param = torch.nn.Parameter(torch.empty([0], device=grad_device))
self.param_to_chunk32[fake_param] = chunk16.paired_chunk
self.param_to_range[fake_param] = range_pair
-
+ self.id_to_fake_params[param_id] = fake_param
fake_params_list.append(fake_param)
+ # Update self.optim.param_groups as well as backup group.
group['params'] = fake_params_list
+ group_backup['params'] = group_ids
+ self.param_groups_backup.append(group_backup)
+
+ def get_offsets(self, param_id: int) -> tuple:
+ '''
+ Args:
+ param_id(int): The id of parameter.
+
+ Returns:
+ chunk_offset(int): Offset of parameter inside the chunk.
+ shard_offset(int): Offset of its optimizer state shard
+ relative to the whole optimizer state.
+ shard_size(int): Length of parameter shard owned by current process.
+ '''
+
+ if param_id not in self.id_to_fake_params:
+ return -1, -1, -1
+ fake_param = self.id_to_fake_params[param_id]
+ chunk = self.param_to_chunk32[fake_param].paired_chunk
+ param = self.id_to_real_params[param_id]
+ param_info = chunk.tensors_info[param]
+
+ begin_in_chunk, end_in_chunk = self.param_to_range[fake_param]
+ chunk_offset = begin_in_chunk
+ shard_offset = begin_in_chunk + chunk.shard_begin - param_info.offset
+ shard_size = end_in_chunk - begin_in_chunk
+ assert chunk_offset >= 0 and shard_offset >= 0
+
+ return chunk_offset, shard_offset, shard_size
+
+ def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict:
+ """
+ Args:
+ param_id (int): id of the parameter whose state is to be gathered at master rank.
+ only_rank_0(bool): if True, states will be collected only on master rank, otherwise collected on every rank.
+
+ Returns:
+ collected_states(dict): the gathered optimzier state of parameter with given id
+ if this method is called by master rank, otherwise an empty dict.
+
+ This method can work only when called by all processes simultaneously.
+ """
+
+ # Get param & chunk & process group.
+ param = self.id_to_real_params[param_id]
+ fake_param = self.id_to_fake_params.get(param_id, None)
+ chunk = self.chunk_manager.get_chunk(param)
+ process_group = chunk.torch_pg
+ rank = dist.get_rank(process_group)
+ master_rank = 0
+ collected_states = {}
+
+ # Fetch names of states through all_gather.
+ local_state_names = None
+ if fake_param is not None:
+ local_state_names = list(self.optim.state[fake_param].keys())
+ gathered_state_names = [None for _ in range(dist.get_world_size(process_group))]
+ dist.barrier()
+ dist.all_gather_object(gathered_state_names, local_state_names)
+ state_names = None
+ for names in gathered_state_names:
+ if names is not None:
+ # Assume different devices share the same set of state names if they have.
+ state_names = copy.deepcopy(names)
+ break
+
+ # Directly return if this parameter doesn't have optimizer states.
+ # e.g. parameter freezed/layer dropped
+ if state_names is None:
+ return collected_states
+
+ # Boolean variable is_collector indicates that whether the current rank
+ # needs to gather the whole optimizer states.
+ # Only master rank is collector when only_rank_0 is True.
+ # Every rank is collector when only_rank_0 is False.
+ is_collector = (rank == master_rank) or (not only_rank_0)
+
+ # If the chunk is kept gathered,
+ # the parameteres are treated the same as that of those in strict DDP during training.
+ # So states can be directly fetched from current device.
+ if chunk.keep_gathered:
+ assert param_id in self.id_to_fake_params
+ if is_collector:
+ states = self.optim.state[fake_param]
+ for state_name in state_names:
+ if state_name == 'step':
+ # To keep aligned with pytorch, state 'step' is stored as a pytorch tensor with type float32.
+ collected_states[state_name] = torch.tensor(states['step'],
+ dtype=torch.float32,
+ requires_grad=False).cpu()
+ else:
+ collected_states[state_name] = states[state_name].detach().clone().to(torch.float32).cpu()
+ return collected_states
+
+ # Check whether the param with given id is managed by current process.
+ own_param = param_id in self.id_to_fake_params
+
+ # Collector gets prepared for state collecting.
+ if is_collector:
+ for state_name in state_names:
+ if state_name == 'step':
+ # To keep aligned with pytorch, state 'step' is stored as a pytorch tensor with type float32.
+ collected_states[state_name] = torch.tensor(0.0, dtype=torch.float32, requires_grad=False).cpu()
+ else:
+ collected_states[state_name] = torch.zeros(param.numel(), dtype=torch.float32,
+ requires_grad=False).cpu()
+
+ # Materials for gathering, including compacted state tensors, and the offset of shard inside each state.
+ compacted_states = self.pack_optimizer_states_to_tensor(param_id, state_names) if own_param else None
+ _, shard_offset, shard_size = self.get_offsets(param_id)
+
+ # Collectors gather state shards through all_gathering.
+ gathered_state_shards = [None for _ in range(dist.get_world_size(process_group))]
+
+ dist.barrier()
+ dist.all_gather_object(gathered_state_shards, [compacted_states, shard_offset, shard_size])
+
+ if is_collector:
+ for state_shard in gathered_state_shards:
+ compacted_states = state_shard[0]
+ shard_offset = state_shard[1]
+ shard_size = state_shard[2]
+ if compacted_states is None:
+ continue
+ self.load_from_compacted_states(compacted_states, collected_states, state_names, shard_offset,
+ shard_size)
+
+ # Clean gathered states
+ for state_shard in gathered_state_shards:
+ del state_shard[0]
+ gc.collect()
+
+ # Reshape tensors
+ if is_collector:
+ for state_name, state_tensor in collected_states.items():
+ if state_tensor.numel() == param.numel():
+ collected_states[state_name] = torch.reshape(state_tensor, param.shape)
+
+ return collected_states
+
+ def pack_optimizer_states_to_tensor(self,
+ param_id: int,
+ state_names: list,
+ device: torch.device = torch.device('cuda'),
+ dtype: torch.dtype = torch.float32) -> torch.Tensor:
+ '''
+ With param id given, pack its optimizer states into a compact tensor and return.
+ '''
+ if param_id not in self.id_to_fake_params:
+ return None
+
+ fake_param = self.id_to_fake_params[param_id]
+ param_range = self.param_to_range[fake_param]
+ states = self.optim.state[fake_param]
+ shard_size = param_range[1] - param_range[0]
+ compacted_size = 0
+ for name in state_names:
+ if name == 'step':
+ compacted_size += 1
+ else:
+ compacted_size += shard_size
+ compacted_states = torch.zeros(compacted_size, dtype=dtype, device=device, requires_grad=False)
+
+ next_state_offset = 0
+ for state_name, state_tensor in states.items():
+ # State 'step' needs special operation.
+ if state_name == 'step':
+ if isinstance(state_tensor, torch.Tensor):
+ compacted_states[next_state_offset] = state_tensor[0].item()
+ else:
+ assert isinstance(state_tensor, int)
+ compacted_states[next_state_offset] = state_tensor
+ next_state_offset += 1
+ else:
+ assert state_tensor.numel() == shard_size
+ compacted_states[next_state_offset:next_state_offset + shard_size].copy_(state_tensor)
+ next_state_offset += shard_size
+
+ return compacted_states
+
+ def load_from_compacted_states(self, compacted_states: torch.Tensor, collected_states: dict, state_names: list,
+ shard_start: int, shard_size: int):
+ '''
+ Given a tensor carrying compacted optimizer states,
+ update these states to collected_states.
+ '''
+ shard_end = shard_start + shard_size
+ next_state_offset = 0
+
+ for state_name in state_names:
+ if state_name == 'step':
+ collected_states['step'].data = torch.tensor(compacted_states[next_state_offset].item(),
+ dtype=torch.float32,
+ requires_grad=False).cpu()
+ next_state_offset += 1
+ else:
+ target_segment = collected_states[state_name][shard_start:shard_end]
+ target_segment.copy_(compacted_states[next_state_offset:next_state_offset + shard_size])
+ next_state_offset += shard_size
+
+ def state_dict(self, only_rank_0: bool = True) -> dict:
+ """
+ Args:
+ only_rank_0 (bool): a boolean value indicating whether the state_dict is collected
+ only on rank 0, dafault to True.
+
+ Returns:
+ The complete state of the optimizer as a :class:`dict`.
+ It contains two entries:
+
+ * state - a dict holding current optimization state. Its content
+ differs between optimizer classes.
+ * param_groups - a list containing all parameter groups where each
+ parameter group is a dict.
+
+ Warning: This method will gather and return the whole optimizer state_dict,
+ so it should be called only when memory resources are abundant.
+ """
+ state_dict = {}
+ state_dict['param_groups'] = copy.deepcopy(self.param_groups_backup)
+
+ torch_special_hyperparameters = {
+ 'amsgrad': False,
+ 'maximize': False,
+ 'foreach': None,
+ 'capturable': False,
+ 'differentiable': False,
+ 'fused': False
+ }
+
+ for group in state_dict['param_groups']:
+ for k, v in torch_special_hyperparameters.items():
+ if k not in group:
+ group[k] = v
+
+ # Collect optimizer states.
+ state_dict['state'] = dict()
+ for param_id in self.id_to_real_params.keys():
+ dist.barrier()
+ state_dict['state'][param_id] = self.collect_states(param_id=param_id, only_rank_0=only_rank_0)
+ return state_dict
+
+ def load_param_groups(self, saved_param_groups: list):
+ """
+ Load saved_param_groups into
+ self.param_groups and self.param_groups_backup
+ """
+ self.param_groups_backup = copy.deepcopy(saved_param_groups)
+
+ # discard the older param_groups
+ self.optim.param_groups = []
+
+ for group in saved_param_groups:
+ fake_params_list = list()
+ updated_group = {k: v for k, v in group.items() if k != 'params'}
+ for param_id in group['params']:
+ if param_id not in self.id_to_fake_params:
+ continue
+ fake_param = self.id_to_fake_params[param_id]
+ fake_params_list.append(fake_param)
+ updated_group['params'] = fake_params_list
+ self.optim.param_groups.append(updated_group)
+
+ def load_single_param_states(self, param_id: int, saved_states: dict):
+ """
+ Load saved optimizer states into parameter with given id.
+ """
+
+ def cast(param, state_range, value, key=None):
+ """
+ Make a copy of the needed segment of value and cast it to device of param.
+ """
+ assert isinstance(value, torch.Tensor)
+ ret_val = value
+ if (key == "step"):
+ assert value.numel() == 1
+ ret_val = int(value.item())
+ else:
+ state_start, state_end = state_range
+ ret_val = torch.zeros(state_end - state_start,
+ dtype=torch.float32,
+ device=param.device,
+ requires_grad=False)
+ ret_val.copy_(value.flatten()[state_start:state_end])
+ return ret_val
+
+ assert param_id in self.id_to_fake_params
+ fake_param = self.id_to_fake_params[param_id]
+ _, state_offset, param_size = self.get_offsets(param_id)
+ state_range = (state_offset, state_offset + param_size)
+
+ # Copy states assigned to param (and cast tensors to appropriate types).
+ updated_states = dict()
+ for k, v in saved_states.items():
+ updated_states[k] = cast(fake_param, state_range, v, k)
+ del v # clean loaded states
+ self.optim.state[fake_param].update(updated_states)
+
+ def load_state_dict(self, state_dict: dict):
+ """Loads optimizer state from whole optimizer state_dict.
+ During loading, filter out the part of states not considered by current process.
+
+ Args:
+ state_dict (dict): optimizer state. Should be an object returned
+ from a call to :meth:`state_dict`.
+ """
+ assert 'param_groups' in state_dict
+ self.load_param_groups(state_dict['param_groups'])
+
+ state = state_dict['state']
+
+ for param_id, param_states in state.items():
+ if param_id in self.id_to_fake_params:
+ self.load_single_param_states(param_id, param_states)
+
+ # Epilogue for pytorch optimizer.
+ self.optim._hook_for_profile() # To support multiprocessing pickle/unpickle.
+ self.optim.defaults.setdefault('differentiable', False)
class GeminiAdamOptimizer(ZeroOptimizer):
diff --git a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py
index 602cf468c944..0235ff2e2c81 100644
--- a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py
+++ b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py
@@ -8,15 +8,18 @@
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin
-from colossalai.booster.plugin.gemini_plugin import GeminiCheckpointIO
from colossalai.nn.optimizer import HybridAdam
-from colossalai.testing import check_state_dict_equal, parameterize, rerun_if_address_is_in_use, spawn
-from colossalai.zero import ZeroDDP
-from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration
-from colossalai.zero.gemini.gemini_mgr import GeminiManager
+from colossalai.testing import (
+ check_state_dict_equal,
+ clear_cache_before_run,
+ parameterize,
+ rerun_if_address_is_in_use,
+ spawn,
+)
from tests.kit.model_zoo import model_zoo
+@clear_cache_before_run()
@parameterize('placement_policy', ['cuda', 'cpu'])
@parameterize('model_name', ['transformers_bert_for_sequence_classification'])
@parameterize('use_safetensors', [False, True])
@@ -29,33 +32,33 @@ def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: b
pretrained_path = os.path.join(tempdir, 'pretrained')
bert_model.config.save_pretrained(save_directory=pretrained_path)
- # TODO(ver217): use boost api
- config_dict, *_ = search_chunk_configuration(bert_model, search_range_m=1, search_interval=100)
- chunk_manager = ChunkManager(config_dict)
- gemini_manager = GeminiManager(placement_policy, chunk_manager)
- bert_model = ZeroDDP(bert_model, gemini_manager)
- bert_model.train()
-
- ckpt_io = GeminiCheckpointIO()
+ plugin = GeminiPlugin(placement_policy=placement_policy)
+ booster = Booster(plugin=plugin)
+ bert_model, _, _, _, _ = booster.boost(bert_model)
model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2
- ckpt_io.save_model(bert_model, (pretrained_path),
+
+ booster.save_model(bert_model,
+ pretrained_path,
True,
True,
'', (model_size / 3),
use_safetensors=use_safetensors)
dist.barrier()
+
new_bert_model = BertForSequenceClassification.from_pretrained(pretrained_path)
- check_state_dict_equal(bert_model.state_dict(only_rank_0=False, dtype=torch.float32),
+ check_state_dict_equal(bert_model.unwrap().state_dict(only_rank_0=False, dtype=torch.float32),
new_bert_model.state_dict(), False)
+@clear_cache_before_run()
@parameterize('placement_policy', ['cuda', 'cpu'])
-@parameterize('shard', [True, False])
+@parameterize('shard', [False])
@parameterize('model_name', ['transformers_gpt'])
-def exam_state_dict(placement_policy, shard: bool, model_name: str):
+@parameterize('size_per_shard', [32])
+def exam_state_dict(placement_policy, shard: bool, model_name: str, size_per_shard: int):
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
criterion = lambda x: x.mean()
- plugin = GeminiPlugin(placement_policy=placement_policy)
+ plugin = GeminiPlugin(placement_policy=placement_policy, precision="fp16", initial_scale=(2**14))
booster = Booster(plugin=plugin)
model = model_fn()
@@ -78,18 +81,32 @@ def exam_state_dict(placement_policy, shard: bool, model_name: str):
with shared_tempdir() as tempdir:
model_ckpt_path = f"{tempdir}/model"
optimizer_ckpt_path = f"{tempdir}/optimizer"
- booster.save_model(model, model_ckpt_path)
- if not shard:
- # TODO(ver217): optimizer checkpointing is not supported for sharded checkpoint
- booster.save_optimizer(optimizer, optimizer_ckpt_path)
+ booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard)
+
+ booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard)
dist.barrier()
booster.load_model(new_model, model_ckpt_path)
check_state_dict_equal(model.unwrap().state_dict(only_rank_0=False),
new_model.unwrap().state_dict(only_rank_0=False), False)
- if not shard:
- booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
- check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False)
+
+ booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
+ check_state_dict_equal(optimizer.unwrap().state_dict(only_rank_0=False),
+ new_optimizer.unwrap().state_dict(only_rank_0=False), False)
+
+ # Check the new model/optimizer can successfully run.
+ data = data_gen_fn()
+ data = {
+ k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()
+ }
+ output = new_model(**data)
+ output = output_transform_fn(output)
+ output_key = list(output.keys())[0]
+ loss = criterion(output[output_key])
+ booster.backward(loss, new_optimizer)
+ new_optimizer.step()
+ booster.save_model(new_model, model_ckpt_path, shard=shard)
+ booster.save_optimizer(new_optimizer, optimizer_ckpt_path, shard=shard)
def run_dist(rank, world_size, port):
@@ -100,7 +117,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
-@pytest.mark.parametrize('world_size', [2])
+@pytest.mark.parametrize('world_size', [1, 2])
@rerun_if_address_is_in_use()
def test_gemini_ckpIO(world_size):
spawn(run_dist, world_size)
diff --git a/tests/test_checkpoint_io/test_gemini_torch_compability.py b/tests/test_checkpoint_io/test_gemini_torch_compability.py
new file mode 100644
index 000000000000..b34e3e3a1310
--- /dev/null
+++ b/tests/test_checkpoint_io/test_gemini_torch_compability.py
@@ -0,0 +1,171 @@
+import pytest
+import torch
+import torch.distributed as dist
+from torch.optim import Adam
+from utils import shared_tempdir
+
+import colossalai
+from colossalai.booster import Booster
+from colossalai.booster.plugin import GeminiPlugin, TorchDDPPlugin
+from colossalai.nn.optimizer import HybridAdam
+from colossalai.testing import (
+ check_state_dict_equal,
+ clear_cache_before_run,
+ parameterize,
+ rerun_if_address_is_in_use,
+ spawn,
+)
+from tests.kit.model_zoo import model_zoo
+
+
+@clear_cache_before_run()
+@parameterize('shard', [False])
+@parameterize('model_name', ['transformers_gpt'])
+def exam_torch_load_from_gemini(shard: bool, model_name: str):
+
+ (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
+ criterion = lambda x: x.mean()
+ plugin = GeminiPlugin(precision="fp16", initial_scale=(2**14))
+ booster = Booster(plugin=plugin)
+
+ model = model_fn()
+ optimizer = HybridAdam(model.parameters(), lr=0.001)
+ model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
+
+ data = data_gen_fn()
+ data = {k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()}
+ output = model(**data)
+ output = output_transform_fn(output)
+ output_key = list(output.keys())[0]
+ loss = criterion(output[output_key])
+
+ booster.backward(loss, optimizer)
+ optimizer.step()
+
+ with shared_tempdir() as tempdir:
+ model_ckpt_path = f"{tempdir}/model"
+ optimizer_ckpt_path = f"{tempdir}/optimizer"
+
+ booster.save_model(model, model_ckpt_path, shard=shard)
+ booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard)
+ dist.barrier()
+
+ new_model = model_fn()
+ new_optimizer = Adam(new_model.parameters(), lr=0.001)
+ new_plugin = TorchDDPPlugin()
+ new_booster = Booster(plugin=new_plugin)
+ new_model, new_optimizer, criterion, _, _ = new_booster.boost(new_model, new_optimizer, criterion)
+
+ # Loading HybridAdam states to torch.Adam
+ new_booster.load_model(new_model, model_ckpt_path, strict=True)
+
+ # Add prefix to get aligned with pytorch parameter names.
+ check_state_dict_equal(
+ model.unwrap().state_dict(only_rank_0=False, prefix='module.module.', dtype=torch.float32),
+ new_model.state_dict(), False)
+
+ new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
+ check_state_dict_equal(optimizer.unwrap().state_dict(only_rank_0=False), new_optimizer.state_dict(), False)
+
+ # Check the new model/optimizer can successfully run.
+ data = data_gen_fn()
+ data = {
+ k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()
+ }
+ output = new_model(**data)
+ output = output_transform_fn(output)
+ output_key = list(output.keys())[0]
+ loss = criterion(output[output_key])
+ new_booster.backward(loss, new_optimizer)
+ new_optimizer.step()
+ new_booster.save_model(new_model, model_ckpt_path, shard=shard)
+ new_booster.save_optimizer(new_optimizer, optimizer_ckpt_path, shard=shard)
+
+
+@clear_cache_before_run()
+@parameterize('shard', [False])
+@parameterize('model_name', ['transformers_gpt'])
+def exam_gemini_load_from_torch(shard: bool, model_name: str):
+
+ (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
+ criterion = lambda x: x.mean()
+ plugin = TorchDDPPlugin()
+ booster = Booster(plugin=plugin)
+
+ model = model_fn()
+ optimizer = Adam(model.parameters(), lr=0.001)
+ model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
+
+ data = data_gen_fn()
+ data = {k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()}
+ output = model(**data)
+ output = output_transform_fn(output)
+ output_key = list(output.keys())[0]
+ loss = criterion(output[output_key])
+
+ booster.backward(loss, optimizer)
+ optimizer.step()
+
+ with shared_tempdir() as tempdir:
+ model_ckpt_path = f"{tempdir}/model"
+ optimizer_ckpt_path = f"{tempdir}/optimizer"
+
+ booster.save_model(model, model_ckpt_path, shard=shard)
+ booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard)
+ dist.barrier()
+
+ new_model = model_fn()
+ new_optimizer = HybridAdam(new_model.parameters(), lr=0.001)
+ new_plugin = GeminiPlugin()
+ new_booster = Booster(plugin=new_plugin)
+ new_model, new_optimizer, criterion, _, _ = new_booster.boost(new_model, new_optimizer, criterion)
+
+ # Loading torch.Adam states to HybridAdam
+ new_booster.load_model(new_model, model_ckpt_path, strict=True)
+
+ # Add prefix to get aligned with pytorch parameter names.
+ check_state_dict_equal(
+ new_model.unwrap().state_dict(only_rank_0=False, prefix='module.module.', dtype=torch.float32),
+ model.state_dict(), False)
+
+ new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
+ old_state_dict = optimizer.state_dict()
+ new_state_dict = new_optimizer.unwrap().state_dict(only_rank_0=False)
+
+ # Comparison of param_groups needs special care here,
+ # since not all hyperparameters in Adam are used by HybridAdam
+ hyperparameters_to_examine = ['params', 'lr', 'betas', 'eps', 'weight_decay']
+ for old_group, new_group in zip(old_state_dict['param_groups'], new_state_dict['param_groups']):
+ for k in hyperparameters_to_examine:
+ assert k in old_group and k in new_group, \
+ f"Old group's keys: {list(old_group.keys())}, New group's keys: {list(new_group.keys())}"
+ assert old_group[k] == new_group[k]
+ check_state_dict_equal(old_state_dict['state'], new_state_dict['state'], False)
+
+ # Check the new model/optimizer can successfully run.
+ data = data_gen_fn()
+ data = {
+ k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()
+ }
+ output = new_model(**data)
+ output = output_transform_fn(output)
+ output_key = list(output.keys())[0]
+ loss = criterion(output[output_key])
+ new_booster.backward(loss, new_optimizer)
+ new_optimizer.step()
+ new_booster.save_model(new_model, model_ckpt_path, shard=shard)
+ new_booster.save_optimizer(new_optimizer, optimizer_ckpt_path, shard=shard)
+
+
+def run_dist(rank, world_size, port):
+ config = {}
+ colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
+ exam_torch_load_from_gemini()
+ exam_gemini_load_from_torch()
+
+
+@pytest.mark.dist
+@pytest.mark.parametrize('world_size', [1, 2])
+@rerun_if_address_is_in_use()
+def test_gemini_ckpIO(world_size):
+ spawn(run_dist, world_size)
From c1cf752021f3f9e6f578eca5827e3e87450d575b Mon Sep 17 00:00:00 2001
From: Frank Lee
Date: Mon, 10 Jul 2023 11:48:27 +0800
Subject: [PATCH 05/64] [docker] fixed ninja build command (#4203)
* [docker] fixed ninja build command
* polish code
---
docker/Dockerfile | 7 +++++--
1 file changed, 5 insertions(+), 2 deletions(-)
diff --git a/docker/Dockerfile b/docker/Dockerfile
index 97399c939376..a1e136ee58a5 100644
--- a/docker/Dockerfile
+++ b/docker/Dockerfile
@@ -21,7 +21,10 @@ RUN apt-get update && \
RUN conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch
# install ninja
-RUN apt-get install -y --no-install-recommends ninja-build
+RUN apt-get update && \
+ apt-get install -y --no-install-recommends ninja-build && \
+ apt-get clean && \
+ rm -rf /var/lib/apt/lists/*
# install apex
RUN git clone https://github.com/NVIDIA/apex && \
@@ -31,7 +34,7 @@ RUN git clone https://github.com/NVIDIA/apex && \
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--fast_layer_norm" ./
# install colossalai
-ARG VERSION=1
+ARG VERSION=main
RUN git clone -b ${VERSION} https://github.com/hpcaitech/ColossalAI.git \
&& cd ./ColossalAI \
&& CUDA_EXT=1 pip install -v --no-cache-dir .
From 4e9b09c222c9b3f78c4ad48eb55f09e3aaba10e1 Mon Sep 17 00:00:00 2001
From: "github-actions[bot]"
<41898282+github-actions[bot]@users.noreply.github.com>
Date: Wed, 12 Jul 2023 17:35:58 +0800
Subject: [PATCH 06/64] Automated submodule synchronization (#4217)
Co-authored-by: github-actions
---
examples/tutorial/fastfold/FastFold | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/examples/tutorial/fastfold/FastFold b/examples/tutorial/fastfold/FastFold
index 05681304651b..eba496808a91 160000
--- a/examples/tutorial/fastfold/FastFold
+++ b/examples/tutorial/fastfold/FastFold
@@ -1 +1 @@
-Subproject commit 05681304651b1b29d7d887db169045ea3dd28fce
+Subproject commit eba496808a91bbcd9661cf832349a418b197015f
From 9a4842c571cd63e6a660182a234bc6ff60991ba0 Mon Sep 17 00:00:00 2001
From: Jianghai <72591262+CjhHa1@users.noreply.github.com>
Date: Mon, 17 Jul 2023 17:30:57 +0800
Subject: [PATCH 07/64] revise shardformer readme (#4246)
---
colossalai/shardformer/README.md | 25 ++++++++++---------------
1 file changed, 10 insertions(+), 15 deletions(-)
diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md
index 6ae32e4fbd42..bf4215c52980 100644
--- a/colossalai/shardformer/README.md
+++ b/colossalai/shardformer/README.md
@@ -22,7 +22,6 @@
- [System Performance](#system-performance)
- [Convergence](#convergence)
-
## 🔗 Introduction
**Shardformer** is a module that automatically parallelizes the mainstream models in libraries such as HuggingFace and TIMM. This module aims to make parallelization hassle-free for users who are not from the system background.
@@ -33,7 +32,7 @@
The sample API usage is given below:
-``` python
+```python
from colossalai.shardformer import ShardConfig, Shard
from transformers import BertForMaskedLM
@@ -74,6 +73,7 @@ shard_former.optimize(model, my_policy)
```
+
## 🗺 Roadmap
We will follow this roadmap to develop Shardformer:
@@ -117,15 +117,13 @@ Please refer to the code for more details.
-
-
### Distributed Modules
`ShardFormer` replaces the original PyTorch module with a distributed module.
The distributed module keeps the same attributes as the original module but replaces the original parameters with distributed parameters and defines a new `forward` function to execute distributed computation.
Each distributed module implements its `from_native_module` static method to convert the PyTorch module to its corresponding distributed module.
-```python
+````python
class ParallelModule(torch.nn.Module):
@abstractmethod
@@ -140,7 +138,7 @@ class ParallelModule(torch.nn.Module):
my_linear = Linear1D_Col.from_native_module(my_linear, process_group)
```
"""
-```
+````
### Shard Config
@@ -169,7 +167,7 @@ We abstract the policy into four stages:
2. Providing `ModulePolicyDescription`: call `Policy.module_policy` to get a bunch of `ModulePolicyDescription` to tell `ModelSharder` how the submodules's attributes, child parameters, and deeper submodules will be substituted.
3. Postprocessing: call `Policy.postprocess` to perform some postprocessing work, for example, binding the embedding and classifier head weights of the BERT model.
-``` python
+```python
@dataclass
class ModulePolicyDescription:
r"""
@@ -238,7 +236,6 @@ class Policy(ABC):
...
```
-
### Model Sharder
`ModelSharder` is the class in charge of sharding the model based on the given policy.
@@ -324,21 +321,20 @@ You can create a new file in the `colossalai/shardformer/policies` folder and na
Please follow the following protocols when writing your policy:
- You have to make a clear decision what you want to replace exactly in the original PyTorch module
- - Use `ModulePolicyDescription.attribute_replacement` to replace the module attributes
- - Use `ModulePolicyDescription.param_replacement` to replace the module parameters
- - Use `ModulePolicyDescription.sub_module_replacement` to replace the submodules completely. The target module should implement the `from_native_module` for the .
- - Use `ModulePolicyDescription.method_replacement` to replace the module methods. **These replacement methods should be put in the `shardformer/modeling/.py`**.
+ - Use `ModulePolicyDescription.attribute_replacement` to replace the module attributes
+ - Use `ModulePolicyDescription.param_replacement` to replace the module parameters
+ - Use `ModulePolicyDescription.sub_module_replacement` to replace the submodules completely. The target module should implement the `from_native_module` for the replacement.
+ - Use `ModulePolicyDescription.method_replacement` to replace the module methods. **These replacement methods should be put in the `shardformer/modeling/.py`**.
- You can implement the `ParallelModule` for primitive modules in the `shardformer/layer/.py` file. Primitive modules refer to modules which are not composed of other modules. For example, the `torch.nn.Linear` module is a primitive module while modules such as `BertEncoder` module in the `transformers` library is a composite module. Primitive modules do not nested inner `nn.Module` members. For composite modules, you should consider using `ModulePolicyDescription` to implement your replacement.
- `ParallelModule` is meant to be used in two ways: `ParallelModule.from_native_module` to convert native PyTorch module to the `ParallelModule` and `ParallelModule(...)` to instantiate the module directly just like a normal PyTorch module. `ParallelModule` should be only implemented for modules whose weights are sharded. If you want to make your module compatible with the `ModulePolicyDescription.sub_module_replacement` and there is no weight sharding in your module, you can just implement the `from_native_module` method without inheriting the `ParallelModule` like `colossalai/shardformer/layer/normalization.py`.
- **Do not import any file in the `colossalai/shardformer/policies` and `colossalai/shardformer/modeling` to avoid unwanted import error**. For example, a file in these folders accidentally imports `transformers` library at the top of the file, then the user will have to install `transformers` library even if they do not use this file. Any file in the `modeling` folder should be only imported by the policy file. A policy implementation should be only imported dynamically via the autopolicy or manually via the `ShardFormer` module.
- Try to keep your import statement on third-party libraries such as `transformers` within the function body instead of the header section of the file. This is because we do not want to import the library when we do not use the policy.
-
- Step 2. Register your policy to the autopolicy
Next, you need to register your policy in the `colossalai/shardformer/policies/autopolicy.py` file.
-For example, if we register the policy for the BERT model, we just add a key-value in the `_POLICY_LIST` dictionary. The key if the `qualname` of the model object (you can get it by model.__class__.__qualname__). The value is a `PolicyLocation` object, which contains the file name and the class name of the policy. We do not import the policy directly because the policy file may contain libraries (such as `transformers`) which we do not want to import when we do not use the policy.
+For example, if we register the policy for the BERT model, we just add a key-value in the `_POLICY_LIST` dictionary. The key if the `qualname` of the model object (you can get it by model.\_\_class\_\_.\_\_qualname\_\_). The value is a `PolicyLocation` object, which contains the file name and the class name of the policy. We do not import the policy directly because the policy file may contain libraries (such as `transformers`) which we do not want to import when we do not use the policy.
```python
_POLICY_LIST = {
@@ -360,7 +356,6 @@ Add your model to the `tests/kit/model_zoo` file. This allows you to define test
Next, implement your unit test in the `tests/test_shardformer` folder. Please refer to other similar tests for style consistency.
-
- Step 3. Execute your test
When you run tests locally, you should run tests for both your newly-added test file and the whole `shardformer` module tests.
From 7ff11b5537123b50d8b1b3b0fbaca0fa31d9481b Mon Sep 17 00:00:00 2001
From: binmakeswell
Date: Mon, 17 Jul 2023 21:07:44 +0800
Subject: [PATCH 08/64] [example] add llama pretraining (#4257)
---
README.md | 11 +++++++++++
docs/README-zh-Hans.md | 10 ++++++++++
examples/language/llama/README.md | 11 +++++++++++
3 files changed, 32 insertions(+)
create mode 100644 examples/language/llama/README.md
diff --git a/README.md b/README.md
index 34c8a6b730a3..21670e1e59fb 100644
--- a/README.md
+++ b/README.md
@@ -25,6 +25,7 @@
## Latest News
+* [2023/07] [65B Model Pretraining Accelerated by 38%, Best Practices for Building LLaMA-Like Base Models Open-Source](https://www.hpc-ai.tech/blog/large-model-pretraining)
* [2023/03] [ColossalChat: An Open-Source Solution for Cloning ChatGPT With a Complete RLHF Pipeline](https://medium.com/@yangyou_berkeley/colossalchat-an-open-source-solution-for-cloning-chatgpt-with-a-complete-rlhf-pipeline-5edf08fb538b)
* [2023/03] [Intel and Colossal-AI Partner to Deliver Cost-Efficient Open-Source Solution for Protein Folding Structure Prediction](https://www.hpc-ai.tech/blog/intel-habana)
* [2023/03] [AWS and Google Fund Colossal-AI with Startup Cloud Programs](https://www.hpc-ai.tech/blog/aws-and-google-fund-colossal-ai-with-startup-cloud-programs)
@@ -49,6 +50,7 @@
Parallel Training Demo
+ - LLaMA
- GPT-3
- GPT-2
- BERT
@@ -216,6 +218,15 @@ Acceleration of [AlphaFold Protein Structure](https://alphafold.ebi.ac.uk/)
## Parallel Training Demo
+### LLaMA
+
+
+
+
+- 65-billion-parameter large model pretraining accelerated by 38%
+[[code]](https://github.com/hpcaitech/ColossalAI/tree/example/llama/examples/language/llama)
+[[blog]](https://www.hpc-ai.tech/blog/large-model-pretraining)
+
### GPT-3
diff --git a/docs/README-zh-Hans.md b/docs/README-zh-Hans.md
index 1dde7a816676..e229c65d890c 100644
--- a/docs/README-zh-Hans.md
+++ b/docs/README-zh-Hans.md
@@ -24,6 +24,7 @@
## 新闻
+* [2023/07] [65B Model Pretraining Accelerated by 38%, Best Practices for Building LLaMA-Like Base Models Open-Source](https://www.hpc-ai.tech/blog/large-model-pretraining)
* [2023/03] [ColossalChat: An Open-Source Solution for Cloning ChatGPT With a Complete RLHF Pipeline](https://medium.com/@yangyou_berkeley/colossalchat-an-open-source-solution-for-cloning-chatgpt-with-a-complete-rlhf-pipeline-5edf08fb538b)
* [2023/03] [Intel and Colossal-AI Partner to Deliver Cost-Efficient Open-Source Solution for Protein Folding Structure Prediction](https://www.hpc-ai.tech/blog/intel-habana)
* [2023/03] [AWS and Google Fund Colossal-AI with Startup Cloud Programs](https://www.hpc-ai.tech/blog/aws-and-google-fund-colossal-ai-with-startup-cloud-programs)
@@ -49,6 +50,7 @@
-
并行训练样例展示
+ - LLaMA
- GPT-3
- GPT-2
- BERT
@@ -209,6 +211,14 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的
(返回顶端)
## 并行训练样例展示
+### LLaMA
+
+
+
+
+- 650亿参数大模型预训练加速38%
+[[代码]](https://github.com/hpcaitech/ColossalAI/tree/example/llama/examples/language/llama)
+[[博客]](https://www.hpc-ai.tech/blog/large-model-pretraining)
### GPT-3
diff --git a/examples/language/llama/README.md b/examples/language/llama/README.md
new file mode 100644
index 000000000000..871804f2ca86
--- /dev/null
+++ b/examples/language/llama/README.md
@@ -0,0 +1,11 @@
+# Pretraining LLaMA: best practices for building LLaMA-like base models
+
+
+
+
+
+- 65-billion-parameter large model pretraining accelerated by 38%
+[[code]](https://github.com/hpcaitech/ColossalAI/tree/example/llama/examples/language/llama)
+[[blog]](https://www.hpc-ai.tech/blog/large-model-pretraining)
+
+> Since the main branch is being updated, in order to maintain the stability of the code, this example is temporarily kept as an [independent branch](https://github.com/hpcaitech/ColossalAI/tree/example/llama/examples/language/llama).
From 4b977541a86c90946badc77a6a77fee64fdc8cce Mon Sep 17 00:00:00 2001
From: Cuiqing Li
Date: Tue, 18 Jul 2023 23:53:38 +0800
Subject: [PATCH 09/64] [Kernels] added triton-implemented of self attention
for colossal-ai (#4241)
* added softmax kernel
* added qkv_kernel
* added ops
* adding tests
* upload tets
* fix tests
* debugging
* debugging tests
* debugging
* added
* fixed errors
* added softmax kernel
* clean codes
* added tests
* update tests
* update tests
* added attention
* add
* fixed pytest checking
* add cuda check
* fix cuda version
* fix typo
---
colossalai/kernel/triton/ops.py | 209 ++++++++++++++++++
colossalai/kernel/triton/qkv_matmul_kernel.py | 109 +++++++++
colossalai/kernel/triton/softmax_kernel.py | 44 ++++
tests/test_kernels/test_self_attention.py | 136 ++++++++++++
tests/test_kernels/test_softmax.py | 27 +++
5 files changed, 525 insertions(+)
create mode 100644 colossalai/kernel/triton/ops.py
create mode 100644 colossalai/kernel/triton/qkv_matmul_kernel.py
create mode 100644 colossalai/kernel/triton/softmax_kernel.py
create mode 100644 tests/test_kernels/test_self_attention.py
create mode 100644 tests/test_kernels/test_softmax.py
diff --git a/colossalai/kernel/triton/ops.py b/colossalai/kernel/triton/ops.py
new file mode 100644
index 000000000000..5e8d4ba3ec99
--- /dev/null
+++ b/colossalai/kernel/triton/ops.py
@@ -0,0 +1,209 @@
+import torch
+from torch import nn
+
+try:
+ import triton
+ import triton.language as tl
+ HAS_TRITON = True
+except ImportError:
+ HAS_TRITON = False
+ print("please install triton from https://github.com/openai/triton")
+
+if HAS_TRITON:
+ from .qkv_matmul_kernel import qkv_gemm_4d_kernel
+ from .softmax_kernel import softmax_kernel
+
+ def self_attention_forward_without_fusion(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, input_mask: torch.Tensor, scale: float):
+ r""" A function to do QKV Attention calculation by calling GEMM and softmax triton kernels
+ Args:
+ q (torch.Tensor): Q embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size)
+ k (torch.Tensor): K embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size)
+ v (torch.Tensor): V embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size)
+ input_mask (torch.Tensor): mask for softmax layer, shape should be (batch, num_heads, seq_lem, seq_len)
+ scale: the float scale value which is used to multiply with Q*K^T before doing softmax
+
+ Return:
+ output (Torch.Tensor): The output shape is (batch, seq_len, num_heads, head_size)
+ """
+ assert len(q.shape) == 4, "the shape of q val must be 4"
+ batches, M, H, K = q.shape
+ assert q.shape == k.shape, "the shape of q and the shape of k must be equal"
+ assert q.shape == v.shape, "the shape of q and the shape of v must be equal"
+ assert q.shape[-1] == k.shape[-1], "the last dimension of q and k must be equal"
+
+ N = k.shape[1]
+
+ # head_size * num_of_head
+ d_model = q.shape[-1] * q.shape[-2]
+
+ score_output = torch.empty(
+ (batches, H, M, N), device=q.device, dtype=q.dtype)
+
+ grid = lambda meta: (
+ batches,
+ H,
+ triton.cdiv(M, meta["BLOCK_SIZE_M"]) *
+ triton.cdiv(N, meta["BLOCK_SIZE_N"]),
+ )
+
+ qkv_gemm_4d_kernel[grid](
+ q, k, score_output,
+ M, N, K,
+ q.stride(0), q.stride(2), q.stride(1), q.stride(3),
+ k.stride(0), k.stride(2), k.stride(3), k.stride(1),
+ score_output.stride(0), score_output.stride(1), score_output.stride(2), score_output.stride(3),
+ scale=scale,
+ # currently manually setting, later on we can use auto-tune config to match best setting
+ BLOCK_SIZE_M=64,
+ BLOCK_SIZE_N=32,
+ BLOCK_SIZE_K=32,
+ GROUP_SIZE_M=8,
+ )
+
+ softmax_output = torch.empty(
+ score_output.shape, device=score_output.device, dtype=score_output.dtype)
+ score_output_shape = score_output.shape
+
+ score_output = score_output.view(-1, score_output.shape[-1])
+ n_rows, n_cols = score_output.shape
+
+ if n_rows <= 350000:
+
+ block_size = max(triton.next_power_of_2(n_cols), 2)
+ num_warps = 4
+ if block_size >= 4096:
+ num_warps = 16
+ elif block_size >= 2048:
+ num_warps = 8
+ else:
+ num_warps = 4
+
+ softmax_kernel[(n_rows, )](
+ softmax_output,
+ score_output,
+ score_output.stride(0),
+ n_cols,
+ mask_ptr = input_mask,
+ num_warps=num_warps,
+ BLOCK_SIZE=block_size,
+ )
+
+ else:
+ #TODO: change softmax kernel functions to make it suitable for large size dimension
+ softmax_output = torch.nn.functional.softmax(score_output, dim=-1)
+ softmax_output = softmax_output.view(*score_output_shape)
+
+ batches, H, M, K = softmax_output.shape
+ N = v.shape[-1]
+
+ output = torch.empty(
+ (batches, M, H, N), device=softmax_output.device, dtype=softmax_output.dtype)
+
+ grid = lambda meta: (
+ batches,
+ H,
+ triton.cdiv(M, meta["BLOCK_SIZE_M"]) *
+ triton.cdiv(N, meta["BLOCK_SIZE_N"]),
+ )
+
+ qkv_gemm_4d_kernel[grid](
+ softmax_output, v, output,
+ M, N, K,
+ softmax_output.stride(0),
+ softmax_output.stride(1),
+ softmax_output.stride(2),
+ softmax_output.stride(3),
+ v.stride(0),
+ v.stride(2),
+ v.stride(1),
+ v.stride(3),
+ output.stride(0),
+ output.stride(2),
+ output.stride(1),
+ output.stride(3),
+ BLOCK_SIZE_M=128,
+ BLOCK_SIZE_N=64,
+ BLOCK_SIZE_K=64,
+ GROUP_SIZE_M=8,
+ scale=-1,
+ )
+ return output.view(batches, -1, d_model)
+
+
+ def self_attention_compute_using_triton(qkv,
+ input_mask,
+ layer_past,
+ alibi,
+ scale,
+ head_size,
+ triangular=False,
+ use_flash=False):
+
+ assert qkv.is_contiguous()
+ assert alibi is None, "current triton self-attention does not support alibi"
+ batches = qkv.shape[0]
+ d_model = qkv.shape[-1] // 3
+ num_of_heads = d_model // head_size
+
+ q = qkv[:, :, :d_model]
+ k = qkv[:, :, d_model:d_model * 2]
+ v = qkv[:, :, d_model * 2:]
+ q = q.view(batches, -1, num_of_heads, head_size)
+ k = k.view(batches, -1, num_of_heads, head_size)
+ v = v.view(batches, -1, num_of_heads, head_size)
+
+ data_output_triton = self_attention_forward_without_fusion(
+ q, k, v, input_mask, scale)
+
+ return data_output_triton
+
+
+ def softmax(input: torch.Tensor, mask: torch.Tensor = None, dim=-1) -> torch.Tensor:
+ if mask is not None:
+ assert input[-1] == mask[-1], "the last dimentions should be the same for input and mask"
+ assert dim == -1 or dim == len(input.shape)-1, "currently softmax layer only support last dimention"
+
+ hidden_dim = input.shape[-1]
+ output = torch.empty_like(input)
+ input = input.view(-1, hidden_dim)
+ if mask is not None:
+ mask = mask.view(-1, hidden_dim)
+ assert input.shape[0] == mask.shape[0], "the fist dimention of mask and input should be the same"
+
+ num_rows, num_cols = input.shape
+ block_size = max(triton.next_power_of_2(num_cols), 2)
+ num_warps = 16
+ if block_size >= 4096:
+ num_warps = 16
+ elif block_size >= 2048:
+ num_warps = 8
+ else:
+ num_warps = 4
+
+ if num_rows <= 350000:
+ grid = (num_rows,)
+ softmax_kernel[grid](output, input, input.stride(0), num_cols, mask, BLOCK_SIZE = block_size, num_warps=num_warps)
+ else:
+ grid = lambda meta: ()
+
+ grid = lambda meta: (
+ triton.cdiv(num_rows, meta["BLOCK_M"]),
+ )
+
+ BLOCK_M = 32
+ if block_size >= 4096:
+ BLOCK_M = 4
+ elif block_size >= 2048:
+ BLOCK_M = 8
+
+ softmax_kernel_2[grid](output_ptr = output,
+ input_ptr = input,
+ row_stride = input.stride(0),
+ n_rows = num_rows,
+ n_cols = num_cols,
+ mask_ptr = mask,
+ # currently manually setting up size
+ BLOCK_M = 32,
+ BLOCK_SIZE = block_size)
+
+ return output
\ No newline at end of file
diff --git a/colossalai/kernel/triton/qkv_matmul_kernel.py b/colossalai/kernel/triton/qkv_matmul_kernel.py
new file mode 100644
index 000000000000..62fc6bba0360
--- /dev/null
+++ b/colossalai/kernel/triton/qkv_matmul_kernel.py
@@ -0,0 +1,109 @@
+import torch
+try:
+ import triton
+ import triton.language as tl
+ HAS_TRITON = True
+except ImportError:
+ HAS_TRITON = False
+ print("please install triton from https://github.com/openai/triton")
+
+
+if HAS_TRITON:
+ '''
+ this kernel function is modified from https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
+ '''
+ @triton.jit
+ def qkv_gemm_4d_kernel(
+ a_ptr,
+ b_ptr,
+ c_ptr,
+ M,
+ N,
+ K,
+ stride_ab,
+ stride_ah,
+ stride_am,
+ stride_ak,
+ stride_bb,
+ stride_bh,
+ stride_bk,
+ stride_bn,
+ stride_cb,
+ stride_ch,
+ stride_cm,
+ stride_cn,
+ scale,
+ # Meta-parameters
+ BLOCK_SIZE_M : tl.constexpr = 64,
+ BLOCK_SIZE_N : tl.constexpr = 32,
+ BLOCK_SIZE_K : tl.constexpr = 32,
+ GROUP_SIZE_M : tl.constexpr = 8,
+ ):
+ r""" A kernel function which is used to do batch-matmul for Q*K^T or score_matrix * V for attention layer,
+ where score_matrix is softmax(Q*V^T/sqrt(hidden_size))
+ Args:
+ a_ptr(torch.Tensor): pointer to input tensor array (bs, M, h, K) or (bs, h, M, K)
+ b_ptr(torch.Tensor): pointer to input tensor array (bs, N, h, K) or (bs, h, N, K)
+ c_ptr(torch.Tensor): pointer to output tensor array (bs, M, h, N) or (bs, h, M, N)
+ stride_ab(tl.constexpr): stride for bs-dimention for tensor array A
+ stride_ah(tl.constexpr): stride for h-dimention for tensor array A
+ stride_am(tl.constexpr): stride for m-dimention for tensor array A
+ stride_ak(tl.constexpr): stride for k-dimention for tensor array A
+ stride_bb(tl.constexpr): stride for bs-dimention for tensor array B
+ stride_bh(tl.constexpr): stride for h-dimention for tensor array B
+ stride_bk(tl.constexpr): stride for k-dimention for tensor array B
+ stride_bn(tl.constexpr): stride for n-dimention for tensor array B
+ stride_cb(tl.constexpr): stride for bs-dimention for tensor array output
+ stride_ch(tl.constexpr): stride for h-dimention for tensor array output
+ stride_cm(tl.constexpr): stride for m-dimention for tensor array output
+ stride_cn(tl.constexpr): stride for n-dimention for tensor array output
+ BLOCK_SIZE_M : tiling size for M-dimension of tensor Array a
+ BLOCK_SIZE_N : tiling size for N-dimension of tensor Array b
+ BLOCK_SIZE_K : tiling size for K-dimension of a and b
+ GROUP_SIZE_M : group size for reducing cache miss, more details:
+ """
+
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
+ batch = tl.program_id(axis = 0)
+ head = tl.program_id(axis = 1)
+ pid = tl.program_id(axis = 2)
+
+ # the following is from tutorial: https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
+ group_id = pid // num_pid_in_group
+ first_pid_m = group_id * GROUP_SIZE_M
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
+ pid_m = first_pid_m + (pid % group_size_m)
+ pid_n = (pid % num_pid_in_group) // group_size_m
+
+
+ offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
+ offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
+ a_ptrs = (a_ptr + batch * stride_ab + head * stride_ah +
+ (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak))
+ b_ptrs = (b_ptr + batch * stride_bb + head * stride_bh +
+ (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn))
+
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
+ for k in range(0, K, BLOCK_SIZE_K):
+ a_mask = (offs_am[:, None] < M) & (offs_k[None, :] + k < K)
+ b_mask = (offs_k[:, None] + k < K) & (offs_bn[None, :] < N)
+ a = tl.load(a_ptrs, mask=a_mask, other=0.)
+ b = tl.load(b_ptrs, mask=b_mask, other=0.)
+ accumulator += tl.dot(a, b)
+ a_ptrs += BLOCK_SIZE_K * stride_ak
+ b_ptrs += BLOCK_SIZE_K * stride_bk
+
+ accumulator = accumulator.to(c_ptr.dtype.element_ty)
+ if scale > 0:
+ accumulator = accumulator * scale.to(c_ptr.dtype.element_ty)
+
+
+ offs_accumu_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
+ offs_accumu_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
+ c_ptrs = (c_ptr + batch * stride_cb + head * stride_ch + stride_cm * offs_accumu_m[:, None] +
+ stride_cn * offs_accumu_n[None, :])
+ accumulator_mask = (offs_accumu_m[:, None] < M) & (offs_accumu_n[None, :] < N)
+ tl.store(c_ptrs, accumulator, mask=accumulator_mask)
diff --git a/colossalai/kernel/triton/softmax_kernel.py b/colossalai/kernel/triton/softmax_kernel.py
new file mode 100644
index 000000000000..c215890badff
--- /dev/null
+++ b/colossalai/kernel/triton/softmax_kernel.py
@@ -0,0 +1,44 @@
+try:
+ import triton
+ import triton.language as tl
+ HAS_TRITON = True
+except ImportError:
+ HAS_TRITON = False
+ print("please install triton from https://github.com/openai/triton")
+
+if HAS_TRITON:
+ '''
+ softmax kernel is modified based on
+ https://github.com/openai/triton/blob/34817ecc954a6f4ca7b4dfb352fdde1f8bd49ca5/python/tutorials/02-fused-softmax.py
+ '''
+ @triton.jit
+ def softmax_kernel(output_ptr, input_ptr, row_stride, n_cols, mask_ptr, BLOCK_SIZE: tl.constexpr):
+ r""" the kernel function for implementing softmax operator
+ Args:
+ output_ptr: the output after finishing softmax operation, (N, hidden_dim)
+ input_ptr: the tensor of input, shape should be (N, hidden_dim)
+ n_cols(tl.constexpr): the number of cols of input
+ BLOCK_SIZE(tl.constexpr): the block_size of your hidden_dim dimension, typically BLOCK_SIZE >= hidden_dim
+ """
+ row_idx = tl.program_id(0)
+ row_start_ptr = input_ptr + row_idx * row_stride
+ col_offsets = tl.arange(0, BLOCK_SIZE)
+ input_ptrs = row_start_ptr + col_offsets
+ row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf')).to(tl.float32)
+ row_minus_max = row - tl.max(row, axis=0)
+
+ if mask_ptr is not None:
+ # load mask into SRAM
+ mask_ptrs = (mask_ptr + (row_indx * row_stride)) + col_offsets
+ mask = tl.load(mask_ptrs, mask=col_offsets < n_cols, other=0).to(tl.float32)
+
+ # update
+ row_minus_max = row_minus_max + mask
+
+ numerator = tl.exp(row_minus_max)
+ denominator = tl.sum(numerator, axis=0)
+ softmax_output = numerator / denominator
+ output_row_start_ptr = output_ptr + row_idx * row_stride
+ output_ptrs = output_row_start_ptr + col_offsets
+ # Write back output to DRAM
+ tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)
\ No newline at end of file
diff --git a/tests/test_kernels/test_self_attention.py b/tests/test_kernels/test_self_attention.py
new file mode 100644
index 000000000000..b316404a58db
--- /dev/null
+++ b/tests/test_kernels/test_self_attention.py
@@ -0,0 +1,136 @@
+import pytest
+from packaging import version
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+from colossalai.kernel.triton.ops import self_attention_compute_using_triton
+from colossalai.kernel.triton.qkv_matmul_kernel import qkv_gemm_4d_kernel
+
+try:
+ import triton
+ import triton.language as tl
+ HAS_TRITON = True
+except ImportError:
+ HAS_TRITON = False
+ print("please install triton from https://github.com/openai/triton")
+
+TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')
+
+@pytest.mark.skipif(not TRITON_CUDA_SUPPORT, reason="triton requires cuda version to be higher than 11.4")
+def test_qkv_matmul():
+ qkv = torch.randn((4, 24, 64*3), device="cuda", dtype=torch.float16)
+ scale = 1.2
+ head_size = 32
+ batches = qkv.shape[0]
+ d_model = qkv.shape[-1] // 3
+ num_of_heads = d_model // head_size
+
+ q = qkv[:, :, :d_model]
+ k = qkv[:, :, d_model:d_model * 2]
+
+ q = q.view(batches, -1, num_of_heads, head_size)
+ k = k.view(batches, -1, num_of_heads, head_size)
+ q_copy = q.clone()
+ k_copy = k.clone()
+ q = torch.transpose(q, 1, 2).contiguous()
+ k = torch.transpose(k, 1, 2).contiguous()
+ k = torch.transpose(k, 2, 3).contiguous()
+
+ torch_ouput = torch.einsum('bnij,bnjk->bnik', q, k)
+ torch_ouput *= 1.2
+
+ q, k = q_copy, k_copy
+ batches, M, H, K = q.shape
+ N = k.shape[1]
+ score_output = torch.empty(
+ (batches, H, M, N), device=q.device, dtype=q.dtype)
+
+ grid = lambda meta: (
+ batches,
+ H,
+ triton.cdiv(M, meta["BLOCK_SIZE_M"]) *
+ triton.cdiv(N, meta["BLOCK_SIZE_N"]),
+ )
+
+ K = q.shape[3]
+ qkv_gemm_4d_kernel[grid](
+ q, k, score_output,
+ M, N, K,
+ q.stride(0), q.stride(2), q.stride(1), q.stride(3),
+ k.stride(0), k.stride(2), k.stride(3), k.stride(1),
+ score_output.stride(0), score_output.stride(1), score_output.stride(2), score_output.stride(3),
+ scale=scale,
+ # currently manually setting, later on we can use auto-tune config to match best setting
+ BLOCK_SIZE_M=64,
+ BLOCK_SIZE_N=32,
+ BLOCK_SIZE_K=32,
+ GROUP_SIZE_M=8,
+ )
+
+ check = torch.allclose(torch_ouput.cpu(), score_output.cpu(), rtol=1e-3, atol=1e-5)
+ assert check is True, "the outputs of triton and torch are not matched"
+
+
+def self_attention_compute_using_torch(qkv,
+ input_mask,
+ scale,
+ head_size
+ ):
+
+ batches = qkv.shape[0]
+ d_model = qkv.shape[-1] // 3
+ num_of_heads = d_model // head_size
+
+ q = qkv[:, :, :d_model]
+ k = qkv[:, :, d_model:d_model * 2]
+ v = qkv[:, :, d_model * 2:]
+ q = q.view(batches, -1, num_of_heads, head_size)
+ k = k.view(batches, -1, num_of_heads, head_size)
+ v = v.view(batches, -1, num_of_heads, head_size)
+
+ q = torch.transpose(q, 1, 2).contiguous()
+ k = torch.transpose(k, 1, 2).contiguous()
+ v = torch.transpose(v, 1, 2).contiguous()
+
+ k = torch.transpose(k, -1, -2).contiguous()
+
+ score_output = torch.einsum('bnij,bnjk->bnik', q, k)
+ score_output *= scale
+
+ softmax_output = F.softmax(score_output, dim = -1)
+ res = torch.einsum('bnij,bnjk->bnik', softmax_output, v)
+ res = torch.transpose(res, 1, 2)
+ res = res.contiguous()
+
+
+ return res.view(batches, -1, d_model), score_output, softmax_output
+
+@pytest.mark.skipif(not TRITON_CUDA_SUPPORT, reason="triton requires cuda version to be higher than 11.4")
+def test_self_atttention_test():
+
+ qkv = torch.randn((4, 24, 64*3), device="cuda", dtype=torch.float16)
+ data_output_torch, score_output_torch, softmax_output_torch = self_attention_compute_using_torch(
+ qkv.clone(),
+ input_mask = None,
+ scale = 1.2,
+ head_size = 32
+ )
+
+ data_output_triton = self_attention_compute_using_triton(
+ qkv.clone(),
+ alibi=None,
+ head_size=32,
+ scale=1.2,
+ input_mask=None,
+ layer_past=None,
+ use_flash=False,
+ triangular=True)
+
+ check = torch.allclose(data_output_triton.cpu(), data_output_torch.cpu(), rtol=1e-4, atol=1e-2)
+ assert check is True, "the triton output is not matched with torch output"
+
+
+if __name__ == "__main__":
+ test_qkv_matmul()
+ test_self_atttention_test()
\ No newline at end of file
diff --git a/tests/test_kernels/test_softmax.py b/tests/test_kernels/test_softmax.py
new file mode 100644
index 000000000000..843d811d019c
--- /dev/null
+++ b/tests/test_kernels/test_softmax.py
@@ -0,0 +1,27 @@
+import pytest
+from packaging import version
+import torch
+from torch import nn
+
+from colossalai.kernel.triton.ops import softmax
+
+TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')
+
+@pytest.mark.skipif(not TRITON_CUDA_SUPPORT, reason="triton requires cuda version to be higher than 11.4")
+def test_softmax_op():
+ data_samples = [
+ torch.randn((3, 4, 5, 32), device = "cuda", dtype = torch.float32),
+ torch.randn((320, 320, 78), device = "cuda", dtype = torch.float32),
+ torch.randn((2345, 4, 5, 64), device = "cuda", dtype = torch.float16)
+ ]
+
+ for data in data_samples:
+ module = nn.Softmax(dim = -1)
+ data_torch_out = module(data)
+ data_triton_out = softmax(data)
+ check = torch.allclose(data_torch_out.cpu(), data_triton_out.cpu(), rtol=1e-3, atol=1e-3)
+ assert check is True, "softmax outputs from triton and torch are not matched"
+
+
+if __name__ == "__main__":
+ test_softmax_op()
\ No newline at end of file
From fc5cef2c79265e36b585ef22c5e1d7f18be52a4e Mon Sep 17 00:00:00 2001
From: Hongxin Liu
Date: Wed, 19 Jul 2023 16:43:01 +0800
Subject: [PATCH 10/64] [lazy] support init on cuda (#4269)
* [lazy] support init on cuda
* [test] update lazy init test
* [test] fix transformer version
---
colossalai/lazy/lazy_init.py | 28 ++++++++++++++++++++--------
requirements/requirements-test.txt | 2 +-
tests/test_lazy/lazy_init_utils.py | 10 +++++++---
tests/test_lazy/test_models.py | 5 +++--
4 files changed, 31 insertions(+), 14 deletions(-)
diff --git a/colossalai/lazy/lazy_init.py b/colossalai/lazy/lazy_init.py
index 8b911407307c..1f5345015bf2 100644
--- a/colossalai/lazy/lazy_init.py
+++ b/colossalai/lazy/lazy_init.py
@@ -1,3 +1,4 @@
+from contextlib import contextmanager
from types import MethodType
from typing import Callable, Dict, Optional, Union
@@ -61,12 +62,15 @@ class _MyTensor(Tensor):
"""
_pre_op_fn: Callable[['LazyTensor'], None] = lambda *args: None
+ default_device: Optional[torch.device] = None
+
def __new__(cls, func, *args, concrete_data=None, **kwargs) -> '_MyTensor':
cls._pre_op_fn()
if concrete_data is not None:
# uniform api as LazyTensor
data = concrete_data
else:
+ kwargs['device'] = cls.default_device
data = func(*args, **kwargs)
return Tensor._make_subclass(cls, data, require_grad=data.requires_grad)
@@ -142,6 +146,8 @@ class LazyTensor(torch.Tensor):
_meta_data: Optional[MetaTensor] = None # shape, dtype, device
_pre_op_fn: Callable[['LazyTensor'], None] = lambda *args: None
+ default_device: Optional[torch.device] = None
+
@staticmethod
def __new__(cls, func, *args, meta_data=None, concrete_data=None, **kwargs):
if concrete_data is not None:
@@ -159,6 +165,8 @@ def __new__(cls, func, *args, meta_data=None, concrete_data=None, **kwargs):
return r
def __init__(self, func, *args, meta_data=None, concrete_data=None, **kwargs):
+ if func.__name__ in _NORMAL_FACTORY:
+ kwargs = {**kwargs, 'device': LazyTensor.default_device}
self._factory_method = (func, args, kwargs) # (func, args, kwargs)
self._op_buffer = [] # (func, args, kwargs, replace)
self._materialized_data: Optional[torch.Tensor] = concrete_data # materialized data
@@ -206,16 +214,11 @@ def _materialize_data(self) -> torch.Tensor:
if self._materialized_data is None:
# apply factory method
func, args, kwargs = self._factory_method
-
# apply cached sequence
self._pre_op_fn()
- try:
- init_val = func(*tree_map(self._replace_with_materialized, args),
- **tree_map(self._replace_with_materialized, kwargs))
- except TypeError as e:
- print(f'init fn: {func.__name__}')
- raise e
+ init_val = func(*tree_map(self._replace_with_materialized, args),
+ **tree_map(self._replace_with_materialized, kwargs))
self._materialized_data = self._rerun_ops(init_val)
return self._materialized_data
@@ -305,6 +308,7 @@ def wrap(y, i=None):
else:
# out of place op, create new lazy tensor
fn = lambda *a, **kw: func(*a, **kw) if i is None else func(*a, **kw)[i]
+ fn.__name__ = func.__name__
lazy_y = LazyTensor(fn, *args, meta_data=y, **kwargs)
return lazy_y
elif type(y) is Tensor:
@@ -435,14 +439,21 @@ class LazyInitContext:
"""
_replaced: bool = False
- def __init__(self, tensor_cls: Union[_MyTensor, LazyTensor] = LazyTensor):
+ def __init__(self,
+ tensor_cls: Union[_MyTensor, LazyTensor] = LazyTensor,
+ default_device: Optional[Union[torch.device, str, int]] = None):
+ assert tensor_cls is LazyTensor or tensor_cls is _MyTensor
self.overrides = {}
self.tensor_cls = tensor_cls
+ self.old_default_device = LazyTensor.default_device
+ self.default_device = default_device
def __enter__(self):
if LazyInitContext._replaced:
raise RuntimeError(f'LazyInitContext is not reentrant')
LazyInitContext._replaced = True
+ self.old_default_device = self.tensor_cls.default_device
+ self.tensor_cls.default_device = self.default_device
def wrap_factory_method(target):
# factory functions (eg. torch.empty())
@@ -518,6 +529,7 @@ def wrapper(*args, **kwargs):
setattr(torch, name, wrapper)
def __exit__(self, exc_type, exc_val, exc_tb):
+ self.tensor_cls.default_device = self.old_default_device
LazyInitContext._replaced = False
for name, (wrapper, orig) in self.overrides.items():
setattr(torch, name, orig)
diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt
index 50121a9283f2..9f6580c72d1b 100644
--- a/requirements/requirements-test.txt
+++ b/requirements/requirements-test.txt
@@ -4,7 +4,7 @@ pytest
coverage==7.2.3
git+https://github.com/hpcaitech/pytest-testmon
torchvision
-transformers
+transformers==4.30.2
timm
titans
torchaudio
diff --git a/tests/test_lazy/lazy_init_utils.py b/tests/test_lazy/lazy_init_utils.py
index 73c3c5422d8a..9d9e9a3a5c76 100644
--- a/tests/test_lazy/lazy_init_utils.py
+++ b/tests/test_lazy/lazy_init_utils.py
@@ -61,14 +61,18 @@ def assert_forward_equal(m1: torch.nn.Module, m2: torch.nn.Module, data_gen_fn:
f'{m1.__class__.__name__} has inconsistent outputs, {out1} vs {out2}'
-def check_lazy_init(entry: TestingEntry, seed: int = 42, verbose: bool = False, check_forward: bool = False) -> None:
+def check_lazy_init(entry: TestingEntry,
+ seed: int = 42,
+ verbose: bool = False,
+ check_forward: bool = False,
+ default_device: str = 'cpu') -> None:
model_fn, data_gen_fn, output_transform_fn, _, model_attr = entry
_MyTensor._pre_op_fn = lambda *args: set_seed(seed)
LazyTensor._pre_op_fn = lambda *args: set_seed(seed)
- ctx = LazyInitContext(tensor_cls=_MyTensor)
+ ctx = LazyInitContext(tensor_cls=_MyTensor, default_device=default_device)
with ctx:
model = model_fn()
- ctx = LazyInitContext()
+ ctx = LazyInitContext(default_device=default_device)
with ctx:
deferred_model = model_fn()
copied_deferred_model = deepcopy(deferred_model)
diff --git a/tests/test_lazy/test_models.py b/tests/test_lazy/test_models.py
index 4b7aeed73a69..e37184125d21 100644
--- a/tests/test_lazy/test_models.py
+++ b/tests/test_lazy/test_models.py
@@ -6,13 +6,14 @@
@pytest.mark.skipif(not SUPPORT_LAZY, reason='requires torch >= 1.12.0')
@pytest.mark.parametrize('subset', ['torchvision', 'diffusers', 'timm', 'transformers', 'torchaudio', 'deepfm', 'dlrm'])
-def test_torchvision_models_lazy_init(subset):
+@pytest.mark.parametrize('default_device', ['cpu', 'cuda'])
+def test_torchvision_models_lazy_init(subset, default_device):
sub_model_zoo = model_zoo.get_sub_registry(subset)
for name, entry in sub_model_zoo.items():
# TODO(ver217): lazy init does not support weight norm, skip these models
if name in ('torchaudio_wav2vec2_base', 'torchaudio_hubert_base') or name.startswith('transformers_llama'):
continue
- check_lazy_init(entry, verbose=True)
+ check_lazy_init(entry, verbose=True, default_device=default_device)
if __name__ == '__main__':
From c6f6005990b182d7ee34c1fb84762d31ce7d3616 Mon Sep 17 00:00:00 2001
From: Baizhou Zhang
Date: Fri, 21 Jul 2023 14:39:01 +0800
Subject: [PATCH 11/64] [checkpointio] Sharded Optimizer Checkpoint for Gemini
Plugin (#4302)
* sharded optimizer checkpoint for gemini plugin
* modify test to reduce testing time
* update doc
* fix bug when keep_gatherd is true under GeminiPlugin
---
colossalai/booster/plugin/gemini_plugin.py | 131 +++++++++++++---
.../checkpoint_io/general_checkpoint_io.py | 38 +++--
colossalai/checkpoint_io/utils.py | 38 +++++
colossalai/zero/gemini/gemini_optimizer.py | 140 ++++++++++++++----
docs/source/en/basics/booster_api.md | 5 +-
docs/source/en/basics/booster_checkpoint.md | 2 -
docs/source/en/basics/booster_plugins.md | 2 -
docs/source/zh-Hans/basics/booster_api.md | 5 +-
.../zh-Hans/basics/booster_checkpoint.md | 1 -
docs/source/zh-Hans/basics/booster_plugins.md | 1 -
.../test_gemini_checkpoint_io.py | 4 +-
.../test_gemini_torch_compability.py | 6 +-
12 files changed, 289 insertions(+), 84 deletions(-)
diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py
index 6191f271c318..7b6e17337d36 100644
--- a/colossalai/booster/plugin/gemini_plugin.py
+++ b/colossalai/booster/plugin/gemini_plugin.py
@@ -1,3 +1,4 @@
+import gc
import logging
import os
import warnings
@@ -12,11 +13,19 @@
from torch.utils.data import DataLoader
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO
-from colossalai.checkpoint_io.utils import get_model_base_filenames, get_shard_filename, save_state_dict
+from colossalai.checkpoint_io.utils import (
+ get_model_base_filenames,
+ get_optimizer_base_filenames,
+ get_shard_filename,
+ load_shard_state_dict,
+ save_state_dict,
+ save_state_dict_shards,
+)
from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.utils import get_current_device
from colossalai.zero import GeminiDDP, zero_model_wrapper, zero_optim_wrapper
+from colossalai.zero.gemini import ZeroOptimizer
from colossalai.zero.gemini.memory_tracer import MemStats
from .dp_plugin_base import DPPluginBase
@@ -37,7 +46,7 @@ def save_unsharded_model(self, model: GeminiDDP, checkpoint: str, gather_dtensor
"""
Save sharded model to checkpoint but only on master process.
The model should be unwrapped in self.load_model via ModelWrapper.unwrap.
- As there is communication when getting state dict, this must be called on all processes.
+ As there is communication when getting state dict, model.state_dict() must be called on all processes.
"""
state_dict = model.state_dict(only_rank_0=True)
if self.coordinator.is_master():
@@ -54,7 +63,7 @@ def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather
"""
Save unsharded optimizer state dict to checkpoint.
After calling optimizer.state_dict(), the complete optimizer states will be collected on master rank.
- As there is communication when getting state dict, this must be called on all processes.
+ As there is communication when getting state dict, optimizer.state_dict() must be called on all processes.
The saving process will only be executed by master rank.
"""
state_dict = optimizer.state_dict()
@@ -76,7 +85,8 @@ def save_sharded_model(self,
max_shard_size: int = 1024,
use_safetensors: bool = False):
"""
- Save sharded model
+ Save sharded model.
+ As there is communication when getting state dict, model.state_dict() must be called on all processes.
"""
if os.path.isfile(checkpoint_path):
logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file")
@@ -86,28 +96,24 @@ def save_sharded_model(self,
state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True, dtype=torch.float32)
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
- total_size = 0
index_file = CheckpointIndexFile(checkpoint_path)
- for idx, shard_pair in enumerate(state_dict_shard):
- if not self.coordinator.is_master():
- continue
- shard = shard_pair[0]
- shard_file = get_shard_filename(weights_name, idx)
- total_size = total_size + shard_pair[1]
- for key in shard.keys():
- index_file.append_weight_map(key, shard_file)
-
- checkpoint_file_path = os.path.join(checkpoint_path, shard_file)
- save_state_dict(shard, checkpoint_file_path, use_safetensors)
- index_file.append_meta_data("total_size", total_size)
+ # Save shards of optimizer states.
+ is_master = self.coordinator.is_master()
+ total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
+ checkpoint=checkpoint_path,
+ index_file=index_file,
+ base_filename=weights_name,
+ is_master=is_master,
+ use_safetensors=use_safetensors)
# only save the index file on the master rank
if self.coordinator.is_master():
+ index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
- logging.info(f"The model is split into checkpoint shards. "
- f"You can find where each parameters has been saved in the "
- f"index located at {save_index_file}.")
+ logging.info(f"The model is split into checkpoint shards. "
+ f"You can find where each parameters has been saved in the "
+ f"index located at {save_index_file}.")
def load_sharded_model(self,
model: GeminiDDP,
@@ -115,7 +121,7 @@ def load_sharded_model(self,
strict: bool = False,
use_safetensors: bool = False):
"""
- load shard model, load model from multiple files
+ Load shard model, load model from multiple files.
"""
return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False)
@@ -125,16 +131,93 @@ def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_
Save sharded optimizer state dict to checkpoint folder.
As there is communication when getting state dict, this must be called on all processes.
"""
+
+ # If optimizer is wrapped, unwrap it.
+ if isinstance(optimizer, OptimizerWrapper):
+ optimizer = optimizer.unwrap()
+
+ assert isinstance(optimizer, ZeroOptimizer)
+
+ if os.path.isfile(checkpoint):
+ logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
+ return
+
Path(checkpoint).mkdir(parents=True, exist_ok=True)
- super().save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard)
+
+ # Preparing file paths and index file.
+ states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
+ index_file = CheckpointIndexFile(checkpoint)
+
+ # Store the information of param groups to param_group_file.
+ index_file.append_meta_data("param_groups", param_group_file)
+ group_file_path = os.path.join(checkpoint, param_group_file)
+ param_groups = optimizer.get_param_groups_for_saving()
+ torch.save(param_groups, group_file_path)
+
+ # States are broken into shards within max_shard_size.
+ state_dict_shard = optimizer.state_shard(prefix=prefix, max_shard_size=size_per_shard, only_rank_0=True)
+
+ # Save shards of optimizer states.
+ is_master = self.coordinator.is_master()
+ total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
+ checkpoint=checkpoint,
+ index_file=index_file,
+ base_filename=states_name,
+ is_master=is_master,
+ use_safetensors=False)
+
+ # Wrap up index file. Only save it on master rank.
+ if self.coordinator.is_master():
+ index_file.append_meta_data("total_size", total_size)
+ index_file.write_index_file(save_index_file)
+ logging.info(f"The optimizer is going to be split to checkpoint shards. "
+ f"You can find where each parameters has been saved in the "
+ f"index located at {save_index_file}.")
def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint_index_file: Path, prefix: str):
"""
Loading sharded optimizer from checkpoint folder, with index file given.
For each process, only loading optimizer states of parameters it controls.
"""
- # TODO(Baizhou): To be implemented.
- pass
+
+ if not os.path.isfile(checkpoint_index_file):
+ logging.error(f"Provided path ({checkpoint_index_file}) should be a file")
+
+ # If optimizer is wrapped, unwrap it.
+ if isinstance(optimizer, OptimizerWrapper):
+ optimizer = optimizer.unwrap()
+
+ assert isinstance(optimizer, ZeroOptimizer)
+
+ # Read checkpoint index file.
+ ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
+
+ # Load param_groups.
+ param_group_path = ckpt_index_file.get_param_group_filename()
+ if param_group_path is None:
+ raise RuntimeError(f'Invalid index file path {checkpoint_index_file} for an optimizer. \
+ Lacking param group file under current directory.')
+ saved_param_groups = torch.load(param_group_path)
+ optimizer.load_param_groups(saved_param_groups)
+
+ checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
+
+ # Load optimizer states from shard files under checkpoint path.
+ # For each file, only load the states managed by current process.
+ for shard_file in checkpoint_files:
+ state_dict_shard = load_shard_state_dict(Path(shard_file), use_safetensors=False)
+ optimizer.load_param_states(state_dict_shard)
+ del state_dict_shard
+ gc.collect()
+
+ optimizer.optimizer_loading_epilogue()
+
+ def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
+ """
+ Save model to checkpoint but only on master process.
+ """
+ if self.coordinator.is_master():
+ super().save_lr_scheduler(lr_scheduler, checkpoint)
class GeminiModel(ModelWrapper):
diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py
index e1d9066948dd..83e4bdcc863b 100644
--- a/colossalai/checkpoint_io/general_checkpoint_io.py
+++ b/colossalai/checkpoint_io/general_checkpoint_io.py
@@ -5,6 +5,7 @@
from pathlib import Path
from typing import Iterator, Optional, OrderedDict, Tuple
+import torch.distributed as dist
import torch.nn as nn
from torch.optim import Optimizer
@@ -16,7 +17,6 @@
get_model_base_filenames,
get_optimizer_base_filenames,
get_shard_filename,
- has_index_file,
is_safetensors_available,
load_param_groups_into_optimizer,
load_shard_state_dict,
@@ -25,6 +25,7 @@
load_states_into_optimizer,
save_param_groups,
save_state_dict,
+ save_state_dict_shards,
shard_model_checkpoint,
shard_optimizer_checkpoint,
sharded_optimizer_loading_epilogue,
@@ -122,15 +123,13 @@ def save_sharded_optimizer(
save_param_groups(state_dict, group_file_path)
# Save shards of optimizer states.
- total_size = 0
- for idx, shard_pair in enumerate(sharded_state):
- shard, current_size = shard_pair
- shard_file = get_shard_filename(states_name, idx)
- total_size = total_size + current_size
- for key in shard.keys():
- index_file.append_weight_map(key, shard_file)
- checkpoint_file_path = os.path.join(checkpoint, shard_file)
- save_state_dict(shard, checkpoint_file_path, use_safetensors=False)
+ # In general cases, is_master is set to True to get the right behavior.
+ total_size = save_state_dict_shards(sharded_state_dict=sharded_state,
+ checkpoint=checkpoint,
+ index_file=index_file,
+ base_filename=states_name,
+ is_master=True,
+ use_safetensors=False)
# Wrap up index file.
index_file.append_meta_data("total_size", total_size)
@@ -172,18 +171,17 @@ def save_sharded_model(self,
# shard checkpoint
state_dict = model.state_dict()
state_dict_shard = shard_model_checkpoint(state_dict, max_shard_size=max_shard_size)
-
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
- total_size = 0
index_file = CheckpointIndexFile(checkpoint_path)
- for idx, shard_pair in enumerate(state_dict_shard):
- shard = shard_pair[0]
- shard_file = get_shard_filename(weights_name, idx)
- total_size = total_size + shard_pair[1]
- for key in shard.keys():
- index_file.append_weight_map(key, shard_file)
- checkpoint_file_path = os.path.join(checkpoint_path, shard_file)
- save_state_dict(shard, checkpoint_file_path, use_safetensors)
+
+ # Save shards of optimizer states.
+ # In general cases, is_master is set to True to get the right behavior.
+ total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
+ checkpoint=checkpoint_path,
+ index_file=index_file,
+ base_filename=weights_name,
+ is_master=True,
+ use_safetensors=use_safetensors)
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py
index 19e28c3f7068..8837776aee4d 100644
--- a/colossalai/checkpoint_io/utils.py
+++ b/colossalai/checkpoint_io/utils.py
@@ -1,4 +1,5 @@
# coding=utf-8
+import os
import re
from collections import abc as container_abcs
from collections import defaultdict
@@ -103,6 +104,43 @@ def unwrap_optimizer(optimizer: OptimizerWrapper):
return unwrapped_optim
+def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]],
+ checkpoint: str,
+ index_file: "CheckpointIndexFile",
+ base_filename: str,
+ is_master: bool,
+ use_safetensors: bool = False) -> int:
+ '''
+ Save sharded state dict only on master rank, this method can be used by both model and optimizer states.
+ Args:
+ sharded_state_dict (Iterator[Tuple[OrderedDict, int]]): a generator of shards, each shard contains state dict and shard size.
+ checkpoint (str): The path of checkpoint directory as string.
+ index_file (CheckpointIndexFile): The index file object to be updated.
+ base_filename (str): Decides the prefix of filenames of shards.
+ is_master (bool): Whether current rank is master.
+ use_safetensors (bool): Whether to use safetensors to save checkpoint.
+
+ Returns:
+ int: the total size of shards
+ '''
+
+ total_size = 0
+ for idx, shard_pair in enumerate(sharded_state_dict):
+ if not is_master:
+ continue
+ shard, current_size = shard_pair
+ shard_file = get_shard_filename(base_filename, idx)
+ total_size = total_size + current_size
+ for key in shard.keys():
+ index_file.append_weight_map(key, shard_file)
+ checkpoint_file_path = os.path.join(checkpoint, shard_file)
+
+ # Only save on master rank.
+ save_state_dict(shard, checkpoint_file_path, use_safetensors=use_safetensors)
+
+ return total_size
+
+
def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
"""
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py
index 99aff6f1c527..7d0db6b1fa23 100644
--- a/colossalai/zero/gemini/gemini_optimizer.py
+++ b/colossalai/zero/gemini/gemini_optimizer.py
@@ -3,7 +3,7 @@
import gc
import math
import warnings
-from typing import Any, Dict, Set, Tuple
+from typing import Any, Dict, Iterator, OrderedDict, Set, Tuple
import torch
import torch.distributed as dist
@@ -11,8 +11,10 @@
from torch.optim import Optimizer
from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin
+from colossalai.checkpoint_io.utils import calculate_tensor_size
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer, CPUAdam, FusedAdam, HybridAdam
+from colossalai.tensor.d_tensor import is_distributed_tensor
from colossalai.utils import disposable, get_current_device, is_ddp_ignored
from .chunk import Chunk, ChunkManager
@@ -360,10 +362,12 @@ def get_offsets(self, param_id: int) -> tuple:
begin_in_chunk, end_in_chunk = self.param_to_range[fake_param]
chunk_offset = begin_in_chunk
- shard_offset = begin_in_chunk + chunk.shard_begin - param_info.offset
+ if chunk.keep_gathered:
+ shard_offset = 0
+ else:
+ shard_offset = begin_in_chunk + chunk.shard_begin - param_info.offset
shard_size = end_in_chunk - begin_in_chunk
assert chunk_offset >= 0 and shard_offset >= 0
-
return chunk_offset, shard_offset, shard_size
def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict:
@@ -427,7 +431,8 @@ def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict:
dtype=torch.float32,
requires_grad=False).cpu()
else:
- collected_states[state_name] = states[state_name].detach().clone().to(torch.float32).cpu()
+ state_tensor = states[state_name].detach().clone().to(torch.float32).cpu()
+ collected_states[state_name] = torch.reshape(state_tensor, param.shape)
return collected_states
# Check whether the param with given id is managed by current process.
@@ -536,6 +541,31 @@ def load_from_compacted_states(self, compacted_states: torch.Tensor, collected_s
target_segment.copy_(compacted_states[next_state_offset:next_state_offset + shard_size])
next_state_offset += shard_size
+ def get_param_groups_for_saving(self) -> list:
+ '''
+ Return the param_groups in Pytorch format when saving to checkpoint.
+ '''
+
+ param_groups = copy.deepcopy(self.param_groups_backup)
+
+ # To be compatible with pytorch checkpointing,
+ # store extra hyperparameters used by pytorch Adam optimizer.
+ torch_special_hyperparameters = {
+ 'amsgrad': False,
+ 'maximize': False,
+ 'foreach': None,
+ 'capturable': False,
+ 'differentiable': False,
+ 'fused': False
+ }
+
+ for group in param_groups:
+ for k, v in torch_special_hyperparameters.items():
+ if k not in group:
+ group[k] = v
+
+ return param_groups
+
def state_dict(self, only_rank_0: bool = True) -> dict:
"""
Args:
@@ -555,21 +585,7 @@ def state_dict(self, only_rank_0: bool = True) -> dict:
so it should be called only when memory resources are abundant.
"""
state_dict = {}
- state_dict['param_groups'] = copy.deepcopy(self.param_groups_backup)
-
- torch_special_hyperparameters = {
- 'amsgrad': False,
- 'maximize': False,
- 'foreach': None,
- 'capturable': False,
- 'differentiable': False,
- 'fused': False
- }
-
- for group in state_dict['param_groups']:
- for k, v in torch_special_hyperparameters.items():
- if k not in group:
- group[k] = v
+ state_dict['param_groups'] = self.get_param_groups_for_saving()
# Collect optimizer states.
state_dict['state'] = dict()
@@ -634,8 +650,24 @@ def cast(param, state_range, value, key=None):
del v # clean loaded states
self.optim.state[fake_param].update(updated_states)
+ def load_param_states(self, param_states: dict):
+ """Loads param states from a state_dict. The param_states can be complete or sharded.
+ During loading, filter out the part of states not considered by current process.
+
+ Args:
+ param_states (dict): A mapping from param_id to its states.
+ """
+ for param_id, states in param_states.items():
+ if param_id in self.id_to_fake_params:
+ self.load_single_param_states(param_id, states)
+
+ def optimizer_loading_epilogue(self):
+ # Epilogue when loading state_dict to pytorch optimizer.
+ self.optim._hook_for_profile() # To support multiprocessing pickle/unpickle.
+ self.optim.defaults.setdefault('differentiable', False)
+
def load_state_dict(self, state_dict: dict):
- """Loads optimizer state from whole optimizer state_dict.
+ """Loads optimizer state from complete optimizer state_dict.
During loading, filter out the part of states not considered by current process.
Args:
@@ -643,17 +675,71 @@ def load_state_dict(self, state_dict: dict):
from a call to :meth:`state_dict`.
"""
assert 'param_groups' in state_dict
+ assert 'state' in state_dict
self.load_param_groups(state_dict['param_groups'])
+ self.load_param_states(state_dict['state'])
+ self.optimizer_loading_epilogue()
- state = state_dict['state']
+ def state_shard(self,
+ prefix: str = '',
+ max_shard_size: int = 1024,
+ only_rank_0: bool = True) -> Iterator[Tuple[OrderedDict, int]]:
+ """Returns dictionaries containing shards of optimizer states one by one.
+ The max size of each dictionary shard is specified by ``max_shard_size``.
- for param_id, param_states in state.items():
- if param_id in self.id_to_fake_params:
- self.load_single_param_states(param_id, param_states)
+ Args:
+ prefix (str, optional): the prefix for states. Default to ''.
+ max_shard_size (int, optional): max size of state dict shard (in MB). Defaults to 1024.
+ only_rank_0 (bool, optional): a boolean value indicating whether the state_dict is collected
+ only on rank 0, dafault to True.
- # Epilogue for pytorch optimizer.
- self.optim._hook_for_profile() # To support multiprocessing pickle/unpickle.
- self.optim.defaults.setdefault('differentiable', False)
+ Yields:
+ Iterator[OrderedDict]: A generator of state dict shard of optimizer states.
+ """
+
+ current_block = {}
+ current_block_size = 0
+
+ for param_id in self.id_to_real_params.keys():
+
+ dist.barrier()
+ state = self.collect_states(param_id=param_id, only_rank_0=only_rank_0)
+
+ ret_block = None
+ ret_block_size = 0
+
+ # A state might contain more than one tensors.
+ # e.g. each Adam state includes: 'step', 'exp_avg', 'exp_avg_sq'
+ state_size = 0
+ isDTensor = False
+ for state_tensor in state.values():
+
+ # When state_tensor is not of Tensor class,
+ # e.g., a SGD optimizer with momentum set to 0 can have None as state
+ # The calculation of tensor size should be skipped to avoid error.
+ if not isinstance(state_tensor, torch.Tensor):
+ continue
+
+ # If the states are stored as DTensors, mark isDTensor as true.
+ if is_distributed_tensor(state_tensor):
+ isDTensor = True
+ state_size += calculate_tensor_size(state_tensor)
+
+ if not isDTensor:
+
+ if current_block_size + state_size > max_shard_size and current_block_size > 0:
+ ret_block = current_block
+ ret_block_size = current_block_size
+ current_block = {}
+ current_block_size = 0
+
+ current_block[param_id] = state
+ current_block_size += state_size
+
+ if ret_block != None:
+ yield ret_block, ret_block_size
+
+ yield current_block, current_block_size
class GeminiAdamOptimizer(ZeroOptimizer):
diff --git a/docs/source/en/basics/booster_api.md b/docs/source/en/basics/booster_api.md
index 22d5ee818019..1e75c343c14f 100644
--- a/docs/source/en/basics/booster_api.md
+++ b/docs/source/en/basics/booster_api.md
@@ -21,10 +21,13 @@ Plugin is an important component that manages parallel configuration (eg: The ge
**_GeminiPlugin:_** This plugin wraps the Gemini acceleration solution, that ZeRO with chunk-based memory management.
-**_TorchDDPPlugin:_** This plugin wraps the DDP acceleration solution, it implements data parallelism at the module level which can run across multiple machines.
+**_TorchDDPPlugin:_** This plugin wraps the DDP acceleration solution of Pytorch. It implements data parallelism at the module level which can run across multiple machines.
**_LowLevelZeroPlugin:_** This plugin wraps the 1/2 stage of Zero Redundancy Optimizer. Stage 1 : Shards optimizer states across data parallel workers/GPUs. Stage 2 : Shards optimizer states + gradients across data parallel workers/GPUs.
+
+**_TorchFSDPPlugin:_** This plugin wraps the FSDP acceleration solution of Pytorch and can be used to train models with zero-dp.
+
### API of booster
{{ autodoc:colossalai.booster.Booster }}
diff --git a/docs/source/en/basics/booster_checkpoint.md b/docs/source/en/basics/booster_checkpoint.md
index adc0af60b7de..b2840fe87441 100644
--- a/docs/source/en/basics/booster_checkpoint.md
+++ b/docs/source/en/basics/booster_checkpoint.md
@@ -21,8 +21,6 @@ Model must be boosted by `colossalai.booster.Booster` before loading. It will de
## Optimizer Checkpoint
-> ⚠ Saving optimizer checkpoint in a sharded way is not supported yet.
-
{{ autodoc:colossalai.booster.Booster.save_optimizer }}
Optimizer must be boosted by `colossalai.booster.Booster` before saving.
diff --git a/docs/source/en/basics/booster_plugins.md b/docs/source/en/basics/booster_plugins.md
index 5e2586b836ad..c5c45abce8f7 100644
--- a/docs/source/en/basics/booster_plugins.md
+++ b/docs/source/en/basics/booster_plugins.md
@@ -51,8 +51,6 @@ This plugin implements Zero-3 with chunk-based and heterogeneous memory manageme
{{ autodoc:colossalai.booster.plugin.GeminiPlugin }}
-> ⚠ This plugin can only load optimizer checkpoint saved by itself with the same number of processes now. This will be fixed in the future.
-
### Torch DDP Plugin
More details can be found in [Pytorch Docs](https://pytorch.org/docs/main/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel).
diff --git a/docs/source/zh-Hans/basics/booster_api.md b/docs/source/zh-Hans/basics/booster_api.md
index 1df821ce7d6e..b2235b73bca1 100644
--- a/docs/source/zh-Hans/basics/booster_api.md
+++ b/docs/source/zh-Hans/basics/booster_api.md
@@ -24,10 +24,13 @@ Booster 插件是管理并行配置的重要组件(eg:gemini 插件封装了
**_GeminiPlugin:_** GeminiPlugin 插件封装了 gemini 加速解决方案,即基于块内存管理的 ZeRO 优化方案。
-**_TorchDDPPlugin:_** TorchDDPPlugin 插件封装了 DDP 加速方案,实现了模型级别的数据并行,可以跨多机运行。
+**_TorchDDPPlugin:_** TorchDDPPlugin 插件封装了Pytorch的DDP加速方案,实现了模型级别的数据并行,可以跨多机运行。
**_LowLevelZeroPlugin:_** LowLevelZeroPlugin 插件封装了零冗余优化器的 1/2 阶段。阶段 1:切分优化器参数,分发到各并发进程或并发 GPU 上。阶段 2:切分优化器参数及梯度,分发到各并发进程或并发 GPU 上。
+**_TorchFSDPPlugin:_** TorchFSDPPlugin封装了 Pytorch的FSDP加速方案,可以用于零冗余优化器数据并行(ZeroDP)的训练。
+
+
### Booster 接口
diff --git a/docs/source/zh-Hans/basics/booster_checkpoint.md b/docs/source/zh-Hans/basics/booster_checkpoint.md
index d75f18c908ba..4ed049dcf44f 100644
--- a/docs/source/zh-Hans/basics/booster_checkpoint.md
+++ b/docs/source/zh-Hans/basics/booster_checkpoint.md
@@ -21,7 +21,6 @@
## 优化器 Checkpoint
-> ⚠ 尚不支持以分片方式保存优化器 Checkpoint。
{{ autodoc:colossalai.booster.Booster.save_optimizer }}
diff --git a/docs/source/zh-Hans/basics/booster_plugins.md b/docs/source/zh-Hans/basics/booster_plugins.md
index 5bd88b679000..0f355c43901c 100644
--- a/docs/source/zh-Hans/basics/booster_plugins.md
+++ b/docs/source/zh-Hans/basics/booster_plugins.md
@@ -51,7 +51,6 @@ Zero-2 不支持局部梯度累积。如果您坚持使用,虽然可以积累
{{ autodoc:colossalai.booster.plugin.GeminiPlugin }}
-> ⚠ 该插件现在只能加载自己保存的且具有相同进程数的优化器 Checkpoint。这将在未来得到解决。
### Torch DDP 插件
diff --git a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py
index 0235ff2e2c81..7b664419b405 100644
--- a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py
+++ b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py
@@ -52,7 +52,7 @@ def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: b
@clear_cache_before_run()
@parameterize('placement_policy', ['cuda', 'cpu'])
-@parameterize('shard', [False])
+@parameterize('shard', [False, True])
@parameterize('model_name', ['transformers_gpt'])
@parameterize('size_per_shard', [32])
def exam_state_dict(placement_policy, shard: bool, model_name: str, size_per_shard: int):
@@ -117,7 +117,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
-@pytest.mark.parametrize('world_size', [1, 2])
+@pytest.mark.parametrize('world_size', [2])
@rerun_if_address_is_in_use()
def test_gemini_ckpIO(world_size):
spawn(run_dist, world_size)
diff --git a/tests/test_checkpoint_io/test_gemini_torch_compability.py b/tests/test_checkpoint_io/test_gemini_torch_compability.py
index b34e3e3a1310..464fccb39103 100644
--- a/tests/test_checkpoint_io/test_gemini_torch_compability.py
+++ b/tests/test_checkpoint_io/test_gemini_torch_compability.py
@@ -19,7 +19,7 @@
@clear_cache_before_run()
-@parameterize('shard', [False])
+@parameterize('shard', [False, True])
@parameterize('model_name', ['transformers_gpt'])
def exam_torch_load_from_gemini(shard: bool, model_name: str):
@@ -83,7 +83,7 @@ def exam_torch_load_from_gemini(shard: bool, model_name: str):
@clear_cache_before_run()
-@parameterize('shard', [False])
+@parameterize('shard', [False, True])
@parameterize('model_name', ['transformers_gpt'])
def exam_gemini_load_from_torch(shard: bool, model_name: str):
@@ -165,7 +165,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
-@pytest.mark.parametrize('world_size', [1, 2])
+@pytest.mark.parametrize('world_size', [2])
@rerun_if_address_is_in_use()
def test_gemini_ckpIO(world_size):
spawn(run_dist, world_size)
From 02192a632e6c6f965d93ec79937f97e10e121307 Mon Sep 17 00:00:00 2001
From: Hongxin Liu
Date: Fri, 21 Jul 2023 18:36:35 +0800
Subject: [PATCH 12/64] [ci] support testmon core pkg change detection (#4305)
---
.github/workflows/build_on_pr.yml | 1 +
1 file changed, 1 insertion(+)
diff --git a/.github/workflows/build_on_pr.yml b/.github/workflows/build_on_pr.yml
index 380c8e9f882c..8a1bc8e113de 100644
--- a/.github/workflows/build_on_pr.yml
+++ b/.github/workflows/build_on_pr.yml
@@ -213,6 +213,7 @@ jobs:
DATA: /data/scratch/cifar-10
NCCL_SHM_DISABLE: 1
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
+ TESTMON_CORE_PKGS: /__w/ColossalAI/ColossalAI/requirements/requirements.txt,/__w/ColossalAI/ColossalAI/requirements/requirements-test.txt
- name: Store Testmon Cache
run: |
From b366f1d99fd77dac73403b15eb590144e3bc2fdf Mon Sep 17 00:00:00 2001
From: Jianghai <72591262+CjhHa1@users.noreply.github.com>
Date: Mon, 17 Jul 2023 18:01:30 +0800
Subject: [PATCH 13/64] [NFC] Fix format for mixed precision (#4253)
* [NFC] polish colossalai/booster/mixed_precision/mixed_precision_base.py code style
---
.../booster/mixed_precision/mixed_precision_base.py | 11 ++++++-----
1 file changed, 6 insertions(+), 5 deletions(-)
diff --git a/colossalai/booster/mixed_precision/mixed_precision_base.py b/colossalai/booster/mixed_precision/mixed_precision_base.py
index 8caa34e505e1..a86fdfc17eaf 100644
--- a/colossalai/booster/mixed_precision/mixed_precision_base.py
+++ b/colossalai/booster/mixed_precision/mixed_precision_base.py
@@ -13,10 +13,11 @@ class MixedPrecision(ABC):
"""
@abstractmethod
- def configure(self,
- model: nn.Module,
- optimizer: Optional[Optimizer] = None,
- criterion: Optional[Callable] = None,
- ) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
+ def configure(
+ self,
+ model: nn.Module,
+ optimizer: Optional[Optimizer] = None,
+ criterion: Optional[Callable] = None,
+ ) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
# TODO: implement this method
pass
From 86cf6aed5b0822cdb539c54888f00553dfd12209 Mon Sep 17 00:00:00 2001
From: Michelle <97082656+MichelleMa8@users.noreply.github.com>
Date: Tue, 18 Jul 2023 10:23:46 +0800
Subject: [PATCH 14/64] Fix/format (#4261)
* revise shardformer readme (#4246)
* [example] add llama pretraining (#4257)
* [NFC] polish colossalai/communication/p2p.py code style
---------
Co-authored-by: Jianghai <72591262+CjhHa1@users.noreply.github.com>
Co-authored-by: binmakeswell
Co-authored-by: Qianran Ma
---
colossalai/communication/p2p.py | 12 +++++++-----
1 file changed, 7 insertions(+), 5 deletions(-)
diff --git a/colossalai/communication/p2p.py b/colossalai/communication/p2p.py
index 1f20fca4f74d..d28d140168fd 100644
--- a/colossalai/communication/p2p.py
+++ b/colossalai/communication/p2p.py
@@ -1,16 +1,18 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
+import operator
+from functools import reduce
from typing import List, Tuple, Union
+
import torch
import torch.distributed as dist
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import get_current_device
-from functools import reduce
-import operator
-from .utils import split_tensor_into_1d_equal_chunks, gather_split_1d_tensor
+
+from .utils import gather_split_1d_tensor, split_tensor_into_1d_equal_chunks
TensorShape = Union[torch.Size, List[int], Tuple[int]]
@@ -260,7 +262,7 @@ def send_forward_recv_backward(output_tensor,
next_rank=None,
dtype=torch.float,
scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]:
- """Batched communication operation. Sends the input tensor to the
+ """Batched communication operation. Sends the input tensor to the
next stage in pipeline, while receives the gradient tensor from the
next stage in pipeline as the input gradient tensor of this stage.
@@ -319,7 +321,7 @@ def send_forward_recv_forward(output_tensor,
next_rank=None,
dtype=torch.float,
scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]:
- """Batched communication operation. Sends the input tensor to the
+ """Batched communication operation. Sends the input tensor to the
next stage in pipeline, while receives the output tensor from the
previous stage in pipeline as the input of this stage.
From 915ed8bed10fef441713ccb7b2173a4d74fa9898 Mon Sep 17 00:00:00 2001
From: Camille Zhong <44392324+Camille7777@users.noreply.github.com>
Date: Tue, 18 Jul 2023 10:41:55 +0800
Subject: [PATCH 15/64] [NFC] polish
applications/Chat/inference/requirements.txt code style (#4265)
---
applications/Chat/inference/requirements.txt | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/applications/Chat/inference/requirements.txt b/applications/Chat/inference/requirements.txt
index 511fe1a4f1f3..cb6275361736 100644
--- a/applications/Chat/inference/requirements.txt
+++ b/applications/Chat/inference/requirements.txt
@@ -10,4 +10,4 @@ uvicorn
git+https://github.com/huggingface/transformers
accelerate
bitsandbytes
-jieba
\ No newline at end of file
+jieba
From 77c469e1ba948cdc6c4d6dd32ec151d653255ad9 Mon Sep 17 00:00:00 2001
From: Junming Wu
Date: Tue, 18 Jul 2023 10:43:52 +0800
Subject: [PATCH 16/64] [NFC] polish
applications/Chat/coati/models/base/actor.py code style (#4248)
---
applications/Chat/coati/models/base/actor.py | 17 +++++++----------
1 file changed, 7 insertions(+), 10 deletions(-)
diff --git a/applications/Chat/coati/models/base/actor.py b/applications/Chat/coati/models/base/actor.py
index 2034d5cc81d4..6842f81d9b87 100644
--- a/applications/Chat/coati/models/base/actor.py
+++ b/applications/Chat/coati/models/base/actor.py
@@ -21,16 +21,13 @@ def __init__(self, model: nn.Module, lora_rank: int = 0, lora_train_bias: str =
self.model = model
self.convert_to_lora()
- def forward(self,
- input_ids: torch.LongTensor,
- attention_mask: Optional[torch.Tensor] = None,
- **model_kwargs, # HACK: `generate` method may pass more kwargs
- ) -> torch.Tensor:
+ def forward(
+ self,
+ input_ids: torch.LongTensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ **model_kwargs, # HACK: `generate` method may pass more kwargs
+ ) -> torch.Tensor:
"""Returns model output.
"""
- output = self.model(
- input_ids,
- attention_mask=attention_mask,
- **model_kwargs
- )
+ output = self.model(input_ids, attention_mask=attention_mask, **model_kwargs)
return output
From dee1c96344f7de1dd0bd924698d3c0866eb7b02e Mon Sep 17 00:00:00 2001
From: CZYCW
Date: Tue, 18 Jul 2023 10:53:08 +0800
Subject: [PATCH 17/64] [NFC] policy
applications/Chat/examples/ray/mmmt_prompt.py code style (#4250)
---
applications/Chat/examples/ray/mmmt_prompt.py | 14 ++++++--------
1 file changed, 6 insertions(+), 8 deletions(-)
diff --git a/applications/Chat/examples/ray/mmmt_prompt.py b/applications/Chat/examples/ray/mmmt_prompt.py
index 60f049bd5b70..76929c9d0144 100644
--- a/applications/Chat/examples/ray/mmmt_prompt.py
+++ b/applications/Chat/examples/ray/mmmt_prompt.py
@@ -87,8 +87,8 @@ def model_fn():
kl_coef=0.1,
debug=args.debug,
update_lora_weights=not (args.lora_rank == 0),
- # sync_models_from_trainers=True,
- # generation kwargs:
+ # sync_models_from_trainers=True,
+ # generation kwargs:
max_length=512,
do_sample=True,
temperature=1.0,
@@ -161,12 +161,10 @@ def tokenize_fn(texts):
parser.add_argument('--prompt_path', type=str, default=None)
parser.add_argument('--num_makers', type=int, default=1)
parser.add_argument('--num_trainers', type=int, default=1)
- parser.add_argument('--trainer_strategy',
- choices=[
- 'ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_gemini_cpu',
- 'colossalai_zero2_cpu'
- ],
- default='ddp')
+ parser.add_argument(
+ '--trainer_strategy',
+ choices=['ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_gemini_cpu', 'colossalai_zero2_cpu'],
+ default='ddp')
parser.add_argument('--maker_strategy', choices=['naive'], default='naive')
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
parser.add_argument('--critic_model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
From 85774f0c1fd1425149a61a463f726b25d49ec420 Mon Sep 17 00:00:00 2001
From: ocd_with_naming <54058983+yuanheng-zhao@users.noreply.github.com>
Date: Tue, 18 Jul 2023 10:54:27 +0800
Subject: [PATCH 18/64] [NFC] polish colossalai/cli/benchmark/utils.py code
style (#4254)
---
colossalai/cli/benchmark/utils.py | 11 ++++++-----
1 file changed, 6 insertions(+), 5 deletions(-)
diff --git a/colossalai/cli/benchmark/utils.py b/colossalai/cli/benchmark/utils.py
index 825b795f21f6..ee7d92d6ea6a 100644
--- a/colossalai/cli/benchmark/utils.py
+++ b/colossalai/cli/benchmark/utils.py
@@ -1,10 +1,11 @@
import math
import time
+from typing import Callable, Dict, List, Tuple
+
import torch
+from colossalai.context import Config, ParallelMode
from colossalai.utils import MultiTimer
-from colossalai.context import ParallelMode, Config
-from typing import List, Dict, Tuple, Callable
def get_time_stamp() -> int:
@@ -25,8 +26,8 @@ def get_memory_states() -> Tuple[float]:
Return the memory statistics.
Returns:
- max_allocated (float): the allocated CUDA memory
- max_cached (float): the cached CUDA memory
+ max_allocated (float): the allocated CUDA memory
+ max_cached (float): the cached CUDA memory
"""
max_allocated = torch.cuda.max_memory_allocated() / (1024**3)
@@ -101,7 +102,7 @@ def profile_model(model: torch.nn.Module, warmup_steps: int, profile_steps: int,
profile_steps (int): the number of steps for profiling
data_func (Callable): a function to generate random data
timer (colossalai.utils.Multitimer): a timer instance for time recording
-
+
Returns:
fwd_time (float): the average forward time taken by forward pass in second
bwd_time (float): the average backward time taken by forward pass in second
From c614a99d286087b5768cc6422b7317dcff02db3e Mon Sep 17 00:00:00 2001
From: Yanjia0 <42895286+Yanjia0@users.noreply.github.com>
Date: Tue, 18 Jul 2023 10:54:55 +0800
Subject: [PATCH 19/64] [NFC] polish
colossalai/auto_parallel/offload/amp_optimizer.py code style (#4255)
---
colossalai/auto_parallel/offload/amp_optimizer.py | 11 ++++++-----
1 file changed, 6 insertions(+), 5 deletions(-)
diff --git a/colossalai/auto_parallel/offload/amp_optimizer.py b/colossalai/auto_parallel/offload/amp_optimizer.py
index a79e5006e7d2..19d85b80dd3d 100644
--- a/colossalai/auto_parallel/offload/amp_optimizer.py
+++ b/colossalai/auto_parallel/offload/amp_optimizer.py
@@ -1,24 +1,25 @@
-from typing import Dict, Tuple
from enum import Enum
+from typing import Dict, Tuple
+
import torch
from torch.optim import Optimizer
+from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer
-from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.utils import get_current_device
from .base_offload_module import BaseOffloadModule
-from .region_manager import RegionManager
from .region import Region
+from .region_manager import RegionManager
class OptimState(Enum):
SCALED = 0
UNSCALED = 1
-class AMPOptimizer(ColossalaiOptimizer):
+class AMPOptimizer(ColossalaiOptimizer):
"""
A wrapper for Optimizer.
Code reference: https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/optimizer/zero_optimizer.py
@@ -174,4 +175,4 @@ def __init__optimizer(self):
# Leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors
- self.optim.load_state_dict(self.optim.state_dict())
\ No newline at end of file
+ self.optim.load_state_dict(self.optim.state_dict())
From abe4f971e0e316e8558569bf30faca77772367b6 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E6=A2=81=E7=88=BD?=
<100194095+supercooledith@users.noreply.github.com>
Date: Tue, 18 Jul 2023 10:58:43 +0800
Subject: [PATCH 20/64] [NFC] polish
colossalai/booster/plugin/low_level_zero_plugin.py code style (#4256)
Co-authored-by: supercooledith <893754954@qq.com>
---
colossalai/booster/plugin/low_level_zero_plugin.py | 5 +----
1 file changed, 1 insertion(+), 4 deletions(-)
diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py
index 94d722080367..3ec0d34092a4 100644
--- a/colossalai/booster/plugin/low_level_zero_plugin.py
+++ b/colossalai/booster/plugin/low_level_zero_plugin.py
@@ -208,10 +208,7 @@ def configure(
if optimizer is not None and \
not isinstance(optimizer, OptimizerWrapper):
- optimizer = LowLevelZeroOptimizer(model.unwrap(),
- optimizer,
- self.zero_optim_config,
- self.optim_kwargs,
+ optimizer = LowLevelZeroOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs,
self.verbose)
return model, optimizer, criterion, dataloader, lr_scheduler
From b2debdc09bc9dd7d87a409e7a4485d1eca74cb61 Mon Sep 17 00:00:00 2001
From: "Zheng Zangwei (Alex Zheng)"
Date: Tue, 18 Jul 2023 10:59:38 +0800
Subject: [PATCH 21/64] [NFC] polish
applications/Chat/coati/dataset/sft_dataset.py code style (#4259)
---
.../Chat/coati/dataset/sft_dataset.py | 20 +++++++++----------
1 file changed, 9 insertions(+), 11 deletions(-)
diff --git a/applications/Chat/coati/dataset/sft_dataset.py b/applications/Chat/coati/dataset/sft_dataset.py
index 3702d00cc609..3038fbe071db 100644
--- a/applications/Chat/coati/dataset/sft_dataset.py
+++ b/applications/Chat/coati/dataset/sft_dataset.py
@@ -74,15 +74,10 @@ def __getitem__(self, idx):
return dict(input_ids=self.input_ids[idx], labels=self.labels[idx])
-def _tokenize_fn(strings: Sequence[str],
- tokenizer: transformers.PreTrainedTokenizer,
- max_length: int
- ) -> Dict[str, torch.Tensor]:
+def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer,
+ max_length: int) -> Dict[str, torch.Tensor]:
"""Tokenize a list of strings."""
- tokenized_list = tokenizer(
- strings, return_tensors="pt", padding="longest",
- max_length=max_length, truncation=True
- )
+ tokenized_list = tokenizer(strings, return_tensors="pt", padding="longest", max_length=max_length, truncation=True)
input_ids = labels = tokenized_list["input_ids"]
input_ids_lens = labels_lens = \
tokenized_list["input_ids"].ne(tokenizer.pad_token_id).sum(dim=-1)
@@ -103,8 +98,7 @@ def preprocess(
"""Preprocess the data by tokenizing."""
examples = [s + t for s, t in zip(sources, targets)]
examples_tokenized, sources_tokenized = [
- _tokenize_fn(strings, tokenizer, max_length)
- for strings in (examples, sources)
+ _tokenize_fn(strings, tokenizer, max_length) for strings in (examples, sources)
]
input_ids = examples_tokenized["input_ids"]
labels = copy.deepcopy(input_ids)
@@ -116,7 +110,11 @@ def preprocess(
class SupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""
- def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer, max_datasets_size: int = None, max_length: int = 512):
+ def __init__(self,
+ data_path: str,
+ tokenizer: transformers.PreTrainedTokenizer,
+ max_datasets_size: int = None,
+ max_length: int = 512):
super(SupervisedDataset, self).__init__()
logger.info("Loading data...")
list_data_dict = jload(data_path)
From 798cb72907f5425dccc94c83fd1a900a6b8f67eb Mon Sep 17 00:00:00 2001
From: shenggan
Date: Tue, 18 Jul 2023 10:59:57 +0800
Subject: [PATCH 22/64] [NFC] polish applications/Chat/coati/trainer/base.py
code style (#4260)
---
applications/Chat/coati/trainer/base.py | 53 ++++++++++---------------
1 file changed, 22 insertions(+), 31 deletions(-)
diff --git a/applications/Chat/coati/trainer/base.py b/applications/Chat/coati/trainer/base.py
index 13571cdcc23a..b4d168a563d9 100644
--- a/applications/Chat/coati/trainer/base.py
+++ b/applications/Chat/coati/trainer/base.py
@@ -25,12 +25,13 @@ class SLTrainer(ABC):
optim (Optimizer): the optimizer to use for training
"""
- def __init__(self,
- strategy: Strategy,
- max_epochs: int,
- model: nn.Module,
- optimizer: Optimizer,
- ) -> None:
+ def __init__(
+ self,
+ strategy: Strategy,
+ max_epochs: int,
+ model: nn.Module,
+ optimizer: Optimizer,
+ ) -> None:
super().__init__()
self.strategy = strategy
self.max_epochs = max_epochs
@@ -50,10 +51,7 @@ def _before_fit(self):
def fit(self, *args, **kwargs):
self._before_fit(*args, **kwargs)
- for epoch in tqdm.trange(self.max_epochs,
- desc="Epochs",
- disable=not is_rank_0() or self.no_epoch_bar
- ):
+ for epoch in tqdm.trange(self.max_epochs, desc="Epochs", disable=not is_rank_0() or self.no_epoch_bar):
self._train(epoch)
self._eval(epoch)
@@ -75,8 +73,7 @@ def __init__(self,
buffer: NaiveReplayBuffer,
sample_buffer: bool,
dataloader_pin_memory: bool,
- callbacks: List[Callback] = []
- ) -> None:
+ callbacks: List[Callback] = []) -> None:
super().__init__()
self.strategy = strategy
self.buffer = buffer
@@ -138,7 +135,7 @@ def _make_experience(self, collect_step: int):
@abstractmethod
def _learn(self, update_step: int):
"""
- Implement this method to learn from experience, either
+ Implement this method to learn from experience, either
sample from buffer or transform buffer into dataloader.
"""
raise NotImplementedError()
@@ -154,13 +151,14 @@ def _update_phase(self, update_step: int):
self._learn(update_step)
self._on_learn_epoch_end(update_step)
- def fit(self,
- prompt_dataloader: DataLoader,
- pretrain_dataloader: DataLoader,
- num_episodes: int,
- num_collect_steps: int,
- num_update_steps: int,
- ):
+ def fit(
+ self,
+ prompt_dataloader: DataLoader,
+ pretrain_dataloader: DataLoader,
+ num_episodes: int,
+ num_collect_steps: int,
+ num_update_steps: int,
+ ):
"""
The main training loop of on-policy rl trainers.
@@ -175,23 +173,16 @@ def fit(self,
self.pretrain_dataloader = CycledDataLoader(pretrain_dataloader)
with self._fit_ctx():
- for episode in tqdm.trange(num_episodes,
- desc="Episodes",
- disable=not is_rank_0()):
+ for episode in tqdm.trange(num_episodes, desc="Episodes", disable=not is_rank_0()):
with self._episode_ctx(episode):
- for collect_step in tqdm.trange(num_collect_steps,
- desc="Collect steps",
- disable=not is_rank_0()):
+ for collect_step in tqdm.trange(num_collect_steps, desc="Collect steps", disable=not is_rank_0()):
self._collect_phase(collect_step)
if not self.sample_buffer:
# HACK(cwher): according to the design of boost API, dataloader should also be boosted,
# but it is impractical to adapt this pattern in RL training. Thus, I left dataloader unboosted.
# I only call strategy.setup_dataloader() to setup dataloader.
- self.dataloader = self.strategy.setup_dataloader(self.buffer,
- self.dataloader_pin_memory)
- for update_step in tqdm.trange(num_update_steps,
- desc="Update steps",
- disable=not is_rank_0()):
+ self.dataloader = self.strategy.setup_dataloader(self.buffer, self.dataloader_pin_memory)
+ for update_step in tqdm.trange(num_update_steps, desc="Update steps", disable=not is_rank_0()):
self._update_phase(update_step)
# NOTE: this is for on-policy algorithms
self.buffer.clear()
From 3883db452c533127c1bc28cc12b9533d206d50cf Mon Sep 17 00:00:00 2001
From: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Date: Tue, 18 Jul 2023 11:53:47 +0800
Subject: [PATCH 23/64] [NFC] polish unary_elementwise_generator.py code style
(#4267)
Co-authored-by: aye42
---
.../node_handler/strategy/unary_elementwise_generator.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/unary_elementwise_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/unary_elementwise_generator.py
index b867a30686eb..39799a67c5a0 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/unary_elementwise_generator.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/unary_elementwise_generator.py
@@ -1,7 +1,7 @@
import copy
from typing import List
-from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
+from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem
from .strategy_generator import FollowingStrategyGenerator
From fee553288b7b4f3116a8cc4d1a5052183302c709 Mon Sep 17 00:00:00 2001
From: Wenhao Chen
Date: Tue, 18 Jul 2023 11:54:09 +0800
Subject: [PATCH 24/64] [NFC] polish runtime_preparation_pass style (#4266)
---
colossalai/auto_parallel/passes/runtime_preparation_pass.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/colossalai/auto_parallel/passes/runtime_preparation_pass.py b/colossalai/auto_parallel/passes/runtime_preparation_pass.py
index 9a2314826448..1a6dc7815176 100644
--- a/colossalai/auto_parallel/passes/runtime_preparation_pass.py
+++ b/colossalai/auto_parallel/passes/runtime_preparation_pass.py
@@ -55,7 +55,7 @@ def size_processing(size: Union[int, torch.Size],
def solution_annotation_pass(gm: torch.fx.GraphModule, solution: List[int],
- strategies_constructor: StrategiesConstructor):
+ strategies_constructor: StrategiesConstructor):
"""
This method is used to stick the solution strategy to the nodes and add the information
required in runtime into graph as placeholder nodes.
From a50d39a143c8ae83e1ae4960d2026bb53199d3ad Mon Sep 17 00:00:00 2001
From: dayellow <49357110+dayellow@users.noreply.github.com>
Date: Tue, 18 Jul 2023 13:51:37 +0800
Subject: [PATCH 25/64] [NFC] fix: format (#4270)
* [NFC] polish colossalai/fx/profiler/experimental/profiler_module/embedding.py code style
* [NFC] polish colossalai/communication/utils.py code style
---------
Co-authored-by: Minghao Huang
---
colossalai/communication/utils.py | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/colossalai/communication/utils.py b/colossalai/communication/utils.py
index ef9eceea847d..1516df356278 100644
--- a/colossalai/communication/utils.py
+++ b/colossalai/communication/utils.py
@@ -1,10 +1,11 @@
+from typing import List, Tuple, Union
+
import torch
import torch.distributed as dist
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import get_current_device
-from typing import Union, List, Tuple
TensorShape = Union[torch.Size, List[int], Tuple[int]]
From 1ce997daaf7ffdf6c93a7bbd179d7be18928ba41 Mon Sep 17 00:00:00 2001
From: Xu Kai
Date: Tue, 18 Jul 2023 18:01:52 +0800
Subject: [PATCH 26/64] [NFC] polish
applications/Chat/examples/train_reward_model.py code style (#4271)
---
applications/Chat/examples/train_reward_model.py | 8 ++------
1 file changed, 2 insertions(+), 6 deletions(-)
diff --git a/applications/Chat/examples/train_reward_model.py b/applications/Chat/examples/train_reward_model.py
index 5b1b8d3d16b2..fb9802e38542 100644
--- a/applications/Chat/examples/train_reward_model.py
+++ b/applications/Chat/examples/train_reward_model.py
@@ -150,9 +150,7 @@ def train(args):
pin_memory=True)
lr_scheduler = CosineAnnealingLR(optim, train_dataloader.__len__() // 100)
- strategy_dict = strategy.prepare(
- dict(model=model, optimizer=optim, lr_scheduler=lr_scheduler)
- )
+ strategy_dict = strategy.prepare(dict(model=model, optimizer=optim, lr_scheduler=lr_scheduler))
model = strategy_dict['model']
optim = strategy_dict['optimizer']
lr_scheduler = strategy_dict['lr_scheduler']
@@ -163,9 +161,7 @@ def train(args):
loss_fn=loss_fn,
max_epochs=args.max_epochs)
- trainer.fit(train_dataloader=train_dataloader,
- valid_dataloader=valid_dataloader,
- eval_dataloader=eval_dataloader)
+ trainer.fit(train_dataloader=train_dataloader, valid_dataloader=valid_dataloader, eval_dataloader=eval_dataloader)
# save model checkpoint after fitting on only rank0
strategy.save_model(model, args.save_path, only_rank0=True)
# save optimizer checkpoint on all ranks
From caa4433072aab219418205bf4d338fd1aa355d42 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E3=82=A2=E3=83=9E=E3=83=87=E3=82=A6=E3=82=B9?=
Date: Tue, 18 Jul 2023 18:02:35 +0800
Subject: [PATCH 27/64] [NFC] fix format of
application/Chat/coati/trainer/utils.py (#4273)
---
applications/Chat/coati/trainer/utils.py | 7 ++++---
1 file changed, 4 insertions(+), 3 deletions(-)
diff --git a/applications/Chat/coati/trainer/utils.py b/applications/Chat/coati/trainer/utils.py
index c9fc8d0fe19f..4d45061bab09 100644
--- a/applications/Chat/coati/trainer/utils.py
+++ b/applications/Chat/coati/trainer/utils.py
@@ -14,9 +14,10 @@ class CycledDataLoader:
NOTE: next(iter(dataloader)) is not equivalent to for batch in dataloader: break, it causes slightly different behavior.
"""
- def __init__(self,
- dataloader: DataLoader,
- ) -> None:
+ def __init__(
+ self,
+ dataloader: DataLoader,
+ ) -> None:
self.dataloader = dataloader
self.count = 0
From dc1b6127f9554fda0051cb5f06e6532599896592 Mon Sep 17 00:00:00 2001
From: Yuanchen <70520919+chengeharrison@users.noreply.github.com>
Date: Tue, 18 Jul 2023 18:03:08 +0800
Subject: [PATCH 28/64] [NFC] polish applications/Chat/inference/server.py code
style (#4274)
Co-authored-by: Yuanchen Xu
---
applications/Chat/inference/server.py | 6 ++++--
1 file changed, 4 insertions(+), 2 deletions(-)
diff --git a/applications/Chat/inference/server.py b/applications/Chat/inference/server.py
index b4627299397e..e23f0fceb2fa 100644
--- a/applications/Chat/inference/server.py
+++ b/applications/Chat/inference/server.py
@@ -14,7 +14,7 @@
from slowapi.util import get_remote_address
from sse_starlette.sse import EventSourceResponse
from transformers import AutoTokenizer, GenerationConfig, LlamaForCausalLM
-from utils import ChatPromptProcessor, Dialogue, LockedIterator, sample_streamingly, update_model_kwargs_fn, load_json
+from utils import ChatPromptProcessor, Dialogue, LockedIterator, load_json, sample_streamingly, update_model_kwargs_fn
CONTEXT = 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.'
MAX_LEN = 512
@@ -145,7 +145,9 @@ def generate_no_stream(data: GenerationTaskReq, request: Request):
help='Group size for GPTQ. This is only useful when quantization mode is 4bit. Default: 128.')
parser.add_argument('--http_host', default='0.0.0.0')
parser.add_argument('--http_port', type=int, default=7070)
- parser.add_argument('--profanity_file', default=None, help='Path to profanity words list. It should be a JSON file containing a list of words.')
+ parser.add_argument('--profanity_file',
+ default=None,
+ help='Path to profanity words list. It should be a JSON file containing a list of words.')
args = parser.parse_args()
if args.quant == '4bit':
From 709e121cd5a98e01177e47dd4fe8a0833ff0af8a Mon Sep 17 00:00:00 2001
From: RichardoLuo <50363844+RichardoLuo@users.noreply.github.com>
Date: Tue, 18 Jul 2023 18:04:02 +0800
Subject: [PATCH 29/64] [NFC] polish
applications/Chat/coati/models/generation.py code style (#4275)
---
applications/Chat/coati/models/generation.py | 13 ++++++-------
1 file changed, 6 insertions(+), 7 deletions(-)
diff --git a/applications/Chat/coati/models/generation.py b/applications/Chat/coati/models/generation.py
index 0156e2284e52..d96ad78a89ce 100644
--- a/applications/Chat/coati/models/generation.py
+++ b/applications/Chat/coati/models/generation.py
@@ -5,7 +5,6 @@
import torch.nn as nn
import torch.nn.functional as F
-
try:
from transformers.generation_logits_process import (
LogitsProcessorList,
@@ -148,12 +147,12 @@ def generate(model: nn.Module,
@torch.no_grad()
-def generate_with_actor(actor_model: nn.Module,
- input_ids: torch.Tensor,
- return_action_mask: bool = True,
- **kwargs
- ) -> Union[Tuple[torch.LongTensor, torch.LongTensor],
- Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]:
+def generate_with_actor(
+ actor_model: nn.Module,
+ input_ids: torch.Tensor,
+ return_action_mask: bool = True,
+ **kwargs
+) -> Union[Tuple[torch.LongTensor, torch.LongTensor], Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]:
"""Generate token sequence with actor model. Refer to `generate` for more details.
"""
# generate sequences
From c972d653111dcfbd63cd22b26ddb7d3ee83a69ed Mon Sep 17 00:00:00 2001
From: Ziheng Qin <37519855+henryqin1997@users.noreply.github.com>
Date: Wed, 19 Jul 2023 09:38:49 +0800
Subject: [PATCH 30/64] applications/Chat/.gitignore (#4279)
Co-authored-by: henryqin1997
---
applications/Chat/.gitignore | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/applications/Chat/.gitignore b/applications/Chat/.gitignore
index 2b9b4f345d0f..5fa068105e26 100644
--- a/applications/Chat/.gitignore
+++ b/applications/Chat/.gitignore
@@ -145,4 +145,4 @@ docs/.build
# wandb log
example/wandb/
-examples/awesome-chatgpt-prompts/
\ No newline at end of file
+examples/awesome-chatgpt-prompts/
From 9e512938f6b0b79c2d61c12d4fdc3b4a0008362e Mon Sep 17 00:00:00 2001
From: Zirui Zhu
Date: Wed, 19 Jul 2023 22:18:08 +0800
Subject: [PATCH 31/64] [NFC] polish
applications/Chat/coati/trainer/strategies/base.py code style (#4278)
---
.../Chat/coati/trainer/strategies/base.py | 22 ++++---------------
1 file changed, 4 insertions(+), 18 deletions(-)
diff --git a/applications/Chat/coati/trainer/strategies/base.py b/applications/Chat/coati/trainer/strategies/base.py
index 80bc3272872e..3d1dfaf784cf 100644
--- a/applications/Chat/coati/trainer/strategies/base.py
+++ b/applications/Chat/coati/trainer/strategies/base.py
@@ -79,8 +79,7 @@ def prepare(self, *boost_args: _BoostArgSpec) -> Union[List[_BoostArgSpec], _Boo
model, optimizer = arg
except ValueError:
raise RuntimeError(f'Expect (model, optimizer) pair, got a tuple with size "{len(arg)}"')
- model, optimizer, *_ = self.booster.boost(model=model,
- optimizer=optimizer)
+ model, optimizer, *_ = self.booster.boost(model=model, optimizer=optimizer)
rets.append((model, optimizer))
elif isinstance(arg, Dict):
model, optimizer, criterion, dataloader, lr_scheduler = self.booster.boost(**arg)
@@ -90,10 +89,7 @@ def prepare(self, *boost_args: _BoostArgSpec) -> Union[List[_BoostArgSpec], _Boo
dataloader=dataloader,
lr_scheduler=lr_scheduler)
# remove None values
- boost_result = {
- key: value
- for key, value in boost_result.items() if value is not None
- }
+ boost_result = {key: value for key, value in boost_result.items() if value is not None}
rets.append(boost_result)
else:
raise RuntimeError(f'Type {type(arg)} is not supported')
@@ -112,23 +108,13 @@ def unwrap_model(model: nn.Module) -> nn.Module:
"""
return model
- def save_model(self,
- model: nn.Module,
- path: str,
- only_rank0: bool = True,
- **kwargs
- ) -> None:
+ def save_model(self, model: nn.Module, path: str, only_rank0: bool = True, **kwargs) -> None:
self.booster.save_model(model, path, shard=not only_rank0, **kwargs)
def load_model(self, model: nn.Module, path: str, strict: bool = True) -> None:
self.booster.load_model(model, path, strict)
- def save_optimizer(self,
- optimizer: Optimizer,
- path: str,
- only_rank0: bool = False,
- **kwargs
- ) -> None:
+ def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False, **kwargs) -> None:
self.booster.save_optimizer(optimizer, path, shard=not only_rank0, **kwargs)
def load_optimizer(self, optimizer: Optimizer, path: str) -> None:
From 09914053619fe78232ec39931c39947b35717f9e Mon Sep 17 00:00:00 2001
From: yuxuan-lou <83441848+yuxuan-lou@users.noreply.github.com>
Date: Wed, 19 Jul 2023 22:18:30 +0800
Subject: [PATCH 32/64] [NFC] polish applications/Chat/coati/models/utils.py
codestyle (#4277)
* [NFC] polish colossalai/context/random/__init__.py code style
* [NFC] polish applications/Chat/coati/models/utils.py code style
---
applications/Chat/coati/models/utils.py | 5 +----
1 file changed, 1 insertion(+), 4 deletions(-)
diff --git a/applications/Chat/coati/models/utils.py b/applications/Chat/coati/models/utils.py
index b9f15f894a1f..772bfc32982a 100644
--- a/applications/Chat/coati/models/utils.py
+++ b/applications/Chat/coati/models/utils.py
@@ -46,10 +46,7 @@ def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.T
return log_probs_labels.squeeze(-1)
-def calc_action_log_probs(output: torch.Tensor,
- sequences: torch.LongTensor,
- num_actions: int
- ) -> torch.Tensor:
+def calc_action_log_probs(output: torch.Tensor, sequences: torch.LongTensor, num_actions: int) -> torch.Tensor:
"""Calculate action log probs.
Args:
From ef4b99ebcda823c32e221383ac9365a82697cd5d Mon Sep 17 00:00:00 2001
From: binmakeswell
Date: Sat, 22 Jul 2023 15:08:37 +0800
Subject: [PATCH 33/64] add llama example CI
---
examples/language/llama/test_ci.sh | 0
1 file changed, 0 insertions(+), 0 deletions(-)
create mode 100755 examples/language/llama/test_ci.sh
diff --git a/examples/language/llama/test_ci.sh b/examples/language/llama/test_ci.sh
new file mode 100755
index 000000000000..e69de29bb2d1
From 5187c96b7c04ac6c794a58044533c874fa24e206 Mon Sep 17 00:00:00 2001
From: Yuanchen <70520919+chengeharrison@users.noreply.github.com>
Date: Fri, 28 Jul 2023 11:29:55 +0800
Subject: [PATCH 34/64] support session-based training (#4313)
Co-authored-by: Yuanchen Xu
---
.../Chat/coati/dataset/conversation.py | 87 +++++++++++++++++
.../Chat/coati/dataset/sft_dataset.py | 97 +++++++++++++++++--
applications/Chat/examples/README.md | 44 +++++++++
.../examples/generate_conversation_dataset.py | 79 +++++++++++++++
applications/Chat/examples/train_sft.py | 8 +-
5 files changed, 300 insertions(+), 15 deletions(-)
create mode 100644 applications/Chat/coati/dataset/conversation.py
create mode 100644 applications/Chat/examples/generate_conversation_dataset.py
diff --git a/applications/Chat/coati/dataset/conversation.py b/applications/Chat/coati/dataset/conversation.py
new file mode 100644
index 000000000000..465fa867c7ab
--- /dev/null
+++ b/applications/Chat/coati/dataset/conversation.py
@@ -0,0 +1,87 @@
+# Copyright 2023 lm-sys@FastChat
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import dataclasses
+from enum import Enum, auto
+from typing import List
+
+
+class SeparatorStyle(Enum):
+ ADD_EOS_TOKEN = auto()
+
+
+@dataclasses.dataclass
+class Conversation:
+ system: str
+ roles: List[str]
+ messages: List[List[str]]
+ offset: int
+ sep_style: SeparatorStyle = SeparatorStyle.ADD_EOS_TOKEN
+ sep: str = ""
+
+ skip_next: bool = False
+
+ def get_prompt(self):
+ if self.sep_style == SeparatorStyle.ADD_EOS_TOKEN:
+ ret = self.system
+ for role, message in self.messages:
+ if message:
+ ret += role + ": " + message + self.sep
+ else:
+ ret += role + ": "
+ return ret
+ else:
+ raise ValueError(f"Invalid style: {self.sep_style}")
+
+ def append_message(self, role, message):
+ self.messages.append([role, message])
+
+ def to_gradio_chatbot(self):
+ ret = []
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
+ if i % 2 == 0:
+ ret.append([msg, None])
+ else:
+ ret[-1][-1] = msg
+ return ret
+
+ def copy(self):
+ return Conversation(system=self.system,
+ roles=self.roles,
+ messages=[[x, y] for x, y in self.messages],
+ offset=self.offset,
+ sep_style=self.sep_style,
+ sep=self.sep)
+
+ def dict(self):
+ return {
+ "system": self.system,
+ "roles": self.roles,
+ "messages": self.messages,
+ "offset": self.offset,
+ "sep": self.sep
+ }
+
+
+conv = Conversation(
+ system="A chat between a curious human and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
+ roles=("Human", "Assistant"),
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.ADD_EOS_TOKEN,
+ sep="",
+)
+
+default_conversation = conv
diff --git a/applications/Chat/coati/dataset/sft_dataset.py b/applications/Chat/coati/dataset/sft_dataset.py
index 3038fbe071db..0b04cf79ee54 100644
--- a/applications/Chat/coati/dataset/sft_dataset.py
+++ b/applications/Chat/coati/dataset/sft_dataset.py
@@ -15,7 +15,7 @@
import copy
import random
from dataclasses import dataclass, field
-from typing import Callable, Dict, Sequence
+from typing import Callable, Dict, List, Sequence, Tuple
import torch
import torch.distributed as dist
@@ -25,11 +25,21 @@
from colossalai.logging import get_dist_logger
+from .conversation import default_conversation
from .utils import is_rank_0, jload
+# The following is a template prompt for a 4-round conversation.
+"""
+A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
+
+Human: xxxAssistant: xxxHuman: xxxAssistant: xxxHuman: xxxAssistant: xxxHuman: xxxAssistant: xxx
+"""
+# Please note that we only calculate loss on assistant's answer tokens.
+
logger = get_dist_logger()
IGNORE_INDEX = -100
+DEFAULT_EOS_TOKEN = ""
PROMPT_DICT = {
"prompt_input":
("Below is an instruction that describes a task, paired with an input that provides further context. "
@@ -107,6 +117,61 @@ def preprocess(
return dict(input_ids=input_ids, labels=labels)
+def preprocess_conversation(sources: List[List[Dict]], tokenizer: transformers.PreTrainedTokenizer,
+ max_length: int) -> Dict:
+ """Preprocess the conversation data by tokenizing."""
+ conversations = []
+ intermediates = []
+ for source in sources:
+ header = f"{default_conversation.system}"
+ conversation, intermediate = _add_speaker_and_signal(header, source)
+ conversations.append(conversation)
+ intermediates.append(intermediate)
+
+ conversations_tokenized = _tokenize_fn(conversations, tokenizer, max_length)
+ input_ids = conversations_tokenized["input_ids"]
+ targets = copy.deepcopy(input_ids)
+
+ assert len(targets) == len(intermediates)
+ for target, inters in zip(targets, intermediates):
+ mask = torch.zeros_like(target, dtype=torch.bool)
+ for inter in inters:
+ tokenized = _tokenize_fn(inter, tokenizer, max_length)
+
+ start_idx = tokenized["input_ids"][0].size(0) - 1
+ end_idx = tokenized["input_ids"][1].size(0)
+
+ mask[start_idx:end_idx] = True
+ target[~mask] = IGNORE_INDEX
+
+ return dict(input_ids=input_ids, labels=targets)
+
+
+def _add_speaker_and_signal(header: str,
+ source: List[Dict],
+ get_conversation: bool = True) -> Tuple[str, List[List[str]]]:
+ END_SIGNAL = DEFAULT_EOS_TOKEN
+ conversation = header
+ intermediate = []
+ for sentence in source:
+ from_str = sentence["from"]
+ if from_str.lower() == "human":
+ from_str = default_conversation.roles[0]
+ elif from_str.lower() == "gpt":
+ from_str = default_conversation.roles[1]
+ else:
+ from_str = 'unknown'
+
+ value = from_str + ": " + sentence["value"] + END_SIGNAL
+ if sentence["from"].lower() == "gpt":
+ start = conversation + from_str + ": "
+ end = conversation + value
+ intermediate.append([start, end])
+ if get_conversation:
+ conversation += value
+ return conversation, intermediate
+
+
class SupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""
@@ -125,15 +190,27 @@ def __init__(self,
list_data_dict = list_data_dict[:max_datasets_size]
logger.info("Formatting inputs...")
- prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
- sources = [
- prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example)
- for example in list_data_dict
- ]
- targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]
-
- logger.info("Tokenizing inputs... This may take some time...")
- data_dict = preprocess(sources, targets, tokenizer, max_length)
+ if "conversations" not in list_data_dict[0]:
+ prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
+ sources = [
+ prompt_input.format_map(example)
+ if example.get("input", "") != "" else prompt_no_input.format_map(example) for example in list_data_dict
+ ]
+ targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]
+
+ if is_rank_0():
+ logger.info("Tokenizing inputs... This may take some time...")
+
+ data_dict = preprocess(sources, targets, tokenizer, max_length)
+ else:
+ if is_rank_0():
+ logger.info("Tokenizing inputs... This may take some time...")
+
+ sources = [conv["conversations"] for conv in list_data_dict]
+ data_dict = preprocess_conversation(sources, tokenizer, max_length)
+
+ if is_rank_0():
+ logger.info("Tokenizing finish.")
self.input_ids = data_dict["input_ids"]
self.labels = data_dict["labels"]
diff --git a/applications/Chat/examples/README.md b/applications/Chat/examples/README.md
index 56e4cc992c17..f0cdfeff5b61 100644
--- a/applications/Chat/examples/README.md
+++ b/applications/Chat/examples/README.md
@@ -6,6 +6,7 @@
- [Table of Contents](#table-of-contents)
- [Install requirements](#install-requirements)
- [Supervised datasets collection](#supervised-datasets-collection)
+ - [Conversation dataset generation](#conversation-dataset-generation)
- [Stage1 - Supervised instructs tuning](#stage1---supervised-instructs-tuning)
- [Arg List](#arg-list)
- [Stage2 - Training reward model](#stage2---training-reward-model)
@@ -45,6 +46,49 @@ The following pic shows how we collected the data.
+### Conversation dataset generation
+
+In order to further improve the model's ability to handle multi-turn conversations, we need to include samples with multi-turn conversations in the dataset. However, the samples in InstructWild and Alpaca datasets currently consist of only single-turn conversations, and their dataset organization is not suitable for storing multi-turn conversations. Additionally, after converting the aforementioned datasets, we also need to include multi-turn conversation datasets like ShareGPT, and we should transform them into the training format supported by ColossalChat.
+
+A sample of conversation dataset should have the following fields:
+
+* `type` (str, optional): The type of the data sample.
+* `language` (str, optional): The language of the data sample.
+* `dataset` (str, optional): The dataset the data sample originates from.
+* `conversations` (str, compulsory): Conversation content of the data sample.
+* `id` (int, optional): The ID of the data sample.
+
+A simple example:
+```json
+{
+ "type": "instruction",
+ "language": "English",
+ "dataset": "Alpaca",
+ "conversations": [
+ {
+ "from": "human",
+ "value": "Give three tips for staying healthy."
+ },
+ {
+ "from": "gpt",
+ "value": "1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \n2. Exercise regularly to keep your body active and strong. \n3. Get enough sleep and maintain a consistent sleep schedule."
+ }
+ ],
+ "id": 1
+}
+```
+
+> **NOTE:** Only key `conversations` is compulsary for training and other keys serve as metadata. The length of `conversations` varies.
+
+You can run the `examples/generate_conversation_dataset.py` to generate a conversation dataset supported by ColossalChat.
+
+You can use the following cmd to generate conversation dataset.
+```
+python generate_conversation_dataset.py \
+ --dataset "All"
+ --save_path "/path/to/dataset"
+```
+
## Stage1 - Supervised instructs tuning
Stage1 is supervised instructs fine-tuning, which uses the datasets mentioned earlier to fine-tune the model.
diff --git a/applications/Chat/examples/generate_conversation_dataset.py b/applications/Chat/examples/generate_conversation_dataset.py
new file mode 100644
index 000000000000..8d2fbba955b8
--- /dev/null
+++ b/applications/Chat/examples/generate_conversation_dataset.py
@@ -0,0 +1,79 @@
+import argparse
+import json
+
+from datasets import load_dataset
+
+
+def generate_alpaca():
+ # We can convert dataset with the same format("instruction", "input", "output") as Alpaca into a one-round conversation.
+ conversation_dataset = []
+ dataset = load_dataset("tatsu-lab/alpaca", split="train")
+
+ instructions = dataset["instruction"]
+ inputs = dataset["input"]
+ outputs = dataset["output"]
+
+ assert len(instructions) == len(inputs) == len(outputs)
+
+ for idx in range(len(instructions)):
+ human_utterance = instructions[idx] + "\n\n" + inputs[idx] if inputs[idx] else instructions[idx]
+ human = {"from": "human", "value": human_utterance}
+
+ gpt_utterance = outputs[idx]
+ gpt = {"from": "gpt", "value": gpt_utterance}
+
+ conversation = dict(type="instruction", language="English", dataset="Alpaca", conversations=[human, gpt])
+ conversation_dataset.append(conversation)
+
+ return conversation_dataset
+
+
+def generate_sharegpt():
+ # ShareGPT data requires less processing.
+ conversation_dataset = []
+ dataset = load_dataset("anon8231489123/ShareGPT_Vicuna_unfiltered",
+ data_files="ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json",
+ split="train")
+
+ conversations = dataset["conversations"]
+
+ for idx in range(len(conversations)):
+ for conv in conversations[idx]:
+ # We don't need markdown and text value.
+ del conv["markdown"]
+ del conv["text"]
+
+ conversation = dict(type="conversation",
+ language="Multilingual",
+ dataset="ShareGPT",
+ conversations=conversations[idx])
+ conversation_dataset.append(conversation)
+
+ return conversation_dataset
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--dataset',
+ type=str,
+ default="All",
+ choices=["Alpaca", "ShareGPT", "All"],
+ help="which dataset to convert, All will combine Alpaca and ShareGPT")
+ parser.add_argument('--save_path', type=str, default="dataset.json", help="path to save the converted dataset")
+ args = parser.parse_args()
+
+ conversation_dataset = []
+
+ if args.dataset == "Alpaca":
+ conversation_dataset.extend(generate_alpaca())
+ elif args.dataset == "ShareGPT":
+ conversation_dataset.extend(generate_sharegpt())
+ else:
+ conversation_dataset.extend(generate_alpaca())
+ conversation_dataset.extend(generate_sharegpt())
+
+ for idx, sample in enumerate(conversation_dataset):
+ sample["id"] = idx + 1
+
+ with open(args.save_path, mode='w') as f:
+ json.dump(conversation_dataset, f, indent=4, default=str, ensure_ascii=False)
diff --git a/applications/Chat/examples/train_sft.py b/applications/Chat/examples/train_sft.py
index cb3eb649d76c..4676d47dd331 100644
--- a/applications/Chat/examples/train_sft.py
+++ b/applications/Chat/examples/train_sft.py
@@ -74,8 +74,8 @@ def train(args):
padding_side="right",
use_fast=False,
)
- tokenizer.eos_token = '<\s>'
- tokenizer.pad_token = tokenizer.unk_token
+ tokenizer.eos_token = ''
+ tokenizer.pad_token = tokenizer.eos_token
else:
raise ValueError(f'Unsupported model "{args.model}"')
@@ -153,9 +153,7 @@ def train(args):
optim,
num_warmup_steps=math.ceil(max_steps * 0.03),
num_training_steps=max_steps)
- strategy_dict = strategy.prepare(
- dict(model=model, optimizer=optim, lr_scheduler=lr_scheduler)
- )
+ strategy_dict = strategy.prepare(dict(model=model, optimizer=optim, lr_scheduler=lr_scheduler))
model = strategy_dict['model']
optim = strategy_dict['optimizer']
lr_scheduler = strategy_dict['lr_scheduler']
From c6ab96983ab522e1457bdbba94071b7684caca3e Mon Sep 17 00:00:00 2001
From: LuGY <74758262+Gy-Lu@users.noreply.github.com>
Date: Fri, 30 Jun 2023 15:30:50 +0800
Subject: [PATCH 35/64] [zero] refactor low level zero for shard evenly (#4030)
* refactor low level zero
* fix zero2 and support cpu offload
* avg gradient and modify unit test
* refactor grad store, support layer drop
* refactor bucket store, support grad accumulation
* fix and update unit test of zero and ddp
* compatible with tp, ga and unit test
* fix memory leak and polish
* add zero layer drop unittest
* polish code
* fix import err in unit test
* support diffenert comm dtype, modify docstring style
* polish code
* test padding and fix
* fix unit test of low level zero
* fix pad recording in bucket store
* support some models
* polish
---
colossalai/zero/low_level/_utils.py | 2 +-
.../low_level/bookkeeping/bucket_store.py | 122 ++++-
.../low_level/bookkeeping/gradient_store.py | 118 ++---
.../low_level/bookkeeping/parameter_store.py | 99 ++--
colossalai/zero/low_level/low_level_optim.py | 438 +++++++-----------
.../test_plugin/test_low_level_zero_plugin.py | 11 +-
.../test_zero/test_low_level/test_grad_acc.py | 48 +-
.../test_zero/test_low_level/test_zero1_2.py | 52 ++-
8 files changed, 422 insertions(+), 468 deletions(-)
diff --git a/colossalai/zero/low_level/_utils.py b/colossalai/zero/low_level/_utils.py
index 218f7603bc54..a9e552ebdabc 100644
--- a/colossalai/zero/low_level/_utils.py
+++ b/colossalai/zero/low_level/_utils.py
@@ -253,7 +253,7 @@ def compute_norm(gradients, params, dp_group, mp_group, norm_type=2):
return total_norm
-def sync_param(flat_tensor, tensor_list):
+def sync_tensor(flat_tensor, tensor_list):
"""
Synchronize the flattened tensor and unflattened tensor list. When
a list of tensor are flattened with `torch._utils._unflatten_dense_tensors`,
diff --git a/colossalai/zero/low_level/bookkeeping/bucket_store.py b/colossalai/zero/low_level/bookkeeping/bucket_store.py
index ec322a78bf81..98f1b78d0049 100644
--- a/colossalai/zero/low_level/bookkeeping/bucket_store.py
+++ b/colossalai/zero/low_level/bookkeeping/bucket_store.py
@@ -1,3 +1,8 @@
+from typing import Dict
+
+import torch
+from torch import Tensor
+from torch._utils import _flatten_dense_tensors
from torch.distributed import ProcessGroup
from .base_store import BaseStore
@@ -7,35 +12,102 @@ class BucketStore(BaseStore):
def __init__(self, torch_pg: ProcessGroup):
super().__init__(torch_pg)
- self._params = dict()
- self._num_elements_in_bucket = dict()
+
+ # init and reset
+ self.current_group_id = 0
+ # mapping gardient slices and parameter
+ self.grad_to_param_mapping = dict()
+
+ self._param_list = []
+ self._padding_size = []
self.reset()
- def num_elements_in_bucket(self, reduce_rank: int = None):
- return self._num_elements_in_bucket[reduce_rank]
+ def num_elements_in_bucket(self) -> int:
+ """Return the total number of elements in bucket
+
+ Returns:
+ int: the total number of elements in bucket
+ """
+
+ return self._num_elements_in_bucket
+
+ def add_param_grad(self, group_id: int, param: Tensor, padding_size: int):
+ """Add a param to bucket and record the padding size of a param for gradient padding
+
+ Args:
+ group_id (int): The index of a parameter group
+ param (Tensor): The parameter
+ padding_size (int): The padding size of the parameter
+ """
+
+ self._param_list.append(param)
+ self._padding_size.append(padding_size)
+ self._num_elements_in_bucket += (param.numel() + padding_size)
+ self.current_group_id = group_id
+
+ def build_grad_in_bucket(self):
+ """Orgnize parameters' gradient(padding and split), follows the paramters' splitting method
+
+ Data structure of self._grad_in_bucket:
+ {
+ rank0: [grad0_rank0, grad1_rank0, ...]
+ rank1: [grad1_rank1, grad1_rank1, ...]
+ }
+ """
+
+ for param, padding_size in zip(self._param_list, self._padding_size):
+ with torch.no_grad():
+ grad = param.grad.detach().flatten()
+ if padding_size > 0:
+ grad = torch.nn.functional.pad(grad, [0, padding_size])
+ grad_list = grad.split(grad.numel() // self._world_size)
+ for rank in range(self._world_size):
+ grad_current_rank = grad_list[rank].detach()
+ self.grad_to_param_mapping[id(grad_current_rank)] = id(param)
+ self._grad_in_bucket[rank].append(grad_current_rank)
+ param.grad = None
+
+ def get_grad(self) -> Dict:
+ """Return the dictionary of gradients slices, of which the keys are ranks
+
+ Returns:
+ Dict: The dictionary of gradients slices
+ """
+
+ return self._grad_in_bucket
+
+ def get_flatten_grad(self) -> Tensor:
+ """Return the flattened gradients slices in the bucket, the data orginization of the flattened tensor:
+ [grad0_rank0, grad1_rank0, ..., grad_1_rank0, grad1_rank1, ....]
+
+ Returns:
+ Tensor: the flattened gradients slices in the bucket
+ """
+
+ flat_grad = []
+ for grad_list in self._grad_in_bucket.values():
+ flat_grad.append(_flatten_dense_tensors(grad_list))
+ flat_grad = _flatten_dense_tensors(flat_grad)
+ return flat_grad
+
+ def get_param_id_of_grad(self, grad: Tensor) -> int:
+ """Return the id of a parameter which the gradient slice belongs to
+
+ Args:
+ grad (Tensor): the gradient slice
- def add_num_elements_in_bucket(self, num_elements, reduce_rank: int = None):
- self._num_elements_in_bucket[reduce_rank] += num_elements
+ Returns:
+ int: the id of a parameter which the gradient slice belongs to
+ """
- def add_param(self, tensor, reduce_rank: int = None):
- self._params[reduce_rank].append(tensor)
+ return self.grad_to_param_mapping[id(grad)]
def reset(self):
- keys = [None] + list(range(self._world_size))
- self._params = {rank: [] for rank in keys}
- self._num_elements_in_bucket = {rank: 0 for rank in keys}
-
- def reset_by_rank(self, reduce_rank=None):
- self._params[reduce_rank] = []
- self._num_elements_in_bucket[reduce_rank] = 0
-
- def get_grad(self, reduce_rank: int = None):
- param_list = self.get_param(reduce_rank)
- for param in param_list:
- # the param must have grad for reduction
- assert param.grad is not None, f'Parameter of size ({param.size()}) has None grad, cannot be reduced'
- return [param.grad for param in param_list]
-
- def get_param(self, reduce_rank: int = None):
- return self._params[reduce_rank]
+ self.grad_to_param_mapping = dict()
+ self._num_elements_in_bucket = 0
+ self._param_list = []
+ self._padding_size = []
+ self._grad_in_bucket = dict()
+ for rank in range(self._world_size):
+ self._grad_in_bucket[rank] = []
diff --git a/colossalai/zero/low_level/bookkeeping/gradient_store.py b/colossalai/zero/low_level/bookkeeping/gradient_store.py
index 942d7186e55f..0b86ec8ca89e 100644
--- a/colossalai/zero/low_level/bookkeeping/gradient_store.py
+++ b/colossalai/zero/low_level/bookkeeping/gradient_store.py
@@ -1,88 +1,92 @@
from typing import List
from torch import Tensor
+from torch._utils import _flatten_dense_tensors
from .base_store import BaseStore
class GradientStore(BaseStore):
- def __init__(self, *args):
+ def __init__(self, *args, partition_grad: bool = False):
super().__init__(*args)
- # bookkeeping data structures
- self._averaged_gradients = dict()
-
- # for backward reduction hooks
- self._grad_acc_objs = []
-
- def append_accumulate_grad_object(self, obj):
"""
- Keep :class:`AccumulateGrad` objects. If these objects are not kept, reduction hooks may not
- be attached successfully.
-
- :param obj: An object of :class:`AccumulateGrad` class
- :type obj: :class:`AccumulateGrad`
+ self._grads_of_params mapping the paramater and its gradient slices
+ data structure:
+ {
+ group_id:{
+ param_id: [grad_rank0, grad_rank1, ...]
+ }
+ }
"""
+ self._grads_of_params = dict()
+ # for zero2, it's `param_id: [grad_local_rank]`
+ self._working_index = 0 if partition_grad else self._local_rank
- self._grad_acc_objs.append(obj)
+ def get_partitioned_gradients_by_param_id(self, group_id: int, param_id: int) -> List:
+ """Return list of gradient slices of a specific parameter
- def get_averaged_gradients_by_group(self, group_id: int) -> List[Tensor]:
- """
- Return average gradients of a parameter group
+ Args:
+ group_id (int): The index of a parameter group
+ param_id (int): The id of a parameter
- :param group_id: The index of parameter group
- :type group_id: int
-
- :return: Return the list of averaged gradients of a parameter group. Each element is a gradient, not a parameter.
- :rtype: List[torch.Tensor]
+ Returns:
+ List: the list of gradient slices of a parameter.
"""
- if group_id not in self._averaged_gradients:
- self._averaged_gradients[group_id] = []
-
- return self._averaged_gradients[group_id]
- def append_average_gradient_by_group(self, group_id: int, tensor: Tensor) -> None:
- """
- Append an average gradient to the list of averaged gradients of a parameter group
+ if group_id in self._grads_of_params:
+ if param_id in self._grads_of_params[group_id]:
+ return self._grads_of_params[group_id][param_id]
+ # the param has no grad, for instance, in layer drop
+ return []
- :param group_id: The index of a parameter group
- :param tensor: A :class:`torch.Tensor` object
- :type group_id: int
- :type tensor: torch.Tensor
+ def append_gradients_by_param_id(self, grad: Tensor, group_id: int, param_id: int):
+ """Append a gradient slice to the parameter's gradient slice list
+ Args:
+ grad (Tensor): The gradient slice to append to list
+ group_id (int): The index of a parameter group
+ param_id (int): The id of a parameter
"""
- if group_id in self._averaged_gradients:
- self._averaged_gradients[group_id].append(tensor)
+ if group_id not in self._grads_of_params:
+ self._grads_of_params[group_id] = dict()
+ if param_id not in self._grads_of_params[group_id]:
+ self._grads_of_params[group_id][param_id] = [grad]
else:
- self._averaged_gradients[group_id] = [tensor]
+ self._grads_of_params[group_id][param_id].append(grad)
- def add_average_gradient_by_group(self, group_id: int, tensor_idx: int, tensor: Tensor) -> None:
+ def add_gradients_by_param_id(self, grad: Tensor, grad_idx: int, group_id: int, param_id: int):
+ """For old gradient accumulation, not in use now.
+ Add a gradient slice on an existing slice of the parameter's gradient
+
+ Args:
+ grad (Tensor): The split gradient to append to list
+ grad_idx (int): The index of the existing slice
+ group_id (int): The index of a parameter group
+ param_id (int): The id of a parameter
"""
- Add an average gradient to the list of averaged gradients of a parameter group
- :param group_id: The index of a parameter group
- :param tensor_idx: The index of a tensor in the list of averaged gradients
- :param tensor: A :class:`torch.Tensor` object
- :type group_id: int
- :type tensor_idx: int
- :type tensor: torch.Tensor
+ self._grads_of_params[group_id][param_id][grad_idx].add_(grad)
- """
- self._averaged_gradients[group_id][tensor_idx].add_(tensor)
+ def get_working_grads_by_group_id(self, group_id: int) -> List:
+ """Return list of working gradient slices in the group
- def reset_average_gradients_by_group(self, group_id: int) -> None:
- """
- Reset the bookkeeping data structure for averaged gradients to an empty list
+ Args:
+ group_id (int): The index of a parameter group
- :param group_id: The index of a parameter group
- :type group_id: int
+ Returns:
+ List: the list working gradient slices in the group
"""
- self._averaged_gradients[group_id] = []
+ grad_list = []
+ for param_grads in self._grads_of_params[group_id].values():
+ grad_list.append(param_grads[self._working_index])
- def reset_all_average_gradients(self) -> None:
- """
- Reset the bookkeeping data structure for averaged gradients to an empty list
- """
- self._averaged_gradients = dict()
+ return grad_list
+
+ def reset_grads_by_group_id(self, group_id: int):
+ self._grads_of_params[group_id] = dict()
+
+ def reset_all_gradients(self):
+ self._grads_of_params = dict()
diff --git a/colossalai/zero/low_level/bookkeeping/parameter_store.py b/colossalai/zero/low_level/bookkeeping/parameter_store.py
index 1f3ba7cbc3bc..63f7c5506069 100644
--- a/colossalai/zero/low_level/bookkeeping/parameter_store.py
+++ b/colossalai/zero/low_level/bookkeeping/parameter_store.py
@@ -1,5 +1,3 @@
-from typing import List
-
from torch import Tensor
from torch.distributed import ProcessGroup
@@ -10,88 +8,43 @@ class ParameterStore(BaseStore):
def __init__(self, torch_pg: ProcessGroup):
super().__init__(torch_pg)
- # param partitioning data structures
- self._param_to_rank = dict()
- self._rank_group_id_to_param_list = dict()
- self._rank_group_id_to_flat_param = dict()
- # param reduction data structures
- self._is_param_reduced = dict()
- self._reduced_param = []
+ # record the padding size of each param
+ self._padding_map = dict()
- def set_param_to_rank(self, tensor: Tensor, rank: int) -> None:
- """
- Set the mapping between parameter to rank, each parameter should be owned by a rank.
+ # mapping working param and master param
+ self.master_to_working_param = dict()
+ self.working_to_master_param = dict()
- :param tensor: A :class:`torch.Tensor` object
- :type tensor: torch.Tensor
- :param rank: The rank of which the process is responsible for updating the parameter
- :type rank: int
- """
+ def record_param_padding_size(self, param: Tensor, padding_size: int):
+ """Record the padding size of a param
- self._param_to_rank[tensor] = rank
-
- def get_param_rank(self, tensor: Tensor) -> int:
+ Args:
+ param (Tensor): The parameter
+ padding_size (int): The padding size of the parameter
"""
- Gives the rank which the parameter belongs to
- :param tensor: A :class:`torch.Tensor` object
- :type tensor: torch.Tensor
- """
- return self._param_to_rank[tensor]
+ self._padding_map[id(param)] = padding_size
- def belongs_to_current_rank(self, tensor) -> bool:
- """
- Check whether a parameter is supposed to be updated by the process of the current rank
+ def get_param_padding_size(self, param: Tensor) -> int:
+ """Return the padding size of the parameter
- :param tensor: A :class:`torch.Tensor` object
- :type tensor: torch.Tensor
+ Args:
+ param (Tensor): The parameter
- :return: True if the parameter should be updated by the current rank. Otherwise false.
- :rtype: bool
+ Returns:
+ int: the padding size of the parameter
"""
- tensor_rank = self._param_to_rank[tensor]
- return tensor_rank == self._local_rank
-
- def add_param_list_by_rank_group(self, rank, group_id, tensor_list) -> None:
- if rank not in self._rank_group_id_to_param_list:
- self._rank_group_id_to_param_list[rank] = dict()
-
- if group_id not in self._rank_group_id_to_param_list[rank]:
- self._rank_group_id_to_param_list[rank][group_id] = []
-
- self._rank_group_id_to_param_list[rank][group_id].extend(tensor_list)
+ return self._padding_map[id(param)]
- def get_params_by_rank_group(self, rank, group_id) -> List[Tensor]:
- return self._rank_group_id_to_param_list[rank][group_id]
+ def link_master_and_working_param(self, master_param: Tensor, working_param: Tensor):
+ """Mapping master parameter and working parameter
- def add_flat_param_by_rank_group(self, rank, group_id, tensor) -> None:
- if rank not in self._rank_group_id_to_flat_param:
- self._rank_group_id_to_flat_param[rank] = dict()
-
- self._rank_group_id_to_flat_param[rank][group_id] = tensor
-
- def get_flat_param_by_rank_group(self, rank, group_id) -> Tensor:
- return self._rank_group_id_to_flat_param[rank][group_id]
-
- def is_param_reduced(self, tensor):
- return self._is_param_reduced[tensor]
-
- def set_param_reduction_state(self, tensor, state):
- self._is_param_reduced[tensor] = state
-
- def get_param_reduction_states(self):
- return self._is_param_reduced
-
- def reset_previous_reduced_params(self):
- self._reduced_param = []
-
- def add_previous_reduced_param(self, tensor):
- self._reduced_param.append(tensor)
+ Args:
+ master_param (Tensor): The parameter copy in optimizer
+ working_param (Tensor): The parameter of the model
+ """
- def clear_grads_of_previous_reduced_params(self):
- if len(self._reduced_param) > 0:
- for param in self._reduced_param:
- param.grad = None
- self.reset_previous_reduced_params()
+ self.master_to_working_param[id(master_param)] = working_param
+ self.working_to_master_param[id(working_param)] = master_param
diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py
index ee03c0f0ae15..8743cab3313f 100644
--- a/colossalai/zero/low_level/low_level_optim.py
+++ b/colossalai/zero/low_level/low_level_optim.py
@@ -1,4 +1,5 @@
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
+from contextlib import contextmanager
from functools import partial
from typing import Optional
@@ -16,6 +17,7 @@
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.tensor import ColoParameter, ProcessGroup
+from colossalai.utils import conditional_context
from colossalai.utils.cuda import get_current_device
from ._utils import (
@@ -23,12 +25,10 @@
compute_norm,
flatten,
has_inf_or_nan,
- reduce_tensor_dp_group,
release_param_grad,
- split_by_dtype,
- sync_param,
+ sync_tensor,
)
-from .bookkeeping import BucketStore, GradientStore, ParameterStore, TensorBucket
+from .bookkeeping import BucketStore, GradientStore, ParameterStore
class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
@@ -50,7 +50,7 @@ def __init__(self,
def check_local_overflow(self) -> bool:
for group_id in range(self.num_working_param_groups):
- for avg_grad in self.grad_store.get_averaged_gradients_by_group(group_id):
+ for avg_grad in self.grad_store.get_working_grads_by_group_id(group_id):
if avg_grad is not None and has_inf_or_nan(avg_grad):
return True
return False
@@ -77,14 +77,11 @@ def __init__(
overlap_communication: bool = False,
partition_grad: bool = False, # stage 2 flag
cpu_offload: bool = False, # cpu offload
+ grad_accumulate_interval: int = 1,
forced_dtype: Optional[torch.dtype] = None):
- # TODO: add support for
- # 1. fp16 master weights
- # 2. contiguous gradients
- # 3. cpu offload
- # 4. support when some parameters requires_grad = False
- # 5. support layer drop
+ assert not (partition_grad and grad_accumulate_interval > 1), \
+ "gradient accumulation is not compatible with ZeRO-2"
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
self._dtype = self.optim.param_groups[0]['params'][0].dtype
self._logger = get_dist_logger()
@@ -95,6 +92,11 @@ def __init__(
self._cpu_offload = cpu_offload
+ # grad accumulation
+ self.require_grad_sync = True
+ self._accumulate_intervel = grad_accumulate_interval
+ self._accumulate_step = 0
+
colo_pg = self._search_colo_process_group()
if isinstance(colo_pg, ProcessGroup):
self._local_rank = colo_pg.dp_local_rank()
@@ -122,7 +124,7 @@ def __init__(
# working and master params for mixed precision training
self._working_param_groups = dict()
- self._master_flat_param_groups_of_current_rank = dict()
+ self._master_param_groups_of_current_rank = dict()
# communication params
self._overlap_communication = overlap_communication
@@ -145,7 +147,7 @@ def __init__(
# ParameterStore will manage the tensor buffers used for zero
# it will not manage the tensors used by mixed precision training
self._param_store = ParameterStore(self._dp_torch_group)
- self._grad_store = GradientStore(self._dp_torch_group)
+ self._grad_store = GradientStore(self._dp_torch_group, partition_grad=partition_grad)
self._bucket_store = BucketStore(self._dp_torch_group)
# iterate over the param group in the optimizer
@@ -160,55 +162,17 @@ def __init__(
# add the working params to working_param_groups for bookkeeping
self._working_param_groups[group_id] = group_params
- # assign parameters to ranks
- # the params in the list are sorted
- params_per_rank = self._partition_param_list(group_params)
-
- # store the mapping between param to rank
- # each param should belong to only one rank
- for rank, params in enumerate(params_per_rank):
- self._param_store.add_param_list_by_rank_group(rank, group_id, params)
- for param in params:
- self._param_store.set_param_to_rank(param, rank)
+ master_param_current_rank = self._create_master_param_current_rank(group_params)
- # move to cpu to make room to create the flat tensor
- # move_tensor(params, device='cpu')
- for param in group_params:
- param.data = param.data.cpu()
-
- # flatten the reordered tensors
- for rank in range(self._world_size):
- tensor_list = self._param_store.get_params_by_rank_group(rank, group_id)
- with torch.no_grad():
- flat_tensor = flatten(tensor_list)
- flat_tensor = flat_tensor.data.cuda()
- self._param_store.add_flat_param_by_rank_group(rank, group_id, flat_tensor)
-
- # sync parameters
- for rank in range(self._world_size):
- flat_tensor = self._param_store.get_flat_param_by_rank_group(rank, group_id)
- tensor_list = self._param_store.get_params_by_rank_group(rank, group_id)
- sync_param(flat_tensor=flat_tensor, tensor_list=tensor_list)
-
- # create a copy of fp32 master weights of the parameters for which this rank is responsible
- working_flat_current_rank = self._param_store.get_flat_param_by_rank_group(self._local_rank, group_id)
- master_flat_current_rank = working_flat_current_rank.float()
- device = 'cpu' if self._cpu_offload else get_current_device()
- master_flat_current_rank = master_flat_current_rank.to(device)
- master_flat_current_rank.requires_grad = True
- self._master_flat_param_groups_of_current_rank[group_id] = master_flat_current_rank
+ self._master_param_groups_of_current_rank[group_id] = master_param_current_rank
# need to replace the params in the `params` field in the optimizer
# so that when the optimizer calls step(), it only updates the tensors
# managed by this data parallel rank
- param_group['params'] = [master_flat_current_rank]
-
- # set reduction state
- for param in self._working_param_groups[group_id]:
- self._param_store.set_param_reduction_state(param, False)
+ param_group['params'] = master_param_current_rank
- # initialize communication stream for
- # communication-computation overlapping
+ # intialize communication stream for
+ # communication-compuation overlapping
if self._overlap_communication:
self._comm_stream = torch.cuda.Stream()
@@ -265,29 +229,36 @@ def _search_colo_process_group(self):
raise RuntimeError("All parameters should be ColoParameter if you use ColoParameter.")
return colo_pg
- def _partition_param_list(self, param_list):
- params_per_rank = [[] for _ in range(self._world_size)]
- numel_per_rank = [0 for _ in range(self._world_size)]
+ def _create_master_param_current_rank(self, param_list):
+ # split each param evenly by world size
+ params_current_rank = []
+ device = 'cpu' if self._cpu_offload else get_current_device()
+
+ for param in reversed(param_list):
+ padding_size = (self._world_size - param.numel() % self._world_size) % self._world_size
+ self._param_store.record_param_padding_size(param, padding_size)
+
+ with torch.no_grad():
+ if padding_size > 0:
+ padding_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size])
+ else:
+ padding_param = param.data.view(-1)
+ splited_params = padding_param.split(padding_param.numel() // self._world_size)
- # partition the parameters in a greedy fashion
- sorted_params = sorted(param_list, key=lambda x: x.numel(), reverse=True)
- for param in sorted_params:
- # allocate this parameter to the rank with
- # the smallest numel for load balancing purpose
- rank_to_go = numel_per_rank.index(min(numel_per_rank))
- params_per_rank[rank_to_go].append(param)
- numel_per_rank[rank_to_go] += param.numel()
+ splited_param_current_rank = splited_params[self._local_rank].detach().float().to(device)
+ params_current_rank.append(splited_param_current_rank)
+ self._param_store.link_master_and_working_param(splited_param_current_rank, param)
- if self._verbose:
- self._logger.info(f'Number of elements on ranks: {numel_per_rank}', ranks=[0])
- return params_per_rank
+ return params_current_rank
###########################
# Backward Reduction Hook #
###########################
- def _grad_handler(self, param, grad, reduce_rank):
- self._add_to_reduction_bucket(param, reduce_rank)
+ def _grad_handler(self, param, group_id, grad):
+ # if run with no_sync context, would not sync grad when backward
+ if self.require_grad_sync:
+ self._add_to_bucket(param, group_id)
return grad
def _attach_reduction_hook(self):
@@ -297,149 +268,96 @@ def _attach_reduction_hook(self):
param_group = self._working_param_groups[group_id]
for param in param_group:
if param.requires_grad:
- # determines the reduction destination rank
- # this is only valid for stage 2
- # dst_rank = None means using all-reduce
- # else using reduce
- if self._partition_grads:
- reduce_rank = self._param_store.get_param_rank(param)
- else:
- reduce_rank = None
-
- param.register_hook(partial(self._grad_handler, param, reduce_rank=reduce_rank))
-
- def _reduce_tensor_bucket(self, bucket: TensorBucket, reduce_rank):
- if self._overlap_communication:
- torch.cuda.synchronize()
- self._param_store.clear_grads_of_previous_reduced_params()
- stream = self._comm_stream
- else:
- stream = torch.cuda.current_stream()
-
- with torch.cuda.stream(stream):
- flat = bucket.flatten()
- reduce_global_rank = None
- if reduce_rank is not None:
- reduce_global_rank = self._dp_global_ranks[reduce_rank]
- reduced_flat = reduce_tensor_dp_group(tensor=flat,
- dtype=self._communication_dtype,
- dst_local_rank=reduce_rank,
- dst_global_rank=reduce_global_rank,
- group=self._dp_torch_group)
-
- # update the reduced tensor
- if reduce_rank is None or reduce_rank == self._local_rank:
- bucket.unflatten_and_copy(reduced_flat)
-
- def _reduce_tensor_list_with_one_dtype(self, tensor_list, bucket_size, reduce_rank):
- param_bucket = TensorBucket(size=bucket_size)
-
- for tensor in tensor_list:
- param_bucket.add_to_bucket(tensor, allow_oversize=True)
-
- if param_bucket.is_full_or_oversized():
- self._reduce_tensor_bucket(bucket=param_bucket, reduce_rank=reduce_rank)
- param_bucket.empty()
-
- if not param_bucket.is_empty():
- self._reduce_tensor_bucket(bucket=param_bucket, reduce_rank=reduce_rank)
-
- def _reduce_grads(self, reduce_rank, grads, bucket_size):
- grad_buckets_by_dtype = split_by_dtype(grads)
-
- for tensor_list in grad_buckets_by_dtype:
- self._reduce_tensor_list_with_one_dtype(tensor_list=tensor_list,
- bucket_size=bucket_size,
- reduce_rank=reduce_rank)
+ param.register_hook(partial(self._grad_handler, param, group_id))
#######################
# Reduction Functions #
#######################
- def _run_reduction(self, reduce_rank=None):
- # reduce grads
- self._reduce_grads(reduce_rank=reduce_rank,
- grads=self._bucket_store.get_grad(reduce_rank=reduce_rank),
- bucket_size=self._bucket_store.num_elements_in_bucket(reduce_rank))
+ def _run_reduction(self):
+ if self._bucket_store.num_elements_in_bucket() > 0:
+ self._bucket_store.build_grad_in_bucket()
+ flat_grads = self._bucket_store.get_flatten_grad()
+ flat_grads /= self._world_size
+ if self._overlap_communication:
+ stream = self._comm_stream
+ else:
+ stream = torch.cuda.current_stream()
+
+ with torch.cuda.stream(stream):
+ group_id = self._bucket_store.current_group_id
+
+ grad_dtype = flat_grads.dtype
+ if self._communication_dtype is not None:
+ flat_grads = flat_grads.to(self._communication_dtype)
+
+ if not self._partition_grads:
+ dist.all_reduce(flat_grads, group=self._dp_torch_group)
+ if flat_grads.dtype != grad_dtype:
+ flat_grads = flat_grads.to(grad_dtype)
+
+ flat_grads_per_rank = flat_grads.split(flat_grads.numel() // self._world_size)
+ grad_in_bucket = self._bucket_store.get_grad()
+
+ for rank, grad_list in grad_in_bucket.items():
+ sync_tensor(flat_grads_per_rank[rank], grad_list)
+ for grad in grad_list:
+ param_id = self._bucket_store.get_param_id_of_grad(grad)
+ self._grad_store.append_gradients_by_param_id(grad, group_id, param_id)
- # use communication stream if overlapping
- # communication with computation
- if self._overlap_communication:
- stream = self._comm_stream
- else:
- stream = torch.cuda.current_stream()
-
- with torch.cuda.stream(stream):
- params_in_bucket = self._bucket_store.get_param(reduce_rank=reduce_rank)
-
- for param in params_in_bucket:
- # the is_param_reduced flag should be False showing that
- # this param is not reduced before calling self._reduce_grads_by_rank
- is_param_reduced = self._param_store.is_param_reduced(param)
-
- if is_param_reduced:
- msg = f'Parameter of size ({param.size()}) has been reduced, ' + \
- 'duplicate reduction will lead to arithmetic incorrectness'
- raise RuntimeError(msg)
-
- # update the flag
- self._param_store.set_param_reduction_state(param, True)
-
- # if partition grads = True
- # we do not keep the gradient after reduction
- if self._partition_grads and not self._param_store.belongs_to_current_rank(param):
- if self._overlap_communication:
- # we need to keep this gradient for now as reduction may
- # be completed yet since it is using a different cuda stream
- self._param_store.add_previous_reduced_param(param)
- else:
- param.grad = None
+ else:
+ flat_grads_list = list(flat_grads.split(len(flat_grads) // self._world_size))
+ recieved_grad = torch.zeros_like(flat_grads_list[0])
+ dist.reduce_scatter(recieved_grad, flat_grads_list, group=self._dp_torch_group)
+
+ if recieved_grad.dtype != grad_dtype:
+ recieved_grad = recieved_grad.to(grad_dtype)
+
+ grad_in_bucket_current_rank = self._bucket_store.get_grad()[self._local_rank]
+ sync_tensor(recieved_grad, grad_in_bucket_current_rank)
+ for grad in grad_in_bucket_current_rank:
+ param_id = self._bucket_store.get_param_id_of_grad(grad)
+ self._grad_store.append_gradients_by_param_id(grad, group_id, param_id)
- self._bucket_store.reset_by_rank(reduce_rank)
+ self._bucket_store.reset()
- def _add_to_reduction_bucket(self, param, reduce_rank=None):
+ def _add_to_bucket(self, param, group_id):
param_size = param.numel()
# check if the bucket is full
# if full, will reduce the grads already in the bucket
+ # or got a grad of param from another group
# after reduction, the bucket will be empty
- if self._bucket_store.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size:
- self._run_reduction(reduce_rank)
+ if self._bucket_store.num_elements_in_bucket() + param_size > self._reduce_bucket_size or \
+ group_id != self._bucket_store.current_group_id:
+ self._run_reduction()
- # the param must not be reduced to ensure correctness
- is_param_reduced = self._param_store.is_param_reduced(param)
- if is_param_reduced:
- msg = f'Parameter of size ({param.size()}) has already been reduced, ' \
- + 'duplicate reduction will lead to arithmetic incorrectness'
- raise RuntimeError(msg)
-
- self._bucket_store.add_num_elements_in_bucket(param_size, reduce_rank)
- self._bucket_store.add_param(param, reduce_rank)
+ padding_size = self._param_store.get_param_padding_size(param)
+ self._bucket_store.add_param_grad(group_id, param, padding_size)
################################
# torch.optim.Optimizer methods
################################
- def backward(self, loss, retain_graph=False, sync_grad=True):
+ def backward(self, loss, retain_graph=False):
if self.mixed_precision_mixin is not None:
loss = self.mixed_precision_mixin.pre_backward(loss)
- loss.backward(retain_graph=retain_graph)
- # finish gradient reduction
- if not self._partition_grads:
- self._reduce_grad_stage1()
- else:
- # TODO: support async comm in reduce
- self._reduce_grad_stage2()
+ self._accumulate_step += 1
+ no_sync = self._accumulate_step < self._accumulate_intervel
+ with conditional_context(self.no_sync(), enable=no_sync):
+ loss.backward(retain_graph=retain_graph)
+
+ if no_sync:
+ return
+
+ self._reduce_grad(self._partition_grads)
# clear reduced grads
if self._overlap_communication:
torch.cuda.synchronize()
- self._param_store.clear_grads_of_previous_reduced_params()
- # gradient synchronization
- if sync_grad:
- self._sync_grad()
+ self.zero_grad()
def zero_grad(self, set_to_none=True):
"""
@@ -467,68 +385,86 @@ def zero_grad(self, set_to_none=True):
def step(self, closure=None):
assert closure is None, 'closure is not supported by step()'
+ if not self._accumulate_step == self._accumulate_intervel:
+ return
+
if self.mixed_precision_mixin is not None and self.mixed_precision_mixin.should_skip_step():
- self._grad_store.reset_all_average_gradients()
+ self._grad_store.reset_all_gradients()
if self._verbose:
self._logger.info(f'Found overflow. Skip step')
self.zero_grad()
+ self._accumulate_step -= 1
return
- # copy the grad of working param to master param
- single_grad_partition_groups = []
+ # record all grads for unscale and clip
+ grad_partition_groups = []
norm_groups = []
+ # sometimes not all params are 'really' working
+ # for instance, when layer drop, the dropped layer has no grad
+ # and should not be updated
+ real_working_params = dict()
+ real_master_params = dict()
+
+ grad_index = 0 if self._partition_grads else self._local_rank
+
for group_id in range(self.num_param_groups):
+ master_params = self._master_param_groups_of_current_rank[group_id]
+ real_working_params[group_id] = []
+ real_master_params[group_id] = []
+ for splited_param in master_params:
+ working_param = self._param_store.master_to_working_param[id(splited_param)]
+ # if a working param requires grad and has no grad
+ # it is not 'really' working, e.g. the droped layer
+ # else the splited grad should be attached to the splited param
+ grads = self._grad_store.get_partitioned_gradients_by_param_id(group_id, id(working_param))
+ if len(grads) > 0:
+ real_working_params[group_id].append(working_param)
+ grad = grads[grad_index].to(splited_param.dtype).to(splited_param.device)
+ splited_param.grad = grad
+ grad_partition_groups.append(grad)
+ real_master_params[group_id].append(splited_param)
+
# compute norm
- norm_group = compute_norm(gradients=self._grad_store.get_averaged_gradients_by_group(group_id),
- params=self._param_store.get_params_by_rank_group(group_id=group_id,
- rank=self._local_rank),
+ working_grads = self._grad_store.get_working_grads_by_group_id(group_id)
+ norm_group = compute_norm(gradients=working_grads,
+ params=real_working_params[group_id],
dp_group=self._dp_torch_group,
mp_group=self._mp_torch_group)
norm_groups.append(norm_group)
- # create flat gradient for the flat fp32 master params
- working_avg_grads = self._grad_store.get_averaged_gradients_by_group(group_id)
- flat_working_avg_grads = flatten(working_avg_grads)
+ self._grad_store.reset_grads_by_group_id(group_id)
- dtype = self._master_flat_param_groups_of_current_rank[group_id].dtype
- flat_master_avg_grads = flat_working_avg_grads.to(dtype)
-
- param_shape = self._master_flat_param_groups_of_current_rank[group_id].shape
- assert param_shape == flat_master_avg_grads.shape, \
- f'fp32 param and grad have different shape {param_shape} vs {flat_master_avg_grads.shape}'
-
- single_grad_partition_groups.append(flat_master_avg_grads)
- device = self._master_flat_param_groups_of_current_rank[group_id].device
- self._master_flat_param_groups_of_current_rank[group_id].grad = flat_master_avg_grads.to(device)
- self._grad_store.reset_average_gradients_by_group(group_id)
+ # update the params in the optimizer
+ self.optim.param_groups[group_id]['params'] = real_master_params[group_id]
# unscale and clip grads
global_norm = calculate_global_norm_from_list(norm_list=norm_groups)
- self._unscale_and_clip_grads(single_grad_partition_groups, global_norm)
+ self._unscale_and_clip_grads(grad_partition_groups, global_norm)
# update the parameters
self.optim.step()
- # release the master grad
- release_param_grad(self._master_flat_param_groups_of_current_rank.values())
- # update working partition updated by the current rank
- for group_id in range(len(self._working_param_groups)):
- working_param = self._param_store.get_flat_param_by_rank_group(rank=self._local_rank, group_id=group_id)
- master_param = self._master_flat_param_groups_of_current_rank[group_id]
- working_param.data.copy_(master_param)
+ # release the grad
+ grad_partition_groups = []
+ for group_id in range(self.num_param_groups):
+ release_param_grad(self._master_param_groups_of_current_rank[group_id])
- # broadcast the updated model weights
- handles = []
+ # update working partition updated by the current rank
for group_id in range(self.num_param_groups):
- for index in range(self._world_size):
- rank = self._dp_global_ranks[index]
- working_param = self._param_store.get_flat_param_by_rank_group(rank=index, group_id=group_id)
- handle = dist.broadcast(working_param, src=rank, group=self._dp_torch_group, async_op=True)
- handles.append(handle)
+ master_working_param = self.optim.param_groups[group_id]['params']
+
+ for idx, splited_param in enumerate(master_working_param):
+ full_master_param = [torch.zeros_like(splited_param).cuda() for _ in range(self._world_size)]
+ dist.all_gather(full_master_param, splited_param.cuda(), group=self._dp_torch_group)
+ working_param = real_working_params[group_id][idx]
+ full_master_param = flatten(full_master_param)[:working_param.numel()].reshape_as(working_param)
+ working_param.data.copy_(full_master_param)
+
+ self.optim.param_groups[group_id]['params'] = self._master_param_groups_of_current_rank[group_id]
- for handle in handles:
- handle.wait()
+ # reset accumulate step
+ self._accumulate_step = 0
#############################
# Mixed Precision Utilities #
@@ -553,49 +489,25 @@ def _unscale_and_clip_grads(self, grad_groups_flat, total_norm):
# Gradient Synchronization #
############################
- def _sync_grad(self):
- # update param already reduced flag
- reduction_states = self._param_store.get_param_reduction_states()
- for tensor, _ in reduction_states.items():
- reduction_states[tensor] = False
-
- # accumulate gradient
- for group_id in range(self.num_param_groups):
- param_group = self._param_store.get_params_by_rank_group(self._local_rank, group_id)
-
- avg_gradients_group = self._grad_store.get_averaged_gradients_by_group(group_id)
-
- param_idx = 0
- for param in param_group:
- if param.grad is not None:
- if len(avg_gradients_group) == param_idx:
- self._grad_store.append_average_gradient_by_group(group_id, param.grad)
- else:
- self._grad_store.add_average_gradient_by_group(group_id, param_idx, param.grad)
- param_idx += 1
-
- # the gradients needed are stored in the avg_gradients buffer
- # thus, can clear this
- self.zero_grad()
-
- def _reduce_grad_stage1(self):
- # if not overlapping communication (no reduction hook is attached)
+ def _reduce_grad(self, partition_grad):
+ # if not overlapping communication (no reduction hook is attached) when zero1
# we need to manually reduce these gradients
- if not self._overlap_communication:
+ if not partition_grad and not self._overlap_communication:
for group_id in range(len(self._working_param_groups)):
param_group = self._working_param_groups[group_id]
for param in param_group:
if param.grad is not None:
- self._add_to_reduction_bucket(param)
+ self._add_to_bucket(param, group_id)
- # we need to reduce the gradients
- # left in the communication bucket
+ # run reduction
self._run_reduction()
- def _reduce_grad_stage2(self):
- # when partition_grads is True, reduction hooks
- # are attached in the __init__ function, so we
- # only need to reduce the gradients
- # left in the communication bucket
- for reduce_rank in range(self._world_size):
- self._run_reduction(reduce_rank)
+ # this context comes from pytorch DDP
+ @contextmanager
+ def no_sync(self):
+ old_require_grad_sync = self.require_grad_sync
+ self.require_grad_sync = False
+ try:
+ yield
+ finally:
+ self.require_grad_sync = old_require_grad_sync
diff --git a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py
index eedd8c59a3a8..79f98a4c95d0 100644
--- a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py
+++ b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py
@@ -11,14 +11,9 @@
from tests.kit.model_zoo import model_zoo
# These models are not compatible with AMP
-_AMP_ERR_MODELS = ['timm_convit', 'dlrm', 'deepfm_interactionarch', 'deepfm_simpledeepfmnn']
+_AMP_ERR_MODELS = ['timm_convit', 'deepfm_interactionarch']
# These models have no parameters
-_LOW_LEVEL_ZERO_ERR_MODELS = ['dlrm_interactionarch', 'deepfm_overarch', 'deepfm_sparsearch', 'dlrm_sparsearch']
-# These models will get stuck
-_STUCK_MODELS = [
- 'diffusers_vq_model', 'transformers_albert', 'transformers_albert_for_pretraining', 'transformers_bert',
- 'transformers_bert_for_pretraining', 'transformers_gpt_double_heads'
-]
+_LOW_LEVEL_ZERO_ERR_MODELS = ['dlrm_interactionarch']
def run_fn(stage, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]:
@@ -58,7 +53,7 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True):
"""
passed_models = []
failed_info = {} # (model_name, error) pair
- ignore_models = _AMP_ERR_MODELS + _LOW_LEVEL_ZERO_ERR_MODELS + _STUCK_MODELS
+ ignore_models = _AMP_ERR_MODELS + _LOW_LEVEL_ZERO_ERR_MODELS
skipped_models = []
for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items():
diff --git a/tests/test_zero/test_low_level/test_grad_acc.py b/tests/test_zero/test_low_level/test_grad_acc.py
index c264a8077d2a..ac1f677f9a0d 100644
--- a/tests/test_zero/test_low_level/test_grad_acc.py
+++ b/tests/test_zero/test_low_level/test_grad_acc.py
@@ -39,37 +39,37 @@ def exam_zero_1_2_grad_acc():
overlap_communication=True,
initial_scale=32,
clip_grad_norm=1.0,
+ grad_accumulate_interval=2,
verbose=True)
zero2_optimizer = LowLevelZeroOptimizer(zero2_optimizer,
overlap_communication=True,
partition_grad=True,
initial_scale=32,
- clip_grad_norm=1.0)
+ clip_grad_norm=1.0,
+ grad_accumulate_interval=2)
# create data
seed_all(2021 + local_rank)
input_data1 = torch.randn(32, 128).cuda()
input_data2 = torch.randn(32, 128).cuda()
- def fwd_bwd_func(number, cur_data):
+ def fwd_bwd_func(number, cur_data, check_flag):
# zero-dp forward
zero1_output = zero1_model(cur_data)
zero2_output = zero2_model(cur_data)
assert torch.equal(zero1_output, zero2_output)
# zero-dp backward
- zero1_optimizer.backward(zero1_output.sum().float(), sync_grad=False)
- zero2_optimizer.backward(zero2_output.sum().float(), sync_grad=False)
+ zero1_optimizer.backward(zero1_output.sum().float())
+ zero2_optimizer.backward(zero2_output.sum().float())
- for (n, z1p), z2p in zip(zero1_model.named_parameters(), zero2_model.parameters()):
- if z2p.grad is not None:
- # print(local_rank, n, z1p.shape, torch.max(z2p.grad), torch.max(torch.abs(z1p.grad - z2p.grad)))
- assert torch.equal(z1p.grad, z2p.grad)
-
- zero1_optimizer._sync_grad()
- zero2_optimizer._sync_grad()
+ if check_flag:
+ for (n, z1p), z2p in zip(zero1_model.named_parameters(), zero2_model.parameters()):
+ if z2p.grad is not None:
+ # print(local_rank, n, z1p.shape, torch.max(z2p.grad), torch.max(torch.abs(z1p.grad - z2p.grad)))
+ assert torch.equal(z1p.grad, z2p.grad)
- fwd_bwd_func(0, input_data1)
- fwd_bwd_func(1, input_data2)
+ fwd_bwd_func(0, input_data1, True)
+ fwd_bwd_func(1, input_data2, False)
# step
zero1_optimizer.step()
@@ -101,7 +101,8 @@ def exam_zero_1_grad_acc():
zero_optimizer = LowLevelZeroOptimizer(zero_optimizer,
overlap_communication=False,
reduce_bucket_size=262144,
- clip_grad_norm=1.0)
+ clip_grad_norm=1.0,
+ grad_accumulate_interval=2)
torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=1)
@@ -115,13 +116,19 @@ def fwd_bwd_func(number, cur_data, check_flag):
zero_output = zero_model(cur_data)
# torch-ddp forward
- torch_output = torch_model(cur_data)
- assert torch.equal(zero_output, torch_output)
# zero-dp backward
- zero_optimizer.backward(zero_output.sum().float(), sync_grad=False)
+ zero_optimizer.backward(zero_output.sum().float())
# torch-ddp backward
- torch_output.sum().backward()
+ if number < 1:
+ with torch_model.no_sync():
+ torch_output = torch_model(cur_data)
+ assert torch.equal(zero_output, torch_output)
+ torch_output.sum().backward()
+ else:
+ torch_output = torch_model(cur_data)
+ assert torch.equal(zero_output, torch_output)
+ torch_output.sum().backward()
if check_flag:
# check grad
@@ -129,8 +136,6 @@ def fwd_bwd_func(number, cur_data, check_flag):
# print(n, p.shape, torch.max(torch.abs(p.grad - unscale_grad)))
assert torch.equal(p.grad, z1p.grad)
- zero_optimizer._sync_grad()
-
fwd_bwd_func(0, input_data1, True)
fwd_bwd_func(1, input_data2, False)
@@ -148,7 +153,8 @@ def run_dist(rank, world_size, port):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
exam_zero_1_grad_acc()
- exam_zero_1_2_grad_acc()
+ # gradient accumulation is not compatible with ZeRO-2
+ # exam_zero_1_2_grad_acc()
@pytest.mark.dist
diff --git a/tests/test_zero/test_low_level/test_zero1_2.py b/tests/test_zero/test_low_level/test_zero1_2.py
index 8e2206fe6c8d..5a0609bff192 100644
--- a/tests/test_zero/test_low_level/test_zero1_2.py
+++ b/tests/test_zero/test_low_level/test_zero1_2.py
@@ -2,6 +2,7 @@
import pytest
import torch
+import torch.distributed as dist
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close
@@ -16,8 +17,9 @@ class MlpModel(nn.Module):
def __init__(self):
super(MlpModel, self).__init__()
- self.linear1 = nn.Linear(128, 256)
- self.linear2 = nn.Linear(256, 512)
+ self.linear1 = nn.Linear(123, 253)
+ self.linear_drop = nn.Linear(253, 253)
+ self.linear2 = nn.Linear(253, 512)
def forward(self, x):
x = self.linear1(x)
@@ -41,6 +43,16 @@ def loose_close(a, b, dtype: torch.dtype = torch.float32):
assert_close(a, b, rtol=rtol, atol=atol)
+def split_ddp_grad(grad, world_size):
+ with torch.no_grad():
+ grad = grad.clone().detach().flatten()
+ padding_size = (world_size - grad.numel() % world_size) % world_size
+ if padding_size > 0:
+ grad = torch.nn.functional.pad(grad, [0, padding_size])
+ splited_grad = grad.split(grad.numel() // world_size)
+ return splited_grad
+
+
def exam_zero_1_2():
"""
In this test, we want to test whether zero stage 1 and 2
@@ -72,23 +84,21 @@ def exam_zero_1_2():
initial_scale=128)
# create data
seed_all(2001 + local_rank)
- input_data = torch.randn(32, 128).cuda()
+ input_data = torch.randn(32, 123).cuda()
zero1_output = zero1_model(input_data)
zero2_output = zero2_model(input_data)
assert torch.equal(zero1_output, zero2_output)
# zero-dp backward
- zero1_optimizer.backward(zero1_output.mean().float(), sync_grad=False)
- zero2_optimizer.backward(zero2_output.mean().float(), sync_grad=False)
+ zero1_optimizer.backward(zero1_output.mean().float())
+ zero2_optimizer.backward(zero2_output.mean().float())
- for (n, z1p), z2p in zip(zero1_model.named_parameters(), zero2_model.parameters()):
- if z2p.grad is not None:
- # print(local_rank, n, z1p.shape, torch.max(z2p.grad), torch.max(torch.abs(z1p.grad - z2p.grad)))
- assert torch.equal(z1p.grad, z2p.grad)
-
- zero1_optimizer._sync_grad()
- zero2_optimizer._sync_grad()
+ # check grad
+ z1g_list = zero1_optimizer._grad_store.get_working_grads_by_group_id(0)
+ z2g_list = zero2_optimizer._grad_store.get_working_grads_by_group_id(0)
+ for z1g, z2g in zip(z1g_list, z2g_list):
+ assert torch.equal(z1g, z2g)
# step
zero1_optimizer.step()
@@ -100,7 +110,7 @@ def exam_zero_1_2():
@parameterize('dtype', [torch.float16, torch.bfloat16])
-def exam_zero_1_torch_ddp(dtype: torch.dtype):
+def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype):
"""
In this test, two pairs of model and optimizers are created.
1. zero: use sharded optimizer and fp16 parameters
@@ -116,7 +126,7 @@ def exam_zero_1_torch_ddp(dtype: torch.dtype):
torch_model = MlpModel().cuda()
zero_model = copy.deepcopy(torch_model).to(dtype)
- torch_model = DDP(torch_model.cuda(), bucket_cap_mb=0).cuda()
+ torch_model = DDP(torch_model.cuda(), static_graph=True).cuda()
# create optimizer
zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1)
@@ -133,7 +143,7 @@ def exam_zero_1_torch_ddp(dtype: torch.dtype):
seed_all(1453 + local_rank)
# create
- input_data = torch.rand(32, 128).cuda()
+ input_data = torch.rand(32, 123).cuda()
# zero-dp forward
zero_output = zero_model(input_data.to(dtype))
@@ -143,17 +153,20 @@ def exam_zero_1_torch_ddp(dtype: torch.dtype):
loose_close(zero_output, torch_output, dtype=dtype)
# zero-dp backward
- zero_optimizer.backward(zero_output.mean().float(), sync_grad=False)
+ zero_optimizer.backward(zero_output.mean().float())
# torch-ddp backward
torch_output.mean().backward()
# check grad
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
- loose_close(p.grad, z1p.grad, dtype=dtype)
+ if p.grad is not None:
+ zero_grad_list = zero_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(z1p))
+ torch_grad_list = split_ddp_grad(p.grad, world_size)
+ for zero_grad, torch_grad in zip(zero_grad_list, torch_grad_list):
+ loose_close(zero_grad, torch_grad, dtype=dtype)
# zero-dp step
- zero_optimizer._sync_grad()
zero_optimizer.step()
# torch ddp step
@@ -161,14 +174,13 @@ def exam_zero_1_torch_ddp(dtype: torch.dtype):
# check updated param
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
- # print(n, torch.max(torch.abs(p.data - z1p.data)))
loose_close(p.data, z1p.data, dtype=dtype)
def run_dist(rank, world_size, port):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
- exam_zero_1_torch_ddp()
+ exam_zero_1_torch_ddp(world_size=world_size)
exam_zero_1_2()
From 79cf1b5f3378e7db5b5c9e44eb27c5e7686054e7 Mon Sep 17 00:00:00 2001
From: LuGY <74758262+Gy-Lu@users.noreply.github.com>
Date: Tue, 4 Jul 2023 12:00:33 +0800
Subject: [PATCH 36/64] [zero]support no_sync method for zero1 plugin (#4138)
* support no sync for zero1 plugin
* polish
* polish
---
colossalai/booster/booster.py | 12 ++++---
colossalai/booster/plugin/gemini_plugin.py | 2 +-
.../booster/plugin/low_level_zero_plugin.py | 10 ++++--
colossalai/booster/plugin/plugin_base.py | 2 +-
colossalai/booster/plugin/torch_ddp_plugin.py | 2 +-
.../booster/plugin/torch_fsdp_plugin.py | 2 +-
colossalai/zero/low_level/low_level_optim.py | 29 +++++++--------
.../test_zero/test_low_level/test_grad_acc.py | 35 ++++++++-----------
8 files changed, 45 insertions(+), 49 deletions(-)
diff --git a/colossalai/booster/booster.py b/colossalai/booster/booster.py
index cee547b33b0c..ec3dc7fc143f 100644
--- a/colossalai/booster/booster.py
+++ b/colossalai/booster/booster.py
@@ -9,7 +9,7 @@
from torch.utils.data import DataLoader
from colossalai.checkpoint_io import GeneralCheckpointIO
-from colossalai.interface import ModelWrapper
+from colossalai.interface import ModelWrapper, OptimizerWrapper
from .accelerator import Accelerator
from .mixed_precision import MixedPrecision, mixed_precision_factory
@@ -153,18 +153,20 @@ def execute_pipeline(self,
# return loss or outputs if needed
pass
- def no_sync(self, model: nn.Module) -> contextmanager:
+ def no_sync(self, model: nn.Module = None, optimizer: OptimizerWrapper = None) -> contextmanager:
"""Context manager to disable gradient synchronization across DP process groups.
+ Support torch DDP and Low Level ZeRO-1 for now.
Args:
- model (nn.Module): The model to be disabled gradient synchronization.
+ model (nn.Module): The model to be disabled gradient synchronization, for DDP
+ optimizer (OptimizerWrapper): The optimizer to be disabled gradient synchronization, for ZeRO1-1
Returns:
contextmanager: Context to disable gradient synchronization.
"""
assert self.plugin is not None, f'no_sync is only enabled when a plugin is provided and the plugin supports no_sync.'
- assert self.plugin.support_no_sync, f'The plugin {self.plugin.__class__.__name__} does not support no_sync.'
- return self.plugin.no_sync(model)
+ assert self.plugin.support_no_sync(), f'The plugin {self.plugin.__class__.__name__} does not support no_sync.'
+ return self.plugin.no_sync(model, optimizer)
def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True):
"""Load model from checkpoint.
diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py
index 7b6e17337d36..0f5ba6e9a6da 100644
--- a/colossalai/booster/plugin/gemini_plugin.py
+++ b/colossalai/booster/plugin/gemini_plugin.py
@@ -408,5 +408,5 @@ def control_checkpoint_io(self) -> bool:
def get_checkpoint_io(self) -> CheckpointIO:
return GeminiCheckpointIO()
- def no_sync(self, model: nn.Module) -> Iterator[None]:
+ def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
raise NotImplementedError
diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py
index 3ec0d34092a4..0a3221b231bc 100644
--- a/colossalai/booster/plugin/low_level_zero_plugin.py
+++ b/colossalai/booster/plugin/low_level_zero_plugin.py
@@ -179,8 +179,11 @@ def __init__(
norm_type=norm_type)
self.verbose = verbose
+ # set class name with stage, for better error message
+ setattr(self.__class__, "__name__", f"LowLevelZeroPlugin_ZeRO-{stage}")
+
def support_no_sync(self) -> bool:
- return False
+ return self.stage == 1
def control_precision(self) -> bool:
return True
@@ -219,5 +222,6 @@ def control_checkpoint_io(self) -> bool:
def get_checkpoint_io(self) -> CheckpointIO:
return LowLevelZeroCheckpointIO()
- def no_sync(self, model: nn.Module) -> Iterator[None]:
- raise NotImplementedError
+ def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
+ assert isinstance(optimizer, LowLevelZeroOptimizer)
+ return optimizer.optim.no_sync()
diff --git a/colossalai/booster/plugin/plugin_base.py b/colossalai/booster/plugin/plugin_base.py
index aa78f6827003..fb21e57f41f7 100644
--- a/colossalai/booster/plugin/plugin_base.py
+++ b/colossalai/booster/plugin/plugin_base.py
@@ -61,7 +61,7 @@ def get_checkpoint_io(self) -> CheckpointIO:
pass
@abstractmethod
- def no_sync(self, model: nn.Module) -> Iterator[None]:
+ def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
"""
Context manager to disable gradient synchronization.
"""
diff --git a/colossalai/booster/plugin/torch_ddp_plugin.py b/colossalai/booster/plugin/torch_ddp_plugin.py
index 71b435155503..f3f779c88e42 100644
--- a/colossalai/booster/plugin/torch_ddp_plugin.py
+++ b/colossalai/booster/plugin/torch_ddp_plugin.py
@@ -168,6 +168,6 @@ def control_checkpoint_io(self) -> bool:
def get_checkpoint_io(self) -> CheckpointIO:
return TorchDDPCheckpointIO()
- def no_sync(self, model: nn.Module) -> Iterator[None]:
+ def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
assert isinstance(model, TorchDDPModel), 'Model is not boosted by TorchDDPPlugin.'
return model.module.no_sync()
diff --git a/colossalai/booster/plugin/torch_fsdp_plugin.py b/colossalai/booster/plugin/torch_fsdp_plugin.py
index abfffa9b099e..fb7b5baadd0c 100644
--- a/colossalai/booster/plugin/torch_fsdp_plugin.py
+++ b/colossalai/booster/plugin/torch_fsdp_plugin.py
@@ -177,7 +177,7 @@ def __init__(
def support_no_sync(self) -> bool:
False
- def no_sync(self, model: nn.Module) -> Iterator[None]:
+ def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
raise NotImplementedError("Torch fsdp no_sync func not supported yet.")
def control_precision(self) -> bool:
diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py
index 8743cab3313f..615c870971b1 100644
--- a/colossalai/zero/low_level/low_level_optim.py
+++ b/colossalai/zero/low_level/low_level_optim.py
@@ -14,10 +14,10 @@
)
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
+from colossalai.interface import OptimizerWrapper
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.tensor import ColoParameter, ProcessGroup
-from colossalai.utils import conditional_context
from colossalai.utils.cuda import get_current_device
from ._utils import (
@@ -56,7 +56,7 @@ def check_local_overflow(self) -> bool:
return False
-class LowLevelZeroOptimizer(ColossalaiOptimizer):
+class LowLevelZeroOptimizer(OptimizerWrapper):
"""Optimizer used for ZeRO-1 and ZeRO-2.
"""
@@ -77,11 +77,12 @@ def __init__(
overlap_communication: bool = False,
partition_grad: bool = False, # stage 2 flag
cpu_offload: bool = False, # cpu offload
- grad_accumulate_interval: int = 1,
forced_dtype: Optional[torch.dtype] = None):
- assert not (partition_grad and grad_accumulate_interval > 1), \
- "gradient accumulation is not compatible with ZeRO-2"
+ # TODO:
+ # 1. process group api
+ # 2. checkpoint IO
+
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
self._dtype = self.optim.param_groups[0]['params'][0].dtype
self._logger = get_dist_logger()
@@ -94,8 +95,6 @@ def __init__(
# grad accumulation
self.require_grad_sync = True
- self._accumulate_intervel = grad_accumulate_interval
- self._accumulate_step = 0
colo_pg = self._search_colo_process_group()
if isinstance(colo_pg, ProcessGroup):
@@ -340,15 +339,15 @@ def _add_to_bucket(self, param, group_id):
################################
def backward(self, loss, retain_graph=False):
+ assert not(self._partition_grads and not self.require_grad_sync), \
+ "ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible"
+
if self.mixed_precision_mixin is not None:
loss = self.mixed_precision_mixin.pre_backward(loss)
- self._accumulate_step += 1
- no_sync = self._accumulate_step < self._accumulate_intervel
- with conditional_context(self.no_sync(), enable=no_sync):
- loss.backward(retain_graph=retain_graph)
+ loss.backward(retain_graph=retain_graph)
- if no_sync:
+ if not self.require_grad_sync:
return
self._reduce_grad(self._partition_grads)
@@ -385,7 +384,7 @@ def zero_grad(self, set_to_none=True):
def step(self, closure=None):
assert closure is None, 'closure is not supported by step()'
- if not self._accumulate_step == self._accumulate_intervel:
+ if not self.require_grad_sync:
return
if self.mixed_precision_mixin is not None and self.mixed_precision_mixin.should_skip_step():
@@ -393,7 +392,6 @@ def step(self, closure=None):
if self._verbose:
self._logger.info(f'Found overflow. Skip step')
self.zero_grad()
- self._accumulate_step -= 1
return
# record all grads for unscale and clip
@@ -463,9 +461,6 @@ def step(self, closure=None):
self.optim.param_groups[group_id]['params'] = self._master_param_groups_of_current_rank[group_id]
- # reset accumulate step
- self._accumulate_step = 0
-
#############################
# Mixed Precision Utilities #
#############################
diff --git a/tests/test_zero/test_low_level/test_grad_acc.py b/tests/test_zero/test_low_level/test_grad_acc.py
index ac1f677f9a0d..a1d14f1d5a9d 100644
--- a/tests/test_zero/test_low_level/test_grad_acc.py
+++ b/tests/test_zero/test_low_level/test_grad_acc.py
@@ -9,6 +9,7 @@
import colossalai
from colossalai.testing import spawn
from colossalai.testing.random import seed_all
+from colossalai.utils import conditional_context
from colossalai.zero import LowLevelZeroOptimizer
@@ -39,14 +40,12 @@ def exam_zero_1_2_grad_acc():
overlap_communication=True,
initial_scale=32,
clip_grad_norm=1.0,
- grad_accumulate_interval=2,
verbose=True)
zero2_optimizer = LowLevelZeroOptimizer(zero2_optimizer,
overlap_communication=True,
partition_grad=True,
initial_scale=32,
- clip_grad_norm=1.0,
- grad_accumulate_interval=2)
+ clip_grad_norm=1.0)
# create data
seed_all(2021 + local_rank)
input_data1 = torch.randn(32, 128).cuda()
@@ -59,8 +58,11 @@ def fwd_bwd_func(number, cur_data, check_flag):
assert torch.equal(zero1_output, zero2_output)
# zero-dp backward
- zero1_optimizer.backward(zero1_output.sum().float())
- zero2_optimizer.backward(zero2_output.sum().float())
+ no_sync = number == 0
+ with conditional_context(zero1_optimizer.no_sync(), no_sync):
+ zero1_optimizer.backward(zero1_output.sum().float())
+ with conditional_context(zero2_optimizer.no_sync(), no_sync):
+ zero2_optimizer.backward(zero2_output.sum().float())
if check_flag:
for (n, z1p), z2p in zip(zero1_model.named_parameters(), zero2_model.parameters()):
@@ -101,8 +103,7 @@ def exam_zero_1_grad_acc():
zero_optimizer = LowLevelZeroOptimizer(zero_optimizer,
overlap_communication=False,
reduce_bucket_size=262144,
- clip_grad_norm=1.0,
- grad_accumulate_interval=2)
+ clip_grad_norm=1.0)
torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=1)
@@ -112,20 +113,15 @@ def exam_zero_1_grad_acc():
input_data2 = torch.randn(32, 128).cuda()
def fwd_bwd_func(number, cur_data, check_flag):
- # zero-dp forward
- zero_output = zero_model(cur_data)
- # torch-ddp forward
+ no_sync = number == 0
+ # zero1 fwd and bwd
+ with conditional_context(zero_optimizer.no_sync(), no_sync):
+ zero_output = zero_model(cur_data)
+ zero_optimizer.backward(zero_output.sum().float())
- # zero-dp backward
- zero_optimizer.backward(zero_output.sum().float())
- # torch-ddp backward
- if number < 1:
- with torch_model.no_sync():
- torch_output = torch_model(cur_data)
- assert torch.equal(zero_output, torch_output)
- torch_output.sum().backward()
- else:
+ # torch-ddp fwd and bwd
+ with conditional_context(torch_model.no_sync(), no_sync):
torch_output = torch_model(cur_data)
assert torch.equal(zero_output, torch_output)
torch_output.sum().backward()
@@ -133,7 +129,6 @@ def fwd_bwd_func(number, cur_data, check_flag):
if check_flag:
# check grad
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
- # print(n, p.shape, torch.max(torch.abs(p.grad - unscale_grad)))
assert torch.equal(p.grad, z1p.grad)
fwd_bwd_func(0, input_data1, True)
From c668801d363eb0142d8cb0fc789b1cf7d55f8077 Mon Sep 17 00:00:00 2001
From: LuGY <74758262+Gy-Lu@users.noreply.github.com>
Date: Tue, 4 Jul 2023 17:41:28 +0800
Subject: [PATCH 37/64] [zero] allow passing process group to zero12 (#4153)
* allow passing process group to zero12
* union tp-zero and normal-zero
* polish code
---
colossalai/zero/low_level/_utils.py | 48 +++++-------
colossalai/zero/low_level/low_level_optim.py | 74 +++++--------------
.../test_low_level/test_zero_init.py | 5 +-
.../test_zero/test_low_level/test_zero_tp.py | 4 +-
4 files changed, 41 insertions(+), 90 deletions(-)
diff --git a/colossalai/zero/low_level/_utils.py b/colossalai/zero/low_level/_utils.py
index a9e552ebdabc..4205a9891534 100644
--- a/colossalai/zero/low_level/_utils.py
+++ b/colossalai/zero/low_level/_utils.py
@@ -3,8 +3,9 @@
import torch
import torch.distributed as dist
-from torch import inf
+from torch import Tensor, inf
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
+from torch.distributed import ProcessGroup
from colossalai.tensor import ColoParameter
from colossalai.utils import is_model_parallel_parameter
@@ -194,26 +195,21 @@ def calculate_global_norm_from_list(norm_list):
return math.sqrt(total_norm)
-def compute_norm(gradients, params, dp_group, mp_group, norm_type=2):
+def compute_norm(gradients: Tensor, dp_group: ProcessGroup, tp_group: ProcessGroup, norm_type: int = 2) -> int:
"""Clips gradient norm of an iterable of parameters.
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
- added functionality to handle model parallel parameters. Note that
- the gradients are modified in place.
- Arguments:
- parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
- single Tensor that will have gradients normalized
- max_norm (float or int): max norm of the gradients
- norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
- infinity norm.
+ added functionality to handle model parallel parameters.
+
+ Args:
+ gradients (Tensor): The gradients to compute norm
+ dp_group (ProcessGroup): The process group of ZeRO Data Parallelism
+ tp_group (ProcessGroup): The process group of Tensor Parallelism
+ norm_type (int, optional): type of the used p-norm, Can be ``'inf'`` for infinity norm. Defaults to 2.
+
Returns:
- Total norm of the parameters (viewed as a single vector).
+ int: The total norm of given gradients
"""
- if mp_group is None:
- mp_rank = 0
- else:
- mp_rank = dist.get_rank(mp_group)
-
norm_type = float(norm_type)
if norm_type == inf:
total_norm = max(g.data.abs().max() for g in gradients)
@@ -221,29 +217,21 @@ def compute_norm(gradients, params, dp_group, mp_group, norm_type=2):
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=dp_group)
# Take max across all GPUs.
- if mp_group is not None:
+ if tp_group is not None:
dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.MAX)
total_norm = total_norm_cuda[0].item()
else:
total_norm = 0.0
- # if dist.get_rank() == 0:
- # logger.info(f"Total Norm beginning {total_norm}")
-
- for g, p in zip(gradients, params):
- # Pipeline parallelism may replicate parameters. Avoid multi-counting.
- tp_param_flag = False
- if is_model_parallel_parameter(p) or (isinstance(p, ColoParameter) and not p.is_replicate()):
- tp_param_flag = True
- if tp_param_flag or mp_rank == 0:
- param_norm = g.data.double().norm(2)
- total_norm += param_norm.item()**2
+ for g in gradients:
+ param_norm = g.data.double().norm(2)
+ total_norm += param_norm.item()**2
# Sum across all model parallel GPUs.
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
torch.distributed.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=dp_group)
- if mp_group is not None:
- dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=mp_group)
+ if tp_group is not None:
+ dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=tp_group)
total_norm = total_norm_cuda[0].item()**(1. / norm_type)
diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py
index 615c870971b1..27ac06ec9dc5 100644
--- a/colossalai/zero/low_level/low_level_optim.py
+++ b/colossalai/zero/low_level/low_level_optim.py
@@ -5,6 +5,7 @@
import torch
import torch.distributed as dist
+from torch.distributed import ProcessGroup
from torch.optim import Optimizer
from colossalai.amp.naive_amp.mixed_precision_mixin import (
@@ -12,12 +13,9 @@
FP16MixedPrecisionMixin,
MixedPrecisionMixin,
)
-from colossalai.context import ParallelMode
-from colossalai.core import global_context as gpc
from colossalai.interface import OptimizerWrapper
from colossalai.logging import get_dist_logger
-from colossalai.nn.optimizer import ColossalaiOptimizer
-from colossalai.tensor import ColoParameter, ProcessGroup
+# from colossalai.tensor import ColoParameter, ProcessGroup
from colossalai.utils.cuda import get_current_device
from ._utils import (
@@ -77,11 +75,12 @@ def __init__(
overlap_communication: bool = False,
partition_grad: bool = False, # stage 2 flag
cpu_offload: bool = False, # cpu offload
+ dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
+ tp_process_group: Optional[ProcessGroup] = None, # if using tp
forced_dtype: Optional[torch.dtype] = None):
# TODO:
- # 1. process group api
- # 2. checkpoint IO
+ # 1. state_dict for checkpoint IO
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
self._dtype = self.optim.param_groups[0]['params'][0].dtype
@@ -96,30 +95,12 @@ def __init__(
# grad accumulation
self.require_grad_sync = True
- colo_pg = self._search_colo_process_group()
- if isinstance(colo_pg, ProcessGroup):
- self._local_rank = colo_pg.dp_local_rank()
- self._world_size = colo_pg.dp_world_size()
- self._dp_global_ranks = colo_pg.get_ranks_in_dp()
- self._dp_torch_group = colo_pg.dp_process_group()
- self._mp_torch_group = None
- if colo_pg.tp_world_size() > 1:
- self._mp_torch_group = colo_pg.tp_process_group()
- elif colo_pg is None:
- dp_parallel_mode = ParallelMode.DATA
- mp_parallel_mode = ParallelMode.MODEL
-
- self._dp_parallel_mode = dp_parallel_mode
- self._mp_parallel_mode = mp_parallel_mode
- self._local_rank = gpc.get_local_rank(dp_parallel_mode)
- self._world_size = gpc.get_world_size(dp_parallel_mode)
- self._dp_global_ranks = gpc.get_ranks_in_group(dp_parallel_mode)
- self._dp_torch_group = gpc.get_group(dp_parallel_mode)
- self._mp_torch_group = None
- if gpc.is_initialized(mp_parallel_mode) and gpc.get_world_size(mp_parallel_mode) > 1:
- self._mp_torch_group = gpc.get_group(mp_parallel_mode)
- else:
- raise NotImplementedError
+ # if process_group is none, will use the default one
+ self.dp_pg = dp_process_group
+ self._local_rank = dist.get_rank(group=self.dp_pg)
+ self._world_size = dist.get_world_size(group=self.dp_pg)
+
+ self.tp_pg = tp_process_group
# working and master params for mixed precision training
self._working_param_groups = dict()
@@ -145,9 +126,9 @@ def __init__(
# ParameterStore will manage the tensor buffers used for zero
# it will not manage the tensors used by mixed precision training
- self._param_store = ParameterStore(self._dp_torch_group)
- self._grad_store = GradientStore(self._dp_torch_group, partition_grad=partition_grad)
- self._bucket_store = BucketStore(self._dp_torch_group)
+ self._param_store = ParameterStore(self.dp_pg)
+ self._grad_store = GradientStore(self.dp_pg, partition_grad=partition_grad)
+ self._bucket_store = BucketStore(self.dp_pg)
# iterate over the param group in the optimizer
# partition these param groups for data parallel training
@@ -212,22 +193,6 @@ def _sanity_checks(self):
assert param.dtype == self._dtype, \
f"Parameters are expected to have the same dtype `{self._dtype}`, but got `{param.dtype}`"
- def _search_colo_process_group(self):
- colo_flag = False
- colo_pg = None
- for param_group in self.optim.param_groups:
- group_params = param_group['params']
- for param in group_params:
- if isinstance(param, ColoParameter):
- colo_flag = True
- if colo_pg is None:
- colo_pg = param.get_process_group()
- else:
- assert colo_pg == param.get_process_group(), "All parameters should be in a same process group"
- elif colo_flag:
- raise RuntimeError("All parameters should be ColoParameter if you use ColoParameter.")
- return colo_pg
-
def _create_master_param_current_rank(self, param_list):
# split each param evenly by world size
params_current_rank = []
@@ -291,7 +256,7 @@ def _run_reduction(self):
flat_grads = flat_grads.to(self._communication_dtype)
if not self._partition_grads:
- dist.all_reduce(flat_grads, group=self._dp_torch_group)
+ dist.all_reduce(flat_grads, group=self.dp_pg)
if flat_grads.dtype != grad_dtype:
flat_grads = flat_grads.to(grad_dtype)
@@ -307,7 +272,7 @@ def _run_reduction(self):
else:
flat_grads_list = list(flat_grads.split(len(flat_grads) // self._world_size))
recieved_grad = torch.zeros_like(flat_grads_list[0])
- dist.reduce_scatter(recieved_grad, flat_grads_list, group=self._dp_torch_group)
+ dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg)
if recieved_grad.dtype != grad_dtype:
recieved_grad = recieved_grad.to(grad_dtype)
@@ -425,10 +390,7 @@ def step(self, closure=None):
# compute norm
working_grads = self._grad_store.get_working_grads_by_group_id(group_id)
- norm_group = compute_norm(gradients=working_grads,
- params=real_working_params[group_id],
- dp_group=self._dp_torch_group,
- mp_group=self._mp_torch_group)
+ norm_group = compute_norm(gradients=working_grads, dp_group=self.dp_pg, tp_group=self.tp_pg)
norm_groups.append(norm_group)
self._grad_store.reset_grads_by_group_id(group_id)
@@ -454,7 +416,7 @@ def step(self, closure=None):
for idx, splited_param in enumerate(master_working_param):
full_master_param = [torch.zeros_like(splited_param).cuda() for _ in range(self._world_size)]
- dist.all_gather(full_master_param, splited_param.cuda(), group=self._dp_torch_group)
+ dist.all_gather(full_master_param, splited_param.cuda(), group=self.dp_pg)
working_param = real_working_params[group_id][idx]
full_master_param = flatten(full_master_param)[:working_param.numel()].reshape_as(working_param)
working_param.data.copy_(full_master_param)
diff --git a/tests/test_zero/test_low_level/test_zero_init.py b/tests/test_zero/test_low_level/test_zero_init.py
index aeeaff5b5cb9..368ef976ef6e 100644
--- a/tests/test_zero/test_low_level/test_zero_init.py
+++ b/tests/test_zero/test_low_level/test_zero_init.py
@@ -33,10 +33,9 @@ def exam_zero_init():
assert optimizer1._local_rank == optimizer2._local_rank
assert optimizer1._world_size == optimizer2._world_size
- assert optimizer1._dp_global_ranks == optimizer2._dp_global_ranks
- mp_group1 = optimizer1._mp_torch_group
- mp_group2 = optimizer2._mp_torch_group
+ mp_group1 = optimizer1.tp_pg
+ mp_group2 = optimizer2.tp_pg
assert dist.get_world_size(mp_group1) == dist.get_world_size(mp_group2)
assert dist.get_rank(mp_group1) == dist.get_rank(mp_group2)
diff --git a/tests/test_zero/test_low_level/test_zero_tp.py b/tests/test_zero/test_low_level/test_zero_tp.py
index f0804f4bb5ba..238de3334c80 100644
--- a/tests/test_zero/test_low_level/test_zero_tp.py
+++ b/tests/test_zero/test_low_level/test_zero_tp.py
@@ -57,7 +57,9 @@ def exam_zero_with_tp(overlap_flag, partition_flag):
initial_scale=2,
clip_grad_norm=1.0,
overlap_communication=overlap_flag,
- partition_grad=partition_flag)
+ partition_grad=partition_flag,
+ dp_process_group=tp_pg.dp_process_group(),
+ tp_process_group=tp_pg.tp_process_group())
dp_local_rank = tp_pg.dp_local_rank()
set_seed(255 + dp_local_rank)
From dd7cc5829998c26f3207bee578f2d8470e7f61f2 Mon Sep 17 00:00:00 2001
From: LuGY <74758262+Gy-Lu@users.noreply.github.com>
Date: Thu, 6 Jul 2023 17:20:04 +0800
Subject: [PATCH 38/64] [zero] add state dict for low level zero (#4179)
* add state dict for zero
* fix unit test
* polish
---
colossalai/zero/low_level/low_level_optim.py | 68 +++++++++-
.../test_low_level/test_zero_ckpt.py | 121 ++++++++++++++++++
2 files changed, 188 insertions(+), 1 deletion(-)
create mode 100644 tests/test_zero/test_low_level/test_zero_ckpt.py
diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py
index 27ac06ec9dc5..72bec8b0c070 100644
--- a/colossalai/zero/low_level/low_level_optim.py
+++ b/colossalai/zero/low_level/low_level_optim.py
@@ -1,4 +1,5 @@
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
+import copy
from contextlib import contextmanager
from functools import partial
from typing import Optional
@@ -198,7 +199,7 @@ def _create_master_param_current_rank(self, param_list):
params_current_rank = []
device = 'cpu' if self._cpu_offload else get_current_device()
- for param in reversed(param_list):
+ for param in param_list:
padding_size = (self._world_size - param.numel() % self._world_size) % self._world_size
self._param_store.record_param_padding_size(param, padding_size)
@@ -468,3 +469,68 @@ def no_sync(self):
yield
finally:
self.require_grad_sync = old_require_grad_sync
+
+ ##############
+ # State Dict #
+ ##############
+ def _pack_state(self, state: dict) -> dict:
+ # comes from pytorch optimizer.state_dict()
+ param_mappings = {}
+ start_index = 0
+
+ def pack_group(group):
+ nonlocal start_index
+ packed = {k: v for k, v in group.items() if k != 'params'}
+ param_mappings.update(
+ {id(p): i for i, p in enumerate(group['params'], start_index) if id(p) not in param_mappings})
+ packed['params'] = [param_mappings[id(p)] for p in group['params']]
+ start_index += len(packed['params'])
+ return packed
+
+ param_groups = [pack_group(g) for g in self.param_groups]
+ # Remap state to use order indices as keys
+ packed_state = {(param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): v for k, v in state.items()}
+
+ return {'state': packed_state, 'param_groups': param_groups}
+
+ def state_dict(self) -> dict:
+ """Return a state_dict same with DDP
+
+ Returns:
+ dict: the pytorch form state_dict
+ """
+ zero_state = dict()
+ for param, state in self.optim.state.items():
+ zero_state[param] = copy.deepcopy(state)
+ for k, v in state.items():
+ if isinstance(v, torch.Tensor) and k != 'step':
+ working_param = self._param_store.master_to_working_param[id(param)]
+ gather_tensor = [torch.zeros_like(v) for _ in range(self._world_size)]
+ dist.all_gather(gather_tensor, v, group=self.dp_pg)
+ param_state = torch.stack(gather_tensor).view(-1)[:working_param.numel()].reshape_as(working_param)
+ zero_state[param][k] = param_state
+
+ states_dict = self._pack_state(zero_state)
+
+ return states_dict
+
+ def load_state_dict(self, state_dict: dict):
+ """Load state dict, requires the state_dict be the pytorch form
+
+ Args:
+ state_dict (dict): A pytorch form state_dict
+ """
+ zero_state_dict = copy.deepcopy(state_dict)
+ for param_idx, state in zero_state_dict['state'].items():
+ for k, v in state.items():
+ if isinstance(v, torch.Tensor) and k != 'step':
+ padding_size = (self._world_size - v.numel() % self._world_size) % self._world_size
+ with torch.no_grad():
+ v = v.flatten()
+ if padding_size > 0:
+ v = torch.nn.functional.pad(v, [0, padding_size])
+ v_list = v.split(v.numel() // self._world_size)
+ zero_state_dict['state'][param_idx][k] = v_list[self._local_rank].detach()
+
+ self.optim.load_state_dict(zero_state_dict)
+ zero_state_dict = dict()
diff --git a/tests/test_zero/test_low_level/test_zero_ckpt.py b/tests/test_zero/test_low_level/test_zero_ckpt.py
new file mode 100644
index 000000000000..23356fe718a6
--- /dev/null
+++ b/tests/test_zero/test_low_level/test_zero_ckpt.py
@@ -0,0 +1,121 @@
+import copy
+
+import pytest
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.testing import assert_close
+
+import colossalai
+from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
+from colossalai.testing.random import seed_all
+from colossalai.zero import LowLevelZeroOptimizer
+
+
+class MlpModel(nn.Module):
+
+ def __init__(self):
+ super(MlpModel, self).__init__()
+ self.linear1 = nn.Linear(12, 24)
+ self.linear2 = nn.Linear(24, 12)
+
+ def forward(self, x):
+ x = self.linear1(x)
+ x = self.linear2(x)
+ return x
+
+
+def loose_close(a, b, dtype: torch.dtype = torch.float32):
+ rtol = None
+ atol = None
+ if dtype is torch.float16:
+ rtol = 5e-2
+ atol = 5e-4
+ elif dtype is torch.bfloat16:
+ rtol = 4e-3
+ atol = 4e-3
+
+ a = a.detach().to(dtype)
+ b = b.detach().to(dtype)
+
+ assert_close(a, b, rtol=rtol, atol=atol)
+
+
+def exam_zero_1_torch_ddp_ckpt():
+ """
+ We examine the state_dict of zero and DDP.
+ Moreover, we examine the zero's loading checkpoint of a torch ckpt.
+ """
+ local_rank = torch.distributed.get_rank()
+ seed_all(1453)
+
+ # create models
+ torch_model = MlpModel().cuda()
+ zero_model = copy.deepcopy(torch_model)
+
+ torch_model = DDP(torch_model.cuda(), static_graph=True).cuda()
+
+ # create optimizer
+ zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1)
+
+ # we only test stage 1 here
+ # the state dicts of stage 1 and stage 2 are the same
+ zero_optimizer = LowLevelZeroOptimizer(zero_optimizer,
+ overlap_communication=True,
+ initial_scale=1,
+ reduce_bucket_size=262144)
+
+ torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=1)
+
+ seed_all(1453 + local_rank)
+ # create
+ input_data = torch.rand(4, 12).cuda()
+
+ # forward
+ zero_output = zero_model(input_data)
+ torch_output = torch_model(input_data)
+
+ # backward
+ zero_optimizer.backward(zero_output.mean().float())
+ torch_output.mean().backward()
+
+ # step
+ zero_optimizer.step()
+ torch_optimizer.step()
+
+ torch_state_dict = torch_optimizer.state_dict()
+ zero_state_dict = zero_optimizer.state_dict()
+
+ # examine the original state dict
+ for torch_state, zero_state in zip(torch_state_dict['state'].values(), zero_state_dict['state'].values()):
+ for t_v, z_v in zip(torch_state.values(), zero_state.values()):
+ loose_close(t_v, z_v)
+
+ # empty the optimzer state
+ zero_optimizer.optim.state = []
+
+ # zero load a torch checkpoint
+ zero_optimizer.load_state_dict(copy.deepcopy(torch_state_dict))
+ zero_state_dict = zero_optimizer.state_dict()
+
+ # examine the loaded state dict
+ for torch_state, zero_state in zip(torch_state_dict['state'].values(), zero_state_dict['state'].values()):
+ for t_v, z_v in zip(torch_state.values(), zero_state.values()):
+ loose_close(t_v, z_v)
+
+
+def run_dist(rank, world_size, port):
+ colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
+
+ exam_zero_1_torch_ddp_ckpt()
+
+
+@pytest.mark.dist
+@rerun_if_address_is_in_use()
+def test_zero_ckpt():
+ spawn(run_dist, 2)
+
+
+if __name__ == '__main__':
+ test_zero_ckpt()
From 1a49a5ea009b3e4033599bba452403be1f778ac1 Mon Sep 17 00:00:00 2001
From: LuGY <74758262+Gy-Lu@users.noreply.github.com>
Date: Tue, 11 Jul 2023 18:03:13 +0800
Subject: [PATCH 39/64] [zero] support shard optimizer state dict of zero
(#4194)
* support shard optimizer of zero
* polish code
* support sync grad manually
---
.../booster/plugin/low_level_zero_plugin.py | 159 ++++++++++++------
colossalai/zero/low_level/low_level_optim.py | 77 +++++++--
colossalai/zero/low_level/readme.md | 54 ++++++
.../test_low_level_zero_checkpoint_io.py | 15 +-
4 files changed, 238 insertions(+), 67 deletions(-)
create mode 100644 colossalai/zero/low_level/readme.md
diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py
index 0a3221b231bc..616b218b2070 100644
--- a/colossalai/booster/plugin/low_level_zero_plugin.py
+++ b/colossalai/booster/plugin/low_level_zero_plugin.py
@@ -1,5 +1,8 @@
+import logging
+import os
import warnings
from functools import partial
+from pathlib import Path
from typing import Callable, Iterator, List, Optional, Tuple, Union
import torch
@@ -10,10 +13,16 @@
from torch.utils._pytree import tree_map
from torch.utils.data import DataLoader
-from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
+from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO
+from colossalai.checkpoint_io.utils import (
+ get_optimizer_base_filenames,
+ get_shard_filename,
+ save_param_groups,
+ save_state_dict,
+)
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.utils import get_current_device
-from colossalai.zero import zero_model_wrapper, zero_optim_wrapper
+from colossalai.zero import LowLevelZeroOptimizer, zero_model_wrapper, zero_optim_wrapper
from .dp_plugin_base import DPPluginBase
from .torch_ddp_plugin import TorchDDPCheckpointIO
@@ -32,21 +41,104 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
- def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
- """
- Save optimizer to checkpoint but only on master process.
+ def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False):
+ """Save optimizer to checkpoint but only on master process.
+
+ Args:
+ optimizer (OptimizerWrapper): Optimizer to save state_dict
+ checkpoint (str): Path to save checkpoint
+ gather_dtensor (bool): Whether to gather_dtensor, not used
"""
- # TODO(ver217): optimizer state dict is sharded, and cannot get full state dict now
- warnings.warn(
- 'LowLevelZeroPlugin does not support save full optimizer checkpoint now. Save it on every process.')
- checkpoint = f'{checkpoint}.rank{self.coordinator.rank}'
- GeneralCheckpointIO.save_unsharded_optimizer(self, optimizer, checkpoint, gather_dtensor)
- def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
- warnings.warn(
- 'LowLevelZeroPlugin can only load optimizer checkpoint saved by itself with the same number of processes.')
- checkpoint = f'{checkpoint}.rank{self.coordinator.rank}'
- super().load_optimizer(optimizer, checkpoint)
+ # the `state_dict` in LowLevelZeroOptimizer has communication
+ # if only the master rank collect state_dict and save,
+ # the communication on each rank would not match
+ state_dict = optimizer.state_dict()
+ if self.coordinator.is_master():
+ save_state_dict(state_dict, checkpoint, use_safetensors=False)
+
+ def save_sharded_optimizer(self,
+ optimizer: OptimizerWrapper,
+ checkpoint: str,
+ gather_dtensor: bool = False,
+ prefix: str = None,
+ size_per_shard: int = 1024):
+ """
+ Save sharded Zero-optimizer checkpoint under the given checkpointing path.
+ The following files will be created under the path:
+ - An index file (pytorch_optim.bin.index.json) containing a map between optimizer states and file names
+ - A group file (pytorch_optim_group.bin) recording information of param_groups
+ - Multiple files (pytorch_optim-000XX.bin) that store state tensors of optimizer in a sharding way
+
+ Args:
+ optimizer (OptimizerWrapper): Optimizer to save sharded state_dict
+ checkpoint (str): Path to save optimizer state_dict
+ gather_dtensor (bool): Whether to gather_dtensor, not used
+ prefix (str): Perfix of file to save
+ size_per_shard (int): Max file size of each file that store state tensors
+ """
+ if os.path.isfile(checkpoint):
+ logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
+ return
+
+ Path(checkpoint).mkdir(parents=True, exist_ok=True)
+
+ # state_dict only provide only 'param_groups'
+ state_dict = optimizer.optim.state_dict()
+ # state shard would be handled by the low-level zero optimizer
+ sharded_state = optimizer.state_dict_shard(max_shard_size=size_per_shard)
+
+ # Preparing file paths and index file.
+ states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
+ index_file = CheckpointIndexFile(checkpoint)
+
+ # Store the information of param groups to param_group_file.
+ index_file.append_meta_data("param_groups", param_group_file)
+ group_file_path = os.path.join(checkpoint, param_group_file)
+ save_param_groups(state_dict, group_file_path)
+
+ # Save shards of optimizer states.
+ total_size = 0
+ for idx, shard_pair in enumerate(sharded_state):
+ shard, current_size = shard_pair
+ shard_file = get_shard_filename(states_name, idx)
+ total_size = total_size + current_size
+ for param_id in shard.keys():
+ index_file.append_weight_map(str(param_id), shard_file)
+
+ checkpoint_file_path = os.path.join(checkpoint, shard_file)
+ if self.coordinator.is_master():
+ save_state_dict(shard, checkpoint_file_path, use_safetensors=False)
+
+ # Wrap up index file.
+ index_file.append_meta_data("total_size", total_size)
+ if self.coordinator.is_master():
+ index_file.write_index_file(save_index_file)
+ logging.info(f"The optimizer is going to be split to checkpoint shards. "
+ f"You can find where each parameters has been saved in the "
+ f"index located at {save_index_file}.")
+
+ def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: str, prefix: str):
+ """Load sharded optimizer with the given path to index file.
+
+ Args:
+ optimizer (OptimizerWrapper): Optimizer to load state_dict
+ index_file_path (str): Path to the index file
+ prefix (str): Not used.
+ """
+ super().load_sharded_optimizer(optimizer, index_file_path, prefix)
+ current_rank_state_dict = optimizer.optim.state_dict()['state']
+ for param_idx, state in current_rank_state_dict.items():
+ for k, v in state.items():
+ if isinstance(v, torch.Tensor) and k != 'step':
+ padding_size = (self.coordinator.world_size -
+ v.numel() % self.coordinator.world_size) % self.coordinator.world_size
+ with torch.no_grad():
+ v = v.flatten()
+ if padding_size > 0:
+ v = torch.nn.functional.pad(v, [0, padding_size])
+ v_list = v.split(v.numel() // self.coordinator.world_size)
+ current_rank_state_dict[param_idx][k] = v_list[self.coordinator.rank].detach()
class LowLevelZeroModel(ModelWrapper):
@@ -74,36 +166,6 @@ def forward(self, *args, **kwargs):
return super().forward(*args, **kwargs)
-class LowLevelZeroOptimizer(OptimizerWrapper):
-
- def __init__(self,
- module: nn.Module,
- optimizer: Optimizer,
- zero_optim_config: dict,
- optim_kwargs: dict,
- verbose: bool = False) -> None:
- optimizer = zero_optim_wrapper(module,
- optimizer,
- optim_config=zero_optim_config,
- **optim_kwargs,
- verbose=verbose)
- super().__init__(optimizer)
-
- def backward(self, loss: Tensor, *args, **kwargs):
- self.optim.backward(loss)
-
- def clip_grad_by_norm(self,
- max_norm: Union[float, int],
- norm_type: Union[float, int] = 2,
- error_if_nonfinite: bool = False,
- *args,
- **kwargs) -> Tensor:
- warnings.warn(f'LowLevelZero controls grad clipping by itself, so you should not use clip_grad_by_norm')
-
- def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None:
- raise NotImplementedError('LowLevelZero does not support clip_grad_by_value')
-
-
class LowLevelZeroPlugin(DPPluginBase):
"""
Plugin for low level zero.
@@ -211,8 +273,11 @@ def configure(
if optimizer is not None and \
not isinstance(optimizer, OptimizerWrapper):
- optimizer = LowLevelZeroOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs,
- self.verbose)
+ optimizer = zero_optim_wrapper(model.unwrap(),
+ optimizer,
+ optim_config=self.zero_optim_config,
+ **self.optim_kwargs,
+ verbose=self.verbose)
return model, optimizer, criterion, dataloader, lr_scheduler
diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py
index 72bec8b0c070..023db122fd33 100644
--- a/colossalai/zero/low_level/low_level_optim.py
+++ b/colossalai/zero/low_level/low_level_optim.py
@@ -2,7 +2,7 @@
import copy
from contextlib import contextmanager
from functools import partial
-from typing import Optional
+from typing import Dict, Iterator, Optional, Tuple
import torch
import torch.distributed as dist
@@ -447,18 +447,23 @@ def _unscale_and_clip_grads(self, grad_groups_flat, total_norm):
# Gradient Synchronization #
############################
+ # this method is used to sync gradient manually
+ def sync_grad(self):
+ for group_id in range(self.num_param_groups):
+ param_group = self._working_param_groups[group_id]
+ for param in param_group:
+ if param.requires_grad and param.grad is not None:
+ self._add_to_bucket(param, group_id)
+
+ self._run_reduction()
+
def _reduce_grad(self, partition_grad):
# if not overlapping communication (no reduction hook is attached) when zero1
# we need to manually reduce these gradients
if not partition_grad and not self._overlap_communication:
- for group_id in range(len(self._working_param_groups)):
- param_group = self._working_param_groups[group_id]
- for param in param_group:
- if param.grad is not None:
- self._add_to_bucket(param, group_id)
-
- # run reduction
- self._run_reduction()
+ self.sync_grad()
+ else:
+ self._run_reduction()
# this context comes from pytorch DDP
@contextmanager
@@ -473,7 +478,8 @@ def no_sync(self):
##############
# State Dict #
##############
- def _pack_state(self, state: dict) -> dict:
+
+ def _pack_state(self, state: Dict) -> Dict:
# comes from pytorch optimizer.state_dict()
param_mappings = {}
start_index = 0
@@ -487,17 +493,17 @@ def pack_group(group):
start_index += len(packed['params'])
return packed
- param_groups = [pack_group(g) for g in self.param_groups]
+ param_groups = [pack_group(g) for g in self.optim.param_groups]
# Remap state to use order indices as keys
packed_state = {(param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): v for k, v in state.items()}
return {'state': packed_state, 'param_groups': param_groups}
- def state_dict(self) -> dict:
+ def state_dict(self) -> Dict:
"""Return a state_dict same with DDP
Returns:
- dict: the pytorch form state_dict
+ Dict: the pytorch form state_dict
"""
zero_state = dict()
for param, state in self.optim.state.items():
@@ -514,7 +520,7 @@ def state_dict(self) -> dict:
return states_dict
- def load_state_dict(self, state_dict: dict):
+ def load_state_dict(self, state_dict: Dict):
"""Load state dict, requires the state_dict be the pytorch form
Args:
@@ -534,3 +540,46 @@ def load_state_dict(self, state_dict: dict):
self.optim.load_state_dict(zero_state_dict)
zero_state_dict = dict()
+
+ def state_dict_shard(self, max_shard_size: int = 1024) -> Iterator[Tuple[Dict, int]]:
+ """Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``.
+ Only include the 'state' in state_dict.
+
+ Args:
+ max_shard_size (int, optional): max size of state shard (in MB). Defaults to 1024.
+
+ Yields:
+ Iterator[OrderedDict]: A generator of state dict shard
+ """
+ ret_block = dict()
+ ret_block_size = 0
+
+ local_states = self.optim.state_dict()['state']
+ for param_idx, states in local_states.items():
+ current_block_size = 0
+ current_block = copy.deepcopy(states)
+
+ # find the working param of current param_id
+ for group_id, pg in self._master_param_groups_of_current_rank.items():
+ if (group_id + 1) * len(pg) < param_idx:
+ continue
+ master_param = pg[param_idx - (group_id) * len(pg)]
+ working_param = self._param_store.master_to_working_param[id(master_param)]
+
+ for k, v in states.items():
+ if isinstance(v, torch.Tensor) and k != 'step':
+ state_tensor = [torch.zeros_like(v) for _ in range(self._world_size)]
+ dist.all_gather(state_tensor, v, group=self.dp_pg)
+ state_tensor = torch.stack(state_tensor).view(-1)[:working_param.numel()].reshape_as(working_param)
+ current_block_size += state_tensor.numel()
+ current_block[k] = state_tensor
+
+ if ret_block_size + current_block_size > max_shard_size and len(ret_block) > 0:
+ yield ret_block, ret_block_size
+ ret_block = dict()
+ ret_block_size = 0
+
+ ret_block[param_idx] = current_block
+ ret_block_size += current_block_size
+
+ yield ret_block, ret_block_size
diff --git a/colossalai/zero/low_level/readme.md b/colossalai/zero/low_level/readme.md
new file mode 100644
index 000000000000..aa92159d8022
--- /dev/null
+++ b/colossalai/zero/low_level/readme.md
@@ -0,0 +1,54 @@
+# Low Level ZeRO
+>Low Level ZeRO == ZeRO-DP stage 1 and 2, we would denote it as ZeRO.
+
+## Design:
+### Notion
+`p32` denotes the param copy in the optimizer
+`p` denotes the model param
+`g` denotes the gradient
+
+### INIT
+In low level zero(1, 2), `p32` is split. Different from the previous implement, we split each `p32` evenly by world_size. Thus, rank0 got a param list as `[p00, p10]`, rank1 got a param list as `[p-01, p-11]`, etc.
+
+
+For the detailed implementation, we first pad `p` for it can be split by world_size if needed. Then, we would view it to the shape `[world_size, -1]`, and each rank got its own part `p32` by cloning.
+
+### BWD
+To leverage the communication, a gradient would be added to a bucket first. When the bucket is full, each `g` in it would be reshaped as `[world_size, -1]`. And the `[local_rank]` parts would be united.
+The data structure looks like this:
+```
+{
+0: [g-00, g-10],
+1: [g-01, g-11],
+2: [g-02, g-12]
+}
+```
+After that, the gradients would be flattened by rank, and the data structure looks like this:
+```
+# g-0 means flatten([g-00, g-10])
+{
+0: [g-0],
+1: [g-1],
+2: [g-2]
+}
+```
+For zero1, we iterate the dictionary and do `all_reduce`. For zero2, we can just do `reduce-scatter`.
+
+### Optim
+For each rank gets its own `p32` and the counterpart `g`, it is quite easy to do `optim.step()`.
+
+However, we have to consider a situation of layer drop, for instance:
+```
+class MlpModel(nn.Module):
+ def __init__(self):
+ super(MlpModel, self).__init__()
+ self.linear1 = nn.Linear(128, 256)
+ self.drop_linear = nn.Linear(256, 256)
+ self.linear2 = nn.Linear(256, 512)
+
+ def forward(self, x):
+ x = self.linear1(x)
+ x = self.linear2(x)
+ return x
+```
+And the solution is to build a mapping of `p32`, `p`, and `g`. Before `optim.step()`, we collect `p` which `requires_grad=True` and `p.grad != None` as a real working param. And select the counterpart `p32` and `g`.
diff --git a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py
index c51b54c82f57..a94e8d42c78e 100644
--- a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py
+++ b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py
@@ -38,9 +38,8 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool):
optimizer_ckpt_path = f"{tempdir}/optimizer"
# lr scheduler is tested in test_torch_ddp_checkpoint_io.py and low level zero does not change it, we can skip it here
booster.save_model(model, model_ckpt_path, shard=shard)
- if not shard:
- # TODO(ver217): optimizer checkpointing is not supported for sharded checkpoint
- booster.save_optimizer(optimizer, optimizer_ckpt_path)
+ booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard)
+
dist.barrier()
new_model = resnet18()
@@ -49,9 +48,9 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool):
booster.load_model(new_model, model_ckpt_path)
check_state_dict_equal(model.state_dict(), new_model.state_dict(), False)
- if not shard:
- booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
- check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False)
+
+ booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
+ check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False)
def run_dist(rank, world_size, port):
@@ -62,3 +61,7 @@ def run_dist(rank, world_size, port):
@rerun_if_address_is_in_use()
def test_low_level_zero_checkpointIO():
spawn(run_dist, 2)
+
+
+if __name__ == "__main__":
+ test_low_level_zero_checkpointIO()
From 45b08f08cb8581986e513ef9162d93a8c07fd250 Mon Sep 17 00:00:00 2001
From: LuGY <74758262+Gy-Lu@users.noreply.github.com>
Date: Tue, 18 Jul 2023 14:44:13 +0800
Subject: [PATCH 40/64] [zero] optimize the optimizer step time (#4221)
* optimize the optimizer step time
* fix corner case
* polish
* replace all-reduce with all-gather
* set comm device to cuda
---
colossalai/zero/low_level/low_level_optim.py | 11 ++++++-----
1 file changed, 6 insertions(+), 5 deletions(-)
diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py
index 023db122fd33..2b3f50ed4fd4 100644
--- a/colossalai/zero/low_level/low_level_optim.py
+++ b/colossalai/zero/low_level/low_level_optim.py
@@ -412,15 +412,16 @@ def step(self, closure=None):
release_param_grad(self._master_param_groups_of_current_rank[group_id])
# update working partition updated by the current rank
+ dtype = real_working_params[0][0].dtype
for group_id in range(self.num_param_groups):
master_working_param = self.optim.param_groups[group_id]['params']
-
for idx, splited_param in enumerate(master_working_param):
- full_master_param = [torch.zeros_like(splited_param).cuda() for _ in range(self._world_size)]
- dist.all_gather(full_master_param, splited_param.cuda(), group=self.dp_pg)
working_param = real_working_params[group_id][idx]
- full_master_param = flatten(full_master_param)[:working_param.numel()].reshape_as(working_param)
- working_param.data.copy_(full_master_param)
+ all_splited_param = [
+ torch.zeros(splited_param.shape, device="cuda", dtype=dtype) for _ in range(self._world_size)
+ ]
+ dist.all_gather(all_splited_param, splited_param.cuda().to(dtype), group=self.dp_pg)
+ working_param.data.copy_(flatten(all_splited_param)[:working_param.numel()].reshape_as(working_param))
self.optim.param_groups[group_id]['params'] = self._master_param_groups_of_current_rank[group_id]
From 03654c0ce2de82c86367d16f594f692f0655108f Mon Sep 17 00:00:00 2001
From: LuGY <74758262+Gy-Lu@users.noreply.github.com>
Date: Tue, 1 Aug 2023 10:14:00 +0800
Subject: [PATCH 41/64] fix localhost measurement (#4320)
---
colossalai/cli/launcher/hostinfo.py | 7 ++-----
1 file changed, 2 insertions(+), 5 deletions(-)
diff --git a/colossalai/cli/launcher/hostinfo.py b/colossalai/cli/launcher/hostinfo.py
index d1b88b229fb8..2a6a111e4d72 100644
--- a/colossalai/cli/launcher/hostinfo.py
+++ b/colossalai/cli/launcher/hostinfo.py
@@ -46,11 +46,8 @@ def is_host_localhost(hostname: str, port: str = None) -> None:
localhost = socket.gethostname()
localaddrs = socket.getaddrinfo(localhost, port)
targetaddrs = socket.getaddrinfo(hostname, port)
- for (family, socktype, proto, canonname, sockaddr) in localaddrs:
- for (rfamily, rsocktype, rproto, rcanonname, rsockaddr) in targetaddrs:
- if rsockaddr[0] == sockaddr[0]:
- return True
- return False
+
+ return localaddrs == targetaddrs
def __str__(self):
return f'hostname: {self.hostname}, port: {self.port}'
From 75c53890378d3e72b4700a264f52524d4185168a Mon Sep 17 00:00:00 2001
From: Wenhao Chen
Date: Tue, 1 Aug 2023 10:21:45 +0800
Subject: [PATCH 42/64] [chat] fix compute_approx_kl (#4338)
---
applications/Chat/coati/models/utils.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/applications/Chat/coati/models/utils.py b/applications/Chat/coati/models/utils.py
index 772bfc32982a..8769fb7a8c43 100644
--- a/applications/Chat/coati/models/utils.py
+++ b/applications/Chat/coati/models/utils.py
@@ -19,7 +19,7 @@ def compute_approx_kl(log_probs: torch.Tensor,
action_mask: Mask for actions.
"""
- log_ratio = log_probs - log_probs_base
+ log_ratio = log_probs_base - log_probs
approx_kl = (log_ratio.exp() - 1) - log_ratio
if action_mask is not None:
approx_kl = masked_mean(approx_kl, action_mask, dim=1)
From 806477121d960a11c45d37c48247249201f97e97 Mon Sep 17 00:00:00 2001
From: Hongxin Liu
Date: Tue, 1 Aug 2023 15:01:19 +0800
Subject: [PATCH 43/64] [release] update version (#4332)
* [release] update version
* [devops] hotfix cuda extension building
* [devops] pytest ignore useless folders
---
.github/workflows/compatiblity_test_on_dispatch.yml | 2 +-
.github/workflows/compatiblity_test_on_pr.yml | 2 +-
.github/workflows/compatiblity_test_on_schedule.yml | 12 ++++++++++++
.github/workflows/cuda_ext_check_before_merge.yml | 12 ++++++++++++
pytest.ini | 1 +
version.txt | 2 +-
6 files changed, 28 insertions(+), 3 deletions(-)
diff --git a/.github/workflows/compatiblity_test_on_dispatch.yml b/.github/workflows/compatiblity_test_on_dispatch.yml
index 3dcc4dfd182a..1778d64ee287 100644
--- a/.github/workflows/compatiblity_test_on_dispatch.yml
+++ b/.github/workflows/compatiblity_test_on_dispatch.yml
@@ -72,7 +72,7 @@ jobs:
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
- name: Download cub for CUDA 10.2
run: |
- CUDA_VERSION=$(cat $CUDA_HOME/version.txt | grep "CUDA Version" | awk '{print $NF}' | cut -d. -f1,2)
+ CUDA_VERSION=$(nvcc -V | awk -F ',| ' '/release/{print $6}')
# check if it is CUDA 10.2
# download cub
diff --git a/.github/workflows/compatiblity_test_on_pr.yml b/.github/workflows/compatiblity_test_on_pr.yml
index 5098b8e364d0..c0f45c65a7fc 100644
--- a/.github/workflows/compatiblity_test_on_pr.yml
+++ b/.github/workflows/compatiblity_test_on_pr.yml
@@ -66,7 +66,7 @@ jobs:
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
- name: Download cub for CUDA 10.2
run: |
- CUDA_VERSION=$(cat $CUDA_HOME/version.txt | grep "CUDA Version" | awk '{print $NF}' | cut -d. -f1,2)
+ CUDA_VERSION=$(nvcc -V | awk -F ',| ' '/release/{print $6}')
# check if it is CUDA 10.2
# download cub
diff --git a/.github/workflows/compatiblity_test_on_schedule.yml b/.github/workflows/compatiblity_test_on_schedule.yml
index 9802795fad24..15ac4f1a92bb 100644
--- a/.github/workflows/compatiblity_test_on_schedule.yml
+++ b/.github/workflows/compatiblity_test_on_schedule.yml
@@ -61,6 +61,18 @@ jobs:
with:
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
+ - name: Download cub for CUDA 10.2
+ run: |
+ CUDA_VERSION=$(nvcc -V | awk -F ',| ' '/release/{print $6}')
+
+ # check if it is CUDA 10.2
+ # download cub
+ if [ "$CUDA_VERSION" = "10.2" ]; then
+ wget https://github.com/NVIDIA/cub/archive/refs/tags/1.8.0.zip
+ unzip 1.8.0.zip
+ cp -r cub-1.8.0/cub/ colossalai/kernel/cuda_native/csrc/kernels/include/
+ fi
+
- name: Install Colossal-AI
run: |
pip install -v --no-cache-dir .
diff --git a/.github/workflows/cuda_ext_check_before_merge.yml b/.github/workflows/cuda_ext_check_before_merge.yml
index eba5bb98ec07..686f0f395c73 100644
--- a/.github/workflows/cuda_ext_check_before_merge.yml
+++ b/.github/workflows/cuda_ext_check_before_merge.yml
@@ -37,6 +37,18 @@ jobs:
- name: Install PyTorch
run: eval ${{ matrix.build.torch_command }}
+ - name: Download cub for CUDA 10.2
+ run: |
+ CUDA_VERSION=$(nvcc -V | awk -F ',| ' '/release/{print $6}')
+
+ # check if it is CUDA 10.2
+ # download cub
+ if [ "$CUDA_VERSION" = "10.2" ]; then
+ wget https://github.com/NVIDIA/cub/archive/refs/tags/1.8.0.zip
+ unzip 1.8.0.zip
+ cp -r cub-1.8.0/cub/ colossalai/kernel/cuda_native/csrc/kernels/include/
+ fi
+
- name: Build
run: |
CUDA_EXT=1 pip install -v .
diff --git a/pytest.ini b/pytest.ini
index 01e5cd217c5d..e99fe3f086c6 100644
--- a/pytest.ini
+++ b/pytest.ini
@@ -4,3 +4,4 @@ markers =
gpu: tests which requires a single GPU
dist: tests which are run in a multi-GPU or multi-machine environment
experiment: tests for experimental features
+addopts = --ignore=tests/test_analyzer --ignore=tests/test_auto_parallel --ignore=tests/test_autochunk
diff --git a/version.txt b/version.txt
index 0d91a54c7d43..9e11b32fcaa9 100644
--- a/version.txt
+++ b/version.txt
@@ -1 +1 @@
-0.3.0
+0.3.1
From 16c0acc01b34c988eb3f452d21a1cd466e86dc73 Mon Sep 17 00:00:00 2001
From: caption <101684156+chncaption@users.noreply.github.com>
Date: Tue, 1 Aug 2023 16:25:25 +0800
Subject: [PATCH 44/64] [hotfix] update gradio 3.11 to 3.34.0 (#4329)
---
examples/images/diffusion/requirements.txt | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/examples/images/diffusion/requirements.txt b/examples/images/diffusion/requirements.txt
index 59d027fcf60f..0d9ce55a8079 100644
--- a/examples/images/diffusion/requirements.txt
+++ b/examples/images/diffusion/requirements.txt
@@ -12,7 +12,7 @@ einops==0.3.0
transformers
webdataset==0.2.5
open-clip-torch==2.7.0
-gradio==3.11
+gradio==3.34.0
lightning==1.9.0
datasets
colossalai
From 16bf4c022150fea303d23437b0190c46204e722c Mon Sep 17 00:00:00 2001
From: Hongxin Liu
Date: Tue, 1 Aug 2023 18:52:14 +0800
Subject: [PATCH 45/64] [test] remove useless tests (#4359)
* [test] remove legacy zero test
* [test] remove lazy distribute test
* [test] remove outdated checkpoint io
---
colossalai/utils/checkpoint_io/__init__.py | 2 -
colossalai/utils/checkpoint_io/backend.py | 74 ------
colossalai/utils/checkpoint_io/constant.py | 9 -
colossalai/utils/checkpoint_io/convertor.py | 227 ------------------
colossalai/utils/checkpoint_io/distributed.py | 127 ----------
colossalai/utils/checkpoint_io/io.py | 170 -------------
colossalai/utils/checkpoint_io/meta.py | 81 -------
colossalai/utils/checkpoint_io/reader.py | 131 ----------
colossalai/utils/checkpoint_io/utils.py | 223 -----------------
colossalai/utils/checkpoint_io/writer.py | 98 --------
tests/test_lazy/test_distribute.py | 102 --------
.../test_build_checkpoints.py | 120 ---------
.../test_checkpoint_io/test_load.py | 186 --------------
.../test_checkpoint_io/test_merge.py | 126 ----------
.../test_checkpoint_io/test_merge_param.py | 101 --------
.../test_checkpoint_io/test_redist.py | 152 ------------
.../test_checkpoint_io/test_save.py | 149 ------------
.../test_checkpoint_io/test_unmerge_param.py | 137 -----------
tests/test_zero/test_legacy/common.py | 140 -----------
tests/test_zero/test_legacy/test_found_inf.py | 67 ------
.../test_legacy/test_gemini_manager.py | 75 ------
.../test_legacy/test_init_context.py | 73 ------
tests/test_zero/test_legacy/test_param_op.py | 82 -------
.../test_legacy/test_shard_model_v2.py | 64 -----
.../test_zero/test_legacy/test_shard_param.py | 91 -------
.../test_sharded_optim_state_dict.py | 89 -------
.../test_legacy/test_sharded_optim_v2.py | 110 ---------
.../test_sharded_optim_with_sync_bn.py | 87 -------
.../test_zero/test_legacy/test_state_dict.py | 55 -----
.../test_legacy/test_tensor_utils.py | 94 --------
.../test_zero/test_legacy/test_zero_engine.py | 113 ---------
31 files changed, 3355 deletions(-)
delete mode 100644 colossalai/utils/checkpoint_io/__init__.py
delete mode 100644 colossalai/utils/checkpoint_io/backend.py
delete mode 100644 colossalai/utils/checkpoint_io/constant.py
delete mode 100644 colossalai/utils/checkpoint_io/convertor.py
delete mode 100644 colossalai/utils/checkpoint_io/distributed.py
delete mode 100644 colossalai/utils/checkpoint_io/io.py
delete mode 100644 colossalai/utils/checkpoint_io/meta.py
delete mode 100644 colossalai/utils/checkpoint_io/reader.py
delete mode 100644 colossalai/utils/checkpoint_io/utils.py
delete mode 100644 colossalai/utils/checkpoint_io/writer.py
delete mode 100644 tests/test_lazy/test_distribute.py
delete mode 100644 tests/test_utils/test_checkpoint_io/test_build_checkpoints.py
delete mode 100644 tests/test_utils/test_checkpoint_io/test_load.py
delete mode 100644 tests/test_utils/test_checkpoint_io/test_merge.py
delete mode 100644 tests/test_utils/test_checkpoint_io/test_merge_param.py
delete mode 100644 tests/test_utils/test_checkpoint_io/test_redist.py
delete mode 100644 tests/test_utils/test_checkpoint_io/test_save.py
delete mode 100644 tests/test_utils/test_checkpoint_io/test_unmerge_param.py
delete mode 100644 tests/test_zero/test_legacy/common.py
delete mode 100644 tests/test_zero/test_legacy/test_found_inf.py
delete mode 100644 tests/test_zero/test_legacy/test_gemini_manager.py
delete mode 100644 tests/test_zero/test_legacy/test_init_context.py
delete mode 100644 tests/test_zero/test_legacy/test_param_op.py
delete mode 100644 tests/test_zero/test_legacy/test_shard_model_v2.py
delete mode 100644 tests/test_zero/test_legacy/test_shard_param.py
delete mode 100644 tests/test_zero/test_legacy/test_sharded_optim_state_dict.py
delete mode 100644 tests/test_zero/test_legacy/test_sharded_optim_v2.py
delete mode 100644 tests/test_zero/test_legacy/test_sharded_optim_with_sync_bn.py
delete mode 100644 tests/test_zero/test_legacy/test_state_dict.py
delete mode 100644 tests/test_zero/test_legacy/test_tensor_utils.py
delete mode 100644 tests/test_zero/test_legacy/test_zero_engine.py
diff --git a/colossalai/utils/checkpoint_io/__init__.py b/colossalai/utils/checkpoint_io/__init__.py
deleted file mode 100644
index fe030866894f..000000000000
--- a/colossalai/utils/checkpoint_io/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-from .io import load, merge, redist, save
-from .meta import (ParamDistMeta, ParamRedistMeta, PipelineRedistMeta, RankRedistMeta, RedistMeta)
diff --git a/colossalai/utils/checkpoint_io/backend.py b/colossalai/utils/checkpoint_io/backend.py
deleted file mode 100644
index 140192c05f12..000000000000
--- a/colossalai/utils/checkpoint_io/backend.py
+++ /dev/null
@@ -1,74 +0,0 @@
-import shutil
-import tempfile
-from abc import ABC, abstractmethod
-from typing import Dict, List, Type
-
-from .reader import CheckpointReader, DiskCheckpointReader
-from .writer import CheckpointWriter, DiskCheckpointWriter
-
-_backends: Dict[str, Type['CheckpointIOBackend']] = {}
-
-
-def register(name: str):
- assert name not in _backends, f'"{name}" is registered'
-
- def wrapper(cls):
- _backends[name] = cls
- return cls
-
- return wrapper
-
-
-def get_backend(name: str) -> 'CheckpointIOBackend':
- assert name in _backends, f'Unsupported backend "{name}"'
- return _backends[name]()
-
-
-class CheckpointIOBackend(ABC):
-
- def __init__(self) -> None:
- super().__init__()
- self.temps: List[str] = []
-
- @abstractmethod
- def get_writer(self,
- base_name: str,
- overwrite: bool = False,
- rank: int = 0,
- world_size: int = 1) -> CheckpointWriter:
- pass
-
- @abstractmethod
- def get_reader(self, base_name: str) -> CheckpointReader:
- pass
-
- @abstractmethod
- def get_temp(self, base_name: str) -> str:
- pass
-
- @abstractmethod
- def clean_temp(self) -> None:
- pass
-
-
-@register('disk')
-class CheckpointDiskIO(CheckpointIOBackend):
-
- def get_writer(self,
- base_name: str,
- overwrite: bool = False,
- rank: int = 0,
- world_size: int = 1) -> CheckpointWriter:
- return DiskCheckpointWriter(base_name, overwrite, rank=rank, world_size=world_size)
-
- def get_reader(self, base_name: str) -> CheckpointReader:
- return DiskCheckpointReader(base_name)
-
- def get_temp(self, base_name: str) -> str:
- temp_dir_name = tempfile.mkdtemp(dir=base_name)
- self.temps.append(temp_dir_name)
- return temp_dir_name
-
- def clean_temp(self) -> None:
- for temp_dir_name in self.temps:
- shutil.rmtree(temp_dir_name)
diff --git a/colossalai/utils/checkpoint_io/constant.py b/colossalai/utils/checkpoint_io/constant.py
deleted file mode 100644
index 2199484741bf..000000000000
--- a/colossalai/utils/checkpoint_io/constant.py
+++ /dev/null
@@ -1,9 +0,0 @@
-import re
-
-GLOBAL_META_FILE_NAME = 'global_meta.bin'
-MODEL_CKPT_FILE_NAME = 'model.bin'
-OPTIM_CKPT_FILE_NAME = 'optim.bin'
-META_CKPT_FILE_NAME = 'meta.bin'
-OTHER_CKPT_FILE_NAME = 'other.bin'
-
-CKPT_PAT = re.compile(r'global_meta|model|optim|meta|other')
diff --git a/colossalai/utils/checkpoint_io/convertor.py b/colossalai/utils/checkpoint_io/convertor.py
deleted file mode 100644
index 529ceb86829b..000000000000
--- a/colossalai/utils/checkpoint_io/convertor.py
+++ /dev/null
@@ -1,227 +0,0 @@
-from abc import ABC, abstractmethod
-from collections import defaultdict
-from typing import Any, Callable, Dict, List, Optional
-
-from torch import Tensor
-
-from .distributed import merge_param, unmerge_param
-from .meta import ParamDistMeta, RedistMeta
-from .utils import (ModelCheckpointSharder, OptimizerCheckpointSharder, run_if_not_none)
-
-
-class CheckpointConvertor(ABC):
-
- @abstractmethod
- def append(self, shard_dict: Dict[int, dict], dist_meta_list: List[Optional[Dict[str, ParamDistMeta]]]) -> None:
- pass
-
- @abstractmethod
- def complete(self) -> None:
- pass
-
-
-class ModelCheckpointConvertor(CheckpointConvertor):
-
- def __init__(self, param_count: Dict[str, int]) -> None:
- super().__init__()
- self.param_count = param_count
- self.buffer: Dict[str, Dict[int, Tensor]] = defaultdict(dict)
-
- @abstractmethod
- def convert_tensors(self, key: str, tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> None:
- pass
-
- def append(self, shard_dict: Dict[int, dict], dist_meta_list: List[Optional[Dict[str, ParamDistMeta]]]) -> None:
- for rank, state_dict in shard_dict.items():
- for k, tensor in state_dict.items():
- self.buffer[k][rank] = tensor
- converted_keys = set()
- for k, rank_dict in self.buffer.items():
- if len(rank_dict) == self.param_count[k]:
- tensors = []
- dist_metas = []
- for rank, tensor in rank_dict.items():
- tensors.append(tensor)
- if dist_meta_list[rank] is not None:
- dist_metas.append(dist_meta_list[rank][k])
- self.convert_tensors(k, tensors, dist_metas)
- converted_keys.add(k)
- for k in converted_keys:
- del self.buffer[k]
-
- def complete(self) -> None:
- assert len(self.buffer) == 0
-
-
-class ModelCheckpointMerger(ModelCheckpointConvertor):
-
- def __init__(self, max_shard_size: int, save_fn: Callable[[dict], Any], param_count: Dict[str, int]) -> None:
- super().__init__(param_count)
- self.sharder = ModelCheckpointSharder(max_shard_size)
- self.save_fn = save_fn
-
- def convert_tensors(self, key: str, tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> None:
- assert len(dist_metas) == len(tensors)
- tensor = merge_param(tensors, dist_metas)
- shard = self.sharder.append(key, tensor)
- run_if_not_none(self.save_fn, shard)
-
- def complete(self) -> None:
- super().complete()
- run_if_not_none(self.save_fn, self.sharder.complete())
-
-
-class ModelCheckpointRedistor(ModelCheckpointConvertor):
-
- def __init__(self, max_shard_size: int, save_fns: List[Callable[[dict], Any]], param_count: Dict[str, int],
- redist_meta: RedistMeta) -> None:
- super().__init__(param_count)
- self.save_fns = save_fns
- self.redist_meta = redist_meta
- nprocs = len(save_fns)
- self.sharders = [ModelCheckpointSharder(max_shard_size) for _ in range(nprocs)]
- self.rank_map = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
- for k, rank_meta in redist_meta.rank_meta.items():
- for rank, rank_info in rank_meta.items():
- self.rank_map[k][rank_info.tp_rank][rank_info.dp_rank].append(rank)
-
- def convert_tensors(self, key: str, tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> None:
- if len(dist_metas) == 0:
- # already global
- tensor = tensors[0]
- else:
- assert len(dist_metas) == len(tensors)
- tensor = merge_param(tensors, dist_metas)
- for tp_rank, tensor_list in enumerate(unmerge_param(tensor, self.redist_meta.param_meta[key])):
- for dp_rank, t in enumerate(tensor_list):
- for rank in self.rank_map[key][tp_rank][dp_rank]:
- shard = self.sharders[rank].append(key, t)
- run_if_not_none(self.save_fns[rank], shard)
-
- def complete(self) -> None:
- super().complete()
- for rank, save_fn in enumerate(self.save_fns):
- run_if_not_none(save_fn, self.sharders[rank].complete())
-
-
-class OptimizerCheckpointConvertor(CheckpointConvertor):
-
- def __init__(self, param_count: Dict[str, int], param_to_os: Optional[Dict[str, int]],
- paired_os: Optional[Dict[int, dict]]) -> None:
- super().__init__()
- self.param_count = param_count
- self.param_to_os = param_to_os
- self.paired_os = paired_os
- self.buffer: Dict[int, Dict[int, dict]] = defaultdict(dict)
- self.os_to_param = {v: k for k, v in param_to_os.items()}
-
- @abstractmethod
- def setup(self, param_groups: dict) -> None:
- pass
-
- @abstractmethod
- def convert_states(self, idx: int, states: List[dict], dist_metas: List[ParamDistMeta]) -> None:
- pass
-
- def append(self, shard_dict: Dict[int, dict], dist_meta_list: List[Optional[Dict[str, ParamDistMeta]]]) -> None:
- for rank, state_dict in shard_dict.items():
- self.setup(state_dict['param_groups'])
- for idx, state in state_dict['state'].items():
- self.buffer[idx][rank] = state
- converted_indices = set()
- for idx, rank_dict in self.buffer.items():
- if len(rank_dict) == self.param_count[self.os_to_param[idx]]:
- states = []
- dist_metas = []
- for rank, state in rank_dict.items():
- states.append(state)
- if dist_meta_list[rank] is not None:
- dist_metas.append(dist_meta_list[rank][self.os_to_param[idx]])
- self.convert_states(idx, states, dist_metas)
- converted_indices.add(idx)
- for idx in converted_indices:
- del self.buffer[idx]
-
- def complete(self) -> None:
- assert len(self.buffer) == 0
-
-
-class OptimizerCheckpointMerger(OptimizerCheckpointConvertor):
-
- def __init__(self, max_shard_size: int, save_fn: Callable[[dict], Any], param_count: Dict[str, int],
- param_to_os: Optional[Dict[str, int]], paired_os: Optional[Dict[int, dict]]) -> None:
- super().__init__(param_count, param_to_os, paired_os)
- self.max_shard_size = max_shard_size
- self.save_fn = save_fn
- self.sharder = None
-
- def setup(self, param_groups: dict) -> None:
- if self.sharder is None:
- self.sharder = OptimizerCheckpointSharder(self.max_shard_size, param_groups)
-
- def convert_states(self, idx: int, states: List[dict], dist_metas: List[ParamDistMeta]) -> None:
- assert len(dist_metas) == len(states)
- new_state = {}
- for state_key, state_tensor in states[0].items():
- if self.paired_os[idx][state_key]:
- new_state[state_key] = merge_param([state[state_key] for state in states], dist_metas)
- else:
- new_state[state_key] = state_tensor
- shard = self.sharder.append(idx, new_state)
- run_if_not_none(self.save_fn, shard)
-
- def complete(self) -> None:
- super().complete()
- run_if_not_none(self.save_fn, self.sharder.complete())
-
-
-class OptimizerCheckpointRedistor(OptimizerCheckpointConvertor):
-
- def __init__(self, max_shard_size: int, save_fns: List[Callable[[dict], Any]], param_count: Dict[str, int],
- param_to_os: Optional[Dict[str, int]], paired_os: Optional[Dict[int, dict]],
- redist_meta: RedistMeta) -> None:
- super().__init__(param_count, param_to_os, paired_os)
- self.max_shard_size = max_shard_size
- self.save_fns = save_fns
- self.redist_meta = redist_meta
- self.sharders: List[OptimizerCheckpointSharder] = []
- self.rank_map = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
- for k, rank_meta in redist_meta.rank_meta.items():
- for rank, rank_info in rank_meta.items():
- self.rank_map[k][rank_info.tp_rank][rank_info.dp_rank].append(rank)
-
- def setup(self, param_groups: dict) -> None:
- if len(self.sharders) == 0:
- nprocs = len(self.save_fns)
- for _ in range(nprocs):
- self.sharders.append(OptimizerCheckpointSharder(self.max_shard_size, param_groups))
-
- def convert_states(self, idx: int, states: List[dict], dist_metas: List[ParamDistMeta]) -> None:
- need_merge: bool = True
- if len(dist_metas) == 0:
- need_merge = False
- else:
- assert len(dist_metas) == len(states)
- new_states = [{} for _ in range(len(self.save_fns))]
- for state_key, state_tensor in states[0].items():
- if self.paired_os[idx][state_key]:
- if need_merge:
- tensor = merge_param([state[state_key] for state in states], dist_metas)
- else:
- tensor = state_tensor
- for tp_rank, tensor_list in enumerate(
- unmerge_param(tensor, self.redist_meta.param_meta[self.os_to_param[idx]])):
- for dp_rank, t in enumerate(tensor_list):
- for rank in self.rank_map[self.os_to_param[idx]][tp_rank][dp_rank]:
- new_states[rank][state_key] = t
- else:
- for new_state in new_states:
- new_state[state_key] = state_tensor
- for rank, new_state in enumerate(new_states):
- shard = self.sharders[rank].append(idx, new_state)
- run_if_not_none(self.save_fns[rank], shard)
-
- def complete(self) -> None:
- super().complete()
- for rank, save_fn in enumerate(self.save_fns):
- run_if_not_none(save_fn, self.sharders[rank].complete())
diff --git a/colossalai/utils/checkpoint_io/distributed.py b/colossalai/utils/checkpoint_io/distributed.py
deleted file mode 100644
index bf720437c41a..000000000000
--- a/colossalai/utils/checkpoint_io/distributed.py
+++ /dev/null
@@ -1,127 +0,0 @@
-import torch
-from numpy import prod
-from torch import Tensor
-from typing import List, Optional, Tuple
-from collections import defaultdict
-from .meta import ParamDistMeta, ParamRedistMeta
-
-
-def unflatten_zero_param(tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> Tensor:
- assert len(tensors) > 0 and len(dist_metas) > 0 and len(tensors) == len(dist_metas)
- for dist_meta in dist_metas[1:]:
- assert dist_meta.zero_meta == dist_metas[0].zero_meta, 'Expect all params have the same zero meta.'
- if not dist_metas[0].used_zero:
- # tensors are replicate
- return tensors[0]
- numel = dist_metas[0].zero_numel
- orig_shape = dist_metas[0].zero_orig_shape
- tensors = [t[1] for t in sorted(zip(dist_metas, tensors), key=lambda tp: tp[0].dp_rank)]
- assert numel == sum(t.numel() for t in tensors), 'Expect numel of all params is equal to zero_numel.'
- return torch.cat(tensors).reshape(orig_shape)
-
-
-def gather_tp_param(tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> Tensor:
- assert len(tensors) > 0 and len(dist_metas) > 0 and len(tensors) == len(dist_metas)
- for dist_meta in dist_metas[1:]:
- assert dist_meta.tp_meta == dist_metas[0].tp_meta, 'Expect all params have the same tp meta.'
- for t in tensors[1:]:
- assert t.shape == tensors[0].shape, 'Expect all params have the same shape.'
- if not dist_metas[0].used_tp:
- # tensors are replicate
- return tensors[0]
- total_parts = prod(dist_meta.tp_num_parts)
- assert dist_meta.tp_world_size == total_parts, \
- f'Expect prod(tp_num_parts) == tp_world_size, got {total_parts} and {dist_meta.tp_world_size}.'
- shard_info = sorted(zip(dist_meta.tp_shard_dims, dist_meta.tp_num_parts), key=lambda t: t[0], reverse=True)
- for dim, num_parts in shard_info:
- buffer = []
- for start in range(0, len(tensors), num_parts):
- buffer.append(torch.cat(tensors[start:start + num_parts], dim))
- tensors = buffer
- assert len(tensors) == 1
- return tensors[0]
-
-
-def validate_parallel_info(dist_metas: List[ParamDistMeta]) -> None:
- assert len(dist_metas) > 0
- # check world size
- for dist_meta in dist_metas[1:]:
- assert dist_meta.dp_world_size == dist_metas[
- 0].dp_world_size, 'Expect all dist meta have the same dp_world_size'
- assert dist_meta.tp_world_size == dist_metas[
- 0].tp_world_size, 'Expect all dist meta have the same tp_world_size'
-
-
-def deduplicate_params(tensors: List[Tensor],
- dist_metas: List[ParamDistMeta]) -> Tuple[List[Tensor], List[ParamDistMeta]]:
- unique_dist_meta = []
- unique_idx = []
- for i, dist_meta in enumerate(dist_metas):
- if dist_meta not in unique_dist_meta:
- unique_dist_meta.append(dist_meta)
- unique_idx.append(i)
- return [tensors[i] for i in unique_idx], [dist_metas[i] for i in unique_idx]
-
-
-def merge_param(tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> Tensor:
- assert len(tensors) > 0 and len(dist_metas) > 0 and len(tensors) == len(dist_metas)
- # validate parallel info
- validate_parallel_info(dist_metas)
- tensors, dist_metas = deduplicate_params(tensors, dist_metas)
- unflattened_tensors = []
- # group zero params by tp rank
- tensor_dict = defaultdict(list)
- dist_meta_dict = defaultdict(list)
- for t, dist_meta in zip(tensors, dist_metas):
- tensor_dict[dist_meta.tp_rank].append(t)
- dist_meta_dict[dist_meta.tp_rank].append(dist_meta)
- assert len(tensor_dict
- ) == dist_metas[0].tp_world_size, f'Expect {dist_metas[0].tp_world_size} ranks, got {len(tensor_dict)}'
- for tp_rank in tensor_dict.keys():
- unflattened_tensors.append(unflatten_zero_param(tensor_dict[tp_rank], dist_meta_dict[tp_rank]))
- return gather_tp_param(unflattened_tensors, [dist_meta_list[0] for dist_meta_list in dist_meta_dict.values()])
-
-
-def split_tp_param(tensor: Tensor, redist_meta: ParamRedistMeta) -> List[Tensor]:
- if not redist_meta.used_tp:
- assert redist_meta.tp_world_size == 1, 'Expect tp_world_size == 1, when no tp meta provided.'
- return [tensor]
- total_parts = prod(redist_meta.tp_num_parts)
- assert redist_meta.tp_world_size == total_parts, f'Expect prod(tp_num_parts) == tp_world_size, got {total_parts} and {redist_meta.tp_world_size}.'
- shard_info = sorted(zip(redist_meta.tp_shard_dims, redist_meta.tp_num_parts), key=lambda t: t[0])
- tensors = [tensor]
- for dim, num_parts in shard_info:
- buffer = []
- for t in tensors:
- assert t.size(dim) % num_parts == 0, \
- f'Expect dim{dim} of tensor({tensor.shape}) is divisible by {num_parts}.'
- chunks = [chunk.contiguous() for chunk in t.chunk(num_parts, dim)]
- buffer.extend(chunks)
- tensors = buffer
- assert len(tensors) == redist_meta.tp_world_size
- return tensors
-
-
-def flatten_zero_param(tensor: Tensor, redist_meta: ParamRedistMeta) -> List[Tensor]:
- if not redist_meta.used_zero:
- return [tensor] * redist_meta.dp_world_size
- tensors: List[Optional[Tensor]] = [
- torch.empty(0, dtype=tensor.dtype, device=tensor.device) for _ in range(redist_meta.zero_start_dp_rank)
- ]
- offsets = redist_meta.zero_offsets + [tensor.numel()]
- for i, offset in enumerate(offsets[:-1]):
- end = offsets[i + 1]
- tensors.append(tensor.view(-1)[offset:end])
- if len(tensors) < redist_meta.dp_world_size:
- tensors.extend([
- torch.empty(0, dtype=tensor.dtype, device=tensor.device)
- for _ in range(redist_meta.dp_world_size - len(tensors))
- ])
- assert len(tensors) == redist_meta.dp_world_size
- return tensors
-
-
-def unmerge_param(tensor: Tensor, redist_meta: ParamRedistMeta) -> List[List[Tensor]]:
- tensors = split_tp_param(tensor, redist_meta)
- tensors = [flatten_zero_param(t, redist_meta) for t in tensors]
- return tensors
diff --git a/colossalai/utils/checkpoint_io/io.py b/colossalai/utils/checkpoint_io/io.py
deleted file mode 100644
index f00212cdf859..000000000000
--- a/colossalai/utils/checkpoint_io/io.py
+++ /dev/null
@@ -1,170 +0,0 @@
-import warnings
-from typing import Any, Callable, Dict, Generator, List, Optional, Tuple
-
-import torch.distributed as dist
-from torch.nn import Module
-from torch.optim import Optimizer
-
-from .backend import get_backend
-from .convertor import (CheckpointConvertor, ModelCheckpointMerger, ModelCheckpointRedistor, OptimizerCheckpointMerger,
- OptimizerCheckpointRedistor)
-from .meta import ParamDistMeta, RedistMeta
-from .utils import build_checkpoints, optimizer_load_state_dict
-
-
-def save(path: str,
- model: Module,
- optimizer: Optional[Optimizer] = None,
- param_to_os: Optional[Dict[str, int]] = None,
- dist_meta: Optional[Dict[str, ParamDistMeta]] = None,
- max_shard_size_gb: float = 0.0,
- overwrite: bool = False,
- backend: str = 'disk',
- **kwargs: Any) -> None:
- io_backend = get_backend(backend)
- if dist.is_initialized():
- rank = dist.get_rank()
- world_size = dist.get_world_size()
- else:
- rank = 0
- world_size = 1
- if world_size == 1:
- # global doesn't need dist_meta
- dist_meta = None
- else:
- assert dist_meta is not None
- max_shard_size = int(max_shard_size_gb * 1024**3)
- model_checkpoints, optimizer_checkpoints, meta_checkpoint = build_checkpoints(max_shard_size, model, optimizer,
- param_to_os, dist_meta)
- writer = io_backend.get_writer(path, overwrite, rank, world_size)
- writer.save_others(kwargs)
- for model_checkpoint in model_checkpoints:
- writer.save_model(model_checkpoint)
- for optimizer_checkpoint in optimizer_checkpoints:
- writer.save_optimizer(optimizer_checkpoint)
- writer.save_meta(meta_checkpoint)
-
-
-def merge(path: str,
- output_path: str,
- max_shard_size_gb: float = 0.0,
- overwrite: bool = False,
- backend: str = 'disk') -> bool:
- io_backend = get_backend(backend)
- if dist.is_initialized() and dist.get_rank() != 0:
- return False
- reader = io_backend.get_reader(path)
- if len(reader.meta_list) == 1:
- # already global
- warnings.warn(f'Checkpoint at "{path}" is already global, nothing to do.')
- return False
- dist_meta_list, param_count, param_to_os, paired_os = reader.load_meta()
- writer = io_backend.get_writer(output_path, overwrite=overwrite)
- writer.save_others(reader.load_others())
- max_shard_size = int(max_shard_size_gb * 1024**3)
- _convert_shards(ModelCheckpointMerger(max_shard_size, writer.save_model, param_count), reader.load_models(),
- dist_meta_list)
- _convert_shards(
- OptimizerCheckpointMerger(max_shard_size, writer.save_optimizer, param_count, param_to_os, paired_os),
- reader.load_optimizers(), dist_meta_list)
- meta_checkpoint = {'dist_meta': None, 'params': list(param_count.keys())}
- if param_to_os is not None:
- meta_checkpoint['param_to_os'] = param_to_os
- meta_checkpoint['paired_os'] = paired_os
- writer.save_meta(meta_checkpoint)
- return True
-
-
-def redist(path: str,
- output_path: str,
- redist_meta: RedistMeta,
- dist_metas: List[Dict[str, ParamDistMeta]],
- max_shard_size_gb: float = 0.0,
- overwrite: bool = False,
- backend: str = 'disk') -> bool:
- io_backend = get_backend(backend)
- if dist.is_initialized() and dist.get_rank() != 0:
- return False
- nprocs = len(dist_metas)
- reader = io_backend.get_reader(path)
- dist_meta_list, param_count, param_to_os, paired_os = reader.load_meta()
- do_redist: bool = False
- if len(dist_meta_list) == nprocs:
- for a, b in zip(dist_metas, dist_meta_list):
- if a != b:
- do_redist = True
- break
- else:
- do_redist = True
- if not do_redist:
- warnings.warn(f'Checkpoint at "{path}" is not required to redist, nothing to do.')
- return False
-
- writers = [io_backend.get_writer(output_path, overwrite, rank, nprocs) for rank in range(nprocs)]
- writers[0].save_others(reader.load_others())
- max_shard_size = int(max_shard_size_gb * 1024**3)
- _convert_shards(
- ModelCheckpointRedistor(max_shard_size, [writer.save_model for writer in writers], param_count, redist_meta),
- reader.load_models(), dist_meta_list)
- _convert_shards(
- OptimizerCheckpointRedistor(max_shard_size, [writer.save_optimizer for writer in writers], param_count,
- param_to_os, paired_os, redist_meta), reader.load_optimizers(), dist_meta_list)
- for writer, dist_meta in zip(writers, dist_metas):
- meta_checkpoint = {'dist_meta': dist_meta, 'params': list(param_count.keys())}
- if param_to_os is not None:
- meta_checkpoint['param_to_os'] = param_to_os
- meta_checkpoint['paired_os'] = paired_os
- writer.save_meta(meta_checkpoint)
- return True
-
-
-def _convert_shards(convertor: CheckpointConvertor, shard_generator: Generator[dict, None, None],
- dist_meta_list: List[Optional[Dict[str, ParamDistMeta]]]) -> None:
- for shard_dict in shard_generator:
- convertor.append(shard_dict, dist_meta_list)
- convertor.complete()
-
-
-def load(path: str,
- model: Module,
- optimizer: Optional[Optimizer] = None,
- redist_meta: Optional[RedistMeta] = None,
- dist_metas: Optional[List[Dict[str, ParamDistMeta]]] = None,
- max_shard_size_gb: float = 0.0,
- backend: str = 'disk') -> dict:
- is_global: bool = not dist.is_initialized() or dist.get_world_size() == 1
- rank: int = dist.get_rank() if dist.is_initialized() else 0
- is_main_process: bool = rank == 0
- # validate args
- if redist_meta is None or dist_metas is None:
- assert is_global
- io_backend = get_backend(backend)
- read_path: str = path
- if is_main_process:
- # pre-process checkpoints
- temp_path = io_backend.get_temp(path)
- if is_global:
- wrote = merge(path, temp_path, max_shard_size_gb, backend=backend)
- else:
- wrote = redist(path, temp_path, redist_meta, dist_metas, max_shard_size_gb, backend=backend)
- if wrote:
- read_path = temp_path
- if not is_global:
- bcast_list = [read_path] if is_main_process else [None]
- dist.broadcast_object_list(bcast_list)
- read_path = bcast_list[0]
- reader = io_backend.get_reader(read_path)
- # load model
- for shard in reader.load_model(rank):
- model.load_state_dict(shard, strict=False)
- if optimizer is not None:
- for shard in reader.load_optimizer(rank):
- # optimizer.load_state_dict(shard)
- optimizer_load_state_dict(optimizer, shard)
- others_dict = reader.load_others()
- if not is_global:
- dist.barrier()
- # clean up temp
- if is_main_process:
- io_backend.clean_temp()
- return others_dict
diff --git a/colossalai/utils/checkpoint_io/meta.py b/colossalai/utils/checkpoint_io/meta.py
deleted file mode 100644
index 994f08b4b5e4..000000000000
--- a/colossalai/utils/checkpoint_io/meta.py
+++ /dev/null
@@ -1,81 +0,0 @@
-from dataclasses import dataclass
-from typing import List, Optional, Set, Dict
-
-
-@dataclass
-class ParamDistMeta:
- # parallel info
- dp_rank: int
- dp_world_size: int
- tp_rank: int
- tp_world_size: int
- # tp info
- tp_shard_dims: Optional[List[int]] = None
- tp_num_parts: Optional[List[int]] = None
- # zero info
- zero_numel: Optional[int] = None
- zero_orig_shape: Optional[List[int]] = None
-
- @property
- def used_tp(self) -> bool:
- return self.tp_shard_dims is not None and self.tp_num_parts is not None
-
- @property
- def used_zero(self) -> bool:
- return self.zero_numel is not None and self.zero_orig_shape is not None
-
- @property
- def parallel_meta(self) -> tuple:
- return self.dp_rank, self.dp_world_size, self.tp_rank, self.tp_world_size
-
- @property
- def tp_meta(self) -> tuple:
- return self.tp_shard_dims, self.tp_num_parts
-
- @property
- def zero_meta(self) -> tuple:
- return self.zero_numel, self.zero_orig_shape
-
- @staticmethod
- def from_dict(d: dict) -> 'ParamDistMeta':
- return ParamDistMeta(**d)
-
-
-@dataclass
-class ParamRedistMeta:
- # parallel info
- dp_world_size: int
- tp_world_size: int
- # tp info
- tp_shard_dims: Optional[List[int]] = None
- tp_num_parts: Optional[List[int]] = None
- # zero info
- zero_start_dp_rank: Optional[int] = None
- zero_offsets: Optional[List[int]] = None
-
- @property
- def used_tp(self) -> bool:
- return self.tp_shard_dims is not None and self.tp_num_parts is not None
-
- @property
- def used_zero(self) -> bool:
- return self.zero_start_dp_rank is not None and self.zero_offsets is not None
-
-
-@dataclass
-class RankRedistMeta:
- dp_rank: int
- tp_rank: int
- pp_rank: int
-
-
-@dataclass
-class PipelineRedistMeta:
- params: Set[str]
-
-
-@dataclass
-class RedistMeta:
- rank_meta: Dict[str, Dict[int, RankRedistMeta]]
- pipeline_meta: List[PipelineRedistMeta]
- param_meta: Dict[str, ParamRedistMeta]
diff --git a/colossalai/utils/checkpoint_io/reader.py b/colossalai/utils/checkpoint_io/reader.py
deleted file mode 100644
index 3158c6481263..000000000000
--- a/colossalai/utils/checkpoint_io/reader.py
+++ /dev/null
@@ -1,131 +0,0 @@
-import os
-from abc import ABC, abstractmethod
-from collections import Counter
-from typing import Dict, Generator, List, Optional, Tuple
-
-import torch
-
-from .constant import GLOBAL_META_FILE_NAME, OTHER_CKPT_FILE_NAME
-from .meta import ParamDistMeta
-from .utils import is_duplicated_list
-
-
-class CheckpointReader(ABC):
-
- def __init__(self, base_name: str) -> None:
- super().__init__()
- self.base_name = base_name
- self.meta_list = []
-
- @abstractmethod
- def read(self, name: str) -> dict:
- pass
-
- @abstractmethod
- def load_meta(
- self) -> Tuple[List[Optional[Dict[str, ParamDistMeta]]], Dict[str, int], Optional[dict], Optional[dict]]:
- pass
-
- @abstractmethod
- def load_model(self, rank: int) -> Generator[dict, None, None]:
- pass
-
- @abstractmethod
- def load_models(self) -> Generator[Dict[int, dict], None, None]:
- pass
-
- @abstractmethod
- def load_optimizer(self, rank: int) -> Generator[dict, None, None]:
- pass
-
- @abstractmethod
- def load_optimizers(self) -> Generator[Dict[int, dict], None, None]:
- pass
-
- @abstractmethod
- def load_others(self) -> dict:
- pass
-
-
-class DiskCheckpointReader(CheckpointReader):
-
- def __init__(self, base_name: str) -> None:
- super().__init__(base_name)
- assert os.path.isdir(base_name), f'"{base_name}" is not a directory'
- global_meta = self.read(GLOBAL_META_FILE_NAME)
- for meta_file_name in global_meta['meta']:
- meta = self.read(meta_file_name)
- if meta.get('dist_meta', None) is None:
- # only global checkpoint can have empty dist_meta
- assert len(global_meta['meta']) == 1
- self.meta_list.append(meta)
-
- def read(self, name: str) -> dict:
- return torch.load(os.path.join(self.base_name, name))
-
- def load_meta(
- self) -> Tuple[List[Optional[Dict[str, ParamDistMeta]]], Dict[str, int], Optional[dict], Optional[dict]]:
- meta_infos = [(meta.get('dist_meta', None), meta['params'], meta.get('param_to_os',
- None), meta.get('paired_os', None))
- for meta in self.meta_list]
- dist_meta_list, params_list, param_to_os_list, paired_os_list = zip(*meta_infos)
- # reduce param_count
- param_count = Counter(p for params in params_list for p in params)
- # validate param_to_os
- assert is_duplicated_list(param_to_os_list)
- assert is_duplicated_list(paired_os_list)
- return list(dist_meta_list), param_count, param_to_os_list[0], paired_os_list[0]
-
- def _load_shard(self, shard_type: str, rank: int) -> Generator[dict, None, None]:
- meta = self.meta_list[rank]
- checkpoint_names = meta.get(shard_type, [])
- for name in checkpoint_names:
- yield self.read(name)
-
- def load_model(self, rank: int) -> Generator[dict, None, None]:
- return self._load_shard('model', rank)
-
- def load_models(self) -> Generator[Dict[int, dict], None, None]:
- indices = [0] * len(self.meta_list)
- while True:
- shards = {}
- for i, meta in enumerate(self.meta_list):
- model_checkpoint_names = meta.get('model', [])
- if indices[i] < len(model_checkpoint_names):
- shards[i] = self.read(model_checkpoint_names[indices[i]])
- indices[i] += 1
- if len(shards) > 0:
- yield shards
- else:
- break
-
- def load_optimizer(self, rank: int) -> Generator[dict, None, None]:
- param_groups = None
- for shard in self._load_shard('optimizer', rank):
- if param_groups is None:
- param_groups = shard['param_groups']
- else:
- shard['param_groups'] = param_groups
- yield shard
-
- def load_optimizers(self) -> Generator[Dict[int, dict], None, None]:
- indices = [0] * len(self.meta_list)
- param_groups = []
- while True:
- shards = {}
- for i, meta in enumerate(self.meta_list):
- optimizer_checkpoint_names = meta.get('optimizer', [])
- if indices[i] < len(optimizer_checkpoint_names):
- shards[i] = self.read(optimizer_checkpoint_names[indices[i]])
- if indices[i] == 0:
- param_groups.append(shards[i]['param_groups'])
- else:
- shards[i]['param_groups'] = param_groups[i]
- indices[i] += 1
- if len(shards) > 0:
- yield shards
- else:
- break
-
- def load_others(self) -> dict:
- return self.read(OTHER_CKPT_FILE_NAME)
diff --git a/colossalai/utils/checkpoint_io/utils.py b/colossalai/utils/checkpoint_io/utils.py
deleted file mode 100644
index 135385f57379..000000000000
--- a/colossalai/utils/checkpoint_io/utils.py
+++ /dev/null
@@ -1,223 +0,0 @@
-import warnings
-from copy import deepcopy
-from itertools import chain
-from typing import Any, Callable, Dict, List, Optional, Tuple
-
-from torch import Tensor
-from torch.nn import Module
-from torch.nn.parameter import Parameter
-from torch.optim import Optimizer
-
-from .meta import ParamDistMeta
-
-
-def run_if_not_none(fn: Callable[[Any], Any], arg: Any) -> Any:
- if arg is not None:
- return fn(arg)
-
-
-def get_param_to_os(model: Module, optimizer: Optimizer) -> Dict[str, int]:
- # ensure all params in optimizer are in model state dict
- params_set = set(id(p) for p in model.parameters())
- for group in optimizer.param_groups:
- for p in group['params']:
- assert id(p) in params_set
- param_mappings = {}
- start_index = 0
-
- def get_group_mapping(group):
- nonlocal start_index
- param_mappings.update(
- {id(p): i for i, p in enumerate(group['params'], start_index) if id(p) not in param_mappings})
- start_index += len(group['params'])
-
- for g in optimizer.param_groups:
- get_group_mapping(g)
- return {k: param_mappings[id(p)] for k, p in model.named_parameters()}
-
-
-def compute_optimizer_state_size(state: Dict[str, Any]) -> int:
- size = 0
- for v in state.values():
- if isinstance(v, Tensor):
- size += v.numel() * v.element_size()
- return size
-
-
-class ModelCheckpointSharder:
-
- def __init__(self, max_shard_size: int) -> None:
- self.max_shard_size = max_shard_size
- self.buffer: Dict[str, Tensor] = {}
- self.buffer_size: int = 0
-
- def append(self, key: str, tensor: Tensor) -> Optional[dict]:
- retval = None
- if self.max_shard_size > 0 and self.buffer_size >= self.max_shard_size:
- retval = self.buffer
- self.buffer = {}
- self.buffer_size = 0
- self.buffer[key] = tensor
- self.buffer_size += tensor.numel() * tensor.element_size()
- return retval
-
- def extend(self, state_dict: Dict[str, Tensor]) -> List[dict]:
- shards = []
- for key, tensor in state_dict.items():
- shard = self.append(key, tensor)
- run_if_not_none(shards.append, shard)
- return shards
-
- def complete(self) -> Optional[dict]:
- return self.buffer if len(self.buffer) > 0 else None
-
-
-class OptimizerCheckpointSharder:
-
- def __init__(self, max_shard_size: int, param_groups: dict) -> None:
- self.max_shard_size = max_shard_size
- self.buffer: Dict[str, dict] = {'state': {}, 'param_groups': param_groups}
- self.buffer_size: int = 0
- self.returned_first: bool = False
-
- def append(self, key: int, state: dict) -> Optional[dict]:
- retval = None
- if self.max_shard_size > 0 and self.buffer_size >= self.max_shard_size:
- retval = self.buffer
- self.buffer = {'state': {}}
- self.buffer_size = 0
- self.buffer['state'][key] = state
- self.buffer_size += compute_optimizer_state_size(state)
- return retval
-
- def extend(self, state_dict: Dict[str, dict]) -> List[dict]:
- shards = []
- for key, state in state_dict['state'].items():
- shard = self.append(key, state)
- run_if_not_none(shards.append, shard)
- return shards
-
- def complete(self) -> Optional[dict]:
- return self.buffer if len(self.buffer['state']) > 0 else None
-
-
-def shard_checkpoint(max_shard_size: int,
- model_state_dict: Dict[str, Tensor],
- optimizer_state_dict: Optional[dict] = None,
- param_to_os: Optional[dict] = None) -> Tuple[List[dict], List[dict]]:
- has_optimizer: bool = False
- if optimizer_state_dict is not None:
- assert param_to_os is not None
- os_to_param = {v: k for k, v in param_to_os.items()}
- for os_key in optimizer_state_dict['state'].keys():
- assert os_key in os_to_param
- assert os_to_param[os_key] in model_state_dict
- has_optimizer = True
- model_sharder = ModelCheckpointSharder(max_shard_size)
- model_shards = model_sharder.extend(model_state_dict)
- run_if_not_none(model_shards.append, model_sharder.complete())
- if not has_optimizer:
- return model_shards, []
- optimizer_sharder = OptimizerCheckpointSharder(max_shard_size, optimizer_state_dict['param_groups'])
- optimizer_shards = optimizer_sharder.extend(optimizer_state_dict)
- run_if_not_none(optimizer_shards.append, optimizer_sharder.complete())
- return model_shards, optimizer_shards
-
-
-def get_paired_os(model_state_dict: Dict[str, Tensor], optimizer_state_dict: dict, param_to_os: Dict[str, int]) -> dict:
- os_to_param = {v: k for k, v in param_to_os.items()}
- paired_os = {}
- for idx, state in optimizer_state_dict['state'].items():
- paired_os[idx] = {}
- p = model_state_dict[os_to_param[idx]]
- for k, v in state.items():
- if isinstance(v, Tensor) and v.shape == p.shape:
- paired_os[idx][k] = True
- else:
- paired_os[idx][k] = False
- return paired_os
-
-
-def build_checkpoints(max_size: int,
- model: Module,
- optimizer: Optional[Optimizer] = None,
- param_to_os: Optional[Dict[str, int]] = None,
- dist_meta: Optional[Dict[str, ParamDistMeta]] = None,
- eliminate_replica: bool = False) -> Tuple[List[dict], List[dict], dict]:
- save_global = dist_meta is None
- model_state_dict = model.state_dict()
- optimizer_state_dict = optimizer.state_dict() if optimizer else None
- meta = {'dist_meta': dist_meta}
- if optimizer:
- param_to_os = param_to_os or get_param_to_os(model, optimizer)
- paired_os = get_paired_os(model_state_dict, optimizer_state_dict, param_to_os)
- meta['param_to_os'] = param_to_os
- meta['paired_os'] = paired_os
- if not save_global and eliminate_replica:
- # filter dp replicated params
- model_state_dict = {
- k: v for k, v in model_state_dict.items() if dist_meta[k].used_zero or dist_meta[k].dp_rank == 0
- }
- if optimizer:
- optimizer_state_dict['state'] = {
- param_to_os[k]: optimizer_state_dict['state'][param_to_os[k]]
- for k in model_state_dict.keys()
- if dist_meta[k].used_zero or dist_meta[k].dp_rank == 0
- }
- meta['params'] = list(model_state_dict.keys())
- if len(model_state_dict) == 0:
- warnings.warn('model state dict is empty, checkpoint is not saved')
- return [], [], meta
- model_checkpoints, optimizer_checkpoints = shard_checkpoint(max_size, model_state_dict, optimizer_state_dict,
- param_to_os)
- return model_checkpoints, optimizer_checkpoints, meta
-
-
-def is_duplicated_list(list_: List[Any]) -> bool:
- if len(list_) == 0:
- return True
- elem = list_[0]
- for x in list_[1:]:
- if x != elem:
- return False
- return True
-
-
-def copy_optimizer_state(src_state: dict, dest_state: dict) -> None:
- for k, v in src_state.items():
- if k in dest_state:
- old_v = dest_state[k]
- if isinstance(old_v, Tensor):
- old_v.copy_(v)
- else:
- dest_state[k] = v
-
-
-def optimizer_load_state_dict(optimizer: Optimizer, state_dict: dict, strict: bool = False) -> None:
- assert optimizer.state_dict()['param_groups'] == state_dict['param_groups']
- state_dict = deepcopy(state_dict)
- groups = optimizer.param_groups
- saved_groups = state_dict['param_groups']
- idx_to_p: Dict[str, Parameter] = {
- old_id: p for old_id, p in zip(chain.from_iterable((g['params'] for g in saved_groups
- )), chain.from_iterable((g['params'] for g in groups)))
- }
- missing_keys = list(set(idx_to_p.keys()) - set(state_dict['state'].keys()))
- unexpected_keys = []
- error_msgs = []
- for idx, state in state_dict['state'].items():
- if idx in idx_to_p:
- old_state = optimizer.state[idx_to_p[idx]]
- copy_optimizer_state(state, old_state)
- else:
- unexpected_keys.append(idx)
- if strict:
- if len(unexpected_keys) > 0:
- error_msgs.insert(
- 0, 'Unexpected key(s) in state_dict: {}. '.format(', '.join('"{}"'.format(k) for k in unexpected_keys)))
- if len(missing_keys) > 0:
- error_msgs.insert(
- 0, 'Missing key(s) in state_dict: {}. '.format(', '.join('"{}"'.format(k) for k in missing_keys)))
- if len(error_msgs) > 0:
- raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(optimizer.__class__.__name__,
- "\n\t".join(error_msgs)))
diff --git a/colossalai/utils/checkpoint_io/writer.py b/colossalai/utils/checkpoint_io/writer.py
deleted file mode 100644
index 4552accde470..000000000000
--- a/colossalai/utils/checkpoint_io/writer.py
+++ /dev/null
@@ -1,98 +0,0 @@
-from abc import ABC, abstractmethod
-from typing import Optional
-from .constant import MODEL_CKPT_FILE_NAME, OPTIM_CKPT_FILE_NAME, META_CKPT_FILE_NAME, OTHER_CKPT_FILE_NAME, GLOBAL_META_FILE_NAME
-import torch
-import os
-
-
-class CheckpointWriter(ABC):
-
- def __init__(self, base_name: str, overwrite: bool = False, rank: int = 0, world_size: int = 1) -> None:
- super().__init__()
- self.base_name = base_name
- self.overwrite = overwrite
- self.rank = rank
- self.world_size = world_size
- self.is_distributed = world_size > 1
- self.is_main_process = rank == 0
-
- @abstractmethod
- def write(self, name: str, state_dict: dict) -> None:
- pass
-
- @abstractmethod
- def save_model(self, model_checkpoint: dict) -> None:
- pass
-
- @abstractmethod
- def save_optimizer(self, optimizer_checkpoint: dict) -> None:
- pass
-
- @abstractmethod
- def save_meta(self, meta_checkpoint: dict) -> None:
- pass
-
- @abstractmethod
- def save_others(self, kwargs: dict) -> None:
- pass
-
-
-class DiskCheckpointWriter(CheckpointWriter):
-
- def __init__(self, base_name: str, overwrite: bool = False, rank: int = 0, world_size: int = 1) -> None:
- super().__init__(base_name, overwrite, rank, world_size)
- if not os.path.exists(base_name):
- os.makedirs(base_name)
- assert os.path.isdir(base_name), f'"{base_name}" is not a directory'
- self.model_checkpoint_names = []
- self.optimizer_checkpoint_names = []
- self.is_meta_saved: bool = False
- self._save_global_meta()
-
- def write(self, name: str, state_dict: dict) -> None:
- path = os.path.join(self.base_name, name)
- if os.path.exists(path) and not self.overwrite:
- raise RuntimeError(f'Save error: Checkpoint "{path}" exists. (overwrite = False)')
- torch.save(state_dict, path)
-
- def _save_global_meta(self) -> None:
- if self.is_main_process:
- global_meta = {'meta': []}
- if self.is_distributed:
- for i in range(self.world_size):
- global_meta['meta'].append(META_CKPT_FILE_NAME.replace('.bin', f'-rank{i}.bin'))
- else:
- global_meta['meta'].append(META_CKPT_FILE_NAME)
- self.write(GLOBAL_META_FILE_NAME, global_meta)
-
- def _get_checkpoint_name(self, base_name: str, shard_idx: Optional[int] = None) -> str:
- checkpoint_name = base_name
- if self.is_distributed:
- checkpoint_name = checkpoint_name.replace('.bin', f'-rank{self.rank}.bin')
- if shard_idx is not None:
- checkpoint_name = checkpoint_name.replace('.bin', f'-shard{shard_idx}.bin')
- return checkpoint_name
-
- def save_model(self, model_checkpoint: dict) -> None:
- assert not self.is_meta_saved, 'Cannot save model after saving meta'
- name = self._get_checkpoint_name(MODEL_CKPT_FILE_NAME, len(self.model_checkpoint_names))
- self.write(name, model_checkpoint)
- self.model_checkpoint_names.append(name)
-
- def save_optimizer(self, optimizer_checkpoint: dict) -> None:
- assert not self.is_meta_saved, 'Cannot save optimizer after saving meta'
- name = self._get_checkpoint_name(OPTIM_CKPT_FILE_NAME, len(self.optimizer_checkpoint_names))
- self.write(name, optimizer_checkpoint)
- self.optimizer_checkpoint_names.append(name)
-
- def save_meta(self, meta_checkpoint: dict) -> None:
- if len(self.model_checkpoint_names) > 0:
- meta_checkpoint['model'] = self.model_checkpoint_names
- if len(self.optimizer_checkpoint_names) > 0:
- meta_checkpoint['optimizer'] = self.optimizer_checkpoint_names
- self.write(self._get_checkpoint_name(META_CKPT_FILE_NAME), meta_checkpoint)
- self.is_meta_saved = True
-
- def save_others(self, kwargs: dict) -> None:
- if self.is_main_process:
- self.write(OTHER_CKPT_FILE_NAME, kwargs)
diff --git a/tests/test_lazy/test_distribute.py b/tests/test_lazy/test_distribute.py
deleted file mode 100644
index 622d9deb601d..000000000000
--- a/tests/test_lazy/test_distribute.py
+++ /dev/null
@@ -1,102 +0,0 @@
-from typing import Optional
-
-import pytest
-import torch
-import torch.nn as nn
-
-import colossalai
-from colossalai.device.device_mesh import DeviceMesh
-from colossalai.tensor.d_tensor.layout import Layout
-from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec
-from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
-from colossalai.utils.common import print_rank_0
-
-try:
- from colossalai.lazy.lazy_init import LazyInitContext, LazyTensor, _MyTensor
-except:
- pass
-from lazy_init_utils import SUPPORT_LAZY, assert_dist_model_equal, set_seed
-
-from tests.kit.model_zoo import model_zoo
-
-
-def find_shard_dim(shape: torch.Size) -> Optional[int]:
- for dim, size in enumerate(shape):
- if size % 2 == 0:
- return dim
-
-
-def make_sharding_spec(original_tensor: torch.Tensor) -> Layout:
- shard_dim = find_shard_dim(original_tensor.shape)
- dim_partition_dict = {shard_dim: [0]} if shard_dim is not None else {}
- target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict=dim_partition_dict)
- return target_sharding_spec
-
-
-def _get_current_name(prefix: str, name: str) -> str:
- return f'{prefix}.{name}'.lstrip('.')
-
-
-def generate_sharding_spec_dict(model: nn.Module) -> dict:
- sharding_spec_dict = {}
-
- @torch.no_grad()
- def generate_recursively(module: nn.Module, prefix: str = ''):
- # recursively initialize the module
- for name, mod in module.named_children():
- generate_recursively(mod, prefix=_get_current_name(prefix, name))
-
- # initialize tensors directly attached to the current module
- for name, param in module.named_parameters(recurse=False):
- if isinstance(param, LazyTensor):
- sharding_spec = make_sharding_spec(param)
- sharding_spec_dict[_get_current_name(prefix, name)] = sharding_spec
-
- for name, buf in module.named_buffers(recurse=False):
- if isinstance(buf, LazyTensor):
- sharding_spec = make_sharding_spec(buf)
- sharding_spec_dict[_get_current_name(prefix, name)] = sharding_spec
-
- generate_recursively(model)
-
- return sharding_spec_dict
-
-
-@parameterize('subset', ['torchvision', 'diffusers', 'timm', 'transformers', 'torchaudio', 'deepfm', 'dlrm'])
-def run_dist_lazy_init(subset, seed: int = 42):
- sub_model_zoo = model_zoo.get_sub_registry(subset)
- device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True)
- _MyTensor._pre_op_fn = lambda *args: set_seed(seed)
- LazyTensor._pre_op_fn = lambda *args: set_seed(seed)
-
- for name, entry in sub_model_zoo.items():
- # TODO(ver217): lazy init does not support weight norm, skip these models
- if name in ('torchaudio_wav2vec2_base', 'torchaudio_hubert_base') or name.startswith('transformers_llama'):
- continue
- print_rank_0(name)
- model_fn, data_gen_fn, output_transform_fn, _, model_attr = entry
- ctx = LazyInitContext(tensor_cls=_MyTensor)
- with ctx:
- model = model_fn()
- ctx = LazyInitContext()
- with ctx:
- deferred_model = model_fn()
- sharding_spec_dict = generate_sharding_spec_dict(deferred_model)
- ctx.distribute(deferred_model, device_mesh, sharding_spec_dict, verbose=True)
- assert_dist_model_equal(model, deferred_model, device_mesh, sharding_spec_dict)
-
-
-def run_dist(rank, world_size, port) -> None:
- colossalai.launch({}, rank=rank, world_size=world_size, host='localhost', port=port)
- run_dist_lazy_init()
-
-
-@pytest.mark.skipif(not SUPPORT_LAZY, reason='torch version should be >= 1.12.0')
-@pytest.mark.dist
-@rerun_if_address_is_in_use()
-def test_dist_lazy_init():
- spawn(run_dist, 4)
-
-
-if __name__ == '__main__':
- test_dist_lazy_init()
diff --git a/tests/test_utils/test_checkpoint_io/test_build_checkpoints.py b/tests/test_utils/test_checkpoint_io/test_build_checkpoints.py
deleted file mode 100644
index 6d89fb90c574..000000000000
--- a/tests/test_utils/test_checkpoint_io/test_build_checkpoints.py
+++ /dev/null
@@ -1,120 +0,0 @@
-import torch
-import torch.nn as nn
-from colossalai.utils.checkpoint_io.meta import ParamDistMeta
-from colossalai.utils.checkpoint_io.utils import build_checkpoints
-from torch.optim import Adam
-
-
-class DummyModel(nn.Module):
-
- def __init__(self) -> None:
- super().__init__()
- self.fc = nn.Linear(20, 1)
-
-
-def test_global_model():
- model = DummyModel()
- model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(0, model)
- assert len(model_checkpoints) == 1
- assert len(optimizer_checkpoints) == 0
- assert meta['dist_meta'] is None
- orig_state_dict = model.state_dict()
- global_state_dict = model_checkpoints[0]
- assert set(orig_state_dict.keys()) == set(global_state_dict.keys())
- for k, v in orig_state_dict.items():
- assert torch.equal(v, global_state_dict[k])
-
-
-def test_global_model_shard():
- model = DummyModel()
- model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(80, model)
- assert len(model_checkpoints) == 2
- assert len(optimizer_checkpoints) == 0
- assert meta['dist_meta'] is None
- orig_state_dict = model.state_dict()
- assert set(orig_state_dict.keys()) == set(model_checkpoints[0].keys()) | set(model_checkpoints[1].keys())
- assert len(set(model_checkpoints[0].keys()) & set(model_checkpoints[1].keys())) == 0
- for k, v in orig_state_dict.items():
- for state_dict in model_checkpoints:
- if k in state_dict:
- assert torch.equal(v, state_dict[k])
-
-
-def test_global_optimizer():
- model = DummyModel()
- for p in model.parameters():
- p.grad = torch.rand_like(p)
- optimizer = Adam(model.parameters(), lr=1e-3)
- optimizer.step()
- model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(0, model, optimizer)
- assert len(optimizer_checkpoints) == 1
- assert meta['param_to_os'] == {'fc.weight': 0, 'fc.bias': 1}
- for state in meta['paired_os'].values():
- for k, is_paired in state.items():
- if k == 'step':
- assert not is_paired
- else:
- assert is_paired
- orig_state_dict = optimizer.state_dict()
- state_dict = optimizer_checkpoints[0]
- for k, orig_state in orig_state_dict['state'].items():
- state = state_dict['state'][k]
- for v1, v2 in zip(orig_state.values(), state.values()):
- if isinstance(v2, torch.Tensor):
- assert torch.equal(v1, v2)
- else:
- assert v2 == v2
- assert orig_state_dict['param_groups'] == state_dict['param_groups']
-
-
-def test_global_optimizer_shard():
- model = DummyModel()
- for p in model.parameters():
- p.grad = torch.rand_like(p)
- optimizer = Adam(model.parameters(), lr=1e-3)
- optimizer.step()
- model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(80, model, optimizer)
- assert len(optimizer_checkpoints) == 2
- assert 'param_groups' in optimizer_checkpoints[0] and 'param_groups' not in optimizer_checkpoints[1]
- orig_state_dict = optimizer.state_dict()
- assert set(orig_state_dict['state'].keys()) == set(optimizer_checkpoints[0]['state'].keys()) | set(
- optimizer_checkpoints[1]['state'].keys())
- assert len(set(optimizer_checkpoints[0]['state'].keys()) & set(optimizer_checkpoints[1]['state'].keys())) == 0
- for k, orig_state in orig_state_dict['state'].items():
- state = optimizer_checkpoints[0]['state'][k] if k in optimizer_checkpoints[0][
- 'state'] else optimizer_checkpoints[1]['state'][k]
- for v1, v2 in zip(orig_state.values(), state.values()):
- if isinstance(v2, torch.Tensor):
- assert torch.equal(v1, v2)
- else:
- assert v1 == v2
-
- assert orig_state_dict['param_groups'] == optimizer_checkpoints[0]['param_groups']
-
-
-def test_dist_model_optimizer():
- model = DummyModel()
- for p in model.parameters():
- p.grad = torch.rand_like(p)
- optimizer = Adam(model.parameters(), lr=1e-3)
- optimizer.step()
- dist_meta = {'fc.weight': ParamDistMeta(0, 2, 0, 1), 'fc.bias': ParamDistMeta(1, 2, 0, 1)}
- model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(0, model, optimizer, dist_meta=dist_meta)
- assert dist_meta == meta['dist_meta']
- assert len(model_checkpoints) == 1
- assert len(optimizer_checkpoints) == 1
- assert 'fc.weight' in model_checkpoints[0] and 'fc.bias' in model_checkpoints[0]
- assert 0 in optimizer_checkpoints[0]['state'] and 1 in optimizer_checkpoints[0]['state']
- dist_meta = {'fc.weight': ParamDistMeta(1, 2, 0, 1), 'fc.bias': ParamDistMeta(1, 2, 0, 1)}
- model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(0, model, optimizer, dist_meta=dist_meta)
- assert dist_meta == meta['dist_meta']
- assert len(model_checkpoints) == 1
- assert len(optimizer_checkpoints) == 1
-
-
-if __name__ == '__main__':
- test_global_model()
- test_global_model_shard()
- test_global_optimizer()
- test_global_optimizer_shard()
- test_dist_model_optimizer()
diff --git a/tests/test_utils/test_checkpoint_io/test_load.py b/tests/test_utils/test_checkpoint_io/test_load.py
deleted file mode 100644
index 2949c9f0752d..000000000000
--- a/tests/test_utils/test_checkpoint_io/test_load.py
+++ /dev/null
@@ -1,186 +0,0 @@
-from copy import deepcopy
-from functools import partial
-from tempfile import TemporaryDirectory
-from typing import Dict
-
-import pytest
-import torch
-import torch.distributed as dist
-import torch.nn as nn
-from torch import Tensor
-from torch.nn import Module
-from torch.optim import Adam, Optimizer
-
-import colossalai
-from colossalai.testing import rerun_if_address_is_in_use, spawn
-from colossalai.utils.checkpoint_io.io import load, save
-from colossalai.utils.checkpoint_io.meta import ParamDistMeta, ParamRedistMeta, RankRedistMeta, RedistMeta
-
-
-def check_model_state_dict(a: Dict[str, Tensor], b: Dict[str, Tensor]) -> None:
- assert set(a.keys()) == set(b.keys())
- for k, v in a.items():
- assert torch.equal(v, b[k])
-
-
-def check_optim_state_dict(a: dict, b: dict, ignore_param_groups: bool = False) -> None:
- assert set(a['state'].keys()) == set(b['state'].keys())
- for k, state in a['state'].items():
- b_state = b['state'][k]
- for v1, v2 in zip(state.values(), b_state.values()):
- if isinstance(v1, Tensor):
- assert torch.equal(v1, v2)
- else:
- assert v1 == v2
- if not ignore_param_groups:
- assert a['param_groups'] == b['param_groups']
-
-
-class DummyModel(nn.Module):
-
- def __init__(self) -> None:
- super().__init__()
- self.fc = nn.Linear(20, 1)
-
-
-def prepare_model_optim(shard: bool = False, zero: bool = False):
- model = DummyModel()
- if shard:
- model.fc.weight.data = model.fc.weight.chunk(2, 1)[dist.get_rank() % 2]
- if zero:
- dp_rank = dist.get_rank() // 2
- model.fc.weight.data = model.fc.weight.reshape(-1).split([3, model.fc.weight.size(1) - 3], 0)[dp_rank]
- if dp_rank != 0:
- model.fc.bias.data = torch.empty(0, dtype=model.fc.bias.dtype)
- for p in model.parameters():
- p.grad = torch.rand_like(p)
- optimizer = Adam(model.parameters(), lr=1e-3)
- optimizer.step()
- return model, optimizer
-
-
-def reset_model_optim(model: Module, optimizer: Optimizer, scalar: float = 0.0):
- with torch.no_grad():
- for p in model.parameters():
- p.fill_(scalar)
- for state in optimizer.state.values():
- for v in state.values():
- if isinstance(v, Tensor):
- v.fill_(scalar)
-
-
-def get_dist_metas(nprocs: int, zero: bool = False):
- dp_world_size = nprocs // 2
- dist_metas = []
- for rank in range(nprocs):
- if zero:
- dist_metas.append({
- 'fc.weight':
- ParamDistMeta(rank // 2,
- dp_world_size,
- rank % 2,
- 2,
- tp_shard_dims=[1],
- tp_num_parts=[2],
- zero_numel=10,
- zero_orig_shape=[1, 10]),
- 'fc.bias':
- ParamDistMeta(rank // 2, dp_world_size, 0, 1, zero_numel=1, zero_orig_shape=[1])
- })
- else:
- dist_metas.append({
- 'fc.weight': ParamDistMeta(rank // 2, dp_world_size, rank % 2, 2, tp_shard_dims=[1], tp_num_parts=[2]),
- 'fc.bias': ParamDistMeta(rank // 2, dp_world_size, 0, 1)
- })
- return dist_metas
-
-
-def get_redist_meta(nprocs: int):
- dp_world_size = nprocs // 2
- rank_meta = {
- 'fc.weight': {rank: RankRedistMeta(rank // 2, rank % 2, 0) for rank in range(nprocs)},
- 'fc.bias': {rank: RankRedistMeta(rank // 2, 0, 0) for rank in range(nprocs)}
- }
- param_meta = {
- 'fc.weight': ParamRedistMeta(dp_world_size, 2, tp_shard_dims=[1], tp_num_parts=[2]),
- 'fc.bias': ParamRedistMeta(dp_world_size, 1)
- }
- return RedistMeta(rank_meta, [], param_meta)
-
-
-@pytest.mark.parametrize('max_shard_size_gb', [80 / 1024**3, 0])
-def test_save_global_load_global(max_shard_size_gb: float):
- model, optimizer = prepare_model_optim()
- with TemporaryDirectory() as dir_name:
- save(dir_name, model, optimizer, max_shard_size_gb=max_shard_size_gb)
- new_model, new_optimizer = prepare_model_optim()
- load(dir_name, new_model, new_optimizer, max_shard_size_gb=max_shard_size_gb)
- check_model_state_dict(model.state_dict(), new_model.state_dict())
- check_optim_state_dict(optimizer.state_dict(), new_optimizer.state_dict())
-
-
-def run_dist(rank, world_size, port, test_fn):
- colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
- test_fn()
-
-
-def launch_dist(fn, world_size: int):
- spawn(run_dist, world_size, test_fn=fn)
-
-
-def save_dist(dir_name: str, zero: bool):
- model, optimizer = prepare_model_optim(shard=True, zero=zero)
- reset_model_optim(model, optimizer)
- world_size = dist.get_world_size()
- rank = dist.get_rank()
- save(dir_name, model, optimizer, dist_meta=get_dist_metas(world_size, zero)[rank])
-
-
-def load_and_check_dist(dir_name: str):
- world_size = dist.get_world_size()
- model, optimizer = prepare_model_optim(shard=True)
- reset_model_optim(model, optimizer)
- model_state_dict = deepcopy(model.state_dict())
- optimizer_state_dict = deepcopy(optimizer.state_dict())
- reset_model_optim(model, optimizer, 1)
- load(dir_name, model, optimizer, get_redist_meta(world_size), get_dist_metas(world_size))
- check_model_state_dict(model_state_dict, model.state_dict())
- check_optim_state_dict(optimizer_state_dict, optimizer.state_dict())
-
-
-@pytest.mark.dist
-@rerun_if_address_is_in_use()
-def test_save_global_load_dist():
- model, optimizer = prepare_model_optim()
- reset_model_optim(model, optimizer)
- with TemporaryDirectory() as dir_name:
- save(dir_name, model, optimizer)
- fn = partial(load_and_check_dist, dir_name)
- launch_dist(fn, 4)
-
-
-@pytest.mark.dist
-@rerun_if_address_is_in_use()
-def test_save_dist_load_dist():
- with TemporaryDirectory() as dir_name:
- # save tp + dp
- fn = partial(save_dist, dir_name, False)
- launch_dist(fn, 2)
- # load tp + dp
- fn = partial(load_and_check_dist, dir_name)
- launch_dist(fn, 2)
- with TemporaryDirectory() as dir_name:
- # save tp + zero
- fn = partial(save_dist, dir_name, True)
- launch_dist(fn, 4)
- # load tp + dp
- fn = partial(load_and_check_dist, dir_name)
- launch_dist(fn, 2)
- launch_dist(fn, 4)
-
-
-if __name__ == '__main__':
- test_save_global_load_global(80 / 1024**3)
- test_save_global_load_global(0)
- test_save_global_load_dist()
- test_save_dist_load_dist()
diff --git a/tests/test_utils/test_checkpoint_io/test_merge.py b/tests/test_utils/test_checkpoint_io/test_merge.py
deleted file mode 100644
index 07d4597f8391..000000000000
--- a/tests/test_utils/test_checkpoint_io/test_merge.py
+++ /dev/null
@@ -1,126 +0,0 @@
-import os
-from functools import partial
-from tempfile import TemporaryDirectory
-
-import pytest
-import torch
-import torch.distributed as dist
-import torch.nn as nn
-from torch.optim import Adam
-
-import colossalai
-from colossalai.testing import rerun_if_address_is_in_use, spawn
-from colossalai.utils.checkpoint_io.constant import GLOBAL_META_FILE_NAME
-from colossalai.utils.checkpoint_io.io import merge, save
-from colossalai.utils.checkpoint_io.meta import ParamDistMeta
-
-
-class DummyModel(nn.Module):
-
- def __init__(self) -> None:
- super().__init__()
- self.fc = nn.Linear(20, 1)
-
-
-def prepare_model_optim(shard: bool = False, zero: bool = False):
- model = DummyModel()
- if shard:
- model.fc.weight.data = model.fc.weight.chunk(2, 1)[dist.get_rank() % 2]
- if zero:
- dp_rank = dist.get_rank() // 2
- model.fc.weight.data = model.fc.weight.reshape(-1).split([3, model.fc.weight.size(1) - 3], 0)[dp_rank]
- if dp_rank != 0:
- model.fc.bias.data = torch.empty(0, dtype=model.fc.bias.dtype)
- for p in model.parameters():
- p.grad = torch.ones_like(p)
- optimizer = Adam(model.parameters(), lr=1e-3)
- optimizer.step()
- return model, optimizer
-
-
-def test_merge_global():
- model, optimizer = prepare_model_optim()
- with TemporaryDirectory() as dir_name:
- save(dir_name, model, optimizer)
- with TemporaryDirectory() as output_dir:
- merge(dir_name, output_dir)
- assert len(os.listdir(output_dir)) == 0
- with TemporaryDirectory() as dir_name:
- save(dir_name, model, optimizer, max_shard_size_gb=80 / 1024**3)
- with TemporaryDirectory() as output_dir:
- merge(dir_name, output_dir)
- assert len(os.listdir(output_dir)) == 0
-
-
-def run_dist(rank, world_size, port, test_fn):
- colossalai.launch(config={'parallel': {
- 'tensor': {
- 'mode': '1d',
- 'size': 2
- }
- }},
- rank=rank,
- world_size=world_size,
- host='localhost',
- port=port,
- backend='nccl')
- test_fn()
-
-
-def run_save_dist(dir_name: str, zero: bool):
- model, optimizer = prepare_model_optim(shard=True, zero=zero)
- rank = dist.get_rank()
- dp_world_size = dist.get_world_size() // 2
- if not zero:
- dist_metas = {
- 'fc.weight': ParamDistMeta(rank // 2, dp_world_size, rank % 2, 2, tp_shard_dims=[1], tp_num_parts=[2]),
- 'fc.bias': ParamDistMeta(rank // 2, dp_world_size, 0, 1)
- }
- else:
- dist_metas = {
- 'fc.weight':
- ParamDistMeta(rank // 2,
- dp_world_size,
- rank % 2,
- 2,
- tp_shard_dims=[1],
- tp_num_parts=[2],
- zero_numel=10,
- zero_orig_shape=[1, 10]),
- 'fc.bias':
- ParamDistMeta(rank // 2, dp_world_size, 0, 1, zero_numel=1, zero_orig_shape=[1])
- }
- save(dir_name, model, optimizer, dist_meta=dist_metas)
-
-
-@pytest.mark.dist
-@pytest.mark.parametrize("zero", [False, True])
-@rerun_if_address_is_in_use()
-def test_merge_tp_dp(zero: bool):
- with TemporaryDirectory() as dir_name:
- fn = partial(run_save_dist, dir_name, zero)
- world_size = 4
- spawn(run_dist, world_size, test_fn=fn)
- with TemporaryDirectory() as output_dir:
- merge(dir_name, output_dir)
- assert len(os.listdir(output_dir)) == 5
- global_meta = torch.load(os.path.join(output_dir, GLOBAL_META_FILE_NAME))
- assert len(global_meta['meta']) == 1
- meta = torch.load(os.path.join(output_dir, global_meta['meta'][0]))
- assert meta['dist_meta'] is None
- assert len(meta['params']) == 2
- assert len(meta['model']) == 1 and len(meta['optimizer']) == 1
- model_state_dict = torch.load(os.path.join(output_dir, meta['model'][0]))
- assert len(model_state_dict) == 2
- assert model_state_dict['fc.weight'].size(1) == 20
- optimizer_state_dict = torch.load(os.path.join(output_dir, meta['optimizer'][0]))
- assert len(optimizer_state_dict['state']) == 2
- assert 'param_groups' in optimizer_state_dict and 'state' in optimizer_state_dict
- assert optimizer_state_dict['state'][0]['exp_avg'].size(1) == 20
- assert optimizer_state_dict['state'][0]['exp_avg_sq'].size(1) == 20
-
-
-if __name__ == '__main__':
- test_merge_global()
- test_merge_tp_dp(False)
- test_merge_tp_dp(True)
diff --git a/tests/test_utils/test_checkpoint_io/test_merge_param.py b/tests/test_utils/test_checkpoint_io/test_merge_param.py
deleted file mode 100644
index 5da2ae4fe1f8..000000000000
--- a/tests/test_utils/test_checkpoint_io/test_merge_param.py
+++ /dev/null
@@ -1,101 +0,0 @@
-import torch
-from colossalai.utils.checkpoint_io.meta import ParamDistMeta
-from colossalai.utils.checkpoint_io.distributed import unflatten_zero_param, gather_tp_param, merge_param
-
-
-def test_unflatten_zero_param_even() -> None:
- dist_metas = [ParamDistMeta(i, 4, 0, 1, zero_numel=16, zero_orig_shape=[4, 4]) for i in range(4)]
- orig_tensor = torch.rand(4, 4)
- tensors = list(orig_tensor.reshape(-1).chunk(4))
- unflattened_tensor = unflatten_zero_param(tensors, dist_metas)
- assert torch.equal(orig_tensor, unflattened_tensor)
- merged_tensor = merge_param(tensors, dist_metas)
- assert torch.equal(orig_tensor, merged_tensor)
-
-
-def test_unflatten_zero_param_uneven() -> None:
- dist_metas = [ParamDistMeta(i, 4, 0, 1, zero_numel=16, zero_orig_shape=[4, 4]) for i in range(1, 3)]
- orig_tensor = torch.rand(4, 4)
- tensors = list(orig_tensor.reshape(-1).split([13, 3]))
- unflattened_tensor = unflatten_zero_param(tensors, dist_metas)
- assert torch.equal(orig_tensor, unflattened_tensor)
- merged_tensor = merge_param(tensors, dist_metas)
- assert torch.equal(orig_tensor, merged_tensor)
-
-
-def test_gather_tp_param_1d_row() -> None:
- dist_metas = [ParamDistMeta(0, 1, i, 4, tp_shard_dims=[0], tp_num_parts=[4]) for i in range(4)]
- orig_tensor = torch.rand(4, 4)
- tensors = [t.contiguous() for t in orig_tensor.chunk(4, 0)]
- gathered_tensor = gather_tp_param(tensors, dist_metas)
- assert torch.equal(orig_tensor, gathered_tensor)
- merged_tensor = merge_param(tensors, dist_metas)
- assert torch.equal(orig_tensor, merged_tensor)
-
-
-def test_gather_tp_param_1d_col() -> None:
- dist_metas = [ParamDistMeta(0, 1, i, 4, tp_shard_dims=[1], tp_num_parts=[4]) for i in range(4)]
- orig_tensor = torch.rand(4, 4)
- tensors = [t.contiguous() for t in orig_tensor.chunk(4, 1)]
- gathered_tensor = gather_tp_param(tensors, dist_metas)
- assert torch.equal(orig_tensor, gathered_tensor)
- merged_tensor = merge_param(tensors, dist_metas)
- assert torch.equal(orig_tensor, merged_tensor)
-
-
-def test_gather_tp_param_2d() -> None:
- dist_metas = [ParamDistMeta(0, 1, i, 6, tp_shard_dims=[0, 1], tp_num_parts=[2, 3]) for i in range(6)]
- orig_tensor = torch.rand(4, 6)
- tensors = [t.contiguous() for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1)]
- gathered_tensor = gather_tp_param(tensors, dist_metas)
- assert torch.equal(orig_tensor, gathered_tensor)
- merged_tensor = merge_param(tensors, dist_metas)
- assert torch.equal(orig_tensor, merged_tensor)
-
-
-def test_gather_tp_param_2d_reverse() -> None:
- dist_metas = [ParamDistMeta(0, 1, i, 6, tp_shard_dims=[1, 0], tp_num_parts=[3, 2]) for i in range(6)]
- orig_tensor = torch.rand(4, 6)
- tensors = [t.contiguous() for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1)]
- gathered_tensor = gather_tp_param(tensors, dist_metas)
- assert torch.equal(orig_tensor, gathered_tensor)
- merged_tensor = merge_param(tensors, dist_metas)
- assert torch.equal(orig_tensor, merged_tensor)
-
-
-def test_merge_param_hybrid() -> None:
- dist_metas = [
- ParamDistMeta(i % 2,
- 2,
- i // 2,
- 6,
- tp_shard_dims=[1, 0],
- tp_num_parts=[3, 2],
- zero_numel=4,
- zero_orig_shape=[2, 2]) for i in range(12)
- ]
- orig_tensor = torch.rand(4, 6)
- tensors = [
- chunk for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1)
- for chunk in t.contiguous().reshape(-1).split([1, 3])
- ]
- merged_tensor = merge_param(tensors, dist_metas)
- assert torch.equal(orig_tensor, merged_tensor)
-
-
-def test_merge_param_dummy() -> None:
- dist_metas = [ParamDistMeta(0, 1, 0, 1)]
- orig_tensor = torch.rand(4, 6)
- merged_tensor = merge_param([orig_tensor], dist_metas)
- assert torch.equal(orig_tensor, merged_tensor)
-
-
-if __name__ == '__main__':
- test_unflatten_zero_param_even()
- test_unflatten_zero_param_uneven()
- test_gather_tp_param_1d_row()
- test_gather_tp_param_1d_col()
- test_gather_tp_param_2d()
- test_gather_tp_param_2d_reverse()
- test_merge_param_hybrid()
- test_merge_param_dummy()
diff --git a/tests/test_utils/test_checkpoint_io/test_redist.py b/tests/test_utils/test_checkpoint_io/test_redist.py
deleted file mode 100644
index fdc849a5ecc0..000000000000
--- a/tests/test_utils/test_checkpoint_io/test_redist.py
+++ /dev/null
@@ -1,152 +0,0 @@
-import os
-from functools import partial
-from tempfile import TemporaryDirectory
-
-import pytest
-import torch
-import torch.distributed as dist
-import torch.nn as nn
-from torch.optim import Adam
-
-import colossalai
-from colossalai.testing import rerun_if_address_is_in_use, spawn
-from colossalai.utils.checkpoint_io.constant import GLOBAL_META_FILE_NAME
-from colossalai.utils.checkpoint_io.io import redist, save
-from colossalai.utils.checkpoint_io.meta import (
- ParamDistMeta,
- ParamRedistMeta,
- PipelineRedistMeta,
- RankRedistMeta,
- RedistMeta,
-)
-
-
-class DummyModel(nn.Module):
-
- def __init__(self) -> None:
- super().__init__()
- self.fc = nn.Linear(20, 1)
-
-
-def prepare_model_optim(shard: bool = False, zero: bool = False):
- model = DummyModel()
- if shard:
- model.fc.weight.data = model.fc.weight.chunk(2, 1)[dist.get_rank() % 2]
- if zero:
- dp_rank = dist.get_rank() // 2
- model.fc.weight.data = model.fc.weight.reshape(-1).split([3, model.fc.weight.size(1) - 3], 0)[dp_rank]
- if dp_rank != 0:
- model.fc.bias.data = torch.empty(0, dtype=model.fc.bias.dtype)
- for p in model.parameters():
- p.grad = torch.ones_like(p)
- optimizer = Adam(model.parameters(), lr=1e-3)
- optimizer.step()
- return model, optimizer
-
-
-def get_dist_metas(nprocs: int, zero: bool = False):
- dp_world_size = nprocs // 2
- dist_metas = []
- for rank in range(nprocs):
- if zero:
- dist_metas.append({
- 'fc.weight':
- ParamDistMeta(rank // 2,
- dp_world_size,
- rank % 2,
- 2,
- tp_shard_dims=[1],
- tp_num_parts=[2],
- zero_numel=10,
- zero_orig_shape=[1, 10]),
- 'fc.bias':
- ParamDistMeta(rank // 2, dp_world_size, 0, 1, zero_numel=1, zero_orig_shape=[1])
- })
- else:
- dist_metas.append({
- 'fc.weight': ParamDistMeta(rank // 2, dp_world_size, rank % 2, 2, tp_shard_dims=[1], tp_num_parts=[2]),
- 'fc.bias': ParamDistMeta(rank // 2, dp_world_size, 0, 1)
- })
- return dist_metas
-
-
-def get_redist_meta(nprocs: int):
- dp_world_size = nprocs // 2
- rank_meta = {
- 'fc.weight': {rank: RankRedistMeta(rank // 2, rank % 2, 0) for rank in range(nprocs)},
- 'fc.bias': {rank: RankRedistMeta(rank // 2, 0, 0) for rank in range(nprocs)}
- }
- param_meta = {
- 'fc.weight': ParamRedistMeta(dp_world_size, 2, tp_shard_dims=[1], tp_num_parts=[2]),
- 'fc.bias': ParamRedistMeta(dp_world_size, 1)
- }
- return RedistMeta(rank_meta, [], param_meta)
-
-
-def check_checkpoint_shape(dir_name: str):
- global_meta = torch.load(os.path.join(dir_name, GLOBAL_META_FILE_NAME))
- for meta_name in global_meta['meta']:
- meta = torch.load(os.path.join(dir_name, meta_name))
- assert meta['dist_meta'] is not None
- assert len(meta['params']) == 2
- assert len(meta['model']) == 1 and len(meta['optimizer']) == 1
- model_state_dict = torch.load(os.path.join(dir_name, meta['model'][0]))
- assert len(model_state_dict) == 2
- assert model_state_dict['fc.weight'].size(1) == 10
- optimizer_state_dict = torch.load(os.path.join(dir_name, meta['optimizer'][0]))
- assert len(optimizer_state_dict['state']) == 2
- assert 'param_groups' in optimizer_state_dict and 'state' in optimizer_state_dict
- assert optimizer_state_dict['state'][0]['exp_avg'].size(1) == 10
- assert optimizer_state_dict['state'][0]['exp_avg_sq'].size(1) == 10
-
-
-def test_global_to_dist():
- model, optimizer = prepare_model_optim()
- with TemporaryDirectory() as dir_name:
- save(dir_name, model, optimizer)
- with TemporaryDirectory() as output_dir:
- redist(dir_name, output_dir, get_redist_meta(4), get_dist_metas(4))
- check_checkpoint_shape(output_dir)
-
-
-def run_dist(rank, world_size, port, test_fn):
- colossalai.launch(config={'parallel': {
- 'tensor': {
- 'mode': '1d',
- 'size': 2
- }
- }},
- rank=rank,
- world_size=world_size,
- host='localhost',
- port=port,
- backend='nccl')
- test_fn()
-
-
-def run_save_dist(dir_name: str, zero: bool):
- model, optimizer = prepare_model_optim(shard=True, zero=zero)
- rank = dist.get_rank()
- save(dir_name, model, optimizer, dist_meta=get_dist_metas(4, zero)[rank])
-
-
-@pytest.mark.dist
-@pytest.mark.parametrize("zero", [False, True])
-@rerun_if_address_is_in_use()
-def test_dist_to_dist(zero: bool):
- with TemporaryDirectory() as dir_name:
- fn = partial(run_save_dist, dir_name, zero)
- world_size = 4
- spawn(run_dist, world_size, test_fn=fn)
- with TemporaryDirectory() as output_dir:
- redist(dir_name, output_dir, get_redist_meta(4), get_dist_metas(4))
- if not zero:
- assert len(os.listdir(output_dir)) == 0
- else:
- check_checkpoint_shape(output_dir)
-
-
-if __name__ == '__main__':
- test_global_to_dist()
- test_dist_to_dist(False)
- test_dist_to_dist(True)
diff --git a/tests/test_utils/test_checkpoint_io/test_save.py b/tests/test_utils/test_checkpoint_io/test_save.py
deleted file mode 100644
index 2abdd95a6481..000000000000
--- a/tests/test_utils/test_checkpoint_io/test_save.py
+++ /dev/null
@@ -1,149 +0,0 @@
-import os
-from functools import partial
-from tempfile import TemporaryDirectory
-from typing import Dict
-
-import pytest
-import torch
-import torch.distributed as dist
-import torch.nn as nn
-from torch import Tensor
-from torch.optim import Adam
-
-import colossalai
-from colossalai.testing import rerun_if_address_is_in_use, spawn
-from colossalai.utils.checkpoint_io.constant import (
- GLOBAL_META_FILE_NAME,
- META_CKPT_FILE_NAME,
- MODEL_CKPT_FILE_NAME,
- OTHER_CKPT_FILE_NAME,
-)
-from colossalai.utils.checkpoint_io.io import save
-from colossalai.utils.checkpoint_io.meta import ParamDistMeta
-
-
-def check_model_state_dict(a: Dict[str, Tensor], b: Dict[str, Tensor]) -> None:
- assert set(a.keys()) == set(b.keys())
- for k, v in a.items():
- assert torch.equal(v, b[k])
-
-
-def check_optim_state_dict(a: dict, b: dict, ignore_param_groups: bool = False) -> None:
- assert set(a['state'].keys()) == set(b['state'].keys())
- for k, state in a['state'].items():
- b_state = b['state'][k]
- for v1, v2 in zip(state.values(), b_state.values()):
- if isinstance(v1, Tensor):
- assert torch.equal(v1, v2)
- else:
- assert v1 == v2
- if not ignore_param_groups:
- assert a['param_groups'] == b['param_groups']
-
-
-class DummyModel(nn.Module):
-
- def __init__(self) -> None:
- super().__init__()
- self.fc = nn.Linear(20, 1)
-
-
-def prepare_model_optim():
- model = DummyModel()
- for p in model.parameters():
- p.grad = torch.ones_like(p)
- optimizer = Adam(model.parameters(), lr=1e-3)
- optimizer.step()
- return model, optimizer
-
-
-def test_overwrite():
- model = DummyModel()
- with TemporaryDirectory() as dir_name:
- with open(os.path.join(dir_name, MODEL_CKPT_FILE_NAME.replace('.bin', '-shard0.bin')), 'a') as f:
- pass
- with pytest.raises(RuntimeError, match=r'Save error: Checkpoint ".+" exists\. \(overwrite = False\)'):
- save(dir_name, model)
-
-
-def test_save_global():
- model, optimizer = prepare_model_optim()
- with TemporaryDirectory() as dir_name:
- save(dir_name, model, optimizer)
- assert len(os.listdir(dir_name)) == 5
- global_meta = torch.load(os.path.join(dir_name, GLOBAL_META_FILE_NAME))
- assert len(global_meta['meta']) == 1 and global_meta['meta'][0] == META_CKPT_FILE_NAME
- meta = torch.load(os.path.join(dir_name, META_CKPT_FILE_NAME))
- assert len(meta['model']) == 1
- assert len(meta['optimizer']) == 1
- model_state_dict = torch.load(os.path.join(dir_name, meta['model'][0]))
- check_model_state_dict(model.state_dict(), model_state_dict)
- optimizer_state_dict = torch.load(os.path.join(dir_name, meta['optimizer'][0]))
- check_optim_state_dict(optimizer.state_dict(), optimizer_state_dict)
- other_state_dict = torch.load(os.path.join(dir_name, OTHER_CKPT_FILE_NAME))
- assert len(other_state_dict) == 0
-
-
-def test_save_global_shard():
- model, optimizer = prepare_model_optim()
- with TemporaryDirectory() as dir_name:
- save(dir_name, model, optimizer, max_shard_size_gb=80 / 1024**3)
- assert len(os.listdir(dir_name)) == 7
- meta = torch.load(os.path.join(dir_name, META_CKPT_FILE_NAME))
- assert len(meta['model']) == 2 and len(meta['optimizer']) == 2
- model_state_dicts = [torch.load(os.path.join(dir_name, name)) for name in meta['model']]
- assert len(set(model_state_dicts[0].keys()) & set(model_state_dicts[1].keys())) == 0
- check_model_state_dict(model.state_dict(), {**model_state_dicts[0], **model_state_dicts[1]})
- optimizer_state_dicts = [torch.load(os.path.join(dir_name, name)) for name in meta['optimizer']]
- assert len(set(optimizer_state_dicts[0]['state'].keys()) & set(optimizer_state_dicts[1]['state'].keys())) == 0
- assert 'param_groups' in optimizer_state_dicts[0] and 'param_groups' not in optimizer_state_dicts[1]
- check_optim_state_dict(
- optimizer.state_dict(), {
- 'state': {
- **optimizer_state_dicts[0]['state'],
- **optimizer_state_dicts[1]['state']
- },
- 'param_groups': optimizer_state_dicts[0]['param_groups']
- })
-
-
-def run_dist(rank, world_size, port, test_fn):
- colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
- test_fn()
-
-
-def run_save_dist(dir_name):
- model, optimizer = prepare_model_optim()
- dist_metas = {
- 'fc.weight': ParamDistMeta(dist.get_rank(), dist.get_world_size(), 0, 1),
- 'fc.bias': ParamDistMeta(dist.get_rank(), dist.get_world_size(), 0, 1)
- }
- save(dir_name, model, optimizer, dist_meta=dist_metas)
-
-
-@pytest.mark.dist
-@rerun_if_address_is_in_use()
-def test_save_dist():
- with TemporaryDirectory() as dir_name:
- fn = partial(run_save_dist, dir_name)
- world_size = 2
- spawn(run_dist, world_size, test_fn=fn)
- assert len(os.listdir(dir_name)) == 8
- global_meta = torch.load(os.path.join(dir_name, GLOBAL_META_FILE_NAME))
- assert len(global_meta['meta']) == 2
- for rank, meta_name in enumerate(global_meta['meta']):
- meta = torch.load(os.path.join(dir_name, meta_name))
- assert meta.get('dist_meta', None) is not None
- assert len(meta['model']) == 1 and len(meta['optimizer']) == 1
- model_state_dict = torch.load(os.path.join(dir_name, meta['model'][0]))
- assert len(model_state_dict) == 2
- optimizer_state_dict = torch.load(os.path.join(dir_name, meta['optimizer'][0]))
- assert len(optimizer_state_dict['state']) == 2
- assert 'param_groups' in optimizer_state_dict
-
-
-if __name__ == '__main__':
- test_overwrite()
- test_save_global()
- test_save_global_shard()
- test_save_dist()
diff --git a/tests/test_utils/test_checkpoint_io/test_unmerge_param.py b/tests/test_utils/test_checkpoint_io/test_unmerge_param.py
deleted file mode 100644
index 8b83caa12359..000000000000
--- a/tests/test_utils/test_checkpoint_io/test_unmerge_param.py
+++ /dev/null
@@ -1,137 +0,0 @@
-import torch
-from colossalai.utils.checkpoint_io.meta import ParamRedistMeta
-from colossalai.utils.checkpoint_io.distributed import flatten_zero_param, split_tp_param, unmerge_param
-
-
-def test_flatten_zero_param_even() -> None:
- redist_meta = ParamRedistMeta(4, 1, zero_start_dp_rank=0, zero_offsets=[0, 4, 8, 12])
- orig_tensor = torch.rand(4, 4)
- tensors = list(orig_tensor.reshape(-1).chunk(4))
- flat_tensors = flatten_zero_param(orig_tensor, redist_meta)
- assert len(tensors) == len(flat_tensors)
- for t, st in zip(tensors, flat_tensors):
- assert torch.equal(t, st)
- unmerged_tensors = unmerge_param(orig_tensor, redist_meta)
- assert len(unmerged_tensors) == 1
- unmerged_tensors = unmerged_tensors[0]
- assert len(tensors) == len(unmerged_tensors)
- for t, tl in zip(tensors, unmerged_tensors):
- assert torch.equal(t, tl)
-
-
-def test_flatten_zero_param_uneven() -> None:
- redist_meta = ParamRedistMeta(4, 1, zero_start_dp_rank=1, zero_offsets=[0, 13])
- orig_tensor = torch.rand(4, 4)
- tensors = list(orig_tensor.reshape(-1).split([13, 3]))
- flat_tensors = flatten_zero_param(orig_tensor, redist_meta)
- assert flat_tensors[0].size(0) == 0 and flat_tensors[-1].size(0) == 0
- flat_tensors = flat_tensors[1:-1]
- assert len(tensors) == len(flat_tensors)
- for t, st in zip(tensors, flat_tensors):
- assert torch.equal(t, st)
- unmerged_tensors = unmerge_param(orig_tensor, redist_meta)
- assert len(unmerged_tensors) == 1
- unmerged_tensors = unmerged_tensors[0]
- assert unmerged_tensors[0].size(0) == 0 and unmerged_tensors[-1].size(0) == 0
- unmerged_tensors = unmerged_tensors[1:-1]
- assert len(tensors) == len(unmerged_tensors)
- for t, tl in zip(tensors, unmerged_tensors):
- assert torch.equal(t, tl)
-
-
-def test_split_tp_param_1d_row() -> None:
- redist_meta = ParamRedistMeta(1, 4, tp_shard_dims=[0], tp_num_parts=[4])
- orig_tensor = torch.rand(4, 4)
- tensors = [t.contiguous() for t in orig_tensor.chunk(4, 0)]
- split_tensors = split_tp_param(orig_tensor, redist_meta)
- assert len(tensors) == len(split_tensors)
- for t, st in zip(tensors, split_tensors):
- assert torch.equal(t, st)
- unmerged_tensors = unmerge_param(orig_tensor, redist_meta)
- assert len(tensors) == len(unmerged_tensors)
- for t, tl in zip(tensors, unmerged_tensors):
- assert len(tl) == 1
- assert torch.equal(t, tl[0])
-
-
-def test_split_tp_param_1d_col() -> None:
- redist_meta = ParamRedistMeta(1, 4, tp_shard_dims=[1], tp_num_parts=[4])
- orig_tensor = torch.rand(4, 4)
- tensors = [t.contiguous() for t in orig_tensor.chunk(4, 1)]
- split_tensors = split_tp_param(orig_tensor, redist_meta)
- assert len(tensors) == len(split_tensors)
- for t, st in zip(tensors, split_tensors):
- assert torch.equal(t, st)
- unmerged_tensors = unmerge_param(orig_tensor, redist_meta)
- assert len(tensors) == len(unmerged_tensors)
- for t, tl in zip(tensors, unmerged_tensors):
- assert len(tl) == 1
- assert torch.equal(t, tl[0])
-
-
-def test_split_tp_param_2d() -> None:
- redist_meta = ParamRedistMeta(1, 6, tp_shard_dims=[0, 1], tp_num_parts=[2, 3])
- orig_tensor = torch.rand(4, 6)
- tensors = [t.contiguous() for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1)]
- split_tensors = split_tp_param(orig_tensor, redist_meta)
- assert len(tensors) == len(split_tensors)
- for t, st in zip(tensors, split_tensors):
- assert torch.equal(t, st)
- unmerged_tensors = unmerge_param(orig_tensor, redist_meta)
- assert len(tensors) == len(unmerged_tensors)
- for t, tl in zip(tensors, unmerged_tensors):
- assert len(tl) == 1
- assert torch.equal(t, tl[0])
-
-
-def test_split_tp_param_2d_reverse() -> None:
- redist_meta = ParamRedistMeta(1, 6, tp_shard_dims=[1, 0], tp_num_parts=[3, 2])
- orig_tensor = torch.rand(4, 6)
- tensors = [t.contiguous() for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1)]
- split_tensors = split_tp_param(orig_tensor, redist_meta)
- assert len(tensors) == len(split_tensors)
- for t, st in zip(tensors, split_tensors):
- assert torch.equal(t, st)
- unmerged_tensors = unmerge_param(orig_tensor, redist_meta)
- assert len(tensors) == len(unmerged_tensors)
- for t, tl in zip(tensors, unmerged_tensors):
- assert len(tl) == 1
- assert torch.equal(t, tl[0])
-
-
-def test_unmerge_param_hybrid() -> None:
- redist_meta = ParamRedistMeta(2,
- 6,
- tp_shard_dims=[1, 0],
- tp_num_parts=[3, 2],
- zero_start_dp_rank=0,
- zero_offsets=[0, 1])
- orig_tensor = torch.rand(4, 6)
- tensors = [
- chunk for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1)
- for chunk in t.contiguous().reshape(-1).split([1, 3])
- ]
- unmerged_tensors = unmerge_param(orig_tensor, redist_meta)
- assert len(unmerged_tensors) == 6 and len(unmerged_tensors[0]) == 2
- for tp_rank in range(6):
- for dp_rank in range(2):
- assert torch.equal(tensors[tp_rank * 2 + dp_rank], unmerged_tensors[tp_rank][dp_rank])
-
-
-def test_unmerge_param_dummy() -> None:
- redist_meta = ParamRedistMeta(1, 1)
- orig_tensor = torch.rand(4, 6)
- unmerged_tensors = unmerge_param(orig_tensor, redist_meta)
- assert len(unmerged_tensors) == 1 and len(unmerged_tensors[0]) == 1
- assert torch.equal(orig_tensor, unmerged_tensors[0][0])
-
-
-if __name__ == '__main__':
- test_flatten_zero_param_even()
- test_flatten_zero_param_uneven()
- test_split_tp_param_1d_row()
- test_split_tp_param_1d_col()
- test_split_tp_param_2d()
- test_split_tp_param_2d_reverse()
- test_unmerge_param_hybrid()
- test_unmerge_param_dummy()
diff --git a/tests/test_zero/test_legacy/common.py b/tests/test_zero/test_legacy/common.py
deleted file mode 100644
index 2c3d122c79af..000000000000
--- a/tests/test_zero/test_legacy/common.py
+++ /dev/null
@@ -1,140 +0,0 @@
-from functools import partial
-
-import torch
-import torch.distributed as dist
-
-from colossalai.logging import get_dist_logger
-from colossalai.utils import checkpoint
-from colossalai.zero.legacy.shard_utils import TensorShardStrategy
-from colossalai.zero.legacy.sharded_model import ShardedModelV2
-
-LOGGER = get_dist_logger('zero_test')
-
-MP_PARALLEL_CONFIG = dict(fp16=dict(mode=None,), parallel=dict(pipeline=dict(size=1), tensor=dict(size=2, mode=None)))
-
-_ZERO_MODEL_CONFIG = dict(reduce_scatter_bucket_size_mb=25,
- fp32_reduce_scatter=False,
- tensor_placement_policy='cuda',
- gradient_predivide_factor=1.0,
- shard_strategy=TensorShardStrategy(),
- reuse_fp16_shard=False)
-
-_ZERO_OPTIMIZER_CONFIG = dict(initial_scale=2**5,
- min_scale=1,
- growth_factor=2,
- backoff_factor=0.5,
- growth_interval=1000,
- hysteresis=2,
- max_scale=2**32)
-
-ZERO_PARALLEL_CONFIG = dict(fp16=dict(mode=None,),
- zero=dict(
- model_config=_ZERO_MODEL_CONFIG,
- optimizer_config=_ZERO_OPTIMIZER_CONFIG,
- ),
- parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)))
-
-CONFIG = dict(fp16=dict(mode=None,),
- zero=dict(level=3,
- verbose=False,
- offload_optimizer_config=dict(device='cpu', pin_memory=True, buffer_count=5, fast_init=False),
- offload_param_config=dict(device='cpu',
- pin_memory=True,
- buffer_count=5,
- buffer_size=1e8,
- max_in_cpu=1e9)),
- parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)))
-
-
-def run_fwd_bwd(model, data, label, criterion, enable_autocast=False):
- model.train()
- with torch.cuda.amp.autocast(enabled=enable_autocast):
- if criterion:
- y = model(data)
- loss = criterion(y, label)
- else:
- loss = model(data, label)
- loss = loss.float()
- if isinstance(model, ShardedModelV2):
- model.backward(loss)
- else:
- loss.backward()
-
-
-def checkpoint_wrapper(module, enable=True):
- if enable:
- module.forward = partial(checkpoint, module.forward)
- return module
-
-
-def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> bool:
- if loose:
- return torch.allclose(tensor_a, tensor_b, atol=1e-2, rtol=1e-3)
- return torch.allclose(tensor_a, tensor_b)
-
-
-def check_grads(model, zero_model, loose=False):
- for p, zero_p in zip(model.parameters(), zero_model.parameters()):
- zero_grad = zero_p.grad.clone().to(p.device)
- grad = p.grad.float()
- assert grad.dtype == zero_grad.dtype
- assert allclose(grad, zero_grad, loose=loose)
-
-
-def check_params(model, zero_model, loose=False):
- for p, zero_p in zip(model.parameters(), zero_model.parameters()):
- zero_p = zero_p.clone().to(p.device)
- # assert p.dtype == zero_p.dtype
- assert allclose(p.float(), zero_p.float(), loose=loose), f"diff {p.float() - zero_p.float()}"
-
-
-def check_grads_padding(model, zero_model, loose=False):
- rank = dist.get_rank()
- for (name, p), (zero_name, zero_p) in zip(model.named_parameters(), zero_model.named_parameters()):
- # zero_grad = zero_p.grad.clone().to(p.device)
- if zero_p.colo_attr.is_replicated:
- zero_grad = zero_p.colo_attr.grad_payload.clone().to(p.device)
- chunks = torch.flatten(p.grad).chunk(dist.get_world_size())
- if rank >= len(chunks):
- continue
- grad = chunks[rank].float()
- if zero_grad.size(0) > grad.size(0):
- zero_grad = zero_grad[:grad.size(0)]
- else:
- zero_grad = zero_p.colo_attr.grad_payload
- grad = p.grad.to(zero_grad.dtype)
-
- assert grad.dtype == zero_grad.dtype
- assert allclose(grad, zero_grad, loose=loose), f'diff: {grad - zero_grad}'
-
-
-def check_params_padding(model, zero_model, loose=False):
- rank = dist.get_rank()
- for p, zero_p in zip(model.parameters(), zero_model.parameters()):
- zero_p = zero_p.clone().to(p.device)
- chunks = torch.flatten(p).chunk(dist.get_world_size())
- if rank >= len(chunks):
- continue
- p = chunks[rank]
- if zero_p.size(0) > p.size(0):
- zero_p = zero_p[:p.size(0)]
- assert p.dtype == zero_p.dtype
- assert allclose(p, zero_p, loose=loose)
-
-
-def check_sharded_model_params(model, zero_model, loose=False, reuse_fp16_shard=False):
- rank = dist.get_rank()
- for (name, p), (zero_name, zero_p) in zip(model.named_parameters(), zero_model.named_parameters()):
- if zero_p.colo_attr.param_is_sharded:
- zero_p = zero_p.colo_attr.data_payload.to(p.device).float()
- chunks = torch.flatten(p).chunk(dist.get_world_size())
- if rank >= len(chunks):
- continue
- p = chunks[rank].float()
- if zero_p.size(0) > p.size(0):
- zero_p = zero_p[:p.size(0)]
- else:
- zero_p = zero_p.colo_attr.data_payload.to(p.device)
-
- assert p.dtype == zero_p.dtype, "Parameter `{}`:\n{} vs {}".format(name, p.dtype, zero_p.dtype)
- assert allclose(p, zero_p, loose=loose), f'{p} vs {zero_p}'
diff --git a/tests/test_zero/test_legacy/test_found_inf.py b/tests/test_zero/test_legacy/test_found_inf.py
deleted file mode 100644
index e90158e0a43b..000000000000
--- a/tests/test_zero/test_legacy/test_found_inf.py
+++ /dev/null
@@ -1,67 +0,0 @@
-import pytest
-import torch
-from common import CONFIG
-from test_sharded_optim_v2 import _run_step
-
-import colossalai
-from colossalai.nn.optimizer import HybridAdam
-from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
-from colossalai.utils.cuda import get_current_device
-from colossalai.zero.legacy.init_ctx import ZeroInitContext
-from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy
-from colossalai.zero.legacy.sharded_model import ShardedModelV2
-from colossalai.zero.legacy.sharded_optim import ShardedOptimizerV2
-from colossalai.zero.low_level._utils import has_inf_or_nan
-from tests.components_to_test.registry import non_distributed_component_funcs
-
-
-@parameterize("cpu_offload", [True, False])
-@parameterize("shard_strategy_class", [BucketTensorShardStrategy])
-@parameterize("gpu_margin_mem_ratio", [0.0, 0.7])
-def _run_test_found_inf(cpu_offload, shard_strategy_class, gpu_margin_mem_ratio):
- test_models = ['repeated_computed_layers']
- shard_strategy = shard_strategy_class()
-
- for model_name in test_models:
- get_components_func = non_distributed_component_funcs.get_callable(model_name)
- model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()
-
- with ZeroInitContext(target_device=torch.device(f'cpu:0') if cpu_offload else get_current_device(),
- shard_strategy=shard_strategy,
- shard_param=True):
- zero_model = model_builder(checkpoint=True)
- zero_model = ShardedModelV2(
- zero_model,
- shard_strategy,
- tensor_placement_policy='cpu' if cpu_offload else 'cuda',
- reuse_fp16_shard=True,
- )
-
- sharded_optim = HybridAdam(zero_model.parameters(), lr=1e-3)
- sharded_optim = ShardedOptimizerV2(zero_model, sharded_optim, gpu_margin_mem_ratio=gpu_margin_mem_ratio)
-
- for i, (data, label) in enumerate(train_dataloader):
- if i > 1:
- break
- assert zero_model.overflow_counter == 0
- data, label = data.cuda(), label.cuda()
- _run_step(zero_model, sharded_optim, data, label, criterion, False)
- for param in zero_model.parameters():
- assert not has_inf_or_nan(param.colo_attr.data_payload)
-
-
-def _run_dist(rank, world_size, port):
- colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
- _run_test_found_inf()
-
-
-# use_cpuadam = True can be used with cpu_offload = False
-@pytest.mark.dist
-@pytest.mark.parametrize("world_size", [1, 2])
-@rerun_if_address_is_in_use()
-def test_found_inf(world_size):
- spawn(_run_dist, world_size)
-
-
-if __name__ == '__main__':
- test_found_inf(world_size=2)
diff --git a/tests/test_zero/test_legacy/test_gemini_manager.py b/tests/test_zero/test_legacy/test_gemini_manager.py
deleted file mode 100644
index 0e956f7cc617..000000000000
--- a/tests/test_zero/test_legacy/test_gemini_manager.py
+++ /dev/null
@@ -1,75 +0,0 @@
-import pytest
-import torch
-
-from colossalai.testing import clear_cache_before_run
-from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor, TensorState
-
-
-@pytest.mark.dist
-@clear_cache_before_run()
-def test_gemini_manager():
- # reset the manager, in case that there exists memory information left
- manager = StatefulTensor.GST_MGR
- manager.reset()
-
- # occupation 8
- st1 = StatefulTensor(torch.empty(2, 2, dtype=torch.float16, device='cuda'))
- # occupation 60
- st2 = StatefulTensor(torch.empty(3, 5, dtype=torch.float32, device='cpu'))
-
- # occupation 28
- t1 = torch.empty(7, device='cuda')
- # occupation 12
- t2 = torch.empty(3, device='cpu')
- st3 = StatefulTensor(t1, TensorState.HOLD_AFTER_FWD)
- st4 = StatefulTensor(None, TensorState.FREE)
-
- assert manager.total_number == 4
- assert manager.total_mem['cpu'] == 60
- assert manager.total_mem['cuda'] == 36
- assert manager.state_mem['cpu'][TensorState.HOLD] == 60
- assert manager.state_mem['cuda'][TensorState.HOLD] == 8
- assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 28
-
- st4.payload_reset(t2)
- st3.payload_reset(t2)
-
- assert manager.total_number == 4
- assert manager.total_mem['cpu'] == 84
- assert manager.total_mem['cuda'] == 8
- assert manager.state_mem['cpu'][TensorState.HOLD] == 72
- assert manager.state_mem['cuda'][TensorState.HOLD] == 8
- assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 12
- assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 0
-
- st1.move_to(torch.device('cpu'))
- st2.move_to(torch.device('cpu'))
- st3.move_to(torch.device('cuda', 0))
-
- assert manager.total_number == 4
- assert manager.total_mem['cpu'] == 80
- assert manager.total_mem['cuda'] == 12
- assert manager.state_mem['cpu'][TensorState.HOLD] == 80
- assert manager.state_mem['cuda'][TensorState.HOLD] == 0
- assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 0
- assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 12
-
- st1.trans_state(TensorState.COMPUTE)
- st2.trans_state(TensorState.COMPUTE)
- st2.trans_state(TensorState.HOLD_AFTER_BWD)
-
- assert manager.total_number == 4
- assert manager.total_mem['cpu'] == 80
- assert manager.total_mem['cuda'] == 12
- assert manager.state_mem['cpu'][TensorState.HOLD] == 12
- assert manager.state_mem['cuda'][TensorState.HOLD] == 0
- assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 0
- assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 12
- assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_BWD] == 60
- assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_BWD] == 0
- assert manager.state_mem['cpu'][TensorState.COMPUTE] == 8
- assert manager.state_mem['cuda'][TensorState.COMPUTE] == 0
-
-
-if __name__ == '__main__':
- test_gemini_manager()
diff --git a/tests/test_zero/test_legacy/test_init_context.py b/tests/test_zero/test_legacy/test_init_context.py
deleted file mode 100644
index 84493827193e..000000000000
--- a/tests/test_zero/test_legacy/test_init_context.py
+++ /dev/null
@@ -1,73 +0,0 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
-import pytest
-import torch
-from common import CONFIG
-
-import colossalai
-from colossalai.logging import get_dist_logger
-from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
-from colossalai.utils.cuda import get_current_device
-from colossalai.utils.memory import colo_device_memory_used
-from colossalai.zero.gemini.memory_tracer.utils import colo_model_mem_usage
-from colossalai.zero.legacy.init_ctx import ZeroInitContext
-from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy
-from tests.components_to_test.registry import non_distributed_component_funcs
-
-
-@parameterize("init_device_type", ['cpu', 'cuda'])
-@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
-def run_model_test(init_device_type, shard_strategy_class):
- logger = get_dist_logger("test_zero_init")
-
- for name, get_components_func in non_distributed_component_funcs._registry.items():
- # because the ZeroInitContext automatically turns parameters to fp16
- # and the beit model use tensor.erfinv_() function to initialize weights
- # tensor.erfinv_() doesn't support Half in CPU, we omit the beit model
- if name == 'beit':
- continue
- model_builder, _, _, _, _ = get_components_func()
- if init_device_type == 'cuda':
- init_device = get_current_device()
- elif init_device_type == 'cpu':
- init_device = torch.device("cpu")
- else:
- continue
-
- model_numel_tensor = torch.zeros(1, dtype=torch.int)
- with ZeroInitContext(target_device=init_device,
- shard_strategy=shard_strategy_class(),
- shard_param=True,
- model_numel_tensor=model_numel_tensor):
- model = model_builder(checkpoint=True)
-
- for param in model.parameters():
- assert hasattr(param, 'colo_attr')
- assert param.colo_attr.sharded_data_tensor.dtype == torch.half
- assert param.colo_attr.sharded_data_tensor.is_sharded
- assert param.colo_attr.data_payload.device.type == init_device.type, \
- f'{param.colo_attr.data_payload.device.type} vs. {init_device.type}'
-
- cuda_mem_use, _ = colo_model_mem_usage(model)
- model_data_cuda_mem_MB = cuda_mem_use / 1e6
- logger.info(f"Existing ZeRO Context.\nModel Data CUDA Memory {model_data_cuda_mem_MB} MB", ranks=[0])
- sys_cuda_mem_MB = colo_device_memory_used(get_current_device()) / 1e6
- logger.info(f"System CUDA Memory Usage {sys_cuda_mem_MB} MB", ranks=[0])
- logger.info(f"Model Number Parameter {model_numel_tensor.numpy()[0]/1e6} M", ranks=[0])
-
-
-def run_dist(rank, world_size, port):
- colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
- run_model_test()
-
-
-@pytest.mark.dist
-@pytest.mark.parametrize("world_size", [1, 4])
-@rerun_if_address_is_in_use()
-def test_zero_init_context(world_size):
- spawn(run_dist, world_size)
-
-
-if __name__ == '__main__':
- test_zero_init_context(1)
diff --git a/tests/test_zero/test_legacy/test_param_op.py b/tests/test_zero/test_legacy/test_param_op.py
deleted file mode 100644
index b91371b98922..000000000000
--- a/tests/test_zero/test_legacy/test_param_op.py
+++ /dev/null
@@ -1,82 +0,0 @@
-import copy
-
-import torch
-
-from colossalai.testing import clear_cache_before_run
-from colossalai.zero.legacy.gemini.paramhooks import BaseParamHookMgr
-from tests.components_to_test.registry import non_distributed_component_funcs
-
-
-def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> bool:
- if loose:
- return torch.allclose(tensor_a, tensor_b, atol=1e-3, rtol=1e-3)
- return torch.allclose(tensor_a, tensor_b)
-
-
-def run_model(model, inputs, label, criterion, use_param_hook=False):
- if use_param_hook:
-
- class HooKWrapper:
-
- def __init__(self) -> None:
- self.hook_triggered_times = 0
-
- def wrapper_func(self):
-
- def hook(param, grad) -> torch.Tensor or None:
- self.hook_triggered_times += 1
- return grad
-
- return hook
-
- hookwrapper = HooKWrapper()
- param_list = [p for p in model.parameters()]
- hook_mgr = BaseParamHookMgr(param_list)
- hook_mgr.register_backward_hooks(hookwrapper.wrapper_func())
-
- model.zero_grad(set_to_none=True)
-
- with torch.cuda.amp.autocast():
- if criterion:
- y = model(inputs)
- loss = criterion(y, label)
- else:
- loss = model(inputs, label)
- loss = loss.float()
- loss.backward()
-
- if use_param_hook:
- hook_mgr.remove_hooks()
- return hookwrapper.hook_triggered_times
-
-
-@clear_cache_before_run()
-def test_base_param_hook():
- test_models = ['repeated_computed_layers', 'resnet18', 'hanging_param_model', 'inline_op_model']
- # test_models = ['bert']
-
- for model_name in test_models:
- get_components_func = non_distributed_component_funcs.get_callable(model_name)
- model_builder, train_dataloader, _, _, criterion = get_components_func()
-
- torch.manual_seed(0)
- model = model_builder(checkpoint=True).cuda()
- model.train()
-
- for i, (inputs, label) in enumerate(train_dataloader):
- if i > 0:
- break
- model_copy = copy.deepcopy(model)
-
- run_model(model, inputs.cuda(), label.cuda(), criterion, False)
- ret2 = run_model(model_copy, inputs.cuda(), label.cuda(), criterion, True)
-
- # Make sure param hook has only be fired once in case of parameter sharing
- assert ret2 == len(list(model.parameters()))
-
- for p, p_copy in zip(model.parameters(), model_copy.parameters()):
- assert allclose(p.grad, p_copy.grad), f"{p.grad} vs {p_copy.grad}"
-
-
-if __name__ == '__main__':
- test_base_param_hook()
diff --git a/tests/test_zero/test_legacy/test_shard_model_v2.py b/tests/test_zero/test_legacy/test_shard_model_v2.py
deleted file mode 100644
index 93d624aa2bbd..000000000000
--- a/tests/test_zero/test_legacy/test_shard_model_v2.py
+++ /dev/null
@@ -1,64 +0,0 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
-import pytest
-import torch
-from common import CONFIG, check_grads_padding, run_fwd_bwd
-from torch.nn.parallel import DistributedDataParallel as DDP
-
-import colossalai
-from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
-from colossalai.zero.legacy.init_ctx import ZeroInitContext
-from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy
-from colossalai.zero.legacy.sharded_model import ShardedModelV2
-from colossalai.zero.legacy.sharded_model._utils import cast_tensor_to_fp16
-from colossalai.zero.legacy.sharded_model.utils import col_model_deepcopy
-from tests.components_to_test.registry import non_distributed_component_funcs
-
-
-@parameterize("enable_autocast", [True])
-@parameterize("shard_strategy_class", [BucketTensorShardStrategy])
-def run_model_test(enable_autocast, shard_strategy_class):
- test_models = ['repeated_computed_layers', 'resnet18', 'bert', 'hanging_param_model']
- shard_strategy = shard_strategy_class()
- for model_name in test_models:
- get_components_func = non_distributed_component_funcs.get_callable(model_name)
- model_builder, train_dataloader, _, _, criterion = get_components_func()
-
- with ZeroInitContext(target_device=torch.device('cuda', torch.cuda.current_device()),
- shard_strategy=shard_strategy,
- shard_param=True):
- zero_model = model_builder(checkpoint=True)
- zero_model = ShardedModelV2(zero_model, shard_strategy)
-
- model = model_builder(checkpoint=True).half()
- col_model_deepcopy(zero_model, model)
- model = model.cuda()
-
- model = DDP(model, device_ids=[torch.cuda.current_device()])
-
- for i, (data, label) in enumerate(train_dataloader):
- if i > 5:
- break
-
- data, label = cast_tensor_to_fp16(data).cuda(), label.cuda()
- run_fwd_bwd(model, data, label, criterion, enable_autocast)
- run_fwd_bwd(zero_model, data, label, criterion, enable_autocast)
-
- check_grads_padding(model, zero_model, loose=True)
-
-
-def run_dist(rank, world_size, port):
- colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
- run_model_test()
-
-
-@pytest.mark.dist
-@pytest.mark.parametrize("world_size", [1, 2])
-@rerun_if_address_is_in_use()
-def test_shard_model_v2(world_size):
- spawn(run_dist, world_size)
-
-
-if __name__ == '__main__':
- test_shard_model_v2(world_size=2)
diff --git a/tests/test_zero/test_legacy/test_shard_param.py b/tests/test_zero/test_legacy/test_shard_param.py
deleted file mode 100644
index 4ba43edceb5d..000000000000
--- a/tests/test_zero/test_legacy/test_shard_param.py
+++ /dev/null
@@ -1,91 +0,0 @@
-from copy import deepcopy
-
-import pytest
-import torch
-from common import CONFIG, allclose
-
-import colossalai
-from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
-from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor
-from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy
-from colossalai.zero.legacy.sharded_param import ShardedTensor
-from colossalai.zero.legacy.sharded_param.sharded_param import ShardedParamV2
-
-
-@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
-def run_shard_tensor_with_strategy(shard_strategy_class, world_size):
- t = ShardedTensor(tensor=torch.randn(world_size * 2, 3))
- assert list(t.origin_shape) == [world_size * 2, 3]
- assert list(t.shape) == [world_size * 2, 3]
-
- shard_strategy = shard_strategy_class()
-
- # test shard strategy
- shard_strategy.shard([t])
- assert list(t.shape) == [6], f"{list(t.shape)} vs 6"
- shard_strategy.gather([t])
- assert list(t.shape) == [world_size * 2, 3], f"{list(t.shape)} vs {[world_size * 2, 3]}"
-
-
-def _run_shard_tensor(rank, world_size, port):
- colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
- run_shard_tensor_with_strategy(world_size=world_size)
-
-
-@pytest.mark.dist
-@pytest.mark.parametrize("world_size", [1, 2])
-@rerun_if_address_is_in_use()
-def test_shard_tensor(world_size):
- spawn(_run_shard_tensor, world_size)
-
-
-def _run_shard_param_v2(rank, world_size, port):
- colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
-
- param = torch.nn.Parameter(torch.randn(2, 3))
- param_ref = deepcopy(param)
- sparam = ShardedParamV2(param=param)
-
- allclose(sparam.data_payload, param_ref.data)
-
- # Test get memory usage
- sparam.saved_grad = StatefulTensor(torch.randn(2, 3))
- cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
- assert cpu_mem_use == 2 * 3 * 4 * 2, f"cpu_mem_use: {cpu_mem_use}"
-
- sparam.set_data_none()
- assert (param.data.numel() == 0)
- cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
- # 4 is size of dummy tensor of param.data
- assert cpu_mem_use == 2 * 3 * 4 * 2
-
- sparam.saved_grad = StatefulTensor(torch.randn(2, 3))
- sparam.set_data_none()
- cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
- assert cpu_mem_use == 2 * 3 * 4 * 2
- assert cuda_mem_use == 0
-
- # append a grad to torch param
- param.data = sparam.data_payload
- param.grad = torch.randn(2, 3)
- cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
- assert cpu_mem_use == 2 * 3 * 4 * 2 + 2 * 3 * 4, f"cpu_mem_use {cpu_mem_use}"
- assert cuda_mem_use == 0
-
- # reuse torch grad for sparam
- sparam.saved_grad = StatefulTensor(param.grad)
- cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
- assert cpu_mem_use == 2 * 3 * 4 * 2
- assert cuda_mem_use == 0
-
-
-@pytest.mark.dist
-@pytest.mark.parametrize("world_size", [1, 2])
-@rerun_if_address_is_in_use()
-def test_shard_param_v2(world_size):
- spawn(_run_shard_param_v2, world_size)
-
-
-if __name__ == '__main__':
- # test_shard_tensor(2)
- test_shard_param_v2(2)
diff --git a/tests/test_zero/test_legacy/test_sharded_optim_state_dict.py b/tests/test_zero/test_legacy/test_sharded_optim_state_dict.py
deleted file mode 100644
index 1ca144662722..000000000000
--- a/tests/test_zero/test_legacy/test_sharded_optim_state_dict.py
+++ /dev/null
@@ -1,89 +0,0 @@
-import pytest
-import torch
-
-import colossalai
-from colossalai.nn.optimizer import HybridAdam
-from colossalai.tensor import ProcessGroup
-from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
-from colossalai.utils.cuda import get_current_device
-from colossalai.zero.legacy.init_ctx import ZeroInitContext
-from colossalai.zero.legacy.shard_utils import TensorShardStrategy
-from colossalai.zero.legacy.sharded_model import ShardedModelV2
-from colossalai.zero.legacy.sharded_optim import ShardedOptimizerV2
-from tests.components_to_test.registry import non_distributed_component_funcs
-from tests.test_tensor.common_utils import set_seed
-
-
-def init_zero(model_builder, placement_policy):
- device = get_current_device() if placement_policy == 'cuda' else torch.device('cpu')
- shard_strategy = TensorShardStrategy()
- with ZeroInitContext(target_device=device, shard_strategy=shard_strategy, shard_param=True):
- model = model_builder()
- model = ShardedModelV2(
- model,
- shard_strategy,
- tensor_placement_policy=placement_policy,
- reuse_fp16_shard=True,
- )
- optim = HybridAdam(model.parameters(), lr=1e-3)
- optim = ShardedOptimizerV2(model, optim, initial_scale=32)
- return model, optim
-
-
-def run_step(model, optim, criterion, data, label):
- optim.zero_grad()
- logits = model(data)
- loss = criterion(logits, label)
- optim.backward(loss)
- optim.step()
-
-
-def check_state_dict_eq(state_dict, other):
- for p, state in state_dict['state'].items():
- other_state = other['state'][p]
- for k, v in state.items():
- if isinstance(v, torch.Tensor):
- assert torch.allclose(v, other_state[k], atol=1e-3), f'{v} vs {other_state[k]}'
- else:
- assert v == other_state[k]
-
-
-@parameterize('placement_policy', ['cuda', 'cpu'])
-def run_nested_model(placement_policy):
- get_components_func = non_distributed_component_funcs.get_callable('simple_net')
- model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
-
- set_seed(42)
- model, optim = init_zero(model_builder, placement_policy)
- set_seed(42)
- model_copy, optim_copy = init_zero(model_builder, placement_policy)
-
- model.train()
- model_copy.train()
- pg = ProcessGroup()
- set_seed(pg.dp_local_rank())
- data_iter = iter(train_dataloader)
-
- data, label = map(lambda x: x.cuda(), next(data_iter))
- run_step(model, optim, criterion, data, label)
- optim_copy.load_state_dict(optim.state_dict())
- check_state_dict_eq(optim.state_dict(), optim_copy.state_dict())
-
- data, label = map(lambda x: x.cuda(), next(data_iter))
- run_step(model_copy, optim_copy, criterion, data, label)
-
-
-def run_dist(rank, world_size, port):
- colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
- run_nested_model()
-
-
-@pytest.mark.dist
-@pytest.mark.parametrize('world_size', [1, 2])
-@rerun_if_address_is_in_use()
-def test_sharded_optim_state_dist(world_size):
- spawn(run_dist, world_size)
-
-
-if __name__ == '__main__':
- test_sharded_optim_state_dist(2)
diff --git a/tests/test_zero/test_legacy/test_sharded_optim_v2.py b/tests/test_zero/test_legacy/test_sharded_optim_v2.py
deleted file mode 100644
index c6f77995ebcd..000000000000
--- a/tests/test_zero/test_legacy/test_sharded_optim_v2.py
+++ /dev/null
@@ -1,110 +0,0 @@
-import pytest
-import torch
-import torch.distributed as dist
-from common import CONFIG, check_sharded_model_params
-from torch.nn.parallel import DistributedDataParallel as DDP
-
-import colossalai
-from colossalai.amp import convert_to_apex_amp
-from colossalai.nn.optimizer import CPUAdam
-from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
-from colossalai.utils.cuda import get_current_device
-from colossalai.zero.legacy.init_ctx import ZeroInitContext
-from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy
-from colossalai.zero.legacy.sharded_model import ShardedModelV2
-from colossalai.zero.legacy.sharded_model.utils import col_model_deepcopy
-from colossalai.zero.legacy.sharded_optim import ShardedOptimizerV2
-from colossalai.zero.low_level._utils import has_inf_or_nan
-from tests.components_to_test.registry import non_distributed_component_funcs
-
-
-def _run_step(model, optimizer, data, label, criterion, enable_autocast=False):
- model.train()
- optimizer.zero_grad()
- with torch.cuda.amp.autocast(enabled=enable_autocast):
- if criterion:
- y = model(data)
- loss = criterion(y, label)
- else:
- loss = model(data, label)
-
- loss = loss.float()
- if isinstance(model, ShardedModelV2):
- optimizer.backward(loss)
- else:
- loss.backward()
- optimizer.step()
-
-
-@parameterize("cpu_offload", [True, False])
-@parameterize("use_cpuadam", [True, False])
-@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
-@parameterize("gpu_margin_mem_ratio", [0.0, 0.7])
-def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, gpu_margin_mem_ratio):
- test_models = ['repeated_computed_layers', 'resnet18', 'bert', 'hanging_param_model']
- shard_strategy = shard_strategy_class()
-
- if use_cpuadam and cpu_offload is False:
- return
- if gpu_margin_mem_ratio > 0.0 and not (cpu_offload and use_cpuadam):
- return
-
- for model_name in test_models:
- get_components_func = non_distributed_component_funcs.get_callable(model_name)
- model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()
-
- with ZeroInitContext(target_device=torch.device(f'cpu:0') if cpu_offload else get_current_device(),
- shard_strategy=shard_strategy,
- shard_param=True):
- zero_model = model_builder(checkpoint=True)
- zero_model = ShardedModelV2(
- zero_model,
- shard_strategy,
- tensor_placement_policy='cpu' if cpu_offload else 'auto',
- reuse_fp16_shard=use_cpuadam,
- )
-
- model = model_builder(checkpoint=True).half()
- col_model_deepcopy(zero_model, model)
- model = model.cuda().float()
-
- if use_cpuadam:
- optimizer_class = CPUAdam
- optim = optimizer_class(model.parameters(), lr=1e-3)
- sharded_optim = optimizer_class(zero_model.parameters(), lr=1e-3)
- sharded_optim = ShardedOptimizerV2(zero_model,
- sharded_optim,
- initial_scale=2**5,
- gpu_margin_mem_ratio=gpu_margin_mem_ratio)
-
- amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False)
- apex_model, apex_optimizer = convert_to_apex_amp(model, optim, amp_config)
- if dist.get_world_size() > 1:
- apex_model = DDP(apex_model, device_ids=[torch.cuda.current_device()])
-
- for i, (data, label) in enumerate(train_dataloader):
- if i > 5:
- break
- data, label = data.cuda(), label.cuda()
- _run_step(apex_model, apex_optimizer, data, label, criterion, False)
- _run_step(zero_model, sharded_optim, data, label, criterion, False)
- check_sharded_model_params(model, zero_model, loose=True, reuse_fp16_shard=use_cpuadam)
- for param in model.parameters():
- assert not has_inf_or_nan(param)
-
-
-def _run_dist(rank, world_size, port):
- colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
- _run_test_sharded_optim_v2()
-
-
-# use_cpuadam = True can be used with cpu_offload = False
-@pytest.mark.dist
-@pytest.mark.parametrize("world_size", [1, 2])
-@rerun_if_address_is_in_use()
-def test_sharded_optim_v2(world_size):
- spawn(_run_dist, world_size)
-
-
-if __name__ == '__main__':
- test_sharded_optim_v2(world_size=2)
diff --git a/tests/test_zero/test_legacy/test_sharded_optim_with_sync_bn.py b/tests/test_zero/test_legacy/test_sharded_optim_with_sync_bn.py
deleted file mode 100644
index 0223f18c29d6..000000000000
--- a/tests/test_zero/test_legacy/test_sharded_optim_with_sync_bn.py
+++ /dev/null
@@ -1,87 +0,0 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
-import pytest
-import torch
-import torch.distributed as dist
-from torchvision.models import resnet50
-
-import colossalai
-from colossalai.context.parallel_mode import ParallelMode
-from colossalai.core import global_context as gpc
-from colossalai.testing import rerun_if_address_is_in_use, spawn
-from colossalai.zero.legacy.init_ctx import ZeroInitContext
-from colossalai.zero.legacy.shard_utils import TensorShardStrategy
-
-
-def run_dist(rank, world_size, port):
- # this test only runs on resnet18
- # as this model has sync batch normalization
- # need to configure cudnn deterministic so that
- # randomness of convolution layers will be disabled
- zero_config = dict(model_config=dict(shard_strategy=TensorShardStrategy()))
- colossalai.launch(config=dict(zero=zero_config, cudnn_deterministic=True, cudnn_benchmark=False),
- rank=rank,
- world_size=world_size,
- host='localhost',
- port=port,
- backend='nccl')
-
- with ZeroInitContext(target_device=torch.cuda.current_device(),
- shard_strategy=gpc.config.zero.model_config.shard_strategy,
- shard_param=True):
- model = resnet50()
- optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
- criterion = torch.nn.CrossEntropyLoss()
-
- engine, *args = colossalai.initialize(model, optimizer, criterion)
-
- # train for dummy iterations
- engine.train()
- for _ in range(2):
- data = torch.rand(4, 3, 128, 128).cuda().half()
- label = torch.randint(0, 10, size=(4,)).cuda()
- engine.zero_grad()
- out = engine(data)
- loss = engine.criterion(out, label)
- engine.backward(loss)
- engine.step()
-
- # test
- # need to make sure the batch norm stats are synchronized
- # so that given the same input, the model will produce the same
- # output on different ranks
- engine.eval()
- data = torch.rand(4, 3, 128, 128).cuda().half()
- dist.broadcast(data, src=0, group=gpc.get_group(ParallelMode.DATA))
-
- # predict
- out = engine(data)
-
- # test if results are equal
- tensor_list = [torch.empty_like(out) for _ in range(world_size - 1)]
- tensor_list.insert(rank, out)
- dist.all_gather(tensor_list=tensor_list, tensor=out, group=gpc.get_group(ParallelMode.DATA))
-
- assert torch.all(tensor_list[0] == tensor_list[1]), \
- 'expected the output from different ranks to be the same, but got different values'
-
-
-@pytest.mark.dist
-@rerun_if_address_is_in_use()
-def test_sharded_optim_with_sync_bn():
- """
- This test is to make sure that buffers are synchronized between ranks
- when using ZeRO. An example of module buffer is the running stats of
- BatchNormalization layer, i.e. mean and var.
-
- If the buffers are not synchronized, the model will produce different
- output even though the input and parameters are the same. This is not
- wanted if we are doing predictions.
-
- """
- spawn(run_dist, 2)
-
-
-if __name__ == '__main__':
- test_sharded_optim_with_sync_bn()
diff --git a/tests/test_zero/test_legacy/test_state_dict.py b/tests/test_zero/test_legacy/test_state_dict.py
deleted file mode 100644
index 5f76fff3e5c3..000000000000
--- a/tests/test_zero/test_legacy/test_state_dict.py
+++ /dev/null
@@ -1,55 +0,0 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
-from functools import partial
-
-import pytest
-import torch
-from common import CONFIG
-
-import colossalai
-from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
-from colossalai.zero.legacy.init_ctx import ZeroInitContext
-from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy
-from colossalai.zero.legacy.sharded_model import ShardedModelV2
-from colossalai.zero.legacy.sharded_model.utils import col_model_deepcopy
-from tests.components_to_test.registry import non_distributed_component_funcs
-
-
-@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
-def run_zero_state_dict(shard_strategy_class):
- test_models = ['repeated_computed_layers', 'resnet18']
- shard_strategy = shard_strategy_class()
- for model_name in test_models:
- get_components_func = non_distributed_component_funcs.get_callable(model_name)
- model_builder, train_dataloader, test_dataloader, optimizer, criterion = get_components_func()
-
- with ZeroInitContext(target_device=torch.device('cuda', torch.cuda.current_device()),
- shard_strategy=shard_strategy,
- shard_param=True):
- zero_model = model_builder(checkpoint=True)
- zero_model = ShardedModelV2(zero_model, shard_strategy)
-
- model = model_builder(checkpoint=True).half()
- col_model_deepcopy(zero_model, model)
- model = model.cuda()
-
- zero_state_dict = zero_model.state_dict()
- for key, val in model.state_dict().items():
- assert torch.equal(val, zero_state_dict[key].to(val.device))
-
-
-def run_dist(rank, world_size, port):
- colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
- run_zero_state_dict()
-
-
-@pytest.mark.dist
-@pytest.mark.parametrize("world_size", [1, 2])
-@rerun_if_address_is_in_use()
-def test_zero_state_dict(world_size):
- spawn(run_dist, world_size)
-
-
-if __name__ == '__main__':
- test_zero_state_dict(2)
diff --git a/tests/test_zero/test_legacy/test_tensor_utils.py b/tests/test_zero/test_legacy/test_tensor_utils.py
deleted file mode 100644
index 238bc3fe1a98..000000000000
--- a/tests/test_zero/test_legacy/test_tensor_utils.py
+++ /dev/null
@@ -1,94 +0,0 @@
-import pytest
-import torch
-
-import colossalai
-from colossalai.testing import rerun_if_address_is_in_use, spawn
-from colossalai.utils.cuda import get_current_device
-from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor
-from colossalai.zero.legacy.gemini.tensor_utils import (
- colo_model_data_move_to_cpu,
- colo_model_data_tensor_move,
- colo_model_data_tensor_move_inline,
- colo_model_tensor_clone,
- colo_tensor_mem_usage,
-)
-
-
-def _run_colo_tensor_mem_usage():
- for i in range(1):
- if i == 1:
- t1 = StatefulTensor(torch.randn(2, 2))
- t2 = StatefulTensor(torch.randn(4, 4))
- c1, g1 = colo_tensor_mem_usage(t1)
- c2, g2 = colo_tensor_mem_usage(t2)
- assert c1 * 4 == c2
- assert g1 * 4 == g2
- else:
- t1 = torch.randn(2, 2)
- t2 = torch.randn(4, 4)
- c1, g1 = colo_tensor_mem_usage(t1)
- c2, g2 = colo_tensor_mem_usage(t2)
- assert c1 * 4 == c2
- assert g1 * 4 == g2
-
-
-def _run_colo_model_data_tensor_move_inline():
- for t in [StatefulTensor(torch.randn(2, 3)), torch.randn(2, 3)]:
- colo_model_data_tensor_move_inline(t, get_current_device())
- assert t.device == get_current_device()
-
-
-def _run_colo_model_data_tensor_move():
- for t in [(StatefulTensor(torch.ones(2, 3)), StatefulTensor(torch.zeros(2, 3).to(get_current_device()))),
- (torch.ones(2, 3), torch.zeros(2, 3).to(get_current_device()))]:
- cpu_t, cuda_t = t
- colo_model_data_tensor_move(cpu_t, cuda_t)
- assert cuda_t.device == get_current_device()
-
-
-def _run_colo_model_data_move_to_cpu():
- for t in [StatefulTensor(torch.randn(2, 2)), torch.randn(4, 4)]:
- colo_model_data_move_to_cpu(t)
- assert t.device == torch.device("cpu")
-
-
-def _run_colo_model_tensor_clone():
- for t in [
- StatefulTensor(torch.randn(2, 2).cuda(torch.cuda.current_device())),
- torch.randn(4, 4).cuda(torch.cuda.current_device())
- ]:
- if issubclass(type(t), StatefulTensor):
- assert t.payload.device == get_current_device()
- else:
- assert t.device == get_current_device()
- p = colo_model_tensor_clone(t, get_current_device())
- assert p.device == get_current_device()
- for i in range(2):
- for j in range(2):
- if issubclass(type(t), StatefulTensor):
- assert t.payload.device == p.device
- assert t.payload[i][j] == p[i][j]
- else:
- assert t.device == p.device
- assert t[i][j] == p[i][j]
-
-
-def run_dist(rank, world_size, port):
- colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
-
- _run_colo_tensor_mem_usage()
- _run_colo_model_data_tensor_move_inline()
- _run_colo_model_data_tensor_move()
- _run_colo_model_data_move_to_cpu()
- _run_colo_model_tensor_clone()
-
-
-@pytest.mark.dist
-@pytest.mark.parametrize("world_size", [2, 4])
-@rerun_if_address_is_in_use()
-def test_zero_tensor_utils(world_size):
- spawn(run_dist, world_size)
-
-
-if __name__ == '__main__':
- test_zero_tensor_utils(world_size=2)
diff --git a/tests/test_zero/test_legacy/test_zero_engine.py b/tests/test_zero/test_legacy/test_zero_engine.py
deleted file mode 100644
index 826a543db861..000000000000
--- a/tests/test_zero/test_legacy/test_zero_engine.py
+++ /dev/null
@@ -1,113 +0,0 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
-import pytest
-import torch
-import torch.distributed as dist
-from common import MP_PARALLEL_CONFIG, ZERO_PARALLEL_CONFIG, check_params, check_sharded_model_params
-from torch.nn.parallel import DistributedDataParallel as DDP
-
-import colossalai
-from colossalai.core import global_context as gpc
-from colossalai.testing import rerun_if_address_is_in_use, spawn
-from colossalai.zero.legacy.init_ctx import ZeroInitContext
-from colossalai.zero.legacy.sharded_model.utils import col_model_deepcopy
-from colossalai.zero.low_level._utils import has_inf_or_nan
-from tests.components_to_test.registry import non_distributed_component_funcs
-
-
-def run_dist(rank, world_size, port, parallel_config, bf16):
- is_mp_config = parallel_config == MP_PARALLEL_CONFIG
- is_zero_config = parallel_config == ZERO_PARALLEL_CONFIG
- if bf16:
- parallel_config['zero']['model_config']['bf16'] = True
- colossalai.launch(config=parallel_config,
- rank=rank,
- world_size=world_size,
- host='localhost',
- port=port,
- backend='nccl')
-
- test_models = ['repeated_computed_layers', 'resnet18', 'bert']
- for model_name in test_models:
- get_components_func = non_distributed_component_funcs.get_callable(model_name)
- model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()
- with ZeroInitContext(target_device=torch.cuda.current_device(),
- shard_strategy=gpc.config.zero.model_config.shard_strategy,
- shard_param=True,
- bf16=bf16):
- colo_model = model_builder(checkpoint=True)
-
- colo_optimizer = optimizer_class(colo_model.parameters(), lr=1e-3)
- engine, train_dataloader, _, _ = colossalai.initialize(colo_model,
- optimizer=colo_optimizer,
- criterion=criterion,
- train_dataloader=train_dataloader)
- dtype = torch.bfloat16 if bf16 else torch.float16
- torch_model = model_builder(checkpoint=True).to(dtype)
- col_model_deepcopy(engine.model, torch_model)
- torch_model = torch_model.cuda().float()
-
- engine.train()
- torch_optimizer = optimizer_class(torch_model.parameters(), lr=1e-3)
-
- if dist.get_world_size() > 1:
- torch_model = DDP(torch_model, device_ids=[torch.cuda.current_device()])
-
- i = 0
- for data, label in train_dataloader:
- if i > 4:
- break
-
- data, label = data.cuda(), label.cuda()
-
- engine.zero_grad()
- torch_optimizer.zero_grad()
-
- if criterion:
- output = engine(data)
- loss = engine.criterion(output, label)
-
- torch_output = torch_model(data)
- torch_loss = engine.criterion(torch_output, label)
- else:
- loss = engine(data, label)
- torch_loss = torch_model(data, label)
-
- engine.backward(loss)
- engine.step()
-
- torch_loss.backward()
-
- for param in torch_model.parameters():
- if param.grad is not None:
- assert not has_inf_or_nan(param.grad)
-
- torch_optimizer.step()
- i += 1
-
- if is_mp_config:
- check_params(torch_model, colo_model, loose=True)
- elif is_zero_config:
- check_sharded_model_params(torch_model, colo_model, loose=True)
-
-
-# FIXME: enable this test in next PR
-@pytest.mark.skip
-@pytest.mark.dist
-@pytest.mark.parametrize("world_size", [2, 4])
-@rerun_if_address_is_in_use()
-def test_mp_engine(world_size):
- spawn(run_dist, world_size, parallel_config=MP_PARALLEL_CONFIG)
-
-
-@pytest.mark.dist
-@pytest.mark.parametrize("world_size", [1, 2])
-@pytest.mark.parametrize("bf16", [True, False])
-@rerun_if_address_is_in_use()
-def test_zero_engine(world_size, bf16):
- spawn(run_dist, world_size, parallel_config=ZERO_PARALLEL_CONFIG, bf16=bf16)
-
-
-if __name__ == '__main__':
- test_zero_engine(world_size=4)
From da4f7b855f0074b374bbd26837c036f2cdfa9564 Mon Sep 17 00:00:00 2001
From: Wenhao Chen
Date: Wed, 2 Aug 2023 10:17:36 +0800
Subject: [PATCH 46/64] [chat] fix bugs and add unit tests (#4213)
* style: rename replay buffer
Experience replay is typically for off policy algorithms.
Use this name in PPO maybe misleading.
* fix: fix wrong zero2 default arg
* test: update experience tests
* style: rename zero_pad fn
* fix: defer init in CycledDataLoader
* test: add benchmark test
* style: rename internal fn of generation
* style: rename internal fn of lora
* fix: remove unused loss fn
* fix: remove unused utils fn
* refactor: remove generate_with_actor fn
* fix: fix type annotation
* test: add models tests
* fix: skip llama due to long execution time
* style: modify dataset
* style: apply formatter
* perf: update reward dataset
* fix: fix wrong IGNORE_INDEX in sft dataset
* fix: remove DataCollatorForSupervisedDataset
* test: add dataset tests
* style: apply formatter
* style: rename test_ci to test_train
* feat: add llama in inference
* test: add inference tests
* test: change test scripts directory
* fix: update ci
* fix: fix typo
* fix: skip llama due to oom
* fix: fix file mod
* style: apply formatter
* refactor: remove duplicated llama_gptq
* style: apply formatter
* to: update rm test
* feat: add tokenizer arg
* feat: add download model script
* test: update train tests
* fix: modify gemini load and save pretrained
* test: update checkpoint io test
* to: modify nproc_per_node
* fix: do not remove existing dir
* fix: modify save path
* test: add random choice
* fix: fix sft path
* fix: enlarge nproc_per_node to avoid oom
* fix: add num_retry
* fix: make lora config of rm and critic consistent
* fix: add warning about lora weights
* fix: skip some gpt2 tests
* fix: remove grad ckpt in rm and critic due to errors
* refactor: directly use Actor in train_sft
* test: add more arguments
* fix: disable grad ckpt when using lora
* fix: fix save_pretrained and related tests
* test: enable zero2 tests
* revert: remove useless fn
* style: polish code
* test: modify test args
---
.github/workflows/run_chatgpt_examples.yml | 4 +-
applications/Chat/coati/dataset/__init__.py | 7 +-
.../Chat/coati/dataset/prompt_dataset.py | 18 +-
.../Chat/coati/dataset/reward_dataset.py | 130 ++++----
.../Chat/coati/dataset/sft_dataset.py | 243 +++++----------
.../Chat/coati/experience_buffer/__init__.py | 4 +
.../base.py | 4 +-
.../naive.py | 6 +-
.../utils.py | 10 +-
.../Chat/coati/experience_maker/naive.py | 30 +-
applications/Chat/coati/models/__init__.py | 4 +-
.../Chat/coati/models/bloom/bloom_critic.py | 5 +-
.../Chat/coati/models/bloom/bloom_rm.py | 5 +-
applications/Chat/coati/models/generation.py | 97 ++----
.../Chat/coati/models/gpt/gpt_critic.py | 5 +-
applications/Chat/coati/models/gpt/gpt_rm.py | 4 -
.../Chat/coati/models/llama/llama_critic.py | 6 -
.../Chat/coati/models/llama/llama_rm.py | 4 -
applications/Chat/coati/models/lora.py | 10 +-
applications/Chat/coati/models/loss.py | 25 --
.../Chat/coati/models/opt/opt_critic.py | 5 +-
applications/Chat/coati/models/opt/opt_rm.py | 4 -
applications/Chat/coati/models/utils.py | 52 +---
.../ray/callbacks/performance_evaluator.py | 22 +-
.../Chat/coati/ray/detached_replay_buffer.py | 4 +-
.../Chat/coati/ray/detached_trainer_base.py | 2 +-
.../Chat/coati/ray/experience_maker_holder.py | 24 +-
.../Chat/coati/ray/lora_constructor.py | 14 +-
.../Chat/coati/replay_buffer/__init__.py | 4 -
applications/Chat/coati/trainer/base.py | 14 +-
.../callbacks/performance_evaluator.py | 18 +-
applications/Chat/coati/trainer/ppo.py | 8 +-
.../Chat/coati/trainer/strategies/base.py | 4 +-
.../coati/trainer/strategies/colossalai.py | 14 +-
.../Chat/coati/trainer/strategies/ddp.py | 45 ++-
.../Chat/coati/trainer/strategies/sampler.py | 1 -
applications/Chat/coati/trainer/utils.py | 6 +-
applications/Chat/examples/download_model.py | 84 ++++++
.../Chat/examples/generate_prompt_dataset.py | 9 +-
applications/Chat/examples/inference.py | 39 ++-
applications/Chat/examples/test_ci.sh | 160 ----------
applications/Chat/examples/train_prompts.py | 34 ++-
.../Chat/examples/train_reward_model.py | 32 +-
applications/Chat/examples/train_rm.sh | 24 +-
applications/Chat/examples/train_sft.py | 67 +++--
applications/Chat/examples/train_sft.sh | 21 +-
applications/Chat/inference/benchmark.py | 9 +-
.../Chat/inference/llama_gptq/__init__.py | 5 -
.../Chat/inference/llama_gptq/loader.py | 41 ---
.../Chat/inference/llama_gptq/model_utils.py | 13 -
.../Chat/inference/llama_gptq/quant.py | 283 ------------------
applications/Chat/inference/locustfile.py | 3 +-
applications/Chat/inference/server.py | 15 +-
.../Chat/inference/tests/test_chat_prompt.py | 17 +-
applications/Chat/inference/utils.py | 10 +-
applications/Chat/tests/test_benchmarks.sh | 33 ++
applications/Chat/tests/test_checkpoint.py | 101 ++++---
applications/Chat/tests/test_dataset.py | 248 +++++++++++++++
.../{test_data.py => test_experience.py} | 30 +-
applications/Chat/tests/test_inference.sh | 11 +
applications/Chat/tests/test_models.py | 235 +++++++++++++++
applications/Chat/tests/test_train.sh | 228 ++++++++++++++
62 files changed, 1408 insertions(+), 1206 deletions(-)
create mode 100644 applications/Chat/coati/experience_buffer/__init__.py
rename applications/Chat/coati/{replay_buffer => experience_buffer}/base.py (91%)
rename applications/Chat/coati/{replay_buffer => experience_buffer}/naive.py (92%)
rename applications/Chat/coati/{replay_buffer => experience_buffer}/utils.py (83%)
delete mode 100644 applications/Chat/coati/replay_buffer/__init__.py
create mode 100644 applications/Chat/examples/download_model.py
delete mode 100755 applications/Chat/examples/test_ci.sh
delete mode 100644 applications/Chat/inference/llama_gptq/__init__.py
delete mode 100644 applications/Chat/inference/llama_gptq/loader.py
delete mode 100644 applications/Chat/inference/llama_gptq/model_utils.py
delete mode 100644 applications/Chat/inference/llama_gptq/quant.py
create mode 100755 applications/Chat/tests/test_benchmarks.sh
create mode 100644 applications/Chat/tests/test_dataset.py
rename applications/Chat/tests/{test_data.py => test_experience.py} (82%)
create mode 100755 applications/Chat/tests/test_inference.sh
create mode 100644 applications/Chat/tests/test_models.py
create mode 100755 applications/Chat/tests/test_train.sh
diff --git a/.github/workflows/run_chatgpt_examples.yml b/.github/workflows/run_chatgpt_examples.yml
index 510f6b6f0985..650689498fda 100644
--- a/.github/workflows/run_chatgpt_examples.yml
+++ b/.github/workflows/run_chatgpt_examples.yml
@@ -43,7 +43,9 @@ jobs:
run: |
cd applications/Chat
rm -rf ~/.cache/colossalai
- ./examples/test_ci.sh
+ ./tests/test_inference.sh
+ ./tests/test_benchmarks.sh
+ ./tests/test_train.sh
env:
NCCL_SHM_DISABLE: 1
MAX_JOBS: 8
diff --git a/applications/Chat/coati/dataset/__init__.py b/applications/Chat/coati/dataset/__init__.py
index f650668e90b0..bd4e5460d11e 100644
--- a/applications/Chat/coati/dataset/__init__.py
+++ b/applications/Chat/coati/dataset/__init__.py
@@ -1,9 +1,10 @@
from .prompt_dataset import PromptDataset
from .reward_dataset import HhRlhfDataset, RmStaticDataset
-from .sft_dataset import DataCollatorForSupervisedDataset, SFTDataset, SupervisedDataset
+from .sft_dataset import SFTDataset, SupervisedDataset
from .utils import is_rank_0
__all__ = [
- 'RmStaticDataset', 'HhRlhfDataset', 'is_rank_0', 'SFTDataset', 'SupervisedDataset',
- 'DataCollatorForSupervisedDataset', 'PromptDataset'
+ 'RmStaticDataset', 'HhRlhfDataset',
+ 'SFTDataset', 'SupervisedDataset',
+ 'PromptDataset', 'is_rank_0',
]
diff --git a/applications/Chat/coati/dataset/prompt_dataset.py b/applications/Chat/coati/dataset/prompt_dataset.py
index 0bdcbbc5928e..2c953fffa513 100644
--- a/applications/Chat/coati/dataset/prompt_dataset.py
+++ b/applications/Chat/coati/dataset/prompt_dataset.py
@@ -1,20 +1,13 @@
-import copy
-import random
from collections import defaultdict
-from dataclasses import dataclass, field
-from typing import Callable, Dict, Sequence
+from typing import Dict
import torch
-import torch.distributed as dist
import transformers
from torch.utils.data import Dataset
-from tqdm import tqdm
from colossalai.logging import get_dist_logger
-from .utils import is_rank_0, jload
-
-logger = get_dist_logger()
+from .utils import jload
class PromptDataset(Dataset):
@@ -27,12 +20,13 @@ def __init__(self,
max_length: int = 96):
super(PromptDataset, self).__init__()
self.keyed_prompt = defaultdict(list)
- logger.info("Loading data...")
+ self.logger = get_dist_logger()
+ self.logger.info("Loading data...")
list_data_dict = jload(data_path)
- logger.info(f"Loaded {len(list_data_dict)} examples.")
+ self.logger.info(f"Loaded {len(list_data_dict)} examples.")
if max_datasets_size is not None:
- logger.info(f"Limiting dataset to {max_datasets_size} examples.")
+ self.logger.info(f"Limiting dataset to {max_datasets_size} examples.")
list_data_dict = list_data_dict[:max_datasets_size]
instructions = [data_dict["instruction"] for data_dict in list_data_dict]
diff --git a/applications/Chat/coati/dataset/reward_dataset.py b/applications/Chat/coati/dataset/reward_dataset.py
index 5dacf7e81464..3c4ec8b214bb 100644
--- a/applications/Chat/coati/dataset/reward_dataset.py
+++ b/applications/Chat/coati/dataset/reward_dataset.py
@@ -20,44 +20,44 @@ class RmStaticDataset(Dataset):
def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None:
super().__init__()
- self.chosen = []
- self.reject = []
- if special_token is None:
- self.end_token = tokenizer.eos_token
- else:
- self.end_token = special_token
- for data in tqdm(dataset, disable=not is_rank_0()):
- prompt = data['prompt']
-
- chosen = prompt + data['chosen'] + self.end_token
- chosen_token = tokenizer(chosen,
- max_length=max_length,
- padding="max_length",
- truncation=True,
- return_tensors="pt")
- self.chosen.append({
- "input_ids": chosen_token['input_ids'],
- "attention_mask": chosen_token['attention_mask']
- })
-
- reject = prompt + data['rejected'] + self.end_token
- reject_token = tokenizer(reject,
- max_length=max_length,
- padding="max_length",
- truncation=True,
- return_tensors="pt")
- self.reject.append({
- "input_ids": reject_token['input_ids'],
- "attention_mask": reject_token['attention_mask']
- })
+ self.end_token = tokenizer.eos_token \
+ if special_token is None else special_token
+
+ chosen = [
+ data["prompt"] + data["chosen"] + self.end_token
+ for data in tqdm(dataset, disable=not is_rank_0())
+ ]
+ chosen_token = tokenizer(chosen,
+ max_length=max_length,
+ padding="max_length",
+ truncation=True,
+ return_tensors="pt")
+ self.chosen = {
+ "input_ids": chosen_token["input_ids"],
+ "attention_mask": chosen_token["attention_mask"]
+ }
+
+ reject = [
+ data["prompt"] + data["rejected"] + self.end_token
+ for data in tqdm(dataset, disable=not is_rank_0())
+ ]
+ reject_token = tokenizer(reject,
+ max_length=max_length,
+ padding="max_length",
+ truncation=True,
+ return_tensors="pt")
+ self.reject = {
+ "input_ids": reject_token["input_ids"],
+ "attention_mask": reject_token["attention_mask"]
+ }
def __len__(self):
- length = len(self.chosen)
+ length = self.chosen["input_ids"].shape[0]
return length
def __getitem__(self, idx):
- return self.chosen[idx]["input_ids"], self.chosen[idx]["attention_mask"], self.reject[idx][
- "input_ids"], self.reject[idx]["attention_mask"]
+ return self.chosen["input_ids"][idx], self.chosen["attention_mask"][idx], \
+ self.reject["input_ids"][idx], self.reject["attention_mask"][idx]
# Anthropic/hh-rlhf
@@ -74,39 +74,41 @@ class HhRlhfDataset(Dataset):
def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None:
super().__init__()
- self.chosen = []
- self.reject = []
- if special_token is None:
- self.end_token = tokenizer.eos_token
- else:
- self.end_token = special_token
- for data in tqdm(dataset, disable=not is_rank_0()):
- chosen = data['chosen'] + self.end_token
- chosen_token = tokenizer(chosen,
- max_length=max_length,
- padding="max_length",
- truncation=True,
- return_tensors="pt")
- self.chosen.append({
- "input_ids": chosen_token['input_ids'],
- "attention_mask": chosen_token['attention_mask']
- })
-
- reject = data['rejected'] + self.end_token
- reject_token = tokenizer(reject,
- max_length=max_length,
- padding="max_length",
- truncation=True,
- return_tensors="pt")
- self.reject.append({
- "input_ids": reject_token['input_ids'],
- "attention_mask": reject_token['attention_mask']
- })
+ self.end_token = tokenizer.eos_token \
+ if special_token is None else special_token
+
+ chosen = [
+ data["chosen"] + self.end_token
+ for data in tqdm(dataset, disable=not is_rank_0())
+ ]
+ chosen_token = tokenizer(chosen,
+ max_length=max_length,
+ padding="max_length",
+ truncation=True,
+ return_tensors="pt")
+ self.chosen = {
+ "input_ids": chosen_token["input_ids"],
+ "attention_mask": chosen_token["attention_mask"]
+ }
+
+ reject = [
+ data["rejected"] + self.end_token
+ for data in tqdm(dataset, disable=not is_rank_0())
+ ]
+ reject_token = tokenizer(reject,
+ max_length=max_length,
+ padding="max_length",
+ truncation=True,
+ return_tensors="pt")
+ self.reject = {
+ "input_ids": reject_token["input_ids"],
+ "attention_mask": reject_token["attention_mask"]
+ }
def __len__(self):
- length = len(self.chosen)
+ length = self.chosen["input_ids"].shape[0]
return length
def __getitem__(self, idx):
- return self.chosen[idx]["input_ids"], self.chosen[idx]["attention_mask"], self.reject[idx][
- "input_ids"], self.reject[idx]["attention_mask"]
+ return self.chosen["input_ids"][idx], self.chosen["attention_mask"][idx], \
+ self.reject["input_ids"][idx], self.reject["attention_mask"][idx]
diff --git a/applications/Chat/coati/dataset/sft_dataset.py b/applications/Chat/coati/dataset/sft_dataset.py
index 0b04cf79ee54..636b4e6772cb 100644
--- a/applications/Chat/coati/dataset/sft_dataset.py
+++ b/applications/Chat/coati/dataset/sft_dataset.py
@@ -13,44 +13,64 @@
# limitations under the License.
import copy
-import random
-from dataclasses import dataclass, field
-from typing import Callable, Dict, List, Sequence, Tuple
+from typing import Dict, Sequence, Tuple
import torch
-import torch.distributed as dist
-import transformers
from torch.utils.data import Dataset
from tqdm import tqdm
+from transformers import PreTrainedTokenizer
from colossalai.logging import get_dist_logger
-from .conversation import default_conversation
from .utils import is_rank_0, jload
-# The following is a template prompt for a 4-round conversation.
-"""
-A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
-
-Human: xxxAssistant: xxxHuman: xxxAssistant: xxxHuman: xxxAssistant: xxxHuman: xxxAssistant: xxx
-"""
-# Please note that we only calculate loss on assistant's answer tokens.
-
logger = get_dist_logger()
IGNORE_INDEX = -100
-DEFAULT_EOS_TOKEN = ""
PROMPT_DICT = {
- "prompt_input":
- ("Below is an instruction that describes a task, paired with an input that provides further context. "
- "Write a response that appropriately completes the request.\n\n"
- "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"),
+ "prompt_input": ("Below is an instruction that describes a task, paired with an input that provides further context. "
+ "Write a response that appropriately completes the request.\n\n"
+ "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"),
"prompt_no_input": ("Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Response:"),
}
+def _preprocess(sources: Sequence[str],
+ targets: Sequence[str],
+ tokenizer: PreTrainedTokenizer,
+ max_length: int,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Preprocess the data by tokenizing."""
+ sequences = [s + t for s, t in zip(sources, targets)]
+ sequences_token = tokenizer(sequences,
+ max_length=max_length,
+ padding="max_length",
+ truncation=True,
+ return_tensors="pt")
+ sources_token = tokenizer(sources,
+ max_length=max_length,
+ padding="max_length",
+ truncation=True,
+ return_tensors="pt")
+
+ labels = copy.deepcopy(sequences_token["input_ids"])
+ for i in range(labels.shape[0]):
+ source_len = sources_token["attention_mask"][i].sum().item()
+ pad_len = max_length - sequences_token["attention_mask"][i].sum().item()
+ if tokenizer.padding_side == "right":
+ # |prompt|completion|eos|pad|
+ labels[i][:source_len] = IGNORE_INDEX
+ elif tokenizer.padding_side == "left":
+ # |pad|prompt|completion|eos|
+ labels[i][pad_len:pad_len + source_len] = IGNORE_INDEX
+ else:
+ raise RuntimeError()
+
+ return sequences_token["input_ids"], labels, sequences_token["attention_mask"]
+
+
class SFTDataset(Dataset):
"""
Dataset for sft model
@@ -61,115 +81,31 @@ class SFTDataset(Dataset):
max_length: max length of input
"""
- def __init__(self, dataset, tokenizer: Callable, max_length: int = 512) -> None:
+ def __init__(self,
+ dataset: Dict,
+ tokenizer: PreTrainedTokenizer,
+ max_length: int = 512
+ ) -> None:
super().__init__()
self.input_ids = []
- for data in tqdm(dataset, disable=not is_rank_0()):
- prompt = data['prompt'] + data['completion'] + tokenizer.eos_token
- prompt_token = tokenizer(prompt,
- max_length=max_length,
- padding="max_length",
- truncation=True,
- return_tensors="pt")
+ sources = [data["prompt"] for data in dataset]
+ targets = [
+ data["completion"] + tokenizer.eos_token
+ for data in tqdm(dataset, disable=not is_rank_0())
+ ]
- self.input_ids.append(prompt_token['input_ids'][0])
- self.labels = copy.deepcopy(self.input_ids)
+ self.input_ids, self.labels, self.attention_mask = \
+ _preprocess(sources, targets, tokenizer, max_length)
def __len__(self):
- length = len(self.input_ids)
+ length = self.input_ids.shape[0]
return length
def __getitem__(self, idx):
- return dict(input_ids=self.input_ids[idx], labels=self.labels[idx])
-
-
-def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer,
- max_length: int) -> Dict[str, torch.Tensor]:
- """Tokenize a list of strings."""
- tokenized_list = tokenizer(strings, return_tensors="pt", padding="longest", max_length=max_length, truncation=True)
- input_ids = labels = tokenized_list["input_ids"]
- input_ids_lens = labels_lens = \
- tokenized_list["input_ids"].ne(tokenizer.pad_token_id).sum(dim=-1)
- return dict(
- input_ids=input_ids,
- labels=labels,
- input_ids_lens=input_ids_lens,
- labels_lens=labels_lens,
- )
-
-
-def preprocess(
- sources: Sequence[str],
- targets: Sequence[str],
- tokenizer: transformers.PreTrainedTokenizer,
- max_length: int,
-) -> Dict:
- """Preprocess the data by tokenizing."""
- examples = [s + t for s, t in zip(sources, targets)]
- examples_tokenized, sources_tokenized = [
- _tokenize_fn(strings, tokenizer, max_length) for strings in (examples, sources)
- ]
- input_ids = examples_tokenized["input_ids"]
- labels = copy.deepcopy(input_ids)
- for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
- label[:source_len] = IGNORE_INDEX
- return dict(input_ids=input_ids, labels=labels)
-
-
-def preprocess_conversation(sources: List[List[Dict]], tokenizer: transformers.PreTrainedTokenizer,
- max_length: int) -> Dict:
- """Preprocess the conversation data by tokenizing."""
- conversations = []
- intermediates = []
- for source in sources:
- header = f"{default_conversation.system}"
- conversation, intermediate = _add_speaker_and_signal(header, source)
- conversations.append(conversation)
- intermediates.append(intermediate)
-
- conversations_tokenized = _tokenize_fn(conversations, tokenizer, max_length)
- input_ids = conversations_tokenized["input_ids"]
- targets = copy.deepcopy(input_ids)
-
- assert len(targets) == len(intermediates)
- for target, inters in zip(targets, intermediates):
- mask = torch.zeros_like(target, dtype=torch.bool)
- for inter in inters:
- tokenized = _tokenize_fn(inter, tokenizer, max_length)
-
- start_idx = tokenized["input_ids"][0].size(0) - 1
- end_idx = tokenized["input_ids"][1].size(0)
-
- mask[start_idx:end_idx] = True
- target[~mask] = IGNORE_INDEX
-
- return dict(input_ids=input_ids, labels=targets)
-
-
-def _add_speaker_and_signal(header: str,
- source: List[Dict],
- get_conversation: bool = True) -> Tuple[str, List[List[str]]]:
- END_SIGNAL = DEFAULT_EOS_TOKEN
- conversation = header
- intermediate = []
- for sentence in source:
- from_str = sentence["from"]
- if from_str.lower() == "human":
- from_str = default_conversation.roles[0]
- elif from_str.lower() == "gpt":
- from_str = default_conversation.roles[1]
- else:
- from_str = 'unknown'
-
- value = from_str + ": " + sentence["value"] + END_SIGNAL
- if sentence["from"].lower() == "gpt":
- start = conversation + from_str + ": "
- end = conversation + value
- intermediate.append([start, end])
- if get_conversation:
- conversation += value
- return conversation, intermediate
+ return dict(input_ids=self.input_ids[idx],
+ labels=self.labels[idx],
+ attention_mask=self.attention_mask[idx])
class SupervisedDataset(Dataset):
@@ -177,10 +113,10 @@ class SupervisedDataset(Dataset):
def __init__(self,
data_path: str,
- tokenizer: transformers.PreTrainedTokenizer,
+ tokenizer: PreTrainedTokenizer,
max_datasets_size: int = None,
max_length: int = 512):
- super(SupervisedDataset, self).__init__()
+ super().__init__()
logger.info("Loading data...")
list_data_dict = jload(data_path)
logger.info(f"Loaded {len(list_data_dict)} examples.")
@@ -190,52 +126,25 @@ def __init__(self,
list_data_dict = list_data_dict[:max_datasets_size]
logger.info("Formatting inputs...")
- if "conversations" not in list_data_dict[0]:
- prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
- sources = [
- prompt_input.format_map(example)
- if example.get("input", "") != "" else prompt_no_input.format_map(example) for example in list_data_dict
- ]
- targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]
-
- if is_rank_0():
- logger.info("Tokenizing inputs... This may take some time...")
-
- data_dict = preprocess(sources, targets, tokenizer, max_length)
- else:
- if is_rank_0():
- logger.info("Tokenizing inputs... This may take some time...")
-
- sources = [conv["conversations"] for conv in list_data_dict]
- data_dict = preprocess_conversation(sources, tokenizer, max_length)
-
- if is_rank_0():
- logger.info("Tokenizing finish.")
-
- self.input_ids = data_dict["input_ids"]
- self.labels = data_dict["labels"]
+ prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
+ sources = [
+ prompt_input.format_map(example) if "input" in example else prompt_no_input.format_map(example)
+ for example in list_data_dict
+ ]
+ targets = [
+ example['output'] + tokenizer.eos_token
+ for example in list_data_dict
+ ]
+
+ logger.info("Tokenizing inputs... This may take some time...")
+ self.input_ids, self.labels, self.attention_mask = \
+ _preprocess(sources, targets, tokenizer, max_length)
def __len__(self):
- return len(self.input_ids)
-
- def __getitem__(self, i) -> Dict[str, torch.Tensor]:
- return dict(input_ids=self.input_ids[i], labels=self.labels[i])
-
-
-@dataclass
-class DataCollatorForSupervisedDataset(object):
- """Collate examples for supervised fine-tuning."""
-
- tokenizer: transformers.PreTrainedTokenizer
+ length = self.input_ids.shape[0]
+ return length
- def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
- input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
- input_ids = torch.nn.utils.rnn.pad_sequence(input_ids,
- batch_first=True,
- padding_value=self.tokenizer.pad_token_id)
- labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
- return dict(
- input_ids=input_ids,
- labels=labels,
- attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
- )
+ def __getitem__(self, idx):
+ return dict(input_ids=self.input_ids[idx],
+ labels=self.labels[idx],
+ attention_mask=self.attention_mask[idx])
diff --git a/applications/Chat/coati/experience_buffer/__init__.py b/applications/Chat/coati/experience_buffer/__init__.py
new file mode 100644
index 000000000000..c0188dc4a471
--- /dev/null
+++ b/applications/Chat/coati/experience_buffer/__init__.py
@@ -0,0 +1,4 @@
+from .base import ExperienceBuffer
+from .naive import NaiveExperienceBuffer
+
+__all__ = ['ExperienceBuffer', 'NaiveExperienceBuffer']
diff --git a/applications/Chat/coati/replay_buffer/base.py b/applications/Chat/coati/experience_buffer/base.py
similarity index 91%
rename from applications/Chat/coati/replay_buffer/base.py
rename to applications/Chat/coati/experience_buffer/base.py
index 4c3812461a10..9ccdc935d506 100644
--- a/applications/Chat/coati/replay_buffer/base.py
+++ b/applications/Chat/coati/experience_buffer/base.py
@@ -4,8 +4,8 @@
from coati.experience_maker.base import Experience
-class ReplayBuffer(ABC):
- """Replay buffer base class. It stores experience.
+class ExperienceBuffer(ABC):
+ """Experience buffer base class. It stores experience.
Args:
sample_batch_size (int): Batch size when sampling.
diff --git a/applications/Chat/coati/replay_buffer/naive.py b/applications/Chat/coati/experience_buffer/naive.py
similarity index 92%
rename from applications/Chat/coati/replay_buffer/naive.py
rename to applications/Chat/coati/experience_buffer/naive.py
index 938f500643c9..bd5213b38993 100644
--- a/applications/Chat/coati/replay_buffer/naive.py
+++ b/applications/Chat/coati/experience_buffer/naive.py
@@ -4,12 +4,12 @@
import torch
from coati.experience_maker.base import Experience
-from .base import ReplayBuffer
+from .base import ExperienceBuffer
from .utils import BufferItem, make_experience_batch, split_experience_batch
-class NaiveReplayBuffer(ReplayBuffer):
- """Naive replay buffer class. It stores experience.
+class NaiveExperienceBuffer(ExperienceBuffer):
+ """Naive experience buffer class. It stores experience.
Args:
sample_batch_size (int): Batch size when sampling.
diff --git a/applications/Chat/coati/replay_buffer/utils.py b/applications/Chat/coati/experience_buffer/utils.py
similarity index 83%
rename from applications/Chat/coati/replay_buffer/utils.py
rename to applications/Chat/coati/experience_buffer/utils.py
index 6ad0db2c3b60..c2a34212e2f4 100644
--- a/applications/Chat/coati/replay_buffer/utils.py
+++ b/applications/Chat/coati/experience_buffer/utils.py
@@ -33,7 +33,8 @@ class BufferItem:
def split_experience_batch(experience: Experience) -> List[BufferItem]:
batch_size = experience.sequences.size(0)
batch_kwargs = [{} for _ in range(batch_size)]
- keys = ('sequences', 'action_log_probs', 'values', 'reward', 'advantages', 'attention_mask', 'action_mask')
+ keys = ('sequences', 'action_log_probs', 'values',
+ 'reward', 'advantages', 'attention_mask', 'action_mask')
for key in keys:
value = getattr(experience, key)
if isinstance(value, torch.Tensor):
@@ -48,7 +49,7 @@ def split_experience_batch(experience: Experience) -> List[BufferItem]:
return items
-def zero_pad_sequences(sequences: List[torch.Tensor], side: str = 'left') -> torch.Tensor:
+def _zero_pad_sequences(sequences: List[torch.Tensor], side: str = 'left') -> torch.Tensor:
assert side in ('left', 'right')
max_len = max(seq.size(0) for seq in sequences)
padded_sequences = []
@@ -62,11 +63,12 @@ def zero_pad_sequences(sequences: List[torch.Tensor], side: str = 'left') -> tor
def make_experience_batch(items: List[BufferItem]) -> Experience:
kwargs = {}
to_pad_keys = set(('action_log_probs', 'action_mask'))
- keys = ('sequences', 'action_log_probs', 'values', 'reward', 'advantages', 'attention_mask', 'action_mask')
+ keys = ('sequences', 'action_log_probs', 'values',
+ 'reward', 'advantages', 'attention_mask', 'action_mask')
for key in keys:
vals = [getattr(item, key) for item in items]
if key in to_pad_keys:
- batch_data = zero_pad_sequences(vals)
+ batch_data = _zero_pad_sequences(vals)
else:
batch_data = torch.stack(vals, dim=0)
kwargs[key] = batch_data
diff --git a/applications/Chat/coati/experience_maker/naive.py b/applications/Chat/coati/experience_maker/naive.py
index e5bb029e63d0..496f8ab445fc 100644
--- a/applications/Chat/coati/experience_maker/naive.py
+++ b/applications/Chat/coati/experience_maker/naive.py
@@ -1,6 +1,7 @@
import torch
-from coati.models.generation import generate_with_actor
-from coati.models.utils import calc_action_log_probs, compute_reward, normalize
+import torch.nn.functional as F
+from coati.models.generation import generate
+from coati.models.utils import calc_action_log_probs, compute_reward
from .base import Experience, ExperienceMaker
@@ -17,10 +18,27 @@ def make_experience(self, input_ids: torch.Tensor, **generate_kwargs) -> Experie
self.initial_model.eval()
self.reward_model.eval()
- sequences, attention_mask, action_mask = generate_with_actor(self.actor,
- input_ids,
- return_action_mask=True,
- **generate_kwargs)
+ # generate sequences
+ sequences = generate(self.actor, input_ids, **generate_kwargs)
+
+ # calculate auxiliary tensors
+ attention_mask = None
+ pad_token_id = generate_kwargs.get('pad_token_id', None)
+ if pad_token_id is not None:
+ attention_mask = sequences.not_equal(pad_token_id)\
+ .to(dtype=torch.long, device=sequences.device)
+
+ input_len = input_ids.size(1)
+ eos_token_id = generate_kwargs.get('eos_token_id', None)
+ if eos_token_id is None:
+ action_mask = torch.ones_like(sequences, dtype=torch.bool)
+ else:
+ # left padding may be applied, only mask action
+ action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
+ action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
+ action_mask[:, :input_len] = False
+ action_mask = action_mask[:, 1:]
+ action_mask = action_mask[:, -(sequences.size(1) - input_len):]
num_actions = action_mask.size(1)
actor_output = self.actor(sequences, attention_mask)
diff --git a/applications/Chat/coati/models/__init__.py b/applications/Chat/coati/models/__init__.py
index 709bc5ac0948..0a296a863756 100644
--- a/applications/Chat/coati/models/__init__.py
+++ b/applications/Chat/coati/models/__init__.py
@@ -1,8 +1,8 @@
from .base import Actor, Critic, RewardModel
from .lora import LoRAModule, convert_to_lora_module
-from .loss import LogExpLoss, LogSigLoss, PolicyLoss, PPOPtxActorLoss, ValueLoss
+from .loss import LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss
__all__ = [
- 'Actor', 'Critic', 'RewardModel', 'PolicyLoss', 'ValueLoss', 'PPOPtxActorLoss', 'LogSigLoss', 'LogExpLoss',
+ 'Actor', 'Critic', 'RewardModel', 'PolicyLoss', 'ValueLoss', 'LogSigLoss', 'LogExpLoss',
'LoRAModule', 'convert_to_lora_module'
]
diff --git a/applications/Chat/coati/models/bloom/bloom_critic.py b/applications/Chat/coati/models/bloom/bloom_critic.py
index a32fb2e102f9..a3716ca94138 100644
--- a/applications/Chat/coati/models/bloom/bloom_critic.py
+++ b/applications/Chat/coati/models/bloom/bloom_critic.py
@@ -14,7 +14,6 @@ class BLOOMCritic(Critic):
Args:
pretrained (str): Pretrained model name or path.
config (BloomConfig): Model config.
- checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): LoRA rank.
lora_train_bias (str): LoRA bias training mode.
"""
@@ -22,7 +21,6 @@ class BLOOMCritic(Critic):
def __init__(self,
pretrained: str = None,
config: Optional[BloomConfig] = None,
- checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none',
**kwargs) -> None:
@@ -32,7 +30,6 @@ def __init__(self,
model = BloomModel(config)
else:
model = BloomModel(BloomConfig())
- if checkpoint:
- model.gradient_checkpointing_enable()
+
value_head = nn.Linear(model.config.hidden_size, 1)
super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)
diff --git a/applications/Chat/coati/models/bloom/bloom_rm.py b/applications/Chat/coati/models/bloom/bloom_rm.py
index 22cfab441abb..e6ca9b1d4851 100644
--- a/applications/Chat/coati/models/bloom/bloom_rm.py
+++ b/applications/Chat/coati/models/bloom/bloom_rm.py
@@ -13,7 +13,6 @@ class BLOOMRM(RewardModel):
Args:
pretrained (str): Pretrained model name or path.
config (BloomConfig): Model config.
- checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): LoRA rank.
lora_train_bias (str): LoRA bias training mode.
"""
@@ -21,7 +20,6 @@ class BLOOMRM(RewardModel):
def __init__(self,
pretrained: str = None,
config: Optional[BloomConfig] = None,
- checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
if pretrained is not None:
@@ -30,8 +28,7 @@ def __init__(self,
model = BloomModel(config)
else:
model = BloomModel(BloomConfig())
- if checkpoint:
- model.gradient_checkpointing_enable()
+
value_head = nn.Linear(model.config.hidden_size, 1)
value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.hidden_size + 1))
super().__init__(model, value_head, lora_rank, lora_train_bias)
diff --git a/applications/Chat/coati/models/generation.py b/applications/Chat/coati/models/generation.py
index d96ad78a89ce..de0d63f95f50 100644
--- a/applications/Chat/coati/models/generation.py
+++ b/applications/Chat/coati/models/generation.py
@@ -1,9 +1,9 @@
-from typing import Any, Callable, Optional, Tuple, Union
+from typing import Any, Callable, Optional
import torch
import torch.distributed as dist
-import torch.nn as nn
-import torch.nn.functional as F
+
+from .base import Actor
try:
from transformers.generation_logits_process import (
@@ -16,9 +16,9 @@
from transformers.generation import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper
-def prepare_logits_processor(top_k: Optional[int] = None,
- top_p: Optional[float] = None,
- temperature: Optional[float] = None) -> LogitsProcessorList:
+def _prepare_logits_processor(top_k: Optional[int] = None,
+ top_p: Optional[float] = None,
+ temperature: Optional[float] = None) -> LogitsProcessorList:
processor_list = LogitsProcessorList()
if temperature is not None and temperature != 1.0:
processor_list.append(TemperatureLogitsWarper(temperature))
@@ -37,22 +37,22 @@ def _is_sequence_finished(unfinished_sequences: torch.Tensor) -> bool:
return unfinished_sequences.max() == 0
-def sample(model: nn.Module,
- input_ids: torch.Tensor,
- max_length: int,
- early_stopping: bool = False,
- eos_token_id: Optional[int] = None,
- pad_token_id: Optional[int] = None,
- top_k: Optional[int] = None,
- top_p: Optional[float] = None,
- temperature: Optional[float] = None,
- prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
- update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
- **model_kwargs) -> torch.Tensor:
+def _sample(model: Actor,
+ input_ids: torch.Tensor,
+ max_length: int,
+ early_stopping: bool = False,
+ eos_token_id: Optional[int] = None,
+ pad_token_id: Optional[int] = None,
+ top_k: Optional[int] = None,
+ top_p: Optional[float] = None,
+ temperature: Optional[float] = None,
+ prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
+ update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
+ **model_kwargs) -> torch.Tensor:
if input_ids.size(1) >= max_length:
return input_ids
- logits_processor = prepare_logits_processor(top_k, top_p, temperature)
+ logits_processor = _prepare_logits_processor(top_k, top_p, temperature)
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
for _ in range(input_ids.size(1), max_length):
@@ -89,7 +89,8 @@ def sample(model: nn.Module,
return input_ids
-def generate(model: nn.Module,
+@torch.no_grad()
+def generate(model: Actor,
input_ids: torch.Tensor,
max_length: int,
num_beams: int = 1,
@@ -128,51 +129,19 @@ def generate(model: nn.Module,
raise NotImplementedError
elif is_sample_gen_mode:
# run sample
- return sample(model,
- input_ids,
- max_length,
- early_stopping=early_stopping,
- eos_token_id=eos_token_id,
- pad_token_id=pad_token_id,
- top_k=top_k,
- top_p=top_p,
- temperature=temperature,
- prepare_inputs_fn=prepare_inputs_fn,
- update_model_kwargs_fn=update_model_kwargs_fn,
- **model_kwargs)
+ return _sample(model,
+ input_ids,
+ max_length,
+ early_stopping=early_stopping,
+ eos_token_id=eos_token_id,
+ pad_token_id=pad_token_id,
+ top_k=top_k,
+ top_p=top_p,
+ temperature=temperature,
+ prepare_inputs_fn=prepare_inputs_fn,
+ update_model_kwargs_fn=update_model_kwargs_fn,
+ **model_kwargs)
elif is_beam_gen_mode:
raise NotImplementedError
else:
raise ValueError("Unsupported generation mode")
-
-
-@torch.no_grad()
-def generate_with_actor(
- actor_model: nn.Module,
- input_ids: torch.Tensor,
- return_action_mask: bool = True,
- **kwargs
-) -> Union[Tuple[torch.LongTensor, torch.LongTensor], Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]:
- """Generate token sequence with actor model. Refer to `generate` for more details.
- """
- # generate sequences
- sequences = generate(actor_model, input_ids, **kwargs)
-
- # calculate auxiliary tensors
- attention_mask = None
- pad_token_id = kwargs.get('pad_token_id', None)
- if pad_token_id is not None:
- attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
- if not return_action_mask:
- return sequences, attention_mask, None
- input_len = input_ids.size(1)
- eos_token_id = kwargs.get('eos_token_id', None)
- if eos_token_id is None:
- action_mask = torch.ones_like(sequences, dtype=torch.bool)
- else:
- # left padding may be applied, only mask action
- action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
- action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
- action_mask[:, :input_len] = False
- action_mask = action_mask[:, 1:]
- return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len):]
diff --git a/applications/Chat/coati/models/gpt/gpt_critic.py b/applications/Chat/coati/models/gpt/gpt_critic.py
index 2e70f5f1fc96..01e1cd10ef57 100644
--- a/applications/Chat/coati/models/gpt/gpt_critic.py
+++ b/applications/Chat/coati/models/gpt/gpt_critic.py
@@ -14,7 +14,6 @@ class GPTCritic(Critic):
Args:
pretrained (str): Pretrained model name or path.
config (GPT2Config): Model config.
- checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the LO-RA decomposition.
lora_train_bias (str): LoRA bias training mode.
"""
@@ -22,7 +21,6 @@ class GPTCritic(Critic):
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[GPT2Config] = None,
- checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none',
**kwargs) -> None:
@@ -32,7 +30,6 @@ def __init__(self,
model = GPT2Model(config)
else:
model = GPT2Model(GPT2Config())
- if checkpoint:
- model.gradient_checkpointing_enable()
+
value_head = nn.Linear(model.config.n_embd, 1)
super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)
diff --git a/applications/Chat/coati/models/gpt/gpt_rm.py b/applications/Chat/coati/models/gpt/gpt_rm.py
index 054432e1ce86..e52a5a14c1da 100644
--- a/applications/Chat/coati/models/gpt/gpt_rm.py
+++ b/applications/Chat/coati/models/gpt/gpt_rm.py
@@ -14,7 +14,6 @@ class GPTRM(RewardModel):
Args:
pretrained (str): Pretrained model name or path.
config (GPT2Config): Model config.
- checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the low-rank approximation.
lora_train_bias (str): LoRA bias training mode.
"""
@@ -22,7 +21,6 @@ class GPTRM(RewardModel):
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[GPT2Config] = None,
- checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
if pretrained is not None:
@@ -31,8 +29,6 @@ def __init__(self,
model = GPT2Model(config)
else:
model = GPT2Model(GPT2Config())
- if checkpoint:
- model.gradient_checkpointing_enable()
value_head = nn.Linear(model.config.n_embd, 1)
value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.n_embd + 1))
diff --git a/applications/Chat/coati/models/llama/llama_critic.py b/applications/Chat/coati/models/llama/llama_critic.py
index dd9e5e7bfa1a..a67e5de5def6 100644
--- a/applications/Chat/coati/models/llama/llama_critic.py
+++ b/applications/Chat/coati/models/llama/llama_critic.py
@@ -13,7 +13,6 @@ class LlamaCritic(Critic):
Args:
pretrained (str): Pretrained model name or path.
config (LlamaConfig): Model config.
- checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): LoRA rank.
lora_train_bias (str): LoRA bias training mode.
"""
@@ -21,7 +20,6 @@ class LlamaCritic(Critic):
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[LlamaConfig] = None,
- checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none',
**kwargs) -> None:
@@ -33,9 +31,5 @@ def __init__(self,
else:
model = LlamaModel(LlamaConfig())
- if checkpoint:
- model.gradient_checkpointing_enable()
-
value_head = nn.Linear(model.config.hidden_size, 1)
-
super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)
diff --git a/applications/Chat/coati/models/llama/llama_rm.py b/applications/Chat/coati/models/llama/llama_rm.py
index f936019d62d2..d6b62922686e 100644
--- a/applications/Chat/coati/models/llama/llama_rm.py
+++ b/applications/Chat/coati/models/llama/llama_rm.py
@@ -13,7 +13,6 @@ class LlamaRM(RewardModel):
Args:
pretrained (str): Pretrained model name or path.
config (LlamaConfig): Model config.
- checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): LoRA rank.
lora_train_bias (str): LoRA bias training mode.
"""
@@ -21,7 +20,6 @@ class LlamaRM(RewardModel):
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[LlamaConfig] = None,
- checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
@@ -32,8 +30,6 @@ def __init__(self,
else:
model = LlamaModel(LlamaConfig())
- if checkpoint:
- model.gradient_checkpointing_enable()
value_head = nn.Linear(model.config.hidden_size, 1)
value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.hidden_size + 1))
diff --git a/applications/Chat/coati/models/lora.py b/applications/Chat/coati/models/lora.py
index 2a9059e6901e..546f675d7d37 100644
--- a/applications/Chat/coati/models/lora.py
+++ b/applications/Chat/coati/models/lora.py
@@ -98,18 +98,18 @@ def T(w):
return F.linear(x, T(self.weight), bias=self.bias)
-def lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear:
+def _lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear:
assert lora_rank <= linear.in_features, f'LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})'
lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank, merge_weights=False)
return lora_linear
-def convert_to_lora_recursively(module: nn.Module, lora_rank: int) -> None:
+def _convert_to_lora_recursively(module: nn.Module, lora_rank: int) -> None:
for name, child in module.named_children():
if isinstance(child, nn.Linear):
- setattr(module, name, lora_linear_wrapper(child, lora_rank))
+ setattr(module, name, _lora_linear_wrapper(child, lora_rank))
else:
- convert_to_lora_recursively(child, lora_rank)
+ _convert_to_lora_recursively(child, lora_rank)
def convert_to_lora_module(module: nn.Module, lora_rank: int, lora_train_bias: str = 'none') -> nn.Module:
@@ -124,7 +124,7 @@ def convert_to_lora_module(module: nn.Module, lora_rank: int, lora_train_bias: s
"""
if lora_rank <= 0:
return module
- convert_to_lora_recursively(module, lora_rank)
+ _convert_to_lora_recursively(module, lora_rank)
lora.mark_only_lora_as_trainable(module, lora_train_bias)
return module
diff --git a/applications/Chat/coati/models/loss.py b/applications/Chat/coati/models/loss.py
index 926c6e2a4e41..05a0b4821797 100644
--- a/applications/Chat/coati/models/loss.py
+++ b/applications/Chat/coati/models/loss.py
@@ -68,31 +68,6 @@ def forward(self,
return 0.5 * loss
-class PPOPtxActorLoss(nn.Module):
- """
- To Do:
-
- PPO-ptx Actor Loss
- """
-
- def __init__(self, policy_clip_eps: float = 0.2, pretrain_coef: float = 0.0, pretrain_loss_fn=GPTLMLoss()) -> None:
- super().__init__()
- self.pretrain_coef = pretrain_coef
- self.policy_loss_fn = PolicyLoss(clip_eps=policy_clip_eps)
- self.pretrain_loss_fn = pretrain_loss_fn
-
- def forward(self,
- log_probs: torch.Tensor,
- old_log_probs: torch.Tensor,
- advantages: torch.Tensor,
- lm_logits: torch.Tensor,
- lm_input_ids: torch.Tensor,
- action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
- policy_loss = self.policy_loss_fn(log_probs, old_log_probs, advantages, action_mask=action_mask)
- lm_loss = self.pretrain_loss_fn(lm_logits, lm_input_ids)
- return policy_loss + self.pretrain_coef * lm_loss
-
-
class LogSigLoss(nn.Module):
"""
Pairwise Loss for Reward Model
diff --git a/applications/Chat/coati/models/opt/opt_critic.py b/applications/Chat/coati/models/opt/opt_critic.py
index fcfebd8a8b03..f66c4173fa52 100644
--- a/applications/Chat/coati/models/opt/opt_critic.py
+++ b/applications/Chat/coati/models/opt/opt_critic.py
@@ -14,7 +14,6 @@ class OPTCritic(Critic):
Args:
pretrained (str): Pretrained model name or path.
config (OPTConfig): Model config.
- checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the low-rank approximation.
lora_train_bias (str): LoRA bias training mode.
"""
@@ -22,7 +21,6 @@ class OPTCritic(Critic):
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[OPTConfig] = None,
- checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none',
**kwargs) -> None:
@@ -32,7 +30,6 @@ def __init__(self,
model = OPTModel(config)
else:
model = OPTModel(OPTConfig())
- if checkpoint:
- model.gradient_checkpointing_enable()
+
value_head = nn.Linear(model.config.word_embed_proj_dim, 1)
super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)
diff --git a/applications/Chat/coati/models/opt/opt_rm.py b/applications/Chat/coati/models/opt/opt_rm.py
index 50fc0dee8568..6f75344e6aae 100644
--- a/applications/Chat/coati/models/opt/opt_rm.py
+++ b/applications/Chat/coati/models/opt/opt_rm.py
@@ -13,7 +13,6 @@ class OPTRM(RewardModel):
Args:
pretrained (str): Pretrained model name or path.
config (OPTConfig): Model config.
- checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the low-rank approximation.
lora_train_bias (str): LoRA bias training mode.
"""
@@ -21,7 +20,6 @@ class OPTRM(RewardModel):
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[OPTConfig] = None,
- checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
if pretrained is not None:
@@ -30,8 +28,6 @@ def __init__(self,
model = OPTModel(config)
else:
model = OPTModel(OPTConfig())
- if checkpoint:
- model.gradient_checkpointing_enable()
value_head = nn.Linear(model.config.word_embed_proj_dim, 1)
value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.word_embed_proj_dim + 1))
diff --git a/applications/Chat/coati/models/utils.py b/applications/Chat/coati/models/utils.py
index 8769fb7a8c43..97637d3523b0 100644
--- a/applications/Chat/coati/models/utils.py
+++ b/applications/Chat/coati/models/utils.py
@@ -1,14 +1,12 @@
from typing import Optional, Union
-import loralib as lora
import torch
-import torch.nn as nn
import torch.nn.functional as F
-def compute_approx_kl(log_probs: torch.Tensor,
- log_probs_base: torch.Tensor,
- action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
+def _compute_approx_kl(log_probs: torch.Tensor,
+ log_probs_base: torch.Tensor,
+ action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Compute the approximate KL divergence between two distributions.
Schulman blog: http://joschu.net/blog/kl-approx.html
@@ -35,12 +33,12 @@ def compute_reward(r: Union[torch.Tensor, float],
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
if kl_coef <= 0.0:
return r
- kl = compute_approx_kl(log_probs, log_probs_base, action_mask=action_mask)
+ kl = _compute_approx_kl(log_probs, log_probs_base, action_mask=action_mask)
reward = r - kl_coef * kl
return reward
-def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
+def _log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
log_probs = F.log_softmax(logits, dim=-1)
log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1))
return log_probs_labels.squeeze(-1)
@@ -58,7 +56,7 @@ def calc_action_log_probs(output: torch.Tensor, sequences: torch.LongTensor, num
torch.Tensor: Action log probs.
"""
logits = output['logits']
- log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
+ log_probs = _log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
return log_probs[:, -num_actions:]
@@ -68,41 +66,3 @@ def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch
mask_sum = mask.sum(dim=dim)
mean = tensor / (mask_sum + 1e-8)
return mean
-
-
-def masked_normalize(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1, eps: float = 1e-8) -> torch.Tensor:
- tensor = tensor * mask
- mean = masked_mean(tensor, mask, dim=dim)
- mean_centered = tensor - mean
- var = masked_mean(mean_centered**2, mask, dim=dim)
- return mean_centered * var.clamp(min=eps).rsqrt()
-
-
-def normalize(tensor: torch.Tensor, dim: int = 0, eps: float = 1e-8) -> torch.Tensor:
- mean = tensor.mean(dim)
- mean_centered = tensor - mean
- var = (mean_centered**2).mean(dim)
- norm = mean_centered * var.clamp(min=eps).rsqrt()
- return norm
-
-
-def convert_to_lora(model: nn.Module,
- input_size: int,
- output_size: int,
- lora_rank: int = 16,
- lora_alpha: int = 1,
- lora_dropout: float = 0.,
- fan_in_fan_out: bool = False,
- merge_weights: bool = True):
- if lora_rank > min(input_size, output_size):
- raise ValueError(f"LoRA rank {lora_rank} must be less or equal than {min(input_size, output_size)}")
-
- for name, module in model.named_modules():
- if isinstance(module, nn.Linear):
- module._modules[name] = lora.Linear(input_size,
- output_size,
- r=lora_rank,
- lora_alpha=lora_alpha,
- lora_dropout=lora_dropout,
- fan_in_fan_out=fan_in_fan_out,
- merge_weights=merge_weights)
diff --git a/applications/Chat/coati/ray/callbacks/performance_evaluator.py b/applications/Chat/coati/ray/callbacks/performance_evaluator.py
index cd3517609e7a..d3df8f9ae3e0 100644
--- a/applications/Chat/coati/ray/callbacks/performance_evaluator.py
+++ b/applications/Chat/coati/ray/callbacks/performance_evaluator.py
@@ -115,12 +115,12 @@ def on_loop_end(self) -> None:
avg_send_time_per_sample = (avg_send_duration + 1e-12) / (self.total_samples * self.world_size)
print_rank_0(
- 'Making Experience Performance Summary:\n' + f'Throughput: {avg_throughput:.3f} samples/sec\n' +
- f'TFLOPS per GPU: {avg_make_experience_tflops:.3f}\n' +
- f'Sample time (overall): {avg_time_per_sample:.3f} s\n' +
- f'Sample time (make experience): {avg_make_experience_time_per_sample:.3f} s, {avg_make_experience_time_per_sample/avg_time_per_sample*100:.2f}%\n'
- +
- f'Sample time (send): {avg_send_time_per_sample:.3f} s, {avg_send_time_per_sample/avg_time_per_sample*100:.2f}%\n'
+ 'Making Experience Performance Summary:\n' + f'Throughput: {avg_throughput:.3f} samples/sec\n'
+ + f'TFLOPS per GPU: {avg_make_experience_tflops:.3f}\n'
+ + f'Sample time (overall): {avg_time_per_sample:.3f} s\n'
+ + f'Sample time (make experience): {avg_make_experience_time_per_sample:.3f} s, {avg_make_experience_time_per_sample/avg_time_per_sample*100:.2f}%\n'
+
+ + f'Sample time (send): {avg_send_time_per_sample:.3f} s, {avg_send_time_per_sample/avg_time_per_sample*100:.2f}%\n'
)
@@ -204,9 +204,9 @@ def on_fit_end(self) -> None:
avg_update_time_per_sample = (avg_update_duration + 1e-12) / (self.total_samples * self.world_size)
print_rank_0(
- 'Learning Performance Summary:\n' + f'Throughput: {avg_throughput:.3f} samples/sec\n' +
- f'TFLOPS per GPU: {avg_learn_tflops:.3f}\n' + f'Sample time (overall): {avg_time_per_sample:.3f} s\n' +
- f'Sample time (train): {avg_train_time_per_sample:.3f} s, {avg_train_time_per_sample/avg_time_per_sample*100:.2f}%\n'
- +
- f'Sample time (update): {avg_update_time_per_sample:.3f} s, {avg_update_time_per_sample/avg_time_per_sample*100:.2f}%\n'
+ 'Learning Performance Summary:\n' + f'Throughput: {avg_throughput:.3f} samples/sec\n'
+ + f'TFLOPS per GPU: {avg_learn_tflops:.3f}\n' + f'Sample time (overall): {avg_time_per_sample:.3f} s\n'
+ + f'Sample time (train): {avg_train_time_per_sample:.3f} s, {avg_train_time_per_sample/avg_time_per_sample*100:.2f}%\n'
+
+ + f'Sample time (update): {avg_update_time_per_sample:.3f} s, {avg_update_time_per_sample/avg_time_per_sample*100:.2f}%\n'
)
diff --git a/applications/Chat/coati/ray/detached_replay_buffer.py b/applications/Chat/coati/ray/detached_replay_buffer.py
index 2f765281178a..7b9df2ee139b 100644
--- a/applications/Chat/coati/ray/detached_replay_buffer.py
+++ b/applications/Chat/coati/ray/detached_replay_buffer.py
@@ -6,9 +6,9 @@
import ray
import torch
+from coati.experience_buffer import ExperienceBuffer
+from coati.experience_buffer.utils import BufferItem, make_experience_batch, split_experience_batch
from coati.experience_maker.base import Experience
-from coati.replay_buffer import ReplayBuffer
-from coati.replay_buffer.utils import BufferItem, make_experience_batch, split_experience_batch
# from torch.multiprocessing import Queue
from ray.util.queue import Queue
diff --git a/applications/Chat/coati/ray/detached_trainer_base.py b/applications/Chat/coati/ray/detached_trainer_base.py
index ac2d35e9da19..90399781187a 100644
--- a/applications/Chat/coati/ray/detached_trainer_base.py
+++ b/applications/Chat/coati/ray/detached_trainer_base.py
@@ -4,8 +4,8 @@
import ray
import torch
+from coati.experience_buffer.utils import BufferItem
from coati.experience_maker import Experience
-from coati.replay_buffer.utils import BufferItem
from torch.utils.data import DataLoader
from tqdm import tqdm
diff --git a/applications/Chat/coati/ray/experience_maker_holder.py b/applications/Chat/coati/ray/experience_maker_holder.py
index 07d9c3e4f396..13314bdafd5f 100644
--- a/applications/Chat/coati/ray/experience_maker_holder.py
+++ b/applications/Chat/coati/ray/experience_maker_holder.py
@@ -8,9 +8,9 @@
import ray
import torch
import torch.nn as nn
+from coati.experience_buffer.utils import BufferItem, make_experience_batch, split_experience_batch
from coati.experience_maker import Experience, ExperienceMaker, NaiveExperienceMaker
from coati.models.base import Actor, Critic, RewardModel
-from coati.replay_buffer.utils import BufferItem, make_experience_batch, split_experience_batch
from coati.trainer.callbacks import Callback
from coati.trainer.strategies import Strategy
from coati.trainer.strategies.sampler import DistributedSampler
@@ -19,13 +19,9 @@
from tqdm import tqdm
from .callbacks import ExperienceMakerPerformanceEvaluator, MakerCallback
-from .utils import (get_model_numel,
- get_rank,
- get_world_size,
- is_rank_0,
- set_dist_env,
- state_dict_to)
from .lora_constructor import LoRAConstructor
+from .utils import get_model_numel, get_rank, get_world_size, is_rank_0, set_dist_env, state_dict_to
+
@ray.remote(concurrency_groups={"experience_io": 1, "model_io": 1, "compute": 1})
class ExperienceMakerHolder:
@@ -41,7 +37,7 @@ def __init__(
self,
detached_trainer_name_list: List[str],
strategy_fn: Callable[[], Strategy],
- # a function returns (actor, critic, reward_model, initial_model)
+ # a function returns (actor, critic, reward_model, initial_model)
model_fn: Callable[[], Tuple[Actor, Critic, RewardModel, Actor]],
env_info: Dict[str, str] = None,
sync_models_from_trainers: bool = False,
@@ -205,15 +201,19 @@ def update_experience_maker(self,
self.experience_maker.actor.model.load_state_dict(new_actor_state_dict, strict=False)
else:
new_actor_state_dict = state_dict_to(new_actor_state_dict, device=torch.cuda.current_device())
- state_dict_increase = self.actor_lora_constructor.reconstruct_increase(new_actor_state_dict, new_actor_lora_config_dict)
- self.actor_lora_constructor.load_state_dict_increase(self.experience_maker.actor.model, state_dict_increase)
+ state_dict_increase = self.actor_lora_constructor.reconstruct_increase(
+ new_actor_state_dict, new_actor_lora_config_dict)
+ self.actor_lora_constructor.load_state_dict_increase(
+ self.experience_maker.actor.model, state_dict_increase)
if new_critic_state_dict is not None:
if not self._update_lora_weights or fully_update:
self.experience_maker.critic.load_state_dict(new_critic_state_dict, strict=False)
else:
new_critic_state_dict = state_dict_to(new_critic_state_dict, device=torch.cuda.current_device())
- state_dict_increase = self.critic_lora_constructor.reconstruct_increase(new_critic_state_dict, new_critic_lora_config_dict)
- self.critic_lora_constructor.load_state_dict_increase(self.experience_maker.critic, state_dict_increase)
+ state_dict_increase = self.critic_lora_constructor.reconstruct_increase(
+ new_critic_state_dict, new_critic_lora_config_dict)
+ self.critic_lora_constructor.load_state_dict_increase(
+ self.experience_maker.critic, state_dict_increase)
# the lock must be released after both actor and critic being updated
if chunk_end:
diff --git a/applications/Chat/coati/ray/lora_constructor.py b/applications/Chat/coati/ray/lora_constructor.py
index 4809617f647b..a98545d4d751 100644
--- a/applications/Chat/coati/ray/lora_constructor.py
+++ b/applications/Chat/coati/ray/lora_constructor.py
@@ -1,11 +1,11 @@
-from typing import Any, Callable, Dict, List, Optional
from collections import OrderedDict
from dataclasses import dataclass
+from typing import Any, Callable, Dict, List, Optional
import torch
import torch.nn as nn
-from loralib.layers import LoRALayer
from coati.models.lora import LoraLinear
+from loralib.layers import LoRALayer
@dataclass
@@ -23,19 +23,19 @@ class LoRAConstructor:
Usage:
Step 1 (Sender):
filter_state_dict_lora()
-
+
Step 2 (Sender, Optional):
extract_lora_config()
-
+
Step 3 (Sender):
send state_dict_lora and lora_config_dict
-
+
Step 4 (Receiver):
reconstruct_increase()
-
+
Step 5 (Receiver):
load_state_dict_increase()
-
+
'''
def __init__(self):
diff --git a/applications/Chat/coati/replay_buffer/__init__.py b/applications/Chat/coati/replay_buffer/__init__.py
deleted file mode 100644
index 1ebf60382913..000000000000
--- a/applications/Chat/coati/replay_buffer/__init__.py
+++ /dev/null
@@ -1,4 +0,0 @@
-from .base import ReplayBuffer
-from .naive import NaiveReplayBuffer
-
-__all__ = ['ReplayBuffer', 'NaiveReplayBuffer']
diff --git a/applications/Chat/coati/trainer/base.py b/applications/Chat/coati/trainer/base.py
index b4d168a563d9..0629c9c00cca 100644
--- a/applications/Chat/coati/trainer/base.py
+++ b/applications/Chat/coati/trainer/base.py
@@ -4,8 +4,8 @@
import torch.nn as nn
import tqdm
+from coati.experience_buffer import NaiveExperienceBuffer
from coati.experience_maker import Experience
-from coati.replay_buffer import NaiveReplayBuffer
from torch.optim import Optimizer
from torch.utils.data import DataLoader
@@ -62,7 +62,7 @@ class OnPolicyTrainer(ABC):
Args:
strategy (Strategy):the strategy to use for training
- buffer (NaiveReplayBuffer): the buffer to collect experiences
+ data_buffer (NaiveExperienceBuffer): the buffer to collect experiences
sample_buffer (bool, defaults to False): whether to sample from buffer
dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader
callbacks (List[Callback], defaults to []): the callbacks to call during training process
@@ -70,13 +70,13 @@ class OnPolicyTrainer(ABC):
def __init__(self,
strategy: Strategy,
- buffer: NaiveReplayBuffer,
+ data_buffer: NaiveExperienceBuffer,
sample_buffer: bool,
dataloader_pin_memory: bool,
callbacks: List[Callback] = []) -> None:
super().__init__()
self.strategy = strategy
- self.buffer = buffer
+ self.data_buffer = data_buffer
self.sample_buffer = sample_buffer
self.dataloader_pin_memory = dataloader_pin_memory
self.callbacks = callbacks
@@ -144,7 +144,7 @@ def _collect_phase(self, collect_step: int):
self._on_make_experience_start()
experience = self._make_experience(collect_step)
self._on_make_experience_end(experience)
- self.buffer.append(experience)
+ self.data_buffer.append(experience)
def _update_phase(self, update_step: int):
self._on_learn_epoch_start(update_step)
@@ -181,8 +181,8 @@ def fit(
# HACK(cwher): according to the design of boost API, dataloader should also be boosted,
# but it is impractical to adapt this pattern in RL training. Thus, I left dataloader unboosted.
# I only call strategy.setup_dataloader() to setup dataloader.
- self.dataloader = self.strategy.setup_dataloader(self.buffer, self.dataloader_pin_memory)
+ self.dataloader = self.strategy.setup_dataloader(self.data_buffer, self.dataloader_pin_memory)
for update_step in tqdm.trange(num_update_steps, desc="Update steps", disable=not is_rank_0()):
self._update_phase(update_step)
# NOTE: this is for on-policy algorithms
- self.buffer.clear()
+ self.data_buffer.clear()
diff --git a/applications/Chat/coati/trainer/callbacks/performance_evaluator.py b/applications/Chat/coati/trainer/callbacks/performance_evaluator.py
index 925455444597..9b44dafa7eaa 100644
--- a/applications/Chat/coati/trainer/callbacks/performance_evaluator.py
+++ b/applications/Chat/coati/trainer/callbacks/performance_evaluator.py
@@ -171,13 +171,13 @@ def on_fit_end(self) -> None:
learn_time_per_sample = divide(avg_learn_duration, num_effective_samples)
print_rank_0(
- f'Performance summary:\n' +
- f'Generate {self.make_experience_num_samples * self.world_size} samples, throughput: {avg_make_experience_throughput:.2f} samples/s, TFLOPS per GPU: {avg_make_experience_tflops:.2f}\n'
- +
- f'Train {self.learn_num_samples * self.world_size} samples, throughput: {avg_learn_throughput:.2f} samples/s, TFLOPS per GPU: {avg_learn_tflops:.2f}\n'
- + f'Overall throughput: {avg_overall_throughput:.2f} samples/s\n' +
- f'Overall time per sample: {overall_time_per_sample:.2f} s\n' +
- f'Make experience time per sample: {make_experience_time_per_sample:.2f} s, {make_experience_time_per_sample/overall_time_per_sample*100:.2f}%\n'
- +
- f'Learn time per sample: {learn_time_per_sample:.2f} s, {learn_time_per_sample/overall_time_per_sample*100:.2f}%'
+ f'Performance summary:\n'
+ + f'Generate {self.make_experience_num_samples * self.world_size} samples, throughput: {avg_make_experience_throughput:.2f} samples/s, TFLOPS per GPU: {avg_make_experience_tflops:.2f}\n'
+
+ + f'Train {self.learn_num_samples * self.world_size} samples, throughput: {avg_learn_throughput:.2f} samples/s, TFLOPS per GPU: {avg_learn_tflops:.2f}\n'
+ + f'Overall throughput: {avg_overall_throughput:.2f} samples/s\n'
+ + f'Overall time per sample: {overall_time_per_sample:.2f} s\n'
+ + f'Make experience time per sample: {make_experience_time_per_sample:.2f} s, {make_experience_time_per_sample/overall_time_per_sample*100:.2f}%\n'
+
+ + f'Learn time per sample: {learn_time_per_sample:.2f} s, {learn_time_per_sample/overall_time_per_sample*100:.2f}%'
)
diff --git a/applications/Chat/coati/trainer/ppo.py b/applications/Chat/coati/trainer/ppo.py
index 4c4a1002e96d..ef625a1c1b3d 100644
--- a/applications/Chat/coati/trainer/ppo.py
+++ b/applications/Chat/coati/trainer/ppo.py
@@ -1,11 +1,11 @@
from typing import Dict, List
import torch.nn as nn
+from coati.experience_buffer import NaiveExperienceBuffer
from coati.experience_maker import Experience, NaiveExperienceMaker
from coati.models.base import Actor, Critic, get_base_model
from coati.models.loss import GPTLMLoss, PolicyLoss, ValueLoss
from coati.models.utils import calc_action_log_probs
-from coati.replay_buffer import NaiveReplayBuffer
from torch import Tensor
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
@@ -86,9 +86,9 @@ def __init__(self,
assert not offload_inference_models, \
"GeminiPlugin is not compatible with manual model.to('cpu')"
- buffer = NaiveReplayBuffer(train_batch_size, buffer_limit, buffer_cpu_offload)
+ data_buffer = NaiveExperienceBuffer(train_batch_size, buffer_limit, buffer_cpu_offload)
super().__init__(
- strategy, buffer,
+ strategy, data_buffer,
sample_buffer, dataloader_pin_memory,
callbacks
)
@@ -170,7 +170,7 @@ def _learn(self, update_step: int):
# buffer may be empty at first, we should rebuild at each training
if self.sample_buffer:
- experience = self.buffer.sample()
+ experience = self.data_buffer.sample()
self._on_learn_batch_start()
experience.to_device(self.device)
metrics = self._training_step(experience)
diff --git a/applications/Chat/coati/trainer/strategies/base.py b/applications/Chat/coati/trainer/strategies/base.py
index 3d1dfaf784cf..c20b2b16e396 100644
--- a/applications/Chat/coati/trainer/strategies/base.py
+++ b/applications/Chat/coati/trainer/strategies/base.py
@@ -4,7 +4,7 @@
import torch
import torch.nn as nn
-from coati.replay_buffer import ReplayBuffer
+from coati.experience_buffer import ExperienceBuffer
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
@@ -45,7 +45,7 @@ def setup_distributed(self) -> None:
pass
@abstractmethod
- def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False) -> DataLoader:
+ def setup_dataloader(self, data_buffer: ExperienceBuffer, pin_memory: bool = False) -> DataLoader:
pass
def model_init_context(self):
diff --git a/applications/Chat/coati/trainer/strategies/colossalai.py b/applications/Chat/coati/trainer/strategies/colossalai.py
index 1b59d704eec3..fa55f97ad661 100644
--- a/applications/Chat/coati/trainer/strategies/colossalai.py
+++ b/applications/Chat/coati/trainer/strategies/colossalai.py
@@ -4,7 +4,6 @@
import torch
import torch.distributed as dist
import torch.nn as nn
-from transformers.tokenization_utils_base import PreTrainedTokenizerBase
import colossalai
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin
@@ -44,7 +43,7 @@ class LowLevelZeroStrategy(DDPStrategy):
"""
def __init__(self,
- stage: int = 3,
+ stage: int = 2,
precision: str = 'fp16',
seed: int = 42,
placement_policy: str = 'cuda',
@@ -214,14 +213,3 @@ def unwrap_model(self, model: nn.Module) -> nn.Module:
ddp_model = model.unwrap()
assert isinstance(ddp_model, GeminiDDP)
return ddp_model.module
-
- def save_pretrained(self,
- model: nn.Module,
- path: str,
- only_rank0: bool = True,
- tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
- raise RuntimeError('ColossalAI strategy with stage-3 does not support save_pretrained() now')
-
- def get_model_state_dict_shard(self, model: nn.Module, **config):
- assert isinstance(self.plugin, GeminiPlugin)
- yield from super().get_model_state_dict_shard(model, **config)
diff --git a/applications/Chat/coati/trainer/strategies/ddp.py b/applications/Chat/coati/trainer/strategies/ddp.py
index e1c1bbf19f35..a52b0460daa8 100644
--- a/applications/Chat/coati/trainer/strategies/ddp.py
+++ b/applications/Chat/coati/trainer/strategies/ddp.py
@@ -7,7 +7,8 @@
import torch
import torch.distributed as dist
import torch.nn as nn
-from coati.replay_buffer import ReplayBuffer
+from coati.experience_buffer import ExperienceBuffer
+from coati.models import Actor, Critic, RewardModel
from torch.utils.data import DataLoader
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
@@ -71,13 +72,13 @@ def set_seed(self, seed: int) -> None:
np.random.seed(seed)
torch.manual_seed(seed)
- def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False) -> DataLoader:
- return self.plugin.prepare_dataloader(replay_buffer,
- batch_size=replay_buffer.sample_batch_size,
+ def setup_dataloader(self, data_buffer: ExperienceBuffer, pin_memory: bool = False) -> DataLoader:
+ return self.plugin.prepare_dataloader(data_buffer,
+ batch_size=data_buffer.sample_batch_size,
shuffle=True,
drop_last=True,
pin_memory=pin_memory,
- collate_fn=replay_buffer.collate_fn)
+ collate_fn=data_buffer.collate_fn)
def setup_sampler(self, dataset) -> DistributedSampler:
# FIXME(cwher): this is only invoked in train_on_ray, not tested after adapt Boost API.
@@ -92,13 +93,33 @@ def save_pretrained(self,
path: str,
only_rank0: bool = True,
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
- if only_rank0 and dist.get_rank() != 0:
- return
- unwrapped_model = self.unwrap_model(model)
- assert isinstance(unwrapped_model, PreTrainedModel)
- unwrapped_model.save_pretrained(path)
- if tokenizer is not None:
- tokenizer.save_pretrained(path)
+ if not only_rank0 or dist.get_rank() == 0:
+ unwrapped_model = self.unwrap_model(model)
+ assert isinstance(unwrapped_model, (Actor, Critic, RewardModel))
+ pretrained_model = unwrapped_model.model
+ assert isinstance(pretrained_model, PreTrainedModel)
+ # HACK: only use hf save_pretrained to save config
+ pretrained_model.save_pretrained(path, save_function=lambda *args, **kwargs: None)
+ if tokenizer is not None:
+ tokenizer.save_pretrained(path)
+ model_path = os.path.join(path, "pytorch_model.bin")
+ self.save_model(model,
+ model_path,
+ only_rank0=only_rank0)
+
+ def _replace_keys(model_path: str,
+ replace_fn: Callable):
+ state_dict = torch.load(model_path, map_location="cpu")
+ state_dict = {
+ replace_fn(k): v
+ for k, v in state_dict.items()
+ }
+ torch.save(state_dict, model_path)
+
+ # FIXME: save_model would add "model." prefix to keys of pytorch_model.bin
+ # HACK: rename keys of pytorch_model.bin
+ if dist.get_rank() == 0:
+ _replace_keys(model_path, lambda k: k.replace("model.", "", 1))
def get_model_state_dict_shard(self, model: nn.Module, **config):
# TODO: implement sharding on naive strategy
diff --git a/applications/Chat/coati/trainer/strategies/sampler.py b/applications/Chat/coati/trainer/strategies/sampler.py
index 65e199dbf029..d726fa640fa2 100644
--- a/applications/Chat/coati/trainer/strategies/sampler.py
+++ b/applications/Chat/coati/trainer/strategies/sampler.py
@@ -27,7 +27,6 @@ def __init__(self, dataset, num_replicas: int, rank: int) -> None:
assert len(indices) == self.num_samples
self.indices = indices
-
def sample(self, batch_size: int) -> list:
sampled_indices = np.random.choice(self.indices, batch_size, replace=False)
return [self.dataset[idx] for idx in sampled_indices]
diff --git a/applications/Chat/coati/trainer/utils.py b/applications/Chat/coati/trainer/utils.py
index 4d45061bab09..7e2cb9c634f7 100644
--- a/applications/Chat/coati/trainer/utils.py
+++ b/applications/Chat/coati/trainer/utils.py
@@ -21,9 +21,13 @@ def __init__(
self.dataloader = dataloader
self.count = 0
- self.dataloader_iter = iter(dataloader)
+ self.dataloader_iter = None
def next(self):
+ # defer initialization
+ if self.dataloader_iter is None:
+ self.dataloader_iter = iter(self.dataloader)
+
self.count += 1
try:
return next(self.dataloader_iter)
diff --git a/applications/Chat/examples/download_model.py b/applications/Chat/examples/download_model.py
new file mode 100644
index 000000000000..c2b5f9a859a9
--- /dev/null
+++ b/applications/Chat/examples/download_model.py
@@ -0,0 +1,84 @@
+import argparse
+import dataclasses
+import os
+import parser
+from typing import List
+
+import tqdm
+from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic
+from coati.models.gpt import GPTRM, GPTActor, GPTCritic
+from coati.models.opt import OPTRM, OPTActor, OPTCritic
+from huggingface_hub import hf_hub_download, snapshot_download
+from transformers import AutoConfig, AutoTokenizer, BloomConfig, BloomTokenizerFast, GPT2Config, GPT2Tokenizer
+
+
+@dataclasses.dataclass
+class HFRepoFiles:
+ repo_id: str
+ files: List[str]
+
+ def download(self, dir_path: str):
+ for file in self.files:
+ file_path = hf_hub_download(self.repo_id, file, local_dir=dir_path)
+
+ def download_all(self):
+ file_path = snapshot_download(self.repo_id)
+
+
+def test_init(model: str, dir_path: str):
+ if model == "gpt2":
+ config = GPT2Config.from_pretrained(dir_path)
+ actor = GPTActor(config=config)
+ critic = GPTCritic(config=config)
+ reward_model = GPTRM(config=config)
+ tokenizer = GPT2Tokenizer.from_pretrained(dir_path)
+ elif model == "bloom":
+ config = BloomConfig.from_pretrained(dir_path)
+ actor = BLOOMActor(config=config)
+ critic = BLOOMCritic(config=config)
+ reward_model = BLOOMRM(config=config)
+ tokenizer = BloomTokenizerFast.from_pretrained(dir_path)
+ elif model == "opt":
+ config = AutoConfig.from_pretrained(dir_path)
+ actor = OPTActor(config=config)
+ critic = OPTCritic(config=config)
+ reward_model = OPTRM(config=config)
+ tokenizer = AutoTokenizer.from_pretrained(dir_path)
+ else:
+ raise NotImplementedError(f"Model {model} not implemented")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model-dir", type=str, default="test_models")
+ parser.add_argument("--config-only", default=False, action="store_true")
+ args = parser.parse_args()
+
+ if os.path.exists(args.model_dir):
+ print(f"[INFO]: {args.model_dir} already exists")
+ exit(0)
+
+ repo_list = {
+ "gpt2": HFRepoFiles(
+ repo_id="gpt2",
+ files=["config.json", "tokenizer.json", "vocab.json", "merges.txt"]
+ ),
+ "bloom": HFRepoFiles(
+ repo_id="bigscience/bloom-560m",
+ files=["config.json", "tokenizer.json", "tokenizer_config.json"]
+ ),
+ "opt": HFRepoFiles(
+ repo_id="facebook/opt-350m",
+ files=["config.json", "tokenizer_config.json", "vocab.json", "merges.txt"]
+ ),
+ }
+
+ os.mkdir(args.model_dir)
+ for model_name in tqdm.tqdm(repo_list):
+ dir_path = os.path.join(args.model_dir, model_name)
+ if args.config_only:
+ os.mkdir(dir_path)
+ repo_list[model_name].download(dir_path)
+ else:
+ repo_list[model_name].download_all()
+ test_init(model_name, dir_path)
diff --git a/applications/Chat/examples/generate_prompt_dataset.py b/applications/Chat/examples/generate_prompt_dataset.py
index 95e40fefe7ff..2abb31c09f82 100644
--- a/applications/Chat/examples/generate_prompt_dataset.py
+++ b/applications/Chat/examples/generate_prompt_dataset.py
@@ -1,7 +1,6 @@
import argparse
-
-import random
import json
+import random
random.seed(42)
@@ -10,8 +9,10 @@ def sample(args):
with open(args.dataset_path, mode='r') as f:
dataset_list = json.load(f)
- sampled_dataset = [{"instruction": sample["instruction"], "id":idx}
- for idx, sample in enumerate(random.sample(dataset_list, args.sample_size))]
+ sampled_dataset = [
+ {"instruction": sample["instruction"], "id": idx}
+ for idx, sample in enumerate(random.sample(dataset_list, args.sample_size))
+ ]
with open(args.save_path, mode='w') as f:
json.dump(sampled_dataset, f, indent=4,
diff --git a/applications/Chat/examples/inference.py b/applications/Chat/examples/inference.py
index 4b49e76088bc..e1e57e3cd376 100644
--- a/applications/Chat/examples/inference.py
+++ b/applications/Chat/examples/inference.py
@@ -4,40 +4,50 @@
from coati.models.bloom import BLOOMActor
from coati.models.generation import generate
from coati.models.gpt import GPTActor
+from coati.models.llama import LlamaActor
from coati.models.opt import OPTActor
-from transformers import AutoTokenizer
-from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
+from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer, LlamaTokenizer
def eval(args):
# configure model
if args.model == 'gpt2':
- actor = GPTActor(pretrained=args.pretrain).to(torch.cuda.current_device())
+ actor = GPTActor(pretrained=args.pretrain)
elif args.model == 'bloom':
- actor = BLOOMActor(pretrained=args.pretrain).to(torch.cuda.current_device())
+ actor = BLOOMActor(pretrained=args.pretrain)
elif args.model == 'opt':
- actor = OPTActor(pretrained=args.pretrain).to(torch.cuda.current_device())
+ actor = OPTActor(pretrained=args.pretrain)
+ elif args.model == 'llama':
+ actor = LlamaActor(pretrained=args.pretrain)
else:
raise ValueError(f'Unsupported model "{args.model}"')
- state_dict = torch.load(args.model_path)
- actor.load_state_dict(state_dict)
+ actor.to(torch.cuda.current_device())
+ if args.model_path is not None:
+ state_dict = torch.load(args.model_path)
+ actor.load_state_dict(state_dict)
# configure tokenizer
if args.model == 'gpt2':
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'bloom':
- tokenizer = AutoTokenizer.from_pretrained('bigscience/bloom-560m')
+ tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m')
tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'opt':
- tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m')
+ tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
+ tokenizer.pad_token = tokenizer.eos_token
+ elif args.model == 'llama':
+ tokenizer = LlamaTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
+ tokenizer.eos_token = '<\s>'
+ tokenizer.pad_token = tokenizer.unk_token
else:
raise ValueError(f'Unsupported model "{args.model}"')
actor.eval()
- input = args.input
- input_ids = tokenizer.encode(input, return_tensors='pt').to(torch.cuda.current_device())
+ input_ids = tokenizer.encode(args.input,
+ return_tensors='pt')\
+ .to(torch.cuda.current_device())
outputs = generate(actor,
input_ids,
max_length=args.max_length,
@@ -45,13 +55,14 @@ def eval(args):
top_k=50,
top_p=0.95,
num_return_sequences=1)
- output = tokenizer.batch_decode(outputs[0], skip_special_tokens=True)
- print(output)
+ output = tokenizer.batch_decode(outputs[0],
+ skip_special_tokens=True)
+ print(f"[Output]: {''.join(output)}")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
- parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt'])
+ parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
# We suggest to use the pretrained model from HuggingFace, use pretrain to configure model
parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--model_path', type=str, default=None)
diff --git a/applications/Chat/examples/test_ci.sh b/applications/Chat/examples/test_ci.sh
deleted file mode 100755
index fe2af471017e..000000000000
--- a/applications/Chat/examples/test_ci.sh
+++ /dev/null
@@ -1,160 +0,0 @@
-#!/usr/bin/env bash
-
-set_n_least_used_CUDA_VISIBLE_DEVICES() {
- local n=${1:-"9999"}
- echo "GPU Memory Usage:"
- local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
- tail -n +2 |
- nl -v 0 |
- tee /dev/tty |
- sort -g -k 2 |
- awk '{print $1}' |
- head -n $n)
- export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
- echo "Now CUDA_VISIBLE_DEVICES is set to:"
- echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
-}
-
-set_n_least_used_CUDA_VISIBLE_DEVICES 4
-
-set -xue
-
-if [ -z "$SFT_DATASET" ]; then
- echo "Please set \$SFT_DATASET to the path to sft dataset."
- exit 1
-fi
-
-if [ -z "$PROMPT_PATH" ]; then
- echo "Please set \$PROMPT_PATH to the path to prompts csv."
- exit 1
-fi
-
-if [ -z "$PRETRAIN_DATASET" ]; then
- echo "Please set \$PRETRAIN_DATASET to the path to alpaca data."
- exit 1
-fi
-
-BASE=$(realpath $(dirname $0))
-
-export OMP_NUM_THREADS=8
-
-# install requirements
-pip install -r ${BASE}/requirements.txt
-
-wandb init -m offline
-
-# FIXME: This is a hack to skip tests that are not working
-# - gpt2-ddp: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
-# - llama-*: These tests can be passed locally, skipped for long execution time
-SKIPPED_TESTS=(
- "gpt2-ddp"
- "llama-ddp"
- "llama-colossalai_gemini"
- "llama-colossalai_zero2"
-)
-
-# These tests are quick and do not have any dependencies
-for model in 'gpt2' 'bloom' 'opt' 'llama'; do
- for strategy in 'ddp' 'colossalai_gemini' 'colossalai_zero2'; do
- if [[ " ${SKIPPED_TESTS[*]} " =~ " ${model}-${strategy} " ]]; then
- echo "[Test]: Skipped $model-$strategy"
- continue
- fi
- torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py \
- --prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \
- --strategy $strategy --model $model \
- --num_episodes 1 --num_collect_steps 2 --num_update_steps 1 \
- --train_batch_size 2 --lora_rank 4
- done
-done
-
-# train sft
-torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'bigscience/bloom-560m' \
- --model 'bloom' --strategy colossalai_zero2 --lora_rank 4 \
- --dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \
- --save_path ${BASE}/output
-rm -rf ${BASE}/output
-
-torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'gpt2' \
- --model 'gpt2' --strategy colossalai_zero2 \
- --dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \
- --save_path ${BASE}/output
-rm -rf ${BASE}/output
-
-torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'facebook/opt-350m' \
- --model 'opt' --strategy colossalai_zero2 --lora_rank 4 \
- --dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \
- --save_path ${BASE}/output
-rm -rf ${BASE}/output
-
-torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'gpt2' \
- --model 'gpt2' --strategy ddp --lora_rank 4 \
- --dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \
- --save_path ${BASE}/output
-rm -rf ${BASE}/output
-
-# train rm
-torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
- --pretrain 'facebook/opt-350m' --model 'opt' \
- --strategy colossalai_zero2 --loss_fn 'log_sig' \
- --dataset 'Anthropic/hh-rlhf' --subset 'harmless-base' \
- --test True --lora_rank 0 \
- --save_path ${BASE}/rm_ckpt_opt.pt
-
-torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
- --pretrain 'gpt2' --model 'gpt2' \
- --strategy colossalai_zero2 --loss_fn 'log_exp' \
- --dataset 'Dahoas/rm-static' \
- --test True --lora_rank 0 \
- --save_path ${BASE}/rm_ckpt_gpt.pt
-
-torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
- --pretrain 'gpt2' --model 'gpt2' \
- --strategy ddp --loss_fn 'log_exp' \
- --dataset 'Dahoas/rm-static' \
- --test True --lora_rank 4 \
- --save_path ${BASE}/rm_ckpt.pt
-rm -rf ${BASE}/rm_ckpt.pt
-
-torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
- --pretrain 'bigscience/bloom-560m' --model 'bloom' \
- --strategy colossalai_zero2 --loss_fn 'log_sig' \
- --dataset 'Anthropic/hh-rlhf' --subset 'harmless-base' \
- --test True --lora_rank 4 \
- --save_path ${BASE}/rm_ckpt.pt
-rm -rf ${BASE}/rm_ckpt.pt
-
-# train rl
-torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py \
- --prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \
- --strategy colossalai_zero2 --num_episodes 1 \
- --num_collect_steps 2 --num_update_steps 1 --train_batch_size 2 \
- --pretrain 'facebook/opt-350m' --model opt \
- --rm_pretrain 'facebook/opt-350m' \
- --rm_path ${BASE}/rm_ckpt_opt.pt \
- --save_path ${BASE}/actor_checkpoint_prompts.pt
-rm -rf ${BASE}/rm_ckpt_opt.pt
-
-torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py \
- --prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \
- --strategy colossalai_zero2 --num_episodes 1 \
- --num_collect_steps 2 --num_update_steps 1 --train_batch_size 2 \
- --pretrain 'gpt2' --model gpt2 \
- --rm_pretrain 'gpt2' \
- --rm_path ${BASE}/rm_ckpt_gpt.pt \
- --save_path ${BASE}/actor_checkpoint_prompts.pt
-
-torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py \
- --prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \
- --strategy colossalai_gemini --num_episodes 1 \
- --num_collect_steps 2 --num_update_steps 1 --train_batch_size 2 \
- --pretrain 'gpt2' --model gpt2 \
- --rm_pretrain 'gpt2' \
- --rm_path ${BASE}/rm_ckpt_gpt.pt \
- --save_path ${BASE}/actor_checkpoint_prompts.pt
-rm -rf ${BASE}/rm_ckpt_gpt.pt
-
-rm -rf ${BASE}/actor_checkpoint_prompts.pt
-
-# 3080 doesn't support P2P, skip this test
-# cd ${BASE}/ray && bash test_ci.sh && cd ${BASE}
diff --git a/applications/Chat/examples/train_prompts.py b/applications/Chat/examples/train_prompts.py
index 7338a6d51142..d27a70a3fef6 100644
--- a/applications/Chat/examples/train_prompts.py
+++ b/applications/Chat/examples/train_prompts.py
@@ -1,8 +1,9 @@
import argparse
+import warnings
import torch
import torch.distributed as dist
-from coati.dataset import DataCollatorForSupervisedDataset, PromptDataset, SupervisedDataset
+from coati.dataset import PromptDataset, SupervisedDataset
from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic
from coati.models.gpt import GPTRM, GPTActor, GPTCritic
from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM
@@ -29,6 +30,7 @@ def main(args):
raise ValueError(f'Unsupported strategy "{args.strategy}"')
if args.rm_path is not None:
+ warnings.warn('LoRA weights should be merged with the model weights')
state_dict = torch.load(args.rm_path, map_location='cpu')
with strategy.model_init_context():
@@ -50,18 +52,18 @@ def main(args):
rm_model_name = args.rm_model
if rm_model_name == 'gpt2':
- reward_model = GPTRM(pretrained=args.rm_pretrain)
+ reward_model = GPTRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
elif rm_model_name == 'bloom':
- reward_model = BLOOMRM(pretrained=args.rm_pretrain)
+ reward_model = BLOOMRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
elif rm_model_name == 'opt':
- reward_model = OPTRM(pretrained=args.rm_pretrain)
+ reward_model = OPTRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
elif rm_model_name == 'llama':
- reward_model = LlamaRM(pretrained=args.rm_pretrain)
+ reward_model = LlamaRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
else:
raise ValueError(f'Unsupported reward model "{rm_model_name}"')
if args.rm_path is not None:
- reward_model.load_state_dict(state_dict)
+ reward_model.load_state_dict(state_dict, strict=False)
initial_model.to(torch.float16).to(torch.cuda.current_device())
reward_model.to(torch.float16).to(torch.cuda.current_device())
@@ -89,7 +91,7 @@ def main(args):
raise ValueError(f'Unsupported reward model "{rm_model_name}"')
if args.rm_path is not None:
- critic.load_state_dict(state_dict)
+ critic.load_state_dict(state_dict, strict=False)
del state_dict
if args.strategy != 'colossalai_gemini':
@@ -106,23 +108,25 @@ def main(args):
# configure tokenizer
if args.model == 'gpt2':
- tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
+ tokenizer = GPT2Tokenizer.from_pretrained(
+ 'gpt2' if args.tokenizer is None else args.tokenizer)
tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'bloom':
- tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m')
+ tokenizer = BloomTokenizerFast.from_pretrained(
+ 'bigscience/bloom-560m' if args.tokenizer is None else args.tokenizer)
tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'opt':
- tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
+ tokenizer = AutoTokenizer.from_pretrained(
+ "facebook/opt-350m" if args.tokenizer is None else args.tokenizer)
tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'llama':
- tokenizer = LlamaTokenizer.from_pretrained(args.pretrain)
+ tokenizer = LlamaTokenizer.from_pretrained(
+ "hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer)
tokenizer.eos_token = '<\s>'
tokenizer.pad_token = tokenizer.unk_token
else:
raise ValueError(f'Unsupported model "{args.model}"')
- data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
-
prompt_dataset = PromptDataset(tokenizer=tokenizer, data_path=args.prompt_dataset, max_datasets_size=16384)
if dist.is_initialized() and dist.get_world_size() > 1:
prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True)
@@ -144,8 +148,7 @@ def main(args):
pretrain_dataloader = DataLoader(pretrain_dataset,
shuffle=(pretrain_sampler is None),
sampler=pretrain_sampler,
- batch_size=args.ptx_batch_size,
- collate_fn=data_collator)
+ batch_size=args.ptx_batch_size)
# NOTE: For small models like opt-1.3b, reward model and initial model are not required to be parallelized.
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model = \
@@ -197,6 +200,7 @@ def main(args):
default='colossalai_zero2',
help='strategy to use')
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
+ parser.add_argument('--tokenizer', type=str, default=None)
parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--rm_model', default=None, choices=['gpt2', 'bloom', 'opt', 'llama'])
parser.add_argument('--rm_path', type=str, default=None)
diff --git a/applications/Chat/examples/train_reward_model.py b/applications/Chat/examples/train_reward_model.py
index fb9802e38542..190460bc20f6 100644
--- a/applications/Chat/examples/train_reward_model.py
+++ b/applications/Chat/examples/train_reward_model.py
@@ -36,34 +36,39 @@ def train(args):
# configure model
with strategy.model_init_context():
if args.model == 'bloom':
- model = BLOOMRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
+ model = BLOOMRM(pretrained=args.pretrain, lora_rank=args.lora_rank)
elif args.model == 'opt':
- model = OPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
+ model = OPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank)
elif args.model == 'gpt2':
- model = GPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
+ model = GPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank)
elif args.model == 'llama':
- model = LlamaRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
+ model = LlamaRM(pretrained=args.pretrain, lora_rank=args.lora_rank)
else:
raise ValueError(f'Unsupported model "{args.model}"')
+ model.to(torch.float16).to(torch.cuda.current_device())
+
if args.model_path is not None:
state_dict = torch.load(args.model_path)
model.load_state_dict(state_dict)
- model = model.to(torch.float16)
-
# configure tokenizer
if args.model == 'gpt2':
- tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
+ tokenizer = GPT2Tokenizer.from_pretrained(
+ 'gpt2' if args.tokenizer is None else args.tokenizer)
tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'bloom':
- tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m')
+ tokenizer = BloomTokenizerFast.from_pretrained(
+ 'bigscience/bloom-560m' if args.tokenizer is None else args.tokenizer)
tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'opt':
- tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
+ tokenizer = AutoTokenizer.from_pretrained(
+ "facebook/opt-350m" if args.tokenizer is None else args.tokenizer)
tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'llama':
- tokenizer = LlamaTokenizer.from_pretrained(args.pretrain)
+ tokenizer = LlamaTokenizer.from_pretrained(
+ "hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer)
+ tokenizer.eos_token = '<\s>'
tokenizer.pad_token = tokenizer.unk_token
else:
raise ValueError(f'Unsupported model "{args.model}"')
@@ -89,8 +94,8 @@ def train(args):
data = load_dataset(args.dataset)
if args.test:
- train_data = data['train'].select(range(100))
- eval_data = data['test'].select(range(10))
+ train_data = data['train'].select(range(20))
+ eval_data = data['test'].select(range(5))
else:
train_data = data['train']
eval_data = data['test']
@@ -177,6 +182,7 @@ def train(args):
choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'],
default='colossalai_zero2')
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom')
+ parser.add_argument('--tokenizer', type=str, default=None)
parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--model_path', type=str, default=None)
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
@@ -184,7 +190,7 @@ def train(args):
type=str,
choices=['Anthropic/hh-rlhf', 'Dahoas/rm-static'],
default='Dahoas/rm-static')
- parser.add_argument('--subset', type=str, default=None)
+ parser.add_argument('--subset', type=lambda x: None if x == 'None' else x, default=None)
parser.add_argument('--save_path', type=str, default='rm_ckpt')
parser.add_argument('--max_epochs', type=int, default=1)
parser.add_argument('--batch_size', type=int, default=1)
diff --git a/applications/Chat/examples/train_rm.sh b/applications/Chat/examples/train_rm.sh
index 80abe62d2a3f..cc1b7be2815f 100755
--- a/applications/Chat/examples/train_rm.sh
+++ b/applications/Chat/examples/train_rm.sh
@@ -1,13 +1,13 @@
set_n_least_used_CUDA_VISIBLE_DEVICES() {
local n=${1:-"9999"}
echo "GPU Memory Usage:"
- local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
- | tail -n +2 \
- | nl -v 0 \
- | tee /dev/tty \
- | sort -g -k 2 \
- | awk '{print $1}' \
- | head -n $n)
+ local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
+ tail -n +2 |
+ nl -v 0 |
+ tee /dev/tty |
+ sort -g -k 2 |
+ awk '{print $1}' |
+ head -n $n)
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
echo "Now CUDA_VISIBLE_DEVICES is set to:"
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
@@ -16,9 +16,7 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() {
set_n_least_used_CUDA_VISIBLE_DEVICES 2
torchrun --standalone --nproc_per_node=2 train_reward_model.py \
- --pretrain \
- --model 'bloom' \
- --strategy colossalai_zero2 \
- --loss_fn 'log_sig'\
- --save_path \
- --dataset 'Anthropic/hh-rlhf'\
+ --model 'bloom' \
+ --strategy colossalai_zero2 \
+ --loss_fn 'log_sig' \
+ --dataset 'Anthropic/hh-rlhf'
diff --git a/applications/Chat/examples/train_sft.py b/applications/Chat/examples/train_sft.py
index 4676d47dd331..7585cf3ed0da 100644
--- a/applications/Chat/examples/train_sft.py
+++ b/applications/Chat/examples/train_sft.py
@@ -1,24 +1,22 @@
import argparse
import math
-import os
+import warnings
-import loralib as lora
import torch
import torch.distributed as dist
-from coati.dataset import DataCollatorForSupervisedDataset, SFTDataset, SupervisedDataset
-from coati.models import convert_to_lora_module
+from coati.dataset import SFTDataset, SupervisedDataset
+from coati.models.bloom import BLOOMActor
+from coati.models.gpt import GPTActor
+from coati.models.llama import LlamaActor
+from coati.models.opt import OPTActor
from coati.trainer import SFTTrainer
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
from datasets import load_dataset
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
-from transformers import AutoTokenizer, BloomConfig, BloomForCausalLM, BloomTokenizerFast, LlamaConfig, LlamaForCausalLM
-from transformers.models.gpt2.configuration_gpt2 import GPT2Config
-from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
+from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
-from transformers.models.opt.configuration_opt import OPTConfig
-from transformers.models.opt.modeling_opt import OPTForCausalLM
from transformers.trainer import get_scheduler
from colossalai.logging import get_dist_logger
@@ -31,8 +29,6 @@ def train(args):
if args.strategy == 'ddp':
strategy = DDPStrategy()
elif args.strategy == 'colossalai_gemini':
- raise NotImplementedError(
- 'Gemini is not supported .from_pretrained() yet. We will update this after checkpoint io is ready.')
strategy = GeminiStrategy(placement_policy='cuda')
elif args.strategy == 'colossalai_zero2':
strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda')
@@ -42,40 +38,49 @@ def train(args):
raise ValueError(f'Unsupported strategy "{args.strategy}"')
# configure model
+ if args.lora_rank > 0:
+ warnings.warn("Gradient checkpoint is disabled when using LoRA")
+ args.grad_checkpoint = False
with strategy.model_init_context():
if args.model == 'bloom':
- model = convert_to_lora_module(BloomForCausalLM.from_pretrained(args.pretrain),
- args.lora_rank).half().cuda()
+ model = BLOOMActor(pretrained=args.pretrain,
+ lora_rank=args.lora_rank,
+ checkpoint=args.grad_checkpoint)
elif args.model == 'opt':
- model = convert_to_lora_module(OPTForCausalLM.from_pretrained(args.pretrain), args.lora_rank).half().cuda()
+ model = OPTActor(pretrained=args.pretrain,
+ lora_rank=args.lora_rank,
+ checkpoint=args.grad_checkpoint)
elif args.model == 'gpt2':
- model = convert_to_lora_module(GPT2LMHeadModel.from_pretrained(args.pretrain), args.lora_rank).half().cuda()
+ model = GPTActor(pretrained=args.pretrain,
+ lora_rank=args.lora_rank,
+ checkpoint=args.grad_checkpoint)
elif args.model == 'llama':
- model = convert_to_lora_module(LlamaForCausalLM.from_pretrained(args.pretrain),
- args.lora_rank).half().cuda()
+ model = LlamaActor(pretrained=args.pretrain,
+ lora_rank=args.lora_rank,
+ checkpoint=args.grad_checkpoint)
else:
raise ValueError(f'Unsupported model "{args.model}"')
- if args.grad_checkpoint:
- model.gradient_checkpointing_enable()
+
+ model.to(torch.float16).to(torch.cuda.current_device())
# configure tokenizer
if args.model == 'gpt2':
- tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
+ tokenizer = GPT2Tokenizer.from_pretrained(
+ 'gpt2' if args.tokenizer is None else args.tokenizer)
tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'bloom':
- tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m')
+ tokenizer = BloomTokenizerFast.from_pretrained(
+ 'bigscience/bloom-560m' if args.tokenizer is None else args.tokenizer)
tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'opt':
- tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
- tokenizer.pad_token = tokenizer.eos_token
- elif args.model == 'llama':
tokenizer = AutoTokenizer.from_pretrained(
- args.pretrain,
- padding_side="right",
- use_fast=False,
- )
- tokenizer.eos_token = ''
+ "facebook/opt-350m" if args.tokenizer is None else args.tokenizer)
tokenizer.pad_token = tokenizer.eos_token
+ elif args.model == 'llama':
+ tokenizer = LlamaTokenizer.from_pretrained(
+ "hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer)
+ tokenizer.eos_token = '<\s>'
+ tokenizer.pad_token = tokenizer.unk_token
else:
raise ValueError(f'Unsupported model "{args.model}"')
@@ -111,7 +116,6 @@ def train(args):
max_datasets_size=args.max_datasets_size,
max_length=args.max_len)
eval_dataset = None
- data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
if dist.is_initialized() and dist.get_world_size() > 1:
train_sampler = DistributedSampler(train_dataset,
@@ -135,14 +139,12 @@ def train(args):
shuffle=(train_sampler is None),
sampler=train_sampler,
batch_size=args.batch_size,
- collate_fn=data_collator,
pin_memory=True)
if eval_dataset is not None:
eval_dataloader = DataLoader(eval_dataset,
shuffle=(eval_sampler is None),
sampler=eval_sampler,
batch_size=args.batch_size,
- collate_fn=data_collator,
pin_memory=True)
else:
eval_dataloader = None
@@ -184,6 +186,7 @@ def train(args):
choices=['ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_zero2_cpu'],
default='colossalai_zero2')
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom')
+ parser.add_argument('--tokenizer', type=str, default=None)
parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--dataset', type=str, default=None)
parser.add_argument('--max_datasets_size', type=int, default=None)
diff --git a/applications/Chat/examples/train_sft.sh b/applications/Chat/examples/train_sft.sh
index c880f85825a7..1a5cd069011d 100755
--- a/applications/Chat/examples/train_sft.sh
+++ b/applications/Chat/examples/train_sft.sh
@@ -1,12 +1,29 @@
+set_n_least_used_CUDA_VISIBLE_DEVICES() {
+ local n=${1:-"9999"}
+ echo "GPU Memory Usage:"
+ local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
+ tail -n +2 |
+ nl -v 0 |
+ tee /dev/tty |
+ sort -g -k 2 |
+ awk '{print $1}' |
+ head -n $n)
+ export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
+ echo "Now CUDA_VISIBLE_DEVICES is set to:"
+ echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
+}
+
+set_n_least_used_CUDA_VISIBLE_DEVICES 4
+
torchrun --standalone --nproc_per_node=4 train_sft.py \
--pretrain "/path/to/LLaMa-7B/" \
--model 'llama' \
--strategy colossalai_zero2 \
--log_interval 10 \
- --save_path /path/to/Coati-7B \
+ --save_path /path/to/Coati-7B \
--dataset /path/to/data.json \
--batch_size 4 \
--accumulation_steps 8 \
--lr 2e-5 \
--max_datasets_size 512 \
- --max_epochs 1 \
+ --max_epochs 1
diff --git a/applications/Chat/inference/benchmark.py b/applications/Chat/inference/benchmark.py
index a8485f588705..438a1e3ef1c7 100644
--- a/applications/Chat/inference/benchmark.py
+++ b/applications/Chat/inference/benchmark.py
@@ -4,8 +4,8 @@
from time import time
import torch
-from llama_gptq import load_quant
-from transformers import AutoTokenizer, GenerationConfig, LlamaForCausalLM
+from coati.quant import llama_load_quant, low_resource_init
+from transformers import AutoTokenizer, GenerationConfig, LlamaConfig, LlamaForCausalLM
def generate_prompt(instruction, input=None):
@@ -106,7 +106,10 @@ def evaluate(
tokenizer = AutoTokenizer.from_pretrained(args.pretrained)
if args.quant == '4bit':
- model = load_quant(args.pretrained, args.gptq_checkpoint, 4, args.gptq_group_size)
+ with low_resource_init():
+ config = LlamaConfig.from_pretrained(args.pretrained)
+ model = LlamaForCausalLM(config)
+ model = llama_load_quant(model, args.gptq_checkpoint, 4, args.gptq_group_size)
model.cuda()
else:
model = LlamaForCausalLM.from_pretrained(
diff --git a/applications/Chat/inference/llama_gptq/__init__.py b/applications/Chat/inference/llama_gptq/__init__.py
deleted file mode 100644
index 51c8d6316290..000000000000
--- a/applications/Chat/inference/llama_gptq/__init__.py
+++ /dev/null
@@ -1,5 +0,0 @@
-from .loader import load_quant
-
-__all__ = [
- 'load_quant',
-]
diff --git a/applications/Chat/inference/llama_gptq/loader.py b/applications/Chat/inference/llama_gptq/loader.py
deleted file mode 100644
index a5c6ac7d1589..000000000000
--- a/applications/Chat/inference/llama_gptq/loader.py
+++ /dev/null
@@ -1,41 +0,0 @@
-import torch
-import torch.nn as nn
-import transformers
-from transformers import LlamaConfig, LlamaForCausalLM
-
-from .model_utils import find_layers
-from .quant import make_quant
-
-
-def load_quant(pretrained: str, checkpoint: str, wbits: int, groupsize: int):
- config = LlamaConfig.from_pretrained(pretrained)
-
- def noop(*args, **kwargs):
- pass
-
- torch.nn.init.kaiming_uniform_ = noop
- torch.nn.init.uniform_ = noop
- torch.nn.init.normal_ = noop
-
- torch.set_default_dtype(torch.half)
- transformers.modeling_utils._init_weights = False
- torch.set_default_dtype(torch.half)
- model = LlamaForCausalLM(config)
- torch.set_default_dtype(torch.float)
- model = model.eval()
- layers = find_layers(model)
- for name in ['lm_head']:
- if name in layers:
- del layers[name]
- make_quant(model, layers, wbits, groupsize)
-
- print(f'Loading model with {wbits} bits...')
- if checkpoint.endswith('.safetensors'):
- from safetensors.torch import load_file as safe_load
- model.load_state_dict(safe_load(checkpoint))
- else:
- model.load_state_dict(torch.load(checkpoint))
- model.seqlen = 2048
- print('Done.')
-
- return model
diff --git a/applications/Chat/inference/llama_gptq/model_utils.py b/applications/Chat/inference/llama_gptq/model_utils.py
deleted file mode 100644
index 62db171abb52..000000000000
--- a/applications/Chat/inference/llama_gptq/model_utils.py
+++ /dev/null
@@ -1,13 +0,0 @@
-# copied from https://github.com/qwopqwop200/GPTQ-for-LLaMa/blob/past/modelutils.py
-
-import torch
-import torch.nn as nn
-
-
-def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''):
- if type(module) in layers:
- return {name: module}
- res = {}
- for name1, child in module.named_children():
- res.update(find_layers(child, layers=layers, name=name + '.' + name1 if name != '' else name1))
- return res
diff --git a/applications/Chat/inference/llama_gptq/quant.py b/applications/Chat/inference/llama_gptq/quant.py
deleted file mode 100644
index f7d5b7ce4bd8..000000000000
--- a/applications/Chat/inference/llama_gptq/quant.py
+++ /dev/null
@@ -1,283 +0,0 @@
-# copied from https://github.com/qwopqwop200/GPTQ-for-LLaMa/blob/past/quant.py
-
-import math
-
-import numpy as np
-import torch
-import torch.nn as nn
-
-
-def quantize(x, scale, zero, maxq):
- q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
- return scale * (q - zero)
-
-
-class Quantizer(nn.Module):
-
- def __init__(self, shape=1):
- super(Quantizer, self).__init__()
- self.register_buffer('maxq', torch.tensor(0))
- self.register_buffer('scale', torch.zeros(shape))
- self.register_buffer('zero', torch.zeros(shape))
-
- def configure(self, bits, perchannel=False, sym=True, mse=False, norm=2.4, grid=100, maxshrink=.8):
- self.maxq = torch.tensor(2**bits - 1)
- self.perchannel = perchannel
- self.sym = sym
- self.mse = mse
- self.norm = norm
- self.grid = grid
- self.maxshrink = maxshrink
-
- def find_params(self, x, weight=False):
- dev = x.device
- self.maxq = self.maxq.to(dev)
-
- shape = x.shape
- if self.perchannel:
- if weight:
- x = x.flatten(1)
- else:
- if len(shape) == 4:
- x = x.permute([1, 0, 2, 3])
- x = x.flatten(1)
- if len(shape) == 3:
- x = x.reshape((-1, shape[-1])).t()
- if len(shape) == 2:
- x = x.t()
- else:
- x = x.flatten().unsqueeze(0)
-
- tmp = torch.zeros(x.shape[0], device=dev)
- xmin = torch.minimum(x.min(1)[0], tmp)
- xmax = torch.maximum(x.max(1)[0], tmp)
-
- if self.sym:
- xmax = torch.maximum(torch.abs(xmin), xmax)
- tmp = xmin < 0
- if torch.any(tmp):
- xmin[tmp] = -xmax[tmp]
- tmp = (xmin == 0) & (xmax == 0)
- xmin[tmp] = -1
- xmax[tmp] = +1
-
- self.scale = (xmax - xmin) / self.maxq
- if self.sym:
- self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2)
- else:
- self.zero = torch.round(-xmin / self.scale)
-
- if self.mse:
- best = torch.full([x.shape[0]], float('inf'), device=dev)
- for i in range(int(self.maxshrink * self.grid)):
- p = 1 - i / self.grid
- xmin1 = p * xmin
- xmax1 = p * xmax
- scale1 = (xmax1 - xmin1) / self.maxq
- zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero
- q = quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq)
- q -= x
- q.abs_()
- q.pow_(self.norm)
- err = torch.sum(q, 1)
- tmp = err < best
- if torch.any(tmp):
- best[tmp] = err[tmp]
- self.scale[tmp] = scale1[tmp]
- self.zero[tmp] = zero1[tmp]
- if not self.perchannel:
- if weight:
- tmp = shape[0]
- else:
- tmp = shape[1] if len(shape) != 3 else shape[2]
- self.scale = self.scale.repeat(tmp)
- self.zero = self.zero.repeat(tmp)
-
- if weight:
- shape = [-1] + [1] * (len(shape) - 1)
- self.scale = self.scale.reshape(shape)
- self.zero = self.zero.reshape(shape)
- return
- if len(shape) == 4:
- self.scale = self.scale.reshape((1, -1, 1, 1))
- self.zero = self.zero.reshape((1, -1, 1, 1))
- if len(shape) == 3:
- self.scale = self.scale.reshape((1, 1, -1))
- self.zero = self.zero.reshape((1, 1, -1))
- if len(shape) == 2:
- self.scale = self.scale.unsqueeze(0)
- self.zero = self.zero.unsqueeze(0)
-
- def quantize(self, x):
- if self.ready():
- return quantize(x, self.scale, self.zero, self.maxq)
- return x
-
- def enabled(self):
- return self.maxq > 0
-
- def ready(self):
- return torch.all(self.scale != 0)
-
-
-try:
- import quant_cuda
-except:
- print('CUDA extension not installed.')
-
-# Assumes layer is perfectly divisible into 256 * 256 blocks
-
-
-class QuantLinear(nn.Module):
-
- def __init__(self, bits, groupsize, infeatures, outfeatures):
- super().__init__()
- if bits not in [2, 3, 4, 8]:
- raise NotImplementedError("Only 2,3,4,8 bits are supported.")
- self.infeatures = infeatures
- self.outfeatures = outfeatures
- self.bits = bits
- if groupsize != -1 and groupsize < 32 and groupsize != int(math.pow(2, int(math.log2(groupsize)))):
- raise NotImplementedError("groupsize supports powers of 2 greater than 32. (e.g. : 32,64,128,etc)")
- groupsize = groupsize if groupsize != -1 else infeatures
- self.groupsize = groupsize
- self.register_buffer(
- 'qzeros', torch.zeros((math.ceil(infeatures / groupsize), outfeatures // 256 * (bits * 8)),
- dtype=torch.int))
- self.register_buffer('scales', torch.zeros((math.ceil(infeatures / groupsize), outfeatures)))
- self.register_buffer('bias', torch.zeros(outfeatures))
- self.register_buffer('qweight', torch.zeros((infeatures // 256 * (bits * 8), outfeatures), dtype=torch.int))
- self._initialized_quant_state = False
-
- def pack(self, linear, scales, zeros):
- scales = scales.t().contiguous()
- zeros = zeros.t().contiguous()
- scale_zeros = zeros * scales
- self.scales = scales.clone()
- if linear.bias is not None:
- self.bias = linear.bias.clone()
-
- intweight = []
- for idx in range(self.infeatures):
- g_idx = idx // self.groupsize
- intweight.append(
- torch.round((linear.weight.data[:, idx] + scale_zeros[g_idx]) / self.scales[g_idx]).to(torch.int)[:,
- None])
- intweight = torch.cat(intweight, dim=1)
- intweight = intweight.t().contiguous()
- intweight = intweight.numpy().astype(np.uint32)
- qweight = np.zeros((intweight.shape[0] // 256 * (self.bits * 8), intweight.shape[1]), dtype=np.uint32)
- i = 0
- row = 0
- while row < qweight.shape[0]:
- if self.bits in [2, 4, 8]:
- for j in range(i, i + (32 // self.bits)):
- qweight[row] |= intweight[j] << (self.bits * (j - i))
- i += 32 // self.bits
- row += 1
- elif self.bits == 3:
- for j in range(i, i + 10):
- qweight[row] |= intweight[j] << (3 * (j - i))
- i += 10
- qweight[row] |= intweight[i] << 30
- row += 1
- qweight[row] |= (intweight[i] >> 2) & 1
- i += 1
- for j in range(i, i + 10):
- qweight[row] |= intweight[j] << (3 * (j - i) + 1)
- i += 10
- qweight[row] |= intweight[i] << 31
- row += 1
- qweight[row] |= (intweight[i] >> 1) & 0x3
- i += 1
- for j in range(i, i + 10):
- qweight[row] |= intweight[j] << (3 * (j - i) + 2)
- i += 10
- row += 1
- else:
- raise NotImplementedError("Only 2,3,4,8 bits are supported.")
-
- qweight = qweight.astype(np.int32)
- self.qweight = torch.from_numpy(qweight)
-
- zeros -= 1
- zeros = zeros.numpy().astype(np.uint32)
- qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 256 * (self.bits * 8)), dtype=np.uint32)
- i = 0
- col = 0
- while col < qzeros.shape[1]:
- if self.bits in [2, 4, 8]:
- for j in range(i, i + (32 // self.bits)):
- qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
- i += 32 // self.bits
- col += 1
- elif self.bits == 3:
- for j in range(i, i + 10):
- qzeros[:, col] |= zeros[:, j] << (3 * (j - i))
- i += 10
- qzeros[:, col] |= zeros[:, i] << 30
- col += 1
- qzeros[:, col] |= (zeros[:, i] >> 2) & 1
- i += 1
- for j in range(i, i + 10):
- qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 1)
- i += 10
- qzeros[:, col] |= zeros[:, i] << 31
- col += 1
- qzeros[:, col] |= (zeros[:, i] >> 1) & 0x3
- i += 1
- for j in range(i, i + 10):
- qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 2)
- i += 10
- col += 1
- else:
- raise NotImplementedError("Only 2,3,4,8 bits are supported.")
-
- qzeros = qzeros.astype(np.int32)
- self.qzeros = torch.from_numpy(qzeros)
-
- def forward(self, x):
- intermediate_dtype = torch.float32
-
- if not self._initialized_quant_state:
- # Do we even have a bias? Check for at least one non-zero element.
- if self.bias is not None and bool(torch.any(self.bias != 0)):
- # Then make sure it's the right type.
- self.bias.data = self.bias.data.to(intermediate_dtype)
- else:
- self.bias = None
-
- outshape = list(x.shape)
- outshape[-1] = self.outfeatures
- x = x.reshape(-1, x.shape[-1])
- if self.bias is None:
- y = torch.zeros(x.shape[0], outshape[-1], dtype=intermediate_dtype, device=x.device)
- else:
- y = self.bias.clone().repeat(x.shape[0], 1)
-
- output_dtype = x.dtype
- x = x.to(intermediate_dtype)
- if self.bits == 2:
- quant_cuda.vecquant2matmul(x, self.qweight, y, self.scales, self.qzeros, self.groupsize)
- elif self.bits == 3:
- quant_cuda.vecquant3matmul(x, self.qweight, y, self.scales, self.qzeros, self.groupsize)
- elif self.bits == 4:
- quant_cuda.vecquant4matmul(x, self.qweight, y, self.scales, self.qzeros, self.groupsize)
- elif self.bits == 8:
- quant_cuda.vecquant8matmul(x, self.qweight, y, self.scales, self.qzeros, self.groupsize)
- else:
- raise NotImplementedError("Only 2,3,4,8 bits are supported.")
- y = y.to(output_dtype)
- return y.reshape(outshape)
-
-
-def make_quant(module, names, bits, groupsize, name=''):
- if isinstance(module, QuantLinear):
- return
- for attr in dir(module):
- tmp = getattr(module, attr)
- name1 = name + '.' + attr if name != '' else attr
- if name1 in names:
- setattr(module, attr, QuantLinear(bits, groupsize, tmp.in_features, tmp.out_features))
- for name1, child in module.named_children():
- make_quant(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1)
diff --git a/applications/Chat/inference/locustfile.py b/applications/Chat/inference/locustfile.py
index 51cdc68125bb..9443d4b99180 100644
--- a/applications/Chat/inference/locustfile.py
+++ b/applications/Chat/inference/locustfile.py
@@ -5,8 +5,7 @@
samples = [[
dict(
instruction='Who is the best player in the history of NBA?',
- response=
- 'The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1'
+ response='The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1'
),
dict(instruction='continue this talk', response=''),
], [
diff --git a/applications/Chat/inference/server.py b/applications/Chat/inference/server.py
index e23f0fceb2fa..9d6b7fabef54 100644
--- a/applications/Chat/inference/server.py
+++ b/applications/Chat/inference/server.py
@@ -1,19 +1,19 @@
import argparse
import os
from threading import Lock
-from typing import Dict, Generator, List, Optional
+from typing import Generator, List, Optional
import torch
import uvicorn
-from fastapi import FastAPI, HTTPException, Request
+from coati.quant import llama_load_quant, low_resource_init
+from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
-from llama_gptq import load_quant
from pydantic import BaseModel, Field
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.errors import RateLimitExceeded
from slowapi.util import get_remote_address
from sse_starlette.sse import EventSourceResponse
-from transformers import AutoTokenizer, GenerationConfig, LlamaForCausalLM
+from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM
from utils import ChatPromptProcessor, Dialogue, LockedIterator, load_json, sample_streamingly, update_model_kwargs_fn
CONTEXT = 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.'
@@ -56,7 +56,7 @@ class GenerationTaskReq(BaseModel):
def generate_streamingly(prompt, max_new_tokens, top_k, top_p, temperature):
inputs = {k: v.cuda() for k, v in tokenizer(prompt, return_tensors="pt").items()}
- #TODO(ver217): streaming generation does not support repetition_penalty now
+ # TODO(ver217): streaming generation does not support repetition_penalty now
model_kwargs = {
'max_generate_tokens': max_new_tokens,
'early_stopping': True,
@@ -162,7 +162,10 @@ def generate_no_stream(data: GenerationTaskReq, request: Request):
prompt_processor = ChatPromptProcessor(tokenizer, CONTEXT, MAX_LEN, censored_words=censored_words)
if args.quant == '4bit':
- model = load_quant(args.pretrained, args.gptq_checkpoint, 4, args.gptq_group_size)
+ with low_resource_init():
+ config = LlamaConfig.from_pretrained(args.pretrained)
+ model = LlamaForCausalLM(config)
+ model = llama_load_quant(model, args.gptq_checkpoint, 4, args.gptq_group_size)
model.cuda()
else:
model = LlamaForCausalLM.from_pretrained(
diff --git a/applications/Chat/inference/tests/test_chat_prompt.py b/applications/Chat/inference/tests/test_chat_prompt.py
index f5737ebe8c09..23028d4959cb 100644
--- a/applications/Chat/inference/tests/test_chat_prompt.py
+++ b/applications/Chat/inference/tests/test_chat_prompt.py
@@ -10,37 +10,34 @@
([
Dialogue(
instruction='Who is the best player in the history of NBA?',
- response=
- 'The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1'
+ response='The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1'
),
Dialogue(instruction='continue this talk', response=''),
], 128,
- 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\nWho is the best player in the history of NBA?\n\n### Response:\nThe best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1\n\n### Instruction:\ncontinue this talk\n\n### Response:\n'
+ 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\nWho is the best player in the history of NBA?\n\n### Response:\nThe best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1\n\n### Instruction:\ncontinue this talk\n\n### Response:\n'
),
([
Dialogue(
instruction='Who is the best player in the history of NBA?',
- response=
- 'The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1'
+ response='The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1'
),
Dialogue(instruction='continue this talk', response=''),
], 200,
- 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\ncontinue this talk\n\n### Response:\n'
+ 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\ncontinue this talk\n\n### Response:\n'
),
([
Dialogue(
instruction='Who is the best player in the history of NBA?',
- response=
- 'The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1'
+ response='The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1'
),
Dialogue(instruction='continue this talk', response=''),
], 211,
- 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\ncontinue this\n\n### Response:\n'
+ 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\ncontinue this\n\n### Response:\n'
),
([
Dialogue(instruction='Who is the best player in the history of NBA?', response=''),
], 128,
- 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\nWho is the best player in the history of NBA?\n\n### Response:\n'
+ 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\nWho is the best player in the history of NBA?\n\n### Response:\n'
),
]
diff --git a/applications/Chat/inference/utils.py b/applications/Chat/inference/utils.py
index 37944be70a3b..e8e7b05ac719 100644
--- a/applications/Chat/inference/utils.py
+++ b/applications/Chat/inference/utils.py
@@ -1,9 +1,9 @@
+import json
import re
from threading import Lock
from typing import Any, Callable, Generator, List, Optional
-import json
-import jieba
+import jieba
import torch
import torch.distributed as dist
import torch.nn as nn
@@ -127,7 +127,7 @@ def _format_dialogue(instruction: str, response: str = ''):
class ChatPromptProcessor:
SAFE_RESPONSE = 'The input/response contains inappropriate content, please rephrase your prompt.'
- def __init__(self, tokenizer, context: str, max_len: int = 2048, censored_words: List[str]=[]):
+ def __init__(self, tokenizer, context: str, max_len: int = 2048, censored_words: List[str] = []):
self.tokenizer = tokenizer
self.context = context
self.max_len = max_len
@@ -182,6 +182,7 @@ def has_censored_words(self, text: str) -> bool:
intersection = set(jieba.cut(text.lower())) & self.censored_words
return len(intersection) > 0
+
class LockedIterator:
def __init__(self, it, lock: Lock) -> None:
@@ -195,6 +196,7 @@ def __next__(self):
with self.lock:
return next(self.it)
+
def load_json(path: str):
with open(path) as f:
- return json.load(f)
\ No newline at end of file
+ return json.load(f)
diff --git a/applications/Chat/tests/test_benchmarks.sh b/applications/Chat/tests/test_benchmarks.sh
new file mode 100755
index 000000000000..3fdb25181342
--- /dev/null
+++ b/applications/Chat/tests/test_benchmarks.sh
@@ -0,0 +1,33 @@
+#!/bin/bash
+
+set -xue
+
+echo "Hint: You can run this script with 'verbose' as the first argument to run all strategies."
+
+if [[ $# -ne 0 && "$1" == "verbose" ]]; then
+ STRATEGIES=(
+ 'ddp'
+ 'colossalai_gemini'
+ 'colossalai_gemini_cpu'
+ 'colossalai_zero2'
+ 'colossalai_zero2_cpu'
+ 'colossalai_zero1'
+ 'colossalai_zero1_cpu'
+ )
+else
+ STRATEGIES=(
+ 'colossalai_zero2'
+ )
+fi
+
+BASE_DIR=$(dirname $(dirname $(realpath $BASH_SOURCE)))
+BENCHMARKS_DIR=$BASE_DIR/benchmarks
+
+echo "[Test]: testing benchmarks ..."
+
+for strategy in ${STRATEGIES[@]}; do
+ torchrun --standalone --nproc_per_node 1 $BENCHMARKS_DIR/benchmark_opt_lora_dummy.py \
+ --model 125m --critic_model 125m --strategy ${strategy} --lora_rank 4 \
+ --num_episodes 2 --num_collect_steps 4 --num_update_steps 2 \
+ --train_batch_size 2 --experience_batch_size 4
+done
diff --git a/applications/Chat/tests/test_checkpoint.py b/applications/Chat/tests/test_checkpoint.py
index 19338da437ab..3a3bf5b19cb8 100644
--- a/applications/Chat/tests/test_checkpoint.py
+++ b/applications/Chat/tests/test_checkpoint.py
@@ -7,7 +7,7 @@
import torch.distributed as dist
from coati.models.gpt import GPTActor
from coati.models.utils import calc_action_log_probs
-from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
+from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy, Strategy
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
from colossalai.nn.optimizer import HybridAdam
@@ -17,40 +17,41 @@
def get_data(batch_size: int, seq_len: int = 10) -> dict:
- input_ids = torch.randint(0, 50257, (batch_size, seq_len), device='cuda')
+ input_ids = torch.randint(0, 50257, (batch_size, seq_len), device="cuda")
attention_mask = torch.ones_like(input_ids)
return dict(input_ids=input_ids, attention_mask=attention_mask)
-def run_test_checkpoint(strategy):
- BATCH_SIZE = 2
+def train_step(strategy: Strategy,
+ actor: GPTActor,
+ actor_optim: HybridAdam,
+ batch_size: int = 8):
+ data = get_data(batch_size)
+ action_mask = torch.ones_like(data["attention_mask"], dtype=torch.bool)
+ actor_output = actor(data["input_ids"], data["attention_mask"])
+ action_log_probs = calc_action_log_probs(actor_output, data["input_ids"], action_mask.size(1))
+ loss = action_log_probs.sum()
+ strategy.backward(loss, actor, actor_optim)
+ strategy.optimizer_step(actor_optim)
- if strategy == 'ddp':
+
+def run_test_checkpoint(strategy_name: str,
+ shard: bool):
+ if strategy_name == "ddp":
strategy = DDPStrategy()
- elif strategy == 'colossalai_gemini':
- strategy = GeminiStrategy(placement_policy='cuda', initial_scale=2**5)
- elif strategy == 'colossalai_zero2':
- strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda')
+ elif strategy_name == "colossalai_gemini":
+ strategy = GeminiStrategy(placement_policy="cuda", initial_scale=2**5)
+ elif strategy_name == "colossalai_zero2":
+ strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
else:
- raise ValueError(f'Unsupported strategy "{strategy}"')
+ raise ValueError(f"Unsupported strategy '{strategy_name}'")
with strategy.model_init_context():
actor = GPTActor(config=GPT_CONFIG).cuda()
-
actor_optim = HybridAdam(actor.parameters())
-
actor, actor_optim = strategy.prepare((actor, actor_optim))
- def run_step():
- data = get_data(BATCH_SIZE)
- action_mask = torch.ones_like(data['attention_mask'], dtype=torch.bool)
- actor_output = actor(data['input_ids'], data['attention_mask'])
- action_log_probs = calc_action_log_probs(actor_output, data['input_ids'], action_mask.size(1))
- loss = action_log_probs.sum()
- strategy.backward(loss, actor, actor_optim)
- strategy.optimizer_step(actor_optim)
-
- run_step()
+ train_step(strategy, actor, actor_optim)
ctx = tempfile.TemporaryDirectory() if dist.get_rank() == 0 else nullcontext()
@@ -59,43 +60,47 @@ def run_step():
dist.broadcast_object_list(rank0_dirname)
rank0_dirname = rank0_dirname[0]
- model_path = os.path.join(rank0_dirname, 'model.pt')
- strategy.save_model(actor, model_path, only_rank0=True)
-
- optim_path = os.path.join(rank0_dirname, f'optim.pt')
- strategy.save_optimizer(actor_optim, optim_path, only_rank0=True)
-
- # FIXME(cwher): Sharded optimizer checkpoint is not supported yet.
- # at "ColossalAI/colossalai/checkpoint_io/general_checkpoint_io.py", line 62
- # optim_path = os.path.join(rank0_dirname, f'optim-r{dist.get_rank()}.pt')
- # strategy.save_optimizer(actor_optim, optim_path, only_rank0=False)
-
+ model_path = os.path.join(
+ rank0_dirname, "model" if shard else f"model.pt")
+ strategy.save_model(actor, model_path, only_rank0=not shard)
+ optim_path = os.path.join(
+ rank0_dirname, "optim" if shard else "optim.pt")
+ strategy.save_optimizer(actor_optim, optim_path, only_rank0=not shard)
dist.barrier()
strategy.load_model(actor, model_path, strict=False)
strategy.load_optimizer(actor_optim, optim_path)
-
dist.barrier()
- run_step()
+ train_step(strategy, actor, actor_optim)
-def run_dist(rank, world_size, port, strategy):
- os.environ['RANK'] = str(rank)
- os.environ['LOCAL_RANK'] = str(rank)
- os.environ['WORLD_SIZE'] = str(world_size)
- os.environ['MASTER_ADDR'] = 'localhost'
- os.environ['MASTER_PORT'] = str(port)
- run_test_checkpoint(strategy)
+def run_dist(rank: int,
+ world_size: int,
+ port: int,
+ strategy_name: str,
+ shard: bool):
+ os.environ["RANK"] = str(rank)
+ os.environ["LOCAL_RANK"] = str(rank)
+ os.environ["WORLD_SIZE"] = str(world_size)
+ os.environ["MASTER_ADDR"] = "localhost"
+ os.environ["MASTER_PORT"] = str(port)
+ run_test_checkpoint(strategy_name, shard)
@pytest.mark.dist
-@pytest.mark.parametrize('world_size', [2])
-@pytest.mark.parametrize('strategy', ['ddp', 'colossalai_zero2', 'colossalai_gemini'])
+@pytest.mark.parametrize("world_size", [4])
+@pytest.mark.parametrize("strategy_name", ["ddp", "colossalai_gemini", "colossalai_zero2"])
+@pytest.mark.parametrize("shard", [False, True])
@rerun_if_address_is_in_use()
-def test_checkpoint(world_size, strategy):
- spawn(run_dist, world_size, strategy=strategy)
+def test_checkpoint(world_size: int,
+ strategy_name: str,
+ shard: bool):
+ spawn(run_dist,
+ world_size,
+ strategy_name=strategy_name,
+ shard=shard)
-if __name__ == '__main__':
- test_checkpoint(2, 'colossalai_zero2')
+if __name__ == "__main__":
+ test_checkpoint(2, "colossalai_gemini", shard=False)
diff --git a/applications/Chat/tests/test_dataset.py b/applications/Chat/tests/test_dataset.py
new file mode 100644
index 000000000000..64ea1178cd0d
--- /dev/null
+++ b/applications/Chat/tests/test_dataset.py
@@ -0,0 +1,248 @@
+import json
+import os
+import tempfile
+from typing import Optional
+
+import pytest
+import torch
+from coati.dataset.prompt_dataset import PromptDataset
+from coati.dataset.reward_dataset import HhRlhfDataset, RmStaticDataset
+from coati.dataset.sft_dataset import IGNORE_INDEX, SFTDataset, SupervisedDataset
+from datasets import load_dataset
+from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer, PreTrainedTokenizer
+from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
+
+SFT_DATASET = [
+ {
+ "instruction": "Provide a list of the top 10 most popular mobile games in Asia",
+ "input": "",
+ "output": "The top 10 most popular mobile games in Asia are:\n1) PUBG Mobile\n2) Pokemon Go\n3) Candy Crush Saga\n4) Free Fire\n5) Clash of Clans\n6) Mario Kart Tour\n7) Arena of Valor\n8) Fantasy Westward Journey\n9) Subway Surfers\n10) ARK Survival Evolved",
+ "id": 0
+ },
+ {
+ "instruction": "Please provide an action plan for reducing carbon footprint on a corporate level",
+ "input": "",
+ "output": "An action plan for reducing carbon footprint on a corporate level could include switching to renewable energy sources (e.g., solar or wind powered), investing in energy efficient infrastructure, implementing green production practices (e.g., reduce water and paper usage, dispose of materials responsibly, recycle/reuse materials), reducing or eliminating travel, encouraging employees to be conscious of and reduce their own energy usage, and incentivizing employees to participate in sustainable work practices.",
+ "id": 1
+ },
+ {
+ "instruction": "Write a persuasive email to your boss explaining why you should have a pay raise",
+ "input": "",
+ "output": "Dear [Boss], \n\nI hope this message finds you well. I am writing to request a pay raise.\n\nAs you know, I have been a dedicated and hardworking member of the team since I started working here [insert number] of months/years ago. My enthusiasm and passion for my job has remained consistent over the years, and I have always given 100% to my role. \n\nI understand that the current financial situation is challenging, however, I would sincerely appreciate you taking the time to consider my request. I believe that my dedication to the job and the value that I bring to the organization warrants a raise. I work diligently and am confident that I can continue to be an asset to the company. \n\nI hope my request is taken into account and I thank you in advance for your understanding. I look forward to our conversation. \n\nSincerely,\n[Your Name]",
+ "id": 2
+ },
+]
+
+PROMPT_DATASET = [
+ {
+ "instruction": "Edit this paragraph to make it more concise: \"Yesterday, I went to the store and bought some things. Then, I came home and put them away. After that, I went for a walk and met some friends.\"",
+ "id": 0
+ },
+ {
+ "instruction": "Write a descriptive paragraph about a memorable vacation you went on",
+ "id": 1
+ },
+ {
+ "instruction": "Write a persuasive essay arguing why homework should be banned in schools",
+ "id": 2
+ },
+ {
+ "instruction": "Create a chart comparing the statistics on student debt in the United States.",
+ "id": 3
+ },
+]
+
+
+def make_tokenizer(model: str):
+ if model == "gpt2":
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
+ tokenizer.pad_token = tokenizer.eos_token
+ elif model == "bloom":
+ tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m")
+ tokenizer.pad_token = tokenizer.eos_token
+ elif model == "opt":
+ tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
+ tokenizer.pad_token = tokenizer.eos_token
+ elif model == "llama":
+ tokenizer = LlamaTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
+ tokenizer.pad_token = tokenizer.unk_token
+ else:
+ raise ValueError(f"Unsupported model '{model}'")
+ return tokenizer
+
+
+def check_content(input_ids_stripped: torch.Tensor,
+ tokenizer: PreTrainedTokenizer,
+ model: str):
+ if model == "opt":
+ # NOTE: Contrary to GPT2, OPT adds the EOS token to the beginning of every prompt.
+ assert input_ids_stripped[0] == tokenizer.eos_token_id
+ input_ids_stripped = input_ids_stripped[1:]
+ elif model == "llama":
+ assert input_ids_stripped[0] == tokenizer.bos_token_id
+ input_ids_stripped = input_ids_stripped[1:]
+
+ assert torch.all(input_ids_stripped != tokenizer.pad_token_id)
+ assert torch.all(input_ids_stripped != tokenizer.bos_token_id)
+ assert torch.all(input_ids_stripped != tokenizer.eos_token_id)
+ assert input_ids_stripped != tokenizer.sep_token_id
+ assert input_ids_stripped != tokenizer.cls_token_id
+ assert input_ids_stripped != tokenizer.mask_token_id
+
+
+@pytest.mark.cpu
+@pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama"])
+@pytest.mark.parametrize("max_length", [32, 1024])
+@pytest.mark.parametrize("max_datasets_size", [2])
+def test_prompt_dataset(model: str,
+ max_datasets_size: int,
+ max_length: int):
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ dataset_name = "prompt_dataset.json"
+ with open(os.path.join(tmp_dir, dataset_name), "w") as f:
+ json.dump(PROMPT_DATASET, f)
+ tokenizer = make_tokenizer(model)
+ assert tokenizer.padding_side in ("left", "right")
+ prompt_dataset = PromptDataset(data_path=os.path.join(tmp_dir, dataset_name),
+ tokenizer=tokenizer,
+ max_datasets_size=max_datasets_size,
+ max_length=max_length)
+ assert len(prompt_dataset) == min(max_datasets_size, len(PROMPT_DATASET))
+ for i in range(len(prompt_dataset)):
+ assert isinstance(prompt_dataset[i], dict)
+ assert list(prompt_dataset[i].keys()) == ["input_ids", "attention_mask"]
+ input_ids = prompt_dataset[i]["input_ids"]
+ attention_mask = prompt_dataset[i]["attention_mask"]
+ attention_mask = attention_mask.bool()
+ assert input_ids.shape == attention_mask.shape == torch.Size([max_length])
+ assert torch.all(input_ids[torch.logical_not(attention_mask)] == tokenizer.pad_token_id)
+ check_content(input_ids.masked_select(attention_mask), tokenizer, model)
+
+
+@pytest.mark.cpu
+@pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama"])
+@pytest.mark.parametrize(["dataset_path", "subset"], [
+ ("Anthropic/hh-rlhf", "harmless-base"),
+ ("Dahoas/rm-static", None)
+])
+@pytest.mark.parametrize("max_datasets_size", [32])
+@pytest.mark.parametrize("max_length", [32, 1024])
+def test_reward_dataset(model: str,
+ dataset_path: str,
+ subset: Optional[str],
+ max_datasets_size: int,
+ max_length: int):
+ data = load_dataset(dataset_path, data_dir=subset)
+ assert max_datasets_size <= len(data["train"]) \
+ and max_datasets_size <= len(data["test"])
+ train_data = data["train"].select(range(max_datasets_size))
+ test_data = data["test"].select(range(max_datasets_size))
+ tokenizer = make_tokenizer(model)
+ assert tokenizer.padding_side in ("left", "right")
+
+ if dataset_path == "Anthropic/hh-rlhf":
+ train_dataset = HhRlhfDataset(train_data, tokenizer, max_length)
+ test_dataset = HhRlhfDataset(test_data, tokenizer, max_length)
+ elif dataset_path == "Dahoas/rm-static":
+ train_dataset = RmStaticDataset(train_data, tokenizer, max_length)
+ test_dataset = RmStaticDataset(test_data, tokenizer, max_length)
+ else:
+ raise ValueError(f'Unsupported dataset "{dataset_path}"')
+
+ assert len(train_dataset) == len(test_dataset) == max_datasets_size
+ for i in range(max_datasets_size):
+ chosen_ids, c_mask, reject_ids, r_mask = train_dataset[i]
+ assert chosen_ids.shape == c_mask.shape == \
+ reject_ids.shape == r_mask.shape == torch.Size([max_length])
+ c_mask = c_mask.to(torch.bool)
+ r_mask = r_mask.to(torch.bool)
+ if chosen_ids.masked_select(c_mask)[-1] == tokenizer.eos_token_id:
+ check_content(chosen_ids.masked_select(c_mask)[:-1], tokenizer, model)
+ assert torch.all(chosen_ids.masked_select(torch.logical_not(c_mask)) == tokenizer.pad_token_id)
+ else:
+ check_content(chosen_ids.masked_select(c_mask), tokenizer, model)
+ assert torch.all(c_mask)
+ if reject_ids.masked_select(r_mask)[-1] == tokenizer.eos_token_id:
+ check_content(reject_ids.masked_select(r_mask)[:-1], tokenizer, model)
+ assert torch.all(reject_ids.masked_select(torch.logical_not(r_mask)) == tokenizer.pad_token_id)
+ else:
+ check_content(reject_ids.masked_select(r_mask), tokenizer, model)
+ assert torch.all(r_mask)
+
+ chosen_ids, c_mask, reject_ids, r_mask = test_dataset[i]
+ assert chosen_ids.shape == c_mask.shape == \
+ reject_ids.shape == r_mask.shape == torch.Size([max_length])
+ c_mask = c_mask.to(torch.bool)
+ r_mask = r_mask.to(torch.bool)
+ if chosen_ids.masked_select(c_mask)[-1] == tokenizer.eos_token_id:
+ check_content(chosen_ids.masked_select(c_mask)[:-1], tokenizer, model)
+ assert torch.all(chosen_ids.masked_select(torch.logical_not(c_mask)) == tokenizer.pad_token_id)
+ else:
+ check_content(chosen_ids.masked_select(c_mask), tokenizer, model)
+ assert torch.all(c_mask)
+ if reject_ids.masked_select(r_mask)[-1] == tokenizer.eos_token_id:
+ check_content(reject_ids.masked_select(r_mask)[:-1], tokenizer, model)
+ assert torch.all(reject_ids.masked_select(torch.logical_not(r_mask)) == tokenizer.pad_token_id)
+ else:
+ check_content(reject_ids.masked_select(r_mask), tokenizer, model)
+ assert torch.all(r_mask)
+
+
+@pytest.mark.cpu
+@pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama"])
+@pytest.mark.parametrize("dataset_path", ["yizhongw/self_instruct", None])
+@pytest.mark.parametrize("max_dataset_size", [2])
+@pytest.mark.parametrize("max_length", [32, 1024])
+def test_sft_dataset(model: str,
+ dataset_path: Optional[str],
+ max_dataset_size: int,
+ max_length: int):
+ tokenizer = make_tokenizer(model)
+ if dataset_path == "yizhongw/self_instruct":
+ data = load_dataset(dataset_path, "super_natural_instructions")
+ train_data = data["train"].select(range(max_dataset_size))
+ sft_dataset = SFTDataset(train_data, tokenizer, max_length)
+ else:
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ dataset_name = "sft_dataset.json"
+ with open(os.path.join(tmp_dir, dataset_name), "w") as f:
+ json.dump(SFT_DATASET, f)
+ sft_dataset = SupervisedDataset(tokenizer=tokenizer,
+ data_path=os.path.join(tmp_dir, dataset_name),
+ max_datasets_size=max_dataset_size,
+ max_length=max_length)
+ assert len(sft_dataset) == min(max_dataset_size, len(SFT_DATASET))
+
+ for i in range(max_dataset_size):
+ assert isinstance(sft_dataset[i], dict)
+ assert list(sft_dataset[i].keys()) == ["input_ids", "labels", "attention_mask"]
+ input_ids = sft_dataset[i]["input_ids"]
+ labels = sft_dataset[i]["labels"]
+ attention_mask = sft_dataset[i]["attention_mask"].to(torch.bool)
+ assert input_ids.shape == labels.shape == \
+ attention_mask.shape == torch.Size([max_length])
+ if input_ids.masked_select(attention_mask)[-1] == tokenizer.eos_token_id:
+ check_content(input_ids.masked_select(attention_mask)[:-1], tokenizer, model)
+ assert torch.all(input_ids.masked_select(torch.logical_not(attention_mask)) == tokenizer.pad_token_id)
+ else:
+ check_content(input_ids.masked_select(attention_mask), tokenizer, model)
+ assert torch.all(attention_mask)
+ ignore_mask = labels == IGNORE_INDEX
+ check_content(input_ids.masked_select(ignore_mask), tokenizer, model)
+
+
+if __name__ == "__main__":
+ test_sft_dataset(model="bloom",
+ dataset_path="yizhongw/self_instruct",
+ max_dataset_size=2,
+ max_length=256)
+
+ test_reward_dataset(model="gpt2",
+ dataset_path="Anthropic/hh-rlhf",
+ subset="harmless-base",
+ max_datasets_size=8,
+ max_length=256)
+
+ test_prompt_dataset(model="opt",
+ max_datasets_size=2,
+ max_length=128)
diff --git a/applications/Chat/tests/test_data.py b/applications/Chat/tests/test_experience.py
similarity index 82%
rename from applications/Chat/tests/test_data.py
rename to applications/Chat/tests/test_experience.py
index db641a6218b1..071e50b90e8e 100644
--- a/applications/Chat/tests/test_data.py
+++ b/applications/Chat/tests/test_experience.py
@@ -4,11 +4,12 @@
import pytest
import torch
import torch.distributed as dist
+from coati.experience_buffer import NaiveExperienceBuffer
from coati.experience_maker import NaiveExperienceMaker
from coati.models.base import RewardModel
from coati.models.gpt import GPTActor, GPTCritic
-from coati.replay_buffer import NaiveReplayBuffer
from coati.trainer.strategies import DDPStrategy, GeminiStrategy
+from coati.trainer.strategies.colossalai import LowLevelZeroStrategy
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
from colossalai.testing import rerun_if_address_is_in_use, spawn
@@ -32,13 +33,15 @@ def gather_and_equal(tensor: torch.Tensor) -> bool:
return True
-def run_test_data(strategy):
+def make_and_consume_experience(strategy):
EXPERIENCE_BATCH_SIZE = 4
SAMPLE_BATCH_SIZE = 2
if strategy == 'ddp':
strategy = DDPStrategy()
- elif strategy == 'colossalai':
+ elif strategy == 'colossalai-zero2':
+ strategy = LowLevelZeroStrategy()
+ elif strategy == 'colossalai-gemini':
strategy = GeminiStrategy(placement_policy='cuda')
else:
raise ValueError(f'Unsupported strategy "{strategy}"')
@@ -50,7 +53,7 @@ def run_test_data(strategy):
reward_model = RewardModel(deepcopy(critic.model)).cuda()
experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model)
- replay_buffer = NaiveReplayBuffer(SAMPLE_BATCH_SIZE, cpu_offload=False)
+ data_buffer = NaiveExperienceBuffer(SAMPLE_BATCH_SIZE, cpu_offload=False)
# experience of all ranks should be the same
for _ in range(2):
@@ -69,12 +72,12 @@ def run_test_data(strategy):
assert gather_and_equal(experience.advantages)
assert gather_and_equal(experience.action_mask)
assert gather_and_equal(experience.attention_mask)
- replay_buffer.append(experience)
+ data_buffer.append(experience)
- # replay buffer's data should be the same
- buffer_size = torch.tensor([len(replay_buffer)], device='cuda')
+ # data buffer's data should be the same
+ buffer_size = torch.tensor([len(data_buffer)], device='cuda')
assert gather_and_equal(buffer_size)
- for item in replay_buffer.items:
+ for item in data_buffer.items:
assert gather_and_equal(item.sequences)
assert gather_and_equal(item.action_log_probs)
assert gather_and_equal(item.values)
@@ -84,7 +87,7 @@ def run_test_data(strategy):
assert gather_and_equal(item.attention_mask)
# dataloader of each rank should have the same size and different batch
- dataloader = strategy.setup_dataloader(replay_buffer)
+ dataloader = strategy.setup_dataloader(data_buffer)
dataloader_size = torch.tensor([len(dataloader)], device='cuda')
assert gather_and_equal(dataloader_size)
for experience in dataloader:
@@ -102,17 +105,16 @@ def run_dist(rank, world_size, port, strategy):
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = str(port)
- run_test_data(strategy)
+ make_and_consume_experience(strategy)
-@pytest.mark.skip
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [2])
-@pytest.mark.parametrize('strategy', ['ddp', 'colossalai'])
+@pytest.mark.parametrize('strategy', ['ddp', 'colossalai-zero2', 'colossalai-gemini'])
@rerun_if_address_is_in_use()
-def test_data(world_size, strategy):
+def test_experience(world_size, strategy):
spawn(run_dist, world_size, strategy=strategy)
if __name__ == '__main__':
- test_data(2, 'colossalai')
+ test_experience(2, 'colossalai')
diff --git a/applications/Chat/tests/test_inference.sh b/applications/Chat/tests/test_inference.sh
new file mode 100755
index 000000000000..849db06e58ab
--- /dev/null
+++ b/applications/Chat/tests/test_inference.sh
@@ -0,0 +1,11 @@
+set -xue
+
+BASE_DIR=$(dirname $(dirname $(realpath $BASH_SOURCE)))
+EXAMPLES_DIR=$BASE_DIR/examples
+
+echo "[Test]: testing inference ..."
+
+# HACK: skip llama due to oom
+for model in 'gpt2' 'bloom' 'opt'; do
+ python $EXAMPLES_DIR/inference.py --model $model
+done
diff --git a/applications/Chat/tests/test_models.py b/applications/Chat/tests/test_models.py
new file mode 100644
index 000000000000..bd6b3e8a5ad1
--- /dev/null
+++ b/applications/Chat/tests/test_models.py
@@ -0,0 +1,235 @@
+import copy
+from typing import Any, Callable, Dict, Tuple
+
+import pytest
+import torch
+import torch.nn as nn
+from coati.models.base import Actor, Critic, RewardModel, get_base_model
+from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic
+from coati.models.generation import generate
+from coati.models.gpt import GPTRM, GPTActor, GPTCritic
+from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM
+from coati.models.lora import LoraLinear, convert_to_lora_module
+from coati.models.loss import GPTLMLoss, LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss
+from coati.models.opt import OPTRM, OPTActor, OPTCritic
+from coati.models.utils import calc_action_log_probs, compute_reward, masked_mean
+
+
+@pytest.mark.gpu
+@pytest.mark.parametrize("batch_size", [4])
+@pytest.mark.parametrize("seq_len", [32])
+@pytest.mark.parametrize("actor_maker", [
+ lambda: BLOOMActor(),
+ lambda: GPTActor(),
+ # HACK: skip llama due to long execution time
+ # lambda: LlamaActor(),
+ lambda: OPTActor()
+])
+@pytest.mark.parametrize("generate_kwargs", [{
+ "max_length": 64,
+ "use_cache": True,
+ "do_sample": True,
+ "temperature": 1.0,
+ "top_k": 50,
+}])
+def test_generation(actor_maker: Callable[[], Actor],
+ batch_size: int,
+ seq_len: int,
+ generate_kwargs: Dict[str, Any]
+ ):
+ actor = actor_maker()
+ input_ids = torch.randint(0, 100, (batch_size, seq_len)).cuda()
+ sequences = generate(actor.cuda(), input_ids, **generate_kwargs)
+ assert sequences.shape == (batch_size, generate_kwargs["max_length"])
+
+
+@pytest.mark.cpu
+def test_utils():
+ fn_input = {
+ "tensor": torch.ones((10, )),
+ "mask": torch.randint(0, 2, (10, ))
+ }
+ fn_output = masked_mean(dim=0, **fn_input)
+ assert fn_output.dim() == 0
+ assert torch.allclose(fn_output, torch.tensor(1.0))
+
+ batch_size = 4
+ num_labels = 10
+ fn_input = {
+ "r": torch.ones((batch_size, )),
+ "kl_coef": 1.0,
+ "log_probs": torch.randn((batch_size, num_labels)),
+ "log_probs_base": torch.randn((batch_size, num_labels)),
+ "action_mask": torch.randint(0, 2, (batch_size, num_labels))
+ }
+ fn_output = compute_reward(**fn_input)
+ assert fn_output.shape == (batch_size, )
+
+ batch_size = 4
+ seq_len = 32
+ num_labels = 10
+ num_actions = 2
+ fn_input = {
+ "output": {
+ "logits": torch.randn((batch_size, seq_len, num_labels))
+ },
+ "sequences": torch.randint(0, num_labels, (batch_size, seq_len)),
+ "num_actions": num_actions,
+ }
+ fn_output = calc_action_log_probs(**fn_input)
+ assert fn_output.shape == (batch_size, num_actions)
+
+
+@pytest.mark.cpu
+@pytest.mark.parametrize("lora_rank", [4])
+@pytest.mark.parametrize("num_dim", [32])
+@pytest.mark.parametrize("num_layers", [4])
+def test_lora(lora_rank: int,
+ num_dim: int,
+ num_layers: int):
+ model = nn.ModuleList(
+ [nn.Linear(num_dim, num_dim)
+ for _ in range(num_layers)]
+ )
+ lora_model = convert_to_lora_module(model, lora_rank)
+ assert isinstance(lora_model, nn.ModuleList)
+ for i in range(num_layers):
+ assert isinstance(lora_model[i], LoraLinear)
+ assert lora_model[i].lora_A.shape == (lora_rank, num_dim)
+ assert lora_model[i].lora_B.shape == (num_dim, lora_rank)
+
+ old_model = copy.deepcopy(lora_model)
+ for i in range(num_layers):
+ assert isinstance(lora_model[i], LoraLinear)
+ assert torch.allclose(old_model[i].weight, lora_model[i].weight)
+ assert torch.allclose(old_model[i].bias, lora_model[i].bias)
+ assert torch.allclose(old_model[i].lora_B @ old_model[i].lora_A,
+ lora_model[i].lora_B @ lora_model[i].lora_A)
+ optimizer = torch.optim.Adam(lora_model.parameters())
+ x = torch.randn(8, num_dim)
+ for i in range(num_layers):
+ x = lora_model[i](x)
+ loss = x.sum()
+ loss.backward()
+ optimizer.step()
+ for i in range(num_layers):
+ assert isinstance(lora_model[i], LoraLinear)
+ assert torch.allclose(old_model[i].weight, lora_model[i].weight)
+ assert torch.allclose(old_model[i].bias, lora_model[i].bias)
+ assert not torch.allclose(old_model[i].lora_B @ old_model[i].lora_A,
+ lora_model[i].lora_B @ lora_model[i].lora_A)
+
+
+@pytest.mark.cpu
+@pytest.mark.parametrize("batch_size", [8])
+@pytest.mark.parametrize("seq_len", [128])
+@pytest.mark.parametrize("models_maker", [
+ lambda: (BLOOMActor(), BLOOMCritic(), BLOOMRM()),
+ lambda: (GPTActor(), GPTCritic(), GPTRM()),
+ # HACK: skip llama due to long execution time
+ # lambda: (LlamaActor(), LlamaCritic(), LlamaRM()),
+ lambda: (OPTActor(), OPTCritic(), OPTRM()),
+])
+@torch.no_grad()
+def test_models(models_maker: Callable[[], Tuple[Actor, Critic, RewardModel]],
+ batch_size: int,
+ seq_len: int):
+
+ actor_input = {
+ "input_ids": torch.randint(0, 100, (batch_size, seq_len)),
+ "attention_mask": torch.randint(0, 2, (batch_size, seq_len))
+ }
+ critic_input = {
+ "sequences": torch.randint(0, 100, (batch_size, seq_len)),
+ "action_mask": torch.randint(0, 2, (batch_size, seq_len)),
+ "attention_mask": torch.randint(0, 2, (batch_size, seq_len))
+ }
+ rm_input = {
+ "sequences": torch.randint(0, 100, (batch_size, seq_len)),
+ "attention_mask": torch.randint(0, 2, (batch_size, seq_len))
+ }
+
+ actor, critic, rm = models_maker()
+ assert isinstance(actor, Actor)
+ base_actor_model = get_base_model(actor)
+ assert isinstance(critic, Critic)
+ base_critic_model = get_base_model(critic)
+ assert isinstance(rm, RewardModel)
+ base_rm_model = get_base_model(rm)
+
+ actor_output = actor(**actor_input)
+ critic_output = critic(**critic_input)
+ rm_output = rm(**rm_input)
+
+ assert actor_output.logits.shape[:2] == (batch_size, seq_len)
+ assert critic_output.shape == (batch_size, )
+ assert rm_output.shape == (batch_size, )
+
+
+@pytest.mark.cpu
+@pytest.mark.parametrize("batch_size", [16])
+@pytest.mark.parametrize("seq_len", [128])
+@pytest.mark.parametrize("num_labels", [100])
+def test_loss(batch_size: int,
+ seq_len: int,
+ num_labels: int):
+ loss = GPTLMLoss()
+ loss_input = {
+ "logits": torch.randn(batch_size, seq_len, num_labels),
+ "labels": torch.randint(0, num_labels, (batch_size, seq_len))
+ }
+ loss_output = loss(**loss_input)
+
+ loss = PolicyLoss()
+ loss_input = {
+ "log_probs": torch.randn(batch_size, ),
+ "old_log_probs": torch.randn(batch_size, ),
+ "advantages": torch.randn(batch_size, )
+ }
+ loss_output = loss(**loss_input)
+
+ loss = ValueLoss()
+ loss_input = {
+ "values": torch.randn(batch_size, ),
+ "old_values": torch.randn(batch_size, ),
+ "reward": torch.randn(batch_size, )
+ }
+ loss_output = loss(**loss_input)
+
+ loss = LogSigLoss()
+ loss_input = {
+ "chosen_reward": torch.randn(batch_size, ),
+ "reject_reward": torch.randn(batch_size, ),
+ }
+ loss_output = loss(**loss_input)
+
+ loss = LogExpLoss()
+ loss_input = {
+ "chosen_reward": torch.randn(batch_size, ),
+ "reject_reward": torch.randn(batch_size, ),
+ }
+ loss_output = loss(**loss_input)
+
+
+if __name__ == "__main__":
+ generate_kwargs = dict(max_length=40,
+ use_cache=True,
+ do_sample=True,
+ temperature=1.0,
+ top_k=50)
+ test_generation(lambda: LlamaActor(),
+ batch_size=4,
+ seq_len=32,
+ generate_kwargs=generate_kwargs)
+
+ test_utils()
+
+ test_lora(lora_rank=2, num_dim=8, num_layers=2)
+
+ test_models(models_maker=lambda: (BLOOMActor(),
+ BLOOMCritic(),
+ BLOOMRM()),
+ batch_size=8,
+ seq_len=128)
+
+ test_loss(batch_size=8, seq_len=128, num_labels=100)
diff --git a/applications/Chat/tests/test_train.sh b/applications/Chat/tests/test_train.sh
new file mode 100755
index 000000000000..c5127c188612
--- /dev/null
+++ b/applications/Chat/tests/test_train.sh
@@ -0,0 +1,228 @@
+#!/usr/bin/env bash
+
+set_n_least_used_CUDA_VISIBLE_DEVICES() {
+ local n=${1:-"9999"}
+ echo "GPU Memory Usage:"
+ local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
+ tail -n +2 |
+ nl -v 0 |
+ tee /dev/tty |
+ sort -g -k 2 |
+ awk '{print $1}' |
+ head -n $n)
+ export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
+ echo "Now CUDA_VISIBLE_DEVICES is set to:"
+ echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
+}
+
+set_n_least_used_CUDA_VISIBLE_DEVICES 4
+
+set -xu
+
+if [ -z "$SFT_DATASET" ]; then
+ echo "Please set \$SFT_DATASET to the path to sft dataset."
+ exit 1
+fi
+
+if [ -z "$PROMPT_PATH" ]; then
+ echo "Please set \$PROMPT_PATH to the path to prompts csv."
+ exit 1
+fi
+
+if [ -z "$PRETRAIN_DATASET" ]; then
+ echo "Please set \$PRETRAIN_DATASET to the path to alpaca data."
+ exit 1
+fi
+
+NUM_RETRY=3
+BASE_DIR=$(dirname $(dirname $(realpath $BASH_SOURCE)))
+EXAMPLES_DIR=$BASE_DIR/examples
+MODELS_DIR=$BASE_DIR/examples/models_config
+MODELS=('gpt2' 'bloom' 'opt' 'llama')
+STRATEGIES=('ddp' 'colossalai_gemini' 'colossalai_zero2')
+
+export OMP_NUM_THREADS=8
+
+# install requirements
+pip install -r $EXAMPLES_DIR/requirements.txt
+
+python $EXAMPLES_DIR/download_model.py --model-dir $MODELS_DIR --config-only
+
+get_pretrain() {
+ local model=$1
+ if [[ $model == "gpt2" ]]; then
+ echo "gpt2"
+ elif [[ $model == "bloom" ]]; then
+ echo "bigscience/bloom-560m"
+ elif [[ $model == "opt" ]]; then
+ echo "facebook/opt-350m"
+ else
+ echo "Unknown model $model"
+ exit 1
+ fi
+}
+
+random_choice() {
+ local arr=("$@")
+ local len=${#arr[@]}
+ local idx=$((RANDOM % len))
+ echo ${arr[$idx]}
+}
+
+echo "[Test]: testing sft ..."
+
+# FIXME: This is a hack to skip tests that are not working
+# - gpt2-ddp: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
+# - llama-*: These tests can be passed locally, skipped for long execution time
+SKIPPED_TESTS=(
+ "gpt2-ddp"
+ "llama-ddp"
+ "llama-colossalai_gemini"
+ "llama-colossalai_zero2"
+)
+
+GRAD_CKPTS=('' '--grad_checkpoint')
+for lora_rank in '0' '4'; do
+ for model in ${MODELS[@]}; do
+ strategies=($(shuf -e "${STRATEGIES[@]}"))
+ for strategy in ${strategies[@]}; do
+ if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$strategy-$lora_rank " ]]; then
+ echo "[Test]: Skipped $model-$strategy-$lora_rank"
+ continue
+ elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$strategy " ]]; then
+ echo "[Test]: Skipped $model-$strategy"
+ continue
+ fi
+ pretrain=$(get_pretrain $model)
+ pretrain_model=""
+ if [[ $lora_rank -gt 0 ]]; then
+ pretrain_model="--pretrain $pretrain"
+ fi
+ grad_ckpt=$(random_choice "${GRAD_CKPTS[@]}")
+ for i in $(seq $NUM_RETRY); do
+ echo "[Test]: $model-$strategy-$lora_rank, attempt $i"
+ torchrun --standalone --nproc_per_node=4 $EXAMPLES_DIR/train_sft.py \
+ $pretrain_model --tokenizer $MODELS_DIR/$model \
+ --model $model --strategy $strategy --lora_rank $lora_rank $grad_ckpt \
+ --dataset $SFT_DATASET --max_datasets_size 8 \
+ --max_epochs 1 --batch_size 1 --accumulation_steps 1 \
+ --save_path $EXAMPLES_DIR/rlhf_models/sft_ckpt_${model}_${lora_rank}
+ passed=$?
+ if [ $passed -eq 0 ]; then
+ break
+ fi
+ done
+ if [ $passed -ne 0 ]; then
+ echo "[Test]: Failed $model-$strategy-$lora_rank"
+ exit 1
+ fi
+ done
+ done
+done
+
+echo "[Test]: testing reward model ..."
+
+# FIXME: This is a hack to skip tests that are not working
+# - gpt2-ddp: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
+# - llama-*: These tests can be passed locally, skipped for long execution time
+SKIPPED_TESTS=(
+ "gpt2-ddp"
+ "llama-ddp"
+ "llama-colossalai_gemini"
+ "llama-colossalai_zero2"
+)
+
+LOSS_FNS=('log_sig' 'log_exp')
+DATASETS=('Anthropic/hh-rlhf' 'Dahoas/rm-static')
+for lora_rank in '0' '4'; do
+ for model in ${MODELS[@]}; do
+ strategies=($(shuf -e "${STRATEGIES[@]}"))
+ for strategy in ${strategies[@]}; do
+ if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$strategy-$lora_rank " ]]; then
+ echo "[Test]: Skipped $model-$strategy-$lora_rank"
+ continue
+ elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$strategy " ]]; then
+ echo "[Test]: Skipped $model-$strategy"
+ continue
+ fi
+ pretrain=$(get_pretrain $model)
+ pretrain_model=""
+ if [[ $lora_rank -gt 0 ]]; then
+ pretrain_model="--pretrain $pretrain"
+ fi
+ loss_fn=$(random_choice "${LOSS_FNS[@]}")
+ dataset=$(random_choice "${DATASETS[@]}")
+ subset=$(if [[ $dataset == "Dahoas/rm-static" ]]; then echo "None"; else echo "harmless-base"; fi)
+ for i in $(seq $NUM_RETRY); do
+ echo "[Test]: $model-$strategy-$lora_rank, attempt $i"
+ torchrun --standalone --nproc_per_node=4 $EXAMPLES_DIR/train_reward_model.py \
+ $pretrain_model --tokenizer $MODELS_DIR/$model \
+ --model $model --strategy $strategy --lora_rank $lora_rank --loss_fn $loss_fn \
+ --dataset $dataset --subset $subset --test True --batch_size 1 \
+ --save_path $EXAMPLES_DIR/rlhf_models/rm_ckpt_${model}_${lora_rank}.pt
+ passed=$?
+ if [ $passed -eq 0 ]; then
+ break
+ fi
+ done
+ if [ $passed -ne 0 ]; then
+ echo "[Test]: Failed to train reward model $model-$strategy-$lora_rank"
+ exit 1
+ fi
+ done
+ done
+done
+
+echo "[Test]: testing RLHF ..."
+
+# FIXME: This is a hack to skip tests that are not working
+# - gpt2-ddp: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
+# - llama-*: These tests can be passed locally, skipped for long execution time
+SKIPPED_TESTS=(
+ "gpt2-ddp"
+ "llama-ddp"
+ "llama-colossalai_gemini"
+ "llama-colossalai_zero2"
+)
+
+for model in ${MODELS[@]}; do
+ for lora_rank in '0' '4'; do
+ strategies=($(shuf -e "${STRATEGIES[@]}"))
+ for strategy in ${strategies[@]}; do
+ if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$strategy-$lora_rank " ]]; then
+ echo "[Test]: Skipped $model-$strategy-$lora_rank"
+ continue
+ elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$strategy " ]]; then
+ echo "[Test]: Skipped $model-$strategy"
+ continue
+ fi
+ rm_pretrain=$(get_pretrain $model)
+ rm_pretrain_model=""
+ if [[ $lora_rank -gt 0 ]]; then
+ rm_pretrain_model="--rm_pretrain $rm_pretrain"
+ fi
+ for i in $(seq $NUM_RETRY); do
+ echo "[Test]: $model-$strategy-$lora_rank, attempt $i"
+ torchrun --standalone --nproc_per_node=4 $EXAMPLES_DIR/train_prompts.py \
+ --prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \
+ --strategy $strategy --model $model --tokenizer $MODELS_DIR/$model \
+ --num_episodes 1 --num_collect_steps 1 --num_update_steps 1 \
+ --experience_batch_size 2 --train_batch_size 1 --lora_rank $lora_rank \
+ --pretrain $EXAMPLES_DIR/rlhf_models/sft_ckpt_${model}_${lora_rank} \
+ $rm_pretrain_model --rm_path $EXAMPLES_DIR/rlhf_models/rm_ckpt_${model}_${lora_rank}.pt \
+ --save_path $EXAMPLES_DIR/rlhf_models/actor_checkpoint_prompts.pt
+ passed=$?
+ if [ $passed -eq 0 ]; then
+ break
+ fi
+ done
+ if [ $passed -ne 0 ]; then
+ echo "[Test]: Failed to train RLHF $model-$strategy-$lora_rank"
+ exit 1
+ fi
+ done
+ rm -rf $EXAMPLES_DIR/rlhf_models/sft_ckpt_${model}_${lora_rank}
+ rm $EXAMPLES_DIR/rlhf_models/rm_ckpt_${model}_${lora_rank}.pt
+ done
+done
+rm $EXAMPLES_DIR/rlhf_models/actor_checkpoint_prompts.pt
From 25c57b9fb44e3499cb9e82bb461c3aa5a2d81a2a Mon Sep 17 00:00:00 2001
From: flybird1111 <1829166702@qq.com>
Date: Fri, 4 Aug 2023 13:46:22 +0800
Subject: [PATCH 47/64] [fix] coloattention support flash attention 2 (#4347)
Improved ColoAttention interface to support flash attention 2. Solved #4322
---
colossalai/kernel/cuda_native/__init__.py | 5 +-
.../kernel/cuda_native/flash_attention.py | 635 ------------------
.../kernel/cuda_native/mha/flash_attn_2.py | 68 ++
.../kernel/cuda_native/mha/mem_eff_attn.py | 70 ++
colossalai/kernel/cuda_native/mha/mha.py | 107 +++
colossalai/kernel/cuda_native/mha/utils.py | 82 +++
.../kernel/cuda_native/scaled_softmax.py | 5 +-
tests/test_utils/test_flash_attention.py | 49 +-
8 files changed, 367 insertions(+), 654 deletions(-)
delete mode 100644 colossalai/kernel/cuda_native/flash_attention.py
create mode 100644 colossalai/kernel/cuda_native/mha/flash_attn_2.py
create mode 100644 colossalai/kernel/cuda_native/mha/mem_eff_attn.py
create mode 100644 colossalai/kernel/cuda_native/mha/mha.py
create mode 100644 colossalai/kernel/cuda_native/mha/utils.py
diff --git a/colossalai/kernel/cuda_native/__init__.py b/colossalai/kernel/cuda_native/__init__.py
index 1d5a6ce495bd..4910717b5723 100644
--- a/colossalai/kernel/cuda_native/__init__.py
+++ b/colossalai/kernel/cuda_native/__init__.py
@@ -1,5 +1,8 @@
from .layer_norm import MixedFusedLayerNorm as LayerNorm
+from .mha.mha import ColoAttention
from .multihead_attention import MultiHeadAttention
from .scaled_softmax import FusedScaleMaskSoftmax, ScaledUpperTriangMaskedSoftmax
-__all__ = ['LayerNorm', 'MultiHeadAttention', 'FusedScaleMaskSoftmax', 'ScaledUpperTriangMaskedSoftmax']
+__all__ = [
+ 'LayerNorm', 'MultiHeadAttention', 'FusedScaleMaskSoftmax', 'ScaledUpperTriangMaskedSoftmax', 'ColoAttention'
+]
diff --git a/colossalai/kernel/cuda_native/flash_attention.py b/colossalai/kernel/cuda_native/flash_attention.py
deleted file mode 100644
index 3db7374509a0..000000000000
--- a/colossalai/kernel/cuda_native/flash_attention.py
+++ /dev/null
@@ -1,635 +0,0 @@
-"""
-A general attention module using the flash attention kernels from xformers:
-https://github.com/facebookresearch/xformers/tree/main/xformers/ops/fmha
-"""
-
-import math
-import os
-import subprocess
-
-import torch
-
-try:
- from xformers.ops.fmha import memory_efficient_attention
- HAS_MEM_EFF_ATTN = True
-except ImportError:
- HAS_MEM_EFF_ATTN = False
- print('please install xformers from https://github.com/facebookresearch/xformers')
-
-if HAS_MEM_EFF_ATTN:
-
- from typing import Optional
-
- from einops import rearrange
- from xformers.ops.fmha import MemoryEfficientAttentionCutlassOp
- from xformers.ops.fmha.attn_bias import BlockDiagonalMask, LowerTriangularMask, LowerTriangularMaskWithTensorBias
-
- from .scaled_softmax import AttnMaskType
-
- allow_alibi = True
- for op in MemoryEfficientAttentionCutlassOp:
- allow_alibi = allow_alibi & (LowerTriangularMaskWithTensorBias in op.SUPPORTED_ATTN_BIAS_TYPES)
-
- class Unpad(torch.autograd.Function):
- """
- Adapted from
- https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
- """
-
- @staticmethod
- def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor):
- ctx.save_for_backward(indices)
- # [b, s, ...]
- assert tensor.ndim >= 3
- ctx.bsz = tensor.shape[0]
- out = rearrange(tensor, 'b s ... -> (b s) ...')
- ctx.shape = out.shape
- # [1, ntokens, ...]
- return out[indices].unsqueeze(0)
-
- @staticmethod
- def backward(ctx, grad_output):
- indices, = ctx.saved_tensors
- # [b*s, ...]
- grad = torch.zeros(ctx.shape, dtype=grad_output.dtype, device=grad_output.device)
- grad[indices] = grad_output.squeeze(0)
- grad = rearrange(grad, '(b s) ... -> b s ...', b=ctx.bsz)
- # [b, s, ...]
- return grad, None
-
- class Repad(torch.autograd.Function):
- """
- Adapted from
- https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
- """
-
- @staticmethod
- def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int):
- ctx.save_for_backward(indices)
- # [ntokens, ...]
- tensor = tensor.squeeze(0)
- out = torch.zeros((batch_size * seq_len, *tensor.shape[1:]), dtype=tensor.dtype, device=tensor.device)
- # [b*s, ...]
- out[indices] = tensor
- # [b, s, ...]
- out = rearrange(out, '(b s) ... -> b s ...', b=batch_size)
- return out
-
- @staticmethod
- def backward(ctx, grad_output):
- indices, = ctx.saved_tensors
- # [b*s, ...]
- grad_output = rearrange(grad_output, 'b s ... -> (b s) ...')
- grad = grad_output[indices]
- # [1, ntokens, ...]
- return grad.unsqueeze(0), None, None, None
-
- class ColoAttention(torch.nn.Module):
-
- def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0):
- super().__init__()
- assert embed_dim % num_heads == 0, \
- f"the embed dim ({embed_dim}) is not divisible by the number of attention heads ({num_heads})."
- self.scale = 1 / math.sqrt(embed_dim // num_heads)
- self.dropout = dropout
-
- @staticmethod
- def get_seq_info_from_mask(attn_mask: torch.Tensor):
- indices = torch.nonzero(attn_mask.flatten(), as_tuple=False).flatten()
- seqlens = attn_mask.sum(dim=-1, dtype=torch.int32).flatten().tolist()
- return indices, seqlens
-
- @staticmethod
- def unpad(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
- return Unpad.apply(tensor, indices)
-
- @staticmethod
- def repad(tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int) -> torch.Tensor:
- return Repad.apply(tensor, indices, batch_size, seq_len)
-
- def forward(self,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- attn_mask: Optional[torch.Tensor] = None,
- attn_mask_type: Optional[AttnMaskType] = None,
- bias: Optional[torch.Tensor] = None):
- batch_size, tgt_len, src_len = query.shape[0], query.shape[1], key.shape[1]
- attn_bias = None
- if attn_mask_type == AttnMaskType.padding: # bert style
- assert attn_mask is not None, \
- f"attention mask {attn_mask} is not valid for attention mask type {attn_mask_type}."
- assert attn_mask.dim() == 2, \
- "attention mask is supposed to have shape (batch_size, seq_len), " + \
- f"but got {attn_mask.dim()} dimensions."
- if tgt_len == src_len:
- q_indices, q_seqlen = self.get_seq_info_from_mask(attn_mask)
- kv_seqlen = None
- if batch_size > 1:
- query, key, value = self.unpad(torch.stack([query, key, value], dim=2), q_indices).unbind(dim=2)
- else:
- q_indices = torch.arange(batch_size * tgt_len, dtype=torch.int32, device=query.device)
- q_seqlen = torch.LongTensor([tgt_len] * batch_size, device=query.device)
- kv_indices, kv_seqlen = self.get_seq_info_from_mask(attn_mask)
- if batch_size > 1:
- query = rearrange(query, "b s ... -> c (b s) ...", c=1)
- key, value = self.unpad(torch.stack([query, key, value], dim=2), kv_indices).unbind(dim=2)
- attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen)
- elif attn_mask_type == AttnMaskType.causal: # gpt style
- attn_bias = LowerTriangularMask()
-
- if bias is not None: # alibi / relative position embedding
- assert allow_alibi, "flash attention with bias is not supported in this system."
- assert attn_mask_type == AttnMaskType.causal, \
- "attention with bias is only supported for causal attention so far."
- attn_bias = attn_bias.add_bias(bias)
-
- out = memory_efficient_attention(query, key, value, attn_bias=attn_bias, p=self.dropout, scale=self.scale)
-
- if attn_mask_type == AttnMaskType.padding and batch_size > 1:
- out = self.repad(out, q_indices, batch_size, tgt_len)
-
- out = rearrange(out, 'b s h d -> b s (h d)')
- return out
-
-
-##########################################################################
-# the flash attention functions below that are copied
-# from the OpenAI/triton repository will be deprecated
-# You can find the repository in Triton https://github.com/openai/triton
-# You can find the source file in https://github.com/openai/triton/blob/main/python/tutorials/06-fused-attention.py
-# Reference:
-# 1. Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf
-# 2. Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf
-
-
-def triton_cuda_check():
- cuda_home = os.getenv("CUDA_HOME", default="/usr/local/cuda")
- cuda_version = subprocess.check_output([os.path.join(cuda_home, "bin/nvcc"), "--version"]).decode().strip()
- cuda_version = cuda_version.split('release ')[1]
- cuda_version = cuda_version.split(',')[0]
- cuda_version = cuda_version.split('.')
- if len(cuda_version) == 2 and \
- (int(cuda_version[0]) == 11 and int(cuda_version[1]) >= 4) or \
- int(cuda_version[0]) > 11:
- return True
- return False
-
-
-try:
- import triton
- import triton.language as tl
- if triton_cuda_check():
- HAS_TRITON = True
- else:
- print("triton requires cuda >= 11.4")
- HAS_TRITON = False
-except ImportError:
- print('please install triton from https://github.com/openai/triton')
- HAS_TRITON = False
-try:
- from flash_attn.flash_attention import FlashAttention
- from flash_attn.flash_attn_interface import (
- flash_attn_unpadded_func,
- flash_attn_unpadded_kvpacked_func,
- flash_attn_unpadded_qkvpacked_func,
- )
- HAS_FLASH_ATTN = True
-except ImportError:
- HAS_FLASH_ATTN = False
- print('please install flash_attn from https://github.com/HazyResearch/flash-attention')
-
-if HAS_TRITON:
- # the following functions are adapted from the OpenAI Triton tutorial
- # https://github.com/openai/triton/blob/main/python/tutorials/06-fused-attention.py
- @triton.jit
- def _fwd_kernel(
- Q,
- K,
- V,
- sm_scale,
- TMP,
- L,
- M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
- Out,
- stride_qz,
- stride_qh,
- stride_qm,
- stride_qk,
- stride_kz,
- stride_kh,
- stride_kn,
- stride_kk,
- stride_vz,
- stride_vh,
- stride_vk,
- stride_vn,
- stride_oz,
- stride_oh,
- stride_om,
- stride_on,
- Z,
- H,
- N_CTX,
- BLOCK_M: tl.constexpr,
- BLOCK_DMODEL: tl.constexpr,
- BLOCK_N: tl.constexpr,
- ):
- start_m = tl.program_id(0)
- off_hz = tl.program_id(1)
- # initialize offsets
- offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
- offs_n = tl.arange(0, BLOCK_N)
- offs_d = tl.arange(0, BLOCK_DMODEL)
- off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
- off_k = off_hz * stride_qh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk
- off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk
- # Initialize pointers to Q, K, V
- q_ptrs = Q + off_q
- k_ptrs = K + off_k
- v_ptrs = V + off_v
- # initialize pointer to m and l
- t_ptrs = TMP + off_hz * N_CTX + offs_m
- m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
- l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
- acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
- # load q: it will stay in SRAM throughout
- q = tl.load(q_ptrs)
- # loop over k, v and update accumulator
- for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):
- start_n = tl.multiple_of(start_n, BLOCK_N)
- # -- compute qk ----
- k = tl.load(k_ptrs + start_n * stride_kn)
- qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
- qk += tl.dot(q, k, trans_b=True)
- qk *= sm_scale
- qk += tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), 0, float("-inf"))
- # -- compute m_ij, p, l_ij
- m_ij = tl.max(qk, 1)
- p = tl.exp(qk - m_ij[:, None])
- l_ij = tl.sum(p, 1)
- # -- update m_i and l_i
- m_i_new = tl.maximum(m_i, m_ij)
- alpha = tl.exp(m_i - m_i_new)
- beta = tl.exp(m_ij - m_i_new)
- l_i_new = alpha * l_i + beta * l_ij
- # -- update output accumulator --
- # scale p
- p_scale = beta / l_i_new
- p = p * p_scale[:, None]
- # scale acc
- acc_scale = l_i / l_i_new * alpha
- tl.store(t_ptrs, acc_scale)
- acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load
- acc = acc * acc_scale[:, None]
- # update acc
- v = tl.load(v_ptrs + start_n * stride_vk)
- p = p.to(tl.float16)
- acc += tl.dot(p, v)
- # update m_i and l_i
- l_i = l_i_new
- m_i = m_i_new
- # rematerialize offsets to save registers
- start_m = tl.program_id(0)
- offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
- # write back l and m
- l_ptrs = L + off_hz * N_CTX + offs_m
- m_ptrs = M + off_hz * N_CTX + offs_m
- tl.store(l_ptrs, l_i)
- tl.store(m_ptrs, m_i)
- # initialize pointers to output
- offs_n = tl.arange(0, BLOCK_DMODEL)
- off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
- out_ptrs = Out + off_o
- tl.store(out_ptrs, acc)
-
- @triton.jit
- def _bwd_preprocess(
- Out,
- DO,
- L,
- NewDO,
- Delta,
- BLOCK_M: tl.constexpr,
- D_HEAD: tl.constexpr,
- ):
- off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
- off_n = tl.arange(0, D_HEAD)
- # load
- o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
- do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
- denom = tl.load(L + off_m).to(tl.float32)
- # compute
- do = do / denom[:, None]
- delta = tl.sum(o * do, axis=1)
- # write-back
- tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do)
- tl.store(Delta + off_m, delta)
-
- @triton.jit
- def _bwd_kernel(
- Q,
- K,
- V,
- sm_scale,
- Out,
- DO,
- DQ,
- DK,
- DV,
- L,
- M,
- D,
- stride_qz,
- stride_qh,
- stride_qm,
- stride_qk,
- stride_kz,
- stride_kh,
- stride_kn,
- stride_kk,
- stride_vz,
- stride_vh,
- stride_vk,
- stride_vn,
- Z,
- H,
- N_CTX,
- num_block,
- BLOCK_M: tl.constexpr,
- BLOCK_DMODEL: tl.constexpr,
- BLOCK_N: tl.constexpr,
- ):
- off_hz = tl.program_id(0)
- off_z = off_hz // H
- off_h = off_hz % H
- # offset pointers for batch/head
- Q += off_z * stride_qz + off_h * stride_qh
- K += off_z * stride_qz + off_h * stride_qh
- V += off_z * stride_qz + off_h * stride_qh
- DO += off_z * stride_qz + off_h * stride_qh
- DQ += off_z * stride_qz + off_h * stride_qh
- DK += off_z * stride_qz + off_h * stride_qh
- DV += off_z * stride_qz + off_h * stride_qh
- for start_n in range(0, num_block):
- lo = start_n * BLOCK_M
- # initialize row/col offsets
- offs_qm = lo + tl.arange(0, BLOCK_M)
- offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)
- offs_m = tl.arange(0, BLOCK_N)
- offs_k = tl.arange(0, BLOCK_DMODEL)
- # initialize pointers to value-like data
- q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
- k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
- v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
- do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
- dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
- # pointer to row-wise quantities in value-like data
- D_ptrs = D + off_hz * N_CTX
- m_ptrs = M + off_hz * N_CTX
- # initialize dv amd dk
- dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
- dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
- # k and v stay in SRAM throughout
- k = tl.load(k_ptrs)
- v = tl.load(v_ptrs)
- # loop over rows
- for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):
- offs_m_curr = start_m + offs_m
- # load q, k, v, do on-chip
- q = tl.load(q_ptrs)
- # recompute p = softmax(qk, dim=-1).T
- # NOTE: `do` is pre-divided by `l`; no normalization here
- qk = tl.dot(q, k, trans_b=True)
- qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
- m = tl.load(m_ptrs + offs_m_curr)
- p = tl.exp(qk * sm_scale - m[:, None])
- # compute dv
- do = tl.load(do_ptrs)
- dv += tl.dot(p.to(tl.float16), do, trans_a=True)
- # compute dp = dot(v, do)
- Di = tl.load(D_ptrs + offs_m_curr)
- dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
- dp += tl.dot(do, v, trans_b=True)
- # compute ds = p * (dp - delta[:, None])
- ds = p * dp * sm_scale
- # compute dk = dot(ds.T, q)
- dk += tl.dot(ds.to(tl.float16), q, trans_a=True)
- # # compute dq
- dq = tl.load(dq_ptrs, eviction_policy="evict_last")
- dq += tl.dot(ds.to(tl.float16), k)
- tl.store(dq_ptrs, dq, eviction_policy="evict_last")
- # # increment pointers
- dq_ptrs += BLOCK_M * stride_qm
- q_ptrs += BLOCK_M * stride_qm
- do_ptrs += BLOCK_M * stride_qm
- # write-back
- dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
- dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
- tl.store(dv_ptrs, dv)
- tl.store(dk_ptrs, dk)
-
- class _TritonFlashAttention(torch.autograd.Function):
-
- @staticmethod
- def forward(ctx, q, k, v, sm_scale):
- BLOCK = 128
- # shape constraints
- Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
- assert Lq == Lk and Lk == Lv
- assert Lk in {16, 32, 64, 128}
- o = torch.empty_like(q)
- grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1])
- tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
- L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
- m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
- num_warps = 4 if Lk <= 64 else 8
-
- _fwd_kernel[grid](
- q,
- k,
- v,
- sm_scale,
- tmp,
- L,
- m,
- o,
- q.stride(0),
- q.stride(1),
- q.stride(2),
- q.stride(3),
- k.stride(0),
- k.stride(1),
- k.stride(2),
- k.stride(3),
- v.stride(0),
- v.stride(1),
- v.stride(2),
- v.stride(3),
- o.stride(0),
- o.stride(1),
- o.stride(2),
- o.stride(3),
- q.shape[0],
- q.shape[1],
- q.shape[2],
- BLOCK_M=BLOCK,
- BLOCK_N=BLOCK,
- BLOCK_DMODEL=Lk,
- num_warps=num_warps,
- num_stages=1,
- )
- ctx.save_for_backward(q, k, v, o, L, m)
- ctx.BLOCK = BLOCK
- ctx.grid = grid
- ctx.sm_scale = sm_scale
- ctx.BLOCK_DMODEL = Lk
- return o
-
- @staticmethod
- def backward(ctx, do):
- q, k, v, o, l, m = ctx.saved_tensors
- do = do.contiguous()
- dq = torch.zeros_like(q, dtype=torch.float32)
- dk = torch.empty_like(k)
- dv = torch.empty_like(v)
- do_scaled = torch.empty_like(do)
- delta = torch.empty_like(l)
- _bwd_preprocess[(ctx.grid[0] * ctx.grid[1],)](
- o,
- do,
- l,
- do_scaled,
- delta,
- BLOCK_M=ctx.BLOCK,
- D_HEAD=ctx.BLOCK_DMODEL,
- )
-
- # NOTE: kernel currently buggy for other values of `num_warps`
- num_warps = 8
- _bwd_kernel[(ctx.grid[1],)](
- q,
- k,
- v,
- ctx.sm_scale,
- o,
- do_scaled,
- dq,
- dk,
- dv,
- l,
- m,
- delta,
- q.stride(0),
- q.stride(1),
- q.stride(2),
- q.stride(3),
- k.stride(0),
- k.stride(1),
- k.stride(2),
- k.stride(3),
- v.stride(0),
- v.stride(1),
- v.stride(2),
- v.stride(3),
- q.shape[0],
- q.shape[1],
- q.shape[2],
- ctx.grid[0],
- BLOCK_M=ctx.BLOCK,
- BLOCK_N=ctx.BLOCK,
- BLOCK_DMODEL=ctx.BLOCK_DMODEL,
- num_warps=num_warps,
- num_stages=1,
- )
- return dq, dk, dv, None
-
- def triton_flash_attention(q, k, v, sm_scale):
- """
- Arguments:
- q: (batch, nheads, seq, headdim)
- k: (batch, nheads, seq, headdim)
- v: (batch, nheads, seq, headdim)
- sm_scale: float. The scaling of QK^T before applying softmax.
- Return:
- out: (batch, nheads, seq, headdim)
- """
- if HAS_TRITON:
- return _TritonFlashAttention.apply(q, k, v, sm_scale)
- else:
- raise RuntimeError("Triton kernel requires CUDA 11.4+!")
-
-
-if HAS_FLASH_ATTN:
-
- def flash_attention_qkv(qkv, sm_scale, batch_size, seq_len, dropout_p=0., causal=False):
- """
- Arguments:
- qkv: (batch * seqlen, 3, nheads, headdim)
- batch_size: int.
- seq_len: int.
- sm_scale: float. The scaling of QK^T before applying softmax.
- Default to 1 / sqrt(headdim).
- dropout_p: float.
- causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
- Return:
- out: (total, nheads, headdim).
- """
- max_s = seq_len
- cu_seqlens = torch.arange(0, (batch_size + 1) * seq_len, step=seq_len, dtype=torch.int32, device=qkv.device)
- out = flash_attn_unpadded_qkvpacked_func(qkv,
- cu_seqlens,
- max_s,
- dropout_p,
- softmax_scale=sm_scale,
- causal=causal)
- return out
-
- def flash_attention_q_kv(q, kv, sm_scale, batch_size, q_seqlen, kv_seqlen, dropout_p=0., causal=False):
- """
- Arguments:
- q: (batch * q_seqlen, nheads, headdim)
- kv: (batch * kv_seqlen, 2, nheads, headdim)
- batch_size: int.
- seq_len: int.
- sm_scale: float. The scaling of QK^T before applying softmax.
- Default to 1 / sqrt(headdim).
- dropout_p: float.
- causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
- Return:
- out: (total, nheads, headdim).
- """
- cu_seqlens_q = torch.arange(0, (batch_size + 1) * q_seqlen, step=q_seqlen, dtype=torch.int32, device=q.device)
- cu_seqlens_k = torch.arange(0, (batch_size + 1) * kv_seqlen,
- step=kv_seqlen,
- dtype=torch.int32,
- device=kv.device)
- out = flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, q_seqlen, kv_seqlen, dropout_p,
- sm_scale, causal)
- return out
-
- def flash_attention_q_k_v(q, k, v, sm_scale, batch_size, q_seqlen, kv_seqlen, dropout_p=0., causal=False):
- """
- Arguments:
- q: (batch * q_seqlen, nheads, headdim)
- k: (batch * kv_seqlen, nheads, headdim)
- v: (batch * kv_seqlen, nheads, headdim)
- batch_size: int.
- seq_len: int.
- dropout_p: float. Dropout probability.
- sm_scale: float. The scaling of QK^T before applying softmax.
- Default to 1 / sqrt(headdim).
- causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
- Return:
- out: (total, nheads, headdim).
- """
- cu_seqlens_q = torch.arange(0, (batch_size + 1) * q_seqlen, step=q_seqlen, dtype=torch.int32, device=q.device)
- cu_seqlens_kv = torch.arange(0, (batch_size + 1) * kv_seqlen,
- step=kv_seqlen,
- dtype=torch.int32,
- device=k.device)
- return flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, q_seqlen, kv_seqlen, dropout_p, sm_scale,
- causal)
-
-
-##########################################################################
diff --git a/colossalai/kernel/cuda_native/mha/flash_attn_2.py b/colossalai/kernel/cuda_native/mha/flash_attn_2.py
new file mode 100644
index 000000000000..6a8d74f70c1d
--- /dev/null
+++ b/colossalai/kernel/cuda_native/mha/flash_attn_2.py
@@ -0,0 +1,68 @@
+import warnings
+from typing import Optional
+
+import torch
+
+
+def is_ampere_or_better_gpu():
+ if torch.cuda.is_available():
+ device = torch.device("cuda")
+ properties = torch.cuda.get_device_properties(device)
+ if properties.major >= 8: # Ampere GPUs or newer
+ return True
+ return False
+
+
+# "Check Ampere GPUs or newer"
+HAS_FLASH_ATTN = False
+if is_ampere_or_better_gpu():
+ HAS_FLASH_ATTN = True
+else:
+ warnings.warn('FlashAttention only supports Ampere GPUs or newer.')
+ HAS_FLASH_ATTN = False
+try:
+ from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func
+ HAS_FLASH_ATTN = True
+except ImportError:
+ warnings.warn('please install flash_attn from https://github.com/HazyResearch/flash-attention')
+ HAS_FLASH_ATTN = False
+
+if HAS_FLASH_ATTN:
+ from einops import rearrange
+
+ from .utils import SeqLenInfo
+
+ def flash_attention(q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ seq_len_info_q: SeqLenInfo,
+ seq_len_info_kv: SeqLenInfo,
+ bias: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.,
+ scale: float = None,
+ causal: bool = False,
+ padded: bool = False):
+ """
+ Arguments:
+ q: (batch, q_seqlen, nheads, headdim)
+ k: (batch, kv_seqlen, nheads, headdim)
+ v: (batch, kv_seqlen, nheads, headdim)
+ batch_size: int.
+ seq_len: int.
+ dropout_p: float. Dropout probability.
+ sm_scale: float. The scaling of QK^T before applying softmax.
+ Default to 1 / sqrt(headdim).
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
+ Return:
+ attn_out: (batch, q_seqlen, nheads, headdim).
+ """
+ if padded:
+ if seq_len_info_kv == None:
+ seq_len_info_kv = seq_len_info_q
+
+ attn_out = flash_attn_varlen_func(q, k, v, seq_len_info_q.cu_seqlens, seq_len_info_kv.cu_seqlens,
+ seq_len_info_q.max_seqlen, seq_len_info_kv.max_seqlen, dropout_p, scale,
+ causal)
+ else:
+ attn_out = flash_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=scale, causal=causal)
+ return attn_out
diff --git a/colossalai/kernel/cuda_native/mha/mem_eff_attn.py b/colossalai/kernel/cuda_native/mha/mem_eff_attn.py
new file mode 100644
index 000000000000..e83beb8b2429
--- /dev/null
+++ b/colossalai/kernel/cuda_native/mha/mem_eff_attn.py
@@ -0,0 +1,70 @@
+import warnings
+
+HAS_MEM_EFF_ATTN = False
+try:
+ from xformers.ops.fmha import memory_efficient_attention
+ HAS_MEM_EFF_ATTN = True
+except ImportError:
+ warnings.warn('please install xformers from https://github.com/facebookresearch/xformers')
+ HAS_MEM_EFF_ATTN = False
+
+if HAS_MEM_EFF_ATTN:
+ """
+ A general attention module using the flash attention kernels from xformers:
+ https://github.com/facebookresearch/xformers/tree/main/xformers/ops/fmha
+ """
+ from typing import Optional
+
+ import torch
+ from xformers.ops.fmha import MemoryEfficientAttentionCutlassOp
+ from xformers.ops.fmha.attn_bias import (
+ BlockDiagonalCausalMask,
+ BlockDiagonalMask,
+ LowerTriangularMask,
+ LowerTriangularMaskWithTensorBias,
+ )
+
+ from .utils import SeqLenInfo
+
+ allow_alibi = True
+ for op in MemoryEfficientAttentionCutlassOp:
+ allow_alibi = allow_alibi & (LowerTriangularMaskWithTensorBias in op.SUPPORTED_ATTN_BIAS_TYPES)
+
+ def mem_eff_attention(q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ seq_len_info_q: SeqLenInfo,
+ seq_len_info_kv: SeqLenInfo,
+ bias: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.,
+ scale: float = None,
+ causal: bool = False,
+ padded: bool = False):
+
+ attn_bias = None
+ if padded: # bert style
+ if not causal:
+ attn_bias = BlockDiagonalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens)
+ else:
+ attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens)
+ elif causal: # gpt style
+ attn_bias = LowerTriangularMask()
+
+ if bias is not None: # alibi / relative position embedding
+ assert allow_alibi, "flash attention with bias is not supported in this system."
+ assert causal, \
+ "attention with bias is only supported for causal attention so far."
+ attn_bias = attn_bias.add_bias(bias)
+
+ if padded:
+ q = q.unsqueeze(0)
+ k = k.unsqueeze(0)
+ v = v.unsqueeze(0)
+
+ out = memory_efficient_attention(q, k, v, attn_bias=attn_bias, p=dropout_p, scale=scale)
+
+ # shape: (b*s, n, d)
+ if padded:
+ out = out.squeeze(0)
+
+ return out
diff --git a/colossalai/kernel/cuda_native/mha/mha.py b/colossalai/kernel/cuda_native/mha/mha.py
new file mode 100644
index 000000000000..8f449a138c51
--- /dev/null
+++ b/colossalai/kernel/cuda_native/mha/mha.py
@@ -0,0 +1,107 @@
+import math
+from typing import Optional
+
+import torch
+import torch.nn.functional as F
+from einops import rearrange
+
+from ..scaled_softmax import AttnMaskType
+from .flash_attn_2 import HAS_FLASH_ATTN
+from .mem_eff_attn import HAS_MEM_EFF_ATTN
+from .utils import Repad, SeqLenInfo, Unpad
+
+if HAS_FLASH_ATTN:
+ from .flash_attn_2 import flash_attention
+if HAS_MEM_EFF_ATTN:
+ from .mem_eff_attn import mem_eff_attention
+
+
+class ColoAttention(torch.nn.Module):
+
+ def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, scale=None):
+ super().__init__()
+ assert embed_dim % num_heads == 0, \
+ f"the embed dim ({embed_dim}) is not divisible by the number of attention heads ({num_heads})."
+ if scale is not None:
+ self.scale = scale
+ else:
+ self.scale = 1 / math.sqrt(embed_dim // num_heads)
+ self.dropout = dropout
+
+ if not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN:
+ raise Exception("flash attention can not support!")
+
+ @staticmethod
+ def unpad(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
+ return Unpad.apply(tensor, indices)
+
+ @staticmethod
+ def repad(tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int) -> torch.Tensor:
+ return Repad.apply(tensor, indices, batch_size, seq_len)
+
+ def forward(self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ attn_mask_type: Optional[AttnMaskType] = None,
+ bias: Optional[torch.Tensor] = None):
+
+ attn = None
+ if HAS_FLASH_ATTN and query.dtype in [torch.float16, torch.bfloat16] and bias == None:
+ attn = flash_attention
+ else:
+ attn = mem_eff_attention
+
+ padded = attn_mask_type is not None and attn_mask_type.value % 2 == 1
+ causal = attn_mask_type is not None and attn_mask_type.value > 1
+
+ batch_size, tgt_len, src_len = query.shape[0], query.shape[1], key.shape[1]
+ # unpad
+ seq_len_info_q = None
+ seq_len_info_kv = None
+ if padded:
+ # bert style, unpad process
+ assert attn_mask is not None, \
+ f"attention mask {attn_mask} is not valid for attention mask type {attn_mask_type}."
+ assert attn_mask.dim() == 2, \
+ "attention mask is supposed to have shape (batch_size, seq_len), " + \
+ f"but got {attn_mask.dim()} dimensions."
+
+ # bert style
+ if tgt_len == src_len:
+ seq_len_info_q = SeqLenInfo.materialize(attn_mask=attn_mask, device=query.device)
+ if batch_size > 1:
+ query, key, value = self.unpad(torch.stack([query, key, value], dim=2),
+ seq_len_info_q.indices).unbind(dim=1)
+ else:
+ query, key, value = torch.stack([query, key, value], dim=2).squeeze(0).unbind(dim=1)
+ seq_len_info_kv = seq_len_info_q
+ else:
+ seq_len_info_q = SeqLenInfo.materialize(size=(batch_size, tgt_len), device=query.device)
+ seq_len_info_kv = SeqLenInfo.materialize(attn_mask=attn_mask, device=query.device)
+ if batch_size > 1:
+ query = rearrange(query, "b s ... -> c (b s) ...", c=1)
+ key, value = self.unpad(torch.stack([query, key, value], dim=2),
+ seq_len_info_kv.indices).unbind(dim=1)
+ else:
+ query, key, value = torch.stack([query, key, value], dim=2).squeeze(0).unbind(dim=1)
+
+ out = attn(query,
+ key,
+ value,
+ seq_len_info_q,
+ seq_len_info_kv,
+ dropout_p=self.dropout,
+ scale=self.scale,
+ causal=causal,
+ padded=padded)
+
+ # repad
+ if padded:
+ if batch_size > 1:
+ out = self.repad(out, seq_len_info_q.indices, batch_size, tgt_len)
+ out = rearrange(out, '(b s) h d -> b s h d', b=batch_size)
+
+ out = rearrange(out, 'b s h d -> b s (h d)')
+ return out
diff --git a/colossalai/kernel/cuda_native/mha/utils.py b/colossalai/kernel/cuda_native/mha/utils.py
new file mode 100644
index 000000000000..e3e431fa7e99
--- /dev/null
+++ b/colossalai/kernel/cuda_native/mha/utils.py
@@ -0,0 +1,82 @@
+from dataclasses import dataclass
+from typing import Iterable, Tuple
+
+import torch
+import torch.nn.functional as F
+from einops import rearrange
+
+from colossalai.utils.cuda import get_current_device
+
+
+class Unpad(torch.autograd.Function):
+ """
+ Adapted from
+ https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
+ """
+
+ @staticmethod
+ def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor):
+ ctx.save_for_backward(indices)
+ # [b, s, ...]
+ assert tensor.ndim >= 3
+ ctx.bsz = tensor.shape[0]
+ out = rearrange(tensor, 'b s ... -> (b s) ...')
+ ctx.shape = out.shape
+ # [ntokens, ...]
+ return out[indices]
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ indices, = ctx.saved_tensors
+ # [ntokens, ...]
+ grad = torch.zeros(ctx.shape, dtype=grad_output.dtype, device=grad_output.device)
+ grad[indices] = grad_output
+ grad = rearrange(grad, '(b s) ... -> b s ...', b=ctx.bsz)
+ # [b, s, ...]
+ return grad, None
+
+
+class Repad(torch.autograd.Function):
+ """
+ Adapted from
+ https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
+ """
+
+ @staticmethod
+ def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int):
+ ctx.save_for_backward(indices)
+ # [ntokens, ...]
+ tensor = tensor
+ out = torch.zeros((batch_size * seq_len, *tensor.shape[1:]), dtype=tensor.dtype, device=tensor.device)
+ # [b*s, ...]
+ out[indices] = tensor
+ return out
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ indices, = ctx.saved_tensors
+ # [b*s, ...]
+ grad = grad_output[indices]
+ # [ntokens, ...]
+ return grad, None, None, None
+
+
+@dataclass
+class SeqLenInfo:
+ seqlens: Iterable[int] = None
+ indices: torch.Tensor = None
+ max_seqlen: int = None
+ cu_seqlens: torch.Tensor = None
+
+ @staticmethod
+ def materialize(attn_mask: torch.Tensor = None, size: Tuple[int] = None, device=get_current_device()):
+ if attn_mask is not None:
+ indices = torch.nonzero(attn_mask.flatten(), as_tuple=False).flatten().to(device)
+ seqlens = attn_mask.sum(dim=-1, dtype=torch.int32).flatten()
+ else:
+ batch_size, tgt_len = size[0], size[1]
+ indices = torch.arange(batch_size * tgt_len, dtype=torch.long, device=device)
+ seqlens = torch.LongTensor([tgt_len] * batch_size, device=device)
+ max_seqlen = max(seqlens)
+ cu_seqlens = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)).to(device)
+ return SeqLenInfo(seqlens.tolist(), indices, max_seqlen, cu_seqlens)
diff --git a/colossalai/kernel/cuda_native/scaled_softmax.py b/colossalai/kernel/cuda_native/scaled_softmax.py
index 24e458bb3ea5..41cd4b20faa1 100644
--- a/colossalai/kernel/cuda_native/scaled_softmax.py
+++ b/colossalai/kernel/cuda_native/scaled_softmax.py
@@ -19,6 +19,7 @@
class AttnMaskType(enum.Enum):
padding = 1
causal = 2
+ paddedcausal = 3
class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
@@ -139,7 +140,7 @@ def is_kernel_available(self, mask, b, np, sq, sk):
if 0 <= sk <= 2048:
batch_per_block = self.get_batch_per_block(sq, sk, b, np)
- if self.attn_mask_type == AttnMaskType.causal:
+ if self.attn_mask_type.value > 1:
if attn_batches % batch_per_block == 0:
return True
else:
@@ -151,7 +152,7 @@ def forward_fused_softmax(self, input, mask):
b, np, sq, sk = input.size()
scale = self.scale if self.scale is not None else 1.0
- if self.attn_mask_type == AttnMaskType.causal:
+ if self.attn_mask_type.value > 1:
assert sq == sk, "causal mask is only for self attention"
# input is 3D tensor (attn_batches, sq, sk)
diff --git a/tests/test_utils/test_flash_attention.py b/tests/test_utils/test_flash_attention.py
index 7a28b0157384..d41ccd8321a8 100644
--- a/tests/test_utils/test_flash_attention.py
+++ b/tests/test_utils/test_flash_attention.py
@@ -4,11 +4,15 @@
import torch
from einops import rearrange
-from colossalai.kernel.cuda_native.flash_attention import HAS_MEM_EFF_ATTN
+from colossalai.kernel.cuda_native.mha.flash_attn_2 import HAS_FLASH_ATTN
+from colossalai.kernel.cuda_native.mha.mem_eff_attn import HAS_MEM_EFF_ATTN
from colossalai.testing import clear_cache_before_run, parameterize
-if HAS_MEM_EFF_ATTN:
- from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention
+if HAS_MEM_EFF_ATTN or HAS_FLASH_ATTN:
+ from colossalai.kernel.cuda_native.mha.mha import ColoAttention
+ from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType
+
+DTYPE = [torch.float16, torch.bfloat16, torch.float32]
def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale):
@@ -22,10 +26,13 @@ def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale):
return ref_out
-@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available")
+@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available")
@clear_cache_before_run()
-@parameterize('B, S, H, D_HEAD', [(6, 8, 4, 16)])
-def test_attention_gpt(B, S, H, D_HEAD, dtype=torch.float16):
+@parameterize('proj_shape', [(1, 8, 4, 16)])
+@parameterize('dtype', DTYPE)
+def test_attention_gpt(proj_shape, dtype):
+ # TODO check output value
+ (B, S, H, D_HEAD) = proj_shape
D = H * D_HEAD
c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda")
@@ -35,7 +42,11 @@ def test_attention_gpt(B, S, H, D_HEAD, dtype=torch.float16):
qkv = c_attn(x)
q, k, v = rearrange(qkv, 'b s (n h d) -> n b s h d', n=3, h=H)
- y = attn(q, k, v, attn_mask_type=AttnMaskType.causal)
+
+ mask = [torch.ones(S - i, dtype=dtype, device="cuda") for i in range(B)]
+ mask = torch.nn.utils.rnn.pad_sequence(mask, batch_first=True)
+
+ y = attn(q, k, v, attn_mask=mask, attn_mask_type=AttnMaskType.paddedcausal)
assert list(y.shape) == [B, S, D]
@@ -43,10 +54,12 @@ def test_attention_gpt(B, S, H, D_HEAD, dtype=torch.float16):
y.backward(dy)
-@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available")
+@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available")
@clear_cache_before_run()
-@parameterize('B, S, H, D_HEAD', [(6, 8, 4, 16)])
-def test_attention_bert(B, S, H, D_HEAD, dtype=torch.float16):
+@parameterize('proj_shape', [(6, 8, 4, 16)])
+@parameterize('dtype', DTYPE)
+def test_attention_bert(proj_shape, dtype):
+ (B, S, H, D_HEAD) = proj_shape
D = H * D_HEAD
c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda")
@@ -67,10 +80,12 @@ def test_attention_bert(B, S, H, D_HEAD, dtype=torch.float16):
y.backward(dy)
-@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available")
+@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available")
@clear_cache_before_run()
-@parameterize('B, S, H, D_HEAD', [(6, 8, 4, 16)])
-def test_attention_no_mask(B, S, H, D_HEAD, dtype=torch.float16):
+@parameterize('proj_shape', [(6, 8, 4, 16)])
+@parameterize('dtype', DTYPE)
+def test_attention_no_mask(proj_shape, dtype):
+ (B, S, H, D_HEAD) = proj_shape
D = H * D_HEAD
c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda")
@@ -87,10 +102,12 @@ def test_attention_no_mask(B, S, H, D_HEAD, dtype=torch.float16):
y.backward(dy)
-@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available")
+@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available")
@clear_cache_before_run()
-@parameterize('B, S, T, H, D_HEAD', [(6, 24, 8, 4, 16)])
-def test_cross_attention(B, S, T, H, D_HEAD, dtype=torch.float16):
+@parameterize('proj_shape', [(6, 24, 8, 4, 16)])
+@parameterize('dtype', DTYPE)
+def test_cross_attention(proj_shape, dtype):
+ (B, S, T, H, D_HEAD) = proj_shape
D = H * D_HEAD
q_attn = torch.nn.Linear(D, D, dtype=dtype, device="cuda")
From 38b792aab2cf6e33f1693489eecbff622dff2c35 Mon Sep 17 00:00:00 2001
From: flybird1111 <1829166702@qq.com>
Date: Fri, 4 Aug 2023 16:28:41 +0800
Subject: [PATCH 48/64] [coloattention] fix import error (#4380)
fixed an import error
---
colossalai/kernel/cuda_native/mha/__init__.py | 3 +++
tests/test_utils/test_flash_attention.py | 2 +-
2 files changed, 4 insertions(+), 1 deletion(-)
create mode 100644 colossalai/kernel/cuda_native/mha/__init__.py
diff --git a/colossalai/kernel/cuda_native/mha/__init__.py b/colossalai/kernel/cuda_native/mha/__init__.py
new file mode 100644
index 000000000000..21fddd512957
--- /dev/null
+++ b/colossalai/kernel/cuda_native/mha/__init__.py
@@ -0,0 +1,3 @@
+from .mha import ColoAttention
+
+__all__ = ['ColoAttention']
diff --git a/tests/test_utils/test_flash_attention.py b/tests/test_utils/test_flash_attention.py
index d41ccd8321a8..fbcc452650cf 100644
--- a/tests/test_utils/test_flash_attention.py
+++ b/tests/test_utils/test_flash_attention.py
@@ -9,7 +9,7 @@
from colossalai.testing import clear_cache_before_run, parameterize
if HAS_MEM_EFF_ATTN or HAS_FLASH_ATTN:
- from colossalai.kernel.cuda_native.mha.mha import ColoAttention
+ from colossalai.kernel.cuda_native import ColoAttention
from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType
DTYPE = [torch.float16, torch.bfloat16, torch.float32]
From f40b718959b496c797da8dfa17194b63858fc2b1 Mon Sep 17 00:00:00 2001
From: flybird1111 <1829166702@qq.com>
Date: Fri, 4 Aug 2023 17:24:35 +0800
Subject: [PATCH 49/64] [doc] Fix gradient accumulation doc. (#4349)
* [doc] fix gradient accumulation doc
* [doc] fix gradient accumulation doc
---
docs/source/en/features/gradient_accumulation_with_booster.md | 2 ++
.../zh-Hans/features/gradient_accumulation_with_booster.md | 2 ++
2 files changed, 4 insertions(+)
diff --git a/docs/source/en/features/gradient_accumulation_with_booster.md b/docs/source/en/features/gradient_accumulation_with_booster.md
index 201e3bc2b643..7bc4eb47bcd7 100644
--- a/docs/source/en/features/gradient_accumulation_with_booster.md
+++ b/docs/source/en/features/gradient_accumulation_with_booster.md
@@ -103,10 +103,12 @@ for idx, (img, label) in enumerate(train_dataloader):
with sync_context:
output = model(img)
train_loss = criterion(output, label)
+ train_loss = train_loss / GRADIENT_ACCUMULATION
booster.backward(train_loss, optimizer)
else:
output = model(img)
train_loss = criterion(output, label)
+ train_loss = train_loss / GRADIENT_ACCUMULATION
booster.backward(train_loss, optimizer)
optimizer.step()
optimizer.zero_grad()
diff --git a/docs/source/zh-Hans/features/gradient_accumulation_with_booster.md b/docs/source/zh-Hans/features/gradient_accumulation_with_booster.md
index a8422060f0ea..d121b161b9ff 100644
--- a/docs/source/zh-Hans/features/gradient_accumulation_with_booster.md
+++ b/docs/source/zh-Hans/features/gradient_accumulation_with_booster.md
@@ -106,10 +106,12 @@ for idx, (img, label) in enumerate(train_dataloader):
with sync_context:
output = model(img)
train_loss = criterion(output, label)
+ train_loss = train_loss / GRADIENT_ACCUMULATION
booster.backward(train_loss, optimizer)
else:
output = model(img)
train_loss = criterion(output, label)
+ train_loss = train_loss / GRADIENT_ACCUMULATION
booster.backward(train_loss, optimizer)
optimizer.step()
optimizer.zero_grad()
From 089c365fa0690485acb8e8335392095ca426633d Mon Sep 17 00:00:00 2001
From: binmakeswell
Date: Fri, 4 Aug 2023 17:42:07 +0800
Subject: [PATCH 50/64] [doc] add Series A Funding and NeurIPS news (#4377)
* [doc] add Series A Funding and NeurIPS news
* [kernal] fix mha kernal
* [CI] skip moe
* [CI] fix requirements
---
README.md | 5 +++--
colossalai/nn/optimizer/README.md | 3 ++-
docs/README-zh-Hans.md | 6 +++---
examples/tutorial/README.md | 3 ++-
pytest.ini | 2 +-
requirements/requirements.txt | 1 +
6 files changed, 12 insertions(+), 8 deletions(-)
diff --git a/README.md b/README.md
index 21670e1e59fb..44e4f97f1f4e 100644
--- a/README.md
+++ b/README.md
@@ -25,6 +25,7 @@
## Latest News
+* [2023/07] [HPC-AI Tech Raises 22 Million USD in Series A Funding](https://www.hpc-ai.tech/blog/hpc-ai-tech-raises-22-million-usd-in-series-a-funding-to-fuel-team-expansion-and-business-growth)
* [2023/07] [65B Model Pretraining Accelerated by 38%, Best Practices for Building LLaMA-Like Base Models Open-Source](https://www.hpc-ai.tech/blog/large-model-pretraining)
* [2023/03] [ColossalChat: An Open-Source Solution for Cloning ChatGPT With a Complete RLHF Pipeline](https://medium.com/@yangyou_berkeley/colossalchat-an-open-source-solution-for-cloning-chatgpt-with-a-complete-rlhf-pipeline-5edf08fb538b)
* [2023/03] [Intel and Colossal-AI Partner to Deliver Cost-Efficient Open-Source Solution for Protein Folding Structure Prediction](https://www.hpc-ai.tech/blog/intel-habana)
@@ -33,7 +34,6 @@
* [2023/01] [Hardware Savings Up to 46 Times for AIGC and Automatic Parallelism](https://medium.com/pytorch/latest-colossal-ai-boasts-novel-automatic-parallelism-and-offers-savings-up-to-46x-for-stable-1453b48f3f02)
* [2022/11] [Diffusion Pretraining and Hardware Fine-Tuning Can Be Almost 7X Cheaper](https://www.hpc-ai.tech/blog/diffusion-pretraining-and-hardware-fine-tuning-can-be-almost-7x-cheaper)
* [2022/10] [Use a Laptop to Analyze 90% of Proteins, With a Single-GPU Inference Sequence Exceeding 10,000](https://www.hpc-ai.tech/blog/use-a-laptop-to-analyze-90-of-proteins-with-a-single-gpu-inference-sequence-exceeding)
-* [2022/09] [HPC-AI Tech Completes $6 Million Seed and Angel Round Fundraising](https://www.hpc-ai.tech/blog/hpc-ai-tech-completes-6-million-seed-and-angel-round-fundraising-led-by-bluerun-ventures-in-the)
## Table of Contents
@@ -463,6 +463,7 @@ To cite this project, you can use the following BibTeX citation.
}
```
-Colossal-AI has been accepted as official tutorial by top conferences [SC](https://sc22.supercomputing.org/), [AAAI](https://aaai.org/Conferences/AAAI-23/), [PPoPP](https://ppopp23.sigplan.org/), [CVPR](https://cvpr2023.thecvf.com/), [ISC](https://www.isc-hpc.com/), etc.
+Colossal-AI has been accepted as official tutorial by top conferences [NeurIPS](https://nips.cc/), [SC](https://sc22.supercomputing.org/), [AAAI](https://aaai.org/Conferences/AAAI-23/),
+[PPoPP](https://ppopp23.sigplan.org/), [CVPR](https://cvpr2023.thecvf.com/), [ISC](https://www.isc-hpc.com/), [NVIDIA GTC](https://www.nvidia.com/en-us/on-demand/session/gtcspring23-S51482/) ,etc.
(back to top)
diff --git a/colossalai/nn/optimizer/README.md b/colossalai/nn/optimizer/README.md
index 09395d08b93e..d839753d6c44 100644
--- a/colossalai/nn/optimizer/README.md
+++ b/colossalai/nn/optimizer/README.md
@@ -3,7 +3,8 @@
## Introduction
Welcome to the large-scale deep learning optimization techniques of [Colossal-AI](https://github.com/hpcaitech/ColossalAI),
-which has been accepted as official tutorials by top conference [SC](https://sc22.supercomputing.org/), [AAAI](https://aaai.org/Conferences/AAAI-23/), [PPoPP](https://ppopp23.sigplan.org/), [CVPR](https://cvpr2023.thecvf.com/), [ISC](https://www.isc-hpc.com/), etc.
+which has been accepted as official tutorials by top conference [NeurIPS](https://nips.cc/), [SC](https://sc22.supercomputing.org/), [AAAI](https://aaai.org/Conferences/AAAI-23/),
+[PPoPP](https://ppopp23.sigplan.org/), [CVPR](https://cvpr2023.thecvf.com/), [ISC](https://www.isc-hpc.com/), [NVIDIA GTC](https://www.nvidia.com/en-us/on-demand/session/gtcspring23-S51482/) ,etc.
[Colossal-AI](https://github.com/hpcaitech/ColossalAI), a unified deep learning system for the big model era, integrates
diff --git a/docs/README-zh-Hans.md b/docs/README-zh-Hans.md
index e229c65d890c..945ca4080413 100644
--- a/docs/README-zh-Hans.md
+++ b/docs/README-zh-Hans.md
@@ -24,6 +24,7 @@
## 新闻
+* [2023/07] [HPC-AI Tech Raises 22 Million USD in Series A Funding](https://www.hpc-ai.tech/blog/hpc-ai-tech-raises-22-million-usd-in-series-a-funding-to-fuel-team-expansion-and-business-growth)
* [2023/07] [65B Model Pretraining Accelerated by 38%, Best Practices for Building LLaMA-Like Base Models Open-Source](https://www.hpc-ai.tech/blog/large-model-pretraining)
* [2023/03] [ColossalChat: An Open-Source Solution for Cloning ChatGPT With a Complete RLHF Pipeline](https://medium.com/@yangyou_berkeley/colossalchat-an-open-source-solution-for-cloning-chatgpt-with-a-complete-rlhf-pipeline-5edf08fb538b)
* [2023/03] [Intel and Colossal-AI Partner to Deliver Cost-Efficient Open-Source Solution for Protein Folding Structure Prediction](https://www.hpc-ai.tech/blog/intel-habana)
@@ -32,8 +33,6 @@
* [2023/01] [Hardware Savings Up to 46 Times for AIGC and Automatic Parallelism](https://medium.com/pytorch/latest-colossal-ai-boasts-novel-automatic-parallelism-and-offers-savings-up-to-46x-for-stable-1453b48f3f02)
* [2022/11] [Diffusion Pretraining and Hardware Fine-Tuning Can Be Almost 7X Cheaper](https://www.hpc-ai.tech/blog/diffusion-pretraining-and-hardware-fine-tuning-can-be-almost-7x-cheaper)
* [2022/10] [Use a Laptop to Analyze 90% of Proteins, With a Single-GPU Inference Sequence Exceeding 10,000](https://www.hpc-ai.tech/blog/use-a-laptop-to-analyze-90-of-proteins-with-a-single-gpu-inference-sequence-exceeding)
-* [2022/09] [HPC-AI Tech Completes $6 Million Seed and Angel Round Fundraising](https://www.hpc-ai.tech/blog/hpc-ai-tech-completes-6-million-seed-and-angel-round-fundraising-led-by-bluerun-ventures-in-the)
-
## 目录
@@ -444,6 +443,7 @@ Colossal-AI项目受一些相关的项目启发而成立,一些项目是我们
}
```
-Colossal-AI 已被 [SC](https://sc22.supercomputing.org/), [AAAI](https://aaai.org/Conferences/AAAI-23/), [PPoPP](https://ppopp23.sigplan.org/), [CVPR](https://cvpr2023.thecvf.com/), [ISC](https://www.isc-hpc.com/)等顶级会议录取为官方教程。
+Colossal-AI 已被[NeurIPS](https://nips.cc/), [SC](https://sc22.supercomputing.org/), [AAAI](https://aaai.org/Conferences/AAAI-23/),
+[PPoPP](https://ppopp23.sigplan.org/), [CVPR](https://cvpr2023.thecvf.com/), [ISC](https://www.isc-hpc.com/), [NVIDIA GTC](https://www.nvidia.com/en-us/on-demand/session/gtcspring23-S51482/) ,等顶级会议录取为官方教程。
(返回顶端)
diff --git a/examples/tutorial/README.md b/examples/tutorial/README.md
index 0664d41fd359..7b5668612818 100644
--- a/examples/tutorial/README.md
+++ b/examples/tutorial/README.md
@@ -4,7 +4,8 @@
## Introduction
-Welcome to the [Colossal-AI](https://github.com/hpcaitech/ColossalAI) tutorial, which has been accepted as official tutorials by top conference [SC](https://sc22.supercomputing.org/), [AAAI](https://aaai.org/Conferences/AAAI-23/), [PPoPP](https://ppopp23.sigplan.org/), [CVPR](https://cvpr2023.thecvf.com/), [ISC](https://www.isc-hpc.com/), etc.
+Welcome to the [Colossal-AI](https://github.com/hpcaitech/ColossalAI) tutorial, which has been accepted as official tutorials by top conference [NeurIPS](https://nips.cc/), [SC](https://sc22.supercomputing.org/), [AAAI](https://aaai.org/Conferences/AAAI-23/),
+[PPoPP](https://ppopp23.sigplan.org/), [CVPR](https://cvpr2023.thecvf.com/), [ISC](https://www.isc-hpc.com/), [NVIDIA GTC](https://www.nvidia.com/en-us/on-demand/session/gtcspring23-S51482/) ,etc.
[Colossal-AI](https://github.com/hpcaitech/ColossalAI), a unified deep learning system for the big model era, integrates
diff --git a/pytest.ini b/pytest.ini
index e99fe3f086c6..e8a60c85336b 100644
--- a/pytest.ini
+++ b/pytest.ini
@@ -4,4 +4,4 @@ markers =
gpu: tests which requires a single GPU
dist: tests which are run in a multi-GPU or multi-machine environment
experiment: tests for experimental features
-addopts = --ignore=tests/test_analyzer --ignore=tests/test_auto_parallel --ignore=tests/test_autochunk
+addopts = --ignore=tests/test_analyzer --ignore=tests/test_auto_parallel --ignore=tests/test_autochunk --ignore=tests/test_moe
diff --git a/requirements/requirements.txt b/requirements/requirements.txt
index b34dc2e223ae..f6be6a624c70 100644
--- a/requirements/requirements.txt
+++ b/requirements/requirements.txt
@@ -10,3 +10,4 @@ contexttimer
ninja
torch>=1.11
safetensors
+einops
From 7c84f5105dbe14bc9b0d646b8817f6a2a2f47ba6 Mon Sep 17 00:00:00 2001
From: flybird1111 <1829166702@qq.com>
Date: Mon, 7 Aug 2023 16:41:07 +0800
Subject: [PATCH 51/64] [Shardformer] Merge flash attention branch to pipeline
branch (#4362)
* [shardformer] supported flash attention test dependency (#4158)
* [shardformer] fix flash attention utils test (#4180)
* [shardformer] opt support flash attention (#4163)
* [shardformer] opt support flash attention
* [shardformer] opt support flash attention
* [shardformer] opt support flash attention
* [shardformer] opt support flash attention
* [shardformer] move to modeling
* [shardformer] move to modeling
* [shardformer] add performance benchmark of shardformer (#4175)
* [shardformer] opt support flash attention
* [shardformer] opt support flash attention
* [shardformer] opt support flash attention
* [shardformer] benchmark fix
* [shardformer] benchmark fix
* [shardformer] llama support flash attention (#4185)
* [shardformer] opt support flash attention
* [shardformer] opt support flash attention
* [shardformer] opt support flash attention
* [shardformer] opt support flash attention
* [shardformer] move to modeling
* [shardformer] move to modeling
* [shardformer] llama support flash attention
* [shardformer] llama support flash attention
* [shardformer] Move the import statement for xformer outside the forward function.
* [shardformer] gpt2 support flash attention. (#4191)
* [shardformer] opt support flash attention
* [shardformer] opt support flash attention
* [shardformer] opt support flash attention
* [shardformer] opt support flash attention
* [shardformer] move to modeling
* [shardformer] move to modeling
* [shardformer] gpt2 support flash attention
* [shardformer] gpt2 support flash attention
* [shardformer] gpt2 support flash attention
* [shardformer] bloom support flash attention (#4188)
* [shardformer] opt support flash attention
* [shardformer] opt support flash attention
* [shardformer] opt support flash attention
* [shardformer] opt support flash attention
* [shardformer] move to modeling
* [shardformer] move to modeling
* [shardformer] bloom suport flash attention
* [shardformer] add assert to sequence length
* [shardformer] fix
* [shardformer] fix
* [shardformer] fix
* [shardformer] bert support flash attention. (#4206)
* [shardformer] opt support flash attention
* [shardformer] opt support flash attention
* [shardformer] opt support flash attention
* [shardformer] opt support flash attention
* [shardformer] move to modeling
* [shardformer] move to modeling
* [shardformer] bert support flash attention
* [shardformer] t5 support flash attention. (#4216)
* [shardformer] opt support flash attention
* [shardformer] opt support flash attention
* [shardformer] opt support flash attention
* [shardformer] opt support flash attention
* [shardformer] move to modeling
* [shardformer] move to modeling
* [shardformer] t5 support flash attention
* [shardformer] t5 support flash attention
* fix typo
* fix typo
* fix typo
* fix typo
* fix typo
* fix typo
* [shardformer] support 'paddedcausal' type of attention mask in Coloattention. (#4215)
* added padded causal attn mask type for ColoAttention
* [shardformer]t5 flash attention fix (#4239)
* [shardformer] opt support flash attention
* [shardformer] opt support flash attention
* [shardformer] opt support flash attention
* [shardformer] opt support flash attention
* [shardformer] move to modeling
* [shardformer] move to modeling
* [shardformer] t5 flash attention fix
* [shardformer] update gpt2 to use coloattention. (#4234)
* [shardformer] opt support flash attention
* [shardformer] opt support flash attention
* [shardformer] opt support flash attention
* [shardformer] opt support flash attention
* [shardformer] move to modeling
* [shardformer] move to modeling
* [shardformer] update gpt2 to use coloattention
* [shardformer] update gpt2 to use coloattention
* [shardformer] update gpt2 to use coloattention
* [shardformer] update gpt2 to use coloattention
* [shardformer] update gpt2
* [shardformer] update opt and llama to use coloattention. (#4226)
* [shardformer] opt support flash attention
* [shardformer] opt support flash attention
* [shardformer] opt support flash attention
* [shardformer] opt support flash attention
* [shardformer] move to modeling
* [shardformer] move to modeling
* update opt to use coloattention
* [shardformer]update opt to use coloattention
* [shardformer]update opt to use coloattention
* [shardformer]update opt to use coloattention
* [shardformer]update opt to use coloattention
* [shardformer]update opt to use coloattention
* [shardformer]update opt to use coloattention
* [shardformer]update opt
* [shardformer] shardformer support jit fused operator. (#4236)
* [shardformer] opt support flash attention
* [shardformer] opt support flash attention
* [shardformer] opt support flash attention
* [shardformer] opt support flash attention
* [shardformer] move to modeling
* [shardformer] move to modeling
* [shardformer] bloom support jit fused operator
* [shardformer] bloom support jit fused operator
* [shardformer] bloom support jit fused operator
* [shardformer] t5 support jit fused operator
* [shardformer] t5 support jit fused operator
* [shardformer] t5 support jit fused operator
* [shardformer] add roadmap of flash attention
* [shardformer] add roadmap of flash attention
* [shardformer] add roadmap of flash attention
* [shardformer] add type hint to 'self' param of forward
* [shardformer] merge feature/shardformer-models branch to feature/flash-attention-shardformer branch. (#4290)
* Feature/vit support (#4182)
* [shardformer] added tests
* [shardformer] vit test finish and support
* fix attention dropout
* [shardformer] support SAM (#4231)
* 1.support sam 2.add fused qkv for nn.Linear
* update utils support set element in list
* overtwrite SamVisionAttention foward to use DropoutForParallelInput
* remove unused code
* [shardformer] support whisper (#4212)
* support whisper
* fix bug in vocabembedding
* support downstream model of whisper
* update readme
* Feature/chatglm (#4240)
* [shardformer] added tests
* [shardformer] vit test finish and support
* [shardformer] chatglm ready
* import chatglm
* [shardformer] add test kit in model zoo for chatglm
* [sharformer] add first version of policy of chatglm
* [shardformer] polish chatglm code
* [shardformer] polish code
* [shardformer] support chatglm without layernorm
* [shardformer] chatglm shard without mlp sharding
* [shardformer] delete some file
* [shardformer] ChatGLM support layernorm sharding
* [shardformer] register without auto policy
* [shardformer] pre-commit check files
* [shardformer] fix chatglm configuration with pre-commit
---------
Co-authored-by: Kun Lin <81014421+klhhhhh@users.noreply.github.com>
Co-authored-by: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com>
* [shardformer] whisper support flash attention (#4301)
* Feature/vit support (#4182)
* [shardformer] added tests
* [shardformer] vit test finish and support
* fix attention dropout
* [shardformer] support SAM (#4231)
* 1.support sam 2.add fused qkv for nn.Linear
* update utils support set element in list
* overtwrite SamVisionAttention foward to use DropoutForParallelInput
* remove unused code
* [shardformer] support whisper (#4212)
* support whisper
* fix bug in vocabembedding
* support downstream model of whisper
* update readme
* Feature/chatglm (#4240)
* [shardformer] added tests
* [shardformer] vit test finish and support
* [shardformer] chatglm ready
* import chatglm
* [shardformer] add test kit in model zoo for chatglm
* [sharformer] add first version of policy of chatglm
* [shardformer] polish chatglm code
* [shardformer] polish code
* [shardformer] support chatglm without layernorm
* [shardformer] chatglm shard without mlp sharding
* [shardformer] delete some file
* [shardformer] ChatGLM support layernorm sharding
* [shardformer] register without auto policy
* [shardformer] pre-commit check files
* [shardformer] fix chatglm configuration with pre-commit
* [shardformer] whisper support flash attention
* [shardformer] whisper support flash attention
* [shardformer]whisper support jit operator
---------
Co-authored-by: Kun Lin <81014421+klhhhhh@users.noreply.github.com>
Co-authored-by: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com>
* [shardformer] sam support flash attention (#4316)
* Feature/vit support (#4182)
* [shardformer] added tests
* [shardformer] vit test finish and support
* fix attention dropout
* [shardformer] support SAM (#4231)
* 1.support sam 2.add fused qkv for nn.Linear
* update utils support set element in list
* overtwrite SamVisionAttention foward to use DropoutForParallelInput
* remove unused code
* [shardformer] support whisper (#4212)
* support whisper
* fix bug in vocabembedding
* support downstream model of whisper
* update readme
* Feature/chatglm (#4240)
* [shardformer] added tests
* [shardformer] vit test finish and support
* [shardformer] chatglm ready
* import chatglm
* [shardformer] add test kit in model zoo for chatglm
* [sharformer] add first version of policy of chatglm
* [shardformer] polish chatglm code
* [shardformer] polish code
* [shardformer] support chatglm without layernorm
* [shardformer] chatglm shard without mlp sharding
* [shardformer] delete some file
* [shardformer] ChatGLM support layernorm sharding
* [shardformer] register without auto policy
* [shardformer] pre-commit check files
* [shardformer] fix chatglm configuration with pre-commit
* [shardformer] sam support flash attention
---------
Co-authored-by: Kun Lin <81014421+klhhhhh@users.noreply.github.com>
Co-authored-by: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com>
* [shardformer] merge blip2/chatglm (#4321)
* Feature/vit support (#4182)
* [shardformer] added tests
* [shardformer] vit test finish and support
* fix attention dropout
* [shardformer] support SAM (#4231)
* 1.support sam 2.add fused qkv for nn.Linear
* update utils support set element in list
* overtwrite SamVisionAttention foward to use DropoutForParallelInput
* remove unused code
* [shardformer] support whisper (#4212)
* support whisper
* fix bug in vocabembedding
* support downstream model of whisper
* update readme
* Feature/chatglm (#4240)
* [shardformer] added tests
* [shardformer] vit test finish and support
* [shardformer] chatglm ready
* import chatglm
* [shardformer] add test kit in model zoo for chatglm
* [sharformer] add first version of policy of chatglm
* [shardformer] polish chatglm code
* [shardformer] polish code
* [shardformer] support chatglm without layernorm
* [shardformer] chatglm shard without mlp sharding
* [shardformer] delete some file
* [shardformer] ChatGLM support layernorm sharding
* [shardformer] register without auto policy
* [shardformer] pre-commit check files
* [shardformer] fix chatglm configuration with pre-commit
* [shardformer] added tests
* [shardformer] vit test finish and support
* import chatglm
* [shardformer] add test kit in model zoo for chatglm
* [sharformer] add first version of policy of chatglm
* [shardformer] polish chatglm code
* [shardformer] polish code
* [shardformer] support chatglm without layernorm
* [shardformer] delete some file
* [shardformer] ChatGLM support layernorm sharding
* [shardformer] register without auto policy
* [shardformer] pre-commit check files
* [shardformer] support ChatGLMForConditionalGeneration & add fusedlayernorm for vit
* [shardformer] support Blip2 (#4243)
* support base blip2
* add support for downstream blip2 model
* update readme
* add forward injection
* skip not compatible models test
* fix test for gemini and low_level_zero_pugin
---------
Co-authored-by: Kun Lin <81014421+klhhhhh@users.noreply.github.com>
Co-authored-by: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com>
Co-authored-by: klhhhhh <1412841649@qq.com>
* [shardformer] blip2 support flash attention and jit operator (#4325)
* Feature/vit support (#4182)
* [shardformer] added tests
* [shardformer] vit test finish and support
* fix attention dropout
* [shardformer] support SAM (#4231)
* 1.support sam 2.add fused qkv for nn.Linear
* update utils support set element in list
* overtwrite SamVisionAttention foward to use DropoutForParallelInput
* remove unused code
* [shardformer] support whisper (#4212)
* support whisper
* fix bug in vocabembedding
* support downstream model of whisper
* update readme
* Feature/chatglm (#4240)
* [shardformer] added tests
* [shardformer] vit test finish and support
* [shardformer] chatglm ready
* import chatglm
* [shardformer] add test kit in model zoo for chatglm
* [sharformer] add first version of policy of chatglm
* [shardformer] polish chatglm code
* [shardformer] polish code
* [shardformer] support chatglm without layernorm
* [shardformer] chatglm shard without mlp sharding
* [shardformer] delete some file
* [shardformer] ChatGLM support layernorm sharding
* [shardformer] register without auto policy
* [shardformer] pre-commit check files
* [shardformer] fix chatglm configuration with pre-commit
* [shardformer] added tests
* [shardformer] vit test finish and support
* import chatglm
* [shardformer] add test kit in model zoo for chatglm
* [sharformer] add first version of policy of chatglm
* [shardformer] polish chatglm code
* [shardformer] polish code
* [shardformer] support chatglm without layernorm
* [shardformer] delete some file
* [shardformer] ChatGLM support layernorm sharding
* [shardformer] register without auto policy
* [shardformer] pre-commit check files
* [shardformer] support ChatGLMForConditionalGeneration & add fusedlayernorm for vit
* [shardformer] support Blip2 (#4243)
* support base blip2
* add support for downstream blip2 model
* update readme
* add forward injection
* skip not compatible models test
* fix test for gemini and low_level_zero_pugin
* [shardformer] blip2 support flash attention and jit operator
* [shardformer] blip2 support flash attention and jit operator
* [shardformer] blip2 support flash attention and jit operator
---------
Co-authored-by: Kun Lin <81014421+klhhhhh@users.noreply.github.com>
Co-authored-by: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com>
Co-authored-by: klhhhhh <1412841649@qq.com>
* [shardformer] chatglm support flash attention and jit operator (#4330)
* Feature/vit support (#4182)
* [shardformer] added tests
* [shardformer] vit test finish and support
* fix attention dropout
* [shardformer] support SAM (#4231)
* 1.support sam 2.add fused qkv for nn.Linear
* update utils support set element in list
* overtwrite SamVisionAttention foward to use DropoutForParallelInput
* remove unused code
* [shardformer] support whisper (#4212)
* support whisper
* fix bug in vocabembedding
* support downstream model of whisper
* update readme
* Feature/chatglm (#4240)
* [shardformer] added tests
* [shardformer] vit test finish and support
* [shardformer] chatglm ready
* import chatglm
* [shardformer] add test kit in model zoo for chatglm
* [sharformer] add first version of policy of chatglm
* [shardformer] polish chatglm code
* [shardformer] polish code
* [shardformer] support chatglm without layernorm
* [shardformer] chatglm shard without mlp sharding
* [shardformer] delete some file
* [shardformer] ChatGLM support layernorm sharding
* [shardformer] register without auto policy
* [shardformer] pre-commit check files
* [shardformer] fix chatglm configuration with pre-commit
* [shardformer] added tests
* [shardformer] vit test finish and support
* import chatglm
* [shardformer] add test kit in model zoo for chatglm
* [sharformer] add first version of policy of chatglm
* [shardformer] polish chatglm code
* [shardformer] polish code
* [shardformer] support chatglm without layernorm
* [shardformer] delete some file
* [shardformer] ChatGLM support layernorm sharding
* [shardformer] register without auto policy
* [shardformer] pre-commit check files
* [shardformer] support ChatGLMForConditionalGeneration & add fusedlayernorm for vit
* [shardformer] support Blip2 (#4243)
* support base blip2
* add support for downstream blip2 model
* update readme
* add forward injection
* skip not compatible models test
* fix test for gemini and low_level_zero_pugin
* [shardformer] chatglm support flash attention and jit operator
* [shardformer] chatglm support flash attention and jit operator
* [shardformer] chatglm support flash attention and jit operator
* [shardformer] chatglm support flash attention and jit operator
---------
Co-authored-by: Kun Lin <81014421+klhhhhh@users.noreply.github.com>
Co-authored-by: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com>
Co-authored-by: klhhhhh <1412841649@qq.com>
* [shardformer] vit support flash attention and jit operator (#4334)
* Feature/vit support (#4182)
* [shardformer] added tests
* [shardformer] vit test finish and support
* fix attention dropout
* [shardformer] support SAM (#4231)
* 1.support sam 2.add fused qkv for nn.Linear
* update utils support set element in list
* overtwrite SamVisionAttention foward to use DropoutForParallelInput
* remove unused code
* [shardformer] support whisper (#4212)
* support whisper
* fix bug in vocabembedding
* support downstream model of whisper
* update readme
* Feature/chatglm (#4240)
* [shardformer] added tests
* [shardformer] vit test finish and support
* [shardformer] chatglm ready
* import chatglm
* [shardformer] add test kit in model zoo for chatglm
* [sharformer] add first version of policy of chatglm
* [shardformer] polish chatglm code
* [shardformer] polish code
* [shardformer] support chatglm without layernorm
* [shardformer] chatglm shard without mlp sharding
* [shardformer] delete some file
* [shardformer] ChatGLM support layernorm sharding
* [shardformer] register without auto policy
* [shardformer] pre-commit check files
* [shardformer] fix chatglm configuration with pre-commit
* [shardformer] added tests
* [shardformer] vit test finish and support
* import chatglm
* [shardformer] add test kit in model zoo for chatglm
* [sharformer] add first version of policy of chatglm
* [shardformer] polish chatglm code
* [shardformer] polish code
* [shardformer] support chatglm without layernorm
* [shardformer] delete some file
* [shardformer] ChatGLM support layernorm sharding
* [shardformer] register without auto policy
* [shardformer] pre-commit check files
* [shardformer] support ChatGLMForConditionalGeneration & add fusedlayernorm for vit
* [shardformer] support Blip2 (#4243)
* support base blip2
* add support for downstream blip2 model
* update readme
* add forward injection
* skip not compatible models test
* fix test for gemini and low_level_zero_pugin
* [shardformer] vit support flash attention and jit operator
* [shardformer] vit support flash attention and jit operator
---------
Co-authored-by: Kun Lin <81014421+klhhhhh@users.noreply.github.com>
Co-authored-by: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com>
Co-authored-by: klhhhhh <1412841649@qq.com>
* [pipeline] merge flash attention branch
* [pipeline] merge flash attention branch
* [pipeline] merge flash attention branch
* [pipeline] fix conflict
* [pipeline] fix conflict
* Merge branch 'feature/pipeline' into feature/pipeline
* Merge branch 'feature/pipeline' into feature/pipeline
* Merge branch 'feature/pipeline' into feature/pipeline
* activate checks
* activate checks
* activate checks
* activate checks
* activate checks
* activate checks
* activate checks
* activate checks
* fix flash attention tests
* gemini ignore whisper
* fix vit
* fix xformers import handle
---------
Co-authored-by: Frank Lee
Co-authored-by: Kun Lin <81014421+klhhhhh@users.noreply.github.com>
Co-authored-by: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com>
Co-authored-by: klhhhhh <1412841649@qq.com>
---
.../kernel/cuda_native/flash_attention.py | 26 +-
.../kernel/cuda_native/scaled_softmax.py | 5 +-
colossalai/shardformer/README.md | 58 +-
..._benchmark.py => convergence_benchmark.py} | 0
..._benchmark.sh => convergence_benchmark.sh} | 2 +-
.../examples/performance_benchmark.py | 86 ++
colossalai/shardformer/modeling/bert.py | 138 +-
colossalai/shardformer/modeling/blip2.py | 60 +
colossalai/shardformer/modeling/bloom.py | 221 +++
colossalai/shardformer/modeling/chatglm.py | 110 ++
colossalai/shardformer/modeling/gpt2.py | 85 +
colossalai/shardformer/modeling/jit.py | 34 +
colossalai/shardformer/modeling/llama.py | 66 +-
colossalai/shardformer/modeling/opt.py | 174 +++
colossalai/shardformer/modeling/sam.py | 164 ++
colossalai/shardformer/modeling/t5.py | 206 +++
colossalai/shardformer/modeling/vit.py | 49 +
colossalai/shardformer/modeling/whisper.py | 249 +++
colossalai/shardformer/policies/bert.py | 34 +-
colossalai/shardformer/policies/blip2.py | 28 +-
colossalai/shardformer/policies/bloom.py | 34 +-
colossalai/shardformer/policies/chatglm.py | 20 +-
colossalai/shardformer/policies/gpt2.py | 90 +-
colossalai/shardformer/policies/llama.py | 9 +-
colossalai/shardformer/policies/opt.py | 17 +-
colossalai/shardformer/policies/sam.py | 12 +-
colossalai/shardformer/policies/t5.py | 30 +-
colossalai/shardformer/policies/vit.py | 48 +-
colossalai/shardformer/policies/whisper.py | 25 +
colossalai/shardformer/shard/shard_config.py | 5 +-
pytest.ini | 1 +
requirements/requirements-test.txt | 4 +-
requirements/requirements.txt | 1 +
tests/kit/model_zoo/transformers/bert.py | 16 +-
tests/kit/model_zoo/transformers/blip2.py | 1 +
tests/kit/model_zoo/transformers/bloom.py | 10 +-
tests/kit/model_zoo/transformers/chatglm.py | 1 -
.../chatglm2_6b/configuration_chatglm.py | 58 -
.../chatglm2_6b/modeling_chatglm.py | 1372 -----------------
tests/kit/model_zoo/transformers/gpt.py | 6 +-
tests/kit/model_zoo/transformers/t5.py | 10 +-
tests/kit/model_zoo/transformers/whisper.py | 4 +-
.../test_plugin/test_gemini_plugin.py | 2 +-
.../test_plugin/test_low_level_zero_plugin.py | 1 +
tests/test_shardformer/test_model/_utils.py | 13 +-
.../test_model/test_shard_bert.py | 11 +-
.../test_model/test_shard_blip2.py | 7 +-
.../test_model/test_shard_bloom.py | 8 +-
.../test_model/test_shard_chatglm.py | 8 +-
.../test_model/test_shard_gpt2.py | 1 -
.../test_model/test_shard_llama.py | 5 +-
.../test_model/test_shard_opt.py | 15 +-
.../test_model/test_shard_sam.py | 6 +-
.../test_model/test_shard_t5.py | 11 +-
.../test_model/test_shard_vit.py | 9 +-
.../test_model/test_shard_whisper.py | 8 +-
tests/test_utils/test_flash_attention.py | 26 +-
57 files changed, 2118 insertions(+), 1582 deletions(-)
rename colossalai/shardformer/examples/{shardformer_benchmark.py => convergence_benchmark.py} (100%)
rename colossalai/shardformer/examples/{shardformer_benchmark.sh => convergence_benchmark.sh} (76%)
create mode 100644 colossalai/shardformer/examples/performance_benchmark.py
create mode 100644 colossalai/shardformer/modeling/jit.py
create mode 100644 colossalai/shardformer/modeling/opt.py
create mode 100644 colossalai/shardformer/modeling/whisper.py
delete mode 100644 tests/kit/model_zoo/transformers/chatglm2_6b/configuration_chatglm.py
delete mode 100644 tests/kit/model_zoo/transformers/chatglm2_6b/modeling_chatglm.py
diff --git a/colossalai/kernel/cuda_native/flash_attention.py b/colossalai/kernel/cuda_native/flash_attention.py
index 3db7374509a0..91bef0908dbb 100644
--- a/colossalai/kernel/cuda_native/flash_attention.py
+++ b/colossalai/kernel/cuda_native/flash_attention.py
@@ -6,6 +6,7 @@
import math
import os
import subprocess
+import warnings
import torch
@@ -14,7 +15,7 @@
HAS_MEM_EFF_ATTN = True
except ImportError:
HAS_MEM_EFF_ATTN = False
- print('please install xformers from https://github.com/facebookresearch/xformers')
+ warnings.warn(f'please install xformers from https://github.com/facebookresearch/xformers')
if HAS_MEM_EFF_ATTN:
@@ -22,7 +23,12 @@
from einops import rearrange
from xformers.ops.fmha import MemoryEfficientAttentionCutlassOp
- from xformers.ops.fmha.attn_bias import BlockDiagonalMask, LowerTriangularMask, LowerTriangularMaskWithTensorBias
+ from xformers.ops.fmha.attn_bias import (
+ BlockDiagonalCausalMask,
+ BlockDiagonalMask,
+ LowerTriangularMask,
+ LowerTriangularMaskWithTensorBias,
+ )
from .scaled_softmax import AttnMaskType
@@ -86,11 +92,14 @@ def backward(ctx, grad_output):
class ColoAttention(torch.nn.Module):
- def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0):
+ def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, scale=None):
super().__init__()
assert embed_dim % num_heads == 0, \
f"the embed dim ({embed_dim}) is not divisible by the number of attention heads ({num_heads})."
- self.scale = 1 / math.sqrt(embed_dim // num_heads)
+ if scale is not None:
+ self.scale = scale
+ else:
+ self.scale = 1 / math.sqrt(embed_dim // num_heads)
self.dropout = dropout
@staticmethod
@@ -116,7 +125,7 @@ def forward(self,
bias: Optional[torch.Tensor] = None):
batch_size, tgt_len, src_len = query.shape[0], query.shape[1], key.shape[1]
attn_bias = None
- if attn_mask_type == AttnMaskType.padding: # bert style
+ if attn_mask_type and attn_mask_type.value % 2 == 1: # bert style
assert attn_mask is not None, \
f"attention mask {attn_mask} is not valid for attention mask type {attn_mask_type}."
assert attn_mask.dim() == 2, \
@@ -134,7 +143,10 @@ def forward(self,
if batch_size > 1:
query = rearrange(query, "b s ... -> c (b s) ...", c=1)
key, value = self.unpad(torch.stack([query, key, value], dim=2), kv_indices).unbind(dim=2)
- attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen)
+ if attn_mask_type == AttnMaskType.padding:
+ attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen)
+ elif attn_mask_type == AttnMaskType.paddedcausal:
+ attn_bias = BlockDiagonalCausalMask.from_seqlens(q_seqlen, kv_seqlen)
elif attn_mask_type == AttnMaskType.causal: # gpt style
attn_bias = LowerTriangularMask()
@@ -146,7 +158,7 @@ def forward(self,
out = memory_efficient_attention(query, key, value, attn_bias=attn_bias, p=self.dropout, scale=self.scale)
- if attn_mask_type == AttnMaskType.padding and batch_size > 1:
+ if attn_mask_type and attn_mask_type.value % 2 == 1 and batch_size > 1:
out = self.repad(out, q_indices, batch_size, tgt_len)
out = rearrange(out, 'b s h d -> b s (h d)')
diff --git a/colossalai/kernel/cuda_native/scaled_softmax.py b/colossalai/kernel/cuda_native/scaled_softmax.py
index 24e458bb3ea5..41cd4b20faa1 100644
--- a/colossalai/kernel/cuda_native/scaled_softmax.py
+++ b/colossalai/kernel/cuda_native/scaled_softmax.py
@@ -19,6 +19,7 @@
class AttnMaskType(enum.Enum):
padding = 1
causal = 2
+ paddedcausal = 3
class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
@@ -139,7 +140,7 @@ def is_kernel_available(self, mask, b, np, sq, sk):
if 0 <= sk <= 2048:
batch_per_block = self.get_batch_per_block(sq, sk, b, np)
- if self.attn_mask_type == AttnMaskType.causal:
+ if self.attn_mask_type.value > 1:
if attn_batches % batch_per_block == 0:
return True
else:
@@ -151,7 +152,7 @@ def forward_fused_softmax(self, input, mask):
b, np, sq, sk = input.size()
scale = self.scale if self.scale is not None else 1.0
- if self.attn_mask_type == AttnMaskType.causal:
+ if self.attn_mask_type.value > 1:
assert sq == sk, "causal mask is only for self attention"
# input is 3D tensor (attn_batches, sq, sk)
diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md
index 357e8ac3397e..1c11b4b85444 100644
--- a/colossalai/shardformer/README.md
+++ b/colossalai/shardformer/README.md
@@ -31,7 +31,7 @@
### Quick Start
-The sample API usage is given below:
+The sample API usage is given below(If you enable the use of flash attention, please install xformers.):
``` python
from colossalai.shardformer import ShardConfig, Shard
@@ -106,6 +106,20 @@ We will follow this roadmap to develop Shardformer:
- [ ] Multi-modal
- [x] SAM
- [x] BLIP-2
+- [ ] Flash Attention Support
+ - [ ] NLP
+ - [x] BERT
+ - [x] T5
+ - [x] LlaMa
+ - [x] GPT2
+ - [x] OPT
+ - [x] BLOOM
+ - [ ] GLM
+ - [ ] RoBERTa
+ - [ ] ALBERT
+ - [ ] ERNIE
+ - [ ] GPT Neo
+ - [ ] GPT-J
## 💡 API Design
@@ -378,11 +392,49 @@ pytest tests/test_shardformer
### System Performance
-To be added.
+We conducted [benchmark tests](./examples/performance_benchmark.py) to evaluate the performance improvement of Shardformer. We compared the training time between the original model and the shard model.
+
+We set the batch size to 4, the number of attention heads to 8, and the head dimension to 64. 'N_CTX' refers to the sequence length.
+
+In the case of using 2 GPUs, the training times are as follows.
+| N_CTX | org_model | shard_model |
+| :------: | :-----: | :-----: |
+| 256 | 11.2ms | 17.2ms |
+| 512 | 9.8ms | 19.5ms |
+| 1024 | 19.6ms | 18.9ms |
+| 2048 | 46.6ms | 30.8ms |
+| 4096 | 160.5ms | 90.4ms |
+
+
+
+
+
+
+
+In the case of using 4 GPUs, the training times are as follows.
+
+| N_CTX | org_model | shard_model |
+| :------: | :-----: | :-----: |
+| 256 | 10.0ms | 21.1ms |
+| 512 | 11.5ms | 20.2ms |
+| 1024 | 22.1ms | 20.6ms |
+| 2048 | 46.9ms | 24.8ms |
+| 4096 | 160.4ms | 68.0ms |
+
+
+
+
+
+
+
+
+
+As shown in the figures above, when the sequence length is around 1000 or greater, the parallel optimization of Shardformer for long sequences starts to become evident.
### Convergence
-To validate that training the model using shardformers does not impact its convergence. We [fine-tuned the BERT model](./examples/shardformer_benchmark.py) using both shardformer and non-shardformer approaches. We compared the accuracy, loss, F1 score of the training results.
+
+To validate that training the model using shardformers does not impact its convergence. We [fine-tuned the BERT model](./examples/convergence_benchmark.py) using both shardformer and non-shardformer approaches. We compared the accuracy, loss, F1 score of the training results.
| accuracy | f1 | loss | GPU number | model shard |
| :------: | :-----: | :-----: | :--------: | :---------: |
diff --git a/colossalai/shardformer/examples/shardformer_benchmark.py b/colossalai/shardformer/examples/convergence_benchmark.py
similarity index 100%
rename from colossalai/shardformer/examples/shardformer_benchmark.py
rename to colossalai/shardformer/examples/convergence_benchmark.py
diff --git a/colossalai/shardformer/examples/shardformer_benchmark.sh b/colossalai/shardformer/examples/convergence_benchmark.sh
similarity index 76%
rename from colossalai/shardformer/examples/shardformer_benchmark.sh
rename to colossalai/shardformer/examples/convergence_benchmark.sh
index f42b19a32d35..1c281abcda6d 100644
--- a/colossalai/shardformer/examples/shardformer_benchmark.sh
+++ b/colossalai/shardformer/examples/convergence_benchmark.sh
@@ -1,4 +1,4 @@
-torchrun --standalone --nproc_per_node=4 shardformer_benchmark.py \
+torchrun --standalone --nproc_per_node=4 convergence_benchmark.py \
--model "bert" \
--pretrain "bert-base-uncased" \
--max_epochs 1 \
diff --git a/colossalai/shardformer/examples/performance_benchmark.py b/colossalai/shardformer/examples/performance_benchmark.py
new file mode 100644
index 000000000000..9c7b76bcf0a6
--- /dev/null
+++ b/colossalai/shardformer/examples/performance_benchmark.py
@@ -0,0 +1,86 @@
+"""
+Shardformer Benchmark
+"""
+import torch
+import torch.distributed as dist
+import transformers
+import triton
+
+import colossalai
+from colossalai.shardformer import ShardConfig, ShardFormer
+
+
+def data_gen(batch_size, seq_length):
+ input_ids = torch.randint(0, seq_length, (batch_size, seq_length), dtype=torch.long)
+ attention_mask = torch.ones((batch_size, seq_length), dtype=torch.long)
+ return dict(input_ids=input_ids, attention_mask=attention_mask)
+
+
+def data_gen_for_sequence_classification(batch_size, seq_length):
+ # LM data gen
+ # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
+ data = data_gen(batch_size, seq_length)
+ data['labels'] = torch.ones((batch_size), dtype=torch.long)
+ return data
+
+
+MODEL_CONFIG = transformers.LlamaConfig(num_hidden_layers=4,
+ hidden_size=128,
+ intermediate_size=256,
+ num_attention_heads=4,
+ max_position_embeddings=128,
+ num_labels=16)
+BATCH, N_HEADS, N_CTX, D_HEAD = 4, 8, 4096, 64
+model_func = lambda: transformers.LlamaForSequenceClassification(MODEL_CONFIG)
+
+# vary seq length for fixed head and batch=4
+configs = [
+ triton.testing.Benchmark(x_names=['N_CTX'],
+ x_vals=[2**i for i in range(8, 13)],
+ line_arg='provider',
+ line_vals=['org_model', 'shard_model'],
+ line_names=['org_model', 'shard_model'],
+ styles=[('red', '-'), ('blue', '-')],
+ ylabel='ms',
+ plot_name=f'lama_for_sequence_classification-batch-{BATCH}',
+ args={
+ 'BATCH': BATCH,
+ 'dtype': torch.float16,
+ 'model_func': model_func
+ })
+]
+
+
+def train(model, data):
+ output = model(**data)
+ loss = output.logits.mean()
+ loss.backward()
+
+
+@triton.testing.perf_report(configs)
+def bench_shardformer(BATCH, N_CTX, provider, model_func, dtype=torch.float32, device="cuda"):
+ warmup = 10
+ rep = 100
+ # prepare data
+ data = data_gen_for_sequence_classification(BATCH, N_CTX)
+ data = {k: v.cuda() for k, v in data.items()}
+ model = model_func().to(device)
+ model.train()
+ if provider == "org_model":
+ fn = lambda: train(model, data)
+ ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
+ return ms
+ if provider == "shard_model":
+ shard_config = ShardConfig(enable_fused_normalization=True, enable_tensor_parallelism=True)
+ shard_former = ShardFormer(shard_config=shard_config)
+ sharded_model = shard_former.optimize(model).cuda()
+ fn = lambda: train(sharded_model, data)
+ ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
+ return ms
+
+
+# start benchmark, command:
+# torchrun --standalone --nproc_per_node=2 performance_benchmark.py
+if __name__ == "__main__":
+ colossalai.launch_from_torch({})
+ bench_shardformer.run(save_path='.', print_data=dist.get_rank() == 0)
diff --git a/colossalai/shardformer/modeling/bert.py b/colossalai/shardformer/modeling/bert.py
index 1b3c14d9d1c9..b9d4b5fda7af 100644
--- a/colossalai/shardformer/modeling/bert.py
+++ b/colossalai/shardformer/modeling/bert.py
@@ -1,5 +1,6 @@
+import math
import warnings
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Dict, List, Optional, Tuple
import torch
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
@@ -962,3 +963,138 @@ def bert_for_question_answering_forward(
else:
hidden_states = outputs.get('hidden_states')
return {'hidden_states': hidden_states}
+
+
+def get_bert_flash_attention_forward():
+
+ try:
+ from xformers.ops import memory_efficient_attention as me_attention
+ except:
+ raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")
+ from transformers.models.bert.modeling_bert import BertAttention
+
+ def forward(
+ self: BertAttention,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor]:
+ mixed_query_layer = self.query(hidden_states)
+
+ # If this is instantiated as a cross-attention module, the keys
+ # and values come from an encoder; the attention mask needs to be
+ # such that the encoder's padding tokens are not attended to.
+ is_cross_attention = encoder_hidden_states is not None
+
+ if is_cross_attention and past_key_value is not None:
+ # reuse k,v, cross_attentions
+ key_layer = past_key_value[0]
+ value_layer = past_key_value[1]
+ attention_mask = encoder_attention_mask
+ elif is_cross_attention:
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
+ attention_mask = encoder_attention_mask
+ elif past_key_value is not None:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
+ else:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+
+ use_cache = past_key_value is not None
+ if self.is_decoder:
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_layer, value_layer)
+
+ final_attention_mask = None
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+ query_length, key_length = query_layer.shape[2], key_layer.shape[2]
+ if use_cache:
+ position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(-1, 1)
+ else:
+ position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
+ position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
+ distance = position_ids_l - position_ids_r
+
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
+
+ if self.position_embedding_type == "relative_key":
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+ final_attention_mask = relative_position_scores
+ elif self.position_embedding_type == "relative_key_query":
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
+ final_attention_mask = relative_position_scores_query + relative_position_scores_key
+
+ scale = 1 / math.sqrt(self.attention_head_size)
+ if attention_mask is not None:
+ if final_attention_mask != None:
+ final_attention_mask = final_attention_mask * scale + attention_mask
+ else:
+ final_attention_mask = attention_mask
+ batch_size, src_len = query_layer.size()[0], query_layer.size()[2]
+ tgt_len = key_layer.size()[2]
+ final_attention_mask = final_attention_mask.expand(batch_size, self.num_attention_heads, src_len, tgt_len)
+
+ query_layer = query_layer.permute(0, 2, 1, 3).contiguous()
+ key_layer = key_layer.permute(0, 2, 1, 3).contiguous()
+ value_layer = value_layer.permute(0, 2, 1, 3).contiguous()
+
+ context_layer = me_attention(query_layer,
+ key_layer,
+ value_layer,
+ attn_bias=final_attention_mask,
+ p=self.dropout.p,
+ scale=scale)
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(new_context_layer_shape)
+
+ outputs = (context_layer, None)
+
+ if self.is_decoder:
+ outputs = outputs + (past_key_value,)
+ return outputs
+
+ return forward
+
+
+def get_jit_fused_bert_self_output_forward():
+
+ from transformers.models.bert.modeling_bert import BertSelfOutput
+
+ def forward(self: BertSelfOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training)
+ hidden_states = self.LayerNorm(hidden_states)
+ return hidden_states
+
+ return forward
+
+
+def get_jit_fused_bert_output_forward():
+
+ from transformers.models.bert.modeling_bert import BertOutput
+
+ def forward(self: BertOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training)
+ hidden_states = self.LayerNorm(hidden_states)
+ return hidden_states
+
+ return forward
diff --git a/colossalai/shardformer/modeling/blip2.py b/colossalai/shardformer/modeling/blip2.py
index b7945423ae83..c5c6b14ba993 100644
--- a/colossalai/shardformer/modeling/blip2.py
+++ b/colossalai/shardformer/modeling/blip2.py
@@ -1,3 +1,4 @@
+import math
from typing import Optional, Tuple, Union
import torch
@@ -58,3 +59,62 @@ def forward(
return outputs
return forward
+
+
+def get_blip2_flash_attention_forward():
+
+ from transformers.models.blip_2.modeling_blip_2 import Blip2Attention
+
+ from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention
+
+ def forward(
+ self: Blip2Attention,
+ hidden_states: torch.Tensor,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+
+ bsz, tgt_len, embed_dim = hidden_states.size()
+ mixed_qkv = self.qkv(hidden_states)
+ mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, -1).permute(2, 0, 1, 3, 4)
+ query_states, key_states, value_states = mixed_qkv[0], mixed_qkv[1], mixed_qkv[2]
+
+ attention = ColoAttention(embed_dim=self.embed_dim,
+ num_heads=self.num_heads,
+ dropout=self.dropout.p,
+ scale=self.scale)
+ context_layer = attention(query_states, key_states, value_states)
+
+ output = self.projection(context_layer)
+ outputs = (output, None)
+
+ return outputs
+
+ return forward
+
+
+def get_jit_fused_blip2_QFormer_self_output_forward():
+
+ from transformers.models.blip_2.modeling_blip_2 import Blip2QFormerSelfOutput
+
+ def forward(self: Blip2QFormerSelfOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training)
+ hidden_states = self.LayerNorm(hidden_states)
+ return hidden_states
+
+ return forward
+
+
+def get_jit_fused_blip2_QFormer_output_forward():
+
+ from transformers.models.blip_2.modeling_blip_2 import Blip2QFormerOutput
+
+ def forward(self: Blip2QFormerOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training)
+ hidden_states = self.LayerNorm(hidden_states)
+ return hidden_states
+
+ return forward
diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py
index 76948fc70439..57c45bc6adfa 100644
--- a/colossalai/shardformer/modeling/bloom.py
+++ b/colossalai/shardformer/modeling/bloom.py
@@ -5,6 +5,7 @@
import torch.distributed as dist
from torch.distributed import ProcessGroup
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+from torch.nn import functional as F
from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
@@ -675,3 +676,223 @@ def bloom_for_question_answering_forward(
else:
hidden_states = outputs.get('hidden_states')
return {'hidden_states': hidden_states}
+
+
+def get_bloom_flash_attention_forward(enabel_jit_fused=False):
+
+ try:
+ from xformers.ops import memory_efficient_attention as me_attention
+ except:
+ raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")
+ from transformers.models.bloom.modeling_bloom import BloomAttention
+
+ def forward(
+ self: BloomAttention,
+ hidden_states: torch.Tensor,
+ residual: torch.Tensor,
+ alibi: torch.Tensor,
+ attention_mask: torch.Tensor,
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ use_cache: bool = False,
+ output_attentions: bool = False,
+ ):
+
+ fused_qkv = self.query_key_value(hidden_states)
+ (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
+ batch_size, tgt_len, _ = hidden_states.size()
+ assert tgt_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4."
+
+ _, kv_length, _, _ = key_layer.size()
+
+ proj_shape = (batch_size, tgt_len, self.num_heads, self.head_dim)
+ query_layer = query_layer.contiguous().view(*proj_shape)
+ key_layer = key_layer.contiguous().view(*proj_shape)
+ value_layer = value_layer.contiguous().view(*proj_shape)
+
+ if layer_past is not None:
+ past_key, past_value = layer_past
+ # concatenate along seq_length dimension:
+ # - key: [batch_size * self.num_heads, head_dim, kv_length]
+ # - value: [batch_size * self.num_heads, kv_length, head_dim]
+ key_layer = torch.cat((past_key, key_layer), dim=1)
+ value_layer = torch.cat((past_value, value_layer), dim=1)
+
+ if use_cache is True:
+ present = (key_layer, value_layer)
+ else:
+ present = None
+
+ tgt_len = key_layer.size()[1]
+
+ attention_numerical_mask = torch.zeros((batch_size, self.num_heads, tgt_len, kv_length),
+ dtype=torch.float32,
+ device=query_layer.device,
+ requires_grad=True)
+ attention_numerical_mask = attention_numerical_mask + alibi.view(batch_size, self.num_heads, 1,
+ kv_length) * self.beta
+ attention_numerical_mask = torch.masked_fill(attention_numerical_mask, attention_mask,
+ torch.finfo(torch.float32).min)
+
+ context_layer = me_attention(query_layer,
+ key_layer,
+ value_layer,
+ attn_bias=attention_numerical_mask,
+ scale=self.inv_norm_factor,
+ p=self.attention_dropout.p)
+ context_layer = context_layer.reshape(-1, kv_length, self.hidden_size)
+ if self.pretraining_tp > 1 and self.slow_but_exact:
+ slices = self.hidden_size / self.pretraining_tp
+ output_tensor = torch.zeros_like(context_layer)
+ for i in range(self.pretraining_tp):
+ output_tensor = output_tensor + F.linear(
+ context_layer[:, :, int(i * slices):int((i + 1) * slices)],
+ self.dense.weight[:, int(i * slices):int((i + 1) * slices)],
+ )
+ else:
+ output_tensor = self.dense(context_layer)
+
+ # TODO to replace with the bias_dropout_add function in jit
+ output_tensor = self.dropout_add(output_tensor, residual, self.hidden_dropout, self.training)
+ outputs = (output_tensor, present, None)
+
+ return outputs
+
+ return forward
+
+
+def get_jit_fused_bloom_attention_forward():
+
+ from transformers.models.bloom.modeling_bloom import BloomAttention
+
+ def forward(
+ self: BloomAttention,
+ hidden_states: torch.Tensor,
+ residual: torch.Tensor,
+ alibi: torch.Tensor,
+ attention_mask: torch.Tensor,
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ use_cache: bool = False,
+ output_attentions: bool = False,
+ ):
+ fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
+
+ # 3 x [batch_size, seq_length, num_heads, head_dim]
+ (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
+
+ batch_size, q_length, _, _ = query_layer.shape
+
+ query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
+ key_layer = key_layer.permute(0, 2, 3, 1).reshape(batch_size * self.num_heads, self.head_dim, q_length)
+ value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
+ if layer_past is not None:
+ past_key, past_value = layer_past
+ # concatenate along seq_length dimension:
+ # - key: [batch_size * self.num_heads, head_dim, kv_length]
+ # - value: [batch_size * self.num_heads, kv_length, head_dim]
+ key_layer = torch.cat((past_key, key_layer), dim=2)
+ value_layer = torch.cat((past_value, value_layer), dim=1)
+
+ _, _, kv_length = key_layer.shape
+
+ if use_cache is True:
+ present = (key_layer, value_layer)
+ else:
+ present = None
+
+ # [batch_size * num_heads, q_length, kv_length]
+ # we use `torch.Tensor.baddbmm` instead of `torch.baddbmm` as the latter isn't supported by TorchScript v1.11
+ matmul_result = alibi.baddbmm(
+ batch1=query_layer,
+ batch2=key_layer,
+ beta=self.beta,
+ alpha=self.inv_norm_factor,
+ )
+
+ # change view to [batch_size, num_heads, q_length, kv_length]
+ attention_scores = matmul_result.view(batch_size, self.num_heads, q_length, kv_length)
+
+ # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
+ input_dtype = attention_scores.dtype
+ # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
+ if input_dtype == torch.float16:
+ attention_scores = attention_scores.to(torch.float)
+ attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
+ attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype)
+
+ # [batch_size, num_heads, q_length, kv_length]
+ attention_probs = self.attention_dropout(attention_probs)
+
+ if head_mask is not None:
+ attention_probs = attention_probs * head_mask
+
+ # change view [batch_size x num_heads, q_length, kv_length]
+ attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, kv_length)
+
+ # matmul: [batch_size * num_heads, q_length, head_dim]
+ context_layer = torch.bmm(attention_probs_reshaped, value_layer)
+
+ # change view [batch_size, num_heads, q_length, head_dim]
+ context_layer = self._merge_heads(context_layer)
+
+ # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
+ if self.pretraining_tp > 1 and self.slow_but_exact:
+ slices = self.hidden_size / self.pretraining_tp
+ output_tensor = torch.zeros_like(context_layer)
+ for i in range(self.pretraining_tp):
+ output_tensor = output_tensor + F.linear(
+ context_layer[:, :, int(i * slices):int((i + 1) * slices)],
+ self.dense.weight[:, int(i * slices):int((i + 1) * slices)],
+ )
+ else:
+ output_tensor = self.dense(context_layer)
+
+ output_tensor = self.dropout_add(output_tensor, residual, self.hidden_dropout, self.training)
+
+ outputs = (output_tensor, present)
+ if output_attentions:
+ outputs += (attention_probs,)
+
+ return outputs
+
+ return forward
+
+
+def get_jit_fused_bloom_mlp_forward():
+
+ from transformers.models.bloom.modeling_bloom import BloomMLP
+
+ def forward(self: BloomMLP, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states))
+
+ if self.pretraining_tp > 1 and self.slow_but_exact:
+ intermediate_output = torch.zeros_like(residual)
+ slices = self.dense_4h_to_h.weight.shape[-1] / self.pretraining_tp
+ for i in range(self.pretraining_tp):
+ intermediate_output = intermediate_output + F.linear(
+ hidden_states[:, :, int(i * slices):int((i + 1) * slices)],
+ self.dense_4h_to_h.weight[:, int(i * slices):int((i + 1) * slices)],
+ )
+ else:
+ intermediate_output = self.dense_4h_to_h(hidden_states)
+ output = self.dropout_add(intermediate_output, residual, self.hidden_dropout, self.training)
+ return output
+
+ return forward
+
+
+def get_jit_fused_bloom_gelu_forward():
+
+ from transformers.models.bloom.modeling_bloom import BloomGelu
+
+ from colossalai.kernel.jit.bias_gelu import GeLUFunction as JitGeLUFunction
+
+ def forward(self: BloomGelu, x: torch.Tensor) -> torch.Tensor:
+ bias = torch.zeros_like(x)
+ if self.training:
+ return JitGeLUFunction.apply(x, bias)
+ else:
+ return self.bloom_gelu_forward(x, bias)
+
+ return forward
diff --git a/colossalai/shardformer/modeling/chatglm.py b/colossalai/shardformer/modeling/chatglm.py
index 0bb8bdc58218..3d453c3bd6db 100644
--- a/colossalai/shardformer/modeling/chatglm.py
+++ b/colossalai/shardformer/modeling/chatglm.py
@@ -17,6 +17,116 @@
)
+def get_flash_core_attention_forward():
+
+ from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention
+
+ from .chatglm2_6b.modeling_chatglm import CoreAttention
+
+ def forward(self: CoreAttention, query_layer, key_layer, value_layer, attention_mask):
+ pytorch_major_version = int(torch.__version__.split(".")[0])
+ if pytorch_major_version >= 2:
+ query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
+ if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
+ context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer,
+ key_layer,
+ value_layer,
+ is_causal=True)
+ else:
+ if attention_mask is not None:
+ attention_mask = ~attention_mask
+ context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
+ attention_mask)
+ context_layer = context_layer.permute(2, 0, 1, 3)
+ new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
+ context_layer = context_layer.reshape(*new_context_layer_shape)
+ else:
+ # Raw attention scores
+ query_layer = query_layer.permute(1, 0, 2, 3).contiguous()
+ key_layer = key_layer.permute(1, 0, 2, 3).contiguous()
+ value_layer = value_layer.permute(1, 0, 2, 3).contiguous()
+
+ scale = 1.0 / self.norm_factor
+ if self.coeff is not None:
+ scale = scale * self.coeff
+
+ flash_attention_mask = None
+ attn_mask_type = None
+ if attention_mask is None:
+ attn_mask_type = AttnMaskType.causal
+ else:
+ flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
+ attn_mask_type = AttnMaskType.paddedcausal
+
+ attention = ColoAttention(embed_dim=self.hidden_size_per_partition,
+ num_heads=self.num_attention_heads_per_partition,
+ dropout=self.attention_dropout.p,
+ scale=scale)
+ context_layer = attention(query_layer,
+ key_layer,
+ value_layer,
+ attn_mask=flash_attention_mask,
+ attn_mask_type=attn_mask_type)
+
+ context_layer = context_layer.permute(1, 0, -1).contiguous()
+
+ return context_layer
+
+ return forward
+
+
+def get_jit_fused_glm_block_forward():
+
+ from .chatglm2_6b.modeling_chatglm import GLMBlock
+
+ def forward(
+ self: GLMBlock,
+ hidden_states,
+ attention_mask,
+ rotary_pos_emb,
+ kv_cache=None,
+ use_cache=True,
+ ):
+ # hidden_states: [s, b, h]
+ # Layer norm at the beginning of the transformer layer.
+ layernorm_output = self.input_layernorm(hidden_states)
+ # Self attention.
+ attention_output, kv_cache = self.self_attention(
+ layernorm_output,
+ attention_mask,
+ rotary_pos_emb,
+ kv_cache=kv_cache,
+ use_cache=use_cache,
+ )
+
+ # Residual connection.
+ if self.apply_residual_connection_post_layernorm:
+ residual = layernorm_output
+ else:
+ residual = hidden_states
+
+ layernorm_input = self.dropout_add(attention_output, residual, self.hidden_dropout, self.training)
+
+ # Layer norm post the self attention.
+ layernorm_output = self.post_attention_layernorm(layernorm_input)
+
+ # MLP.
+ mlp_output = self.mlp(layernorm_output)
+
+ # Second residual connection.
+ if self.apply_residual_connection_post_layernorm:
+ residual = layernorm_output
+ else:
+ residual = layernorm_input
+
+ output = self.dropout_add(mlp_output, residual, self.hidden_dropout, self.training)
+
+ return output, kv_cache
+
+ return forward
+
+
+
class ChatGLMPipelineForwards:
'''
This class serves as a micro library for ChatGLM model forwards under pipeline parallelism.
diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py
index dc5a81dc912b..e02581fbaa9b 100644
--- a/colossalai/shardformer/modeling/gpt2.py
+++ b/colossalai/shardformer/modeling/gpt2.py
@@ -668,3 +668,88 @@ def gpt2_for_sequence_classification_forward(
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
+
+
+def get_gpt2_flash_attention_forward():
+
+ from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
+
+ from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention
+
+ def split_heads(tensor, num_heads, attn_head_size):
+ """
+ Splits hidden_size dim into attn_head_size and num_heads
+ """
+ new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
+ tensor = tensor.view(new_shape)
+ return tensor
+
+ def forward(
+ self: GPT2Attention,
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = False,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
+ _, tgt_len, _ = hidden_states.size()
+ assert tgt_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4."
+
+ if encoder_hidden_states is not None:
+ if not hasattr(self, "q_attn"):
+ raise ValueError(
+ "If class is used as cross attention, the weights `q_attn` have to be defined. "
+ "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`.")
+
+ query = self.q_attn(hidden_states)
+ key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
+ attention_mask = encoder_attention_mask
+ else:
+ query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
+
+ query = split_heads(query, self.num_heads, self.head_dim)
+ key = split_heads(key, self.num_heads, self.head_dim)
+ value = split_heads(value, self.num_heads, self.head_dim)
+
+ if layer_past is not None:
+ past_key, past_value = layer_past
+ key = torch.cat((past_key, key), dim=1)
+ value = torch.cat((past_value, value), dim=1)
+
+ if use_cache is True:
+ present = (key, value)
+ else:
+ present = None
+
+ if not self.is_cross_attention:
+ attn_mask_type = AttnMaskType.causal
+ flash_attention_mask = None
+ if attention_mask != None:
+ if attn_mask_type == AttnMaskType.causal:
+ attn_mask_type == AttnMaskType.paddedcausal
+ else:
+ attn_mask_type = AttnMaskType.padding
+ flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
+
+ scale = value.size(-1)**-0.5
+ if self.scale_attn_by_inverse_layer_idx:
+ scale = scale * (1 / float(self.layer_idx + 1))
+
+ # use coloattention
+ attention = ColoAttention(embed_dim=self.embed_dim,
+ num_heads=self.num_heads,
+ dropout=self.attn_dropout.p,
+ scale=scale)
+
+ attn_output = attention(query, key, value, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type)
+
+ attn_output = self.c_proj(attn_output)
+ attn_output = self.resid_dropout(attn_output)
+ outputs = (attn_output, present, None)
+
+ return outputs
+
+ return forward
diff --git a/colossalai/shardformer/modeling/jit.py b/colossalai/shardformer/modeling/jit.py
new file mode 100644
index 000000000000..6434348ef823
--- /dev/null
+++ b/colossalai/shardformer/modeling/jit.py
@@ -0,0 +1,34 @@
+import torch
+
+
+def get_dropout_add_func():
+
+ from transformers.models.bloom.modeling_bloom import dropout_add
+
+ def self_dropout_add(self, x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
+ return dropout_add(x, residual, prob, training)
+
+ return self_dropout_add
+
+
+def get_jit_fused_dropout_add_func():
+
+ from colossalai.kernel.jit import bias_dropout_add_fused_inference, bias_dropout_add_fused_train
+
+ def self_dropout_add(self, x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
+ bias = torch.zeros_like(x)
+ if training:
+ return bias_dropout_add_fused_train(x, bias, residual, prob)
+ return bias_dropout_add_fused_inference(x, bias, residual, prob)
+
+ return self_dropout_add
+
+
+def get_jit_fused_gelu_forward_func():
+
+ from colossalai.kernel.jit.bias_gelu import bias_gelu
+
+ def bloom_gelu_forward(x: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
+ return bias_gelu(bias, x)
+
+ return bloom_gelu_forward
diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py
index e1ed5f64665c..9d6335503b36 100644
--- a/colossalai/shardformer/modeling/llama.py
+++ b/colossalai/shardformer/modeling/llama.py
@@ -1,4 +1,4 @@
-from typing import Callable, List, Optional
+from typing import Callable, List, Optional, Tuple
import torch
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
@@ -386,3 +386,67 @@ def llama_for_sequence_classification_forward(
else:
hidden_states = transformer_outputs.get('hidden_states')
return {'hidden_states': hidden_states}
+
+
+def get_llama_flash_attention_forward():
+
+ from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
+
+ from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention
+
+ def forward(
+ self: LlamaAttention,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ bsz, q_len, _ = hidden_states.size()
+ assert q_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4."
+
+ query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value[0].shape[-2]
+
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+
+ if past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+
+ past_key_value = (key_states, value_states) if use_cache else None
+
+ me_input_shape = (bsz, q_len, self.num_heads, self.head_dim)
+ query_states = query_states.transpose(1, 2).contiguous().view(*me_input_shape)
+ key_states = key_states.transpose(1, 2).contiguous().view(*me_input_shape)
+ value_states = value_states.transpose(1, 2).contiguous().view(*me_input_shape)
+
+ flash_attention_mask = None
+ attn_mask_type = AttnMaskType.causal
+ if attention_mask != None:
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}")
+ flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
+ attn_mask_type = AttnMaskType.paddedcausal
+
+ attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads)
+ attn_output = attention(query_states,
+ key_states,
+ value_states,
+ attn_mask=flash_attention_mask,
+ attn_mask_type=attn_mask_type)
+
+ attn_output = self.o_proj(attn_output)
+
+ return attn_output, None, past_key_value
+
+ return forward
diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py
new file mode 100644
index 000000000000..299dfb5562f3
--- /dev/null
+++ b/colossalai/shardformer/modeling/opt.py
@@ -0,0 +1,174 @@
+from typing import Optional, Tuple
+
+import torch
+from torch import nn
+
+
+def get_opt_flash_attention_forward():
+
+ from transformers.models.opt.modeling_opt import OPTAttention
+
+ from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention
+
+ def forward(
+ self: OPTAttention,
+ hidden_states: torch.Tensor,
+ key_value_states: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+ bsz, tgt_len, _ = hidden_states.size()
+ assert tgt_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4."
+
+ attention_input_shape = (bsz, -1, self.num_heads, self.head_dim)
+ # get query proj
+ query_states = self.q_proj(hidden_states).view(*attention_input_shape)
+ # get key, value proj
+ if is_cross_attention and past_key_value is not None:
+ # reuse k, v, cross_attentions
+ key_states = past_key_value[0].transpose(1, 2).contiguous().view(*attention_input_shape)
+ value_states = past_key_value[1].transpose(1, 2).contiguous().view(*attention_input_shape)
+ elif is_cross_attention:
+ # cross_attentions
+ key_states = self.k_proj(key_value_states).view(*attention_input_shape)
+ value_states = self.v_proj(key_value_states).view(*attention_input_shape)
+ elif past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = self.k_proj(hidden_states).view(*attention_input_shape)
+ value_states = self.v_proj(hidden_states).view(*attention_input_shape)
+ key_states = torch.cat([past_key_value[0], key_states], dim=1)
+ value_states = torch.cat([past_key_value[1], value_states], dim=1)
+ else:
+ # self_attention
+ key_states = self.k_proj(hidden_states).view(*attention_input_shape)
+ value_states = self.v_proj(hidden_states).view(*attention_input_shape)
+
+ if self.is_decoder:
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_states, value_states)
+
+ src_len = key_states.size(1)
+ if layer_head_mask != None:
+ if layer_head_mask.size() != (self.num_heads,):
+ raise ValueError(f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
+ f" {layer_head_mask.size()}")
+
+ flash_attention_mask = None
+ attn_mask_type = AttnMaskType.causal
+ if attention_mask != None:
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}")
+ flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
+ attn_mask_type = AttnMaskType.paddedcausal
+
+ attention = ColoAttention(embed_dim=self.embed_dim,
+ num_heads=self.num_heads,
+ dropout=self.dropout,
+ scale=self.scaling)
+ attn_output = attention(query_states,
+ key_states,
+ value_states,
+ attn_mask=flash_attention_mask,
+ attn_mask_type=attn_mask_type)
+
+ attn_output = self.out_proj(attn_output)
+ return attn_output, None, past_key_value
+
+ return forward
+
+
+def get_jit_fused_opt_decoder_layer_forward():
+
+ from transformers.models.opt.modeling_opt import OPTDecoderLayer
+
+ def forward(
+ self: OPTDecoderLayer,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ layer_head_mask (`torch.FloatTensor`, *optional*): mask for attention heads in a given layer of size
+ `(encoder_attention_heads,)`.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+ """
+
+ residual = hidden_states
+
+ # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
+ if self.do_layer_norm_before:
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ past_key_value=past_key_value,
+ attention_mask=attention_mask,
+ layer_head_mask=layer_head_mask,
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training)
+
+ # 350m applies layer norm AFTER attention
+ if not self.do_layer_norm_before:
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+
+ # Fully Connected
+ hidden_states_shape = hidden_states.shape
+ hidden_states = hidden_states.reshape(-1, hidden_states.size(-1))
+ residual = hidden_states
+
+ # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
+ if self.do_layer_norm_before:
+ hidden_states = self.final_layer_norm(hidden_states)
+
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+
+ hidden_states = self.fc2(hidden_states)
+
+ hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training).view(hidden_states_shape)
+
+ # 350m applies layer norm AFTER attention
+ if not self.do_layer_norm_before:
+ hidden_states = self.final_layer_norm(hidden_states)
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+ return forward
diff --git a/colossalai/shardformer/modeling/sam.py b/colossalai/shardformer/modeling/sam.py
index 63ebfe89d5fa..c40c02ec411a 100644
--- a/colossalai/shardformer/modeling/sam.py
+++ b/colossalai/shardformer/modeling/sam.py
@@ -1,4 +1,9 @@
+import math
+from typing import Tuple
+
import torch
+import torch.nn.functional as F
+from torch import Tensor
def forward_fn():
@@ -37,3 +42,162 @@ def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch
return outputs
return forward
+
+
+def get_sam_flash_attention_forward():
+
+ from transformers.models.sam.modeling_sam import SamAttention
+ try:
+ from xformers.ops import memory_efficient_attention as me_attention
+ except:
+ raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")
+
+ def _separate_heads(hidden_states: Tensor, num_attention_heads: int) -> Tensor:
+ batch, point_batch_size, n_tokens, channel = hidden_states.shape
+ c_per_head = channel // num_attention_heads
+ hidden_states = hidden_states.reshape(batch * point_batch_size, n_tokens, num_attention_heads, c_per_head)
+ return hidden_states
+
+ def _recombine_heads(hidden_states: Tensor, point_batch_size: int) -> Tensor:
+ batch, n_tokens, n_heads, c_per_head = hidden_states.shape
+ return hidden_states.reshape(batch // point_batch_size, point_batch_size, n_tokens, n_heads * c_per_head)
+
+ def forward(self: SamAttention,
+ query: Tensor,
+ key: Tensor,
+ value: Tensor,
+ attention_similarity: Tensor = None) -> Tensor:
+ # Input projections
+ query = self.q_proj(query)
+ key = self.k_proj(key)
+ value = self.v_proj(value)
+
+ point_batch_size = query.shape[1]
+ # Separate into heads
+ query = _separate_heads(query, self.num_attention_heads)
+ key = _separate_heads(key, self.num_attention_heads)
+ value = _separate_heads(value, self.num_attention_heads)
+
+ # SamAttention
+ _, _, _, c_per_head = query.shape
+ bias = None
+ if attention_similarity is not None:
+ bias = attention_similarity
+
+ scale = 1.0 / math.sqrt(c_per_head)
+ out = me_attention(query, key, value, attn_bias=bias, scale=scale)
+
+ out = _recombine_heads(out, point_batch_size)
+ out = self.out_proj(out)
+
+ return out
+
+ return forward
+
+
+def get_sam_vision_flash_attention_forward():
+
+ from transformers.models.sam.modeling_sam import SamVisionAttention
+ try:
+ from xformers.ops import memory_efficient_attention as me_attention
+ except:
+ raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")
+
+ def add_decomposed_rel_pos(
+ query: torch.Tensor,
+ rel_pos_h: torch.Tensor,
+ rel_pos_w: torch.Tensor,
+ q_size: Tuple[int, int],
+ k_size: Tuple[int, int],
+ ) -> torch.Tensor:
+ """
+ Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
+ https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py
+
+ Args:
+ attn (`torch.Tensor`):
+ attention map.
+ query (`torch.Tensor`):
+ query q in the attention layer with shape (batch_size, query_height * query_width, channel).
+ rel_pos_h (`torch.Tensor`):
+ relative position embeddings (Lh, channel) for height axis.
+ rel_pos_w (`torch.Tensor`):
+ relative position embeddings (Lw, channel) for width axis.
+ q_size (tuple):
+ spatial sequence size of query q with (query_height, query_width).
+ k_size (tuple):
+ spatial sequence size of key k with (key_height, key_width).
+
+ Returns:
+ attn (`torch.Tensor`):
+ attention map with added relative positional embeddings.
+ """
+
+ query_height, query_width = q_size
+ key_height, key_width = k_size
+ relative_position_height = get_rel_pos(query_height, key_height, rel_pos_h)
+ relative_position_width = get_rel_pos(query_width, key_width, rel_pos_w)
+
+ batch_size, _, nHead, dim = query.shape
+ reshaped_query = query.transpose(1, 2).reshape(batch_size * nHead, query_height, query_width, dim)
+ rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height)
+ rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width)
+ rel_pos = rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
+ rel_pos = rel_pos.reshape(batch_size, nHead, query_height * query_width, key_height * key_width)
+ return rel_pos
+
+ def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
+ """
+ Get relative positional embeddings according to the relative positions of
+ query and key sizes.
+
+ Args:
+ q_size (int):
+ size of the query.
+ k_size (int):
+ size of key k.
+ rel_pos (`torch.Tensor`):
+ relative position embeddings (L, channel).
+
+ Returns:
+ Extracted positional embeddings according to relative positions.
+ """
+ max_rel_dist = int(2 * max(q_size, k_size) - 1)
+ # Interpolate rel pos.
+ rel_pos_resized = F.interpolate(
+ rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
+ size=max_rel_dist,
+ mode="linear",
+ )
+ rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
+
+ # Scale the coords with short length if shapes for q and k are different.
+ q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
+ k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
+ relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
+
+ return rel_pos_resized[relative_coords.long()]
+
+ def forward(self: SamVisionAttention, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
+ batch_size, height, width, _ = hidden_states.shape
+ # qkv with shape (3, batch_size, nHead, height * width, channel)
+ qkv = (self.qkv(hidden_states).reshape(batch_size, height * width, 3, self.num_attention_heads,
+ -1).permute(2, 0, 1, 3, 4))
+
+ query, key, value = qkv.reshape(3, batch_size, height * width, self.num_attention_heads, -1).unbind(0)
+
+ rel_pos = None
+ if self.use_rel_pos:
+ rel_pos = add_decomposed_rel_pos(query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width))
+
+ attn_output = me_attention(query, key, value, attn_bias=rel_pos, p=self.dropout, scale=self.scale)
+
+ attn_output = attn_output.reshape(batch_size, height, width, -1)
+
+ attn_output = self.proj(attn_output)
+
+ outputs = (attn_output, None)
+
+ return outputs
+
+ return forward
diff --git a/colossalai/shardformer/modeling/t5.py b/colossalai/shardformer/modeling/t5.py
index 7eb4d17928d6..0b3486e87c7e 100644
--- a/colossalai/shardformer/modeling/t5.py
+++ b/colossalai/shardformer/modeling/t5.py
@@ -587,3 +587,209 @@ def t5_encoder_model_forward(
decoder_starting_stage=decoder_starting_stage)
return outputs
+
+
+def get_t5_flash_attention_forward():
+
+ try:
+ from xformers.ops import memory_efficient_attention as me_attention
+ except:
+ raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")
+ from transformers.models.t5.modeling_t5 import T5Attention
+
+ def forward(
+ self: T5Attention,
+ hidden_states: torch.Tensor,
+ mask: Optional[torch.Tensor] = None,
+ key_value_states: Optional[torch.Tensor] = None,
+ position_bias: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ query_length: Optional[int] = None,
+ use_cache: bool = False,
+ output_attentions: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
+ """
+ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
+ """
+ # Input is (batch_size, seq_length, dim)
+ # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
+ # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
+ batch_size, seq_length = hidden_states.shape[:2]
+
+ real_seq_length = seq_length
+
+ if past_key_value is not None:
+ if len(past_key_value) != 2:
+ raise ValueError(
+ f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
+ )
+ real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
+
+ key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
+
+ def shape(states):
+ """projection"""
+ return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim)
+
+ def unshape(states):
+ """reshape"""
+ return states.view(batch_size, -1, self.inner_dim)
+
+ def project(hidden_states, proj_layer, key_value_states, past_key_value):
+ """projects hidden states correctly to key/query states"""
+ if key_value_states is None:
+ # self-attn
+ # (batch_size, n_heads, seq_length, dim_per_head)
+ hidden_states = shape(proj_layer(hidden_states))
+ elif past_key_value is None:
+ # cross-attn
+ # (batch_size, n_heads, seq_length, dim_per_head)
+ hidden_states = shape(proj_layer(key_value_states))
+
+ if past_key_value is not None:
+ if key_value_states is None:
+ # self-attn
+ # (batch_size, n_heads, key_length, dim_per_head)
+ hidden_states = torch.cat([past_key_value, hidden_states], dim=1)
+ elif past_key_value.shape[1] != key_value_states.shape[1]:
+ # checking that the `sequence_length` of the `past_key_value` is the same as
+ # the provided `key_value_states` to support prefix tuning
+ # cross-attn
+ # (batch_size, n_heads, seq_length, dim_per_head)
+ hidden_states = shape(proj_layer(key_value_states))
+ else:
+ # cross-attn
+ hidden_states = past_key_value
+ return hidden_states
+
+ # get query states
+ query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head)
+
+ # get key/value states
+ key_states = project(hidden_states, self.k, key_value_states,
+ past_key_value[0] if past_key_value is not None else None)
+ value_states = project(hidden_states, self.v, key_value_states,
+ past_key_value[1] if past_key_value is not None else None)
+
+ if position_bias is None:
+ if not self.has_relative_attention_bias:
+ position_bias = torch.zeros((1, self.n_heads, real_seq_length, key_length),
+ device=query_states.device,
+ dtype=query_states.dtype)
+ if self.gradient_checkpointing and self.training:
+ position_bias.requires_grad = True
+ else:
+ position_bias = self.compute_bias(real_seq_length, key_length, device=query_states.device)
+
+ # if key and values are already calculated
+ # we want only the last query position bias
+ if past_key_value is not None:
+ position_bias = position_bias[:, :, -hidden_states.size(1):, :]
+
+ if mask is not None:
+ position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)
+
+ if self.pruned_heads:
+ mask = torch.ones(position_bias.shape[1])
+ mask[list(self.pruned_heads)] = 0
+ position_bias_masked = position_bias[:, mask.bool()]
+ else:
+ position_bias_masked = position_bias
+
+ position_bias_masked = position_bias_masked.contiguous()
+ attn_output = me_attention(query_states,
+ key_states,
+ value_states,
+ attn_bias=position_bias_masked,
+ p=self.dropout,
+ scale=1.0)
+ attn_output = unshape(attn_output)
+ attn_output = self.o(attn_output)
+
+ present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None
+
+ outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
+
+ return outputs
+
+ return forward
+
+
+def get_jit_fused_T5_layer_ff_forward():
+
+ from transformers.models.t5.modeling_t5 import T5LayerFF
+
+ def forward(self: T5LayerFF, hidden_states: torch.Tensor) -> torch.Tensor:
+ forwarded_states = self.layer_norm(hidden_states)
+ forwarded_states = self.DenseReluDense(forwarded_states)
+ hidden_states = self.dropout_add(forwarded_states, hidden_states, self.dropout.p, self.dropout.training)
+ return hidden_states
+
+ return forward
+
+
+def get_T5_layer_self_attention_forward():
+
+ from transformers.models.t5.modeling_t5 import T5LayerSelfAttention
+
+ def forward(
+ self: T5LayerSelfAttention,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_bias: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ use_cache: bool = False,
+ output_attentions: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
+ normed_hidden_states = self.layer_norm(hidden_states)
+ attention_output = self.SelfAttention(
+ normed_hidden_states,
+ mask=attention_mask,
+ position_bias=position_bias,
+ layer_head_mask=layer_head_mask,
+ past_key_value=past_key_value,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ )
+ hidden_states = self.dropout_add(attention_output[0], hidden_states, self.dropout.p, self.dropout.training)
+ outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
+ return outputs
+
+ return forward
+
+
+def get_T5_layer_cross_attention_forward():
+
+ from transformers.models.t5.modeling_t5 import T5LayerCrossAttention
+
+ def forward(
+ self: T5LayerCrossAttention,
+ hidden_states: torch.Tensor,
+ key_value_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_bias: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ use_cache: bool = False,
+ query_length: Optional[int] = None,
+ output_attentions: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
+ normed_hidden_states = self.layer_norm(hidden_states)
+ attention_output = self.EncDecAttention(
+ normed_hidden_states,
+ mask=attention_mask,
+ key_value_states=key_value_states,
+ position_bias=position_bias,
+ layer_head_mask=layer_head_mask,
+ past_key_value=past_key_value,
+ use_cache=use_cache,
+ query_length=query_length,
+ output_attentions=output_attentions,
+ )
+ layer_output = self.dropout_add(attention_output[0], hidden_states, self.dropout.p, self.dropout.training)
+ outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
+ return outputs
+
+ return forward
diff --git a/colossalai/shardformer/modeling/vit.py b/colossalai/shardformer/modeling/vit.py
index f28c13ad0aa2..22c4dd998cac 100644
--- a/colossalai/shardformer/modeling/vit.py
+++ b/colossalai/shardformer/modeling/vit.py
@@ -1,4 +1,5 @@
import logging
+import math
from typing import Dict, List, Optional, Set, Tuple, Union
import torch
@@ -335,3 +336,51 @@ def pp_forward(
)
return pp_forward
+
+
+def get_vit_flash_self_attention_forward():
+
+ from transformers.models.vit.modeling_vit import ViTSelfAttention
+
+ from colossalai.kernel.cuda_native.flash_attention import ColoAttention
+
+ def transpose_for_scores(x: torch.Tensor, num_attention_heads, attention_head_size) -> torch.Tensor:
+ new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size)
+ x = x.view(new_x_shape)
+ return x
+
+ def forward(self: ViTSelfAttention,
+ hidden_states: torch.Tensor,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
+ mixed_query_layer = self.query(hidden_states)
+
+ key_layer = transpose_for_scores(self.key(hidden_states), self.num_attention_heads, self.attention_head_size)
+ value_layer = transpose_for_scores(self.value(hidden_states), self.num_attention_heads,
+ self.attention_head_size)
+ query_layer = transpose_for_scores(mixed_query_layer, self.num_attention_heads, self.attention_head_size)
+
+ scale = 1.0 / math.sqrt(self.attention_head_size)
+ attention = ColoAttention(embed_dim=self.all_head_size,
+ num_heads=self.num_attention_heads,
+ dropout=self.dropout.p,
+ scale=scale)
+ context_layer = attention(query_layer, key_layer, value_layer)
+
+ outputs = (context_layer,)
+
+ return outputs
+
+ return forward
+
+
+def get_jit_fused_vit_output_forward():
+
+ from transformers.models.vit.modeling_vit import ViTOutput
+
+ def forward(self: ViTOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training)
+ return hidden_states
+
+ return forward
diff --git a/colossalai/shardformer/modeling/whisper.py b/colossalai/shardformer/modeling/whisper.py
new file mode 100644
index 000000000000..6bc387ac8974
--- /dev/null
+++ b/colossalai/shardformer/modeling/whisper.py
@@ -0,0 +1,249 @@
+from typing import Optional, Tuple
+
+import torch
+from torch import nn
+
+
+def get_whisper_flash_attention_forward():
+
+ from transformers.models.whisper.modeling_whisper import WhisperAttention
+
+ from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention
+
+ def shape(tensor: torch.Tensor, seq_len: int, bsz: int, num_heads: int, head_dim: int):
+ return tensor.view(bsz, seq_len, num_heads, head_dim).contiguous()
+
+ def forward(
+ self: WhisperAttention,
+ hidden_states: torch.Tensor,
+ key_value_states: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+
+ bsz, tgt_len, _ = hidden_states.size()
+
+ # get key, value proj
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
+ # the provided `key_value_states` to support prefix tuning
+ if (is_cross_attention and past_key_value is not None
+ and past_key_value[0].shape[1] == key_value_states.shape[1]):
+ # reuse k,v, cross_attentions
+ key_states = past_key_value[0]
+ value_states = past_key_value[1]
+ elif is_cross_attention:
+ # cross_attentions
+ key_states = shape(self.k_proj(key_value_states), -1, bsz, self.num_heads, self.head_dim)
+ value_states = shape(self.v_proj(key_value_states), -1, bsz, self.num_heads, self.head_dim)
+ elif past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = shape(self.k_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim)
+ value_states = shape(self.v_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim)
+ key_states = torch.cat([past_key_value[0], key_states], dim=1)
+ value_states = torch.cat([past_key_value[1], value_states], dim=1)
+ else:
+ # self_attention
+ key_states = shape(self.k_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim)
+ value_states = shape(self.v_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim)
+
+ if self.is_decoder:
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_states, value_states)
+
+ # get query proj
+ query_states = shape(self.q_proj(hidden_states), tgt_len, bsz, self.num_heads, self.head_dim)
+
+ src_len = key_states.size(1)
+ if layer_head_mask is not None:
+ if layer_head_mask.size() != (self.num_heads,):
+ raise ValueError(f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
+ f" {layer_head_mask.size()}")
+
+ attn_type = None
+ flash_attention_mask = None
+
+ if self.is_decoder:
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
+ )
+ flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool).contiguous())
+ attn_type = AttnMaskType.paddedcausal
+
+ attention = ColoAttention(embed_dim=self.embed_dim,
+ num_heads=self.num_heads,
+ dropout=self.dropout,
+ scale=self.scaling)
+ attn_output = attention(query_states,
+ key_states,
+ value_states,
+ attn_mask=flash_attention_mask,
+ attn_mask_type=attn_type)
+
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, None, past_key_value
+
+ return forward
+
+
+def get_jit_fused_whisper_encoder_layer_forward():
+
+ from transformers.models.whisper.modeling_whisper import WhisperEncoderLayer
+
+ def forward(
+ self: WhisperEncoderLayer,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ layer_head_mask: torch.Tensor,
+ output_attentions: bool = False,
+ ) -> torch.Tensor:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
+ `(encoder_attention_heads,)`.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ """
+ residual = hidden_states
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+ hidden_states, attn_weights, _ = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ layer_head_mask=layer_head_mask,
+ output_attentions=output_attentions,
+ )
+ hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training)
+
+ residual = hidden_states
+ hidden_states = self.final_layer_norm(hidden_states)
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
+ hidden_states = self.fc2(hidden_states)
+ hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training)
+
+ if hidden_states.dtype == torch.float16 and (torch.isinf(hidden_states).any()
+ or torch.isnan(hidden_states).any()):
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ return outputs
+
+ return forward
+
+
+def get_jit_fused_whisper_decoder_layer_forward():
+
+ from transformers.models.whisper.modeling_whisper import WhisperDecoderLayer
+
+ def forward(
+ self: WhisperDecoderLayer,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = True,
+ ) -> torch.Tensor:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ encoder_hidden_states (`torch.FloatTensor`):
+ cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
+ encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
+ `(encoder_attention_heads,)`.
+ cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
+ size `(decoder_attention_heads,)`.
+ past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ """
+ residual = hidden_states
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+
+ # Self Attention
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
+ # add present self-attn cache to positions 1,2 of present_key_value tuple
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ past_key_value=self_attn_past_key_value,
+ attention_mask=attention_mask,
+ layer_head_mask=layer_head_mask,
+ output_attentions=output_attentions,
+ )
+ hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training)
+
+ # Cross-Attention Block
+ cross_attn_present_key_value = None
+ cross_attn_weights = None
+ if encoder_hidden_states is not None:
+ residual = hidden_states
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
+
+ # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
+ hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
+ hidden_states=hidden_states,
+ key_value_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ layer_head_mask=cross_attn_layer_head_mask,
+ past_key_value=cross_attn_past_key_value,
+ output_attentions=output_attentions,
+ )
+ hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training)
+
+ # add cross-attn to positions 3,4 of present_key_value tuple
+ present_key_value = present_key_value + cross_attn_present_key_value
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.final_layer_norm(hidden_states)
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
+ hidden_states = self.fc2(hidden_states)
+ hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training)
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights, cross_attn_weights)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+ return forward
diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py
index 6f86de232fad..ace9ada3904f 100644
--- a/colossalai/shardformer/policies/bert.py
+++ b/colossalai/shardformer/policies/bert.py
@@ -7,7 +7,14 @@
import colossalai.shardformer.layer as col_nn
-from ..modeling.bert import BertPipelineForwards
+from .._utils import getattr_, setattr_
+from ..modeling.bert import (
+ BertPipelineForwards,
+ get_bert_flash_attention_forward,
+ get_jit_fused_bert_output_forward,
+ get_jit_fused_bert_self_output_forward,
+)
+from ..modeling.jit import get_jit_fused_dropout_add_func
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = [
@@ -37,7 +44,13 @@ def preprocess(self):
return self.model
def module_policy(self):
- from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer
+ from transformers.models.bert.modeling_bert import (
+ BertEmbeddings,
+ BertLayer,
+ BertOutput,
+ BertSelfAttention,
+ BertSelfOutput,
+ )
policy = {}
@@ -126,6 +139,23 @@ def module_policy(self):
policy=policy,
target_key=BertEmbeddings)
+ # use flash attention
+ if self.shard_config.enable_flash_attention:
+ policy[BertSelfAttention] = ModulePolicyDescription(method_replacement={
+ 'forward': get_bert_flash_attention_forward(),
+ })
+
+ # use jit operator
+ if self.shard_config.enable_jit_fused:
+ policy[BertSelfOutput] = ModulePolicyDescription(method_replacement={
+ 'forward': get_jit_fused_bert_self_output_forward(),
+ 'dropout_add': get_jit_fused_dropout_add_func(),
+ })
+ policy[BertOutput] = ModulePolicyDescription(method_replacement={
+ 'forward': get_jit_fused_bert_output_forward(),
+ 'dropout_add': get_jit_fused_dropout_add_func(),
+ })
+
return policy
def add_lm_head_policy(self, base_policy):
diff --git a/colossalai/shardformer/policies/blip2.py b/colossalai/shardformer/policies/blip2.py
index a244d70b56f5..50356302e93e 100644
--- a/colossalai/shardformer/policies/blip2.py
+++ b/colossalai/shardformer/policies/blip2.py
@@ -3,7 +3,13 @@
import colossalai.shardformer.layer as col_nn
from .._utils import getattr_, setattr_
-from ..modeling.blip2 import forward_fn
+from ..modeling.blip2 import (
+ forward_fn,
+ get_blip2_flash_attention_forward,
+ get_jit_fused_blip2_QFormer_output_forward,
+ get_jit_fused_blip2_QFormer_self_output_forward,
+)
+from ..modeling.jit import get_jit_fused_dropout_add_func
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ['BlipPolicy', 'BlipModelPolicy']
@@ -33,6 +39,8 @@ def module_policy(self):
Blip2EncoderLayer,
Blip2QFormerLayer,
Blip2QFormerModel,
+ Blip2QFormerOutput,
+ Blip2QFormerSelfOutput,
Blip2VisionModel,
)
from transformers.models.opt.modeling_opt import OPTDecoderLayer, OPTForCausalLM
@@ -275,6 +283,24 @@ def module_policy(self):
policy=policy,
target_key=OPTDecoderLayer)
+ # use flash attention
+ if self.shard_config.enable_flash_attention:
+ policy[Blip2Attention] = ModulePolicyDescription(method_replacement={
+ 'forward': get_blip2_flash_attention_forward(),
+ })
+
+ # use jit operator
+ if self.shard_config.enable_jit_fused:
+ policy[Blip2QFormerSelfOutput] = ModulePolicyDescription(
+ method_replacement={
+ 'forward': get_jit_fused_blip2_QFormer_self_output_forward(),
+ 'dropout_add': get_jit_fused_dropout_add_func(),
+ })
+ policy[Blip2QFormerOutput] = ModulePolicyDescription(method_replacement={
+ 'forward': get_jit_fused_blip2_QFormer_output_forward(),
+ 'dropout_add': get_jit_fused_dropout_add_func(),
+ })
+
return policy
def postprocess(self):
diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py
index 15bae2f4a959..b35764db3870 100644
--- a/colossalai/shardformer/policies/bloom.py
+++ b/colossalai/shardformer/policies/bloom.py
@@ -7,7 +7,16 @@
import colossalai.shardformer.layer as col_nn
-from ..modeling.bloom import BloomPipelineForwards, build_bloom_alibi_tensor_fn
+from .._utils import getattr_, setattr_
+from ..modeling.bloom import (
+ BloomPipelineForwards,
+ build_bloom_alibi_tensor_fn,
+ get_bloom_flash_attention_forward,
+ get_jit_fused_bloom_attention_forward,
+ get_jit_fused_bloom_gelu_forward,
+ get_jit_fused_bloom_mlp_forward,
+)
+from ..modeling.jit import get_dropout_add_func, get_jit_fused_dropout_add_func, get_jit_fused_gelu_forward_func
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
@@ -30,7 +39,7 @@ def preprocess(self):
return self.model
def module_policy(self):
- from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel
+ from transformers.models.bloom.modeling_bloom import BloomAttention, BloomBlock, BloomGelu, BloomMLP, BloomModel
policy = {}
@@ -107,6 +116,27 @@ def module_policy(self):
policy=policy,
target_key=BloomBlock)
+ if self.shard_config.enable_flash_attention:
+ policy[BloomAttention] = ModulePolicyDescription(method_replacement={
+ 'forward': get_bloom_flash_attention_forward(),
+ 'dropout_add': get_dropout_add_func()
+ })
+
+ # enable jit fused operator
+ if self.shard_config.enable_jit_fused:
+ policy[BloomAttention] = ModulePolicyDescription(method_replacement={
+ 'forward': get_jit_fused_bloom_attention_forward(),
+ 'dropout_add': get_jit_fused_dropout_add_func(),
+ })
+ policy[BloomMLP] = ModulePolicyDescription(method_replacement={
+ 'forward': get_jit_fused_bloom_mlp_forward(),
+ 'dropout_add': get_jit_fused_dropout_add_func(),
+ })
+ policy[BloomGelu] = ModulePolicyDescription(method_replacement={
+ 'forward': get_jit_fused_bloom_gelu_forward(),
+ 'bloom_gelu_forward': get_jit_fused_gelu_forward_func(),
+ })
+
return policy
def postprocess(self):
diff --git a/colossalai/shardformer/policies/chatglm.py b/colossalai/shardformer/policies/chatglm.py
index 9cc651caddc1..e6b458936637 100644
--- a/colossalai/shardformer/policies/chatglm.py
+++ b/colossalai/shardformer/policies/chatglm.py
@@ -15,6 +15,8 @@
GLMBlock,
)
+from ..modeling.chatglm import get_flash_core_attention_forward, get_jit_fused_glm_block_forward
+from ..modeling.jit import get_jit_fused_dropout_add_func
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ['ChatGLMPolicy', 'ChatGLMModelPolicy', 'ChatGLMForConditionalGenerationPolicy']
@@ -35,12 +37,11 @@ def preprocess(self):
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
-
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
- from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMModel, GLMBlock
+ from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMModel, CoreAttention, GLMBlock
policy = {}
@@ -121,6 +122,19 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
policy=policy,
target_key=ChatGLMModel)
+ # use flash attention
+ if self.shard_config.enable_flash_attention:
+ policy[CoreAttention] = ModulePolicyDescription(method_replacement={
+ 'forward': get_flash_core_attention_forward(),
+ })
+
+ # use jit fused operator
+ if self.shard_config.enable_jit_fused:
+ policy[GLMBlock] = ModulePolicyDescription(method_replacement={
+ 'forward': get_jit_fused_glm_block_forward(),
+ 'dropout_add': get_jit_fused_dropout_add_func(),
+ })
+
return policy
def postprocess(self):
@@ -192,7 +206,6 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]:
return []
-
class ChatGLMForConditionalGenerationPolicy(ChatGLMModelPolicy):
def module_policy(self):
@@ -213,4 +226,3 @@ def get_held_layers(self) -> List[nn.Module]:
def get_shared_params(self) -> List[Dict[int, Tensor]]:
"""No shared params in ChatGLMForConditionalGenerationModel."""
return []
-
diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py
index 6d734b063036..20e5fa372c8f 100644
--- a/colossalai/shardformer/policies/gpt2.py
+++ b/colossalai/shardformer/policies/gpt2.py
@@ -5,7 +5,8 @@
import colossalai.shardformer.layer as col_nn
-from ..modeling.gpt2 import GPT2PipelineForwards
+from .._utils import getattr_, setattr_
+from ..modeling.gpt2 import GPT2PipelineForwards, get_gpt2_flash_attention_forward
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = [
@@ -33,7 +34,7 @@ def preprocess(self):
return self.model
def module_policy(self):
- from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model
+ from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model
policy = {}
@@ -53,42 +54,42 @@ def module_policy(self):
"attn.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
},
- sub_module_replacement=[
- SubModuleReplacementDescription(
- suffix="attn.c_attn",
- target_module=col_nn.GPT2FusedLinearConv1D_Col,
- kwargs={
- "n_fused": 3,
- },
- ),
- SubModuleReplacementDescription(
- suffix="attn.c_proj",
- target_module=col_nn.GPT2FusedLinearConv1D_Row,
- ),
- SubModuleReplacementDescription(
- suffix="mlp.c_fc",
- target_module=col_nn.GPT2FusedLinearConv1D_Col,
- kwargs={
- "n_fused": 1,
- },
- ),
- SubModuleReplacementDescription(
- suffix="mlp.c_proj",
- target_module=col_nn.GPT2FusedLinearConv1D_Row,
- ),
- SubModuleReplacementDescription(
- suffix="attn.attn_dropout",
- target_module=col_nn.DropoutForParallelInput,
- ),
- SubModuleReplacementDescription(
- suffix="attn.resid_dropout",
- target_module=col_nn.DropoutForParallelInput,
- ),
- SubModuleReplacementDescription(
- suffix="mlp.dropout",
- target_module=col_nn.DropoutForParallelInput,
- ),
- ])
+ sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="attn.c_attn",
+ target_module=col_nn.GPT2FusedLinearConv1D_Col,
+ kwargs={
+ "n_fused": 3,
+ },
+ ),
+ SubModuleReplacementDescription(
+ suffix="attn.c_proj",
+ target_module=col_nn.GPT2FusedLinearConv1D_Row,
+ ),
+ SubModuleReplacementDescription(
+ suffix="mlp.c_fc",
+ target_module=col_nn.GPT2FusedLinearConv1D_Col,
+ kwargs={
+ "n_fused": 1,
+ },
+ ),
+ SubModuleReplacementDescription(
+ suffix="mlp.c_proj",
+ target_module=col_nn.GPT2FusedLinearConv1D_Row,
+ ),
+ SubModuleReplacementDescription(
+ suffix="attn.attn_dropout",
+ target_module=col_nn.DropoutForParallelInput,
+ ),
+ SubModuleReplacementDescription(
+ suffix="attn.resid_dropout",
+ target_module=col_nn.DropoutForParallelInput,
+ ),
+ SubModuleReplacementDescription(
+ suffix="mlp.dropout",
+ target_module=col_nn.DropoutForParallelInput,
+ ),
+ ])
# optimization configuration
if self.shard_config.enable_fused_normalization:
@@ -96,8 +97,8 @@ def module_policy(self):
suffix="ln_f",
target_module=col_nn.FusedLayerNorm,
),
- policy=policy,
- target_key=GPT2Model)
+ policy=policy,
+ target_key=GPT2Model)
self.append_or_create_submodule_replacement(description=[
SubModuleReplacementDescription(
@@ -112,8 +113,13 @@ def module_policy(self):
target_module=col_nn.FusedLayerNorm,
ignore_if_not_exist=True)
],
- policy=policy,
- target_key=GPT2Block)
+ policy=policy,
+ target_key=GPT2Block)
+
+ if self.shard_config.enable_flash_attention:
+ policy[GPT2Attention] = ModulePolicyDescription(method_replacement={
+ 'forward': get_gpt2_flash_attention_forward(),
+ })
return policy
def postprocess(self):
diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py
index 5988366ed57b..5ee95f3be8fa 100644
--- a/colossalai/shardformer/policies/llama.py
+++ b/colossalai/shardformer/policies/llama.py
@@ -7,7 +7,7 @@
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
-from ..modeling.llama import LlamaPipelineForwards
+from ..modeling.llama import LlamaPipelineForwards, get_llama_flash_attention_forward
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ['LlamaPolicy', 'LlamaForCausalLMPolicy', 'LlamaForSequenceClassificationPolicy']
@@ -31,7 +31,7 @@ def preprocess(self):
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
- from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
+ from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel
policy = {}
@@ -104,6 +104,11 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
policy=policy,
target_key=LlamaModel)
+ if self.shard_config.enable_flash_attention:
+ policy[LlamaAttention] = ModulePolicyDescription(method_replacement={
+ 'forward': get_llama_flash_attention_forward(),
+ })
+
return policy
def postprocess(self):
diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py
index 6fc3a2d31f4d..88ecd8565091 100644
--- a/colossalai/shardformer/policies/opt.py
+++ b/colossalai/shardformer/policies/opt.py
@@ -25,6 +25,8 @@
from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
from .._utils import getattr_, setattr_
+from ..modeling.jit import get_jit_fused_dropout_add_func
+from ..modeling.opt import get_jit_fused_opt_decoder_layer_forward, get_opt_flash_attention_forward
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = [
@@ -114,6 +116,19 @@ def module_policy(self):
policy=policy,
target_key=OPTDecoderLayer)
+ # use flash attention
+ if self.shard_config.enable_flash_attention:
+ policy[OPTAttention] = ModulePolicyDescription(method_replacement={
+ 'forward': get_opt_flash_attention_forward(),
+ })
+
+ # use jit fused operator
+ if self.shard_config.enable_jit_fused:
+ policy[OPTDecoderLayer] = ModulePolicyDescription(method_replacement={
+ 'forward': get_jit_fused_opt_decoder_layer_forward(),
+ 'dropout_add': get_jit_fused_dropout_add_func(),
+ })
+
return policy
def postprocess(self):
@@ -189,13 +204,11 @@ def module_policy(self):
from transformers.models.opt.modeling_opt import OPTForCausalLM
policy = super().module_policy()
-
if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)),
policy=policy,
target_key=OPTForCausalLM)
-
if self.pipeline_stage_manager:
self.set_pipeline_forward(model_cls=OPTForCausalLM,
new_forward=OPTPipelineForwards.opt_for_causal_lm_forward,
diff --git a/colossalai/shardformer/policies/sam.py b/colossalai/shardformer/policies/sam.py
index ca20fff715f2..b1eba0432b49 100644
--- a/colossalai/shardformer/policies/sam.py
+++ b/colossalai/shardformer/policies/sam.py
@@ -3,7 +3,7 @@
import colossalai.shardformer.layer as col_nn
from .._utils import getattr_, setattr_
-from ..modeling.sam import forward_fn
+from ..modeling.sam import forward_fn, get_sam_flash_attention_forward, get_sam_vision_flash_attention_forward
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ['SamPolicy', 'SamModelPolicy']
@@ -19,6 +19,7 @@ def preprocess(self):
def module_policy(self):
from transformers.models.sam.modeling_sam import (
+ SamAttention,
SamFeedForward,
SamTwoWayAttentionBlock,
SamTwoWayTransformer,
@@ -196,6 +197,15 @@ def module_policy(self):
policy=policy,
target_key=SamTwoWayTransformer)
+ # use flash attention
+ if self.shard_config.enable_flash_attention:
+ policy[SamAttention] = ModulePolicyDescription(method_replacement={
+ 'forward': get_sam_flash_attention_forward(),
+ })
+ policy[SamVisionAttention] = ModulePolicyDescription(method_replacement={
+ 'forward': get_sam_vision_flash_attention_forward(),
+ })
+
return policy
def postprocess(self):
diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py
index 0ee18d6c4940..5e78ae9093fa 100644
--- a/colossalai/shardformer/policies/t5.py
+++ b/colossalai/shardformer/policies/t5.py
@@ -14,7 +14,14 @@
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription
from .._utils import getattr_, setattr_
-from ..modeling.t5 import T5PipelineForwards
+from ..modeling.jit import get_jit_fused_dropout_add_func
+from ..modeling.t5 import (
+ T5PipelineForwards,
+ get_jit_fused_T5_layer_ff_forward,
+ get_t5_flash_attention_forward,
+ get_T5_layer_cross_attention_forward,
+ get_T5_layer_self_attention_forward,
+)
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ["distribute_t5_layers", "T5ModelPolicy", "T5ForConditionalGenerationPolicy", "T5EncoderPolicy"]
@@ -168,6 +175,27 @@ def module_policy(self):
suffix="final_layer_norm", target_module=FusedRMSNorm),
policy=policy,
target_key=T5Stack)
+
+ # use flash attention
+ if self.shard_config.enable_flash_attention:
+ policy[T5Attention] = ModulePolicyDescription(method_replacement={
+ 'forward': get_t5_flash_attention_forward(),
+ })
+
+ # use jit operator
+ if self.shard_config.enable_jit_fused:
+ policy[T5LayerFF] = ModulePolicyDescription(method_replacement={
+ 'forward': get_jit_fused_T5_layer_ff_forward(),
+ 'dropout_add': get_jit_fused_dropout_add_func(),
+ })
+ policy[T5LayerSelfAttention] = ModulePolicyDescription(method_replacement={
+ 'forward': get_T5_layer_self_attention_forward(),
+ 'dropout_add': get_jit_fused_dropout_add_func(),
+ })
+ policy[T5LayerCrossAttention] = ModulePolicyDescription(method_replacement={
+ 'forward': get_T5_layer_cross_attention_forward(),
+ 'dropout_add': get_jit_fused_dropout_add_func(),
+ })
return policy
def postprocess(self):
diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py
index 47f2c58fc436..07b1a9a2e7c7 100644
--- a/colossalai/shardformer/policies/vit.py
+++ b/colossalai/shardformer/policies/vit.py
@@ -3,12 +3,21 @@
import torch.nn as nn
-import colossalai.shardformer.layer as col_nn
+from colossalai.shardformer.layer import (
+ DropoutForParallelInput,
+ DropoutForReplicatedInput,
+ FusedLayerNorm,
+ Linear1D_Col,
+ Linear1D_Row,
+)
+from ..modeling.jit import get_jit_fused_dropout_add_func
from ..modeling.vit import (
ViTForImageClassification_pipeline_forward,
ViTForMaskedImageModeling_pipeline_forward,
ViTModel_pipeline_forward,
+ get_jit_fused_vit_output_forward,
+ get_vit_flash_self_attention_forward,
)
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
@@ -24,7 +33,8 @@ def preprocess(self):
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
- from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTLayer
+
+ from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTLayer, ViTModel, ViTOutput, ViTSelfAttention
policy = {}
@@ -34,7 +44,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
- target_module=col_nn.DropoutForReplicatedInput,
+ target_module=DropoutForReplicatedInput,
)
])
@@ -48,42 +58,54 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="attention.attention.query",
- target_module=col_nn.Linear1D_Col,
+ target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.attention.key",
- target_module=col_nn.Linear1D_Col,
+ target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.attention.value",
- target_module=col_nn.Linear1D_Col,
+ target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.attention.dropout",
- target_module=col_nn.DropoutForParallelInput,
+ target_module=DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="attention.output.dense",
- target_module=col_nn.Linear1D_Row,
+ target_module=Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="attention.output.dropout",
- target_module=col_nn.DropoutForReplicatedInput,
+ target_module=DropoutForReplicatedInput,
),
SubModuleReplacementDescription(
suffix="intermediate.dense",
- target_module=col_nn.Linear1D_Col,
+ target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="output.dense",
- target_module=col_nn.Linear1D_Row,
+ target_module=Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="output.dropout",
- target_module=col_nn.DropoutForReplicatedInput,
+ target_module=DropoutForReplicatedInput,
),
])
+ # use flash attention
+ if self.shard_config.enable_flash_attention:
+ policy[ViTSelfAttention] = ModulePolicyDescription(method_replacement={
+ 'forward': get_vit_flash_self_attention_forward(),
+ })
+
+ # use jit fused operator
+ if self.shard_config.enable_jit_fused:
+ policy[ViTOutput] = ModulePolicyDescription(method_replacement={
+ 'forward': get_jit_fused_vit_output_forward(),
+ 'dropout_add': get_jit_fused_dropout_add_func(),
+ })
return policy
def new_model_class(self):
@@ -166,7 +188,7 @@ def module_policy(self):
ViTForImageClassification:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
- suffix="classifier", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True))
+ suffix="classifier", target_module=Linear1D_Col, kwargs=dict(gather_output=True))
])
}
policy.update(new_item)
diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py
index 2f3565bdaa96..2ac7a49fd27b 100644
--- a/colossalai/shardformer/policies/whisper.py
+++ b/colossalai/shardformer/policies/whisper.py
@@ -3,6 +3,12 @@
import colossalai.shardformer.layer as col_nn
from .._utils import getattr_, setattr_
+from ..modeling.jit import get_jit_fused_dropout_add_func
+from ..modeling.whisper import (
+ get_jit_fused_whisper_decoder_layer_forward,
+ get_jit_fused_whisper_encoder_layer_forward,
+ get_whisper_flash_attention_forward,
+)
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = [
@@ -30,6 +36,7 @@ def preprocess(self):
def module_policy(self):
from transformers.models.whisper.modeling_whisper import (
+ WhisperAttention,
WhisperDecoder,
WhisperDecoderLayer,
WhisperEncoder,
@@ -181,6 +188,24 @@ def module_policy(self):
],
policy=policy,
target_key=WhisperDecoder)
+
+ # enable flash attention
+ if self.shard_config.enable_flash_attention:
+ policy[WhisperAttention] = ModulePolicyDescription(method_replacement={
+ 'forward': get_whisper_flash_attention_forward(),
+ })
+
+ # use jit fused operator
+ if self.shard_config.enable_jit_fused:
+ policy[WhisperEncoderLayer] = ModulePolicyDescription(method_replacement={
+ 'forward': get_jit_fused_whisper_encoder_layer_forward(),
+ 'dropout_add': get_jit_fused_dropout_add_func(),
+ })
+ policy[WhisperDecoderLayer] = ModulePolicyDescription(method_replacement={
+ 'forward': get_jit_fused_whisper_decoder_layer_forward(),
+ 'dropout_add': get_jit_fused_dropout_add_func(),
+ })
+
return policy
def add_lm_head_policy(self, base_policy):
diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py
index 75fad4eb7431..ec6e0cd0d4be 100644
--- a/colossalai/shardformer/shard/shard_config.py
+++ b/colossalai/shardformer/shard/shard_config.py
@@ -26,6 +26,8 @@ class ShardConfig:
enable_tensor_parallelism: bool = True
enable_fused_normalization: bool = False
enable_all_optimization: bool = False
+ enable_flash_attention: bool = False
+ enable_jit_fused: bool = False
# TODO: add support for tensor parallel
# pipeline_parallel_size: int
@@ -44,7 +46,6 @@ def __post_init__(self):
else:
# get the parallel size
self._tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group)
-
# turn on all optimization if all_optimization is set to True
if self.enable_all_optimization:
self._turn_on_all_optimization()
@@ -55,3 +56,5 @@ def _turn_on_all_optimization(self):
"""
# you can add all the optimization flag here
self.enable_fused_normalization = True
+ self.enable_flash_attention = True
+ self.enable_jit_fused = True
diff --git a/pytest.ini b/pytest.ini
index 01e5cd217c5d..e8a60c85336b 100644
--- a/pytest.ini
+++ b/pytest.ini
@@ -4,3 +4,4 @@ markers =
gpu: tests which requires a single GPU
dist: tests which are run in a multi-GPU or multi-machine environment
experiment: tests for experimental features
+addopts = --ignore=tests/test_analyzer --ignore=tests/test_auto_parallel --ignore=tests/test_autochunk --ignore=tests/test_moe
diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt
index 6f8a72e3962f..fa797f26a4ca 100644
--- a/requirements/requirements-test.txt
+++ b/requirements/requirements-test.txt
@@ -13,7 +13,9 @@ torchrec==0.2.0
contexttimer
einops
triton==2.0.0.dev20221202
-git+https://github.com/HazyResearch/flash-attention.git@c422fee3776eb3ea24e011ef641fd5fbeb212623#egg=flash_attn
+# git+https://github.com/HazyResearch/flash-attention.git@c422fee3776eb3ea24e011ef641fd5fbeb212623#egg=flash_attn
requests==2.27.1 # downgrade to avoid huggingface error https://github.com/huggingface/transformers/issues/17611
SentencePiece
datasets
+ninja
+flash-attn
diff --git a/requirements/requirements.txt b/requirements/requirements.txt
index b34dc2e223ae..3ee1567db7fa 100644
--- a/requirements/requirements.txt
+++ b/requirements/requirements.txt
@@ -10,3 +10,4 @@ contexttimer
ninja
torch>=1.11
safetensors
+flash-attn
diff --git a/tests/kit/model_zoo/transformers/bert.py b/tests/kit/model_zoo/transformers/bert.py
index d17b8fda425a..9834f5425027 100644
--- a/tests/kit/model_zoo/transformers/bert.py
+++ b/tests/kit/model_zoo/transformers/bert.py
@@ -20,7 +20,7 @@ def data_gen():
# token_type_ids = tokenized_input['token_type_ids']
input_ids = torch.tensor([[101, 7592, 1010, 2026, 3899, 2003, 10140, 102]], dtype=torch.int64)
token_type_ids = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int64)
- attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
+ attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 0]], dtype=torch.int64)
return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
@@ -69,19 +69,21 @@ def data_gen_for_mcq():
# data['labels'] = torch.tensor([0], dtype=torch.int64)
input_ids = torch.tensor([[[
101, 1999, 3304, 1010, 10733, 2366, 1999, 5337, 10906, 1010, 2107, 2004, 2012, 1037, 4825, 1010, 2003, 3591,
- 4895, 14540, 6610, 2094, 1012, 102, 2009, 2003, 8828, 2007, 1037, 9292, 1998, 1037, 5442, 1012, 102
+ 4895, 14540, 6610, 2094, 1012, 102, 2009, 2003, 8828, 2007, 1037, 9292, 1998, 1037, 5442, 1012, 102, 102
],
[
101, 1999, 3304, 1010, 10733, 2366, 1999, 5337, 10906, 1010, 2107, 2004, 2012, 1037,
4825, 1010, 2003, 3591, 4895, 14540, 6610, 2094, 1012, 102, 2009, 2003, 8828, 2096,
- 2218, 1999, 1996, 2192, 1012, 102, 0
+ 2218, 1999, 1996, 2192, 1012, 102, 0, 0
]]])
token_type_ids = torch.tensor(
- [[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
- [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]]])
+ [[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,
+ 0]]])
attention_mask = torch.tensor(
- [[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
- [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]]])
+ [[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,
+ 0]]])
labels = torch.tensor([0], dtype=torch.int64)
return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, labels=labels)
diff --git a/tests/kit/model_zoo/transformers/blip2.py b/tests/kit/model_zoo/transformers/blip2.py
index 7338f740be7f..984a6ffa920d 100644
--- a/tests/kit/model_zoo/transformers/blip2.py
+++ b/tests/kit/model_zoo/transformers/blip2.py
@@ -38,6 +38,7 @@ def data_gen():
loss_fn_blip2_model = lambda x: x.loss
config = transformers.Blip2Config()
+config.vision_config.patch_size = 14
config.text_config.num_hidden_layers = 1
config.qformer_config.num_hidden_layers = 1
config.vision_config.num_hidden_layers = 1
diff --git a/tests/kit/model_zoo/transformers/bloom.py b/tests/kit/model_zoo/transformers/bloom.py
index 5d195db2c68d..177edbef8935 100644
--- a/tests/kit/model_zoo/transformers/bloom.py
+++ b/tests/kit/model_zoo/transformers/bloom.py
@@ -16,8 +16,8 @@ def data_gen():
# tokenized_input = tokenizer(input, return_tensors='pt')
# input_ids = tokenized_input['input_ids']
# attention_mask = tokenized_input['attention_mask']
- input_ids = torch.tensor([[59414, 15, 2670, 35433, 632, 207595]], dtype=torch.int64)
- attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1]], dtype=torch.int64)
+ input_ids = torch.tensor([[59414, 15, 2670, 35433, 632, 207595, 632, 207595]], dtype=torch.int64)
+ attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
return dict(input_ids=input_ids, attention_mask=attention_mask)
@@ -33,7 +33,7 @@ def data_gen_for_token_classification():
# token classification data gen
# `labels` is the type not the token id for token classification, 0 or 1
data = data_gen()
- data['labels'] = torch.tensor([[0, 0, 0, 0, 0, 0]], dtype=torch.int64)
+ data['labels'] = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int64)
return data
@@ -53,8 +53,8 @@ def data_gen_for_question_answering():
# inputs = tokenizer(question, text, return_tensors="pt")
input_ids = torch.tensor(
- [[57647, 1620, 23967, 620, 107373, 34, 91514, 620, 107373, 1620, 267, 35378, 48946, 18161]], dtype=torch.int64)
- attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
+ [[57647, 1620, 23967, 620, 107373, 34, 91514, 620, 107373, 1620, 267, 35378, 48946, 18161, 48946, 18161]], dtype=torch.int64)
+ attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
start_positions = torch.tensor([1], dtype=torch.int64)
end_positions = torch.tensor([10], dtype=torch.int64)
return dict(input_ids=input_ids,
diff --git a/tests/kit/model_zoo/transformers/chatglm.py b/tests/kit/model_zoo/transformers/chatglm.py
index 056c910a8dfe..90bb70bc7f79 100644
--- a/tests/kit/model_zoo/transformers/chatglm.py
+++ b/tests/kit/model_zoo/transformers/chatglm.py
@@ -6,7 +6,6 @@
from ..registry import ModelAttribute, model_zoo
-
# ================================
# Register single-sentence ChatGLM
# ================================
diff --git a/tests/kit/model_zoo/transformers/chatglm2_6b/configuration_chatglm.py b/tests/kit/model_zoo/transformers/chatglm2_6b/configuration_chatglm.py
deleted file mode 100644
index 3e78732be2da..000000000000
--- a/tests/kit/model_zoo/transformers/chatglm2_6b/configuration_chatglm.py
+++ /dev/null
@@ -1,58 +0,0 @@
-from transformers import PretrainedConfig
-
-
-class ChatGLMConfig(PretrainedConfig):
- model_type = "chatglm"
-
- def __init__(self,
- num_layers=28,
- padded_vocab_size=65024,
- hidden_size=4096,
- ffn_hidden_size=13696,
- kv_channels=128,
- num_attention_heads=32,
- seq_length=2048,
- hidden_dropout=0.0,
- attention_dropout=0.0,
- layernorm_epsilon=1e-5,
- rmsnorm=True,
- apply_residual_connection_post_layernorm=False,
- post_layer_norm=True,
- add_bias_linear=False,
- add_qkv_bias=False,
- bias_dropout_fusion=True,
- multi_query_attention=False,
- multi_query_group_num=1,
- apply_query_key_layer_scaling=True,
- attention_softmax_in_fp32=True,
- fp32_residual_connection=False,
- quantization_bit=0,
- pre_seq_len=None,
- prefix_projection=False,
- **kwargs):
- self.num_layers = num_layers
- self.vocab_size = padded_vocab_size
- self.padded_vocab_size = padded_vocab_size
- self.hidden_size = hidden_size
- self.ffn_hidden_size = ffn_hidden_size
- self.kv_channels = kv_channels
- self.num_attention_heads = num_attention_heads
- self.seq_length = seq_length
- self.hidden_dropout = hidden_dropout
- self.attention_dropout = attention_dropout
- self.layernorm_epsilon = layernorm_epsilon
- self.rmsnorm = rmsnorm
- self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
- self.post_layer_norm = post_layer_norm
- self.add_bias_linear = add_bias_linear
- self.add_qkv_bias = add_qkv_bias
- self.bias_dropout_fusion = bias_dropout_fusion
- self.multi_query_attention = multi_query_attention
- self.multi_query_group_num = multi_query_group_num
- self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
- self.attention_softmax_in_fp32 = attention_softmax_in_fp32
- self.fp32_residual_connection = fp32_residual_connection
- self.quantization_bit = quantization_bit
- self.pre_seq_len = pre_seq_len
- self.prefix_projection = prefix_projection
- super().__init__(**kwargs)
diff --git a/tests/kit/model_zoo/transformers/chatglm2_6b/modeling_chatglm.py b/tests/kit/model_zoo/transformers/chatglm2_6b/modeling_chatglm.py
deleted file mode 100644
index bae6d425878d..000000000000
--- a/tests/kit/model_zoo/transformers/chatglm2_6b/modeling_chatglm.py
+++ /dev/null
@@ -1,1372 +0,0 @@
-"""
-The ChatGLM2-6B License
-
-1. Definitions
-
-“Licensor” means the ChatGLM2-6B Model Team that distributes its Software.
-
-“Software” means the ChatGLM2-6B model parameters made available under this license.
-
-2. License Grant
-
-Subject to the terms and conditions of this License, the Licensor hereby grants to you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty-free copyright license to use the Software solely for your non-commercial research purposes.
-
-The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
-
-3. Restriction
-
-You will not use, copy, modify, merge, publish, distribute, reproduce, or create derivative works of the Software, in whole or in part, for any commercial, military, or illegal purposes.
-
-You will not use the Software for any act that may undermine China's national security and national unity, harm the public interest of society, or infringe upon the rights and interests of human beings.
-
-4. Disclaimer
-
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
-
-5. Limitation of Liability
-
-EXCEPT TO THE EXTENT PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER BASED IN TORT, NEGLIGENCE, CONTRACT, LIABILITY, OR OTHERWISE WILL ANY LICENSOR BE LIABLE TO YOU FOR ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES, OR ANY OTHER COMMERCIAL LOSSES, EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
-
-6. Dispute Resolution
-
-This license shall be governed and construed in accordance with the laws of People’s Republic of China. Any dispute arising from or in connection with this License shall be submitted to Haidian District People's Court in Beijing.
-
-Note that the license is subject to update to a more comprehensive version. For any questions related to the license and copyright, please contact us at glm-130b@googlegroups.com.
-"""
-""" PyTorch ChatGLM model. """
-
-import copy
-import math
-import re
-import sys
-import warnings
-from typing import Any, Callable, Dict, List, Optional, Tuple, Union
-
-import torch
-import torch.nn.functional as F
-import torch.utils.checkpoint
-from torch import nn
-from torch.nn import CrossEntropyLoss, LayerNorm
-from torch.nn.utils import skip_init
-from transformers.generation.logits_process import LogitsProcessor
-from transformers.generation.utils import GenerationConfig, LogitsProcessorList, ModelOutput, StoppingCriteriaList
-from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
-from transformers.modeling_utils import PreTrainedModel
-from transformers.utils import logging
-
-from .configuration_chatglm import ChatGLMConfig
-
-# flags required to enable jit fusion kernels
-
-if sys.platform != "darwin":
- torch._C._jit_set_profiling_mode(False)
- torch._C._jit_set_profiling_executor(False)
- torch._C._jit_override_can_fuse_on_cpu(True)
- torch._C._jit_override_can_fuse_on_gpu(True)
-
-logger = logging.get_logger(__name__)
-
-_CHECKPOINT_FOR_DOC = "THUDM/ChatGLM2-6B"
-_CONFIG_FOR_DOC = "ChatGLM6BConfig"
-
-CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [
- "THUDM/chatglm2-6b",
- # See all ChatGLM models at https://huggingface.co/models?filter=chatglm
-]
-
-
-def default_init(cls, *args, **kwargs):
- return cls(*args, **kwargs)
-
-
-class InvalidScoreLogitsProcessor(LogitsProcessor):
-
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
- if torch.isnan(scores).any() or torch.isinf(scores).any():
- scores.zero_()
- scores[..., 5] = 5e4
- return scores
-
-
-class PrefixEncoder(torch.nn.Module):
- """
- The torch.nn model to encode the prefix
- Input shape: (batch-size, prefix-length)
- Output shape: (batch-size, prefix-length, 2*layers*hidden)
- """
-
- def __init__(self, config: ChatGLMConfig):
- super().__init__()
- self.prefix_projection = config.prefix_projection
- if self.prefix_projection:
- # Use a two-layer MLP to encode the prefix
- kv_size = (config.num_layers * config.kv_channels * config.multi_query_group_num * 2)
- self.embedding = torch.nn.Embedding(config.pre_seq_len, kv_size)
- self.trans = torch.nn.Sequential(
- torch.nn.Linear(kv_size, config.hidden_size),
- torch.nn.Tanh(),
- torch.nn.Linear(config.hidden_size, kv_size),
- )
- else:
- self.embedding = torch.nn.Embedding(
- config.pre_seq_len,
- config.num_layers * config.kv_channels * config.multi_query_group_num * 2,
- )
-
- def forward(self, prefix: torch.Tensor):
- if self.prefix_projection:
- prefix_tokens = self.embedding(prefix)
- past_key_values = self.trans(prefix_tokens)
- else:
- past_key_values = self.embedding(prefix)
- return past_key_values
-
-
-def split_tensor_along_last_dim(
- tensor: torch.Tensor,
- num_partitions: int,
- contiguous_split_chunks: bool = False,
-) -> List[torch.Tensor]:
- """Split a tensor along its last dimension.
-
- Arguments:
- tensor: input tensor.
- num_partitions: number of partitions to split the tensor
- contiguous_split_chunks: If True, make each chunk contiguous
- in memory.
-
- Returns:
- A list of Tensors
- """
- # Get the size and dimension.
- last_dim = tensor.dim() - 1
- last_dim_size = tensor.size()[last_dim] // num_partitions
- # Split.
- tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
- # Note: torch.split does not create contiguous tensors by default.
- if contiguous_split_chunks:
- return tuple(chunk.contiguous() for chunk in tensor_list)
-
- return tensor_list
-
-
-class RotaryEmbedding(nn.Module):
-
- def __init__(self, dim, original_impl=False, device=None, dtype=None):
- super().__init__()
- inv_freq = 1.0 / (10000**(torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim))
- self.register_buffer("inv_freq", inv_freq)
- self.dim = dim
- self.original_impl = original_impl
-
- def forward_impl(
- self,
- seq_len: int,
- n_elem: int,
- dtype: torch.dtype,
- device: torch.device,
- base: int = 10000,
- ):
- """Enhanced Transformer with Rotary Position Embedding.
-
- Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
- transformers/rope/__init__.py. MIT License:
- https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
- """
- # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
- theta = 1.0 / (base**(torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem))
-
- # Create position indexes `[0, 1, ..., seq_len - 1]`
- seq_idx = torch.arange(seq_len, dtype=dtype, device=device)
-
- # Calculate the product of position index and $\theta_i$
- idx_theta = torch.outer(seq_idx, theta).float()
-
- cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)
-
- # this is to mimic the behaviour of complex32, else we will get different results
- if dtype in (torch.float16, torch.bfloat16, torch.int8):
- cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half()
- return cache
-
- def forward(self, max_seq_len, offset=0):
- return self.forward_impl(
- max_seq_len,
- self.dim,
- dtype=self.inv_freq.dtype,
- device=self.inv_freq.device,
- )
-
-
-@torch.jit.script
-def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
- # x: [sq, b, np, hn]
- sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
- rot_dim = rope_cache.shape[-2] * 2
- x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
- # truncate to support variable sizes
- rope_cache = rope_cache[:sq]
- xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2)
- rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2)
- x_out2 = torch.stack(
- [
- xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
- xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
- ],
- -1,
- )
- x_out2 = x_out2.flatten(3)
- return torch.cat((x_out2, x_pass), dim=-1)
-
-
-class RMSNorm(torch.nn.Module):
-
- def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
- super().__init__()
- self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
- self.eps = eps
-
- def forward(self, hidden_states: torch.Tensor):
- input_dtype = hidden_states.dtype
- variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
- hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
-
- return (self.weight * hidden_states).to(input_dtype)
-
-
-class CoreAttention(torch.nn.Module):
-
- def __init__(self, config: ChatGLMConfig, layer_number):
- super(CoreAttention, self).__init__()
-
- self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling
- self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
- if self.apply_query_key_layer_scaling:
- self.attention_softmax_in_fp32 = True
- self.layer_number = max(1, layer_number)
-
- projection_size = config.kv_channels * config.num_attention_heads
-
- # Per attention head and per partition values.
- self.hidden_size_per_partition = projection_size
- self.hidden_size_per_attention_head = (projection_size // config.num_attention_heads)
- self.num_attention_heads_per_partition = config.num_attention_heads
-
- coeff = None
- self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
- if self.apply_query_key_layer_scaling:
- coeff = self.layer_number
- self.norm_factor *= coeff
- self.coeff = coeff
-
- self.attention_dropout = torch.nn.Dropout(config.attention_dropout)
-
- def forward(self, query_layer, key_layer, value_layer, attention_mask):
- pytorch_major_version = int(torch.__version__.split(".")[0])
- if pytorch_major_version >= 2:
- query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
- if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
- context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer,
- key_layer,
- value_layer,
- is_causal=True)
- else:
- if attention_mask is not None:
- attention_mask = ~attention_mask
- context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
- attention_mask)
- context_layer = context_layer.permute(2, 0, 1, 3)
- new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
- context_layer = context_layer.reshape(*new_context_layer_shape)
- else:
- # Raw attention scores
-
- # [b, np, sq, sk]
- output_size = (
- query_layer.size(1),
- query_layer.size(2),
- query_layer.size(0),
- key_layer.size(0),
- )
-
- # [sq, b, np, hn] -> [sq, b * np, hn]
- query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)
- # [sk, b, np, hn] -> [sk, b * np, hn]
- key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
-
- # preallocting input tensor: [b * np, sq, sk]
- matmul_input_buffer = torch.empty(
- output_size[0] * output_size[1],
- output_size[2],
- output_size[3],
- dtype=query_layer.dtype,
- device=query_layer.device,
- )
-
- # Raw attention scores. [b * np, sq, sk]
- matmul_result = torch.baddbmm(
- matmul_input_buffer,
- query_layer.transpose(0, 1), # [b * np, sq, hn]
- key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
- beta=0.0,
- alpha=(1.0 / self.norm_factor),
- )
-
- # change view to [b, np, sq, sk]
- attention_scores = matmul_result.view(*output_size)
-
- # ===========================
- # Attention probs and dropout
- # ===========================
-
- # attention scores and attention mask [b, np, sq, sk]
- if self.attention_softmax_in_fp32:
- attention_scores = attention_scores.float()
- if self.coeff is not None:
- attention_scores = attention_scores * self.coeff
- if (attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]):
- attention_mask = torch.ones(
- output_size[0],
- 1,
- output_size[2],
- output_size[3],
- device=attention_scores.device,
- dtype=torch.bool,
- )
- attention_mask.tril_()
- attention_mask = ~attention_mask
- if attention_mask is not None:
- attention_scores = attention_scores.masked_fill(attention_mask, float("-inf"))
- attention_probs = F.softmax(attention_scores, dim=-1)
- attention_probs = attention_probs.type_as(value_layer)
-
- # This is actually dropping out entire tokens to attend to, which might
- # seem a bit unusual, but is taken from the original Transformer paper.
- attention_probs = self.attention_dropout(attention_probs)
- # =========================
- # Context layer. [sq, b, hp]
- # =========================
-
- # value_layer -> context layer.
- # [sk, b, np, hn] --> [b, np, sq, hn]
-
- # context layer shape: [b, np, sq, hn]
- output_size = (
- value_layer.size(1),
- value_layer.size(2),
- query_layer.size(0),
- value_layer.size(3),
- )
- # change view [sk, b * np, hn]
- value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1)
- # change view [b * np, sq, sk]
- attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
- # matmul: [b * np, sq, hn]
- context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
- # change view [b, np, sq, hn]
- context_layer = context_layer.view(*output_size)
- # [b, np, sq, hn] --> [sq, b, np, hn]
- context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
- # [sq, b, np, hn] --> [sq, b, hp]
- new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
- context_layer = context_layer.view(*new_context_layer_shape)
-
- return context_layer
-
-
-class SelfAttention(torch.nn.Module):
- """Parallel self-attention layer abstract class.
-
- Self-attention layer takes input with size [s, b, h]
- and returns output of the same size.
- """
-
- def __init__(self, config: ChatGLMConfig, layer_number, device=None):
- super(SelfAttention, self).__init__()
- self.layer_number = max(1, layer_number)
-
- self.projection_size = config.kv_channels * config.num_attention_heads
- # Per attention head and per partition values.
- self.hidden_size_per_attention_head = (self.projection_size // config.num_attention_heads)
- self.num_attention_heads_per_partition = config.num_attention_heads
-
- self.multi_query_attention = config.multi_query_attention
- self.qkv_hidden_size = 3 * self.projection_size
- if self.multi_query_attention:
- self.num_multi_query_groups_per_partition = config.multi_query_group_num
- self.qkv_hidden_size = (self.projection_size +
- 2 * self.hidden_size_per_attention_head * config.multi_query_group_num)
- self.query_key_value = nn.Linear(
- config.hidden_size,
- self.qkv_hidden_size,
- bias=config.add_bias_linear or config.add_qkv_bias,
- device=device,
- **_config_to_kwargs(config),
- )
-
- self.core_attention = CoreAttention(config, self.layer_number)
-
- # Output.
- self.dense = nn.Linear(
- self.projection_size,
- config.hidden_size,
- bias=config.add_bias_linear,
- device=device,
- **_config_to_kwargs(config),
- )
-
- def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None):
- if self.multi_query_attention:
- num_attention_heads = self.num_multi_query_groups_per_partition
- else:
- num_attention_heads = self.num_attention_heads_per_partition
- return torch.empty(
- inference_max_sequence_len,
- batch_size,
- num_attention_heads,
- self.hidden_size_per_attention_head,
- dtype=dtype,
- device=device,
- )
-
- def forward(
- self,
- hidden_states,
- attention_mask,
- rotary_pos_emb,
- kv_cache=None,
- use_cache=True,
- ):
- # hidden_states: [sq, b, h]
-
- # =================================================
- # Pre-allocate memory for key-values for inference.
- # =================================================
- # =====================
- # Query, Key, and Value
- # =====================
-
- # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
- mixed_x_layer = self.query_key_value(hidden_states)
-
- if self.multi_query_attention:
- (query_layer, key_layer, value_layer) = mixed_x_layer.split(
- [
- self.num_attention_heads_per_partition * self.hidden_size_per_attention_head,
- self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
- self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
- ],
- dim=-1,
- )
- query_layer = query_layer.view(query_layer.size()[:-1] + (
- self.num_attention_heads_per_partition,
- self.hidden_size_per_attention_head,
- ))
- key_layer = key_layer.view(key_layer.size()[:-1] + (
- self.num_multi_query_groups_per_partition,
- self.hidden_size_per_attention_head,
- ))
- value_layer = value_layer.view(value_layer.size()[:-1] + (
- self.num_multi_query_groups_per_partition,
- self.hidden_size_per_attention_head,
- ))
- else:
- new_tensor_shape = mixed_x_layer.size()[:-1] + (
- self.num_attention_heads_per_partition,
- 3 * self.hidden_size_per_attention_head,
- )
- mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
- # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
- (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
-
- # apply relative positional encoding (rotary embedding)
- if rotary_pos_emb is not None:
- query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
- key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)
-
- # adjust key and value for inference
- if kv_cache is not None:
- cache_k, cache_v = kv_cache
- key_layer = torch.cat((cache_k, key_layer), dim=0)
- value_layer = torch.cat((cache_v, value_layer), dim=0)
- if use_cache:
- kv_cache = (key_layer, value_layer)
- else:
- kv_cache = None
-
- if self.multi_query_attention:
- key_layer = key_layer.unsqueeze(-2)
- key_layer = key_layer.expand(
- -1,
- -1,
- -1,
- self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition,
- -1,
- )
- key_layer = key_layer.contiguous().view(key_layer.size()[:2] + (
- self.num_attention_heads_per_partition,
- self.hidden_size_per_attention_head,
- ))
- value_layer = value_layer.unsqueeze(-2)
- value_layer = value_layer.expand(
- -1,
- -1,
- -1,
- self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition,
- -1,
- )
- value_layer = value_layer.contiguous().view(value_layer.size()[:2] + (
- self.num_attention_heads_per_partition,
- self.hidden_size_per_attention_head,
- ))
-
- # ==================================
- # core attention computation
- # ==================================
-
- context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)
-
- # =================
- # Output. [sq, b, h]
- # =================
-
- output = self.dense(context_layer)
-
- return output, kv_cache
-
-
-def _config_to_kwargs(args):
- common_kwargs = {
- "dtype": args.torch_dtype,
- }
- return common_kwargs
-
-
-class MLP(torch.nn.Module):
- """MLP.
-
- MLP will take the input with h hidden state, project it to 4*h
- hidden dimension, perform nonlinear transformation, and project the
- state back into h hidden dimension.
- """
-
- def __init__(self, config: ChatGLMConfig, device=None):
- super(MLP, self).__init__()
-
- self.add_bias = config.add_bias_linear
-
- # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
- self.dense_h_to_4h = nn.Linear(
- config.hidden_size,
- config.ffn_hidden_size * 2,
- bias=self.add_bias,
- device=device,
- **_config_to_kwargs(config),
- )
-
- def swiglu(x):
- x = torch.chunk(x, 2, dim=-1)
- return F.silu(x[0]) * x[1]
-
- self.activation_func = swiglu
-
- # Project back to h.
- self.dense_4h_to_h = nn.Linear(
- config.ffn_hidden_size,
- config.hidden_size,
- bias=self.add_bias,
- device=device,
- **_config_to_kwargs(config),
- )
-
- def forward(self, hidden_states):
- # [s, b, 4hp]
- intermediate_parallel = self.dense_h_to_4h(hidden_states)
- intermediate_parallel = self.activation_func(intermediate_parallel)
- # [s, b, h]
- output = self.dense_4h_to_h(intermediate_parallel)
- return output
-
-
-class GLMBlock(torch.nn.Module):
- """A single transformer layer.
-
- Transformer layer takes input with size [s, b, h] and returns an
- output of the same size.
- """
-
- def __init__(self, config: ChatGLMConfig, layer_number, device=None):
- super(GLMBlock, self).__init__()
- self.layer_number = layer_number
-
- self.apply_residual_connection_post_layernorm = (config.apply_residual_connection_post_layernorm)
-
- self.fp32_residual_connection = config.fp32_residual_connection
-
- LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
- # Layernorm on the input data.
- self.input_layernorm = LayerNormFunc(
- config.hidden_size,
- eps=config.layernorm_epsilon,
- device=device,
- dtype=config.torch_dtype,
- )
-
- # Self attention.
- self.self_attention = SelfAttention(config, layer_number, device=device)
- self.hidden_dropout = config.hidden_dropout
-
- # Layernorm on the attention output
- self.post_attention_layernorm = LayerNormFunc(
- config.hidden_size,
- eps=config.layernorm_epsilon,
- device=device,
- dtype=config.torch_dtype,
- )
-
- # MLP
- self.mlp = MLP(config, device=device)
-
- def forward(
- self,
- hidden_states,
- attention_mask,
- rotary_pos_emb,
- kv_cache=None,
- use_cache=True,
- ):
- # hidden_states: [s, b, h]
-
- # Layer norm at the beginning of the transformer layer.
- layernorm_output = self.input_layernorm(hidden_states)
- # Self attention.
- attention_output, kv_cache = self.self_attention(
- layernorm_output,
- attention_mask,
- rotary_pos_emb,
- kv_cache=kv_cache,
- use_cache=use_cache,
- )
-
- # Residual connection.
- if self.apply_residual_connection_post_layernorm:
- residual = layernorm_output
- else:
- residual = hidden_states
-
- layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training)
- layernorm_input = residual + layernorm_input
-
- # Layer norm post the self attention.
- layernorm_output = self.post_attention_layernorm(layernorm_input)
-
- # MLP.
- mlp_output = self.mlp(layernorm_output)
-
- # Second residual connection.
- if self.apply_residual_connection_post_layernorm:
- residual = layernorm_output
- else:
- residual = layernorm_input
-
- output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training)
- output = residual + output
-
- return output, kv_cache
-
-
-class GLMTransformer(torch.nn.Module):
- """Transformer class."""
-
- def __init__(self, config: ChatGLMConfig, device=None):
- super(GLMTransformer, self).__init__()
-
- self.fp32_residual_connection = config.fp32_residual_connection
- self.post_layer_norm = config.post_layer_norm
-
- # Number of layers.
- self.num_layers = config.num_layers
-
- # Transformer layers.
- def build_layer(layer_number):
- return GLMBlock(config, layer_number, device=device)
-
- self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)])
-
- if self.post_layer_norm:
- LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
- # Final layer norm before output.
- self.final_layernorm = LayerNormFunc(
- config.hidden_size,
- eps=config.layernorm_epsilon,
- device=device,
- dtype=config.torch_dtype,
- )
-
- self.gradient_checkpointing = False
-
- def _get_layer(self, layer_number):
- return self.layers[layer_number]
-
- def forward(
- self,
- hidden_states,
- attention_mask,
- rotary_pos_emb,
- kv_caches=None,
- use_cache: Optional[bool] = True,
- output_hidden_states: Optional[bool] = False,
- ):
- if not kv_caches:
- kv_caches = [None for _ in range(self.num_layers)]
- presents = () if use_cache else None
- if self.gradient_checkpointing and self.training:
- if use_cache:
- logger.warning_once(
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
- use_cache = False
-
- all_self_attentions = None
- all_hidden_states = () if output_hidden_states else None
- for index in range(self.num_layers):
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
-
- layer = self._get_layer(index)
- if self.gradient_checkpointing and self.training:
- layer_ret = torch.utils.checkpoint.checkpoint(
- layer,
- hidden_states,
- attention_mask,
- rotary_pos_emb,
- kv_caches[index],
- use_cache,
- )
- else:
- layer_ret = layer(
- hidden_states,
- attention_mask,
- rotary_pos_emb,
- kv_cache=kv_caches[index],
- use_cache=use_cache,
- )
- hidden_states, kv_cache = layer_ret
- if use_cache:
- presents = presents + (kv_cache,)
-
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
-
- # Final layer norm.
- if self.post_layer_norm:
- hidden_states = self.final_layernorm(hidden_states)
-
- return hidden_states, presents, all_hidden_states, all_self_attentions
-
-
-class ChatGLMPreTrainedModel(PreTrainedModel):
- """
- An abstract class to handle weights initialization and
- a simple interface for downloading and loading pretrained models.
- """
-
- is_parallelizable = False
- supports_gradient_checkpointing = True
- config_class = ChatGLMConfig
- base_model_prefix = "transformer"
- _no_split_modules = ["GLMBlock"]
-
- def _init_weights(self, module: nn.Module):
- """Initialize the weights."""
- return
-
- def get_masks(self, input_ids, past_key_values, padding_mask=None):
- batch_size, seq_length = input_ids.shape
- full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device)
- full_attention_mask.tril_()
- past_length = 0
- if past_key_values:
- past_length = past_key_values[0][0].shape[0]
- if past_length:
- full_attention_mask = torch.cat(
- (
- torch.ones(batch_size, seq_length, past_length, device=input_ids.device),
- full_attention_mask,
- ),
- dim=-1,
- )
- if padding_mask is not None:
- full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
- if not past_length and padding_mask is not None:
- full_attention_mask -= padding_mask.unsqueeze(-1) - 1
- full_attention_mask = (full_attention_mask < 0.5).bool()
- full_attention_mask.unsqueeze_(1)
- return full_attention_mask
-
- def get_position_ids(self, input_ids, device):
- batch_size, seq_length = input_ids.shape
- position_ids = (torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1))
- return position_ids
-
- def _set_gradient_checkpointing(self, module, value=False):
- if isinstance(module, GLMTransformer):
- module.gradient_checkpointing = value
-
-
-class Embedding(torch.nn.Module):
- """Language model embeddings."""
-
- def __init__(self, config: ChatGLMConfig, device=None):
- super(Embedding, self).__init__()
-
- self.hidden_size = config.hidden_size
- # Word embeddings (parallel).
- self.word_embeddings = nn.Embedding(
- config.padded_vocab_size,
- self.hidden_size,
- dtype=config.torch_dtype,
- device=device,
- )
- self.fp32_residual_connection = config.fp32_residual_connection
-
- def forward(self, input_ids):
- # Embeddings.
- words_embeddings = self.word_embeddings(input_ids)
- embeddings = words_embeddings
- # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
- embeddings = embeddings.transpose(0, 1).contiguous()
- # If the input flag for fp32 residual connection is set, convert for float.
- if self.fp32_residual_connection:
- embeddings = embeddings.float()
- return embeddings
-
-
-class ChatGLMModel(ChatGLMPreTrainedModel):
-
- def __init__(self, config: ChatGLMConfig, device=None, empty_init=True):
- super().__init__(config)
- if empty_init:
- init_method = skip_init
- else:
- init_method = default_init
- init_kwargs = {}
- if device is not None:
- init_kwargs["device"] = device
- self.embedding = init_method(Embedding, config, **init_kwargs)
- self.num_layers = config.num_layers
- self.multi_query_group_num = config.multi_query_group_num
- self.kv_channels = config.kv_channels
-
- # Rotary positional embeddings
- self.seq_length = config.seq_length
- rotary_dim = (config.hidden_size //
- config.num_attention_heads if config.kv_channels is None else config.kv_channels)
-
- self.rotary_pos_emb = RotaryEmbedding(
- rotary_dim // 2,
- original_impl=config.original_rope,
- device=device,
- dtype=config.torch_dtype,
- )
- self.encoder = init_method(GLMTransformer, config, **init_kwargs)
- self.output_layer = init_method(
- nn.Linear,
- config.hidden_size,
- config.padded_vocab_size,
- bias=False,
- dtype=config.torch_dtype,
- **init_kwargs,
- )
- self.pre_seq_len = config.pre_seq_len
- self.prefix_projection = config.prefix_projection
- if self.pre_seq_len is not None:
- for param in self.parameters():
- param.requires_grad = False
- self.prefix_tokens = torch.arange(self.pre_seq_len).long()
- self.prefix_encoder = PrefixEncoder(config)
- self.dropout = torch.nn.Dropout(0.1)
-
- def get_input_embeddings(self):
- return self.embedding.word_embeddings
-
- def get_prompt(self, batch_size, device, dtype=torch.half):
- prefix_tokens = (self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device))
- past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)
- past_key_values = past_key_values.view(
- batch_size,
- self.pre_seq_len,
- self.num_layers * 2,
- self.multi_query_group_num,
- self.kv_channels,
- )
- # seq_len, b, nh, hidden_size
- past_key_values = self.dropout(past_key_values)
- past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2)
- return past_key_values
-
- def forward(
- self,
- input_ids,
- position_ids: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.BoolTensor] = None,
- full_attention_mask: Optional[torch.BoolTensor] = None,
- past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
- inputs_embeds: Optional[torch.Tensor] = None,
- use_cache: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ):
- output_hidden_states = (output_hidden_states
- if output_hidden_states is not None else self.config.output_hidden_states)
- use_cache = use_cache if use_cache is not None else self.config.use_cache
- return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
-
- batch_size, seq_length = input_ids.shape
-
- if inputs_embeds is None:
- inputs_embeds = self.embedding(input_ids)
-
- if self.pre_seq_len is not None:
- if past_key_values is None:
- past_key_values = self.get_prompt(
- batch_size=batch_size,
- device=input_ids.device,
- dtype=inputs_embeds.dtype,
- )
- if attention_mask is not None:
- attention_mask = torch.cat(
- [
- attention_mask.new_ones((batch_size, self.pre_seq_len)),
- attention_mask,
- ],
- dim=-1,
- )
-
- if full_attention_mask is None:
- if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
- full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
-
- # Rotary positional embeddings
- rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
- if position_ids is not None:
- rotary_pos_emb = rotary_pos_emb[position_ids]
- else:
- rotary_pos_emb = rotary_pos_emb[None, :seq_length]
- rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
-
- # Run encoder.
- hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
- inputs_embeds,
- full_attention_mask,
- rotary_pos_emb=rotary_pos_emb,
- kv_caches=past_key_values,
- use_cache=use_cache,
- output_hidden_states=output_hidden_states,
- )
-
- if not return_dict:
- return tuple(v for v in [
- hidden_states,
- presents,
- all_hidden_states,
- all_self_attentions,
- ] if v is not None)
-
- return BaseModelOutputWithPast(
- last_hidden_state=hidden_states,
- past_key_values=presents,
- hidden_states=all_hidden_states,
- attentions=all_self_attentions,
- )
-
- def quantize(self, weight_bit_width: int):
- from .quantization import quantize
-
- quantize(self.encoder, weight_bit_width)
- return self
-
-
-class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
-
- def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
- super().__init__(config)
-
- self.max_sequence_length = config.max_length
- self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)
- self.config = config
- self.quantized = False
-
- if self.config.quantization_bit:
- self.quantize(self.config.quantization_bit, empty_init=True)
-
- def _update_model_kwargs_for_generation(
- self,
- outputs: ModelOutput,
- model_kwargs: Dict[str, Any],
- is_encoder_decoder: bool = False,
- standardize_cache_format: bool = False,
- ) -> Dict[str, Any]:
- # update past_key_values
- model_kwargs["past_key_values"] = self._extract_past_from_model_output(
- outputs, standardize_cache_format=standardize_cache_format)
-
- # update attention mask
- if "attention_mask" in model_kwargs:
- attention_mask = model_kwargs["attention_mask"]
- model_kwargs["attention_mask"] = torch.cat(
- [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))],
- dim=-1,
- )
-
- # update position ids
- if "position_ids" in model_kwargs:
- position_ids = model_kwargs["position_ids"]
- new_position_id = position_ids[..., -1:].clone()
- new_position_id += 1
- model_kwargs["position_ids"] = torch.cat([position_ids, new_position_id], dim=-1)
-
- model_kwargs["is_first_forward"] = False
- return model_kwargs
-
- def prepare_inputs_for_generation(
- self,
- input_ids: torch.LongTensor,
- past_key_values: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.Tensor] = None,
- is_first_forward: bool = True,
- **kwargs,
- ) -> dict:
- # only last token for input_ids if past is not None
- if position_ids is None:
- position_ids = self.get_position_ids(input_ids, device=input_ids.device)
- if not is_first_forward:
- position_ids = position_ids[..., -1:]
- input_ids = input_ids[:, -1:]
- return {
- "input_ids": input_ids,
- "past_key_values": past_key_values,
- "position_ids": position_ids,
- "attention_mask": attention_mask,
- "return_last_logit": True,
- }
-
- def forward(
- self,
- input_ids: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
- inputs_embeds: Optional[torch.Tensor] = None,
- labels: Optional[torch.Tensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- return_last_logit: Optional[bool] = False,
- ):
- use_cache = use_cache if use_cache is not None else self.config.use_cache
- return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
-
- transformer_outputs = self.transformer(
- input_ids=input_ids,
- position_ids=position_ids,
- attention_mask=attention_mask,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
-
- hidden_states = transformer_outputs[0]
- if return_last_logit:
- hidden_states = hidden_states[-1:]
- lm_logits = self.transformer.output_layer(hidden_states)
- lm_logits = lm_logits.transpose(0, 1).contiguous()
-
- loss = None
- if labels is not None:
- lm_logits = lm_logits.to(torch.float32)
-
- # Shift so that tokens < n predict n
- shift_logits = lm_logits[..., :-1, :].contiguous()
- shift_labels = labels[..., 1:].contiguous()
- # Flatten the tokens
- loss_fct = CrossEntropyLoss(ignore_index=-100)
- loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
-
- lm_logits = lm_logits.to(hidden_states.dtype)
- loss = loss.to(hidden_states.dtype)
-
- if not return_dict:
- output = (lm_logits,) + transformer_outputs[1:]
- return ((loss,) + output) if loss is not None else output
-
- return CausalLMOutputWithPast(
- loss=loss,
- logits=lm_logits,
- past_key_values=transformer_outputs.past_key_values,
- hidden_states=transformer_outputs.hidden_states,
- attentions=transformer_outputs.attentions,
- )
-
- @staticmethod
- def _reorder_cache(past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...],
- beam_idx: torch.LongTensor) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
- """
- This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
- [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
- beam_idx at every generation step.
-
- Output shares the same memory storage as `past`.
- """
- return tuple((
- layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)),
- layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)),
- ) for layer_past in past)
-
- def process_response(self, response):
- response = response.strip()
- response = response.replace("[[训练时间]]", "2023年")
- return response
-
- def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None):
- prompt = tokenizer.build_prompt(query, history=history)
- inputs = tokenizer([prompt], return_tensors="pt")
- inputs = inputs.to(self.device)
- return inputs
-
- def build_stream_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None):
- if history:
- prompt = "\n\n[Round {}]\n\n问:{}\n\n答:".format(len(history) + 1, query)
- input_ids = tokenizer.encode(prompt, add_special_tokens=False)
- input_ids = input_ids[1:]
- inputs = tokenizer.batch_encode_plus([(input_ids, None)], return_tensors="pt", add_special_tokens=False)
- else:
- prompt = "[Round {}]\n\n问:{}\n\n答:".format(len(history) + 1, query)
- inputs = tokenizer([prompt], return_tensors="pt")
- inputs = inputs.to(self.device)
- return inputs
-
- @torch.no_grad()
- def chat(
- self,
- tokenizer,
- query: str,
- history: List[Tuple[str, str]] = None,
- max_length: int = 8192,
- num_beams=1,
- do_sample=True,
- top_p=0.8,
- temperature=0.8,
- logits_processor=None,
- **kwargs,
- ):
- if history is None:
- history = []
- if logits_processor is None:
- logits_processor = LogitsProcessorList()
- logits_processor.append(InvalidScoreLogitsProcessor())
- gen_kwargs = {
- "max_length": max_length,
- "num_beams": num_beams,
- "do_sample": do_sample,
- "top_p": top_p,
- "temperature": temperature,
- "logits_processor": logits_processor,
- **kwargs,
- }
- inputs = self.build_inputs(tokenizer, query, history=history)
- outputs = self.generate(**inputs, **gen_kwargs)
- outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
- response = tokenizer.decode(outputs)
- response = self.process_response(response)
- history = history + [(query, response)]
- return response, history
-
- @torch.no_grad()
- def stream_chat(
- self,
- tokenizer,
- query: str,
- history: List[Tuple[str, str]] = None,
- past_key_values=None,
- max_length: int = 8192,
- do_sample=True,
- top_p=0.8,
- temperature=0.8,
- logits_processor=None,
- return_past_key_values=False,
- **kwargs,
- ):
- if history is None:
- history = []
- if logits_processor is None:
- logits_processor = LogitsProcessorList()
- logits_processor.append(InvalidScoreLogitsProcessor())
- gen_kwargs = {
- "max_length": max_length,
- "do_sample": do_sample,
- "top_p": top_p,
- "temperature": temperature,
- "logits_processor": logits_processor,
- **kwargs,
- }
- if past_key_values is None and not return_past_key_values:
- inputs = self.build_inputs(tokenizer, query, history=history)
- else:
- inputs = self.build_stream_inputs(tokenizer, query, history=history)
- if past_key_values is not None:
- past_length = past_key_values[0][0].shape[0]
- if self.transformer.pre_seq_len is not None:
- past_length -= self.transformer.pre_seq_len
- inputs.position_ids += past_length
- attention_mask = inputs.attention_mask
- attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)
- inputs["attention_mask"] = attention_mask
- for outputs in self.stream_generate(
- **inputs,
- past_key_values=past_key_values,
- return_past_key_values=return_past_key_values,
- **gen_kwargs,
- ):
- if return_past_key_values:
- outputs, past_key_values = outputs
- outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
- response = tokenizer.decode(outputs)
- if response and response[-1] != "�":
- response = self.process_response(response)
- new_history = history + [(query, response)]
- if return_past_key_values:
- yield response, new_history, past_key_values
- else:
- yield response, new_history
-
- @torch.no_grad()
- def stream_generate(
- self,
- input_ids,
- generation_config: Optional[GenerationConfig] = None,
- logits_processor: Optional[LogitsProcessorList] = None,
- stopping_criteria: Optional[StoppingCriteriaList] = None,
- prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
- return_past_key_values=False,
- **kwargs,
- ):
- batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
-
- if generation_config is None:
- generation_config = self.generation_config
- generation_config = copy.deepcopy(generation_config)
- model_kwargs = generation_config.update(**kwargs)
- bos_token_id, eos_token_id = (
- generation_config.bos_token_id,
- generation_config.eos_token_id,
- )
-
- if isinstance(eos_token_id, int):
- eos_token_id = [eos_token_id]
-
- has_default_max_length = (kwargs.get("max_length") is None and generation_config.max_length is not None)
- if has_default_max_length and generation_config.max_new_tokens is None:
- warnings.warn(
- f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
- "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
- " recommend using `max_new_tokens` to control the maximum length of the generation.",
- UserWarning,
- )
- elif generation_config.max_new_tokens is not None:
- generation_config.max_length = (generation_config.max_new_tokens + input_ids_seq_length)
- if not has_default_max_length:
- logger.warn(
- f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
- f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
- "Please refer to the documentation for more information. "
- "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
- UserWarning,
- )
-
- if input_ids_seq_length >= generation_config.max_length:
- input_ids_string = ("decoder_input_ids" if self.config.is_encoder_decoder else "input_ids")
- logger.warning(f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
- f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
- " increasing `max_new_tokens`.")
-
- # 2. Set generation parameters if not already defined
- logits_processor = (logits_processor if logits_processor is not None else LogitsProcessorList())
- stopping_criteria = (stopping_criteria if stopping_criteria is not None else StoppingCriteriaList())
-
- logits_processor = self._get_logits_processor(
- generation_config=generation_config,
- input_ids_seq_length=input_ids_seq_length,
- encoder_input_ids=input_ids,
- prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
- logits_processor=logits_processor,
- )
-
- stopping_criteria = self._get_stopping_criteria(generation_config=generation_config,
- stopping_criteria=stopping_criteria)
- logits_warper = self._get_logits_warper(generation_config)
-
- unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
- scores = None
- while True:
- model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
- # forward pass to get next token
- outputs = self(
- **model_inputs,
- return_dict=True,
- output_attentions=False,
- output_hidden_states=False,
- )
-
- next_token_logits = outputs.logits[:, -1, :]
-
- # pre-process distribution
- next_token_scores = logits_processor(input_ids, next_token_logits)
- next_token_scores = logits_warper(input_ids, next_token_scores)
-
- # sample
- probs = nn.functional.softmax(next_token_scores, dim=-1)
- if generation_config.do_sample:
- next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
- else:
- next_tokens = torch.argmax(probs, dim=-1)
-
- # update generated ids, model inputs, and length for next step
- input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
- model_kwargs = self._update_model_kwargs_for_generation(outputs,
- model_kwargs,
- is_encoder_decoder=self.config.is_encoder_decoder)
- unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long())
- if return_past_key_values:
- yield input_ids, outputs.past_key_values
- else:
- yield input_ids
- # stop when each sentence is finished, or if we exceed the maximum length
- if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
- break
-
- def quantize(self, bits: int, empty_init=False, device=None, **kwargs):
- if bits == 0:
- return
-
- from .quantization import quantize
-
- if self.quantized:
- logger.info("Already quantized.")
- return self
-
- self.quantized = True
-
- self.config.quantization_bit = bits
-
- self.transformer.encoder = quantize(
- self.transformer.encoder,
- bits,
- empty_init=empty_init,
- device=device,
- **kwargs,
- )
- return self
diff --git a/tests/kit/model_zoo/transformers/gpt.py b/tests/kit/model_zoo/transformers/gpt.py
index 73c210221e61..5c3eb4438bc8 100644
--- a/tests/kit/model_zoo/transformers/gpt.py
+++ b/tests/kit/model_zoo/transformers/gpt.py
@@ -18,8 +18,8 @@ def data_gen():
# tokenized_input = tokenizer(input, return_tensors='pt')
# input_ids = tokenized_input['input_ids']
# attention_mask = tokenized_input['attention_mask']
- input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779]], dtype=torch.int64)
- attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1]], dtype=torch.int64)
+ input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64)
+ attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
return dict(input_ids=input_ids, attention_mask=attention_mask)
@@ -46,7 +46,7 @@ def data_gen_for_token_classification():
# token classification data gen
# `labels` is the type not the token id for token classification, 0 or 1
data = data_gen()
- data['labels'] = torch.tensor([[0, 0, 0, 0, 0, 1]], dtype=torch.int64)
+ data['labels'] = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 1]], dtype=torch.int64)
return data
diff --git a/tests/kit/model_zoo/transformers/t5.py b/tests/kit/model_zoo/transformers/t5.py
index 689db2c40abb..435cb6f46937 100644
--- a/tests/kit/model_zoo/transformers/t5.py
+++ b/tests/kit/model_zoo/transformers/t5.py
@@ -16,8 +16,9 @@ def data_gen_for_encoder_only():
# config = T5Config(decoder_start_token_id=0)
# tokenizer = T5Tokenizer.from_pretrained("t5-small")
# input_ids = tokenizer("translate English to German: The house is wonderful.", return_tensors="pt").input_ids
- input_ids = torch.Tensor([[13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5, 1]]).long()
- return dict(input_ids=input_ids)
+ input_ids = torch.Tensor([[13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5, 1, 12]]).long()
+ attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]]).long()
+ return dict(input_ids=input_ids, attention_mask=attention_mask)
def data_gen_for_conditional_generation():
@@ -25,17 +26,16 @@ def data_gen_for_conditional_generation():
#
# labels = tokenizer("Das Haus ist wunderbar.", return_tensors="pt").input_ids
data = data_gen_for_encoder_only()
- labels = torch.Tensor([[644, 4598, 229, 19250, 5, 1]]).long()
+ labels = torch.Tensor([[644, 4598, 229, 19250, 5, 1, 644, 4598, 229, 19250, 5, 1]]).long()
data['labels'] = labels
return data
def data_gen_for_t5_model():
# decoder_inputs_ids is obtained with the following code
- #
# decoder_input_ids = model._shift_right(input_ids)
data = data_gen_for_encoder_only()
- decoder_input_ids = torch.Tensor([[0, 13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5]]).long()
+ decoder_input_ids = torch.Tensor([[0, 13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5, 5]]).long()
data['decoder_input_ids'] = decoder_input_ids
return data
diff --git a/tests/kit/model_zoo/transformers/whisper.py b/tests/kit/model_zoo/transformers/whisper.py
index 40c96a5777ab..f7cdc052aaf0 100644
--- a/tests/kit/model_zoo/transformers/whisper.py
+++ b/tests/kit/model_zoo/transformers/whisper.py
@@ -76,14 +76,14 @@ def data_gen_for_audio_classification():
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True))
-model_zoo.register(name='transformers_whisperForConditionalGeneration',
+model_zoo.register(name='transformers_whisper_for_conditional_generation',
model_fn=lambda: transformers.WhisperForConditionalGeneration(config),
data_gen_fn=data_gen_for_conditional_generation,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_attr,
model_attribute=ModelAttribute(has_control_flow=True))
-model_zoo.register(name='transformers_whisperWhisperForAudioClassification',
+model_zoo.register(name='transformers_whisper_for_audio_classification',
model_fn=lambda: transformers.WhisperForAudioClassification(config),
data_gen_fn=data_gen_for_audio_classification,
output_transform_fn=output_transform_fn,
diff --git a/tests/test_booster/test_plugin/test_gemini_plugin.py b/tests/test_booster/test_plugin/test_gemini_plugin.py
index a06b2c963bfe..fee153baf1ac 100644
--- a/tests/test_booster/test_plugin/test_gemini_plugin.py
+++ b/tests/test_booster/test_plugin/test_gemini_plugin.py
@@ -93,7 +93,7 @@ def check_gemini_plugin(init_method: str = 'none', early_stop: bool = True):
'transformers_vit_for_image_classification', 'transformers_chatglm',
'transformers_chatglm_for_conditional_generation', 'transformers_blip2',
'transformers_blip2_conditional_gerneration', 'transformers_sam', 'transformers_whisper',
- 'transformers_whisperForConditionalGeneration', 'transformers_whisperWhisperForAudioClassification'
+ 'transformers_whisper_for_conditional_generation', 'transformers_whisper_for_audio_classification'
]:
continue
diff --git a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py
index 7181e6c2b31b..97ee22730ea8 100644
--- a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py
+++ b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py
@@ -21,6 +21,7 @@
_STUCK_MODELS = [
'diffusers_vq_model', 'transformers_albert', 'transformers_albert_for_pretraining', 'transformers_bert',
'transformers_bert_for_pretraining', 'transformers_gpt_double_heads', 'transformers_vit',
+ 'transformers_bert_lm_head_model', 'transformers_bert_for_masked_lm',
'transformers_vit_for_masked_image_modeling', 'transformers_vit_for_image_classification', 'transformers_sam',
'transformers_chatglm', 'transformers_chatglm_for_conditional_generation'
]
diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py
index 0e5cb8144ef3..98cdc5a4b95b 100644
--- a/tests/test_shardformer/test_model/_utils.py
+++ b/tests/test_shardformer/test_model/_utils.py
@@ -21,7 +21,13 @@
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
-def build_model(model_fn, enable_fused_normalization=True, enable_tensor_parallelism=True, use_lazy_init: bool = False):
+def build_model(model_fn,
+ enable_fused_normalization=True,
+ enable_tensor_parallelism=True,
+ enable_flash_attention=False,
+ enable_jit_fused=False,
+ use_lazy_init: bool = False):
+ # create new model
ctx = LazyInitContext() if use_lazy_init else nullcontext()
with ctx:
# create new model
@@ -31,7 +37,10 @@ def build_model(model_fn, enable_fused_normalization=True, enable_tensor_paralle
ctx.materialize(org_model)
# shard model
shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
- enable_tensor_parallelism=enable_tensor_parallelism)
+ enable_tensor_parallelism=enable_tensor_parallelism,
+ enable_flash_attention=enable_flash_attention,
+ enable_jit_fused=enable_jit_fused)
+ model_copy = copy.deepcopy(org_model)
shard_former = ShardFormer(shard_config=shard_config)
sharded_model, shared_params = shard_former.optimize(model_copy)
return org_model.cuda(), sharded_model.cuda()
diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py
index 1d42f1c4703e..afc1507e8b24 100644
--- a/tests/test_shardformer/test_model/test_shard_bert.py
+++ b/tests/test_shardformer/test_model/test_shard_bert.py
@@ -46,14 +46,17 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
check_grad(bert, sharded_bert, row_layer_for_check, atol=1e-7, rtol=1e-3, dim=1, verbose=False)
-@parameterize('enable_fused_normalization', [False, True])
-@parameterize('enable_tensor_parallelism', [False, True])
+@parameterize('enable_fused_normalization', [True, False])
+@parameterize('enable_tensor_parallelism', [True, False])
+@parameterize('enable_flash_attention', [True, False])
+@parameterize('enable_jit_fused', [True, False])
@parameterize('use_lazy_init', [False, True])
-def run_bert_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
+def run_bert_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused,
+ use_lazy_init):
sub_model_zoo = model_zoo.get_sub_registry('transformers_bert')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
- use_lazy_init)
+ enable_flash_attention, enable_jit_fused, use_lazy_init)
check_state_dict(org_model, sharded_model, name=name)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
diff --git a/tests/test_shardformer/test_model/test_shard_blip2.py b/tests/test_shardformer/test_model/test_shard_blip2.py
index cb9725f4de7f..cd034d0c139a 100644
--- a/tests/test_shardformer/test_model/test_shard_blip2.py
+++ b/tests/test_shardformer/test_model/test_shard_blip2.py
@@ -47,10 +47,13 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
@parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False])
-def run_blip2_test(enable_fused_normalization, enable_tensor_parallelism):
+@parameterize('enable_flash_attention', [True, False])
+@parameterize('enable_jit_fused', [True, False])
+def run_blip2_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused):
sub_model_zoo = model_zoo.get_sub_registry('transformers_blip2')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
- org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
+ org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
+ enable_flash_attention, enable_jit_fused)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache()
diff --git a/tests/test_shardformer/test_model/test_shard_bloom.py b/tests/test_shardformer/test_model/test_shard_bloom.py
index c13596fe8db3..e11bcf92ea3c 100644
--- a/tests/test_shardformer/test_model/test_shard_bloom.py
+++ b/tests/test_shardformer/test_model/test_shard_bloom.py
@@ -44,13 +44,15 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
@parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False])
+@parameterize('enable_flash_attention', [True, False])
+@parameterize('enable_jit_fused', [True, False])
@parameterize('use_lazy_init', [False, True])
-def run_bloom_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
+def run_bloom_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused,
+ use_lazy_init):
sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
- use_lazy_init)
- check_state_dict(org_model, sharded_model, name=name)
+ enable_flash_attention, enable_jit_fused, use_lazy_init)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache()
diff --git a/tests/test_shardformer/test_model/test_shard_chatglm.py b/tests/test_shardformer/test_model/test_shard_chatglm.py
index 005223fb8ae4..c455a99d26ce 100644
--- a/tests/test_shardformer/test_model/test_shard_chatglm.py
+++ b/tests/test_shardformer/test_model/test_shard_chatglm.py
@@ -72,7 +72,9 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
@parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False])
-def run_chatglm_test(enable_fused_normalization, enable_tensor_parallelism):
+@parameterize('enable_flash_attention', [True, False])
+@parameterize('enable_jit_fused', [True, False])
+def run_chatglm_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused):
sub_model_zoo = model_zoo.get_sub_registry('transformers_chatglm')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
# create new model
@@ -80,7 +82,9 @@ def run_chatglm_test(enable_fused_normalization, enable_tensor_parallelism):
# shard model
shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
- enable_tensor_parallelism=enable_tensor_parallelism)
+ enable_tensor_parallelism=enable_tensor_parallelism,
+ enable_flash_attention=enable_flash_attention,
+ enable_jit_fused=enable_jit_fused)
model_copy = copy.deepcopy(org_model)
shard_former = ShardFormer(shard_config=shard_config)
if name == "transformers_chatglm":
diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py
index cebb40bd16fe..f7213d8c50b4 100644
--- a/tests/test_shardformer/test_model/test_shard_gpt2.py
+++ b/tests/test_shardformer/test_model/test_shard_gpt2.py
@@ -68,7 +68,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
torch.cuda.empty_cache()
-
@parameterize('test_config', [{
'tp_size': 1,
'pp_size': 2,
diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py
index 2cfc172c8df6..ead14ab111e6 100644
--- a/tests/test_shardformer/test_model/test_shard_llama.py
+++ b/tests/test_shardformer/test_model/test_shard_llama.py
@@ -49,12 +49,13 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
@parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False])
+@parameterize('enable_flash_attention', [True, False])
@parameterize('use_lazy_init', [False, True])
-def run_gpt2_llama(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
+def run_gpt2_llama(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, use_lazy_init):
sub_model_zoo = model_zoo.get_sub_registry('transformers_llama')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
- use_lazy_init)
+ enable_flash_attention, use_lazy_init)
check_state_dict(org_model, sharded_model, name=name)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache()
diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py
index 4684bacb4788..99a278d4303a 100644
--- a/tests/test_shardformer/test_model/test_shard_opt.py
+++ b/tests/test_shardformer/test_model/test_shard_opt.py
@@ -42,18 +42,21 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
# check grad
col_layer_for_check = ['decoder.layers[0].self_attn.q_proj', 'decoder.embed_tokens']
row_layer_for_check = ['decoder.layers[0].self_attn.out_proj']
- check_grad(opt_model, shard_opt_model, col_layer_for_check, atol=1e-7, rtol=1e-3, dim=0, verbose=False)
- check_grad(opt_model, shard_opt_model, row_layer_for_check, atol=1e-7, rtol=1e-3, dim=1, verbose=False)
+ check_grad(opt_model, shard_opt_model, col_layer_for_check, atol=1e-6, rtol=1e-3, dim=0, verbose=False)
+ check_grad(opt_model, shard_opt_model, row_layer_for_check, atol=1e-6, rtol=1e-3, dim=1, verbose=False)
+@parameterize('use_lazy_init', [False, True])
@parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False])
-@parameterize('use_lazy_init', [False, True])
-def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
+@parameterize('enable_flash_attention', [True, False])
+@parameterize('enable_jit_fused', [True, False])
+def run_opt_test(use_lazy_init, enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention,
+ enable_jit_fused):
sub_model_zoo = model_zoo.get_sub_registry('transformers_opt')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
- use_lazy_init)
+ enable_flash_attention, enable_jit_fused, use_lazy_init)
check_state_dict(org_model, sharded_model, name=name)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache()
@@ -62,7 +65,7 @@ def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_
def check_OPTModel(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
- run_t5_test()
+ run_opt_test()
@pytest.mark.dist
diff --git a/tests/test_shardformer/test_model/test_shard_sam.py b/tests/test_shardformer/test_model/test_shard_sam.py
index e7748cfd189d..616104cd7828 100644
--- a/tests/test_shardformer/test_model/test_shard_sam.py
+++ b/tests/test_shardformer/test_model/test_shard_sam.py
@@ -41,10 +41,12 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
@parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False])
-def run_sam_test(enable_fused_normalization, enable_tensor_parallelism):
+@parameterize('enable_flash_attention', [True, False])
+def run_sam_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention):
sub_model_zoo = model_zoo.get_sub_registry('transformers_sam')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
- org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
+ org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
+ enable_flash_attention)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache()
diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py
index 024c5016b0c1..22f04c879879 100644
--- a/tests/test_shardformer/test_model/test_shard_t5.py
+++ b/tests/test_shardformer/test_model/test_shard_t5.py
@@ -33,8 +33,8 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
# check grad
col_layer_for_check = ['encoder.block[0].layer[0].SelfAttention.q', 'shared']
row_layer_for_check = ['encoder.block[0].layer[0].SelfAttention.relative_attention_bias']
- check_grad(org_model, sharded_model, col_layer_for_check, atol=1e-7, rtol=1e-5, dim=0, verbose=False)
- check_grad(org_model, sharded_model, row_layer_for_check, atol=1e-7, rtol=1e-5, dim=1, verbose=False)
+ check_grad(org_model, sharded_model, col_layer_for_check, atol=1e-6, rtol=1e-5, dim=0, verbose=False)
+ check_grad(org_model, sharded_model, row_layer_for_check, atol=1e-6, rtol=1e-5, dim=1, verbose=False)
# check weights are tied
if hasattr(org_model, 'lm_head'):
@@ -45,11 +45,14 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
@parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False])
@parameterize('use_lazy_init', [False, True])
-def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
+@parameterize('enable_flash_attention', [True, False])
+@parameterize('enable_jit_fused', [True, False])
+def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init, enable_flash_attention,
+ enable_jit_fused):
sub_model_zoo = model_zoo.get_sub_registry('transformers_t5')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
- use_lazy_init)
+ enable_flash_attention, enable_jit_fused, use_lazy_init)
check_state_dict(org_model, sharded_model, name=name)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache()
diff --git a/tests/test_shardformer/test_model/test_shard_vit.py b/tests/test_shardformer/test_model/test_shard_vit.py
index 7833ab70275d..d179c8a8ee32 100644
--- a/tests/test_shardformer/test_model/test_shard_vit.py
+++ b/tests/test_shardformer/test_model/test_shard_vit.py
@@ -20,7 +20,9 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
# check forward
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
output_transform_fn, loss_fn)
+
assert_hf_output_close(org_output, shard_output, atol=1e-3, rtol=1e-3)
+
# do backward
org_loss.backward()
shard_loss.backward()
@@ -45,10 +47,13 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
@parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False])
-def run_vit_test(enable_fused_normalization, enable_tensor_parallelism):
+@parameterize('enable_flash_attention', [True, False])
+@parameterize('enable_jit_fused', [True, False])
+def run_vit_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused):
sub_model_zoo = model_zoo.get_sub_registry('transformers_vit')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
- org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
+ org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
+ enable_flash_attention, enable_jit_fused)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache()
diff --git a/tests/test_shardformer/test_model/test_shard_whisper.py b/tests/test_shardformer/test_model/test_shard_whisper.py
index a271bbdf1223..9b38ae07b1d6 100644
--- a/tests/test_shardformer/test_model/test_shard_whisper.py
+++ b/tests/test_shardformer/test_model/test_shard_whisper.py
@@ -48,12 +48,16 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
@parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False])
-def run_whisper_test(enable_fused_normalization, enable_tensor_parallelism):
+@parameterize('enable_flash_attention', [True, False])
+@parameterize('enable_jit_fused', [True, False])
+def run_whisper_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused):
sub_model_zoo = model_zoo.get_sub_registry('transformers_whisper')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn,
enable_fused_normalization=enable_fused_normalization,
- enable_tensor_parallelism=enable_tensor_parallelism)
+ enable_tensor_parallelism=enable_tensor_parallelism,
+ enable_flash_attention=enable_flash_attention,
+ enable_jit_fused=enable_jit_fused)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache()
diff --git a/tests/test_utils/test_flash_attention.py b/tests/test_utils/test_flash_attention.py
index 7a28b0157384..938f85b410e1 100644
--- a/tests/test_utils/test_flash_attention.py
+++ b/tests/test_utils/test_flash_attention.py
@@ -24,8 +24,9 @@ def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale):
@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available")
@clear_cache_before_run()
-@parameterize('B, S, H, D_HEAD', [(6, 8, 4, 16)])
-def test_attention_gpt(B, S, H, D_HEAD, dtype=torch.float16):
+@parameterize('proj_shape', [(1, 128, 4, 16)])
+def test_attention_gpt(proj_shape, dtype=torch.float16):
+ (B, S, H, D_HEAD) = proj_shape
D = H * D_HEAD
c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda")
@@ -35,7 +36,11 @@ def test_attention_gpt(B, S, H, D_HEAD, dtype=torch.float16):
qkv = c_attn(x)
q, k, v = rearrange(qkv, 'b s (n h d) -> n b s h d', n=3, h=H)
- y = attn(q, k, v, attn_mask_type=AttnMaskType.causal)
+
+ mask = [torch.ones(S - i, dtype=dtype, device="cuda") for i in range(B)]
+ mask = torch.nn.utils.rnn.pad_sequence(mask, batch_first=True)
+
+ y = attn(q, k, v, attn_mask=mask, attn_mask_type=AttnMaskType.paddedcausal)
assert list(y.shape) == [B, S, D]
@@ -45,8 +50,9 @@ def test_attention_gpt(B, S, H, D_HEAD, dtype=torch.float16):
@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available")
@clear_cache_before_run()
-@parameterize('B, S, H, D_HEAD', [(6, 8, 4, 16)])
-def test_attention_bert(B, S, H, D_HEAD, dtype=torch.float16):
+@parameterize('proj_shape', [(1, 128, 4, 16)])
+def test_attention_bert(proj_shape, dtype=torch.float16):
+ (B, S, H, D_HEAD) = proj_shape
D = H * D_HEAD
c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda")
@@ -69,8 +75,9 @@ def test_attention_bert(B, S, H, D_HEAD, dtype=torch.float16):
@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available")
@clear_cache_before_run()
-@parameterize('B, S, H, D_HEAD', [(6, 8, 4, 16)])
-def test_attention_no_mask(B, S, H, D_HEAD, dtype=torch.float16):
+@parameterize('proj_shape', [(6, 128, 4, 16)])
+def test_attention_no_mask(proj_shape, dtype=torch.float16):
+ (B, S, H, D_HEAD) = proj_shape
D = H * D_HEAD
c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda")
@@ -89,8 +96,9 @@ def test_attention_no_mask(B, S, H, D_HEAD, dtype=torch.float16):
@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available")
@clear_cache_before_run()
-@parameterize('B, S, T, H, D_HEAD', [(6, 24, 8, 4, 16)])
-def test_cross_attention(B, S, T, H, D_HEAD, dtype=torch.float16):
+@parameterize('proj_shape', [(6, 128, 256, 4, 16)])
+def test_cross_attention(proj_shape, dtype=torch.float16):
+ (B, S, T, H, D_HEAD) = proj_shape
D = H * D_HEAD
q_attn = torch.nn.Linear(D, D, dtype=dtype, device="cuda")
From 2e77e57e408a1d387b5eab7264912c1bf394895d Mon Sep 17 00:00:00 2001
From: Baizhou Zhang
Date: Tue, 8 Aug 2023 17:46:44 +0800
Subject: [PATCH 52/64] [pipeline] rewrite t5 tests & support multi-tensor
transmitting in pipeline (#4388)
* fix remaining t5 bugs/rewrite t5 tests
* fix multi-tensor communication in pipeline
* rearrange test_config
* fix keyerror in sync_shared_params
* fix get_held_layers & Randomnizer, complete t5 tests
* erase printing
* fix get_held_layers through modifying _release_unheld_layers
* fix _get_recursive_held_layers bug
---
.../booster/plugin/hybrid_parallel_plugin.py | 6 +-
colossalai/pipeline/p2p.py | 6 +-
colossalai/pipeline/schedule/_utils.py | 2 +-
colossalai/pipeline/schedule/one_f_one_b.py | 11 +-
colossalai/shardformer/layer/utils.py | 7 +
colossalai/shardformer/modeling/t5.py | 95 +++++------
colossalai/shardformer/policies/t5.py | 51 ++----
colossalai/shardformer/shard/sharder.py | 16 +-
.../test_model/test_shard_gpt2.py | 7 +-
.../test_model/test_shard_t5.py | 150 ++++++++++++------
.../test_model/test_shard_t5_pipeline.py | 101 ------------
11 files changed, 201 insertions(+), 251 deletions(-)
delete mode 100644 tests/test_shardformer/test_model/test_shard_t5_pipeline.py
diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py
index a22bdb7199bb..42942aaeb89d 100644
--- a/colossalai/booster/plugin/hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py
@@ -50,8 +50,10 @@ def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp
def sync_shared_params(self):
for shared_param, group in zip(self.shared_params, self.shared_param_process_groups):
- param = shared_param[self.stage_manager.stage]
- dist.all_reduce(param.grad, group=group)
+ if self.stage_manager.stage in shared_param:
+ param = shared_param[self.stage_manager.stage]
+ dist.all_reduce(param.grad, group=group)
+ dist.barrier()
def no_sync(self) -> Iterator[None]:
# no sync grads across data parallel
diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py
index f741b8363f13..af7a00b5c720 100644
--- a/colossalai/pipeline/p2p.py
+++ b/colossalai/pipeline/p2p.py
@@ -3,6 +3,7 @@
import io
import pickle
+import re
from typing import Any, List, Optional, Union
import torch
@@ -31,7 +32,10 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -
if b'cuda' in buf:
buf_array = bytearray(buf)
device_index = torch.cuda.current_device()
- buf_array[buf_array.find(b'cuda') + 5] = 48 + device_index
+ # There might be more than one output tensors during forward
+ for cuda_str in re.finditer(b'cuda', buf_array):
+ pos = cuda_str.start()
+ buf_array[pos + 5] = 48 + device_index
buf = bytes(buf_array)
io_bytes = io.BytesIO(buf)
diff --git a/colossalai/pipeline/schedule/_utils.py b/colossalai/pipeline/schedule/_utils.py
index 045c86e40e63..3ed9239272f1 100644
--- a/colossalai/pipeline/schedule/_utils.py
+++ b/colossalai/pipeline/schedule/_utils.py
@@ -86,7 +86,7 @@ def retain_grad(x: Any) -> None:
Args:
x (Any): Object to be called.
"""
- if isinstance(x, torch.Tensor):
+ if isinstance(x, torch.Tensor) and x.requires_grad:
x.retain_grad()
diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py
index d907d53edcde..ade3cf456fe3 100644
--- a/colossalai/pipeline/schedule/one_f_one_b.py
+++ b/colossalai/pipeline/schedule/one_f_one_b.py
@@ -107,8 +107,15 @@ def backward_step(self, optimizer: OptimizerWrapper, input_obj: Optional[dict],
if output_obj_grad is None:
optimizer.backward(output_obj)
else:
- for k, grad in output_obj_grad.items():
- optimizer.backward_by_grad(output_obj[k], grad)
+ if "backward_tensor_keys" not in output_obj:
+ for k, grad in output_obj_grad.items():
+ optimizer.backward_by_grad(output_obj[k], grad)
+ else:
+ for k, grad in output_obj_grad.items():
+ output_obj[k].grad = grad
+ for k in output_obj["backward_tensor_keys"]:
+ tensor_to_backward = output_obj[k]
+ optimizer.backward_by_grad(tensor_to_backward, tensor_to_backward.grad)
# Collect the grad of the input_obj.
input_obj_grad = None
diff --git a/colossalai/shardformer/layer/utils.py b/colossalai/shardformer/layer/utils.py
index f2ac6563c46f..09cb7bfe1407 100644
--- a/colossalai/shardformer/layer/utils.py
+++ b/colossalai/shardformer/layer/utils.py
@@ -122,6 +122,13 @@ def increment_index():
"""
Randomizer._INDEX += 1
+ @staticmethod
+ def reset_index():
+ """
+ Reset the index to zero.
+ """
+ Randomizer._INDEX = 0
+
@staticmethod
def is_randomizer_index_synchronized(process_group: ProcessGroup = None):
"""
diff --git a/colossalai/shardformer/modeling/t5.py b/colossalai/shardformer/modeling/t5.py
index 0b3486e87c7e..d622da452366 100644
--- a/colossalai/shardformer/modeling/t5.py
+++ b/colossalai/shardformer/modeling/t5.py
@@ -238,7 +238,8 @@ def custom_forward(*inputs):
return {
'hidden_states': hidden_states,
'position_bias': position_bias,
- 'encoder_decoder_position_bias': encoder_decoder_position_bias
+ 'encoder_decoder_position_bias': encoder_decoder_position_bias,
+ 'backward_tensor_keys': ['hidden_states']
}
@staticmethod
@@ -261,8 +262,10 @@ def t5_model_forward(
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
position_bias: Optional[torch.Tensor] = None,
encoder_decoder_position_bias: Optional[torch.Tensor] = None,
+ backward_tensor_keys: Optional[List[str]] = None,
stage_index: Optional[List[int]] = None,
decoder_starting_stage: Optional[int] = None,
) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]:
@@ -303,7 +306,6 @@ def t5_model_forward(
decoder_head_mask = head_mask
in_decoder = stage_manager.stage >= decoder_starting_stage
-
# Stage is in encoder, directly return the output of t5_stack_forward
if not in_decoder:
encoder_outputs = T5PipelineForwards.t5_stack_forward(
@@ -323,25 +325,18 @@ def t5_model_forward(
decoder_starting_stage=decoder_starting_stage)
if stage_manager.stage == decoder_starting_stage - 1:
# last stage of encoder
- return {'encoder_outputs': encoder_outputs}
+ return {'encoder_hidden_states': encoder_outputs[0]}
else:
return encoder_outputs
at_last_decoder_stage = stage_manager.is_last_stage()
at_first_decoder_stage = stage_manager.stage == decoder_starting_stage
- if encoder_outputs is None:
- raise ValueError("Non-empty encoder_outputs should be passed in at decoder stages.")
-
- encoder_hidden_states = encoder_outputs[0]
- if return_dict and not isinstance(encoder_outputs, BaseModelOutput):
- encoder_outputs = BaseModelOutput(
- last_hidden_state=encoder_outputs[0],
- hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
- attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
- )
+ if encoder_outputs is not None:
+ encoder_hidden_states = encoder_outputs[0]
+ elif encoder_hidden_states is None:
+ raise ValueError("Non-empty encoder_hidden_states should be passed in at decoder stages.")
- # Stage is in decoder, we assume that the outputs of last stage of encoder will be passed in.
if not at_first_decoder_stage and hidden_states is None:
raise ValueError("If not at the first layer of decoder, non-empty hidden_states must be provided.")
@@ -360,6 +355,7 @@ def t5_model_forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
+ stage_manager=stage_manager,
hidden_states=hidden_states,
position_bias=position_bias,
encoder_decoder_position_bias=encoder_decoder_position_bias,
@@ -368,22 +364,19 @@ def t5_model_forward(
# Directly return outputs of overloaded T5Stack forward if not at last stage.
if not at_last_decoder_stage:
- decoder_outputs['encoder_outputs'] = encoder_outputs # encoder_outputs should be passed to the next stage
+ # encoder_hidden_states should be passed to the next stage
+ decoder_outputs['encoder_hidden_states'] = encoder_hidden_states
return decoder_outputs
if not return_dict:
- return decoder_outputs + encoder_outputs
-
- return Seq2SeqModelOutput(
- last_hidden_state=decoder_outputs.last_hidden_state,
- past_key_values=decoder_outputs.past_key_values,
- decoder_hidden_states=decoder_outputs.hidden_states,
- decoder_attentions=decoder_outputs.attentions,
- cross_attentions=decoder_outputs.cross_attentions,
- encoder_last_hidden_state=encoder_outputs.last_hidden_state,
- encoder_hidden_states=encoder_outputs.hidden_states,
- encoder_attentions=encoder_outputs.attentions,
- )
+ return decoder_outputs + encoder_hidden_states
+ else:
+ return Seq2SeqModelOutput(last_hidden_state=decoder_outputs.last_hidden_state,
+ past_key_values=decoder_outputs.past_key_values,
+ decoder_hidden_states=decoder_outputs.hidden_states,
+ decoder_attentions=decoder_outputs.attentions,
+ cross_attentions=decoder_outputs.cross_attentions,
+ encoder_last_hidden_state=encoder_hidden_states)
@staticmethod
def t5_for_conditional_generation_forward(
@@ -406,8 +399,10 @@ def t5_for_conditional_generation_forward(
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
position_bias: Optional[torch.Tensor] = None,
encoder_decoder_position_bias: Optional[torch.Tensor] = None,
+ backward_tensor_keys: Optional[List[str]] = None,
stage_index: Optional[List[int]] = None,
decoder_starting_stage: Optional[int] = None,
) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
@@ -468,28 +463,25 @@ def t5_for_conditional_generation_forward(
decoder_starting_stage=decoder_starting_stage)
if stage_manager.stage == decoder_starting_stage - 1:
# last stage of encoder
- return {'encoder_outputs': encoder_outputs}
+ return {'encoder_hidden_states': encoder_outputs[0]}
else:
return encoder_outputs
at_last_decoder_stage = stage_manager.is_last_stage()
at_first_decoder_stage = stage_manager.stage == decoder_starting_stage
- if encoder_outputs is None:
- raise ValueError("Non-empty encoder_outputs should be passed in at decoder stages.")
+ if encoder_outputs is not None:
+ encoder_hidden_states = encoder_outputs[0]
+ elif encoder_hidden_states is None:
+ raise ValueError("Non-empty encoder_hidden_states should be passed in at decoder stages.")
- encoder_hidden_states = encoder_outputs[0]
- if return_dict and not isinstance(encoder_outputs, BaseModelOutput):
- encoder_outputs = BaseModelOutput(
- last_hidden_state=encoder_outputs[0],
- hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
- attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
- )
-
- # Stage is in decoder, we assume that the outputs of last stage of encoder will be passed in.
if not at_first_decoder_stage and hidden_states is None:
raise ValueError("If not at the first layer of decoder, non-empty hidden_states must be provided.")
+ if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
+ # get decoder inputs from shifting lm labels to the right
+ decoder_input_ids = self._shift_right(labels)
+
# Decode
decoder_outputs = T5PipelineForwards.t5_stack_forward(
self.decoder,
@@ -505,6 +497,7 @@ def t5_for_conditional_generation_forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
+ stage_manager=stage_manager,
hidden_states=hidden_states,
position_bias=position_bias,
encoder_decoder_position_bias=encoder_decoder_position_bias,
@@ -513,7 +506,8 @@ def t5_for_conditional_generation_forward(
# Directly return outputs of overloaded T5Stack forward if not at last stage.
if not at_last_decoder_stage:
- decoder_outputs['encoder_outputs'] = encoder_outputs # encoder_outputs should be passed to the next stage
+ # encoder_hidden_states should be passed to the next stage
+ decoder_outputs['encoder_hidden_states'] = encoder_hidden_states
return decoder_outputs
sequence_output = decoder_outputs[0]
@@ -533,20 +527,16 @@ def t5_for_conditional_generation_forward(
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
if not return_dict:
- output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
+ output = (lm_logits,) + decoder_outputs[1:] + encoder_hidden_states
return ((loss,) + output) if loss is not None else output
- return Seq2SeqLMOutput(
- loss=loss,
- logits=lm_logits,
- past_key_values=decoder_outputs.past_key_values,
- decoder_hidden_states=decoder_outputs.hidden_states,
- decoder_attentions=decoder_outputs.attentions,
- cross_attentions=decoder_outputs.cross_attentions,
- encoder_last_hidden_state=encoder_outputs.last_hidden_state,
- encoder_hidden_states=encoder_outputs.hidden_states,
- encoder_attentions=encoder_outputs.attentions,
- )
+ return Seq2SeqLMOutput(loss=loss,
+ logits=lm_logits,
+ past_key_values=decoder_outputs.past_key_values,
+ decoder_hidden_states=decoder_outputs.hidden_states,
+ decoder_attentions=decoder_outputs.attentions,
+ cross_attentions=decoder_outputs.cross_attentions,
+ encoder_last_hidden_state=encoder_hidden_states)
@staticmethod
def t5_encoder_model_forward(
@@ -562,6 +552,7 @@ def t5_encoder_model_forward(
hidden_states: Optional[torch.FloatTensor] = None,
position_bias: Optional[torch.Tensor] = None,
encoder_decoder_position_bias: Optional[torch.Tensor] = None,
+ backward_tensor_keys: Optional[List[str]] = None,
stage_index: Optional[List[int]] = None,
decoder_starting_stage: Optional[int] = None,
) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py
index 5e78ae9093fa..2ef52c214c6b 100644
--- a/colossalai/shardformer/policies/t5.py
+++ b/colossalai/shardformer/policies/t5.py
@@ -260,7 +260,7 @@ def get_held_layers(self) -> List[nn.Module]:
model = self.model
encoder = self.model.encoder
- decoder = self.model.__dict__.get('decoder', None)
+ decoder = getattr(self.model, 'decoder', None)
num_encoder_layers = len(encoder.block)
num_decoder_layers = len(decoder.block) if decoder else 0
@@ -300,7 +300,7 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli
stage_manager = self.pipeline_stage_manager
encoder = self.model.encoder
- decoder = self.model.__dict__.get('decoder', None)
+ decoder = getattr(self.model, 'decoder', None)
num_encoder_layers = len(encoder.block)
num_decoder_layers = len(decoder.block) if decoder else 0
@@ -355,15 +355,6 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]:
return [{0: module.shared.weight, decoder_starting_stage: module.decoder.embed_tokens.weight}]
return []
- def postprocess(self):
- if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None:
- binding_map = {"shared.weight": ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]}
- for k, v in binding_map.items():
- src = getattr_(self.model, k)
- for dst in v:
- setattr_(self.model, dst, src)
- return self.model
-
class T5ForConditionalGenerationPolicy(T5BasePolicy):
@@ -409,28 +400,21 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]:
stage_manager.num_stages)
shared_params = []
+ shared_embedding = {}
if id(module.decoder.embed_tokens.weight) == id(module.shared.weight):
- shared_params.append({
- 0: module.shared.weight,
- decoder_starting_stage: module.decoder.embed_tokens.weight
- })
+ shared_embedding[0] = module.shared.weight
+ shared_embedding[decoder_starting_stage] = module.decoder.embed_tokens.weight
+
if id(module.lm_head.weight) == id(module.shared.weight):
- shared_params.append({0: module.shared.weight, stage_manager.num_stages - 1: module.lm_head.weight})
- return shared_params
- return []
+ shared_embedding[0] = module.shared.weight
+ shared_embedding[stage_manager.num_stages - 1] = module.lm_head.weight
- def postprocess(self):
- super().postprocess()
- if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None:
- binding_map = {
- "shared.weight": ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
- }
- for k, v in binding_map.items():
- src = getattr_(self.model, k)
- for dst in v:
- setattr_(self.model, dst, src)
+ if len(shared_embedding) > 0:
+ shared_params.append(shared_embedding)
- return self.model
+ return shared_params
+
+ return []
class T5EncoderPolicy(T5BasePolicy):
@@ -462,12 +446,3 @@ def get_held_layers(self) -> List[nn.Module]:
def get_shared_params(self) -> List[Dict[int, Tensor]]:
return []
-
- def postprocess(self):
- if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None:
- binding_map = {"shared.weight": ["encoder.embed_tokens.weight"]}
- for k, v in binding_map.items():
- src = getattr_(self.model, k)
- for dst in v:
- setattr_(self.model, dst, src)
- return self.model
diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py
index ae8cd8c6e553..0ed745a1fc4a 100644
--- a/colossalai/shardformer/shard/sharder.py
+++ b/colossalai/shardformer/shard/sharder.py
@@ -198,6 +198,20 @@ def _replace_sub_module(
setattr_(org_layer, suffix, replace_layer)
+ def _get_recursive_held_layers(self, held_layers: Optional[List[nn.Module]]) -> Optional[List[nn.Module]]:
+
+ def collect_sub_modules(module: nn.Module):
+ if module is None:
+ return
+ recursive_held_layers.append(module)
+ for name, child in module.named_children():
+ collect_sub_modules(child)
+
+ recursive_held_layers = []
+ for module in held_layers:
+ collect_sub_modules(module)
+ return recursive_held_layers
+
def _release_unheld_layers(self) -> Optional[Set[nn.Module]]:
r"""
Release the unheld layers in the model
@@ -205,7 +219,7 @@ def _release_unheld_layers(self) -> Optional[Set[nn.Module]]:
if self.shard_config and self.shard_config.pipeline_stage_manager:
held_layers = self.policy.get_held_layers()
set_tensors_to_none(self.model, exclude=set(held_layers))
- return set(held_layers)
+ return set(self._get_recursive_held_layers(held_layers))
return None
def _materialize(self) -> None:
diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py
index f7213d8c50b4..1882bf7822cc 100644
--- a/tests/test_shardformer/test_model/test_shard_gpt2.py
+++ b/tests/test_shardformer/test_model/test_shard_gpt2.py
@@ -68,16 +68,17 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
torch.cuda.empty_cache()
+
@parameterize('test_config', [{
- 'tp_size': 1,
+ 'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
+ 'enable_fused_normalization': True,
'use_lazy_init': True
}, {
- 'tp_size': 2,
+ 'tp_size': 1,
'pp_size': 2,
'num_microbatches': 4,
- 'enable_fused_normalization': False,
'use_lazy_init': False
}, {
'tp_size': 4,
diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py
index 22f04c879879..d807ffa06296 100644
--- a/tests/test_shardformer/test_model/test_shard_t5.py
+++ b/tests/test_shardformer/test_model/test_shard_t5.py
@@ -1,60 +1,110 @@
-import os
-
import pytest
import torch
import colossalai
from colossalai.logging import disable_existing_loggers
-from colossalai.testing import (
- assert_hf_output_close,
- clear_cache_before_run,
- parameterize,
- rerun_if_address_is_in_use,
- spawn,
-)
+from colossalai.shardformer.layer.utils import Randomizer
+from colossalai.tensor.d_tensor.api import clear_layout_converter
+from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
-from tests.test_shardformer.test_model._utils import build_model, check_grad, check_state_dict, run_forward
-
-
-def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
- # check forward
- # the value "past_key_values" is sharded, so we ignore
- org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
- output_transform_fn, loss_fn)
- assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'], atol=1e-5)
-
- # do backward
- org_loss.backward()
- shard_loss.backward()
-
- assert torch.allclose(org_loss, shard_loss,
- atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
-
- # check grad
- col_layer_for_check = ['encoder.block[0].layer[0].SelfAttention.q', 'shared']
- row_layer_for_check = ['encoder.block[0].layer[0].SelfAttention.relative_attention_bias']
- check_grad(org_model, sharded_model, col_layer_for_check, atol=1e-6, rtol=1e-5, dim=0, verbose=False)
- check_grad(org_model, sharded_model, row_layer_for_check, atol=1e-6, rtol=1e-5, dim=1, verbose=False)
-
- # check weights are tied
- if hasattr(org_model, 'lm_head'):
- assert org_model.shared.weight.data.data_ptr() == org_model.lm_head.weight.data.data_ptr()
- assert sharded_model.shared.weight.data.data_ptr() == sharded_model.lm_head.weight.data.data_ptr()
-
-
-@parameterize('enable_fused_normalization', [True, False])
-@parameterize('enable_tensor_parallelism', [True, False])
-@parameterize('use_lazy_init', [False, True])
-@parameterize('enable_flash_attention', [True, False])
-@parameterize('enable_jit_fused', [True, False])
-def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init, enable_flash_attention,
- enable_jit_fused):
+from tests.test_shardformer.test_model._utils import (
+ build_model_from_hybrid_plugin,
+ check_grad,
+ check_loss,
+ check_output_hidden_state,
+ check_weight,
+ run_forward_backward_with_hybrid_plugin,
+)
+
+
+def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
+
+ org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \
+ build_model_from_hybrid_plugin(model_fn, loss_fn, test_config)
+
+ org_loss, org_output, sharded_loss, sharded_output = \
+ run_forward_backward_with_hybrid_plugin(
+ org_model,
+ sharded_model,
+ sharded_optimizer,
+ data_gen_fn,
+ output_transform_fn,
+ criterion,
+ booster)
+
+ stage_manager = booster.plugin.stage_manager
+ tp_group = booster.plugin.tp_group
+
+ # check last hidden state & loss
+ if stage_manager is None or stage_manager.is_last_stage():
+
+ if org_model.__class__.__name__ != 'T5ForConditionalGeneration':
+ check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3)
+
+ check_loss(org_loss, sharded_loss, atol=1e-5, rtol=1e-3)
+
+ # unwrap model
+ t5 = org_model
+ sharded_t5 = sharded_model.unwrap()
+
+ row_layer_for_check = ['shared', 'encoder.block[0].layer[0].SelfAttention.q']
+
+ # check weights and gradients
+ if stage_manager is None or stage_manager.is_first_stage():
+ check_grad(t5, sharded_t5, row_layer_for_check, tp_group, atol=1e-5, rtol=1e-3, dim=0)
+
+ # check weights after optimizer.step()
+ org_optimizer.step()
+ sharded_optimizer.step()
+ if stage_manager is None or stage_manager.is_first_stage():
+ check_weight(t5, sharded_t5, row_layer_for_check, tp_group, atol=1e-4, rtol=1e-3, dim=0, verbose=False)
+
+ torch.cuda.empty_cache()
+
+
+@parameterize('test_config', [{
+ 'tp_size': 2,
+ 'pp_size': 2,
+ 'num_microbatches': 2,
+ 'enable_fused_normalization': True,
+ 'use_lazy_init': True
+}, {
+ 'tp_size': 1,
+ 'pp_size': 2,
+ 'num_microbatches': 4,
+ 'use_lazy_init': False
+}, {
+ 'tp_size': 4,
+ 'pp_size': 1,
+ 'enable_fused_normalization': True,
+ 'use_lazy_init': False
+}, {
+ 'tp_size': 1,
+ 'pp_size': 4,
+ 'num_microbatches': 4,
+ 'use_lazy_init': False
+}])
+@clear_cache_before_run()
+def run_t5_test(test_config):
+
+ # TODO: add plugin_config for TP+DP after supporting & debugging it
+ # {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True}
+
+ # TODO: add test_config for flash attention & jit operator after supporting
+
sub_model_zoo = model_zoo.get_sub_registry('transformers_t5')
+ test_config['precision'] = 'float' # Do not use fp16/bf16 in testing
+
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
- org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
- enable_flash_attention, enable_jit_fused, use_lazy_init)
- check_state_dict(org_model, sharded_model, name=name)
- check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
+
+ # skip 4-stage pp test for t5_encoder
+ if test_config['pp_size'] > 2 and name == 'transformers_t5_encoder_model':
+ continue
+
+ check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
+
+ clear_layout_converter()
+ Randomizer.reset_index()
torch.cuda.empty_cache()
@@ -68,7 +118,7 @@ def check_t5(rank, world_size, port):
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_t5():
- spawn(check_t5, 2)
+ spawn(check_t5, 4)
if __name__ == "__main__":
diff --git a/tests/test_shardformer/test_model/test_shard_t5_pipeline.py b/tests/test_shardformer/test_model/test_shard_t5_pipeline.py
deleted file mode 100644
index 7f3a5f2ea40b..000000000000
--- a/tests/test_shardformer/test_model/test_shard_t5_pipeline.py
+++ /dev/null
@@ -1,101 +0,0 @@
-import pytest
-import torch
-
-import colossalai
-from colossalai.cluster import ProcessGroupMesh
-from colossalai.logging import disable_existing_loggers
-from colossalai.pipeline.stage_manager import PipelineStageManager
-from colossalai.shardformer.policies.t5 import T5BasePolicy
-from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
-from tests.kit.model_zoo import model_zoo
-from tests.test_shardformer.test_model._utils import build_pipeline_model
-
-
-def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
- # TODO: add tests for forward/backward later
- pass
-
-
-@parameterize('enable_tensor_parallelism', [False])
-@parameterize('enable_fused_normalization', [False])
-@parameterize('use_lazy_init', [False])
-#TODO: merge this into test_shard_t5.py
-def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
- DP_DIM, PP_DIM = 0, 1
- DP_SIZE, PP_SIZE = 2, 2
- pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
- stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
-
- sub_model_zoo = model_zoo.get_sub_registry('transformers_t5')
- for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items():
-
- inputs = data_gen_fn()
- inputs = {k: v.cuda() for k, v in inputs.items()}
- input_ids = inputs['input_ids']
-
- _, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
- enable_tensor_parallelism, use_lazy_init)
-
- batch_size, seq_len = input_ids.shape
- hidden_size = sharded_model.config.d_model
- num_heads = sharded_model.config.num_heads
- hidden_state_shape = (batch_size, seq_len, hidden_size)
- position_bias_shape = (batch_size, num_heads, seq_len, seq_len)
-
- num_encoder_layers = len(sharded_model.encoder.block)
- decoder = sharded_model.__dict__.get('decoder', None)
- num_decoder_layers = len(decoder.block) if decoder else 0
-
- _, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(num_encoder_layers, num_decoder_layers, PP_SIZE)
- stage = stage_manager.stage
- at_first_stage = (stage == 0) or (stage == decoder_starting_stage)
- at_last_stage = (stage == decoder_starting_stage - 1) or (stage == stage_manager.num_stages - 1)
- in_decoder = stage >= decoder_starting_stage
-
- if not at_first_stage:
- # change inputs if not the first stage
- hidden_states = torch.zeros(*hidden_state_shape).cuda()
- position_bias = torch.zeros(*position_bias_shape).cuda()
- encoder_decoder_position_bias = torch.zeros(*position_bias_shape).cuda()
- inputs['input_ids'] = None
- inputs['hidden_states'] = hidden_states
- inputs['position_bias'] = position_bias
- inputs['encoder_decoder_position_bias'] = encoder_decoder_position_bias
- if in_decoder:
- encoder_output_states = torch.zeros(*hidden_state_shape).cuda()
- inputs['encoder_outputs'] = (encoder_output_states,)
-
- sharded_model.train()
- output = sharded_model(**inputs)
- if at_last_stage:
- if name == 'transformers_t5_for_conditional_generation' and in_decoder:
- assert output.loss is not None
- else:
- if name != 'transformers_t5_encoder_model' and not in_decoder:
- output = output['encoder_outputs']
- assert output[0].shape == hidden_state_shape
- else:
- assert output['hidden_states'].shape == hidden_state_shape
- # position_bias information should be passed in T5
- assert output['position_bias'].shape == position_bias_shape
- if in_decoder:
- assert output['encoder_decoder_position_bias'].shape == position_bias_shape
-
- torch.cuda.empty_cache()
-
-
-def check_t5(rank, world_size, port):
- disable_existing_loggers()
- colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
- run_t5_test()
-
-
-@pytest.mark.dist
-@rerun_if_address_is_in_use()
-@clear_cache_before_run()
-def test_t5():
- spawn(check_t5, 4)
-
-
-if __name__ == "__main__":
- test_t5()
From 458ae331ad4db52243eebdda551cbc2b65a5c73f Mon Sep 17 00:00:00 2001
From: flybird1111 <1829166702@qq.com>
Date: Wed, 9 Aug 2023 14:24:45 +0800
Subject: [PATCH 53/64] [kernel] updated unittests for coloattention (#4389)
Updated coloattention tests of checking outputs and gradients
---
requirements/requirements-test.txt | 3 +-
requirements/requirements.txt | 1 +
tests/test_utils/test_flash_attention.py | 142 +++++++++++++++--------
3 files changed, 94 insertions(+), 52 deletions(-)
diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt
index 9f6580c72d1b..e65271621ddd 100644
--- a/requirements/requirements-test.txt
+++ b/requirements/requirements-test.txt
@@ -13,6 +13,7 @@ torchrec==0.2.0
contexttimer
einops
triton==2.0.0.dev20221202
-git+https://github.com/HazyResearch/flash-attention.git@c422fee3776eb3ea24e011ef641fd5fbeb212623#egg=flash_attn
requests==2.27.1 # downgrade to avoid huggingface error https://github.com/huggingface/transformers/issues/17611
SentencePiece
+ninja
+flash_attn>=2.0
diff --git a/requirements/requirements.txt b/requirements/requirements.txt
index f6be6a624c70..65eecce2c34f 100644
--- a/requirements/requirements.txt
+++ b/requirements/requirements.txt
@@ -10,4 +10,5 @@ contexttimer
ninja
torch>=1.11
safetensors
+flash_attn>=2.0
einops
diff --git a/tests/test_utils/test_flash_attention.py b/tests/test_utils/test_flash_attention.py
index fbcc452650cf..e1c7446f40db 100644
--- a/tests/test_utils/test_flash_attention.py
+++ b/tests/test_utils/test_flash_attention.py
@@ -1,4 +1,4 @@
-import random
+import math
import pytest
import torch
@@ -13,118 +13,158 @@
from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType
DTYPE = [torch.float16, torch.bfloat16, torch.float32]
+FLASH_DTYPE = [torch.float16, torch.bfloat16]
-def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale):
- M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
- p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
- for z in range(Z):
- for h in range(H):
- p[:, :, M == 0] = float("-inf")
- p = torch.softmax(p.float(), dim=-1).half()
- ref_out = torch.matmul(p, v)
- return ref_out
+def attention_ref(q, k, v, attn_mask=None, causal=False):
+ """
+ attention output of the control group
+ """
+ dtype_og = q.dtype
+ seqlen_q, seqlen_k = q.shape[1], k.shape[1]
+ d = q.shape[-1]
+ scale = 1.0 / math.sqrt(d)
+ scores = torch.einsum('bthd,bshd->bhts', q * scale, k)
+
+ if attn_mask is not None:
+ scores.masked_fill_(rearrange(~attn_mask, 'b s -> b 1 1 s'), float('-inf'))
+ if causal:
+ causal_mask = torch.triu(torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device), 1)
+ scores.masked_fill_(causal_mask, float('-inf'))
+ attention = torch.softmax(scores, dim=-1)
+
+ output = torch.einsum('bhts,bshd->bthd', attention, v)
+ output = rearrange(output, "b s h d -> b s (h d)")
+
+ # Modify the data at the positions of the mask to 0
+ if attn_mask is not None:
+ output.masked_fill_(rearrange(~attn_mask, 'b s -> b s 1'), 0.0)
+
+ return output.to(dtype=dtype_og)
@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available")
@clear_cache_before_run()
-@parameterize('proj_shape', [(1, 8, 4, 16)])
+@parameterize('proj_shape', [(6, 8, 4, 16)])
@parameterize('dtype', DTYPE)
-def test_attention_gpt(proj_shape, dtype):
- # TODO check output value
+@parameterize('dropout', [0.0])
+def test_attention_gpt(proj_shape, dtype, dropout):
(B, S, H, D_HEAD) = proj_shape
D = H * D_HEAD
- c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda")
- attn = ColoAttention(D, H, dropout=0.1)
+ q = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
+ k = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
+ v = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
- x = torch.randn((B, S, D), dtype=dtype, device="cuda")
-
- qkv = c_attn(x)
- q, k, v = rearrange(qkv, 'b s (n h d) -> n b s h d', n=3, h=H)
-
- mask = [torch.ones(S - i, dtype=dtype, device="cuda") for i in range(B)]
+ mask = [torch.ones(S - i, dtype=torch.bool, device="cuda") for i in range(B)]
mask = torch.nn.utils.rnn.pad_sequence(mask, batch_first=True)
+ attn = ColoAttention(D, H, dropout=dropout)
y = attn(q, k, v, attn_mask=mask, attn_mask_type=AttnMaskType.paddedcausal)
assert list(y.shape) == [B, S, D]
+ out_ref = attention_ref(q, k, v, mask, causal=True)
+
+ # check gradients
dy = torch.rand_like(y)
- y.backward(dy)
+ grad_q, grad_k, grad_v = torch.autograd.grad(y, (q, k, v), dy)
+ grad_ref_q, grad_ref_k, grad_ref_v = torch.autograd.grad(out_ref, (q, k, v), dy)
+
+ torch.allclose(y, out_ref, atol=1e-7), f"{(y - out_ref).abs().max()}"
+ torch.allclose(grad_q, grad_ref_q, atol=1e-7), f"{(grad_q - grad_ref_q).abs().max()}"
+ torch.allclose(grad_k, grad_ref_k, atol=1e-7), f"{(grad_k - grad_ref_k).abs().max()}"
+ torch.allclose(grad_v, grad_ref_v, atol=1e-7), f"{(grad_v - grad_ref_v).abs().max()}"
@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available")
@clear_cache_before_run()
@parameterize('proj_shape', [(6, 8, 4, 16)])
@parameterize('dtype', DTYPE)
-def test_attention_bert(proj_shape, dtype):
+@parameterize('dropout', [0.0])
+def test_attention_bert(proj_shape, dtype, dropout):
(B, S, H, D_HEAD) = proj_shape
D = H * D_HEAD
- c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda")
- attn = ColoAttention(D, H, dropout=0.1)
+ q = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
+ k = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
+ v = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
- x = torch.randn((B, S, D), dtype=dtype, device="cuda")
# attention mask of shape [B, S] with zero padding to max length S
- mask = [torch.ones(S - i, dtype=dtype, device="cuda") for i in range(B)]
- mask = torch.nn.utils.rnn.pad_sequence(mask, batch_first=True)
+ mask = torch.randint(0, 2, (B, S), dtype=torch.bool, device="cuda")
- qkv = c_attn(x)
- q, k, v = rearrange(qkv, 'b s (n h d) -> b s n h d', n=3, h=H).unbind(dim=2)
+ attn = ColoAttention(D, H, dropout=dropout)
y = attn(q, k, v, attn_mask=mask, attn_mask_type=AttnMaskType.padding)
assert list(y.shape) == [B, S, D]
+ out_ref = attention_ref(q, k, v, mask, causal=False)
+
dy = torch.rand_like(y)
- y.backward(dy)
+ grad_q, grad_k, grad_v = torch.autograd.grad(y, (q, k, v), dy)
+ grad_ref_q, grad_ref_k, grad_ref_v = torch.autograd.grad(out_ref, (q, k, v), dy)
+
+ torch.allclose(y, out_ref, atol=1e-7), f"{(y - out_ref).abs().max()}"
+ torch.allclose(grad_q, grad_ref_q, atol=1e-7), f"{(grad_q - grad_ref_q).abs().max()}"
+ torch.allclose(grad_k, grad_ref_k, atol=1e-7), f"{(grad_k - grad_ref_k).abs().max()}"
+ torch.allclose(grad_v, grad_ref_v, atol=1e-7), f"{(grad_v - grad_ref_v).abs().max()}"
@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available")
@clear_cache_before_run()
@parameterize('proj_shape', [(6, 8, 4, 16)])
@parameterize('dtype', DTYPE)
-def test_attention_no_mask(proj_shape, dtype):
+@parameterize('dropout', [0.0])
+def test_attention_no_mask(proj_shape, dtype, dropout):
(B, S, H, D_HEAD) = proj_shape
D = H * D_HEAD
- c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda")
- attn = ColoAttention(D, H, dropout=0.1)
+ q = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
+ k = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
+ v = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
- x = torch.randn((B, S, D), dtype=dtype, device="cuda")
- qkv = c_attn(x)
- q, k, v = rearrange(qkv, 'b s (n h d) -> b s n h d', n=3, h=H).unbind(dim=2)
+ attn = ColoAttention(D, H, dropout=dropout)
y = attn(q, k, v)
assert list(y.shape) == [B, S, D]
+ out_ref = attention_ref(q, k, v, None, causal=False)
+
dy = torch.rand_like(y)
- y.backward(dy)
+ grad_q, grad_k, grad_v = torch.autograd.grad(y, (q, k, v), dy)
+ grad_ref_q, grad_ref_k, grad_ref_v = torch.autograd.grad(out_ref, (q, k, v), dy)
+
+ torch.allclose(y, out_ref, atol=1e-7), f"{(y - out_ref).abs().max()}"
+ torch.allclose(grad_q, grad_ref_q, atol=1e-7), f"{(grad_q - grad_ref_q).abs().max()}"
+ torch.allclose(grad_k, grad_ref_k, atol=1e-7), f"{(grad_k - grad_ref_k).abs().max()}"
+ torch.allclose(grad_v, grad_ref_v, atol=1e-7), f"{(grad_v - grad_ref_v).abs().max()}"
@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available")
@clear_cache_before_run()
@parameterize('proj_shape', [(6, 24, 8, 4, 16)])
@parameterize('dtype', DTYPE)
-def test_cross_attention(proj_shape, dtype):
+@parameterize('dropout', [0.0])
+def test_cross_attention(proj_shape, dtype, dropout):
(B, S, T, H, D_HEAD) = proj_shape
D = H * D_HEAD
- q_attn = torch.nn.Linear(D, D, dtype=dtype, device="cuda")
- kv_attn = torch.nn.Linear(D, 2 * D, dtype=dtype, device="cuda")
+ q = torch.randn((B, T, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
+ k = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
+ v = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
- attn = ColoAttention(D, H, dropout=0.1)
-
- src = torch.randn((B, S, D), dtype=dtype, device="cuda")
- tgt = torch.randn((B, T, D), dtype=dtype, device="cuda")
-
- q = q_attn(tgt)
- kv = kv_attn(src)
- q = rearrange(q, 'b s (h d) -> b s h d', h=H)
- k, v = rearrange(kv, 'b s (n h d) -> b s n h d', n=2, h=H).unbind(dim=2)
+ attn = ColoAttention(D, H, dropout=dropout)
y = attn(q, k, v, attn_mask_type=AttnMaskType.causal)
assert list(y.shape) == [B, T, D]
+ out_ref = attention_ref(q, k, v, None, causal=True)
+
dy = torch.rand_like(y)
- y.backward(dy)
+ grad_q, grad_k, grad_v = torch.autograd.grad(y, (q, k, v), dy)
+ grad_ref_q, grad_ref_k, grad_ref_v = torch.autograd.grad(out_ref, (q, k, v), dy)
+
+ torch.allclose(y, out_ref, atol=1e-18), f"{(y - out_ref).abs().max()}"
+ torch.allclose(grad_q, grad_ref_q, atol=1e-7), f"{(grad_q - grad_ref_q).abs().max()}"
+ torch.allclose(grad_k, grad_ref_k, atol=1e-7), f"{(grad_k - grad_ref_k).abs().max()}"
+ torch.allclose(grad_v, grad_ref_v, atol=1e-7), f"{(grad_v - grad_ref_v).abs().max()}"
\ No newline at end of file
From c14920a075703d3bebb931a6af45c13071cfc481 Mon Sep 17 00:00:00 2001
From: flybird1111 <1829166702@qq.com>
Date: Wed, 9 Aug 2023 14:32:19 +0800
Subject: [PATCH 54/64] [shardformer] update shardformer to use flash attention
2 (#4392)
* cherry-pick flash attention 2
cherry-pick flash attention 2
* [shardformer] update shardformer to use flash attention 2
[shardformer] update shardformer to use flash attention 2, fix
[shardformer] update shardformer to use flash attention 2, fix
[shardformer] update shardformer to use flash attention 2, fix
---
colossalai/kernel/cuda_native/__init__.py | 8 +-
.../kernel/cuda_native/flash_attention.py | 647 ------------------
colossalai/kernel/cuda_native/mha/__init__.py | 3 +
.../kernel/cuda_native/mha/flash_attn_2.py | 68 ++
.../kernel/cuda_native/mha/mem_eff_attn.py | 70 ++
colossalai/kernel/cuda_native/mha/mha.py | 107 +++
colossalai/kernel/cuda_native/mha/utils.py | 82 +++
colossalai/shardformer/modeling/blip2.py | 2 +-
colossalai/shardformer/modeling/chatglm.py | 3 +-
colossalai/shardformer/modeling/gpt2.py | 2 +-
colossalai/shardformer/modeling/llama.py | 2 +-
colossalai/shardformer/modeling/opt.py | 2 +-
colossalai/shardformer/modeling/vit.py | 2 +-
colossalai/shardformer/modeling/whisper.py | 2 +-
tests/test_utils/test_flash_attention.py | 39 +-
15 files changed, 367 insertions(+), 672 deletions(-)
delete mode 100644 colossalai/kernel/cuda_native/flash_attention.py
create mode 100644 colossalai/kernel/cuda_native/mha/__init__.py
create mode 100644 colossalai/kernel/cuda_native/mha/flash_attn_2.py
create mode 100644 colossalai/kernel/cuda_native/mha/mem_eff_attn.py
create mode 100644 colossalai/kernel/cuda_native/mha/mha.py
create mode 100644 colossalai/kernel/cuda_native/mha/utils.py
diff --git a/colossalai/kernel/cuda_native/__init__.py b/colossalai/kernel/cuda_native/__init__.py
index 1d5a6ce495bd..e0136d86e561 100644
--- a/colossalai/kernel/cuda_native/__init__.py
+++ b/colossalai/kernel/cuda_native/__init__.py
@@ -1,5 +1,9 @@
from .layer_norm import MixedFusedLayerNorm as LayerNorm
+from .mha.mha import ColoAttention
from .multihead_attention import MultiHeadAttention
-from .scaled_softmax import FusedScaleMaskSoftmax, ScaledUpperTriangMaskedSoftmax
+from .scaled_softmax import AttnMaskType, FusedScaleMaskSoftmax, ScaledUpperTriangMaskedSoftmax
-__all__ = ['LayerNorm', 'MultiHeadAttention', 'FusedScaleMaskSoftmax', 'ScaledUpperTriangMaskedSoftmax']
+__all__ = [
+ 'LayerNorm', 'MultiHeadAttention', 'FusedScaleMaskSoftmax', 'ScaledUpperTriangMaskedSoftmax', 'ColoAttention',
+ 'AttnMaskType'
+]
diff --git a/colossalai/kernel/cuda_native/flash_attention.py b/colossalai/kernel/cuda_native/flash_attention.py
deleted file mode 100644
index 91bef0908dbb..000000000000
--- a/colossalai/kernel/cuda_native/flash_attention.py
+++ /dev/null
@@ -1,647 +0,0 @@
-"""
-A general attention module using the flash attention kernels from xformers:
-https://github.com/facebookresearch/xformers/tree/main/xformers/ops/fmha
-"""
-
-import math
-import os
-import subprocess
-import warnings
-
-import torch
-
-try:
- from xformers.ops.fmha import memory_efficient_attention
- HAS_MEM_EFF_ATTN = True
-except ImportError:
- HAS_MEM_EFF_ATTN = False
- warnings.warn(f'please install xformers from https://github.com/facebookresearch/xformers')
-
-if HAS_MEM_EFF_ATTN:
-
- from typing import Optional
-
- from einops import rearrange
- from xformers.ops.fmha import MemoryEfficientAttentionCutlassOp
- from xformers.ops.fmha.attn_bias import (
- BlockDiagonalCausalMask,
- BlockDiagonalMask,
- LowerTriangularMask,
- LowerTriangularMaskWithTensorBias,
- )
-
- from .scaled_softmax import AttnMaskType
-
- allow_alibi = True
- for op in MemoryEfficientAttentionCutlassOp:
- allow_alibi = allow_alibi & (LowerTriangularMaskWithTensorBias in op.SUPPORTED_ATTN_BIAS_TYPES)
-
- class Unpad(torch.autograd.Function):
- """
- Adapted from
- https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
- """
-
- @staticmethod
- def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor):
- ctx.save_for_backward(indices)
- # [b, s, ...]
- assert tensor.ndim >= 3
- ctx.bsz = tensor.shape[0]
- out = rearrange(tensor, 'b s ... -> (b s) ...')
- ctx.shape = out.shape
- # [1, ntokens, ...]
- return out[indices].unsqueeze(0)
-
- @staticmethod
- def backward(ctx, grad_output):
- indices, = ctx.saved_tensors
- # [b*s, ...]
- grad = torch.zeros(ctx.shape, dtype=grad_output.dtype, device=grad_output.device)
- grad[indices] = grad_output.squeeze(0)
- grad = rearrange(grad, '(b s) ... -> b s ...', b=ctx.bsz)
- # [b, s, ...]
- return grad, None
-
- class Repad(torch.autograd.Function):
- """
- Adapted from
- https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
- """
-
- @staticmethod
- def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int):
- ctx.save_for_backward(indices)
- # [ntokens, ...]
- tensor = tensor.squeeze(0)
- out = torch.zeros((batch_size * seq_len, *tensor.shape[1:]), dtype=tensor.dtype, device=tensor.device)
- # [b*s, ...]
- out[indices] = tensor
- # [b, s, ...]
- out = rearrange(out, '(b s) ... -> b s ...', b=batch_size)
- return out
-
- @staticmethod
- def backward(ctx, grad_output):
- indices, = ctx.saved_tensors
- # [b*s, ...]
- grad_output = rearrange(grad_output, 'b s ... -> (b s) ...')
- grad = grad_output[indices]
- # [1, ntokens, ...]
- return grad.unsqueeze(0), None, None, None
-
- class ColoAttention(torch.nn.Module):
-
- def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, scale=None):
- super().__init__()
- assert embed_dim % num_heads == 0, \
- f"the embed dim ({embed_dim}) is not divisible by the number of attention heads ({num_heads})."
- if scale is not None:
- self.scale = scale
- else:
- self.scale = 1 / math.sqrt(embed_dim // num_heads)
- self.dropout = dropout
-
- @staticmethod
- def get_seq_info_from_mask(attn_mask: torch.Tensor):
- indices = torch.nonzero(attn_mask.flatten(), as_tuple=False).flatten()
- seqlens = attn_mask.sum(dim=-1, dtype=torch.int32).flatten().tolist()
- return indices, seqlens
-
- @staticmethod
- def unpad(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
- return Unpad.apply(tensor, indices)
-
- @staticmethod
- def repad(tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int) -> torch.Tensor:
- return Repad.apply(tensor, indices, batch_size, seq_len)
-
- def forward(self,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- attn_mask: Optional[torch.Tensor] = None,
- attn_mask_type: Optional[AttnMaskType] = None,
- bias: Optional[torch.Tensor] = None):
- batch_size, tgt_len, src_len = query.shape[0], query.shape[1], key.shape[1]
- attn_bias = None
- if attn_mask_type and attn_mask_type.value % 2 == 1: # bert style
- assert attn_mask is not None, \
- f"attention mask {attn_mask} is not valid for attention mask type {attn_mask_type}."
- assert attn_mask.dim() == 2, \
- "attention mask is supposed to have shape (batch_size, seq_len), " + \
- f"but got {attn_mask.dim()} dimensions."
- if tgt_len == src_len:
- q_indices, q_seqlen = self.get_seq_info_from_mask(attn_mask)
- kv_seqlen = None
- if batch_size > 1:
- query, key, value = self.unpad(torch.stack([query, key, value], dim=2), q_indices).unbind(dim=2)
- else:
- q_indices = torch.arange(batch_size * tgt_len, dtype=torch.int32, device=query.device)
- q_seqlen = torch.LongTensor([tgt_len] * batch_size, device=query.device)
- kv_indices, kv_seqlen = self.get_seq_info_from_mask(attn_mask)
- if batch_size > 1:
- query = rearrange(query, "b s ... -> c (b s) ...", c=1)
- key, value = self.unpad(torch.stack([query, key, value], dim=2), kv_indices).unbind(dim=2)
- if attn_mask_type == AttnMaskType.padding:
- attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen)
- elif attn_mask_type == AttnMaskType.paddedcausal:
- attn_bias = BlockDiagonalCausalMask.from_seqlens(q_seqlen, kv_seqlen)
- elif attn_mask_type == AttnMaskType.causal: # gpt style
- attn_bias = LowerTriangularMask()
-
- if bias is not None: # alibi / relative position embedding
- assert allow_alibi, "flash attention with bias is not supported in this system."
- assert attn_mask_type == AttnMaskType.causal, \
- "attention with bias is only supported for causal attention so far."
- attn_bias = attn_bias.add_bias(bias)
-
- out = memory_efficient_attention(query, key, value, attn_bias=attn_bias, p=self.dropout, scale=self.scale)
-
- if attn_mask_type and attn_mask_type.value % 2 == 1 and batch_size > 1:
- out = self.repad(out, q_indices, batch_size, tgt_len)
-
- out = rearrange(out, 'b s h d -> b s (h d)')
- return out
-
-
-##########################################################################
-# the flash attention functions below that are copied
-# from the OpenAI/triton repository will be deprecated
-# You can find the repository in Triton https://github.com/openai/triton
-# You can find the source file in https://github.com/openai/triton/blob/main/python/tutorials/06-fused-attention.py
-# Reference:
-# 1. Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf
-# 2. Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf
-
-
-def triton_cuda_check():
- cuda_home = os.getenv("CUDA_HOME", default="/usr/local/cuda")
- cuda_version = subprocess.check_output([os.path.join(cuda_home, "bin/nvcc"), "--version"]).decode().strip()
- cuda_version = cuda_version.split('release ')[1]
- cuda_version = cuda_version.split(',')[0]
- cuda_version = cuda_version.split('.')
- if len(cuda_version) == 2 and \
- (int(cuda_version[0]) == 11 and int(cuda_version[1]) >= 4) or \
- int(cuda_version[0]) > 11:
- return True
- return False
-
-
-try:
- import triton
- import triton.language as tl
- if triton_cuda_check():
- HAS_TRITON = True
- else:
- print("triton requires cuda >= 11.4")
- HAS_TRITON = False
-except ImportError:
- print('please install triton from https://github.com/openai/triton')
- HAS_TRITON = False
-try:
- from flash_attn.flash_attention import FlashAttention
- from flash_attn.flash_attn_interface import (
- flash_attn_unpadded_func,
- flash_attn_unpadded_kvpacked_func,
- flash_attn_unpadded_qkvpacked_func,
- )
- HAS_FLASH_ATTN = True
-except ImportError:
- HAS_FLASH_ATTN = False
- print('please install flash_attn from https://github.com/HazyResearch/flash-attention')
-
-if HAS_TRITON:
- # the following functions are adapted from the OpenAI Triton tutorial
- # https://github.com/openai/triton/blob/main/python/tutorials/06-fused-attention.py
- @triton.jit
- def _fwd_kernel(
- Q,
- K,
- V,
- sm_scale,
- TMP,
- L,
- M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
- Out,
- stride_qz,
- stride_qh,
- stride_qm,
- stride_qk,
- stride_kz,
- stride_kh,
- stride_kn,
- stride_kk,
- stride_vz,
- stride_vh,
- stride_vk,
- stride_vn,
- stride_oz,
- stride_oh,
- stride_om,
- stride_on,
- Z,
- H,
- N_CTX,
- BLOCK_M: tl.constexpr,
- BLOCK_DMODEL: tl.constexpr,
- BLOCK_N: tl.constexpr,
- ):
- start_m = tl.program_id(0)
- off_hz = tl.program_id(1)
- # initialize offsets
- offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
- offs_n = tl.arange(0, BLOCK_N)
- offs_d = tl.arange(0, BLOCK_DMODEL)
- off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
- off_k = off_hz * stride_qh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk
- off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk
- # Initialize pointers to Q, K, V
- q_ptrs = Q + off_q
- k_ptrs = K + off_k
- v_ptrs = V + off_v
- # initialize pointer to m and l
- t_ptrs = TMP + off_hz * N_CTX + offs_m
- m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
- l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
- acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
- # load q: it will stay in SRAM throughout
- q = tl.load(q_ptrs)
- # loop over k, v and update accumulator
- for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):
- start_n = tl.multiple_of(start_n, BLOCK_N)
- # -- compute qk ----
- k = tl.load(k_ptrs + start_n * stride_kn)
- qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
- qk += tl.dot(q, k, trans_b=True)
- qk *= sm_scale
- qk += tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), 0, float("-inf"))
- # -- compute m_ij, p, l_ij
- m_ij = tl.max(qk, 1)
- p = tl.exp(qk - m_ij[:, None])
- l_ij = tl.sum(p, 1)
- # -- update m_i and l_i
- m_i_new = tl.maximum(m_i, m_ij)
- alpha = tl.exp(m_i - m_i_new)
- beta = tl.exp(m_ij - m_i_new)
- l_i_new = alpha * l_i + beta * l_ij
- # -- update output accumulator --
- # scale p
- p_scale = beta / l_i_new
- p = p * p_scale[:, None]
- # scale acc
- acc_scale = l_i / l_i_new * alpha
- tl.store(t_ptrs, acc_scale)
- acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load
- acc = acc * acc_scale[:, None]
- # update acc
- v = tl.load(v_ptrs + start_n * stride_vk)
- p = p.to(tl.float16)
- acc += tl.dot(p, v)
- # update m_i and l_i
- l_i = l_i_new
- m_i = m_i_new
- # rematerialize offsets to save registers
- start_m = tl.program_id(0)
- offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
- # write back l and m
- l_ptrs = L + off_hz * N_CTX + offs_m
- m_ptrs = M + off_hz * N_CTX + offs_m
- tl.store(l_ptrs, l_i)
- tl.store(m_ptrs, m_i)
- # initialize pointers to output
- offs_n = tl.arange(0, BLOCK_DMODEL)
- off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
- out_ptrs = Out + off_o
- tl.store(out_ptrs, acc)
-
- @triton.jit
- def _bwd_preprocess(
- Out,
- DO,
- L,
- NewDO,
- Delta,
- BLOCK_M: tl.constexpr,
- D_HEAD: tl.constexpr,
- ):
- off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
- off_n = tl.arange(0, D_HEAD)
- # load
- o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
- do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
- denom = tl.load(L + off_m).to(tl.float32)
- # compute
- do = do / denom[:, None]
- delta = tl.sum(o * do, axis=1)
- # write-back
- tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do)
- tl.store(Delta + off_m, delta)
-
- @triton.jit
- def _bwd_kernel(
- Q,
- K,
- V,
- sm_scale,
- Out,
- DO,
- DQ,
- DK,
- DV,
- L,
- M,
- D,
- stride_qz,
- stride_qh,
- stride_qm,
- stride_qk,
- stride_kz,
- stride_kh,
- stride_kn,
- stride_kk,
- stride_vz,
- stride_vh,
- stride_vk,
- stride_vn,
- Z,
- H,
- N_CTX,
- num_block,
- BLOCK_M: tl.constexpr,
- BLOCK_DMODEL: tl.constexpr,
- BLOCK_N: tl.constexpr,
- ):
- off_hz = tl.program_id(0)
- off_z = off_hz // H
- off_h = off_hz % H
- # offset pointers for batch/head
- Q += off_z * stride_qz + off_h * stride_qh
- K += off_z * stride_qz + off_h * stride_qh
- V += off_z * stride_qz + off_h * stride_qh
- DO += off_z * stride_qz + off_h * stride_qh
- DQ += off_z * stride_qz + off_h * stride_qh
- DK += off_z * stride_qz + off_h * stride_qh
- DV += off_z * stride_qz + off_h * stride_qh
- for start_n in range(0, num_block):
- lo = start_n * BLOCK_M
- # initialize row/col offsets
- offs_qm = lo + tl.arange(0, BLOCK_M)
- offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)
- offs_m = tl.arange(0, BLOCK_N)
- offs_k = tl.arange(0, BLOCK_DMODEL)
- # initialize pointers to value-like data
- q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
- k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
- v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
- do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
- dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
- # pointer to row-wise quantities in value-like data
- D_ptrs = D + off_hz * N_CTX
- m_ptrs = M + off_hz * N_CTX
- # initialize dv amd dk
- dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
- dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
- # k and v stay in SRAM throughout
- k = tl.load(k_ptrs)
- v = tl.load(v_ptrs)
- # loop over rows
- for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):
- offs_m_curr = start_m + offs_m
- # load q, k, v, do on-chip
- q = tl.load(q_ptrs)
- # recompute p = softmax(qk, dim=-1).T
- # NOTE: `do` is pre-divided by `l`; no normalization here
- qk = tl.dot(q, k, trans_b=True)
- qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
- m = tl.load(m_ptrs + offs_m_curr)
- p = tl.exp(qk * sm_scale - m[:, None])
- # compute dv
- do = tl.load(do_ptrs)
- dv += tl.dot(p.to(tl.float16), do, trans_a=True)
- # compute dp = dot(v, do)
- Di = tl.load(D_ptrs + offs_m_curr)
- dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
- dp += tl.dot(do, v, trans_b=True)
- # compute ds = p * (dp - delta[:, None])
- ds = p * dp * sm_scale
- # compute dk = dot(ds.T, q)
- dk += tl.dot(ds.to(tl.float16), q, trans_a=True)
- # # compute dq
- dq = tl.load(dq_ptrs, eviction_policy="evict_last")
- dq += tl.dot(ds.to(tl.float16), k)
- tl.store(dq_ptrs, dq, eviction_policy="evict_last")
- # # increment pointers
- dq_ptrs += BLOCK_M * stride_qm
- q_ptrs += BLOCK_M * stride_qm
- do_ptrs += BLOCK_M * stride_qm
- # write-back
- dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
- dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
- tl.store(dv_ptrs, dv)
- tl.store(dk_ptrs, dk)
-
- class _TritonFlashAttention(torch.autograd.Function):
-
- @staticmethod
- def forward(ctx, q, k, v, sm_scale):
- BLOCK = 128
- # shape constraints
- Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
- assert Lq == Lk and Lk == Lv
- assert Lk in {16, 32, 64, 128}
- o = torch.empty_like(q)
- grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1])
- tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
- L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
- m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
- num_warps = 4 if Lk <= 64 else 8
-
- _fwd_kernel[grid](
- q,
- k,
- v,
- sm_scale,
- tmp,
- L,
- m,
- o,
- q.stride(0),
- q.stride(1),
- q.stride(2),
- q.stride(3),
- k.stride(0),
- k.stride(1),
- k.stride(2),
- k.stride(3),
- v.stride(0),
- v.stride(1),
- v.stride(2),
- v.stride(3),
- o.stride(0),
- o.stride(1),
- o.stride(2),
- o.stride(3),
- q.shape[0],
- q.shape[1],
- q.shape[2],
- BLOCK_M=BLOCK,
- BLOCK_N=BLOCK,
- BLOCK_DMODEL=Lk,
- num_warps=num_warps,
- num_stages=1,
- )
- ctx.save_for_backward(q, k, v, o, L, m)
- ctx.BLOCK = BLOCK
- ctx.grid = grid
- ctx.sm_scale = sm_scale
- ctx.BLOCK_DMODEL = Lk
- return o
-
- @staticmethod
- def backward(ctx, do):
- q, k, v, o, l, m = ctx.saved_tensors
- do = do.contiguous()
- dq = torch.zeros_like(q, dtype=torch.float32)
- dk = torch.empty_like(k)
- dv = torch.empty_like(v)
- do_scaled = torch.empty_like(do)
- delta = torch.empty_like(l)
- _bwd_preprocess[(ctx.grid[0] * ctx.grid[1],)](
- o,
- do,
- l,
- do_scaled,
- delta,
- BLOCK_M=ctx.BLOCK,
- D_HEAD=ctx.BLOCK_DMODEL,
- )
-
- # NOTE: kernel currently buggy for other values of `num_warps`
- num_warps = 8
- _bwd_kernel[(ctx.grid[1],)](
- q,
- k,
- v,
- ctx.sm_scale,
- o,
- do_scaled,
- dq,
- dk,
- dv,
- l,
- m,
- delta,
- q.stride(0),
- q.stride(1),
- q.stride(2),
- q.stride(3),
- k.stride(0),
- k.stride(1),
- k.stride(2),
- k.stride(3),
- v.stride(0),
- v.stride(1),
- v.stride(2),
- v.stride(3),
- q.shape[0],
- q.shape[1],
- q.shape[2],
- ctx.grid[0],
- BLOCK_M=ctx.BLOCK,
- BLOCK_N=ctx.BLOCK,
- BLOCK_DMODEL=ctx.BLOCK_DMODEL,
- num_warps=num_warps,
- num_stages=1,
- )
- return dq, dk, dv, None
-
- def triton_flash_attention(q, k, v, sm_scale):
- """
- Arguments:
- q: (batch, nheads, seq, headdim)
- k: (batch, nheads, seq, headdim)
- v: (batch, nheads, seq, headdim)
- sm_scale: float. The scaling of QK^T before applying softmax.
- Return:
- out: (batch, nheads, seq, headdim)
- """
- if HAS_TRITON:
- return _TritonFlashAttention.apply(q, k, v, sm_scale)
- else:
- raise RuntimeError("Triton kernel requires CUDA 11.4+!")
-
-
-if HAS_FLASH_ATTN:
-
- def flash_attention_qkv(qkv, sm_scale, batch_size, seq_len, dropout_p=0., causal=False):
- """
- Arguments:
- qkv: (batch * seqlen, 3, nheads, headdim)
- batch_size: int.
- seq_len: int.
- sm_scale: float. The scaling of QK^T before applying softmax.
- Default to 1 / sqrt(headdim).
- dropout_p: float.
- causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
- Return:
- out: (total, nheads, headdim).
- """
- max_s = seq_len
- cu_seqlens = torch.arange(0, (batch_size + 1) * seq_len, step=seq_len, dtype=torch.int32, device=qkv.device)
- out = flash_attn_unpadded_qkvpacked_func(qkv,
- cu_seqlens,
- max_s,
- dropout_p,
- softmax_scale=sm_scale,
- causal=causal)
- return out
-
- def flash_attention_q_kv(q, kv, sm_scale, batch_size, q_seqlen, kv_seqlen, dropout_p=0., causal=False):
- """
- Arguments:
- q: (batch * q_seqlen, nheads, headdim)
- kv: (batch * kv_seqlen, 2, nheads, headdim)
- batch_size: int.
- seq_len: int.
- sm_scale: float. The scaling of QK^T before applying softmax.
- Default to 1 / sqrt(headdim).
- dropout_p: float.
- causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
- Return:
- out: (total, nheads, headdim).
- """
- cu_seqlens_q = torch.arange(0, (batch_size + 1) * q_seqlen, step=q_seqlen, dtype=torch.int32, device=q.device)
- cu_seqlens_k = torch.arange(0, (batch_size + 1) * kv_seqlen,
- step=kv_seqlen,
- dtype=torch.int32,
- device=kv.device)
- out = flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, q_seqlen, kv_seqlen, dropout_p,
- sm_scale, causal)
- return out
-
- def flash_attention_q_k_v(q, k, v, sm_scale, batch_size, q_seqlen, kv_seqlen, dropout_p=0., causal=False):
- """
- Arguments:
- q: (batch * q_seqlen, nheads, headdim)
- k: (batch * kv_seqlen, nheads, headdim)
- v: (batch * kv_seqlen, nheads, headdim)
- batch_size: int.
- seq_len: int.
- dropout_p: float. Dropout probability.
- sm_scale: float. The scaling of QK^T before applying softmax.
- Default to 1 / sqrt(headdim).
- causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
- Return:
- out: (total, nheads, headdim).
- """
- cu_seqlens_q = torch.arange(0, (batch_size + 1) * q_seqlen, step=q_seqlen, dtype=torch.int32, device=q.device)
- cu_seqlens_kv = torch.arange(0, (batch_size + 1) * kv_seqlen,
- step=kv_seqlen,
- dtype=torch.int32,
- device=k.device)
- return flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, q_seqlen, kv_seqlen, dropout_p, sm_scale,
- causal)
-
-
-##########################################################################
diff --git a/colossalai/kernel/cuda_native/mha/__init__.py b/colossalai/kernel/cuda_native/mha/__init__.py
new file mode 100644
index 000000000000..21fddd512957
--- /dev/null
+++ b/colossalai/kernel/cuda_native/mha/__init__.py
@@ -0,0 +1,3 @@
+from .mha import ColoAttention
+
+__all__ = ['ColoAttention']
diff --git a/colossalai/kernel/cuda_native/mha/flash_attn_2.py b/colossalai/kernel/cuda_native/mha/flash_attn_2.py
new file mode 100644
index 000000000000..6a8d74f70c1d
--- /dev/null
+++ b/colossalai/kernel/cuda_native/mha/flash_attn_2.py
@@ -0,0 +1,68 @@
+import warnings
+from typing import Optional
+
+import torch
+
+
+def is_ampere_or_better_gpu():
+ if torch.cuda.is_available():
+ device = torch.device("cuda")
+ properties = torch.cuda.get_device_properties(device)
+ if properties.major >= 8: # Ampere GPUs or newer
+ return True
+ return False
+
+
+# "Check Ampere GPUs or newer"
+HAS_FLASH_ATTN = False
+if is_ampere_or_better_gpu():
+ HAS_FLASH_ATTN = True
+else:
+ warnings.warn('FlashAttention only supports Ampere GPUs or newer.')
+ HAS_FLASH_ATTN = False
+try:
+ from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func
+ HAS_FLASH_ATTN = True
+except ImportError:
+ warnings.warn('please install flash_attn from https://github.com/HazyResearch/flash-attention')
+ HAS_FLASH_ATTN = False
+
+if HAS_FLASH_ATTN:
+ from einops import rearrange
+
+ from .utils import SeqLenInfo
+
+ def flash_attention(q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ seq_len_info_q: SeqLenInfo,
+ seq_len_info_kv: SeqLenInfo,
+ bias: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.,
+ scale: float = None,
+ causal: bool = False,
+ padded: bool = False):
+ """
+ Arguments:
+ q: (batch, q_seqlen, nheads, headdim)
+ k: (batch, kv_seqlen, nheads, headdim)
+ v: (batch, kv_seqlen, nheads, headdim)
+ batch_size: int.
+ seq_len: int.
+ dropout_p: float. Dropout probability.
+ sm_scale: float. The scaling of QK^T before applying softmax.
+ Default to 1 / sqrt(headdim).
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
+ Return:
+ attn_out: (batch, q_seqlen, nheads, headdim).
+ """
+ if padded:
+ if seq_len_info_kv == None:
+ seq_len_info_kv = seq_len_info_q
+
+ attn_out = flash_attn_varlen_func(q, k, v, seq_len_info_q.cu_seqlens, seq_len_info_kv.cu_seqlens,
+ seq_len_info_q.max_seqlen, seq_len_info_kv.max_seqlen, dropout_p, scale,
+ causal)
+ else:
+ attn_out = flash_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=scale, causal=causal)
+ return attn_out
diff --git a/colossalai/kernel/cuda_native/mha/mem_eff_attn.py b/colossalai/kernel/cuda_native/mha/mem_eff_attn.py
new file mode 100644
index 000000000000..e83beb8b2429
--- /dev/null
+++ b/colossalai/kernel/cuda_native/mha/mem_eff_attn.py
@@ -0,0 +1,70 @@
+import warnings
+
+HAS_MEM_EFF_ATTN = False
+try:
+ from xformers.ops.fmha import memory_efficient_attention
+ HAS_MEM_EFF_ATTN = True
+except ImportError:
+ warnings.warn('please install xformers from https://github.com/facebookresearch/xformers')
+ HAS_MEM_EFF_ATTN = False
+
+if HAS_MEM_EFF_ATTN:
+ """
+ A general attention module using the flash attention kernels from xformers:
+ https://github.com/facebookresearch/xformers/tree/main/xformers/ops/fmha
+ """
+ from typing import Optional
+
+ import torch
+ from xformers.ops.fmha import MemoryEfficientAttentionCutlassOp
+ from xformers.ops.fmha.attn_bias import (
+ BlockDiagonalCausalMask,
+ BlockDiagonalMask,
+ LowerTriangularMask,
+ LowerTriangularMaskWithTensorBias,
+ )
+
+ from .utils import SeqLenInfo
+
+ allow_alibi = True
+ for op in MemoryEfficientAttentionCutlassOp:
+ allow_alibi = allow_alibi & (LowerTriangularMaskWithTensorBias in op.SUPPORTED_ATTN_BIAS_TYPES)
+
+ def mem_eff_attention(q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ seq_len_info_q: SeqLenInfo,
+ seq_len_info_kv: SeqLenInfo,
+ bias: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.,
+ scale: float = None,
+ causal: bool = False,
+ padded: bool = False):
+
+ attn_bias = None
+ if padded: # bert style
+ if not causal:
+ attn_bias = BlockDiagonalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens)
+ else:
+ attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens)
+ elif causal: # gpt style
+ attn_bias = LowerTriangularMask()
+
+ if bias is not None: # alibi / relative position embedding
+ assert allow_alibi, "flash attention with bias is not supported in this system."
+ assert causal, \
+ "attention with bias is only supported for causal attention so far."
+ attn_bias = attn_bias.add_bias(bias)
+
+ if padded:
+ q = q.unsqueeze(0)
+ k = k.unsqueeze(0)
+ v = v.unsqueeze(0)
+
+ out = memory_efficient_attention(q, k, v, attn_bias=attn_bias, p=dropout_p, scale=scale)
+
+ # shape: (b*s, n, d)
+ if padded:
+ out = out.squeeze(0)
+
+ return out
diff --git a/colossalai/kernel/cuda_native/mha/mha.py b/colossalai/kernel/cuda_native/mha/mha.py
new file mode 100644
index 000000000000..8f449a138c51
--- /dev/null
+++ b/colossalai/kernel/cuda_native/mha/mha.py
@@ -0,0 +1,107 @@
+import math
+from typing import Optional
+
+import torch
+import torch.nn.functional as F
+from einops import rearrange
+
+from ..scaled_softmax import AttnMaskType
+from .flash_attn_2 import HAS_FLASH_ATTN
+from .mem_eff_attn import HAS_MEM_EFF_ATTN
+from .utils import Repad, SeqLenInfo, Unpad
+
+if HAS_FLASH_ATTN:
+ from .flash_attn_2 import flash_attention
+if HAS_MEM_EFF_ATTN:
+ from .mem_eff_attn import mem_eff_attention
+
+
+class ColoAttention(torch.nn.Module):
+
+ def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, scale=None):
+ super().__init__()
+ assert embed_dim % num_heads == 0, \
+ f"the embed dim ({embed_dim}) is not divisible by the number of attention heads ({num_heads})."
+ if scale is not None:
+ self.scale = scale
+ else:
+ self.scale = 1 / math.sqrt(embed_dim // num_heads)
+ self.dropout = dropout
+
+ if not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN:
+ raise Exception("flash attention can not support!")
+
+ @staticmethod
+ def unpad(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
+ return Unpad.apply(tensor, indices)
+
+ @staticmethod
+ def repad(tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int) -> torch.Tensor:
+ return Repad.apply(tensor, indices, batch_size, seq_len)
+
+ def forward(self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ attn_mask_type: Optional[AttnMaskType] = None,
+ bias: Optional[torch.Tensor] = None):
+
+ attn = None
+ if HAS_FLASH_ATTN and query.dtype in [torch.float16, torch.bfloat16] and bias == None:
+ attn = flash_attention
+ else:
+ attn = mem_eff_attention
+
+ padded = attn_mask_type is not None and attn_mask_type.value % 2 == 1
+ causal = attn_mask_type is not None and attn_mask_type.value > 1
+
+ batch_size, tgt_len, src_len = query.shape[0], query.shape[1], key.shape[1]
+ # unpad
+ seq_len_info_q = None
+ seq_len_info_kv = None
+ if padded:
+ # bert style, unpad process
+ assert attn_mask is not None, \
+ f"attention mask {attn_mask} is not valid for attention mask type {attn_mask_type}."
+ assert attn_mask.dim() == 2, \
+ "attention mask is supposed to have shape (batch_size, seq_len), " + \
+ f"but got {attn_mask.dim()} dimensions."
+
+ # bert style
+ if tgt_len == src_len:
+ seq_len_info_q = SeqLenInfo.materialize(attn_mask=attn_mask, device=query.device)
+ if batch_size > 1:
+ query, key, value = self.unpad(torch.stack([query, key, value], dim=2),
+ seq_len_info_q.indices).unbind(dim=1)
+ else:
+ query, key, value = torch.stack([query, key, value], dim=2).squeeze(0).unbind(dim=1)
+ seq_len_info_kv = seq_len_info_q
+ else:
+ seq_len_info_q = SeqLenInfo.materialize(size=(batch_size, tgt_len), device=query.device)
+ seq_len_info_kv = SeqLenInfo.materialize(attn_mask=attn_mask, device=query.device)
+ if batch_size > 1:
+ query = rearrange(query, "b s ... -> c (b s) ...", c=1)
+ key, value = self.unpad(torch.stack([query, key, value], dim=2),
+ seq_len_info_kv.indices).unbind(dim=1)
+ else:
+ query, key, value = torch.stack([query, key, value], dim=2).squeeze(0).unbind(dim=1)
+
+ out = attn(query,
+ key,
+ value,
+ seq_len_info_q,
+ seq_len_info_kv,
+ dropout_p=self.dropout,
+ scale=self.scale,
+ causal=causal,
+ padded=padded)
+
+ # repad
+ if padded:
+ if batch_size > 1:
+ out = self.repad(out, seq_len_info_q.indices, batch_size, tgt_len)
+ out = rearrange(out, '(b s) h d -> b s h d', b=batch_size)
+
+ out = rearrange(out, 'b s h d -> b s (h d)')
+ return out
diff --git a/colossalai/kernel/cuda_native/mha/utils.py b/colossalai/kernel/cuda_native/mha/utils.py
new file mode 100644
index 000000000000..e3e431fa7e99
--- /dev/null
+++ b/colossalai/kernel/cuda_native/mha/utils.py
@@ -0,0 +1,82 @@
+from dataclasses import dataclass
+from typing import Iterable, Tuple
+
+import torch
+import torch.nn.functional as F
+from einops import rearrange
+
+from colossalai.utils.cuda import get_current_device
+
+
+class Unpad(torch.autograd.Function):
+ """
+ Adapted from
+ https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
+ """
+
+ @staticmethod
+ def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor):
+ ctx.save_for_backward(indices)
+ # [b, s, ...]
+ assert tensor.ndim >= 3
+ ctx.bsz = tensor.shape[0]
+ out = rearrange(tensor, 'b s ... -> (b s) ...')
+ ctx.shape = out.shape
+ # [ntokens, ...]
+ return out[indices]
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ indices, = ctx.saved_tensors
+ # [ntokens, ...]
+ grad = torch.zeros(ctx.shape, dtype=grad_output.dtype, device=grad_output.device)
+ grad[indices] = grad_output
+ grad = rearrange(grad, '(b s) ... -> b s ...', b=ctx.bsz)
+ # [b, s, ...]
+ return grad, None
+
+
+class Repad(torch.autograd.Function):
+ """
+ Adapted from
+ https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
+ """
+
+ @staticmethod
+ def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int):
+ ctx.save_for_backward(indices)
+ # [ntokens, ...]
+ tensor = tensor
+ out = torch.zeros((batch_size * seq_len, *tensor.shape[1:]), dtype=tensor.dtype, device=tensor.device)
+ # [b*s, ...]
+ out[indices] = tensor
+ return out
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ indices, = ctx.saved_tensors
+ # [b*s, ...]
+ grad = grad_output[indices]
+ # [ntokens, ...]
+ return grad, None, None, None
+
+
+@dataclass
+class SeqLenInfo:
+ seqlens: Iterable[int] = None
+ indices: torch.Tensor = None
+ max_seqlen: int = None
+ cu_seqlens: torch.Tensor = None
+
+ @staticmethod
+ def materialize(attn_mask: torch.Tensor = None, size: Tuple[int] = None, device=get_current_device()):
+ if attn_mask is not None:
+ indices = torch.nonzero(attn_mask.flatten(), as_tuple=False).flatten().to(device)
+ seqlens = attn_mask.sum(dim=-1, dtype=torch.int32).flatten()
+ else:
+ batch_size, tgt_len = size[0], size[1]
+ indices = torch.arange(batch_size * tgt_len, dtype=torch.long, device=device)
+ seqlens = torch.LongTensor([tgt_len] * batch_size, device=device)
+ max_seqlen = max(seqlens)
+ cu_seqlens = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)).to(device)
+ return SeqLenInfo(seqlens.tolist(), indices, max_seqlen, cu_seqlens)
diff --git a/colossalai/shardformer/modeling/blip2.py b/colossalai/shardformer/modeling/blip2.py
index c5c6b14ba993..69730fd3d254 100644
--- a/colossalai/shardformer/modeling/blip2.py
+++ b/colossalai/shardformer/modeling/blip2.py
@@ -65,7 +65,7 @@ def get_blip2_flash_attention_forward():
from transformers.models.blip_2.modeling_blip_2 import Blip2Attention
- from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention
+ from colossalai.kernel.cuda_native import ColoAttention
def forward(
self: Blip2Attention,
diff --git a/colossalai/shardformer/modeling/chatglm.py b/colossalai/shardformer/modeling/chatglm.py
index 3d453c3bd6db..a95966c3b99e 100644
--- a/colossalai/shardformer/modeling/chatglm.py
+++ b/colossalai/shardformer/modeling/chatglm.py
@@ -19,7 +19,7 @@
def get_flash_core_attention_forward():
- from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention
+ from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention
from .chatglm2_6b.modeling_chatglm import CoreAttention
@@ -126,7 +126,6 @@ def forward(
return forward
-
class ChatGLMPipelineForwards:
'''
This class serves as a micro library for ChatGLM model forwards under pipeline parallelism.
diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py
index e02581fbaa9b..a12a9796fa8a 100644
--- a/colossalai/shardformer/modeling/gpt2.py
+++ b/colossalai/shardformer/modeling/gpt2.py
@@ -674,7 +674,7 @@ def get_gpt2_flash_attention_forward():
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
- from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention
+ from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention
def split_heads(tensor, num_heads, attn_head_size):
"""
diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py
index 9d6335503b36..2f54daac586a 100644
--- a/colossalai/shardformer/modeling/llama.py
+++ b/colossalai/shardformer/modeling/llama.py
@@ -392,7 +392,7 @@ def get_llama_flash_attention_forward():
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
- from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention
+ from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention
def forward(
self: LlamaAttention,
diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py
index 299dfb5562f3..bdf141816737 100644
--- a/colossalai/shardformer/modeling/opt.py
+++ b/colossalai/shardformer/modeling/opt.py
@@ -8,7 +8,7 @@ def get_opt_flash_attention_forward():
from transformers.models.opt.modeling_opt import OPTAttention
- from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention
+ from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention
def forward(
self: OPTAttention,
diff --git a/colossalai/shardformer/modeling/vit.py b/colossalai/shardformer/modeling/vit.py
index 22c4dd998cac..eb0ea4c7502b 100644
--- a/colossalai/shardformer/modeling/vit.py
+++ b/colossalai/shardformer/modeling/vit.py
@@ -342,7 +342,7 @@ def get_vit_flash_self_attention_forward():
from transformers.models.vit.modeling_vit import ViTSelfAttention
- from colossalai.kernel.cuda_native.flash_attention import ColoAttention
+ from colossalai.kernel.cuda_native import ColoAttention
def transpose_for_scores(x: torch.Tensor, num_attention_heads, attention_head_size) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size)
diff --git a/colossalai/shardformer/modeling/whisper.py b/colossalai/shardformer/modeling/whisper.py
index 6bc387ac8974..0a16c6f788da 100644
--- a/colossalai/shardformer/modeling/whisper.py
+++ b/colossalai/shardformer/modeling/whisper.py
@@ -8,7 +8,7 @@ def get_whisper_flash_attention_forward():
from transformers.models.whisper.modeling_whisper import WhisperAttention
- from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention
+ from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention
def shape(tensor: torch.Tensor, seq_len: int, bsz: int, num_heads: int, head_dim: int):
return tensor.view(bsz, seq_len, num_heads, head_dim).contiguous()
diff --git a/tests/test_utils/test_flash_attention.py b/tests/test_utils/test_flash_attention.py
index 938f85b410e1..fbcc452650cf 100644
--- a/tests/test_utils/test_flash_attention.py
+++ b/tests/test_utils/test_flash_attention.py
@@ -4,11 +4,15 @@
import torch
from einops import rearrange
-from colossalai.kernel.cuda_native.flash_attention import HAS_MEM_EFF_ATTN
+from colossalai.kernel.cuda_native.mha.flash_attn_2 import HAS_FLASH_ATTN
+from colossalai.kernel.cuda_native.mha.mem_eff_attn import HAS_MEM_EFF_ATTN
from colossalai.testing import clear_cache_before_run, parameterize
-if HAS_MEM_EFF_ATTN:
- from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention
+if HAS_MEM_EFF_ATTN or HAS_FLASH_ATTN:
+ from colossalai.kernel.cuda_native import ColoAttention
+ from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType
+
+DTYPE = [torch.float16, torch.bfloat16, torch.float32]
def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale):
@@ -22,10 +26,12 @@ def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale):
return ref_out
-@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available")
+@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available")
@clear_cache_before_run()
-@parameterize('proj_shape', [(1, 128, 4, 16)])
-def test_attention_gpt(proj_shape, dtype=torch.float16):
+@parameterize('proj_shape', [(1, 8, 4, 16)])
+@parameterize('dtype', DTYPE)
+def test_attention_gpt(proj_shape, dtype):
+ # TODO check output value
(B, S, H, D_HEAD) = proj_shape
D = H * D_HEAD
@@ -48,10 +54,11 @@ def test_attention_gpt(proj_shape, dtype=torch.float16):
y.backward(dy)
-@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available")
+@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available")
@clear_cache_before_run()
-@parameterize('proj_shape', [(1, 128, 4, 16)])
-def test_attention_bert(proj_shape, dtype=torch.float16):
+@parameterize('proj_shape', [(6, 8, 4, 16)])
+@parameterize('dtype', DTYPE)
+def test_attention_bert(proj_shape, dtype):
(B, S, H, D_HEAD) = proj_shape
D = H * D_HEAD
@@ -73,10 +80,11 @@ def test_attention_bert(proj_shape, dtype=torch.float16):
y.backward(dy)
-@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available")
+@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available")
@clear_cache_before_run()
-@parameterize('proj_shape', [(6, 128, 4, 16)])
-def test_attention_no_mask(proj_shape, dtype=torch.float16):
+@parameterize('proj_shape', [(6, 8, 4, 16)])
+@parameterize('dtype', DTYPE)
+def test_attention_no_mask(proj_shape, dtype):
(B, S, H, D_HEAD) = proj_shape
D = H * D_HEAD
@@ -94,10 +102,11 @@ def test_attention_no_mask(proj_shape, dtype=torch.float16):
y.backward(dy)
-@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available")
+@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available")
@clear_cache_before_run()
-@parameterize('proj_shape', [(6, 128, 256, 4, 16)])
-def test_cross_attention(proj_shape, dtype=torch.float16):
+@parameterize('proj_shape', [(6, 24, 8, 4, 16)])
+@parameterize('dtype', DTYPE)
+def test_cross_attention(proj_shape, dtype):
(B, S, T, H, D_HEAD) = proj_shape
D = H * D_HEAD
From ed2c2297464e603ee9770be911b483c4deb9b7d0 Mon Sep 17 00:00:00 2001
From: flybird1111 <1829166702@qq.com>
Date: Thu, 10 Aug 2023 13:59:30 +0800
Subject: [PATCH 55/64] [shardformer] test all optimizations (#4399)
[shardformer] test all optimizations
[shardformer] test all optimizations
[shardformer] test all optimizations
---
.../booster/plugin/hybrid_parallel_plugin.py | 11 +++-
requirements/requirements-test.txt | 2 +-
requirements/requirements.txt | 2 +-
tests/test_shardformer/test_model/_utils.py | 16 ++---
.../test_model/test_shard_gpt2.py | 59 ++++++++++++-------
5 files changed, 60 insertions(+), 30 deletions(-)
diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py
index 42942aaeb89d..28a19af0ce91 100644
--- a/colossalai/booster/plugin/hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py
@@ -148,7 +148,10 @@ def __init__(
precision: str = 'fp16',
zero_stage: int = 0,
cpu_offload: bool = False,
+ enable_all_optimization: bool = False,
enable_fused_normalization: bool = False,
+ enable_flash_attention: bool = False,
+ enable_jit_fused: bool = False,
num_microbatches: Optional[int] = None,
initial_scale: float = 2**16,
min_scale: float = 1,
@@ -171,7 +174,10 @@ def __init__(
self.precision = precision
self.zero_stage = zero_stage
self.cpu_offload = cpu_offload
+ self.enable_all_optimization = enable_all_optimization
self.enable_fused_normalization = enable_fused_normalization
+ self.enable_flash_attention = enable_flash_attention
+ self.enable_jit_fused = enable_jit_fused
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size)
self.stage_manager = None
self.schedule = None
@@ -186,7 +192,10 @@ def __init__(
self.shard_config = ShardConfig(tensor_parallel_process_group=self.tp_group,
pipeline_stage_manager=self.stage_manager,
enable_tensor_parallelism=self.tp_size > 1,
- enable_fused_normalization=self.enable_fused_normalization)
+ enable_all_optimization=self.enable_all_optimization,
+ enable_fused_normalization=self.enable_fused_normalization,
+ enable_flash_attention=self.enable_flash_attention,
+ enable_jit_fused=self.enable_jit_fused)
self.amp_config = dict(
initial_scale=initial_scale,
growth_factor=growth_factor,
diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt
index fa797f26a4ca..2261c5be2fe8 100644
--- a/requirements/requirements-test.txt
+++ b/requirements/requirements-test.txt
@@ -18,4 +18,4 @@ requests==2.27.1 # downgrade to avoid huggingface error https://github.com/huggi
SentencePiece
datasets
ninja
-flash-attn
+flash-attn>=2.0
diff --git a/requirements/requirements.txt b/requirements/requirements.txt
index 3ee1567db7fa..c94f45e91e2d 100644
--- a/requirements/requirements.txt
+++ b/requirements/requirements.txt
@@ -10,4 +10,4 @@ contexttimer
ninja
torch>=1.11
safetensors
-flash-attn
+flash-attn>=2.0
diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py
index 98cdc5a4b95b..cce21809d829 100644
--- a/tests/test_shardformer/test_model/_utils.py
+++ b/tests/test_shardformer/test_model/_utils.py
@@ -1,6 +1,5 @@
import copy
from contextlib import nullcontext
-from typing import Optional
from typing import Any, Callable, Dict, List, Optional
import torch
@@ -16,8 +15,8 @@
from colossalai.lazy import LazyInitContext
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
-from colossalai.shardformer.policies.auto_policy import Policy
from colossalai.shardformer._utils import getattr_
+from colossalai.shardformer.policies.auto_policy import Policy
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
@@ -156,10 +155,12 @@ def _criterion(outputs, inputs):
else:
data = {k: v.cuda() for k, v in data.items()}
sharded_output = sharded_model(**data)
+
sharded_loss = criterion(sharded_output)
- sharded_loss.backward()
+ sharded_optimizer.backward(sharded_loss)
org_model.train()
+ data = {k: v.cuda() for k, v in data.items()}
org_output = org_model(**data)
org_loss = criterion(org_output)
org_loss.backward()
@@ -181,12 +182,12 @@ def check_output_hidden_state(org_output: Tensor,
if stage_manager and stage_manager.is_last_stage():
sharded_hidden_state = torch.cat([output.last_hidden_state for output in sharded_output['outputs']], dim=0)
- assert torch.allclose(org_hidden_state, sharded_hidden_state, atol=atol, rtol=rtol), \
+ assert torch.allclose(org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol), \
f"shard model's output hidden state is not equal to origin model's last hidden state\n{org_hidden_state}\n{sharded_hidden_state}"
def check_loss(org_loss: Tensor, sharded_loss: Tensor, atol: float = 1e-5, rtol: float = 1e-3):
- assert torch.allclose(org_loss, sharded_loss, atol=atol, rtol=rtol), \
+ assert torch.allclose(org_loss.float(), sharded_loss.float(), atol=atol, rtol=rtol), \
f"shard model loss is not equal to origin model loss\n{org_loss}\n{sharded_loss}"
@@ -213,7 +214,7 @@ def check_weight(org_model: Module,
if verbose and dist.get_rank() == 0:
print(f"'{suffix}' weight: {org_weight}, {sharded_weight}")
- assert torch.allclose(org_weight, sharded_weight, atol=atol, rtol=rtol), \
+ assert torch.allclose(org_weight.float(), sharded_weight.float(), atol=atol, rtol=rtol), \
f"shard model weight is not equal to origin model weight\n{org_weight}\n{sharded_weight}"
@@ -244,6 +245,7 @@ def check_grad(org_model: Module,
if verbose and dist.get_rank() == 0:
print(f"'{suffix}' grad: {org_grad}, {shard_grad}")
+
assert torch.allclose(
- org_grad, shard_grad, rtol=rtol, atol=atol
+ org_grad.float(), shard_grad.float(), rtol=rtol, atol=atol
), f"error attribute '{suffix}', orgin model grad is not equal to shard model grad\n{org_grad}\n{shard_grad}"
diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py
index 1882bf7822cc..3ac8fa26d860 100644
--- a/tests/test_shardformer/test_model/test_shard_gpt2.py
+++ b/tests/test_shardformer/test_model/test_shard_gpt2.py
@@ -3,6 +3,7 @@
from torch import distributed as dist
import colossalai
+from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule
from colossalai.logging import disable_existing_loggers
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
@@ -38,33 +39,49 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
+ if test_config['precision'] == 'fp32':
+ atol, rtol = 1e-5, 1e-3
+ else:
+ atol, rtol = 5e-3, 5e-3
if org_model.__class__.__name__ == 'GPT2Model':
- check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3)
+ check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
- check_loss(org_loss, sharded_loss, atol=1e-5, rtol=1e-3)
+ # check loss
+ check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
+
+ def unwrap(module):
+ if isinstance(module, HybridParallelModule):
+ module = module.unwrap()
+ if module.__class__.__name__ == 'GPT2Model':
+ return module
+ return module.transformer
# unwrap model
- if org_model.__class__.__name__ == 'GPT2Model':
- gpt2 = org_model
- sharded_gpt2 = sharded_model.unwrap()
- else:
- gpt2 = org_model.transformer
- sharded_gpt2 = sharded_model.unwrap().transformer
+ gpt2 = unwrap(org_model)
+ sharded_gpt2 = unwrap(sharded_model)
col_layer_for_check = ['h[0].mlp.c_fc']
row_layer_for_check = ['wte', 'h[0].mlp.c_proj']
# check grad
+ if test_config['precision'] == 'fp32':
+ atol, rtol = 1e-4, 1e-3
+ else:
+ atol, rtol = 5e-3, 5e-3
if stage_manager is None or stage_manager.is_first_stage():
- check_grad(gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=1e-4, rtol=1e-3, dim=1, verbose=False)
- check_grad(gpt2, sharded_gpt2, row_layer_for_check, tp_group, atol=1e-4, rtol=1e-3, dim=0, verbose=False)
+ check_grad(gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False)
+ check_grad(gpt2, sharded_gpt2, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False)
# check weights after optimizer.step()
org_optimizer.step()
sharded_optimizer.step()
+ if test_config['precision'] == 'fp32':
+ atol, rtol = 5e-3, 1e-3
+ else:
+ atol, rtol = 5e-3, 5e-3
if stage_manager is None or stage_manager.is_first_stage():
- check_weight(gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=5e-3, rtol=1e-3, dim=1, verbose=False)
+ check_weight(gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False)
torch.cuda.empty_cache()
@@ -73,29 +90,31 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
- 'enable_fused_normalization': True,
- 'use_lazy_init': True
+ 'enable_all_optimization': True,
+ 'use_lazy_init': True,
+ 'precision': 'fp32',
}, {
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 4,
- 'use_lazy_init': False
+ 'enable_all_optimization': True,
+ 'use_lazy_init': False,
+ 'precision': 'fp16',
+ 'initial_scale': 1,
}, {
'tp_size': 4,
'pp_size': 1,
- 'enable_fused_normalization': True,
- 'use_lazy_init': False
+ 'enable_all_optimization': True,
+ 'use_lazy_init': False,
+ 'precision': 'fp32',
}])
@clear_cache_before_run()
def run_gpt2_test(test_config):
# TODO: add test_config for TP+DP after supporting & debugging it
- # {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True}
-
- # TODO: add test_config for flash attention & jit operator after supporting
+ # TODO: check and debug TP+AMP
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
- test_config['precision'] = 'float' # Do not use fp16/bf16 in testing
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
From 6ccecc0c6984b2fe03d3b1718a79fa170d53a430 Mon Sep 17 00:00:00 2001
From: Baizhou Zhang
Date: Thu, 10 Aug 2023 15:36:46 +0800
Subject: [PATCH 56/64] [gemini] fix tensor storage cleaning in state dict
collection (#4396)
---
colossalai/zero/gemini/gemini_optimizer.py | 6 ------
1 file changed, 6 deletions(-)
diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py
index 7d0db6b1fa23..a2085323f83e 100644
--- a/colossalai/zero/gemini/gemini_optimizer.py
+++ b/colossalai/zero/gemini/gemini_optimizer.py
@@ -1,6 +1,5 @@
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
import copy
-import gc
import math
import warnings
from typing import Any, Dict, Iterator, OrderedDict, Set, Tuple
@@ -468,11 +467,6 @@ def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict:
self.load_from_compacted_states(compacted_states, collected_states, state_names, shard_offset,
shard_size)
- # Clean gathered states
- for state_shard in gathered_state_shards:
- del state_shard[0]
- gc.collect()
-
# Reshape tensors
if is_collector:
for state_name, state_tensor in collected_states.items():
From 9916a190e408296f13f40849c2458518fcd3c538 Mon Sep 17 00:00:00 2001
From: Jianghai <72591262+CjhHa1@users.noreply.github.com>
Date: Fri, 11 Aug 2023 10:32:53 +0800
Subject: [PATCH 57/64] [pipeline] rewrite bert tests and fix some bugs (#4409)
* add pipeline policy and bert forward to be done
* add bertmodel pipeline forward and make tests
* add Bert_Policy and test for policy
* update formatting
* update formatting
* update the code
* fix bugs
* fix name confilt
* add bloom model and policy ,revise the base class of policy
* revise
* revision
* add bert_for_pretraining
* add bert_for_pretraining forward and policy
* fix typos
* cancel warning
* change the imediate output to default dict
* change the default output of get_shared_params
* rewrite bert test
* rewrite bert test
* fix some bugs
* del pipeline tests
* del pipeline tests
* del useless print
* del useless print
* rewrite data repeats
---
tests/kit/model_zoo/transformers/bert.py | 3 +-
tests/test_shardformer/test_model/_utils.py | 8 +-
.../test_model/test_shard_bert.py | 129 +++++++++++-------
.../test_model/test_shard_bert_pipeline.py | 107 ---------------
4 files changed, 88 insertions(+), 159 deletions(-)
delete mode 100644 tests/test_shardformer/test_model/test_shard_bert_pipeline.py
diff --git a/tests/kit/model_zoo/transformers/bert.py b/tests/kit/model_zoo/transformers/bert.py
index 9834f5425027..52158596bcf8 100644
--- a/tests/kit/model_zoo/transformers/bert.py
+++ b/tests/kit/model_zoo/transformers/bert.py
@@ -104,7 +104,8 @@ def data_gen_for_qa():
output_transform_fn = lambda x: x
# define loss funciton
-loss_fn_for_bert_model = lambda x: x.pooler_output.sum()
+loss_fn_for_bert_model = lambda x: torch.nn.functional.mse_loss(x.last_hidden_state, torch.ones_like(x.last_hidden_state
+ ))
loss_fn = lambda x: x.loss
config = transformers.BertConfig(hidden_size=128,
diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py
index cce21809d829..c9da9d32e554 100644
--- a/tests/test_shardformer/test_model/_utils.py
+++ b/tests/test_shardformer/test_model/_utils.py
@@ -131,6 +131,8 @@ def build_model_from_hybrid_plugin(model_fn: Callable, loss_fn: Callable, test_c
def run_forward_backward_with_hybrid_plugin(org_model: Module, sharded_model: Module, sharded_optimizer: Optimizer,
data_gen_fn: Callable, output_transform_fn: Callable, criterion: Callable,
booster: Booster):
+ org_model.cuda()
+ sharded_model.cuda()
def _criterion(outputs, inputs):
outputs = output_transform_fn(outputs)
@@ -141,7 +143,8 @@ def _criterion(outputs, inputs):
sharded_model.train()
if booster.plugin.stage_manager is not None:
data = {
- k: v.to('cuda').repeat(4, 1) if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v
+ k: v.to('cuda').repeat(*([4] + [1] *
+ (v.dim() - 1))) if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v
for k, v in data.items()
}
data_iter = iter([data])
@@ -162,6 +165,7 @@ def _criterion(outputs, inputs):
org_model.train()
data = {k: v.cuda() for k, v in data.items()}
org_output = org_model(**data)
+
org_loss = criterion(org_output)
org_loss.backward()
@@ -226,7 +230,6 @@ def check_grad(org_model: Module,
atol: float = 1e-5,
rtol: float = 1e-3,
verbose: bool = False):
-
for suffix in layer_suffix:
org_grad = getattr_(org_model, suffix).weight.grad
shard_grad = getattr_(sharded_model, suffix).weight.grad
@@ -242,7 +245,6 @@ def check_grad(org_model: Module,
# embedding may be resized when using tensor parallel
if shard_grad.shape[0] > org_grad.shape[0]:
shard_grad = shard_grad[:org_grad.shape[0], :]
-
if verbose and dist.get_rank() == 0:
print(f"'{suffix}' grad: {org_grad}, {shard_grad}")
diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py
index afc1507e8b24..fdbcd014e1b8 100644
--- a/tests/test_shardformer/test_model/test_shard_bert.py
+++ b/tests/test_shardformer/test_model/test_shard_bert.py
@@ -1,65 +1,98 @@
import pytest
import torch
+from torch import distributed as dist
import colossalai
-from colossalai.cluster import ProcessGroupMesh
from colossalai.logging import disable_existing_loggers
-from colossalai.pipeline.stage_manager import PipelineStageManager
-from colossalai.shardformer.policies.auto_policy import get_autopolicy
-from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
-from colossalai.testing import (
- assert_hf_output_close,
- clear_cache_before_run,
- parameterize,
- rerun_if_address_is_in_use,
- spawn,
-)
+from colossalai.shardformer.layer.utils import Randomizer
+from colossalai.tensor.d_tensor.api import clear_layout_converter
+from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
-from tests.test_shardformer.test_model._utils import build_model, check_grad, check_state_dict, run_forward
+from tests.test_shardformer.test_model._utils import (
+ build_model_from_hybrid_plugin,
+ check_grad,
+ check_loss,
+ check_output_hidden_state,
+ check_weight,
+ run_forward_backward_with_hybrid_plugin,
+)
-def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
- # unwarp model
+def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
+
+ org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \
+ build_model_from_hybrid_plugin(model_fn, loss_fn, test_config)
+
+ org_loss, org_output, sharded_loss, sharded_output = \
+ run_forward_backward_with_hybrid_plugin(
+ org_model,
+ sharded_model,
+ sharded_optimizer,
+ data_gen_fn,
+ output_transform_fn,
+ criterion,
+ booster)
+ stage_manager = booster.plugin.stage_manager
+ tp_group = booster.plugin.tp_group
+ # check last hidden state & loss
+ if stage_manager is None or stage_manager.is_last_stage():
+ if org_model.__class__.__name__ == 'BertModel':
+ check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3)
+
+ check_loss(org_loss, sharded_loss, atol=1e-5, rtol=1e-3)
+ # unwrap model
if org_model.__class__.__name__ == 'BertModel':
bert = org_model
- sharded_bert = sharded_model
+ sharded_bert = sharded_model.unwrap()
else:
bert = org_model.bert
- sharded_bert = sharded_model.bert
-
- # check forward
- org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
- output_transform_fn, loss_fn)
- assert_hf_output_close(org_output, shard_output)
-
- # do backward
- org_loss.backward()
- shard_loss.backward()
-
- assert torch.allclose(org_loss, shard_loss,
- atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
-
- # check grad
- col_layer_for_check = ['encoder.layer[0].attention.self.query', 'embeddings.word_embeddings']
- row_layer_for_check = ['encoder.layer[0].attention.output.dense']
- check_grad(bert, sharded_bert, col_layer_for_check, atol=1e-7, rtol=1e-3, dim=0, verbose=False)
- check_grad(bert, sharded_bert, row_layer_for_check, atol=1e-7, rtol=1e-3, dim=1, verbose=False)
-
-
-@parameterize('enable_fused_normalization', [True, False])
-@parameterize('enable_tensor_parallelism', [True, False])
-@parameterize('enable_flash_attention', [True, False])
-@parameterize('enable_jit_fused', [True, False])
-@parameterize('use_lazy_init', [False, True])
-def run_bert_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused,
- use_lazy_init):
+ sharded_bert = sharded_model.unwrap().bert
+
+ col_layer_for_check = ['encoder.layer[0].output.dense']
+ row_layer_for_check = ['embeddings.word_embeddings', 'encoder.layer[0].intermediate.dense']
+
+ if stage_manager is None or stage_manager.is_first_stage():
+ #check_weight(bert.embeddings.word_embeddings, sharded_bert.embeddings.word_embeddings, tp_group, atol=1e-5, rtol=1e-3)
+ #check_weight(bert.encoder.layer[0].attention.self.query, sharded_bert.encoder.layer[0].attention.self.query, tp_group, atol=5e-3, rtol=1e-3)
+ check_grad(bert, sharded_bert, col_layer_for_check, tp_group, atol=1e-4, rtol=1e-3, dim=1, verbose=False)
+ check_grad(bert, sharded_bert, row_layer_for_check, tp_group, atol=1e-4, rtol=1e-3, dim=0, verbose=False)
+
+ # check weights after optimizer.step()
+ org_optimizer.step()
+ sharded_optimizer.step()
+ if stage_manager is None or stage_manager.is_first_stage():
+ check_weight(bert, sharded_bert, col_layer_for_check, tp_group, atol=5e-3, rtol=1e-3, dim=1, verbose=False)
+
+ torch.cuda.empty_cache()
+
+
+@parameterize('test_config', [{
+ 'tp_size': 1,
+ 'pp_size': 2,
+ 'num_microbatches': 4,
+ 'use_lazy_init': True
+}, {
+ 'tp_size': 2,
+ 'pp_size': 2,
+ 'num_microbatches': 4,
+ 'enable_fused_normalization': False,
+ 'use_lazy_init': False
+}, {
+ 'tp_size': 4,
+ 'pp_size': 1,
+ 'enable_fused_normalization': True,
+ 'use_lazy_init': False
+}])
+def run_bert_test(test_config):
+
sub_model_zoo = model_zoo.get_sub_registry('transformers_bert')
+ test_config['precision'] = 'float'
+
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
- org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
- enable_flash_attention, enable_jit_fused, use_lazy_init)
- check_state_dict(org_model, sharded_model, name=name)
- check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
+ check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
+ clear_layout_converter()
+ Randomizer.reset_index()
torch.cuda.empty_cache()
@@ -73,7 +106,7 @@ def check_bert(rank, world_size, port):
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_bert():
- spawn(check_bert, 2)
+ spawn(check_bert, 4)
if __name__ == "__main__":
diff --git a/tests/test_shardformer/test_model/test_shard_bert_pipeline.py b/tests/test_shardformer/test_model/test_shard_bert_pipeline.py
deleted file mode 100644
index 3170b58a1175..000000000000
--- a/tests/test_shardformer/test_model/test_shard_bert_pipeline.py
+++ /dev/null
@@ -1,107 +0,0 @@
-import pytest
-import torch
-
-import colossalai
-from colossalai.cluster import ProcessGroupMesh
-from colossalai.logging import disable_existing_loggers
-from colossalai.pipeline.stage_manager import PipelineStageManager
-from colossalai.shardformer.policies.auto_policy import get_autopolicy
-from colossalai.shardformer.shard import ShardConfig
-from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
-from colossalai.testing import (
- assert_hf_output_close,
- clear_cache_before_run,
- parameterize,
- rerun_if_address_is_in_use,
- spawn,
-)
-from tests.kit.model_zoo import model_zoo
-from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward
-
-
-def check_bert_model_policy(name, model: torch.nn.Module, stage_manager: PipelineStageManager):
- stage_manager = stage_manager
- policy = get_autopolicy(model)
- policy.set_model(model)
- model_config = ShardConfig(pipeline_stage_manager=stage_manager, enable_tensor_parallelism=False)
- policy.set_shard_config(model_config)
- layers = policy.get_held_layers()
- if stage_manager.is_first_stage():
- assert len(layers) == 1 + 1
- else:
- if name == "transformers_bert":
- assert len(layers) == 1 + 1
- elif name in [
- "transformers_bert_for_sequence_classification", "transformers_bert_for_token_classification",
- "transformers_bert_for_mcq"
- ]:
- assert len(layers) == 1 + 3
- else:
- assert len(layers) == 1 + 2
-
-
-def check_bert_model_pipeline_forward(name, sharded_model, stage_manager: PipelineStageManager):
- if name == 'transformers_bert_for_mcq':
- x = torch.randint(0, 1000, (2, 3, 3)).cuda()
- attention_mask = torch.ones_like(x).cuda()
- if stage_manager.stage == 0:
- output = sharded_model(input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager)
- assert output['hidden_states'].shape == (6, 3, 128)
- else:
- hidden_states = torch.randint(0, 1000, (6, 3, 128)).to(torch.float32).cuda()
- output = sharded_model(input_ids=x,
- hidden_states=hidden_states,
- attention_mask=attention_mask,
- stage_manager=stage_manager)
- assert output[0].shape == (2, 3)
- else:
- x = torch.randint(0, 1000, (2, 3)).cuda()
- # one batch, 2 single sentences, each sentence has 3 tokens
- hidden_states = torch.randint(0, 1000, (2, 3, 128)).to(torch.float32).cuda()
- if stage_manager.stage == 0:
- attention_mask = torch.ones_like(x).cuda()
- output = sharded_model(input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager)
- assert output['hidden_states'].shape == (2, 3, 128)
- else:
- attention_mask = torch.ones((2, 3)).cuda()
- output = sharded_model(hidden_states=hidden_states,
- attention_mask=attention_mask,
- stage_manager=stage_manager)
- assert output[0].shape[0] == 2
-
-
-@parameterize('enable_fused_normalization', [False])
-@parameterize('enable_tensor_parallelism', [False])
-@parameterize('use_lazy_init', [False])
-#TODO: merge this into test_shard_bert
-def run_bert_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
- PP_DIM = 0
- PP_SIZE = 2
- pg_mesh = ProcessGroupMesh(PP_SIZE)
- stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
-
- sub_model_zoo = model_zoo.get_sub_registry('transformers_bert')
- for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
- org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
- enable_tensor_parallelism, use_lazy_init)
- check_bert_model_policy(name, org_model, stage_manager)
- check_bert_model_pipeline_forward(name, sharded_model, stage_manager)
-
- torch.cuda.empty_cache()
-
-
-def check_bert(rank, world_size, port):
- disable_existing_loggers()
- colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
- run_bert_test()
-
-
-@pytest.mark.dist
-@rerun_if_address_is_in_use()
-@clear_cache_before_run()
-def test_bert():
- spawn(check_bert, 2)
-
-
-if __name__ == "__main__":
- test_bert()
From fcbf80f8ea1df359e5127dc6f2ea46a0833579b6 Mon Sep 17 00:00:00 2001
From: flybird11111 <1829166702@qq.com>
Date: Fri, 11 Aug 2023 11:44:23 +0800
Subject: [PATCH 58/64] [shardformer]fix, test gpt2 for AMP+TP (#4403)
* [shardformer] gpt2 tests fix
[shardformer] test all optimizations (#4399)
[shardformer] test all optimizations
[shardformer] test all optimizations
[shardformer] test all optimizations
[shardformer] gpt2 tests fix
* [shardformer] gpt2 tests fix
---
tests/test_shardformer/test_model/_utils.py | 8 +++-----
tests/test_shardformer/test_model/test_shard_gpt2.py | 8 +++-----
2 files changed, 6 insertions(+), 10 deletions(-)
diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py
index c9da9d32e554..c51df07f6c11 100644
--- a/tests/test_shardformer/test_model/_utils.py
+++ b/tests/test_shardformer/test_model/_utils.py
@@ -210,7 +210,7 @@ def check_weight(org_model: Module,
if is_distributed_tensor(sharded_weight) or is_customized_distributed_tensor(sharded_weight):
sharded_weight_list = [
- torch.zeros([*sharded_weight.shape]).to('cuda') for _ in range(dist.get_world_size(tp_group))
+ torch.zeros_like(sharded_weight).to('cuda') for _ in range(dist.get_world_size(tp_group))
]
dist.all_gather(sharded_weight_list, sharded_weight, tp_group)
sharded_weight = torch.cat(sharded_weight_list, dim=dim)
@@ -219,7 +219,7 @@ def check_weight(org_model: Module,
print(f"'{suffix}' weight: {org_weight}, {sharded_weight}")
assert torch.allclose(org_weight.float(), sharded_weight.float(), atol=atol, rtol=rtol), \
- f"shard model weight is not equal to origin model weight\n{org_weight}\n{sharded_weight}"
+ f"shard model weight {suffix} is not equal to origin model weight\n{org_weight}\n{sharded_weight}"
def check_grad(org_model: Module,
@@ -236,9 +236,7 @@ def check_grad(org_model: Module,
shard_weight = getattr_(sharded_model, suffix).weight
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
- shard_grad_list = [
- torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(dist.get_world_size(tp_group))
- ]
+ shard_grad_list = [torch.zeros_like(shard_grad).to('cuda') for _ in range(dist.get_world_size(tp_group))]
dist.all_gather(shard_grad_list, shard_grad, tp_group)
shard_grad = torch.cat(shard_grad_list, dim=dim)
diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py
index 3ac8fa26d860..274cfaa39ad1 100644
--- a/tests/test_shardformer/test_model/test_shard_gpt2.py
+++ b/tests/test_shardformer/test_model/test_shard_gpt2.py
@@ -23,7 +23,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \
build_model_from_hybrid_plugin(model_fn, loss_fn, test_config)
-
org_loss, org_output, sharded_loss, sharded_output = \
run_forward_backward_with_hybrid_plugin(
org_model,
@@ -47,7 +46,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
if org_model.__class__.__name__ == 'GPT2Model':
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
- # check loss
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
def unwrap(module):
@@ -92,13 +90,14 @@ def unwrap(module):
'num_microbatches': 4,
'enable_all_optimization': True,
'use_lazy_init': True,
- 'precision': 'fp32',
+ 'precision': 'fp16',
+ 'initial_scale': 1,
}, {
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': True,
- 'use_lazy_init': False,
+ 'use_lazy_init': True,
'precision': 'fp16',
'initial_scale': 1,
}, {
@@ -112,7 +111,6 @@ def unwrap(module):
def run_gpt2_test(test_config):
# TODO: add test_config for TP+DP after supporting & debugging it
- # TODO: check and debug TP+AMP
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
From d86ddd9b2910ef0e9a093039d70c3789d3af3517 Mon Sep 17 00:00:00 2001
From: LuGY <74758262+Gy-Lu@users.noreply.github.com>
Date: Fri, 11 Aug 2023 15:09:24 +0800
Subject: [PATCH 59/64] [hotfix] fix unsafe async comm in zero (#4404)
* improve stablility of zero
* fix wrong index
* add record stream
---
.../low_level/bookkeeping/bucket_store.py | 55 ++++++++++++-------
colossalai/zero/low_level/low_level_optim.py | 9 +++
.../test_zero/test_low_level/test_zero1_2.py | 2 +-
3 files changed, 46 insertions(+), 20 deletions(-)
diff --git a/colossalai/zero/low_level/bookkeeping/bucket_store.py b/colossalai/zero/low_level/bookkeeping/bucket_store.py
index 98f1b78d0049..0ab10e25d407 100644
--- a/colossalai/zero/low_level/bookkeeping/bucket_store.py
+++ b/colossalai/zero/low_level/bookkeeping/bucket_store.py
@@ -13,15 +13,20 @@ class BucketStore(BaseStore):
def __init__(self, torch_pg: ProcessGroup):
super().__init__(torch_pg)
- # init and reset
+ # init
self.current_group_id = 0
+ self._num_elements_in_bucket = 0
# mapping gardient slices and parameter
self.grad_to_param_mapping = dict()
+ self._grad_in_bucket = dict()
self._param_list = []
self._padding_size = []
+ for rank in range(self._world_size):
+ self._grad_in_bucket[rank] = []
- self.reset()
+ # offset_list records number of tensors in the bucket before each reduction
+ self.offset_list = [0]
def num_elements_in_bucket(self) -> int:
"""Return the total number of elements in bucket
@@ -32,6 +37,12 @@ def num_elements_in_bucket(self) -> int:
return self._num_elements_in_bucket
+ def reset_num_elements_in_bucket(self):
+ """Set the number of elements in bucket to zero.
+ """
+
+ self._num_elements_in_bucket = 0
+
def add_param_grad(self, group_id: int, param: Tensor, padding_size: int):
"""Add a param to bucket and record the padding size of a param for gradient padding
@@ -46,28 +57,32 @@ def add_param_grad(self, group_id: int, param: Tensor, padding_size: int):
self._num_elements_in_bucket += (param.numel() + padding_size)
self.current_group_id = group_id
+ # number of tensors in current bucket
+ self.offset_list[-1] += 1
+
def build_grad_in_bucket(self):
"""Orgnize parameters' gradient(padding and split), follows the paramters' splitting method
Data structure of self._grad_in_bucket:
{
rank0: [grad0_rank0, grad1_rank0, ...]
- rank1: [grad1_rank1, grad1_rank1, ...]
+ rank1: [grad0_rank1, grad1_rank1, ...]
}
"""
-
for param, padding_size in zip(self._param_list, self._padding_size):
- with torch.no_grad():
- grad = param.grad.detach().flatten()
- if padding_size > 0:
- grad = torch.nn.functional.pad(grad, [0, padding_size])
- grad_list = grad.split(grad.numel() // self._world_size)
- for rank in range(self._world_size):
- grad_current_rank = grad_list[rank].detach()
- self.grad_to_param_mapping[id(grad_current_rank)] = id(param)
- self._grad_in_bucket[rank].append(grad_current_rank)
+ grad = param.grad.clone().detach().flatten()
+ if padding_size > 0:
+ with torch.no_grad():
+ grad = torch.nn.functional.pad(grad.view(-1), [0, padding_size])
+ grad_list = grad.split(grad.numel() // self._world_size)
+ for rank in range(self._world_size):
+ grad_current_rank = grad_list[rank].clone().detach()
+ self.grad_to_param_mapping[id(grad_current_rank)] = id(param)
+ self._grad_in_bucket[rank].append(grad_current_rank)
param.grad = None
+ self.offset_list.append(0)
+
def get_grad(self) -> Dict:
"""Return the dictionary of gradients slices, of which the keys are ranks
@@ -104,10 +119,12 @@ def get_param_id_of_grad(self, grad: Tensor) -> int:
return self.grad_to_param_mapping[id(grad)]
def reset(self):
- self.grad_to_param_mapping = dict()
- self._num_elements_in_bucket = 0
- self._param_list = []
- self._padding_size = []
- self._grad_in_bucket = dict()
+ """Reset the bucket storage after reduction, only release the tensors have been reduced
+ """
+ cur_offset = self.offset_list.pop(0)
+ self._param_list = self._param_list[cur_offset:]
+ self._padding_size = self._padding_size[cur_offset:]
+ for _ in range(cur_offset):
+ del self.grad_to_param_mapping[next(iter(self.grad_to_param_mapping))]
for rank in range(self._world_size):
- self._grad_in_bucket[rank] = []
+ self._grad_in_bucket[rank] = self._grad_in_bucket[rank][cur_offset:]
diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py
index 2b3f50ed4fd4..64d6a5395120 100644
--- a/colossalai/zero/low_level/low_level_optim.py
+++ b/colossalai/zero/low_level/low_level_optim.py
@@ -242,10 +242,19 @@ def _attach_reduction_hook(self):
def _run_reduction(self):
if self._bucket_store.num_elements_in_bucket() > 0:
self._bucket_store.build_grad_in_bucket()
+
flat_grads = self._bucket_store.get_flatten_grad()
flat_grads /= self._world_size
+
+ # ready to add other tensors to bucket
+ self._bucket_store.reset_num_elements_in_bucket()
+
if self._overlap_communication:
stream = self._comm_stream
+ # in case of the memory being reused in the default stream
+ flat_grads.record_stream(stream)
+ # waiting for ops in the default stream finishing
+ stream.wait_stream(torch.cuda.current_stream())
else:
stream = torch.cuda.current_stream()
diff --git a/tests/test_zero/test_low_level/test_zero1_2.py b/tests/test_zero/test_low_level/test_zero1_2.py
index 5a0609bff192..9c4474aff5c3 100644
--- a/tests/test_zero/test_low_level/test_zero1_2.py
+++ b/tests/test_zero/test_low_level/test_zero1_2.py
@@ -137,7 +137,7 @@ def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype):
zero_optimizer = LowLevelZeroOptimizer(zero_optimizer,
overlap_communication=True,
initial_scale=1,
- reduce_bucket_size=262144)
+ reduce_bucket_size=1024 * 1024)
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
From 1e518ae7db23b0b00f64b01c6f461a3779a57d04 Mon Sep 17 00:00:00 2001
From: Baizhou Zhang
Date: Fri, 11 Aug 2023 15:43:23 +0800
Subject: [PATCH 60/64] [shardformer] rewrite tests for
opt/bloom/llama/vit/chatglm (#4395)
* rewrite opt tests
* rewrite llama tests
* rewrite bloom & vit tests
* rewrite chatglm tests
* fix LinearCol for classfiers
* add judge for other tp layers, fix lazy init in util
---
colossalai/shardformer/layer/linear.py | 16 +
.../shardformer/layer/qkv_fused_linear.py | 16 +
colossalai/shardformer/modeling/opt.py | 497 +++++++++++++-
.../shardformer/policies/auto_policy.py | 6 +
colossalai/shardformer/policies/opt.py | 618 +-----------------
tests/kit/model_zoo/transformers/bloom.py | 8 +-
tests/kit/model_zoo/transformers/chatglm.py | 19 +-
tests/kit/model_zoo/transformers/vit.py | 6 +-
tests/test_shardformer/test_model/_utils.py | 35 +-
.../test_model/test_shard_bloom.py | 118 ++--
.../test_model/test_shard_bloom_pipeline.py | 90 ---
.../test_model/test_shard_chatglm.py | 179 ++---
.../test_model/test_shard_chatglm_pipeline.py | 86 ---
.../test_model/test_shard_llama.py | 144 ++--
.../test_model/test_shard_llama_pipeline.py | 89 ---
.../test_model/test_shard_opt.py | 145 ++--
.../test_model/test_shard_opt_pipeline.py | 70 --
.../test_model/test_shard_vit.py | 137 +++-
.../test_model/test_shard_vit_pipeline.py | 74 ---
19 files changed, 1072 insertions(+), 1281 deletions(-)
delete mode 100644 tests/test_shardformer/test_model/test_shard_bloom_pipeline.py
delete mode 100644 tests/test_shardformer/test_model/test_shard_chatglm_pipeline.py
delete mode 100644 tests/test_shardformer/test_model/test_shard_llama_pipeline.py
delete mode 100644 tests/test_shardformer/test_model/test_shard_opt_pipeline.py
delete mode 100644 tests/test_shardformer/test_model/test_shard_vit_pipeline.py
diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py
index bb36854bd772..d59b68ce4480 100644
--- a/colossalai/shardformer/layer/linear.py
+++ b/colossalai/shardformer/layer/linear.py
@@ -143,6 +143,14 @@ def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, Lis
f'Expected only one process group, got {len(process_group)}.'
process_group = process_group[0]
+ tp_size = dist.get_world_size(process_group)
+ if out_features < tp_size:
+ return module
+
+ if out_features % tp_size != 0:
+ raise ValueError(
+ f"The size of out_features:{out_features} is not integer multiples of tensor parallel size: {tp_size}!")
+
linear_1d = Linear1D_Col(in_features=in_features,
out_features=out_features,
bias=bias,
@@ -293,6 +301,14 @@ def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, Lis
f'Expected only one process group, got {len(process_group)}.'
process_group = process_group[0]
+ tp_size = dist.get_world_size(process_group)
+ if in_features < tp_size:
+ return module
+
+ if in_features % tp_size != 0:
+ raise ValueError(
+ f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!")
+
linear_1d = Linear1D_Row(in_features=in_features,
out_features=out_features,
bias=bias,
diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py
index 42417f8bcc43..df942d43ee2d 100644
--- a/colossalai/shardformer/layer/qkv_fused_linear.py
+++ b/colossalai/shardformer/layer/qkv_fused_linear.py
@@ -265,6 +265,14 @@ def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, Lis
f'Expected only one process group, got {len(process_group)}.'
process_group = process_group[0]
+ tp_size = dist.get_world_size(process_group)
+ if out_features < tp_size:
+ return module
+
+ if out_features % tp_size != 0:
+ raise ValueError(
+ f"The size of out_features:{out_features} is not integer multiples of tensor parallel size: {tp_size}!")
+
linear_1d = GPT2FusedLinearConv1D_Col(in_features=in_features,
out_features=out_features,
bias=bias,
@@ -420,6 +428,14 @@ def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, Lis
f'Expected only one process group, got {len(process_group)}.'
process_group = process_group[0]
+ tp_size = dist.get_world_size(process_group)
+ if in_features < tp_size:
+ return module
+
+ if in_features % tp_size != 0:
+ raise ValueError(
+ f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!")
+
linear_1d = GPT2FusedLinearConv1D_Row(in_features=in_features,
out_features=out_features,
bias=bias,
diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py
index bdf141816737..9afdfff4d71d 100644
--- a/colossalai/shardformer/modeling/opt.py
+++ b/colossalai/shardformer/modeling/opt.py
@@ -1,7 +1,500 @@
-from typing import Optional, Tuple
+import random
+from typing import List, Optional, Tuple, Union
import torch
-from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPast,
+ CausalLMOutputWithPast,
+ QuestionAnsweringModelOutput,
+ SequenceClassifierOutputWithPast,
+)
+from transformers.models.opt.modeling_opt import (
+ OPTForCausalLM,
+ OPTForQuestionAnswering,
+ OPTForSequenceClassification,
+ OPTModel,
+)
+from transformers.utils import logging
+
+from colossalai.pipeline.stage_manager import PipelineStageManager
+
+
+class OPTPipelineForwards:
+ '''
+ This class serves as a micro library for forward function substitution of OPT models
+ under pipeline setting.
+ '''
+
+ @staticmethod
+ def _prepare_decoder_attention_mask(attention_mask, input_shape, _dtype, device, past_key_values_length):
+ # create causal mask
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ from transformers.models.opt.modeling_opt import _make_causal_mask
+ combined_attention_mask = None
+ if input_shape[-1] > 1:
+ combined_attention_mask = _make_causal_mask(
+ input_shape,
+ _dtype,
+ device,
+ past_key_values_length=past_key_values_length,
+ )
+
+ if attention_mask is not None:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ expanded_attn_mask = OPTPipelineForwards._expand_mask(attention_mask, _dtype,
+ tgt_len=input_shape[-1]).to(device)
+ combined_attention_mask = (expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask +
+ combined_attention_mask)
+
+ return combined_attention_mask
+
+ @staticmethod
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
+ """
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
+ """
+ bsz, src_len = mask.size()
+ tgt_len = tgt_len if tgt_len is not None else src_len
+
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
+
+ inverted_mask = 1.0 - expanded_mask
+
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
+
+ @staticmethod
+ def opt_model_forward(
+ self: OPTModel,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ hidden_states: Optional[torch.FloatTensor] = None,
+ stage_index: Optional[List[int]] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+ '''
+ This forward method is modified based on transformers.models.opt.modeling_opt.OPTModel.forward
+ '''
+
+ from transformers.modeling_outputs import BaseModelOutputWithPast
+ from transformers.utils import logging
+ logger = logging.get_logger(__name__)
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (output_hidden_states
+ if output_hidden_states is not None else self.config.output_hidden_states)
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ decoder = self.decoder
+ if stage_manager.is_first_stage():
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+
+ batch_size, seq_length = input_shape
+
+ if inputs_embeds is None:
+ inputs_embeds = decoder.embed_tokens(input_ids)
+
+ if decoder.project_in is not None:
+ inputs_embeds = decoder.project_in(inputs_embeds)
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+ _dtype = inputs_embeds.dtype
+
+ else:
+ if hidden_states is None:
+ raise ValueError("hidden_states shouln't be None for intermediate stages.")
+ input_shape = hidden_states.size()[:-1]
+ batch_size, seq_length = input_shape[0], input_shape[1]
+ device = hidden_states.device
+ _dtype = hidden_states.dtype
+
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
+ # required mask seq length can be calculated via length of past
+ mask_seq_length = past_key_values_length + seq_length
+ # embed positions
+ if attention_mask is None:
+ attention_mask = torch.ones(batch_size, mask_seq_length, device=device)
+ elif attention_mask.shape[1] != mask_seq_length:
+ raise ValueError(
+ f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
+ f"{mask_seq_length} (sum of the lengths of current and past inputs)")
+
+ causal_attention_mask = OPTPipelineForwards._prepare_decoder_attention_mask(attention_mask, input_shape, _dtype,
+ device, past_key_values_length)
+
+ if stage_manager.is_first_stage():
+ pos_embeds = decoder.embed_positions(attention_mask, past_key_values_length)
+ hidden_states = inputs_embeds + pos_embeds
+
+ if decoder.gradient_checkpointing and decoder.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
+ use_cache = False
+
+ # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future.
+ if past_key_values:
+ logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.')
+ past_key_values = None
+ if output_attentions:
+ logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
+ output_attentions = False
+ if output_hidden_states:
+ logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
+ output_hidden_states = False
+ if use_cache:
+ logger.warning_once('use_cache=True is not supported for pipeline models at the moment.')
+ use_cache = False
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = () if use_cache else None
+
+ # check if head_mask has a correct number of layers specified if desired
+ for attn_mask, mask_name in zip([head_mask], ["head_mask"]):
+ if attn_mask is not None:
+ if attn_mask.size()[0] != (len(decoder.layers)):
+ raise ValueError(
+ f"The `{mask_name}` should be specified for {len(decoder.layers)} layers, but it is for"
+ f" {head_mask.size()[0]}.")
+
+ start_idx, end_idx = stage_index[0], stage_index[1]
+
+ torch.cuda.set_device(device)
+
+ for idx in range(start_idx, end_idx):
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+ decoder_layer = decoder.layers[idx]
+
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ dropout_probability = random.uniform(0, 1)
+ if decoder.training and (dropout_probability < decoder.layerdrop):
+ continue
+
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
+
+ if decoder.gradient_checkpointing and decoder.training:
+
+ def create_custom_forward(module):
+
+ def custom_forward(*inputs):
+ # None for past_key_value
+ return module(*inputs, output_attentions, None)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(decoder_layer),
+ hidden_states,
+ causal_attention_mask,
+ head_mask[idx] if head_mask is not None else None,
+ None,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=causal_attention_mask,
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ if stage_manager.is_last_stage():
+ if decoder.final_layer_norm is not None:
+ hidden_states = decoder.final_layer_norm(hidden_states)
+ if decoder.project_out is not None:
+ hidden_states = decoder.project_out(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = next_decoder_cache if use_cache else None
+
+ if stage_manager.is_last_stage():
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+ else:
+ return {'hidden_states': hidden_states}
+
+ @staticmethod
+ def opt_for_causal_lm_forward(
+ self: OPTForCausalLM,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ hidden_states: Optional[torch.FloatTensor] = None,
+ stage_index: Optional[List[int]] = None,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+ r"""
+ This function is modified on the basis of transformers.models.opt.modeling_opt.OPTForCausalLM.forward.
+ Please refer to original code of transformers for more details.
+ """
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (output_hidden_states
+ if output_hidden_states is not None else self.config.output_hidden_states)
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = OPTPipelineForwards.opt_model_forward(
+ self.model,
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ stage_manager=stage_manager,
+ hidden_states=hidden_states,
+ stage_index=stage_index,
+ )
+ if stage_manager.is_last_stage():
+ logits = self.lm_head(outputs[0]).contiguous()
+ loss = None
+ if labels is not None:
+ # move labels to correct device to enable model parallelism
+ labels = labels.to(logits.device)
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+ else:
+ hidden_states = outputs.get('hidden_states')
+ return {'hidden_states': hidden_states}
+
+ @staticmethod
+ def opt_for_sequence_classification_forward(
+ self: OPTForSequenceClassification,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ hidden_states: Optional[torch.FloatTensor] = None,
+ stage_index: Optional[List[int]] = None,
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
+ r"""
+ This function is modified on the basis of transformers.models.opt.modeling_opt.OPTForSequenceClassification.forward.
+ Please refer to original code of transformers for more details.
+ """
+
+ logger = logging.get_logger(__name__)
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = OPTPipelineForwards.opt_model_forward(self.model,
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ stage_manager=stage_manager,
+ hidden_states=hidden_states,
+ stage_index=stage_index)
+
+ if stage_manager.is_last_stage():
+ hidden_states = transformer_outputs[0]
+ logits = self.score(hidden_states)
+
+ batch_size = input_ids.shape[0] if input_ids is not None else hidden_states.shape[0]
+
+ if self.config.pad_token_id is None:
+ sequence_lengths = -1
+ else:
+ if input_ids is not None:
+ sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
+ else:
+ sequence_lengths = -1
+ logger.warning(
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`")
+
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(pooled_logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(pooled_logits, labels)
+
+ if not return_dict:
+ output = (pooled_logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutputWithPast(
+ loss=loss,
+ logits=pooled_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+ else:
+ hidden_states = transformer_outputs.get('hidden_states')
+ return {'hidden_states': hidden_states}
+
+ @staticmethod
+ def opt_for_question_answering_forward(
+ self: OPTForQuestionAnswering,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ start_positions: Optional[torch.LongTensor] = None,
+ end_positions: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ hidden_states: Optional[torch.FloatTensor] = None,
+ stage_index: Optional[List[int]] = None,
+ ) -> Union[Tuple, QuestionAnsweringModelOutput]:
+ r"""
+ This function is modified on the basis of transformers.models.opt.modeling_opt.OPTForQuestionAnswering.forward.
+ Please refer to original code of transformers for more details.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = OPTPipelineForwards.opt_model_forward(self.model,
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ stage_manager=stage_manager,
+ hidden_states=hidden_states,
+ stage_index=stage_index)
+ if stage_manager.is_last_stage():
+ hidden_states = transformer_outputs[0]
+
+ logits = self.qa_outputs(hidden_states)
+ start_logits, end_logits = logits.split(1, dim=-1)
+ start_logits = start_logits.squeeze(-1).contiguous()
+ end_logits = end_logits.squeeze(-1).contiguous()
+
+ total_loss = None
+ if start_positions is not None and end_positions is not None:
+ # If we are on multi-GPU, split add a dimension
+ if len(start_positions.size()) > 1:
+ start_positions = start_positions.squeeze(-1)
+ if len(end_positions.size()) > 1:
+ end_positions = end_positions.squeeze(-1)
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
+ ignored_index = start_logits.size(1)
+ start_positions = start_positions.clamp(0, ignored_index)
+ end_positions = end_positions.clamp(0, ignored_index)
+
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
+ start_loss = loss_fct(start_logits, start_positions)
+ end_loss = loss_fct(end_logits, end_positions)
+ total_loss = (start_loss + end_loss) / 2
+
+ if not return_dict:
+ output = (start_logits, end_logits) + transformer_outputs[2:]
+ return ((total_loss,) + output) if total_loss is not None else output
+
+ return QuestionAnsweringModelOutput(
+ loss=total_loss,
+ start_logits=start_logits,
+ end_logits=end_logits,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+ else:
+ hidden_states = transformer_outputs.get('hidden_states')
+ return {'hidden_states': hidden_states}
def get_opt_flash_attention_forward():
diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py
index 2a041af19be8..eec339c02872 100644
--- a/colossalai/shardformer/policies/auto_policy.py
+++ b/colossalai/shardformer/policies/auto_policy.py
@@ -122,6 +122,12 @@ class PolicyLocation:
PolicyLocation(file_name="blip2", class_name="Blip2ModelPolicy"),
"transformers.models.blip_2.modeling_blip_2.Blip2ForConditionalGeneration":
PolicyLocation(file_name="blip2", class_name="Blip2ForConditionalGenerationPolicy"),
+
+ # ChatGLM
+ "colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMModel":
+ PolicyLocation(file_name="chatglm", class_name="ChatGLMModelPolicy"),
+ "colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMForConditionalGeneration":
+ PolicyLocation(file_name="chatglm", class_name="ChatGLMForConditionalGenerationPolicy"),
}
diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py
index 88ecd8565091..ba6036bd0658 100644
--- a/colossalai/shardformer/policies/opt.py
+++ b/colossalai/shardformer/policies/opt.py
@@ -1,32 +1,14 @@
-import logging
-import random
from functools import partial
-from types import MethodType
-from typing import Callable, Dict, List, Optional, Tuple, Union
+from typing import Callable, Dict, List
-import torch
import torch.nn as nn
from torch import Tensor, nn
-from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
-from transformers.modeling_outputs import (
- BaseModelOutputWithPast,
- CausalLMOutputWithPast,
- QuestionAnsweringModelOutput,
- SequenceClassifierOutputWithPast,
-)
-from transformers.models.opt.modeling_opt import (
- OPTForCausalLM,
- OPTForQuestionAnswering,
- OPTForSequenceClassification,
- OPTModel,
-)
-
-from colossalai.pipeline.stage_manager import PipelineStageManager
+
from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
-from .._utils import getattr_, setattr_
+from .._utils import getattr_
from ..modeling.jit import get_jit_fused_dropout_add_func
-from ..modeling.opt import get_jit_fused_opt_decoder_layer_forward, get_opt_flash_attention_forward
+from ..modeling.opt import OPTPipelineForwards, get_jit_fused_opt_decoder_layer_forward, get_opt_flash_attention_forward
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = [
@@ -228,6 +210,7 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]:
num_stages = self.pipeline_stage_manager.num_stages
if id(opt_model.model.decoder.embed_tokens.weight) == id(opt_model.lm_head.weight):
return [{0: opt_model.model.decoder.embed_tokens.weight, num_stages - 1: opt_model.lm_head.weight}]
+ return []
def postprocess(self):
if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None:
@@ -295,594 +278,3 @@ def get_held_layers(self) -> List[nn.Module]:
def get_shared_params(self) -> List[Dict[int, Tensor]]:
"no shared params in OPTForSequenceClassification"
return []
-
-
-class OPTPipelineForwards:
- '''
- This class serves as a micro library for forward function substitution of OPT models
- under pipeline setting.
- '''
-
- @staticmethod
- def _prepare_decoder_attention_mask(attention_mask, input_shape, _dtype, device, past_key_values_length):
- # create causal mask
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
- from transformers.models.opt.modeling_opt import _make_causal_mask
- combined_attention_mask = None
- if input_shape[-1] > 1:
- combined_attention_mask = _make_causal_mask(
- input_shape,
- _dtype,
- device,
- past_key_values_length=past_key_values_length,
- )
-
- if attention_mask is not None:
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
- expanded_attn_mask = OPTPipelineForwards._expand_mask(attention_mask, _dtype,
- tgt_len=input_shape[-1]).to(device)
- combined_attention_mask = (expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask +
- combined_attention_mask)
-
- return combined_attention_mask
-
- @staticmethod
- def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
- """
- Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
- """
- bsz, src_len = mask.size()
- tgt_len = tgt_len if tgt_len is not None else src_len
-
- expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
-
- inverted_mask = 1.0 - expanded_mask
-
- return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
-
- @staticmethod
- def opt_model_forward(
- self: OPTModel,
- input_ids: torch.LongTensor = None,
- attention_mask: Optional[torch.Tensor] = None,
- head_mask: Optional[torch.Tensor] = None,
- past_key_values: Optional[List[torch.FloatTensor]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- stage_manager: Optional[PipelineStageManager] = None,
- hidden_states: Optional[torch.FloatTensor] = None,
- stage_index: Optional[List[int]] = None,
- ) -> Union[Tuple, BaseModelOutputWithPast]:
- '''
- This forward method is modified based on transformers.models.opt.modeling_opt.OPTModel.forward
- '''
-
- from transformers.modeling_outputs import BaseModelOutputWithPast
- from transformers.utils import logging
- logger = logging.get_logger(__name__)
-
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (output_hidden_states
- if output_hidden_states is not None else self.config.output_hidden_states)
- use_cache = use_cache if use_cache is not None else self.config.use_cache
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
- decoder = self.decoder
- if stage_manager.is_first_stage():
- # retrieve input_ids and inputs_embeds
- if input_ids is not None and inputs_embeds is not None:
- raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
- elif input_ids is not None:
- input_shape = input_ids.size()
- input_ids = input_ids.view(-1, input_shape[-1])
- elif inputs_embeds is not None:
- input_shape = inputs_embeds.size()[:-1]
- else:
- raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
-
- batch_size, seq_length = input_shape
-
- if inputs_embeds is None:
- inputs_embeds = decoder.embed_tokens(input_ids)
-
- if decoder.project_in is not None:
- inputs_embeds = decoder.project_in(inputs_embeds)
- device = input_ids.device if input_ids is not None else inputs_embeds.device
- _dtype = inputs_embeds.dtype
-
- else:
- if hidden_states is None:
- raise ValueError("hidden_states shouln't be None for intermediate stages.")
- input_shape = hidden_states.size()[:-1]
- batch_size, seq_length = input_shape[0], input_shape[1]
- device = hidden_states.device
- _dtype = hidden_states.dtype
-
- past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
- # required mask seq length can be calculated via length of past
- mask_seq_length = past_key_values_length + seq_length
- # embed positions
- if attention_mask is None:
- attention_mask = torch.ones(batch_size, mask_seq_length, device=device)
- elif attention_mask.shape[1] != mask_seq_length:
- raise ValueError(
- f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
- f"{mask_seq_length} (sum of the lengths of current and past inputs)")
-
- causal_attention_mask = OPTPipelineForwards._prepare_decoder_attention_mask(attention_mask, input_shape, _dtype,
- device, past_key_values_length)
-
- if stage_manager.is_first_stage():
- pos_embeds = decoder.embed_positions(attention_mask, past_key_values_length)
- hidden_states = inputs_embeds + pos_embeds
-
- if decoder.gradient_checkpointing and decoder.training:
- if use_cache:
- logger.warning_once(
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
- use_cache = False
-
- # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future.
- if past_key_values:
- logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.')
- past_key_values = None
- if output_attentions:
- logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
- output_attentions = False
- if output_hidden_states:
- logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
- output_hidden_states = False
- if use_cache:
- logger.warning_once('use_cache=True is not supported for pipeline models at the moment.')
- use_cache = False
-
- # decoder layers
- all_hidden_states = () if output_hidden_states else None
- all_self_attns = () if output_attentions else None
- next_decoder_cache = () if use_cache else None
-
- # check if head_mask has a correct number of layers specified if desired
- for attn_mask, mask_name in zip([head_mask], ["head_mask"]):
- if attn_mask is not None:
- if attn_mask.size()[0] != (len(decoder.layers)):
- raise ValueError(
- f"The `{mask_name}` should be specified for {len(decoder.layers)} layers, but it is for"
- f" {head_mask.size()[0]}.")
-
- start_idx, end_idx = stage_index[0], stage_index[1]
-
- torch.cuda.set_device(device)
-
- for idx in range(start_idx, end_idx):
- # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
- decoder_layer = decoder.layers[idx]
-
- if output_hidden_states:
- all_hidden_states += (hidden_states,)
-
- dropout_probability = random.uniform(0, 1)
- if decoder.training and (dropout_probability < decoder.layerdrop):
- continue
-
- past_key_value = past_key_values[idx] if past_key_values is not None else None
-
- if decoder.gradient_checkpointing and decoder.training:
-
- def create_custom_forward(module):
-
- def custom_forward(*inputs):
- # None for past_key_value
- return module(*inputs, output_attentions, None)
-
- return custom_forward
-
- layer_outputs = torch.utils.checkpoint.checkpoint(
- create_custom_forward(decoder_layer),
- hidden_states,
- causal_attention_mask,
- head_mask[idx] if head_mask is not None else None,
- None,
- )
- else:
- layer_outputs = decoder_layer(
- hidden_states,
- attention_mask=causal_attention_mask,
- layer_head_mask=(head_mask[idx] if head_mask is not None else None),
- past_key_value=past_key_value,
- output_attentions=output_attentions,
- use_cache=use_cache,
- )
-
- hidden_states = layer_outputs[0]
-
- if use_cache:
- next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
-
- if output_attentions:
- all_self_attns += (layer_outputs[1],)
-
- if stage_manager.is_last_stage():
- if decoder.final_layer_norm is not None:
- hidden_states = decoder.final_layer_norm(hidden_states)
- if decoder.project_out is not None:
- hidden_states = decoder.project_out(hidden_states)
-
- # add hidden states from the last decoder layer
- if output_hidden_states:
- all_hidden_states += (hidden_states,)
-
- next_cache = next_decoder_cache if use_cache else None
-
- if stage_manager.is_last_stage():
- if not return_dict:
- return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
-
- return BaseModelOutputWithPast(
- last_hidden_state=hidden_states,
- past_key_values=next_cache,
- hidden_states=all_hidden_states,
- attentions=all_self_attns,
- )
- else:
- return {'hidden_states': hidden_states}
-
- @staticmethod
- def opt_for_causal_lm_forward(
- self: OPTForCausalLM,
- input_ids: torch.LongTensor = None,
- attention_mask: Optional[torch.Tensor] = None,
- head_mask: Optional[torch.Tensor] = None,
- past_key_values: Optional[List[torch.FloatTensor]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- labels: Optional[torch.LongTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- stage_manager: Optional[PipelineStageManager] = None,
- hidden_states: Optional[torch.FloatTensor] = None,
- stage_index: Optional[List[int]] = None,
- ) -> Union[Tuple, CausalLMOutputWithPast]:
- r"""
- Args:
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
- provide it.
-
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
-
- [What are input IDs?](../glossary#input-ids)
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
-
- - 1 for tokens that are **not masked**,
- - 0 for tokens that are **masked**.
-
- [What are attention masks?](../glossary#attention-mask)
- head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
- Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
-
- - 1 indicates the head is **not masked**,
- - 0 indicates the head is **masked**.
-
- past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
- shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
- shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional
- tensors are only required when the model is used as a decoder in a Sequence to Sequence model.
-
- Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
- cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
-
- If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
- that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
- all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
- This is useful if you want more control over how to convert `input_ids` indices into associated vectors
- than the model's internal embedding lookup matrix.
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
- use_cache (`bool`, *optional*):
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
- (see `past_key_values`).
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
- returned tensors for more detail.
- output_hidden_states (`bool`, *optional*):
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
- for more detail.
- return_dict (`bool`, *optional*):
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
-
- Returns:
-
- Example:
-
- ```python
- >>> from transformers import AutoTokenizer, OPTForCausalLM
-
- >>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m")
- >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
-
- >>> prompt = "Hey, are you consciours? Can you talk to me?"
- >>> inputs = tokenizer(prompt, return_tensors="pt")
-
- >>> # Generate
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
- "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
- ```"""
- from transformers.modeling_outputs import CausalLMOutputWithPast
-
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (output_hidden_states
- if output_hidden_states is not None else self.config.output_hidden_states)
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
- outputs = OPTPipelineForwards.opt_model_forward(
- self.model,
- input_ids=input_ids,
- attention_mask=attention_mask,
- head_mask=head_mask,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- stage_manager=stage_manager,
- hidden_states=hidden_states,
- stage_index=stage_index,
- )
- if stage_manager.is_last_stage():
- logits = self.lm_head(outputs[0]).contiguous()
- loss = None
- if labels is not None:
- # move labels to correct device to enable model parallelism
- labels = labels.to(logits.device)
- # Shift so that tokens < n predict n
- shift_logits = logits[..., :-1, :].contiguous()
- shift_labels = labels[..., 1:].contiguous()
- # Flatten the tokens
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
- if not return_dict:
- output = (logits,) + outputs[1:]
- return (loss,) + output if loss is not None else output
-
- return CausalLMOutputWithPast(
- loss=loss,
- logits=logits,
- past_key_values=outputs.past_key_values,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- else:
- hidden_states = outputs.get('hidden_states')
- return {'hidden_states': hidden_states}
-
- @staticmethod
- def opt_for_sequence_classification_forward(
- self: OPTForSequenceClassification,
- input_ids: Optional[torch.LongTensor] = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- head_mask: Optional[torch.FloatTensor] = None,
- past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- labels: Optional[torch.LongTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- stage_manager: Optional[PipelineStageManager] = None,
- hidden_states: Optional[torch.FloatTensor] = None,
- stage_index: Optional[List[int]] = None,
- ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
- r"""
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
- """
- from transformers.modeling_outputs import SequenceClassifierOutputWithPast
- from transformers.utils import logging
- logger = logging.get_logger(__name__)
-
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
- transformer_outputs = OPTPipelineForwards.opt_model_forward(self.model,
- input_ids,
- past_key_values=past_key_values,
- attention_mask=attention_mask,
- head_mask=head_mask,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- stage_manager=stage_manager,
- hidden_states=hidden_states,
- stage_index=stage_index)
-
- if stage_manager.is_last_stage():
- hidden_states = transformer_outputs[0]
- logits = self.score(hidden_states)
-
- batch_size = input_ids.shape[0] if input_ids is not None else hidden_states.shape[0]
-
- if self.config.pad_token_id is None:
- sequence_lengths = -1
- else:
- if input_ids is not None:
- sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
- else:
- sequence_lengths = -1
- logger.warning(
- f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
- "unexpected if using padding tokens in conjunction with `inputs_embeds.`")
-
- pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
-
- loss = None
- if labels is not None:
- if self.config.problem_type is None:
- if self.num_labels == 1:
- self.config.problem_type = "regression"
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
- self.config.problem_type = "single_label_classification"
- else:
- self.config.problem_type = "multi_label_classification"
-
- if self.config.problem_type == "regression":
- loss_fct = MSELoss()
- if self.num_labels == 1:
- loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
- else:
- loss = loss_fct(pooled_logits, labels)
- elif self.config.problem_type == "single_label_classification":
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
- elif self.config.problem_type == "multi_label_classification":
- loss_fct = BCEWithLogitsLoss()
- loss = loss_fct(pooled_logits, labels)
-
- if not return_dict:
- output = (pooled_logits,) + transformer_outputs[1:]
- return ((loss,) + output) if loss is not None else output
-
- return SequenceClassifierOutputWithPast(
- loss=loss,
- logits=pooled_logits,
- past_key_values=transformer_outputs.past_key_values,
- hidden_states=transformer_outputs.hidden_states,
- attentions=transformer_outputs.attentions,
- )
- else:
- hidden_states = transformer_outputs.get('hidden_states')
- return {'hidden_states': hidden_states}
-
- @staticmethod
- def opt_for_question_answering_forward(
- self: OPTForQuestionAnswering,
- input_ids: Optional[torch.LongTensor] = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- head_mask: Optional[torch.FloatTensor] = None,
- past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- start_positions: Optional[torch.LongTensor] = None,
- end_positions: Optional[torch.LongTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- stage_manager: Optional[PipelineStageManager] = None,
- hidden_states: Optional[torch.FloatTensor] = None,
- stage_index: Optional[List[int]] = None,
- ) -> Union[Tuple, QuestionAnsweringModelOutput]:
- r"""
- start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for position (index) of the start of the labelled span for computing the token classification loss.
- Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
- are not taken into account for computing the loss.
- end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for position (index) of the end of the labelled span for computing the token classification loss.
- Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
- are not taken into account for computing the loss.
-
- Returns:
-
- Example:
-
- ```python
- >>> from transformers import AutoTokenizer, OPTForQuestionAnswering
- >>> import torch
-
- >>> torch.manual_seed(4) # doctest: +IGNORE_RESULT
- >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
-
- >>> # note: we are loading a OPTForQuestionAnswering from the hub here,
- >>> # so the head will be randomly initialized, hence the predictions will be random
- >>> model = OPTForQuestionAnswering.from_pretrained("facebook/opt-350m")
-
- >>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
-
- >>> inputs = tokenizer(question, text, return_tensors="pt")
- >>> with torch.no_grad():
- ... outputs = model(**inputs)
-
- >>> answer_start_index = outputs.start_logits.argmax()
- >>> answer_end_index = outputs.end_logits.argmax()
-
- >>> answer_offset = len(tokenizer(question)[0])
-
- >>> predict_answer_tokens = inputs.input_ids[
- ... 0, answer_offset + answer_start_index : answer_offset + answer_end_index + 1
- ... ]
- >>> predicted = tokenizer.decode(predict_answer_tokens)
- >>> predicted
- ' a nice puppet'
- ```"""
- from transformers.modeling_outputs import QuestionAnsweringModelOutput
-
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
- transformer_outputs = OPTPipelineForwards.opt_model_forward(self.model,
- input_ids,
- past_key_values=past_key_values,
- attention_mask=attention_mask,
- head_mask=head_mask,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- stage_manager=stage_manager,
- hidden_states=hidden_states,
- stage_index=stage_index)
- if stage_manager.is_last_stage():
- hidden_states = transformer_outputs[0]
-
- logits = self.qa_outputs(hidden_states)
- start_logits, end_logits = logits.split(1, dim=-1)
- start_logits = start_logits.squeeze(-1).contiguous()
- end_logits = end_logits.squeeze(-1).contiguous()
-
- total_loss = None
- if start_positions is not None and end_positions is not None:
- # If we are on multi-GPU, split add a dimension
- if len(start_positions.size()) > 1:
- start_positions = start_positions.squeeze(-1)
- if len(end_positions.size()) > 1:
- end_positions = end_positions.squeeze(-1)
- # sometimes the start/end positions are outside our model inputs, we ignore these terms
- ignored_index = start_logits.size(1)
- start_positions = start_positions.clamp(0, ignored_index)
- end_positions = end_positions.clamp(0, ignored_index)
-
- loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
- start_loss = loss_fct(start_logits, start_positions)
- end_loss = loss_fct(end_logits, end_positions)
- total_loss = (start_loss + end_loss) / 2
-
- if not return_dict:
- output = (start_logits, end_logits) + transformer_outputs[2:]
- return ((total_loss,) + output) if total_loss is not None else output
-
- return QuestionAnsweringModelOutput(
- loss=total_loss,
- start_logits=start_logits,
- end_logits=end_logits,
- hidden_states=transformer_outputs.hidden_states,
- attentions=transformer_outputs.attentions,
- )
- else:
- hidden_states = transformer_outputs.get('hidden_states')
- return {'hidden_states': hidden_states}
diff --git a/tests/kit/model_zoo/transformers/bloom.py b/tests/kit/model_zoo/transformers/bloom.py
index 177edbef8935..2d9c882089cb 100644
--- a/tests/kit/model_zoo/transformers/bloom.py
+++ b/tests/kit/model_zoo/transformers/bloom.py
@@ -53,7 +53,8 @@ def data_gen_for_question_answering():
# inputs = tokenizer(question, text, return_tensors="pt")
input_ids = torch.tensor(
- [[57647, 1620, 23967, 620, 107373, 34, 91514, 620, 107373, 1620, 267, 35378, 48946, 18161, 48946, 18161]], dtype=torch.int64)
+ [[57647, 1620, 23967, 620, 107373, 34, 91514, 620, 107373, 1620, 267, 35378, 48946, 18161, 48946, 18161]],
+ dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
start_positions = torch.tensor([1], dtype=torch.int64)
end_positions = torch.tensor([10], dtype=torch.int64)
@@ -73,12 +74,13 @@ def data_gen_for_question_answering():
loss_fn_for_classification = lambda x: x.loss
loss_fn_for_question_answering = lambda x: x.loss
-config = transformers.BloomConfig(n_layer=1,
+config = transformers.BloomConfig(n_layer=2,
n_head=4,
vocab_size=250880,
hidden_dropout=0,
attention_dropout=0,
- hidden_size=64)
+ hidden_size=64,
+ pad_token_id=50256)
# register the following models
model_zoo.register(name='transformers_bloom',
diff --git a/tests/kit/model_zoo/transformers/chatglm.py b/tests/kit/model_zoo/transformers/chatglm.py
index 90bb70bc7f79..c6473ee2a025 100644
--- a/tests/kit/model_zoo/transformers/chatglm.py
+++ b/tests/kit/model_zoo/transformers/chatglm.py
@@ -17,14 +17,24 @@ def data_gen():
return dict(input_ids=input_ids, attention_mask=attention_mask)
+def data_gen_for_conditional_generation():
+ # token classification data gen
+ # `labels` is the type not the token id for token classification, 0 or 1
+ data = data_gen()
+ labels = data['input_ids'].clone()
+ data['labels'] = labels
+ return data
+
+
# define output transform function
output_transform_fn = lambda x: x
# define loss function
-loss_fn_for_chatglm_model = lambda x: x.last_hidden_state.sum()
-loss_fn = lambda x: x.logits.sum()
+loss_fn_for_chatglm_model = lambda x: torch.nn.functional.mse_loss(x.last_hidden_state,
+ torch.ones_like(x.last_hidden_state))
+loss_fn = lambda x: x.loss
-config = ChatGLMConfig(num_layers=1,
+config = ChatGLMConfig(num_layers=2,
padded_vocab_size=65024,
hidden_size=64,
num_attention_heads=8,
@@ -33,7 +43,6 @@ def data_gen():
use_cache=True,
torch_dtype=torch.float32)
-
model_zoo.register(name='transformers_chatglm',
model_fn=lambda: ChatGLMModel(config, empty_init=False),
data_gen_fn=data_gen,
@@ -43,7 +52,7 @@ def data_gen():
model_zoo.register(name="transformers_chatglm_for_conditional_generation",
model_fn=lambda: ChatGLMForConditionalGeneration(config, empty_init=False),
- data_gen_fn=data_gen,
+ data_gen_fn=data_gen_for_conditional_generation,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True))
diff --git a/tests/kit/model_zoo/transformers/vit.py b/tests/kit/model_zoo/transformers/vit.py
index 93a8d6c615d7..a84b8d31c284 100644
--- a/tests/kit/model_zoo/transformers/vit.py
+++ b/tests/kit/model_zoo/transformers/vit.py
@@ -7,11 +7,7 @@
# Register single-sentence VIT
# ===============================
-config = transformers.ViTConfig(
- num_hidden_layers=4,
- # hidden_size=128,
- # intermediate_size=256,
- num_attention_heads=4)
+config = transformers.ViTConfig(num_hidden_layers=4, hidden_size=128, intermediate_size=256, num_attention_heads=4)
# define data gen function
diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py
index c51df07f6c11..921af2a8b1d0 100644
--- a/tests/test_shardformer/test_model/_utils.py
+++ b/tests/test_shardformer/test_model/_utils.py
@@ -104,27 +104,22 @@ def build_model_from_hybrid_plugin(model_fn: Callable, loss_fn: Callable, test_c
if 'use_lazy_init' in test_config:
use_lazy_init = test_config.pop('use_lazy_init')
- if use_lazy_init:
- ctx = LazyInitContext()
- else:
- ctx = nullcontext()
-
- plugin = HybridParallelPlugin(**test_config)
- booster = Booster(plugin=plugin)
-
+ ctx = LazyInitContext() if use_lazy_init else nullcontext()
with ctx:
- org_model = model_fn().cuda()
+ org_model = model_fn()
sharded_model = copy.deepcopy(org_model)
-
if use_lazy_init:
- org_model = ctx.materialize(org_model)
+ ctx.materialize(org_model)
+ org_model = org_model.cuda()
org_optimizer = Adam(org_model.parameters(), lr=1e-3)
sharded_optimizer = Adam(sharded_model.parameters(), lr=1e-3)
criterion = loss_fn
- sharded_model, sharded_optimizer, criterion, _, _ = booster.boost(sharded_model, sharded_optimizer, criterion)
+ plugin = HybridParallelPlugin(**test_config)
+ booster = Booster(plugin=plugin)
+ sharded_model, sharded_optimizer, criterion, _, _ = booster.boost(sharded_model, sharded_optimizer, criterion)
return org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster
@@ -142,11 +137,12 @@ def _criterion(outputs, inputs):
data = data_gen_fn()
sharded_model.train()
if booster.plugin.stage_manager is not None:
- data = {
- k: v.to('cuda').repeat(*([4] + [1] *
- (v.dim() - 1))) if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v
- for k, v in data.items()
- }
+ for k, v in data.items():
+ if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__:
+ new_shape = [1] * v.dim()
+ new_shape[0] = 4
+ data[k] = v.to('cuda').repeat(*new_shape)
+
data_iter = iter([data])
sharded_output = booster.execute_pipeline(data_iter,
sharded_model,
@@ -176,7 +172,8 @@ def check_output_hidden_state(org_output: Tensor,
sharded_output: Tensor,
stage_manager: Optional[PipelineStageManager] = None,
atol: float = 1e-5,
- rtol: float = 1e-3):
+ rtol: float = 1e-3,
+ dim: int = 0):
org_hidden_state = org_output.last_hidden_state
@@ -184,7 +181,7 @@ def check_output_hidden_state(org_output: Tensor,
sharded_hidden_state = sharded_output.last_hidden_state
if stage_manager and stage_manager.is_last_stage():
- sharded_hidden_state = torch.cat([output.last_hidden_state for output in sharded_output['outputs']], dim=0)
+ sharded_hidden_state = torch.cat([output.last_hidden_state for output in sharded_output['outputs']], dim=dim)
assert torch.allclose(org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol), \
f"shard model's output hidden state is not equal to origin model's last hidden state\n{org_hidden_state}\n{sharded_hidden_state}"
diff --git a/tests/test_shardformer/test_model/test_shard_bloom.py b/tests/test_shardformer/test_model/test_shard_bloom.py
index e11bcf92ea3c..d5a4ce083e2b 100644
--- a/tests/test_shardformer/test_model/test_shard_bloom.py
+++ b/tests/test_shardformer/test_model/test_shard_bloom.py
@@ -3,57 +3,101 @@
import colossalai
from colossalai.logging import disable_existing_loggers
-from colossalai.testing import (
- assert_hf_output_close,
- clear_cache_before_run,
- parameterize,
- rerun_if_address_is_in_use,
- spawn,
-)
+from colossalai.tensor.d_tensor.api import clear_layout_converter
+from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
-from tests.test_shardformer.test_model._utils import build_model, check_grad, check_state_dict, run_forward
+from tests.test_shardformer.test_model._utils import (
+ build_model_from_hybrid_plugin,
+ check_grad,
+ check_loss,
+ check_output_hidden_state,
+ check_weight,
+ run_forward_backward_with_hybrid_plugin,
+)
+
+
+def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
+ org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \
+ build_model_from_hybrid_plugin(model_fn, loss_fn, test_config)
-def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
- # check forward
- org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
- output_transform_fn, loss_fn)
- assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'])
+ org_loss, org_output, sharded_loss, sharded_output = \
+ run_forward_backward_with_hybrid_plugin(
+ org_model,
+ sharded_model,
+ sharded_optimizer,
+ data_gen_fn,
+ output_transform_fn,
+ criterion,
+ booster)
- # do backward
- org_loss.backward()
- shard_loss.backward()
+ stage_manager = booster.plugin.stage_manager
+ tp_group = booster.plugin.tp_group
- assert torch.allclose(org_loss, shard_loss,
- atol=1e-6), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
+ # check last hidden state & loss
+ if stage_manager is None or stage_manager.is_last_stage():
+
+ if org_model.__class__.__name__ == 'BloomModel':
+ check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3)
+
+ check_loss(org_loss, sharded_loss, atol=1e-6, rtol=1e-3)
# unwrap model
if org_model.__class__.__name__ == 'BloomModel':
bloom = org_model
- sharded_bloom = sharded_model
+ sharded_bloom = sharded_model.unwrap()
else:
bloom = org_model.transformer
- sharded_bloom = sharded_model.transformer
+ sharded_bloom = sharded_model.unwrap().transformer
# check grad
- col_layer_for_check = ['h[0].self_attention.query_key_value']
- row_layer_for_check = ['h[0].self_attention.dense']
- check_grad(bloom, sharded_bloom, col_layer_for_check, atol=1e-6, rtol=1e-5, dim=0, verbose=False)
- check_grad(bloom, sharded_bloom, row_layer_for_check, atol=1e-6, rtol=1e-5, dim=1, verbose=False)
-
-
-@parameterize('enable_fused_normalization', [True, False])
-@parameterize('enable_tensor_parallelism', [True, False])
-@parameterize('enable_flash_attention', [True, False])
-@parameterize('enable_jit_fused', [True, False])
-@parameterize('use_lazy_init', [False, True])
-def run_bloom_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused,
- use_lazy_init):
+ row_layer_for_check = ['h[0].self_attention.query_key_value', 'word_embeddings']
+ col_layer_for_check = ['h[0].self_attention.dense']
+ if stage_manager is None or stage_manager.is_first_stage():
+ check_grad(bloom, sharded_bloom, row_layer_for_check, tp_group, atol=1e-6, rtol=1e-5, dim=0, verbose=False)
+ check_grad(bloom, sharded_bloom, col_layer_for_check, tp_group, atol=1e-6, rtol=1e-5, dim=1, verbose=False)
+
+ # check weights after optimizer.step()
+ org_optimizer.step()
+ sharded_optimizer.step()
+ if stage_manager is None or stage_manager.is_first_stage():
+ check_weight(bloom, sharded_bloom, col_layer_for_check, tp_group, atol=1e-4, rtol=1e-3, dim=1, verbose=False)
+
+ torch.cuda.empty_cache()
+
+
+@parameterize('test_config', [{
+ 'tp_size': 2,
+ 'pp_size': 2,
+ 'num_microbatches': 4,
+ 'enable_fused_normalization': True,
+ 'use_lazy_init': True
+}, {
+ 'tp_size': 1,
+ 'pp_size': 2,
+ 'num_microbatches': 4,
+ 'enable_fused_normalization': False,
+ 'use_lazy_init': False
+}, {
+ 'tp_size': 4,
+ 'pp_size': 1,
+ 'enable_fused_normalization': True,
+ 'use_lazy_init': False
+}])
+def run_bloom_test(test_config):
+
+ # TODO: add test_config for TP+DP after supporting & debugging it
+ # {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True}
+
+ # TODO: add test_config for flash attention & jit operator after supporting
+
sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom')
+ test_config['precision'] = 'float' # Do not use fp16/bf16 in testing
+
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
- org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
- enable_flash_attention, enable_jit_fused, use_lazy_init)
- check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
+ check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
+
+ clear_layout_converter()
torch.cuda.empty_cache()
@@ -67,7 +111,7 @@ def check_bloom(rank, world_size, port):
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_bloom():
- spawn(check_bloom, 2)
+ spawn(check_bloom, 4)
if __name__ == "__main__":
diff --git a/tests/test_shardformer/test_model/test_shard_bloom_pipeline.py b/tests/test_shardformer/test_model/test_shard_bloom_pipeline.py
deleted file mode 100644
index 6695e8a687bd..000000000000
--- a/tests/test_shardformer/test_model/test_shard_bloom_pipeline.py
+++ /dev/null
@@ -1,90 +0,0 @@
-import pytest
-import torch
-
-import colossalai
-from colossalai.cluster import ProcessGroupMesh
-from colossalai.logging import disable_existing_loggers
-from colossalai.pipeline.stage_manager import PipelineStageManager
-from colossalai.shardformer.policies.auto_policy import get_autopolicy
-from colossalai.shardformer.policies.base_policy import Policy
-from colossalai.shardformer.shard import ShardConfig
-from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
-from colossalai.testing import (
- assert_hf_output_close,
- clear_cache_before_run,
- parameterize,
- rerun_if_address_is_in_use,
- spawn,
-)
-from tests.kit.model_zoo import model_zoo
-from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward
-
-
-def check_bloom_model_policy(name, model: torch.nn.Module, stage_manager: PipelineStageManager):
- policy = get_autopolicy(model)
- policy.set_model(model)
- model_config = ShardConfig(pipeline_stage_manager=stage_manager, enable_tensor_parallelism=False)
- policy.set_shard_config(model_config)
- layers = policy.get_held_layers()
- if stage_manager.is_first_stage():
- assert len(layers) == 0 + 2
- else:
- if name == 'transformers_bloom':
- assert len(layers) == 1 + 1
- elif name == 'transformers_bloom_for_token_classification':
- assert len(layers) == 1 + 3
- else:
- assert len(layers) == 1 + 2
-
-
-def check_bloom_model_pipeline_forward(name, sharded_model, stage_manager: PipelineStageManager):
- if stage_manager.stage == 0:
- x = torch.randint(0, 1000, (1, 3)).cuda()
- attention_mask = torch.ones_like(x).cuda()
- output = sharded_model(input_ids=x, attention_mask=attention_mask)
- assert output['hidden_states'].shape == (1, 3, 64)
- else:
- attention_mask = torch.ones((1, 3)).cuda()
- hidden_states = torch.randint(0, 1000, (1, 3, 64)).to(torch.float32).cuda()
- output = sharded_model(
- hidden_states=hidden_states,
- attention_mask=attention_mask,
- )
- assert output[0].shape[0] == 1
-
-
-@parameterize('enable_fused_normalization', [False])
-@parameterize('enable_tensor_parallelism', [False])
-@parameterize('use_lazy_init', [False])
-#TODO: merge this into test_shard_bloom
-def run_bloom_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
- PP_DIM = 0
- PP_SIZE = 2
- pg_mesh = ProcessGroupMesh(PP_SIZE)
- stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
-
- sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom')
- for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
- org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
- enable_tensor_parallelism, use_lazy_init)
- check_bloom_model_policy(name, org_model, stage_manager)
- check_bloom_model_pipeline_forward(name, sharded_model, stage_manager)
-
- torch.cuda.empty_cache()
-
-
-def check_bloom(rank, world_size, port):
- disable_existing_loggers()
- colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
- run_bloom_test()
-
-
-@pytest.mark.dist
-@rerun_if_address_is_in_use()
-@clear_cache_before_run()
-def test_bloom():
- spawn(check_bloom, 2)
-
-
-if __name__ == "__main__":
- test_bloom()
diff --git a/tests/test_shardformer/test_model/test_shard_chatglm.py b/tests/test_shardformer/test_model/test_shard_chatglm.py
index c455a99d26ce..69e63ffc854e 100644
--- a/tests/test_shardformer/test_model/test_shard_chatglm.py
+++ b/tests/test_shardformer/test_model/test_shard_chatglm.py
@@ -1,99 +1,126 @@
-import copy
-import os
-
import pytest
import torch
+from torch import distributed as dist
import colossalai
from colossalai.logging import disable_existing_loggers
-from colossalai.shardformer import ShardConfig, ShardFormer
-from colossalai.shardformer.policies.chatglm import ChatGLMForConditionalGenerationPolicy, ChatGLMModelPolicy
-from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
-from colossalai.testing import (
- assert_hf_output_close,
- clear_cache_before_run,
- parameterize,
- rerun_if_address_is_in_use,
- spawn,
-)
+from colossalai.tensor.d_tensor.api import clear_layout_converter
+from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
-from tests.test_shardformer.test_model._utils import build_model, run_forward
+from tests.test_shardformer.test_model._utils import (
+ build_model_from_hybrid_plugin,
+ check_grad,
+ check_loss,
+ check_output_hidden_state,
+ check_weight,
+ run_forward_backward_with_hybrid_plugin,
+)
+
+
+def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
+
+ org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \
+ build_model_from_hybrid_plugin(model_fn, loss_fn, test_config)
+
+ org_loss, org_output, sharded_loss, sharded_output = \
+ run_forward_backward_with_hybrid_plugin(
+ org_model,
+ sharded_model,
+ sharded_optimizer,
+ data_gen_fn,
+ output_transform_fn,
+ criterion,
+ booster)
+ stage_manager = booster.plugin.stage_manager
+ tp_group = booster.plugin.tp_group
-def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
- # check forward
- org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
- output_transform_fn, loss_fn)
- assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'])
- # do backward
- org_loss.backward()
- shard_loss.backward()
+ # check last hidden state & loss
+ if stage_manager is None or stage_manager.is_last_stage():
- assert torch.allclose(org_loss, shard_loss,
- atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
+ if org_model.__class__.__name__ == 'ChatGLMModel':
+ check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3, dim=1)
+
+ check_loss(org_loss, sharded_loss, atol=1e-6, rtol=1e-3)
# unwrap model
if org_model.__class__.__name__ == 'ChatGLMModel':
chatglm_model = org_model
- shard_chatglm_model = sharded_model
+ shard_chatglm_model = sharded_model.unwrap()
else:
chatglm_model = org_model.transformer
- shard_chatglm_model = sharded_model.transformer
-
- # check attention grad
- org_grad = chatglm_model.encoder.layers[0].self_attention.query_key_value.weight.grad
- shard_grad = shard_chatglm_model.encoder.layers[0].self_attention.query_key_value.weight.grad
- shard_weight = shard_chatglm_model.encoder.layers[0].self_attention.query_key_value.weight
+ shard_chatglm_model = sharded_model.unwrap().transformer
+
+ # check grad
+ row_layer_for_check = ['encoder.layers[0].self_attention.query_key_value', 'embedding.word_embeddings']
+ col_layer_for_check = ['encoder.layers[0].self_attention.dense']
+ if stage_manager is None or stage_manager.is_first_stage():
+ check_grad(chatglm_model,
+ shard_chatglm_model,
+ row_layer_for_check,
+ tp_group,
+ atol=1e-6,
+ rtol=1e-3,
+ dim=0,
+ verbose=False)
+
+ check_grad(chatglm_model,
+ shard_chatglm_model,
+ col_layer_for_check,
+ tp_group,
+ atol=1e-6,
+ rtol=1e-3,
+ dim=1,
+ verbose=False)
+
+ # check weights after optimizer.step()
+ org_optimizer.step()
+ sharded_optimizer.step()
+ if stage_manager is None or stage_manager.is_first_stage():
+ check_weight(chatglm_model,
+ shard_chatglm_model,
+ col_layer_for_check,
+ tp_group,
+ atol=1e-4,
+ rtol=1e-3,
+ dim=1,
+ verbose=False)
- if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
- shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
- shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
- all_shard_grad = torch.cat(shard_grad_list, dim=0)
- else:
- all_shard_grad = shard_grad
- assert torch.allclose(org_grad, all_shard_grad,
- atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}"
-
- # check embedding weights
- org_grad = chatglm_model.embedding.word_embeddings.weight.grad
- shard_grad = shard_chatglm_model.embedding.word_embeddings.weight.grad
- shard_weight = shard_chatglm_model.embedding.word_embeddings.weight
-
- if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
- shard_grad_list = [torch.zeros_like(shard_grad) for _ in range(2)]
- torch.distributed.all_gather(shard_grad_list, shard_grad)
- all_shard_grad = torch.cat(shard_grad_list, dim=0)
- else:
- all_shard_grad = shard_grad
+ torch.cuda.empty_cache()
- assert torch.allclose(org_grad, all_shard_grad,
- atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
+@parameterize('test_config', [{
+ 'tp_size': 2,
+ 'pp_size': 2,
+ 'num_microbatches': 4,
+ 'enable_fused_normalization': True,
+ 'use_lazy_init': True
+}, {
+ 'tp_size': 1,
+ 'pp_size': 2,
+ 'num_microbatches': 4,
+ 'enable_fused_normalization': False,
+ 'use_lazy_init': False
+}, {
+ 'tp_size': 4,
+ 'pp_size': 1,
+ 'enable_fused_normalization': True,
+ 'use_lazy_init': False
+}])
+def run_chatglm_test(test_config):
+
+ # TODO: add test_config for TP+DP after supporting & debugging it
+ # {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True}
+
+ # TODO: add test_config for flash attention & jit operator after supporting
-@parameterize('enable_fused_normalization', [True, False])
-@parameterize('enable_tensor_parallelism', [True, False])
-@parameterize('enable_flash_attention', [True, False])
-@parameterize('enable_jit_fused', [True, False])
-def run_chatglm_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused):
sub_model_zoo = model_zoo.get_sub_registry('transformers_chatglm')
+ test_config['precision'] = 'float' # Do not use fp16/bf16 in testing
+
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
- # create new model
- org_model = model_fn().cuda()
-
- # shard model
- shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
- enable_tensor_parallelism=enable_tensor_parallelism,
- enable_flash_attention=enable_flash_attention,
- enable_jit_fused=enable_jit_fused)
- model_copy = copy.deepcopy(org_model)
- shard_former = ShardFormer(shard_config=shard_config)
- if name == "transformers_chatglm":
- sharded_model, _ = shard_former.optimize(model_copy, ChatGLMModelPolicy())
- else:
- sharded_model, _ = shard_former.optimize(model_copy, ChatGLMForConditionalGenerationPolicy())
- sharded_model = sharded_model.cuda()
-
- check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
+ check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
+
+ clear_layout_converter()
torch.cuda.empty_cache()
@@ -107,7 +134,7 @@ def check_chatglm(rank, world_size, port):
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_chatglm():
- spawn(check_chatglm, 2)
+ spawn(check_chatglm, 4)
if __name__ == "__main__":
diff --git a/tests/test_shardformer/test_model/test_shard_chatglm_pipeline.py b/tests/test_shardformer/test_model/test_shard_chatglm_pipeline.py
deleted file mode 100644
index ee474ac7be3b..000000000000
--- a/tests/test_shardformer/test_model/test_shard_chatglm_pipeline.py
+++ /dev/null
@@ -1,86 +0,0 @@
-import copy
-import os
-
-import pytest
-import torch
-
-import colossalai
-from colossalai.cluster import ProcessGroupMesh
-from colossalai.logging import disable_existing_loggers
-from colossalai.pipeline.stage_manager import PipelineStageManager
-from colossalai.shardformer.policies.chatglm import ChatGLMForConditionalGenerationPolicy, ChatGLMModelPolicy
-from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
-from colossalai.testing import (
- assert_hf_output_close,
- clear_cache_before_run,
- parameterize,
- rerun_if_address_is_in_use,
- spawn,
-)
-from tests.kit.model_zoo import model_zoo
-from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward
-
-
-@parameterize('enable_fused_normalization', [False])
-@parameterize('enable_tensor_parallelism', [False])
-@parameterize('use_lazy_init', [False])
-def run_chatglm_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
- DP_DIM, PP_DIM = 0, 1
- DP_SIZE, PP_SIZE = 2, 2
- pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
- stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
- sub_model_zoo = model_zoo.get_sub_registry('transformers_chatglm')
- for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
- # create new model for test
- inputs = data_gen_fn()
- inputs = {k: v.cuda() for k, v in inputs.items()}
- input_ids = inputs['input_ids']
- hidden_size = 64
- batch_size, seq_len = input_ids.shape
- hidden_state_shape = (seq_len, batch_size, hidden_size)
- if name == "transformers_chatglm":
- _, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
- enable_tensor_parallelism, use_lazy_init, ChatGLMModelPolicy())
- if stage_manager.is_last_stage():
- hidden_states = torch.randn(*hidden_state_shape).cuda()
- inputs['input_ids'] = None
- inputs['hidden_states'] = hidden_states
- outputs = sharded_model(**inputs)
- if stage_manager.is_last_stage():
- assert outputs[0].shape == hidden_state_shape
-
- else:
- assert outputs['hidden_states'].shape == hidden_state_shape
-
- if name == "transformers_chatglm_for_conditional_generation":
- _, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
- enable_tensor_parallelism, use_lazy_init,
- ChatGLMForConditionalGenerationPolicy())
- if stage_manager.is_last_stage():
- hidden_states = torch.randn(*hidden_state_shape).cuda()
- inputs['input_ids'] = None
- inputs['hidden_states'] = hidden_states
- outputs = sharded_model(**inputs)
- if stage_manager.is_last_stage():
- assert outputs[0].shape == (batch_size, seq_len, 65024)
- else:
- assert outputs['hidden_states'].shape == hidden_state_shape
-
- torch.cuda.empty_cache()
-
-
-def check_chatglm(rank, world_size, port):
- disable_existing_loggers()
- colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
- run_chatglm_test()
-
-
-@pytest.mark.dist
-@rerun_if_address_is_in_use()
-@clear_cache_before_run()
-def test_chatglm():
- spawn(check_chatglm, 4)
-
-
-if __name__ == "__main__":
- test_chatglm()
diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py
index ead14ab111e6..c5f8d22f18c9 100644
--- a/tests/test_shardformer/test_model/test_shard_llama.py
+++ b/tests/test_shardformer/test_model/test_shard_llama.py
@@ -2,69 +2,139 @@
import pytest
import torch
+from torch import distributed as dist
import colossalai
from colossalai.logging import disable_existing_loggers
-from colossalai.testing import (
- assert_hf_output_close,
- clear_cache_before_run,
- parameterize,
- rerun_if_address_is_in_use,
- spawn,
-)
+from colossalai.tensor.d_tensor.api import clear_layout_converter
+from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
-from tests.test_shardformer.test_model._utils import build_model, check_grad, check_state_dict, run_forward
+from tests.test_shardformer.test_model._utils import (
+ build_model_from_hybrid_plugin,
+ check_grad,
+ check_loss,
+ check_output_hidden_state,
+ check_weight,
+ run_forward_backward_with_hybrid_plugin,
+)
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
-def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
- org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
- output_transform_fn, loss_fn)
+def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
+
+ org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \
+ build_model_from_hybrid_plugin(model_fn, loss_fn, test_config)
+
+ org_loss, org_output, sharded_loss, sharded_output = \
+ run_forward_backward_with_hybrid_plugin(
+ org_model,
+ sharded_model,
+ sharded_optimizer,
+ data_gen_fn,
+ output_transform_fn,
+ criterion,
+ booster)
- # forward check
- assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'], rtol=1e-5)
+ stage_manager = booster.plugin.stage_manager
+ tp_group = booster.plugin.tp_group
- # run backward
- org_loss.backward()
- shard_loss.backward()
+ # check last hidden state & loss
+ if stage_manager is None or stage_manager.is_last_stage():
- assert torch.allclose(org_loss, shard_loss,
- atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
+ if org_model.__class__.__name__ == 'LlamaModel':
+ check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3)
+
+ check_loss(org_loss, sharded_loss, atol=1e-6, rtol=1e-3)
# unwrap model
- if hasattr(org_model, 'model'):
- llama_model = org_model.model
- shard_llama_model = sharded_model.model
- else:
+ if org_model.__class__.__name__ == 'LlamaModel':
llama_model = org_model
- shard_llama_model = sharded_model
+ shard_llama_model = sharded_model.unwrap()
+ else:
+ llama_model = org_model.model
+ shard_llama_model = sharded_model.unwrap().model
# check grad
- col_layer_for_check = ['layers[0].self_attn.q_proj', 'embed_tokens']
- row_layer_for_check = ['layers[0].self_attn.o_proj']
- check_grad(llama_model, shard_llama_model, col_layer_for_check, atol=1e-6, rtol=1e-4, dim=0, verbose=False)
- check_grad(llama_model, shard_llama_model, row_layer_for_check, atol=1e-6, rtol=1e-4, dim=1, verbose=False)
+ row_layer_for_check = ['layers[0].self_attn.q_proj', 'embed_tokens']
+ col_layer_for_check = ['layers[0].self_attn.o_proj']
+ if stage_manager is None or stage_manager.is_first_stage():
+ check_grad(llama_model,
+ shard_llama_model,
+ row_layer_for_check,
+ tp_group,
+ atol=1e-6,
+ rtol=1e-4,
+ dim=0,
+ verbose=False)
+ check_grad(llama_model,
+ shard_llama_model,
+ col_layer_for_check,
+ tp_group,
+ atol=1e-6,
+ rtol=1e-4,
+ dim=1,
+ verbose=False)
+
+ # check weights after optimizer.step()
+ org_optimizer.step()
+ sharded_optimizer.step()
+ if stage_manager is None or stage_manager.is_first_stage():
+ check_weight(llama_model,
+ shard_llama_model,
+ col_layer_for_check,
+ tp_group,
+ atol=1e-4,
+ rtol=1e-3,
+ dim=1,
+ verbose=False)
+
+ torch.cuda.empty_cache()
-@parameterize('enable_fused_normalization', [True, False])
-@parameterize('enable_tensor_parallelism', [True, False])
-@parameterize('enable_flash_attention', [True, False])
-@parameterize('use_lazy_init', [False, True])
-def run_gpt2_llama(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, use_lazy_init):
+@parameterize('test_config', [{
+ 'tp_size': 2,
+ 'pp_size': 2,
+ 'num_microbatches': 2,
+ 'enable_fused_normalization': True,
+ 'use_lazy_init': True
+}, {
+ 'tp_size': 1,
+ 'pp_size': 2,
+ 'num_microbatches': 4,
+ 'use_lazy_init': False
+}, {
+ 'tp_size': 4,
+ 'pp_size': 1,
+ 'enable_fused_normalization': True,
+ 'use_lazy_init': False
+}, {
+ 'tp_size': 1,
+ 'pp_size': 4,
+ 'num_microbatches': 4,
+ 'use_lazy_init': False
+}])
+def run_llama_test(test_config):
+
+ # TODO: add test_config for TP+DP after supporting & debugging it
+ # {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True}
+
+ # TODO: add test_config for flash attention & jit operator after supporting
+
sub_model_zoo = model_zoo.get_sub_registry('transformers_llama')
+ test_config['precision'] = 'float' # Do not use fp16/bf16 in testing
+
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
- org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
- enable_flash_attention, use_lazy_init)
- check_state_dict(org_model, sharded_model, name=name)
- check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
+ check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
+
+ clear_layout_converter()
torch.cuda.empty_cache()
def check_llama(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
- run_gpt2_llama()
+ run_llama_test()
@pytest.mark.dist
diff --git a/tests/test_shardformer/test_model/test_shard_llama_pipeline.py b/tests/test_shardformer/test_model/test_shard_llama_pipeline.py
deleted file mode 100644
index 6f1f0cb34508..000000000000
--- a/tests/test_shardformer/test_model/test_shard_llama_pipeline.py
+++ /dev/null
@@ -1,89 +0,0 @@
-import pytest
-import torch
-
-import colossalai
-from colossalai.cluster import ProcessGroupMesh
-from colossalai.logging import disable_existing_loggers
-from colossalai.pipeline.stage_manager import PipelineStageManager
-from colossalai.shardformer.policies.auto_policy import get_autopolicy
-from colossalai.shardformer.policies.base_policy import Policy
-from colossalai.shardformer.shard import ShardConfig
-from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
-from colossalai.testing import (
- assert_hf_output_close,
- clear_cache_before_run,
- parameterize,
- rerun_if_address_is_in_use,
- spawn,
-)
-from tests.kit.model_zoo import model_zoo
-from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward
-
-
-def check_llama_model_policy(name, model: torch.nn.Module, stage_manager: PipelineStageManager):
- policy = get_autopolicy(model)
- policy.set_model(model)
- model_config = ShardConfig(pipeline_stage_manager=stage_manager, enable_tensor_parallelism=False)
- policy.set_shard_config(model_config)
- layers = policy.get_held_layers()
- if stage_manager.is_first_stage():
- assert len(layers) == 2 + 1
- else:
- if name == "transformers_llama":
- assert len(layers) == 2 + 1
- else:
- assert len(layers) == 2 + 2
-
-
-def check_llama_model_pipeline_forward(name, sharded_model, stage_manager: PipelineStageManager):
- x = torch.randint(0, 1000, (2, 3)).cuda()
- if stage_manager.stage == 0:
- attention_mask = torch.ones_like(x).cuda()
- output = sharded_model(input_ids=x, attention_mask=attention_mask)
- assert output['hidden_states'].shape == (2, 3, 128)
- else:
- hidden_states = torch.randint(0, 1000, (2, 3, 128)).to(torch.float32).cuda()
- attention_mask = torch.ones((2, 3)).cuda()
- output = sharded_model(
- hidden_states=hidden_states,
- attention_mask=attention_mask,
- )
- assert output[0] is not None
-
-
-@parameterize('enable_fused_normalization', [False])
-@parameterize('enable_tensor_parallelism', [False])
-@parameterize('use_lazy_init', [False])
-#TODO: merge this into test_shard_llama
-def run_llama_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
- PP_DIM = 0
- PP_SIZE = 2
- pg_mesh = ProcessGroupMesh(PP_SIZE)
- stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
-
- sub_model_zoo = model_zoo.get_sub_registry('transformers_llama')
-
- for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
- org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
- enable_tensor_parallelism, use_lazy_init)
- check_llama_model_policy(name, org_model, stage_manager)
- check_llama_model_pipeline_forward(name, sharded_model, stage_manager)
-
- torch.cuda.empty_cache()
-
-
-def check_llama(rank, world_size, port):
- disable_existing_loggers()
- colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
- run_llama_test()
-
-
-@pytest.mark.dist
-@rerun_if_address_is_in_use()
-@clear_cache_before_run()
-def test_llama():
- spawn(check_llama, 2)
-
-
-if __name__ == "__main__":
- test_llama()
diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py
index 99a278d4303a..d8fa8104bb07 100644
--- a/tests/test_shardformer/test_model/test_shard_opt.py
+++ b/tests/test_shardformer/test_model/test_shard_opt.py
@@ -1,64 +1,129 @@
-import copy
import os
import pytest
import torch
+from torch import distributed as dist
import colossalai
from colossalai.logging import disable_existing_loggers
-from colossalai.testing import (
- assert_hf_output_close,
- clear_cache_before_run,
- parameterize,
- rerun_if_address_is_in_use,
- spawn,
-)
+from colossalai.tensor.d_tensor.api import clear_layout_converter
+from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
-from tests.test_shardformer.test_model._utils import build_model, check_grad, check_state_dict, run_forward
+from tests.test_shardformer.test_model._utils import (
+ build_model_from_hybrid_plugin,
+ check_grad,
+ check_loss,
+ check_output_hidden_state,
+ check_weight,
+ run_forward_backward_with_hybrid_plugin,
+)
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
-def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
- org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
- output_transform_fn, loss_fn)
- assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'], rtol=1e-5)
+def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
+
+ org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \
+ build_model_from_hybrid_plugin(model_fn, loss_fn, test_config)
+
+ org_loss, org_output, sharded_loss, sharded_output = \
+ run_forward_backward_with_hybrid_plugin(
+ org_model,
+ sharded_model,
+ sharded_optimizer,
+ data_gen_fn,
+ output_transform_fn,
+ criterion,
+ booster)
+
+ stage_manager = booster.plugin.stage_manager
+ tp_group = booster.plugin.tp_group
- # run backward
- org_loss.backward()
- shard_loss.backward()
+ # check last hidden state & loss
+ if stage_manager is None or stage_manager.is_last_stage():
- assert torch.allclose(org_loss, shard_loss,
- atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
+ if org_model.__class__.__name__ == 'OPTModel':
+ check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3)
+
+ check_loss(org_loss, sharded_loss, atol=1e-6, rtol=1e-3)
# unwrap model
- if hasattr(org_model, 'model'):
- opt_model = org_model.model
- shard_opt_model = sharded_model.model
- else:
+ if org_model.__class__.__name__ == 'OPTModel':
opt_model = org_model
- shard_opt_model = sharded_model
+ shard_opt_model = sharded_model.unwrap()
+ else:
+ opt_model = org_model.model
+ shard_opt_model = sharded_model.unwrap().model
# check grad
- col_layer_for_check = ['decoder.layers[0].self_attn.q_proj', 'decoder.embed_tokens']
- row_layer_for_check = ['decoder.layers[0].self_attn.out_proj']
- check_grad(opt_model, shard_opt_model, col_layer_for_check, atol=1e-6, rtol=1e-3, dim=0, verbose=False)
- check_grad(opt_model, shard_opt_model, row_layer_for_check, atol=1e-6, rtol=1e-3, dim=1, verbose=False)
-
-
-@parameterize('use_lazy_init', [False, True])
-@parameterize('enable_fused_normalization', [True, False])
-@parameterize('enable_tensor_parallelism', [True, False])
-@parameterize('enable_flash_attention', [True, False])
-@parameterize('enable_jit_fused', [True, False])
-def run_opt_test(use_lazy_init, enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention,
- enable_jit_fused):
+ row_layer_for_check = ['decoder.layers[0].self_attn.q_proj', 'decoder.embed_tokens']
+ col_layer_for_check = ['decoder.layers[0].self_attn.out_proj']
+ if stage_manager is None or stage_manager.is_first_stage():
+ check_grad(opt_model,
+ shard_opt_model,
+ row_layer_for_check,
+ tp_group,
+ atol=1e-6,
+ rtol=1e-3,
+ dim=0,
+ verbose=False)
+ check_grad(opt_model,
+ shard_opt_model,
+ col_layer_for_check,
+ tp_group,
+ atol=1e-6,
+ rtol=1e-3,
+ dim=1,
+ verbose=False)
+
+ # check weights after optimizer.step()
+ org_optimizer.step()
+ sharded_optimizer.step()
+ if stage_manager is None or stage_manager.is_first_stage():
+ check_weight(opt_model,
+ shard_opt_model,
+ col_layer_for_check,
+ tp_group,
+ atol=1e-3,
+ rtol=1e-3,
+ dim=1,
+ verbose=False)
+
+ torch.cuda.empty_cache()
+
+
+@parameterize('test_config', [{
+ 'tp_size': 2,
+ 'pp_size': 2,
+ 'num_microbatches': 4,
+ 'enable_fused_normalization': True,
+ 'use_lazy_init': True
+}, {
+ 'tp_size': 1,
+ 'pp_size': 2,
+ 'num_microbatches': 4,
+ 'enable_fused_normalization': False,
+ 'use_lazy_init': False
+}, {
+ 'tp_size': 4,
+ 'pp_size': 1,
+ 'enable_fused_normalization': True,
+ 'use_lazy_init': False
+}])
+def run_opt_test(test_config):
+
+ # TODO: add test_config for TP+DP after supporting & debugging it
+ # {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True}
+
+ # TODO: add test_config for flash attention & jit operator after supporting
+
sub_model_zoo = model_zoo.get_sub_registry('transformers_opt')
+ test_config['precision'] = 'float' # Do not use fp16/bf16 in testing
+
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
- org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
- enable_flash_attention, enable_jit_fused, use_lazy_init)
- check_state_dict(org_model, sharded_model, name=name)
- check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
+ check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
+
+ clear_layout_converter()
torch.cuda.empty_cache()
diff --git a/tests/test_shardformer/test_model/test_shard_opt_pipeline.py b/tests/test_shardformer/test_model/test_shard_opt_pipeline.py
deleted file mode 100644
index 0684418d0a8d..000000000000
--- a/tests/test_shardformer/test_model/test_shard_opt_pipeline.py
+++ /dev/null
@@ -1,70 +0,0 @@
-import pytest
-import torch
-
-import colossalai
-from colossalai.cluster import ProcessGroupMesh
-from colossalai.logging import disable_existing_loggers
-from colossalai.pipeline.stage_manager import PipelineStageManager
-from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
-from tests.kit.model_zoo import model_zoo
-from tests.test_shardformer.test_model._utils import build_pipeline_model
-
-
-def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
- # TODO: add tests for forward/backward later
- pass
-
-
-@parameterize('enable_tensor_parallelism', [False])
-@parameterize('enable_fused_normalization', [False])
-@parameterize('use_lazy_init', [False])
-#TODO: merge this into test_shard_opt
-def run_opt_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
- DP_DIM, PP_DIM = 0, 1
- DP_SIZE, PP_SIZE = 2, 2
- pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
- stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
-
- sub_model_zoo = model_zoo.get_sub_registry('transformers_opt')
- for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items():
- inputs = data_gen_fn()
- inputs = {k: v.cuda() for k, v in inputs.items()}
- input_ids, _ = inputs['input_ids'], inputs['attention_mask']
- batch_size, seq_len = input_ids.shape
- hidden_size = 128
- hidden_state_shape = (batch_size, seq_len, hidden_size)
-
- if not stage_manager.is_first_stage():
- # change inputs if not the first stage
-
- hidden_states = torch.zeros(*hidden_state_shape).cuda()
- inputs['input_ids'] = None
- inputs['hidden_states'] = hidden_states
-
- _, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
- enable_tensor_parallelism, use_lazy_init)
- sharded_model.train()
-
- output = sharded_model(**inputs)
- if stage_manager.is_last_stage():
- assert output[0] is not None
- else:
- assert output['hidden_states'].shape == hidden_state_shape
- torch.cuda.empty_cache()
-
-
-def check_opt(rank, world_size, port):
- disable_existing_loggers()
- colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
- run_opt_test()
-
-
-@pytest.mark.dist
-@rerun_if_address_is_in_use()
-@clear_cache_before_run()
-def test_opt():
- spawn(check_opt, 4)
-
-
-if __name__ == "__main__":
- test_opt()
diff --git a/tests/test_shardformer/test_model/test_shard_vit.py b/tests/test_shardformer/test_model/test_shard_vit.py
index d179c8a8ee32..8a78d7c2b8ce 100644
--- a/tests/test_shardformer/test_model/test_shard_vit.py
+++ b/tests/test_shardformer/test_model/test_shard_vit.py
@@ -1,60 +1,127 @@
-import os
-
import pytest
import torch
import colossalai
from colossalai.logging import disable_existing_loggers
-from colossalai.testing import (
- assert_hf_output_close,
- clear_cache_before_run,
- parameterize,
- rerun_if_address_is_in_use,
- spawn,
-)
+from colossalai.shardformer.layer.utils import Randomizer
+from colossalai.tensor.d_tensor.api import clear_layout_converter
+from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
-from tests.test_shardformer.test_model._utils import build_model, check_grad, run_forward
+from tests.test_shardformer.test_model._utils import (
+ build_model_from_hybrid_plugin,
+ check_grad,
+ check_loss,
+ check_output_hidden_state,
+ check_weight,
+ run_forward_backward_with_hybrid_plugin,
+)
+
+
+def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
+
+ org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \
+ build_model_from_hybrid_plugin(model_fn, loss_fn, test_config)
+ org_loss, org_output, sharded_loss, sharded_output = \
+ run_forward_backward_with_hybrid_plugin(
+ org_model,
+ sharded_model,
+ sharded_optimizer,
+ data_gen_fn,
+ output_transform_fn,
+ criterion,
+ booster)
-def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
- # check forward
- org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
- output_transform_fn, loss_fn)
+ stage_manager = booster.plugin.stage_manager
+ tp_group = booster.plugin.tp_group
- assert_hf_output_close(org_output, shard_output, atol=1e-3, rtol=1e-3)
+ # check last hidden state & loss
+ if stage_manager is None or stage_manager.is_last_stage():
- # do backward
- org_loss.backward()
- shard_loss.backward()
+ if org_model.__class__.__name__ == 'ViTModel':
+ check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3)
- assert torch.allclose(org_loss, shard_loss,
- atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
+ check_loss(org_loss, sharded_loss, atol=1e-6, rtol=1e-3)
# unwrap model
if org_model.__class__.__name__ == 'ViTModel':
vit_model = org_model
- shard_vit_model = sharded_model
+ shard_vit_model = sharded_model.unwrap()
else:
vit_model = org_model.vit
- shard_vit_model = sharded_model.vit
+ shard_vit_model = sharded_model.unwrap().vit
# check grad
- col_layer_for_check = ['encoder.layer[0].attention.attention.query']
- row_layer_for_check = ['encoder.layer[0].attention.output.dense']
- check_grad(vit_model, shard_vit_model, col_layer_for_check, atol=1e-5, rtol=1e-3, dim=0, verbose=False)
- check_grad(vit_model, shard_vit_model, row_layer_for_check, atol=1e-5, rtol=1e-3, dim=1, verbose=False)
+ row_layer_for_check = ['encoder.layer[0].attention.attention.query', 'embeddings.patch_embeddings.projection']
+ col_layer_for_check = ['encoder.layer[0].attention.output.dense']
+ if stage_manager is None or stage_manager.is_first_stage():
+ check_grad(vit_model,
+ shard_vit_model,
+ row_layer_for_check,
+ tp_group,
+ atol=1e-5,
+ rtol=1e-3,
+ dim=0,
+ verbose=False)
+ check_grad(vit_model,
+ shard_vit_model,
+ col_layer_for_check,
+ tp_group,
+ atol=1e-5,
+ rtol=1e-3,
+ dim=1,
+ verbose=False)
+
+ # check weights after optimizer.step()
+ org_optimizer.step()
+ sharded_optimizer.step()
+ if stage_manager is None or stage_manager.is_first_stage():
+ check_weight(vit_model,
+ shard_vit_model,
+ col_layer_for_check,
+ tp_group,
+ atol=5e-3,
+ rtol=1e-3,
+ dim=1,
+ verbose=False)
+ torch.cuda.empty_cache()
+
+
+@parameterize('test_config', [{
+ 'tp_size': 2,
+ 'pp_size': 2,
+ 'num_microbatches': 4,
+ 'enable_fused_normalization': True,
+ 'use_lazy_init': False
+}, {
+ 'tp_size': 1,
+ 'pp_size': 2,
+ 'num_microbatches': 4,
+ 'enable_fused_normalization': False,
+ 'use_lazy_init': False
+}, {
+ 'tp_size': 4,
+ 'pp_size': 1,
+ 'enable_fused_normalization': True,
+ 'use_lazy_init': False
+}])
+def run_vit_test(test_config):
+
+ # TODO: add test_config for TP+DP after supporting & debugging it
+ # {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True}
+
+ # TODO: add test_config for flash attention & jit operator after supporting
+ # TODO: fix bug when settign lazy_init for Conv2D Layers in ViT models
-@parameterize('enable_fused_normalization', [True, False])
-@parameterize('enable_tensor_parallelism', [True, False])
-@parameterize('enable_flash_attention', [True, False])
-@parameterize('enable_jit_fused', [True, False])
-def run_vit_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused):
sub_model_zoo = model_zoo.get_sub_registry('transformers_vit')
+ test_config['precision'] = 'float' # Do not use fp16/bf16 in testing
+
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
- org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
- enable_flash_attention, enable_jit_fused)
- check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
+ check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
+
+ clear_layout_converter()
+ Randomizer.reset_index()
torch.cuda.empty_cache()
@@ -68,7 +135,7 @@ def check_vit(rank, world_size, port):
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_vit():
- spawn(check_vit, 2)
+ spawn(check_vit, 4)
if __name__ == "__main__":
diff --git a/tests/test_shardformer/test_model/test_shard_vit_pipeline.py b/tests/test_shardformer/test_model/test_shard_vit_pipeline.py
deleted file mode 100644
index 114992a2a2a5..000000000000
--- a/tests/test_shardformer/test_model/test_shard_vit_pipeline.py
+++ /dev/null
@@ -1,74 +0,0 @@
-import pytest
-import torch
-
-import colossalai
-from colossalai.cluster import ProcessGroupMesh
-from colossalai.logging import disable_existing_loggers
-from colossalai.pipeline.stage_manager import PipelineStageManager
-from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
-from tests.kit.model_zoo import model_zoo
-from tests.test_shardformer.test_model._utils import build_pipeline_model
-
-
-def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
- # TODO: add tests for forward/backward later
- pass
-
-
-@parameterize('enable_tensor_parallelism', [False])
-@parameterize('enable_fused_normalization', [False])
-@parameterize('use_lazy_init', [False])
-#TODO: merge this into test_shard_vit
-def run_vit_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
- DP_DIM, PP_DIM = 0, 1
- DP_SIZE, PP_SIZE = 2, 2
- pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
- stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
-
- sub_model_zoo = model_zoo.get_sub_registry('transformers_vit')
-
- for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items():
-
- inputs = data_gen_fn()
- inputs = {k: v.cuda() for k, v in inputs.items()}
- pixel_values = inputs['pixel_values']
- batch_size = len(pixel_values)
- hidden_size = 768
- hidden_state_shape = (batch_size, 197, hidden_size)
-
- if not stage_manager.is_first_stage():
- # change inputs if not the first stage
- hidden_states = torch.randn(*hidden_state_shape).cuda()
- # inputs['pixel_values'] = None
- inputs['hidden_states'] = hidden_states
-
- _, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
- enable_tensor_parallelism, use_lazy_init)
- sharded_model.train()
-
- output = sharded_model(**inputs)
- if stage_manager.is_last_stage():
- if name != 'transformers_vit':
- assert output.loss is not None
- else:
- assert output['hidden_states'].shape == hidden_state_shape, \
- f'hidden_states shape is not correct, output:{output["hidden_states"].shape} is not equal to hidden_state:{hidden_state_shape}'
-
- torch.cuda.empty_cache()
-
-
-def check_vit(rank, world_size, port):
- disable_existing_loggers()
- colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
- run_vit_test()
-
-
-@pytest.mark.dist
-@rerun_if_address_is_in_use()
-@clear_cache_before_run()
-def test_vit():
- spawn(check_vit, 4)
-
-
-if __name__ == "__main__":
- test_vit()
From d4a3a101012cf5f4433ba5cd76c6f4de58aab34e Mon Sep 17 00:00:00 2001
From: flybird11111 <1829166702@qq.com>
Date: Fri, 11 Aug 2023 16:40:06 +0800
Subject: [PATCH 61/64] [shardformer] update tests for all optimization (#4413)
[shardformer] update tests for all optimization
---
colossalai/shardformer/modeling/bert.py | 5 ++-
tests/kit/model_zoo/transformers/bert.py | 29 +++++++++-----
.../test_model/test_shard_bert.py | 39 +++++++++++++------
3 files changed, 50 insertions(+), 23 deletions(-)
diff --git a/colossalai/shardformer/modeling/bert.py b/colossalai/shardformer/modeling/bert.py
index b9d4b5fda7af..eaafd67b8968 100644
--- a/colossalai/shardformer/modeling/bert.py
+++ b/colossalai/shardformer/modeling/bert.py
@@ -1048,9 +1048,12 @@ def forward(
final_attention_mask = final_attention_mask * scale + attention_mask
else:
final_attention_mask = attention_mask
+
+ if final_attention_mask is not None:
batch_size, src_len = query_layer.size()[0], query_layer.size()[2]
tgt_len = key_layer.size()[2]
- final_attention_mask = final_attention_mask.expand(batch_size, self.num_attention_heads, src_len, tgt_len)
+ final_attention_mask = final_attention_mask.expand(batch_size, self.num_attention_heads, src_len,
+ tgt_len).contiguous()
query_layer = query_layer.permute(0, 2, 1, 3).contiguous()
key_layer = key_layer.permute(0, 2, 1, 3).contiguous()
diff --git a/tests/kit/model_zoo/transformers/bert.py b/tests/kit/model_zoo/transformers/bert.py
index 52158596bcf8..e16d3b269ba0 100644
--- a/tests/kit/model_zoo/transformers/bert.py
+++ b/tests/kit/model_zoo/transformers/bert.py
@@ -69,21 +69,30 @@ def data_gen_for_mcq():
# data['labels'] = torch.tensor([0], dtype=torch.int64)
input_ids = torch.tensor([[[
101, 1999, 3304, 1010, 10733, 2366, 1999, 5337, 10906, 1010, 2107, 2004, 2012, 1037, 4825, 1010, 2003, 3591,
- 4895, 14540, 6610, 2094, 1012, 102, 2009, 2003, 8828, 2007, 1037, 9292, 1998, 1037, 5442, 1012, 102, 102
+ 4895, 14540, 6610, 2094, 1012, 102, 2009, 2003, 8828, 2007, 1037, 9292, 1998, 1037, 5442, 1012, 102, 102, 5442,
+ 1012, 102, 102
],
[
101, 1999, 3304, 1010, 10733, 2366, 1999, 5337, 10906, 1010, 2107, 2004, 2012, 1037,
4825, 1010, 2003, 3591, 4895, 14540, 6610, 2094, 1012, 102, 2009, 2003, 8828, 2096,
- 2218, 1999, 1996, 2192, 1012, 102, 0, 0
+ 2218, 1999, 1996, 2192, 1012, 102, 0, 0, 1012, 102, 0, 0
]]])
- token_type_ids = torch.tensor(
- [[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
- [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,
- 0]]])
- attention_mask = torch.tensor(
- [[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
- [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,
- 0]]])
+ token_type_ids = torch.tensor([[[
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1
+ ],
+ [
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0
+ ]]])
+ attention_mask = torch.tensor([[[
+ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1
+ ],
+ [
+ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0
+ ]]])
labels = torch.tensor([0], dtype=torch.int64)
return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, labels=labels)
diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py
index fdbcd014e1b8..0a24e46d28f2 100644
--- a/tests/test_shardformer/test_model/test_shard_bert.py
+++ b/tests/test_shardformer/test_model/test_shard_bert.py
@@ -36,10 +36,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
tp_group = booster.plugin.tp_group
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
+ if test_config['precision'] == 'fp32':
+ atol, rtol = 1e-5, 1e-3
+ else:
+ atol, rtol = 5e-3, 5e-3
if org_model.__class__.__name__ == 'BertModel':
- check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3)
+ check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
- check_loss(org_loss, sharded_loss, atol=1e-5, rtol=1e-3)
+ check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# unwrap model
if org_model.__class__.__name__ == 'BertModel':
bert = org_model
@@ -51,17 +55,25 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
col_layer_for_check = ['encoder.layer[0].output.dense']
row_layer_for_check = ['embeddings.word_embeddings', 'encoder.layer[0].intermediate.dense']
+ if test_config['precision'] == 'fp32':
+ atol, rtol = 1e-4, 1e-3
+ else:
+ atol, rtol = 5e-3, 5e-3
if stage_manager is None or stage_manager.is_first_stage():
#check_weight(bert.embeddings.word_embeddings, sharded_bert.embeddings.word_embeddings, tp_group, atol=1e-5, rtol=1e-3)
#check_weight(bert.encoder.layer[0].attention.self.query, sharded_bert.encoder.layer[0].attention.self.query, tp_group, atol=5e-3, rtol=1e-3)
- check_grad(bert, sharded_bert, col_layer_for_check, tp_group, atol=1e-4, rtol=1e-3, dim=1, verbose=False)
- check_grad(bert, sharded_bert, row_layer_for_check, tp_group, atol=1e-4, rtol=1e-3, dim=0, verbose=False)
+ check_grad(bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False)
+ check_grad(bert, sharded_bert, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False)
# check weights after optimizer.step()
org_optimizer.step()
sharded_optimizer.step()
+ if test_config['precision'] == 'fp32':
+ atol, rtol = 5e-3, 1e-3
+ else:
+ atol, rtol = 5e-3, 5e-3
if stage_manager is None or stage_manager.is_first_stage():
- check_weight(bert, sharded_bert, col_layer_for_check, tp_group, atol=5e-3, rtol=1e-3, dim=1, verbose=False)
+ check_weight(bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False)
torch.cuda.empty_cache()
@@ -70,23 +82,26 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 4,
- 'use_lazy_init': True
+ 'use_lazy_init': True,
+ 'precision': 'fp32',
}, {
'tp_size': 2,
'pp_size': 2,
- 'num_microbatches': 4,
- 'enable_fused_normalization': False,
- 'use_lazy_init': False
+ 'num_microbatches': 2,
+ 'enable_all_optimization': True,
+ 'use_lazy_init': True,
+ 'precision': 'fp16',
+ 'initial_scale': 1,
}, {
'tp_size': 4,
'pp_size': 1,
- 'enable_fused_normalization': True,
- 'use_lazy_init': False
+ 'enable_all_optimization': True,
+ 'use_lazy_init': False,
+ 'precision': 'fp32',
}])
def run_bert_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_bert')
- test_config['precision'] = 'float'
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
From ac8d4ed8664979beae149fd203339dd94000bad0 Mon Sep 17 00:00:00 2001
From: flybird11111 <1829166702@qq.com>
Date: Mon, 14 Aug 2023 15:49:13 +0800
Subject: [PATCH 62/64] [shardformer]update t5 tests for using all
optimizations. (#4407)
* [shardformer] gpt2 tests fix
[shardformer] test all optimizations (#4399)
[shardformer] test all optimizations
[shardformer] test all optimizations
[shardformer] test all optimizations
[shardformer] gpt2 tests fix
* [shardformer]update t5 to use all optimizations
---
colossalai/shardformer/README.md | 2 +-
tests/kit/model_zoo/transformers/t5.py | 8 ++--
.../test_model/test_shard_t5.py | 39 +++++++++++++------
3 files changed, 33 insertions(+), 16 deletions(-)
diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md
index 1c11b4b85444..18e00a6a663d 100644
--- a/colossalai/shardformer/README.md
+++ b/colossalai/shardformer/README.md
@@ -31,7 +31,7 @@
### Quick Start
-The sample API usage is given below(If you enable the use of flash attention, please install xformers.):
+The sample API usage is given below(If you enable the use of flash attention, please install `flash_attn`. In addition, xformers's `cutlass_op` provide a supplementary optimization, It requires that the sequence length be a multiple of 8.):
``` python
from colossalai.shardformer import ShardConfig, Shard
diff --git a/tests/kit/model_zoo/transformers/t5.py b/tests/kit/model_zoo/transformers/t5.py
index 435cb6f46937..175d48963480 100644
--- a/tests/kit/model_zoo/transformers/t5.py
+++ b/tests/kit/model_zoo/transformers/t5.py
@@ -16,8 +16,8 @@ def data_gen_for_encoder_only():
# config = T5Config(decoder_start_token_id=0)
# tokenizer = T5Tokenizer.from_pretrained("t5-small")
# input_ids = tokenizer("translate English to German: The house is wonderful.", return_tensors="pt").input_ids
- input_ids = torch.Tensor([[13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5, 1, 12]]).long()
- attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]]).long()
+ input_ids = torch.Tensor([[13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5, 1, 12, 1627, 5, 1, 12]]).long()
+ attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]]).long()
return dict(input_ids=input_ids, attention_mask=attention_mask)
@@ -26,7 +26,7 @@ def data_gen_for_conditional_generation():
#
# labels = tokenizer("Das Haus ist wunderbar.", return_tensors="pt").input_ids
data = data_gen_for_encoder_only()
- labels = torch.Tensor([[644, 4598, 229, 19250, 5, 1, 644, 4598, 229, 19250, 5, 1]]).long()
+ labels = torch.Tensor([[644, 4598, 229, 19250, 5, 1, 644, 4598, 229, 19250, 5, 1, 229, 19250, 5, 1]]).long()
data['labels'] = labels
return data
@@ -35,7 +35,7 @@ def data_gen_for_t5_model():
# decoder_inputs_ids is obtained with the following code
# decoder_input_ids = model._shift_right(input_ids)
data = data_gen_for_encoder_only()
- decoder_input_ids = torch.Tensor([[0, 13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5, 5]]).long()
+ decoder_input_ids = torch.Tensor([[0, 13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5, 5, 19, 1627, 5, 5]]).long()
data['decoder_input_ids'] = decoder_input_ids
return data
diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py
index d807ffa06296..fb065b42250b 100644
--- a/tests/test_shardformer/test_model/test_shard_t5.py
+++ b/tests/test_shardformer/test_model/test_shard_t5.py
@@ -37,11 +37,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
+ if test_config['precision'] == 'fp32':
+ atol, rtol = 1e-5, 1e-3
+ else:
+ atol, rtol = 5e-3, 5e-3
if org_model.__class__.__name__ != 'T5ForConditionalGeneration':
- check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3)
+ check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
- check_loss(org_loss, sharded_loss, atol=1e-5, rtol=1e-3)
+ check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# unwrap model
t5 = org_model
@@ -50,14 +54,22 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
row_layer_for_check = ['shared', 'encoder.block[0].layer[0].SelfAttention.q']
# check weights and gradients
+ if test_config['precision'] == 'fp32':
+ atol, rtol = 1e-5, 1e-3
+ else:
+ atol, rtol = 5e-3, 5e-3
if stage_manager is None or stage_manager.is_first_stage():
- check_grad(t5, sharded_t5, row_layer_for_check, tp_group, atol=1e-5, rtol=1e-3, dim=0)
+ check_grad(t5, sharded_t5, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0)
# check weights after optimizer.step()
org_optimizer.step()
sharded_optimizer.step()
+ if test_config['precision'] == 'fp32':
+ atol, rtol = 1e-4, 1e-3
+ else:
+ atol, rtol = 5e-3, 5e-3
if stage_manager is None or stage_manager.is_first_stage():
- check_weight(t5, sharded_t5, row_layer_for_check, tp_group, atol=1e-4, rtol=1e-3, dim=0, verbose=False)
+ check_weight(t5, sharded_t5, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False)
torch.cuda.empty_cache()
@@ -66,23 +78,29 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 2,
- 'enable_fused_normalization': True,
- 'use_lazy_init': True
+ 'enable_all_optimization': True,
+ 'use_lazy_init': True,
+ 'precision': 'fp16',
+ 'initial_scale': 1,
}, {
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 4,
- 'use_lazy_init': False
+ 'use_lazy_init': False,
+ 'precision': 'fp16',
+ 'initial_scale': 1,
}, {
'tp_size': 4,
'pp_size': 1,
- 'enable_fused_normalization': True,
- 'use_lazy_init': False
+ 'enable_all_optimization': True,
+ 'use_lazy_init': False,
+ 'precision': 'fp32',
}, {
'tp_size': 1,
'pp_size': 4,
'num_microbatches': 4,
- 'use_lazy_init': False
+ 'use_lazy_init': False,
+ 'precision': 'fp32',
}])
@clear_cache_before_run()
def run_t5_test(test_config):
@@ -93,7 +111,6 @@ def run_t5_test(test_config):
# TODO: add test_config for flash attention & jit operator after supporting
sub_model_zoo = model_zoo.get_sub_registry('transformers_t5')
- test_config['precision'] = 'float' # Do not use fp16/bf16 in testing
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
From 82ea190b547ba9eab38126f8598cbbc21bb7bb37 Mon Sep 17 00:00:00 2001
From: flybird11111 <1829166702@qq.com>
Date: Mon, 14 Aug 2023 15:51:13 +0800
Subject: [PATCH 63/64] [shardformer] update bloom/llama/vit/chatglm tests
(#4420)
[shardformer] update bloom/llama/vit/chatglm tests
[shardformer] update opt tests
[shardformer] update opt tests
[shardformer] update bloom/llama/vit/chatglm tests
[shardformer] update bloom/llama/vit/chatglm tests
[shardformer] update bloom/llama/vit/chatglm tests
---
.../test_model/test_shard_bloom.py | 43 ++++++++++------
.../test_model/test_shard_chatglm.py | 48 ++++++++++-------
.../test_model/test_shard_gpt2.py | 16 +++---
.../test_model/test_shard_llama.py | 49 +++++++++++-------
.../test_model/test_shard_opt.py | 51 +++++++++++--------
.../test_model/test_shard_vit.py | 48 ++++++++++-------
6 files changed, 157 insertions(+), 98 deletions(-)
diff --git a/tests/test_shardformer/test_model/test_shard_bloom.py b/tests/test_shardformer/test_model/test_shard_bloom.py
index d5a4ce083e2b..145ccf97c388 100644
--- a/tests/test_shardformer/test_model/test_shard_bloom.py
+++ b/tests/test_shardformer/test_model/test_shard_bloom.py
@@ -36,11 +36,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
-
+ if test_config['precision'] == 'fp32':
+ atol, rtol = 1e-5, 1e-3
+ else:
+ atol, rtol = 5e-3, 5e-3
if org_model.__class__.__name__ == 'BloomModel':
- check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3)
+ check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
- check_loss(org_loss, sharded_loss, atol=1e-6, rtol=1e-3)
+ check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# unwrap model
if org_model.__class__.__name__ == 'BloomModel':
@@ -54,14 +57,22 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
row_layer_for_check = ['h[0].self_attention.query_key_value', 'word_embeddings']
col_layer_for_check = ['h[0].self_attention.dense']
if stage_manager is None or stage_manager.is_first_stage():
- check_grad(bloom, sharded_bloom, row_layer_for_check, tp_group, atol=1e-6, rtol=1e-5, dim=0, verbose=False)
- check_grad(bloom, sharded_bloom, col_layer_for_check, tp_group, atol=1e-6, rtol=1e-5, dim=1, verbose=False)
+ if test_config['precision'] == 'fp32':
+ atol, rtol = 1e-6, 1e-5
+ else:
+ atol, rtol = 5e-3, 5e-3
+ check_grad(bloom, sharded_bloom, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False)
+ check_grad(bloom, sharded_bloom, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False)
# check weights after optimizer.step()
org_optimizer.step()
sharded_optimizer.step()
if stage_manager is None or stage_manager.is_first_stage():
- check_weight(bloom, sharded_bloom, col_layer_for_check, tp_group, atol=1e-4, rtol=1e-3, dim=1, verbose=False)
+ if test_config['precision'] == 'fp32':
+ atol, rtol = 1e-4, 1e-3
+ else:
+ atol, rtol = 5e-3, 5e-3
+ check_weight(bloom, sharded_bloom, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False)
torch.cuda.empty_cache()
@@ -70,29 +81,29 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
- 'enable_fused_normalization': True,
- 'use_lazy_init': True
+ 'enable_all_optimization': True,
+ 'use_lazy_init': True,
+ 'precision': 'fp16',
+ 'initial_scale': 1,
}, {
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 4,
- 'enable_fused_normalization': False,
- 'use_lazy_init': False
+ 'enable_all_optimization': False,
+ 'use_lazy_init': False,
+ 'precision': 'fp32',
}, {
'tp_size': 4,
'pp_size': 1,
- 'enable_fused_normalization': True,
- 'use_lazy_init': False
+ 'enable_all_optimization': True,
+ 'use_lazy_init': False,
+ 'precision': 'fp32',
}])
def run_bloom_test(test_config):
# TODO: add test_config for TP+DP after supporting & debugging it
- # {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True}
-
- # TODO: add test_config for flash attention & jit operator after supporting
sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom')
- test_config['precision'] = 'float' # Do not use fp16/bf16 in testing
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
diff --git a/tests/test_shardformer/test_model/test_shard_chatglm.py b/tests/test_shardformer/test_model/test_shard_chatglm.py
index 69e63ffc854e..e9c74b300daa 100644
--- a/tests/test_shardformer/test_model/test_shard_chatglm.py
+++ b/tests/test_shardformer/test_model/test_shard_chatglm.py
@@ -37,11 +37,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
+ if test_config['precision'] == 'fp32':
+ atol, rtol = 1e-5, 1e-3
+ else:
+ atol, rtol = 5e-3, 5e-3
if org_model.__class__.__name__ == 'ChatGLMModel':
- check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3, dim=1)
+ check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol, dim=1)
- check_loss(org_loss, sharded_loss, atol=1e-6, rtol=1e-3)
+ check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# unwrap model
if org_model.__class__.__name__ == 'ChatGLMModel':
@@ -55,12 +59,16 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
row_layer_for_check = ['encoder.layers[0].self_attention.query_key_value', 'embedding.word_embeddings']
col_layer_for_check = ['encoder.layers[0].self_attention.dense']
if stage_manager is None or stage_manager.is_first_stage():
+ if test_config['precision'] == 'fp32':
+ atol, rtol = 1e-6, 1e-3
+ else:
+ atol, rtol = 5e-3, 5e-3
check_grad(chatglm_model,
shard_chatglm_model,
row_layer_for_check,
tp_group,
- atol=1e-6,
- rtol=1e-3,
+ atol=atol,
+ rtol=rtol,
dim=0,
verbose=False)
@@ -68,8 +76,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
shard_chatglm_model,
col_layer_for_check,
tp_group,
- atol=1e-6,
- rtol=1e-3,
+ atol=atol,
+ rtol=rtol,
dim=1,
verbose=False)
@@ -77,12 +85,16 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
org_optimizer.step()
sharded_optimizer.step()
if stage_manager is None or stage_manager.is_first_stage():
+ if test_config['precision'] == 'fp32':
+ atol, rtol = 1e-4, 1e-3
+ else:
+ atol, rtol = 5e-3, 5e-3
check_weight(chatglm_model,
shard_chatglm_model,
col_layer_for_check,
tp_group,
- atol=1e-4,
- rtol=1e-3,
+ atol=atol,
+ rtol=rtol,
dim=1,
verbose=False)
@@ -93,29 +105,29 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
- 'enable_fused_normalization': True,
- 'use_lazy_init': True
+ 'enable_all_optimization': True,
+ 'use_lazy_init': True,
+ 'precision': 'fp16',
+ 'initial_scale': 1,
}, {
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 4,
- 'enable_fused_normalization': False,
- 'use_lazy_init': False
+ 'enable_all_optimization': False,
+ 'use_lazy_init': False,
+ 'precision': 'fp32',
}, {
'tp_size': 4,
'pp_size': 1,
- 'enable_fused_normalization': True,
- 'use_lazy_init': False
+ 'enable_all_optimization': True,
+ 'use_lazy_init': False,
+ 'precision': 'fp32',
}])
def run_chatglm_test(test_config):
# TODO: add test_config for TP+DP after supporting & debugging it
- # {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True}
-
- # TODO: add test_config for flash attention & jit operator after supporting
sub_model_zoo = model_zoo.get_sub_registry('transformers_chatglm')
- test_config['precision'] = 'float' # Do not use fp16/bf16 in testing
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py
index 274cfaa39ad1..8b7a6bf29c8b 100644
--- a/tests/test_shardformer/test_model/test_shard_gpt2.py
+++ b/tests/test_shardformer/test_model/test_shard_gpt2.py
@@ -63,22 +63,22 @@ def unwrap(module):
row_layer_for_check = ['wte', 'h[0].mlp.c_proj']
# check grad
- if test_config['precision'] == 'fp32':
- atol, rtol = 1e-4, 1e-3
- else:
- atol, rtol = 5e-3, 5e-3
if stage_manager is None or stage_manager.is_first_stage():
+ if test_config['precision'] == 'fp32':
+ atol, rtol = 1e-4, 1e-3
+ else:
+ atol, rtol = 5e-3, 5e-3
check_grad(gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False)
check_grad(gpt2, sharded_gpt2, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False)
# check weights after optimizer.step()
org_optimizer.step()
sharded_optimizer.step()
- if test_config['precision'] == 'fp32':
- atol, rtol = 5e-3, 1e-3
- else:
- atol, rtol = 5e-3, 5e-3
if stage_manager is None or stage_manager.is_first_stage():
+ if test_config['precision'] == 'fp32':
+ atol, rtol = 5e-3, 1e-3
+ else:
+ atol, rtol = 5e-3, 5e-3
check_weight(gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False)
torch.cuda.empty_cache()
diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py
index c5f8d22f18c9..fa4ee43e3114 100644
--- a/tests/test_shardformer/test_model/test_shard_llama.py
+++ b/tests/test_shardformer/test_model/test_shard_llama.py
@@ -41,11 +41,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
+ if test_config['precision'] == 'fp32':
+ atol, rtol = 1e-5, 1e-3
+ else:
+ atol, rtol = 5e-3, 5e-3
if org_model.__class__.__name__ == 'LlamaModel':
- check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3)
+ check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
- check_loss(org_loss, sharded_loss, atol=1e-6, rtol=1e-3)
+ check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# unwrap model
if org_model.__class__.__name__ == 'LlamaModel':
@@ -59,20 +63,24 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
row_layer_for_check = ['layers[0].self_attn.q_proj', 'embed_tokens']
col_layer_for_check = ['layers[0].self_attn.o_proj']
if stage_manager is None or stage_manager.is_first_stage():
+ if test_config['precision'] == 'fp32':
+ atol, rtol = 1e-6, 1e-4
+ else:
+ atol, rtol = 5e-3, 5e-3
check_grad(llama_model,
shard_llama_model,
row_layer_for_check,
tp_group,
- atol=1e-6,
- rtol=1e-4,
+ atol=atol,
+ rtol=rtol,
dim=0,
verbose=False)
check_grad(llama_model,
shard_llama_model,
col_layer_for_check,
tp_group,
- atol=1e-6,
- rtol=1e-4,
+ atol=atol,
+ rtol=rtol,
dim=1,
verbose=False)
@@ -80,12 +88,16 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
org_optimizer.step()
sharded_optimizer.step()
if stage_manager is None or stage_manager.is_first_stage():
+ if test_config['precision'] == 'fp32':
+ atol, rtol = 1e-4, 1e-3
+ else:
+ atol, rtol = 5e-3, 5e-3
check_weight(llama_model,
shard_llama_model,
col_layer_for_check,
tp_group,
- atol=1e-4,
- rtol=1e-3,
+ atol=atol,
+ rtol=rtol,
dim=1,
verbose=False)
@@ -96,33 +108,34 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 2,
- 'enable_fused_normalization': True,
- 'use_lazy_init': True
+ 'enable_all_optimization': True,
+ 'use_lazy_init': True,
+ 'precision': 'fp16',
+ 'initial_scale': 1,
}, {
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 4,
- 'use_lazy_init': False
+ 'use_lazy_init': False,
+ 'precision': 'fp32',
}, {
'tp_size': 4,
'pp_size': 1,
- 'enable_fused_normalization': True,
- 'use_lazy_init': False
+ 'enable_all_optimization': True,
+ 'use_lazy_init': False,
+ 'precision': 'fp32',
}, {
'tp_size': 1,
'pp_size': 4,
'num_microbatches': 4,
- 'use_lazy_init': False
+ 'use_lazy_init': False,
+ 'precision': 'fp32',
}])
def run_llama_test(test_config):
# TODO: add test_config for TP+DP after supporting & debugging it
- # {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True}
-
- # TODO: add test_config for flash attention & jit operator after supporting
sub_model_zoo = model_zoo.get_sub_registry('transformers_llama')
- test_config['precision'] = 'float' # Do not use fp16/bf16 in testing
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py
index d8fa8104bb07..403c3e75f52c 100644
--- a/tests/test_shardformer/test_model/test_shard_opt.py
+++ b/tests/test_shardformer/test_model/test_shard_opt.py
@@ -41,11 +41,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
-
+ if test_config['precision'] == 'fp32':
+ atol, rtol = 1e-5, 1e-3
+ else:
+ atol, rtol = 5e-3, 5e-3
if org_model.__class__.__name__ == 'OPTModel':
- check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3)
+ check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
- check_loss(org_loss, sharded_loss, atol=1e-6, rtol=1e-3)
+ check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# unwrap model
if org_model.__class__.__name__ == 'OPTModel':
@@ -56,23 +59,27 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
shard_opt_model = sharded_model.unwrap().model
# check grad
- row_layer_for_check = ['decoder.layers[0].self_attn.q_proj', 'decoder.embed_tokens']
+ row_layer_for_check = ['decoder.layers[0].self_attn.q_proj', 'decoder.embed_tokens'] # 'decoder.embed_tokens'
col_layer_for_check = ['decoder.layers[0].self_attn.out_proj']
if stage_manager is None or stage_manager.is_first_stage():
+ if test_config['precision'] == 'fp32':
+ atol, rtol = 1e-6, 1e-3
+ else:
+ atol, rtol = 3e-2, 3e-2
check_grad(opt_model,
shard_opt_model,
row_layer_for_check,
tp_group,
- atol=1e-6,
- rtol=1e-3,
+ atol=atol,
+ rtol=rtol,
dim=0,
verbose=False)
check_grad(opt_model,
shard_opt_model,
col_layer_for_check,
tp_group,
- atol=1e-6,
- rtol=1e-3,
+ atol=atol,
+ rtol=rtol,
dim=1,
verbose=False)
@@ -80,12 +87,16 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
org_optimizer.step()
sharded_optimizer.step()
if stage_manager is None or stage_manager.is_first_stage():
+ if test_config['precision'] == 'fp32':
+ atol, rtol = 1e-3, 1e-3
+ else:
+ atol, rtol = 5e-3, 5e-3
check_weight(opt_model,
shard_opt_model,
col_layer_for_check,
tp_group,
- atol=1e-3,
- rtol=1e-3,
+ atol=atol,
+ rtol=rtol,
dim=1,
verbose=False)
@@ -96,29 +107,29 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
- 'enable_fused_normalization': True,
- 'use_lazy_init': True
+ 'enable_all_optimization': True,
+ 'use_lazy_init': True,
+ 'precision': 'fp16',
+ 'initial_scale': 1,
}, {
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 4,
- 'enable_fused_normalization': False,
- 'use_lazy_init': False
+ 'enable_all_optimization': False,
+ 'use_lazy_init': False,
+ 'precision': 'fp32',
}, {
'tp_size': 4,
'pp_size': 1,
- 'enable_fused_normalization': True,
- 'use_lazy_init': False
+ 'enable_all_optimization': True,
+ 'use_lazy_init': False,
+ 'precision': 'fp32',
}])
def run_opt_test(test_config):
# TODO: add test_config for TP+DP after supporting & debugging it
- # {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True}
-
- # TODO: add test_config for flash attention & jit operator after supporting
sub_model_zoo = model_zoo.get_sub_registry('transformers_opt')
- test_config['precision'] = 'float' # Do not use fp16/bf16 in testing
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
diff --git a/tests/test_shardformer/test_model/test_shard_vit.py b/tests/test_shardformer/test_model/test_shard_vit.py
index 8a78d7c2b8ce..919bceffc847 100644
--- a/tests/test_shardformer/test_model/test_shard_vit.py
+++ b/tests/test_shardformer/test_model/test_shard_vit.py
@@ -37,11 +37,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
+ if test_config['precision'] == 'fp32':
+ atol, rtol = 1e-5, 1e-3
+ else:
+ atol, rtol = 5e-3, 5e-3
if org_model.__class__.__name__ == 'ViTModel':
- check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3)
+ check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
- check_loss(org_loss, sharded_loss, atol=1e-6, rtol=1e-3)
+ check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# unwrap model
if org_model.__class__.__name__ == 'ViTModel':
@@ -55,20 +59,24 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
row_layer_for_check = ['encoder.layer[0].attention.attention.query', 'embeddings.patch_embeddings.projection']
col_layer_for_check = ['encoder.layer[0].attention.output.dense']
if stage_manager is None or stage_manager.is_first_stage():
+ if test_config['precision'] == 'fp32':
+ atol, rtol = 1e-5, 1e-3
+ else:
+ atol, rtol = 5e-3, 5e-3
check_grad(vit_model,
shard_vit_model,
row_layer_for_check,
tp_group,
- atol=1e-5,
- rtol=1e-3,
+ atol=atol,
+ rtol=rtol,
dim=0,
verbose=False)
check_grad(vit_model,
shard_vit_model,
col_layer_for_check,
tp_group,
- atol=1e-5,
- rtol=1e-3,
+ atol=atol,
+ rtol=rtol,
dim=1,
verbose=False)
@@ -76,12 +84,16 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
org_optimizer.step()
sharded_optimizer.step()
if stage_manager is None or stage_manager.is_first_stage():
+ if test_config['precision'] == 'fp32':
+ atol, rtol = 5e-3, 1e-3
+ else:
+ atol, rtol = 5e-3, 5e-3
check_weight(vit_model,
shard_vit_model,
col_layer_for_check,
tp_group,
- atol=5e-3,
- rtol=1e-3,
+ atol=atol,
+ rtol=rtol,
dim=1,
verbose=False)
@@ -92,30 +104,30 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
- 'enable_fused_normalization': True,
- 'use_lazy_init': False
+ 'enable_all_optimization': True,
+ 'use_lazy_init': False,
+ 'precision': 'fp16',
+ 'initial_scale': 1,
}, {
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 4,
- 'enable_fused_normalization': False,
- 'use_lazy_init': False
+ 'enable_all_optimization': False,
+ 'use_lazy_init': False,
+ 'precision': 'fp32',
}, {
'tp_size': 4,
'pp_size': 1,
- 'enable_fused_normalization': True,
- 'use_lazy_init': False
+ 'enable_all_optimization': True,
+ 'use_lazy_init': False,
+ 'precision': 'fp32',
}])
def run_vit_test(test_config):
# TODO: add test_config for TP+DP after supporting & debugging it
- # {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True}
-
- # TODO: add test_config for flash attention & jit operator after supporting
# TODO: fix bug when settign lazy_init for Conv2D Layers in ViT models
sub_model_zoo = model_zoo.get_sub_registry('transformers_vit')
- test_config['precision'] = 'float' # Do not use fp16/bf16 in testing
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
From 9d1a6d22e39044f5f3737625356c9e372241043f Mon Sep 17 00:00:00 2001
From: Hongxin Liu
Date: Mon, 14 Aug 2023 17:43:33 +0800
Subject: [PATCH 64/64] [misc] resolve code factor issues (#4433)
---
colossalai/booster/booster.py | 2 +-
colossalai/shardformer/layer/utils.py | 2 -
colossalai/shardformer/modeling/bert.py | 8 +-
colossalai/shardformer/modeling/bloom.py | 12 +-
colossalai/shardformer/modeling/chatglm.py | 2 +-
colossalai/shardformer/modeling/gpt2.py | 2 +-
colossalai/shardformer/modeling/llama.py | 6 +-
colossalai/shardformer/modeling/opt.py | 2 +-
colossalai/shardformer/modeling/t5.py | 6 +-
colossalai/shardformer/modeling/vit.py | 2 +-
colossalai/shardformer/shard/shard_config.py | 1 -
.../test_tracer/test_hf_model/test_hf_gpt.py | 2 +-
.../test_layer/test_qkv_fused_linear_1d.py | 2 +-
.../test_model/test_pure_pipeline.py | 171 ------------------
.../test_model/test_shard_bloom.py | 2 +-
.../test_model/test_shard_chatglm.py | 2 +-
.../test_model/test_shard_gpt2.py | 2 +-
.../test_model/test_shard_llama.py | 2 +-
.../test_model/test_shard_opt.py | 2 +-
.../test_model/test_shard_t5.py | 4 +-
.../test_model/test_shard_vit.py | 4 +-
21 files changed, 32 insertions(+), 206 deletions(-)
delete mode 100644 tests/test_shardformer/test_model/test_pure_pipeline.py
diff --git a/colossalai/booster/booster.py b/colossalai/booster/booster.py
index 8a28b1286cfa..adb8f62a5084 100644
--- a/colossalai/booster/booster.py
+++ b/colossalai/booster/booster.py
@@ -139,7 +139,7 @@ def backward(self, loss: torch.Tensor, optimizer: Optimizer) -> None:
loss (torch.Tensor): The loss to be backpropagated.
optimizer (Optimizer): The optimizer to be updated.
"""
- # TODO: implement this method with plugin
+ # TODO(frank lee): implement this method with plugin
optimizer.backward(loss)
def execute_pipeline(self,
diff --git a/colossalai/shardformer/layer/utils.py b/colossalai/shardformer/layer/utils.py
index 09cb7bfe1407..577bef076a7e 100644
--- a/colossalai/shardformer/layer/utils.py
+++ b/colossalai/shardformer/layer/utils.py
@@ -29,8 +29,6 @@ class Randomizer:
_INDEX = 0
def __init__(self, seed: int):
- # TODO: remove colossalai.context.random
-
self.seed = seed
# Handle CUDA rng state
diff --git a/colossalai/shardformer/modeling/bert.py b/colossalai/shardformer/modeling/bert.py
index eaafd67b8968..5bd1c531cc68 100644
--- a/colossalai/shardformer/modeling/bert.py
+++ b/colossalai/shardformer/modeling/bert.py
@@ -57,7 +57,7 @@ def bert_model_forward(
hidden_states: Optional[torch.FloatTensor] = None, # this is from the previous stage
stage_index: Optional[List[int]] = None,
):
- # TODO: add explaination of the output here.
+ # TODO(jianghai): add explaination of the output here.
r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
@@ -113,7 +113,7 @@ def bert_model_forward(
batch_size, seq_length = input_shape
device = hidden_states.device
- # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future.
+ # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
if output_attentions:
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
output_attentions = False
@@ -272,7 +272,7 @@ def bert_for_pretraining_forward(
logger = logging.get_logger(__name__)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future.
+ # TODO(jianghai) left the recording kv-value tensors as () or None type, this feature may be added in the future.
if output_attentions:
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
output_attentions = False
@@ -534,7 +534,7 @@ def bert_for_next_sentence_prediction_forward(
stage_index: Optional[List[int]] = None,
**kwargs,
):
- #-> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]:
+ # -> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py
index 57c45bc6adfa..12276635ecfa 100644
--- a/colossalai/shardformer/modeling/bloom.py
+++ b/colossalai/shardformer/modeling/bloom.py
@@ -252,7 +252,7 @@ def custom_forward(*inputs):
# Add last hidden state
hidden_states = self.ln_f(hidden_states)
- # TODO: deal with all_hidden_states, all_self_attentions, presents
+ # TODO(jianghai): deal with all_hidden_states, all_self_attentions, presents
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
@@ -307,7 +307,7 @@ def bloom_for_causal_lm_forward(self: BloomForCausalLM,
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future.
+ # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
if output_attentions:
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
output_attentions = False
@@ -402,7 +402,7 @@ def bloom_for_sequence_classification_forward(
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future.
+ # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
if output_attentions:
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
output_attentions = False
@@ -431,7 +431,7 @@ def bloom_for_sequence_classification_forward(
all_cross_attentions = None
if stage_manager.is_last_stage():
batch_size = hidden_states.shape[0]
- #update batch size
+ # update batch size
hidden_states = transformer_outputs[0]
logits = self.score(hidden_states)
@@ -525,7 +525,7 @@ def bloom_for_token_classification_forward(
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future.
+ # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
if output_attentions:
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
output_attentions = False
@@ -611,7 +611,7 @@ def bloom_for_question_answering_forward(
logger = logging.get_logger(__name__)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future.
+ # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
if output_attentions:
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
output_attentions = False
diff --git a/colossalai/shardformer/modeling/chatglm.py b/colossalai/shardformer/modeling/chatglm.py
index a95966c3b99e..409e2e1f5497 100644
--- a/colossalai/shardformer/modeling/chatglm.py
+++ b/colossalai/shardformer/modeling/chatglm.py
@@ -152,7 +152,7 @@ def chatglm_model_forward(
if output_hidden_states is not None else self.config.output_hidden_states)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future.
+ # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
if past_key_values:
logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.')
past_key_values = None
diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py
index a12a9796fa8a..47835d5d5468 100644
--- a/colossalai/shardformer/modeling/gpt2.py
+++ b/colossalai/shardformer/modeling/gpt2.py
@@ -57,7 +57,7 @@ def gpt2_model_forward(
logger = logging.get_logger(__name__)
# Preprocess passed in arguments
- # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future.
+ # TODO(baizhou): left the recording kv-value tensors as () or None type, this feature may be added in the future.
if past_key_values:
logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.')
past_key_values = None
diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py
index 2f54daac586a..f1d2998bbee4 100644
--- a/colossalai/shardformer/modeling/llama.py
+++ b/colossalai/shardformer/modeling/llama.py
@@ -65,7 +65,7 @@ def llama_model_forward(
seq_length_with_past = seq_length
past_key_values_length = 0
- # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future.
+ # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
if output_attentions:
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
output_attentions = False
@@ -216,7 +216,7 @@ def llama_for_causal_lm_forward(
if output_hidden_states is not None else self.config.output_hidden_states)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future.
+ # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
if output_attentions:
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
output_attentions = False
@@ -301,7 +301,7 @@ def llama_for_sequence_classification_forward(
logger = logging.get_logger(__name__)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future.
+ # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
if output_attentions:
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
output_attentions = False
diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py
index 9afdfff4d71d..b4251f33b457 100644
--- a/colossalai/shardformer/modeling/opt.py
+++ b/colossalai/shardformer/modeling/opt.py
@@ -148,7 +148,7 @@ def opt_model_forward(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
use_cache = False
- # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future.
+ # TODO(baizhou): left the recording kv-value tensors as () or None type, this feature may be added in the future.
if past_key_values:
logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.')
past_key_values = None
diff --git a/colossalai/shardformer/modeling/t5.py b/colossalai/shardformer/modeling/t5.py
index d622da452366..9cc071f91dfc 100644
--- a/colossalai/shardformer/modeling/t5.py
+++ b/colossalai/shardformer/modeling/t5.py
@@ -50,7 +50,7 @@ def t5_stack_forward(
logger = logging.get_logger(__name__)
- # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future.
+ # TODO(baizhou): left the recording kv-value tensors as () or None type, this feature may be added in the future.
if past_key_values:
logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.')
past_key_values = None
@@ -285,7 +285,7 @@ def t5_model_forward(
logger = logging.get_logger(__name__)
- # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future.
+ # TODO(baizhou): left the recording kv-value tensors as () or None type, this feature may be added in the future.
if past_key_values:
logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.')
past_key_values = None
@@ -422,7 +422,7 @@ def t5_for_conditional_generation_forward(
logger = logging.get_logger(__name__)
- # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future.
+ # TODO(baizhou): left the recording kv-value tensors as () or None type, this feature may be added in the future.
if past_key_values:
logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.')
past_key_values = None
diff --git a/colossalai/shardformer/modeling/vit.py b/colossalai/shardformer/modeling/vit.py
index eb0ea4c7502b..9fc0b7488803 100644
--- a/colossalai/shardformer/modeling/vit.py
+++ b/colossalai/shardformer/modeling/vit.py
@@ -96,7 +96,7 @@ def pp_forward(
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
- # TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?)
+ # TODO(FoolPlayer): maybe have a cleaner way to cast the input (from `ImageProcessor` side?)
expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype
if pixel_values.dtype != expected_dtype:
pixel_values = pixel_values.to(expected_dtype)
diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py
index ec6e0cd0d4be..0c28f115d018 100644
--- a/colossalai/shardformer/shard/shard_config.py
+++ b/colossalai/shardformer/shard/shard_config.py
@@ -29,7 +29,6 @@ class ShardConfig:
enable_flash_attention: bool = False
enable_jit_fused: bool = False
- # TODO: add support for tensor parallel
# pipeline_parallel_size: int
# data_parallel_size: int
# tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']
diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py
index e29afe786c46..1cd3b90db917 100644
--- a/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py
+++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py
@@ -15,7 +15,7 @@ def test_gpt():
for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items():
model = model_fn()
- # TODO: support the following models
+ # TODO(ver217): support the following models
# 1. GPT2DoubleHeadsModel
# as they are not supported, let's skip them
if model.__class__.__name__ in ['GPT2DoubleHeadsModel', 'GPT2ForQuestionAnswering']:
diff --git a/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py b/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py
index 6f24dc9608bd..b5709d1451f2 100644
--- a/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py
+++ b/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py
@@ -27,7 +27,7 @@ def rearrange(tensor: torch.Tensor, dim: int):
return rearanged_tensor
-# TODO: solve lazy_init True is not working
+# TODO(FoolPlayer): solve lazy_init True is not working
@parameterize('lazy_init', [False])
def check_linear_conv_1d_col(lazy_init: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
diff --git a/tests/test_shardformer/test_model/test_pure_pipeline.py b/tests/test_shardformer/test_model/test_pure_pipeline.py
deleted file mode 100644
index 31e76ef5107c..000000000000
--- a/tests/test_shardformer/test_model/test_pure_pipeline.py
+++ /dev/null
@@ -1,171 +0,0 @@
-import copy
-import random
-from typing import Any, Callable, Iterator, List, Optional, Tuple
-
-import numpy as np
-import pytest
-import torch
-import torch.distributed as dist
-from torch.nn import Module
-from torch.optim import Optimizer
-from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
-from torch.utils.data import DataLoader
-from torch.utils.data.distributed import DistributedSampler
-
-import colossalai
-from colossalai.cluster import ProcessGroupMesh
-from colossalai.interface import ModelWrapper, OptimizerWrapper
-from colossalai.logging import disable_existing_loggers
-from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
-from colossalai.pipeline.stage_manager import PipelineStageManager
-from colossalai.shardformer import ShardConfig, ShardFormer
-from colossalai.testing import (
- assert_hf_output_close,
- clear_cache_before_run,
- parameterize,
- rerun_if_address_is_in_use,
- spawn,
-)
-from tests.kit.model_zoo import model_zoo
-from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward
-
-DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2
-
-
-class PipelineOptimizer(OptimizerWrapper):
-
- def __init__(self, optim: Optimizer, model: Module):
- super().__init__(optim)
- params = set(model.parameters())
- new_param_groups = []
- for group in optim.param_groups:
- params = [p for p in group['params'] if p in params]
- new_param_groups.append({**group, 'params': params})
- optim.__setstate__({'param_groups': new_param_groups})
- # TODO: support amp
-
-
-class PipelinedModel(ModelWrapper):
-
- def __init__(self, module: Module, shard_config: ShardConfig, stage_manager: PipelineStageManager) -> None:
- self.stage_manager = stage_manager
- shardformer = ShardFormer(shard_config)
- module, self.shared_params = shardformer.optimize(module)
- self.shared_param_process_groups = []
- super().__init__(module)
-
-
-def prepare_dataloader(dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0):
- sampler = DistributedSampler(
- dataset,
- # rank=self.pg_mesh.coordinate(DP_AXIS),
- shuffle=shuffle)
-
- # Deterministic dataloader
- def seed_worker(worker_id):
- worker_seed = seed
- np.random.seed(worker_seed)
- torch.manual_seed(worker_seed)
- random.seed(worker_seed)
-
- return DataLoader(
- dataset,
- batch_size=batch_size,
- sampler=sampler,
- worker_init_fn=seed_worker,
- drop_last=drop_last,
- pin_memory=pin_memory,
- num_workers=num_workers,
- )
-
-
-def execute_pipeline(
- data_iter: Iterator,
- model: PipelinedModel,
- criterion: Callable[[Any, Any], torch.Tensor],
- optimizer: PipelineOptimizer,
- return_loss: bool = True,
- return_outputs: bool = False,
- schedule: OneForwardOneBackwardSchedule = None,
-) -> dict:
- # return loss or outputs if needed
- outputs = schedule.forward_backward_step(model, optimizer, data_iter, criterion, return_loss, return_outputs)
- return outputs
-
-
-class data_loader():
-
- def __getitem__(self, x):
- return torch.ones((4, 128), dtype=torch.int).cuda() * 10
-
-
-def loss(y, x):
- return (y[0].float().mean() - x[0].float().mean())
-
-
-@parameterize('enable_fused_normalization', [False])
-@parameterize('enable_tensor_parallelism', [False])
-@parameterize('use_lazy_init', [False])
-def run_llama_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
- PP_DIM = 0
- PP_SIZE = 2
- RANK_TO_COORDINATE = {
- 0: (0, 0),
- 1: (0, 1),
- 2: (1, 0),
- 3: (1, 1),
- }
- PP_RANKS_IN_GROUP = {
- 0: [0, 1],
- 1: [0, 1],
- 2: [2, 3],
- 3: [2, 3],
- }
-
- pg_mesh = ProcessGroupMesh(PP_SIZE)
- stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
- sub_model_zoo = model_zoo.get_sub_registry('transformers_llama')
- for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
- if name != 'transformers_llama':
- continue
- num_microbatches = 2
- org_model = model_fn().cuda()
- data_iter = iter(data_loader())
-
- model_copy = copy.deepcopy(org_model)
- batch = next(data_iter)
- with torch.no_grad():
- y = model_copy(batch)
- org_loss = loss(y, batch)
- optimizer = torch.optim.AdamW(org_model.parameters(), lr=1e-3)
- schedule = OneForwardOneBackwardSchedule(num_microbatches, stage_manager)
- shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
- enable_tensor_parallelism=enable_tensor_parallelism,
- pipeline_stage_manager=stage_manager)
- pipelined_model = PipelinedModel(org_model, shard_config, stage_manager)
- pp_optimizer = PipelineOptimizer(optimizer, pipelined_model)
- results = execute_pipeline(data_iter, pipelined_model, loss, pp_optimizer, schedule=schedule)
-
- if stage_manager.is_last_stage():
- assert results['loss'] == org_loss
- else:
- assert results['loss'] is None
- assert results['outputs'] is None
- torch.cuda.empty_cache()
-
-
-def check_llama(rank, world_size, port):
- disable_existing_loggers()
- colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
- run_llama_test()
-
-
-@pytest.mark.dist
-@rerun_if_address_is_in_use()
-@clear_cache_before_run()
-def test_llama():
- spawn(check_llama, 2)
-
-
-if __name__ == "__main__":
- test_llama()
diff --git a/tests/test_shardformer/test_model/test_shard_bloom.py b/tests/test_shardformer/test_model/test_shard_bloom.py
index 145ccf97c388..ed0d1d8e401d 100644
--- a/tests/test_shardformer/test_model/test_shard_bloom.py
+++ b/tests/test_shardformer/test_model/test_shard_bloom.py
@@ -101,7 +101,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
}])
def run_bloom_test(test_config):
- # TODO: add test_config for TP+DP after supporting & debugging it
+ # TODO(baizhou): add test_config for TP+DP after supporting & debugging it
sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom')
diff --git a/tests/test_shardformer/test_model/test_shard_chatglm.py b/tests/test_shardformer/test_model/test_shard_chatglm.py
index e9c74b300daa..bb77759048b3 100644
--- a/tests/test_shardformer/test_model/test_shard_chatglm.py
+++ b/tests/test_shardformer/test_model/test_shard_chatglm.py
@@ -125,7 +125,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
}])
def run_chatglm_test(test_config):
- # TODO: add test_config for TP+DP after supporting & debugging it
+ # TODO(baizhou): add test_config for TP+DP after supporting & debugging it
sub_model_zoo = model_zoo.get_sub_registry('transformers_chatglm')
diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py
index 8b7a6bf29c8b..ca086bf12776 100644
--- a/tests/test_shardformer/test_model/test_shard_gpt2.py
+++ b/tests/test_shardformer/test_model/test_shard_gpt2.py
@@ -110,7 +110,7 @@ def unwrap(module):
@clear_cache_before_run()
def run_gpt2_test(test_config):
- # TODO: add test_config for TP+DP after supporting & debugging it
+ # TODO(baizhou): add test_config for TP+DP after supporting & debugging it
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py
index fa4ee43e3114..30ebdfbe5cd9 100644
--- a/tests/test_shardformer/test_model/test_shard_llama.py
+++ b/tests/test_shardformer/test_model/test_shard_llama.py
@@ -133,7 +133,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
}])
def run_llama_test(test_config):
- # TODO: add test_config for TP+DP after supporting & debugging it
+ # TODO(baizhou): add test_config for TP+DP after supporting & debugging it
sub_model_zoo = model_zoo.get_sub_registry('transformers_llama')
diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py
index 403c3e75f52c..8d1154d82638 100644
--- a/tests/test_shardformer/test_model/test_shard_opt.py
+++ b/tests/test_shardformer/test_model/test_shard_opt.py
@@ -127,7 +127,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
}])
def run_opt_test(test_config):
- # TODO: add test_config for TP+DP after supporting & debugging it
+ # TODO(baizhou): add test_config for TP+DP after supporting & debugging it
sub_model_zoo = model_zoo.get_sub_registry('transformers_opt')
diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py
index fb065b42250b..066f7ee815b4 100644
--- a/tests/test_shardformer/test_model/test_shard_t5.py
+++ b/tests/test_shardformer/test_model/test_shard_t5.py
@@ -105,10 +105,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
@clear_cache_before_run()
def run_t5_test(test_config):
- # TODO: add plugin_config for TP+DP after supporting & debugging it
+ # TODO(baizhou): add plugin_config for TP+DP after supporting & debugging it
# {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True}
- # TODO: add test_config for flash attention & jit operator after supporting
+ # TODO(baizhou): add test_config for flash attention & jit operator after supporting
sub_model_zoo = model_zoo.get_sub_registry('transformers_t5')
diff --git a/tests/test_shardformer/test_model/test_shard_vit.py b/tests/test_shardformer/test_model/test_shard_vit.py
index 919bceffc847..18df8ef555f2 100644
--- a/tests/test_shardformer/test_model/test_shard_vit.py
+++ b/tests/test_shardformer/test_model/test_shard_vit.py
@@ -124,8 +124,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
}])
def run_vit_test(test_config):
- # TODO: add test_config for TP+DP after supporting & debugging it
- # TODO: fix bug when settign lazy_init for Conv2D Layers in ViT models
+ # TODO(baizhou): add test_config for TP+DP after supporting & debugging it
+ # TODO(baizhou): fix bug when settign lazy_init for Conv2D Layers in ViT models
sub_model_zoo = model_zoo.get_sub_registry('transformers_vit')