diff --git a/.github/workflows/build_on_pr.yml b/.github/workflows/build_on_pr.yml index d112d61dd91d..8d98130f8a32 100644 --- a/.github/workflows/build_on_pr.yml +++ b/.github/workflows/build_on_pr.yml @@ -208,7 +208,7 @@ jobs: - name: Execute Unit Testing run: | - CURL_CA_BUNDLE="" PYTHONPATH=$PWD pytest --testmon --testmon-cov=. --durations=10 tests/ + CURL_CA_BUNDLE="" PYTHONPATH=$PWD pytest -m "not largedist" --testmon --testmon-forceselect --testmon-cov=. --durations=10 tests/ env: DATA: /data/scratch/cifar-10 NCCL_SHM_DISABLE: 1 diff --git a/.github/workflows/compatiblity_test_on_dispatch.yml b/.github/workflows/compatiblity_test_on_dispatch.yml index 1778d64ee287..63c0fbbb975d 100644 --- a/.github/workflows/compatiblity_test_on_dispatch.yml +++ b/.github/workflows/compatiblity_test_on_dispatch.yml @@ -44,7 +44,7 @@ jobs: name: Test for PyTorch Compatibility needs: matrix_preparation if: github.repository == 'hpcaitech/ColossalAI' - runs-on: [self-hosted, gpu] + runs-on: [self-hosted, 8-gpu] strategy: fail-fast: false matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}} diff --git a/.github/workflows/compatiblity_test_on_pr.yml b/.github/workflows/compatiblity_test_on_pr.yml index 0aa9dffeb632..87dd9ef500fe 100644 --- a/.github/workflows/compatiblity_test_on_pr.yml +++ b/.github/workflows/compatiblity_test_on_pr.yml @@ -35,7 +35,7 @@ jobs: name: Test for PyTorch Compatibility needs: matrix_preparation if: github.repository == 'hpcaitech/ColossalAI' - runs-on: [self-hosted, gpu] + runs-on: [self-hosted, 8-gpu] strategy: fail-fast: false matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}} diff --git a/.github/workflows/compatiblity_test_on_schedule.yml b/.github/workflows/compatiblity_test_on_schedule.yml index 15ac4f1a92bb..3f8fc96395c9 100644 --- a/.github/workflows/compatiblity_test_on_schedule.yml +++ b/.github/workflows/compatiblity_test_on_schedule.yml @@ -32,7 +32,7 @@ jobs: name: Test for PyTorch Compatibility needs: matrix_preparation if: github.repository == 'hpcaitech/ColossalAI' - runs-on: [self-hosted, gpu] + runs-on: [self-hosted, 8-gpu] strategy: fail-fast: false matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}} diff --git a/applications/Chat/tests/test_dataset.py b/applications/Chat/tests/test_dataset.py index ea3c7b5851e2..f9dee1bae935 100644 --- a/applications/Chat/tests/test_dataset.py +++ b/applications/Chat/tests/test_dataset.py @@ -14,29 +14,43 @@ from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer SFT_DATASET = [ { - "instruction": "Provide a list of the top 10 most popular mobile games in Asia", - "input": "", - "output": "The top 10 most popular mobile games in Asia are:\n1) PUBG Mobile\n2) Pokemon Go\n3) Candy Crush Saga\n4) Free Fire\n5) Clash of Clans\n6) Mario Kart Tour\n7) Arena of Valor\n8) Fantasy Westward Journey\n9) Subway Surfers\n10) ARK Survival Evolved", - "id": 0 + "instruction": + "Provide a list of the top 10 most popular mobile games in Asia", + "input": + "", + "output": + "The top 10 most popular mobile games in Asia are:\n1) PUBG Mobile\n2) Pokemon Go\n3) Candy Crush Saga\n4) Free Fire\n5) Clash of Clans\n6) Mario Kart Tour\n7) Arena of Valor\n8) Fantasy Westward Journey\n9) Subway Surfers\n10) ARK Survival Evolved", + "id": + 0 }, { - "instruction": "Please provide an action plan for reducing carbon footprint on a corporate level", - "input": "", - "output": "An action plan for reducing carbon footprint on a corporate level could include switching to renewable energy sources (e.g., solar or wind powered), investing in energy efficient infrastructure, implementing green production practices (e.g., reduce water and paper usage, dispose of materials responsibly, recycle/reuse materials), reducing or eliminating travel, encouraging employees to be conscious of and reduce their own energy usage, and incentivizing employees to participate in sustainable work practices.", - "id": 1 + "instruction": + "Please provide an action plan for reducing carbon footprint on a corporate level", + "input": + "", + "output": + "An action plan for reducing carbon footprint on a corporate level could include switching to renewable energy sources (e.g., solar or wind powered), investing in energy efficient infrastructure, implementing green production practices (e.g., reduce water and paper usage, dispose of materials responsibly, recycle/reuse materials), reducing or eliminating travel, encouraging employees to be conscious of and reduce their own energy usage, and incentivizing employees to participate in sustainable work practices.", + "id": + 1 }, { - "instruction": "Write a persuasive email to your boss explaining why you should have a pay raise", - "input": "", - "output": "Dear [Boss], \n\nI hope this message finds you well. I am writing to request a pay raise.\n\nAs you know, I have been a dedicated and hardworking member of the team since I started working here [insert number] of months/years ago. My enthusiasm and passion for my job has remained consistent over the years, and I have always given 100% to my role. \n\nI understand that the current financial situation is challenging, however, I would sincerely appreciate you taking the time to consider my request. I believe that my dedication to the job and the value that I bring to the organization warrants a raise. I work diligently and am confident that I can continue to be an asset to the company. \n\nI hope my request is taken into account and I thank you in advance for your understanding. I look forward to our conversation. \n\nSincerely,\n[Your Name]", - "id": 2 + "instruction": + "Write a persuasive email to your boss explaining why you should have a pay raise", + "input": + "", + "output": + "Dear [Boss], \n\nI hope this message finds you well. I am writing to request a pay raise.\n\nAs you know, I have been a dedicated and hardworking member of the team since I started working here [insert number] of months/years ago. My enthusiasm and passion for my job has remained consistent over the years, and I have always given 100% to my role. \n\nI understand that the current financial situation is challenging, however, I would sincerely appreciate you taking the time to consider my request. I believe that my dedication to the job and the value that I bring to the organization warrants a raise. I work diligently and am confident that I can continue to be an asset to the company. \n\nI hope my request is taken into account and I thank you in advance for your understanding. I look forward to our conversation. \n\nSincerely,\n[Your Name]", + "id": + 2 }, ] PROMPT_DATASET = [ { - "instruction": "Edit this paragraph to make it more concise: \"Yesterday, I went to the store and bought some things. Then, I came home and put them away. After that, I went for a walk and met some friends.\"", - "id": 0 + "instruction": + "Edit this paragraph to make it more concise: \"Yesterday, I went to the store and bought some things. Then, I came home and put them away. After that, I went for a walk and met some friends.\"", + "id": + 0 }, { "instruction": "Write a descriptive paragraph about a memorable vacation you went on", @@ -73,9 +87,7 @@ def make_tokenizer(model: str): return tokenizer -def check_content(input_ids_stripped: torch.Tensor, - tokenizer: PreTrainedTokenizer, - model: str): +def check_content(input_ids_stripped: torch.Tensor, tokenizer: PreTrainedTokenizer, model: str): if model == "opt": # NOTE: Contrary to GPT2, OPT adds the EOS token to the beginning of every prompt. assert input_ids_stripped[0] == tokenizer.eos_token_id @@ -98,13 +110,10 @@ def check_content(input_ids_stripped: torch.Tensor, assert input_ids_stripped != tokenizer.mask_token_id -@pytest.mark.cpu @pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama"]) @pytest.mark.parametrize("max_length", [32, 1024]) @pytest.mark.parametrize("max_datasets_size", [2]) -def test_prompt_dataset(model: str, - max_datasets_size: int, - max_length: int): +def test_prompt_dataset(model: str, max_datasets_size: int, max_length: int): with tempfile.TemporaryDirectory() as tmp_dir: dataset_name = "prompt_dataset.json" with open(os.path.join(tmp_dir, dataset_name), "w") as f: @@ -127,19 +136,12 @@ def test_prompt_dataset(model: str, check_content(input_ids.masked_select(attention_mask), tokenizer, model) -@pytest.mark.cpu @pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama"]) -@pytest.mark.parametrize(["dataset_path", "subset"], [ - ("Anthropic/hh-rlhf", "harmless-base"), - ("Dahoas/rm-static", None) -]) +@pytest.mark.parametrize(["dataset_path", "subset"], [("Anthropic/hh-rlhf", "harmless-base"), + ("Dahoas/rm-static", None)]) @pytest.mark.parametrize("max_datasets_size", [32]) @pytest.mark.parametrize("max_length", [32, 1024]) -def test_reward_dataset(model: str, - dataset_path: str, - subset: Optional[str], - max_datasets_size: int, - max_length: int): +def test_reward_dataset(model: str, dataset_path: str, subset: Optional[str], max_datasets_size: int, max_length: int): data = load_dataset(dataset_path, data_dir=subset) assert max_datasets_size <= len(data["train"]) \ and max_datasets_size <= len(data["test"]) @@ -196,15 +198,12 @@ def test_reward_dataset(model: str, assert torch.all(r_mask) -@pytest.mark.cpu + @pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama", "chatglm"]) @pytest.mark.parametrize("dataset_path", ["yizhongw/self_instruct", None]) @pytest.mark.parametrize("max_dataset_size", [2]) @pytest.mark.parametrize("max_length", [32, 1024]) -def test_sft_dataset(model: str, - dataset_path: Optional[str], - max_dataset_size: int, - max_length: int): +def test_sft_dataset(model: str, dataset_path: Optional[str], max_dataset_size: int, max_length: int): tokenizer = make_tokenizer(model) if dataset_path == "yizhongw/self_instruct": data = load_dataset(dataset_path, "super_natural_instructions") @@ -253,10 +252,7 @@ def test_sft_dataset(model: str, if __name__ == "__main__": - test_sft_dataset(model="bloom", - dataset_path="yizhongw/self_instruct", - max_dataset_size=2, - max_length=256) + test_sft_dataset(model="bloom", dataset_path="yizhongw/self_instruct", max_dataset_size=2, max_length=256) test_reward_dataset(model="gpt2", dataset_path="Anthropic/hh-rlhf", @@ -266,4 +262,5 @@ def test_sft_dataset(model: str, test_prompt_dataset(model="opt", max_datasets_size=2, - max_length=128) \ No newline at end of file + max_length=128) + diff --git a/applications/Chat/tests/test_models.py b/applications/Chat/tests/test_models.py index 7b13becc3656..b98b3615cd28 100644 --- a/applications/Chat/tests/test_models.py +++ b/applications/Chat/tests/test_models.py @@ -16,17 +16,19 @@ from coati.models.utils import calc_action_log_probs, compute_reward, masked_mean from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer -@pytest.mark.gpu @pytest.mark.parametrize("batch_size", [4]) @pytest.mark.parametrize("seq_len", [32]) -@pytest.mark.parametrize("actor_maker", [ - lambda: BLOOMActor(), - lambda: GPTActor(), +@pytest.mark.parametrize( + "actor_maker", + [ + lambda: BLOOMActor(), + lambda: GPTActor(), # HACK: skip llama due to long execution time # lambda: LlamaActor(), lambda: OPTActor(), # lambda: ChatGLMActor(), ]) + @pytest.mark.parametrize("generate_kwargs", [{ "max_length": 64, "use_cache": True, @@ -34,23 +36,15 @@ "temperature": 1.0, "top_k": 50, }]) -def test_generation(actor_maker: Callable[[], Actor], - batch_size: int, - seq_len: int, - generate_kwargs: Dict[str, Any] - ): +def test_generation(actor_maker: Callable[[], Actor], batch_size: int, seq_len: int, generate_kwargs: Dict[str, Any]): actor = actor_maker() input_ids = torch.randint(0, 100, (batch_size, seq_len)).cuda() sequences = generate(actor.cuda(), input_ids, **generate_kwargs) assert sequences.shape == (batch_size, generate_kwargs["max_length"]) -@pytest.mark.cpu def test_utils(): - fn_input = { - "tensor": torch.ones((10, )), - "mask": torch.randint(0, 2, (10, )) - } + fn_input = {"tensor": torch.ones((10,)), "mask": torch.randint(0, 2, (10,))} fn_output = masked_mean(dim=0, **fn_input) assert fn_output.dim() == 0 assert torch.allclose(fn_output, torch.tensor(1.0)) @@ -58,14 +52,14 @@ def test_utils(): batch_size = 4 num_labels = 10 fn_input = { - "r": torch.ones((batch_size, )), + "r": torch.ones((batch_size,)), "kl_coef": 1.0, "log_probs": torch.randn((batch_size, num_labels)), "log_probs_base": torch.randn((batch_size, num_labels)), "action_mask": torch.randint(0, 2, (batch_size, num_labels)) } fn_output = compute_reward(**fn_input) - assert fn_output.shape == (batch_size, ) + assert fn_output.shape == (batch_size,) batch_size = 4 seq_len = 32 @@ -82,17 +76,11 @@ def test_utils(): assert fn_output.shape == (batch_size, num_actions) -@pytest.mark.cpu @pytest.mark.parametrize("lora_rank", [4]) @pytest.mark.parametrize("num_dim", [32]) @pytest.mark.parametrize("num_layers", [4]) -def test_lora(lora_rank: int, - num_dim: int, - num_layers: int): - model = nn.ModuleList( - [nn.Linear(num_dim, num_dim) - for _ in range(num_layers)] - ) +def test_lora(lora_rank: int, num_dim: int, num_layers: int): + model = nn.ModuleList([nn.Linear(num_dim, num_dim) for _ in range(num_layers)]) lora_model = convert_to_lora_module(model, lora_rank) assert isinstance(lora_model, nn.ModuleList) for i in range(num_layers): @@ -105,8 +93,7 @@ def test_lora(lora_rank: int, assert isinstance(lora_model[i], LoraLinear) assert torch.allclose(old_model[i].weight, lora_model[i].weight) assert torch.allclose(old_model[i].bias, lora_model[i].bias) - assert torch.allclose(old_model[i].lora_B @ old_model[i].lora_A, - lora_model[i].lora_B @ lora_model[i].lora_A) + assert torch.allclose(old_model[i].lora_B @ old_model[i].lora_A, lora_model[i].lora_B @ lora_model[i].lora_A) optimizer = torch.optim.Adam(lora_model.parameters()) x = torch.randn(8, num_dim) for i in range(num_layers): @@ -122,12 +109,13 @@ def test_lora(lora_rank: int, lora_model[i].lora_B @ lora_model[i].lora_A) -@pytest.mark.cpu @pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("seq_len", [128]) -@pytest.mark.parametrize("models_maker", [ - lambda: (BLOOMActor(), BLOOMCritic(), BLOOMRM()), - lambda: (GPTActor(), GPTCritic(), GPTRM()), +@pytest.mark.parametrize( + "models_maker", + [ + lambda: (BLOOMActor(), BLOOMCritic(), BLOOMRM()), + lambda: (GPTActor(), GPTCritic(), GPTRM()), # HACK: skip llama due to long execution time # lambda: (LlamaActor(), LlamaCritic(), LlamaRM()), lambda: (OPTActor(), OPTCritic(), OPTRM()), @@ -178,13 +166,10 @@ def test_models(models_maker: Callable[[], Tuple[Actor, Critic, RewardModel]], assert rm_output.shape == (batch_size, ) -@pytest.mark.cpu @pytest.mark.parametrize("batch_size", [16]) @pytest.mark.parametrize("seq_len", [128]) @pytest.mark.parametrize("num_labels", [100]) -def test_loss(batch_size: int, - seq_len: int, - num_labels: int): +def test_loss(batch_size: int, seq_len: int, num_labels: int): loss = GPTLMLoss() loss_input = { "logits": torch.randn(batch_size, seq_len, num_labels), @@ -194,54 +179,43 @@ def test_loss(batch_size: int, loss = PolicyLoss() loss_input = { - "log_probs": torch.randn(batch_size, ), - "old_log_probs": torch.randn(batch_size, ), - "advantages": torch.randn(batch_size, ) + "log_probs": torch.randn(batch_size,), + "old_log_probs": torch.randn(batch_size,), + "advantages": torch.randn(batch_size,) } loss_output = loss(**loss_input) loss = ValueLoss() loss_input = { - "values": torch.randn(batch_size, ), - "old_values": torch.randn(batch_size, ), - "reward": torch.randn(batch_size, ) + "values": torch.randn(batch_size,), + "old_values": torch.randn(batch_size,), + "reward": torch.randn(batch_size,) } loss_output = loss(**loss_input) loss = LogSigLoss() loss_input = { - "chosen_reward": torch.randn(batch_size, ), - "reject_reward": torch.randn(batch_size, ), + "chosen_reward": torch.randn(batch_size,), + "reject_reward": torch.randn(batch_size,), } loss_output = loss(**loss_input) loss = LogExpLoss() loss_input = { - "chosen_reward": torch.randn(batch_size, ), - "reject_reward": torch.randn(batch_size, ), + "chosen_reward": torch.randn(batch_size,), + "reject_reward": torch.randn(batch_size,), } loss_output = loss(**loss_input) if __name__ == "__main__": - generate_kwargs = dict(max_length=40, - use_cache=True, - do_sample=True, - temperature=1.0, - top_k=50) - test_generation(lambda: LlamaActor(), - batch_size=4, - seq_len=32, - generate_kwargs=generate_kwargs) + generate_kwargs = dict(max_length=40, use_cache=True, do_sample=True, temperature=1.0, top_k=50) + test_generation(lambda: LlamaActor(), batch_size=4, seq_len=32, generate_kwargs=generate_kwargs) test_utils() test_lora(lora_rank=2, num_dim=8, num_layers=2) - test_models(models_maker=lambda: (BLOOMActor(), - BLOOMCritic(), - BLOOMRM()), - batch_size=8, - seq_len=128) + test_models(models_maker=lambda: (BLOOMActor(), BLOOMCritic(), BLOOMRM()), batch_size=8, seq_len=128) test_loss(batch_size=8, seq_len=128, num_labels=100) \ No newline at end of file diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 54d815ce701e..de03ba27bfda 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -15,6 +15,7 @@ get_model_base_filenames, get_optimizer_base_filenames, load_shard_state_dict, + save_config_file, save_state_dict, save_state_dict_shards, ) @@ -107,6 +108,7 @@ def save_sharded_model(self, if self.coordinator.is_master(): index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) + save_config_file(model.module, checkpoint_path) logging.info(f"The model is split into checkpoint shards. " f"You can find where each parameters has been saved in the " f"index located at {save_index_file}.") diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 28a19af0ce91..d33e3485c39c 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1,19 +1,22 @@ import random from contextlib import nullcontext -from typing import Any, Callable, Iterator, List, Optional, Tuple, Union +from functools import partial +from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple, Union import numpy as np import torch import torch.distributed as dist from torch.distributed import ProcessGroup -from torch.nn import Module +from torch.nn import Module, SyncBatchNorm +from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler +from torch.utils._pytree import tree_map from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer -from colossalai.checkpoint_io import CheckpointIO +from colossalai.checkpoint_io import CheckpointIO, HypridParallelCheckpointIO from colossalai.cluster import ProcessGroupMesh from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule @@ -26,26 +29,52 @@ DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2 +def _convert_floating_point(x, dtype: torch.dtype = torch.float16): + if isinstance(x, torch.Tensor) and torch.is_floating_point(x): + return x.to(dtype) + return x + + class HybridParallelModule(ModelWrapper): - def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp_group: ProcessGroup) -> None: + def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp_group: ProcessGroup, use_ddp: bool, + ddp_config: dict) -> None: + self.stage_manager = shard_config.pipeline_stage_manager self.dp_group = dp_group + shardformer = ShardFormer(shard_config) module, self.shared_params = shardformer.optimize(module) - # TODO(ver217): add input type cast + + # setting process groups for shared parameters self.shared_param_process_groups = [] for shared_param in self.shared_params: if len(shared_param) > 0: self.shared_param_process_groups.append( self.stage_manager.init_process_group_by_stages(list(shared_param.keys()))) + + # setting mixed_precision + self.mixed_precision = None if precision == 'fp16': - module = module.half().cuda() + self.mixed_precision = torch.float16 elif precision == 'bf16': - module = module.to(dtype=torch.bfloat16).cuda() - else: - module = module.cuda() # train without AMP - # TODO(ver217): support TP+DP + self.mixed_precision = torch.bfloat16 + if self.mixed_precision is not None: + module = module.to(self.mixed_precision) + module = module.cuda() + + # setting input type cast when using mixed precision + self.convert_fn = None + if self.mixed_precision is not None: + self.convert_fn = partial(_convert_floating_point, dtype=self.mixed_precision) + + # setting ddp configs + if use_ddp: + # convert model to sync bn + module = SyncBatchNorm.convert_sync_batchnorm(module, dp_group) + # wrap the model with PyTorch DDP + module = DDP(module, process_group=dp_group, **ddp_config) + super().__init__(module) def sync_shared_params(self): @@ -68,19 +97,62 @@ def sync_grads(self): dist.all_reduce(p.grad, group=self.dp_group) p.grad.div_(self.dp_group.size()) + def forward(self, *args, **kwargs): + if self.convert_fn is not None: + args = tree_map(self.convert_fn, args) + kwargs = tree_map(self.convert_fn, kwargs) + return super().forward(*args, **kwargs) + + def unwrap(self): + module = super().unwrap() + if isinstance(module, DDP): + module = module.module + return module + + +def get_param_info(optim: Optimizer): + # Get a backup of necessary information of parameters for future use, which includes: + # 1. A complete param_group, with params in the form of param_id + # 2. A mapping from param address (obtained using id(param)) to integer param_id + # 3. A mapping from integer param_id to param address. + # 4. A mapping from param_address (obtained using id(param)) to the original shape of parameter before sharding. + # When Zero is used, the params here are fp16/bf16 model params rather than fp32 master params in optimizer. + + if optim is None: + return {} + param_info = {'param_groups': [], 'param2id': {}, 'id2param': {}, 'param2shape': {}} + start_index = 0 + for group in optim.param_groups: + + packed_group = {k: v for k, v in group.items() if k != 'params'} + packed_group['params'] = [] + + for param_id, param in enumerate(group['params'], start_index): + original_shape = param.shape if isinstance(param, torch.Tensor) else None + packed_group['params'].append(param_id) + param_info['param2id'][id(param)] = param_id + param_info['id2param'][param_id] = id(param) + param_info['param2shape'][id(param)] = original_shape + + param_info['param_groups'].append(packed_group) + start_index += len(group['params']) + + return param_info + def init_pipeline_optimizer(optim: Optimizer, model: Module): - params = set(model.parameters()) + model_params = set(model.parameters()) new_param_groups = [] for group in optim.param_groups: - params = [p for p in group['params'] if p in params] + params = [p for p in group['params'] if p in model_params] new_param_groups.append({**group, 'params': params}) optim.__setstate__({'param_groups': new_param_groups}) class HybridParallelNaiveOptimizer(OptimizerWrapper): - def __init__(self, optim: Optimizer, model: Module, use_pipeline: bool): + def __init__(self, optim: Optimizer, model: Module, use_pipeline: bool, param_info: OrderedDict): + self.param_info = param_info if use_pipeline: init_pipeline_optimizer(optim, model) super().__init__(optim) @@ -92,6 +164,7 @@ def __init__(self, optim: Optimizer, model: Module, use_pipeline: bool, + param_info: OrderedDict, precision: str = 'fp16', initial_scale: float = 2**16, min_scale: float = 1, @@ -101,6 +174,7 @@ def __init__(self, hysteresis: int = 2, max_scale: float = 2**32, max_norm: float = 0): + self.param_info = param_info if use_pipeline: init_pipeline_optimizer(optim, model) super().__init__(optim, precision, initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, @@ -114,6 +188,7 @@ def __init__( optimizer: Optimizer, model: Module, use_pipeline: bool, + param_info: OrderedDict, initial_scale: int = 2**16, # grad scaler config min_scale: int = 1, growth_factor: float = 2., @@ -131,6 +206,7 @@ def __init__( dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm tp_process_group: Optional[ProcessGroup] = None, # if using tp forced_dtype: Optional[torch.dtype] = None): + self.param_info = param_info if use_pipeline: init_pipeline_optimizer(optimizer, model) super().__init__(optimizer, initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, @@ -140,34 +216,100 @@ def __init__( class HybridParallelPlugin(PipelinePluginBase): + """ + Plugin for Hybrid Parallel Training. + Tensor parallel, pipeline parallel and data parallel(DDP/ZeRO) can be picked and combined in this plugin. + The size of tp and pp should be passed in by user, then the size of dp is automatically calculated from dp_size = world_size / (tp_size * pp_size). + + Example: + >>> from colossalai.booster import Booster + >>> from colossalai.booster.plugin import HybridParallelPlugin + + >>> model, train_dataset, optimizer, criterion = ... + >>> plugin = HybridParallelPlugin(tp_size=2, pp_size=2) + + >>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8) + >>> booster = Booster(plugin=plugin) + >>> model, optimizer, criterion, train_dataloader, _ = booster.boost(model, optimizer, criterion, train_dataloader) + + Args: + tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1. + pp_size (int): The number of pipeline stages in pipeline parallelism. Pipeline parallelism will not be used when pp_size is set to 1. + precision (str, optional): Specifies the precision of parameters during training. + Auto-mixied precision will be used when this argument is set to 'fp16' or 'bf16', otherwise model is trained with 'fp32'. + Defaults to 'fp16'. + zero_stage (int, optional): The stage of ZeRO for data parallelism. Can only be choosed from [0, 1, 2]. + When set to 0, ZeRO will not be used. Defaults to 0. + enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer. + Currently all the optimization methods include fused normalization, flash attention and JIT. + Defaults to False. + enable_fused_normalization (bool, optional): Whether to switch on fused normalization. Defaults to False. + enable_flash_attention (bool, optional): Whether to switch on flash attention. Defaults to False. + enable_jit_fused (bool, optional): Whether to switch on JIT. Default to Falase. + num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None. + microbatch_size (int, optional): Microbatch size when using pipeline parallelism. + Either ``num_microbatches`` or ``microbatch_size`` should be provided if using pipeline. + If ``num_microbatches`` is provided, this will be ignored. Defaults to None. + initial_scale (float, optional): The initial loss scale of AMP. Defaults to 2**16. + min_scale (float, optional): The minimum loss scale of AMP. Defaults to 1. + growth_factor (float, optional): The multiplication factor for increasing loss scale when using AMP. Defaults to 2. + backoff_factor (float, optional): The multiplication factor for decreasing loss scale when using AMP. Defaults to 0.5. + growth_interval (int, optional): The number of steps to increase loss scale when no overflow occurs when using AMP. Defaults to 1000. + hysteresis (int, optional): The number of overflows before decreasing loss scale when using AMP. Defaults to 2. + max_scale (float, optional): The maximum loss scale of AMP. Defaults to 2**32. + max_norm (float, optional): Maximum norm for gradient clipping. Defaults to 0. + broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training when using DDP. Defaults to True. + ddp_bucket_cap_mb (int, optional): The bucket size in MB when using DDP. Defaults to 25. + find_unused_parameters (bool, optional): Whether to find unused parameters when using DDP. Defaults to False. + check_reduction (bool, optional): Whether to check reduction when using DDP. Defaults to False. + gradient_as_bucket_view (bool, optional): Whether to use gradient as bucket view when using DDP. Defaults to False. + static_graph (bool, optional): Whether to use static graph when using DDP. Defaults to False. + zero_bucket_size_in_m (int, optional): Gradient reduce bucket size in million elements when using ZeRO. Defaults to 12. + cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False. + communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None. + overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True. + """ + + def __init__(self, + tp_size: int, + pp_size: int, + precision: str = 'fp16', + zero_stage: int = 0, + enable_all_optimization: bool = False, + enable_fused_normalization: bool = False, + enable_flash_attention: bool = False, + enable_jit_fused: bool = False, + enable_sequence_parallelism: bool = False, + enable_sequence_overlap: bool = False, + num_microbatches: Optional[int] = None, + microbatch_size: Optional[int] = None, + initial_scale: float = 2**16, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + max_norm: float = 0, + broadcast_buffers: bool = True, + ddp_bucket_cap_mb: int = 25, + find_unused_parameters: bool = False, + check_reduction: bool = False, + gradient_as_bucket_view: bool = False, + static_graph: bool = False, + zero_bucket_size_in_m: int = 12, + cpu_offload: bool = False, + communication_dtype: Optional[torch.dtype] = None, + overlap_communication: bool = True) -> None: - def __init__( - self, - tp_size: int, - pp_size: int, - precision: str = 'fp16', - zero_stage: int = 0, - cpu_offload: bool = False, - enable_all_optimization: bool = False, - enable_fused_normalization: bool = False, - enable_flash_attention: bool = False, - enable_jit_fused: bool = False, - num_microbatches: Optional[int] = None, - initial_scale: float = 2**16, - min_scale: float = 1, - growth_factor: float = 2, - backoff_factor: float = 0.5, - growth_interval: int = 1000, - hysteresis: int = 2, - max_scale: float = 2**32, - max_norm: float = 0, - ) -> None: super().__init__() assert dist.get_world_size() % ( tp_size * pp_size ) == 0, f'world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}' - # TODO(ver217): support zero - assert zero_stage == 0, 'zero is not support yet' + + if enable_sequence_parallelism: + assert tp_size > 1, 'Sequence parallelism must be enabled when using tensor parallelism' + self.tp_size = tp_size self.pp_size = pp_size self.dp_size = dist.get_world_size() // (tp_size * pp_size) @@ -178,24 +320,30 @@ def __init__( self.enable_fused_normalization = enable_fused_normalization self.enable_flash_attention = enable_flash_attention self.enable_jit_fused = enable_jit_fused + self.enable_sequence_parallelism = enable_sequence_parallelism self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size) self.stage_manager = None self.schedule = None assert zero_stage in (0, 1, 2) if self.pp_size > 1: - assert num_microbatches is not None, 'num_microbatches must be specified when using pipeline parallelism' + assert num_microbatches is not None or microbatch_size is not None, 'num_microbatches or microbatch_size must be specified when using pipeline parallelism' assert self.zero_stage <= 1, 'zero stage must be 0 or 1 when using pipeline parallelism' self.stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS) - self.schedule = OneForwardOneBackwardSchedule(num_microbatches, self.stage_manager) + self.schedule = OneForwardOneBackwardSchedule(self.stage_manager, + num_microbatches=num_microbatches, + microbatch_size=microbatch_size) self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) + self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS) self.shard_config = ShardConfig(tensor_parallel_process_group=self.tp_group, pipeline_stage_manager=self.stage_manager, enable_tensor_parallelism=self.tp_size > 1, enable_all_optimization=self.enable_all_optimization, enable_fused_normalization=self.enable_fused_normalization, enable_flash_attention=self.enable_flash_attention, - enable_jit_fused=self.enable_jit_fused) + enable_jit_fused=self.enable_jit_fused, + enable_sequence_parallelism=enable_sequence_parallelism, + enable_sequence_overlap=enable_sequence_overlap) self.amp_config = dict( initial_scale=initial_scale, growth_factor=growth_factor, @@ -205,6 +353,20 @@ def __init__( min_scale=min_scale, max_scale=max_scale, ) + + self.ddp_config = dict(broadcast_buffers=broadcast_buffers, + bucket_cap_mb=ddp_bucket_cap_mb, + find_unused_parameters=find_unused_parameters, + check_reduction=check_reduction, + gradient_as_bucket_view=gradient_as_bucket_view, + static_graph=static_graph) + + self.zero_config = dict(reduce_bucket_size=zero_bucket_size_in_m * 1024 * 1024, + communication_dtype=communication_dtype, + overlap_communication=overlap_communication, + cpu_offload=cpu_offload, + partition_grad=(self.zero_stage == 2)) + self.max_norm = max_norm @property @@ -237,32 +399,44 @@ def configure( dataloader: Optional[DataLoader] = None, lr_scheduler: Optional[LRScheduler] = None, ) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: + param_info = get_param_info(optimizer) if not isinstance(model, ModelWrapper): - model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group) + use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0 + model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group, use_ddp, + self.ddp_config) if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): if self.zero_stage == 0: if self.precision in ['fp16', 'bf16']: optimizer = HybridParallelAMPOptimizer(optimizer, model, use_pipeline=self.enable_pipeline_parallelism, + param_info=param_info, precision=self.precision, max_norm=self.max_norm, **self.amp_config) + self.checkpoint_io.link_master_and_working_param(optimizer.working_to_master_map, + optimizer.master_to_working_map) else: optimizer = HybridParallelNaiveOptimizer(optimizer, model, - use_pipeline=self.enable_pipeline_parallelism) + use_pipeline=self.enable_pipeline_parallelism, + param_info=param_info) else: + assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1." + assert self.precision != 'fp32', "Please set precision to 'fp16' or 'bf16' when using ZeRO." optimizer = HybridParallelZeroOptimizer(optimizer, model, use_pipeline=self.enable_pipeline_parallelism, - partition_grad=(self.zero_stage == 2), - cpu_offload=self.cpu_offload, + param_info=param_info, dp_process_group=self.dp_group, tp_process_group=self.tp_group, verbose=True, clip_grad_norm=self.max_norm, + **self.zero_config, **self.amp_config) + self.checkpoint_io.link_master_and_working_param(optimizer._param_store.working_to_master_param, + optimizer._param_store.master_to_working_param) + return model, optimizer, criterion, dataloader, lr_scheduler def execute_pipeline(self, @@ -339,7 +513,8 @@ def seed_worker(worker_id): **_kwargs) def get_checkpoint_io(self) -> CheckpointIO: - return None + self.checkpoint_io = HypridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) + return self.checkpoint_io def no_sync(self, model: Module) -> Iterator[None]: raise NotImplementedError diff --git a/colossalai/checkpoint_io/__init__.py b/colossalai/checkpoint_io/__init__.py index c25048e25754..07b1f81dace6 100644 --- a/colossalai/checkpoint_io/__init__.py +++ b/colossalai/checkpoint_io/__init__.py @@ -1,5 +1,6 @@ from .checkpoint_io_base import CheckpointIO from .general_checkpoint_io import GeneralCheckpointIO +from .hybrid_parallel_checkpoint_io import HypridParallelCheckpointIO from .index_file import CheckpointIndexFile -__all__ = ['CheckpointIO', 'CheckpointIndexFile', 'GeneralCheckpointIO'] +__all__ = ['CheckpointIO', 'CheckpointIndexFile', 'GeneralCheckpointIO', 'HybridParallelCheckpointIO'] diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index 34210ea52162..faaf1d22722a 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -23,6 +23,7 @@ load_state_dict, load_state_dict_into_model, load_states_into_optimizer, + save_config_file, save_param_groups, save_state_dict, save_state_dict_shards, @@ -183,6 +184,7 @@ def save_sharded_model(self, index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) + save_config_file(model, checkpoint_path, is_master=True) logging.info(f"The model is going to be split to checkpoint shards. " f"You can find where each parameters has been saved in the " f"index located at {save_index_file}.") diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py new file mode 100644 index 000000000000..fef5b0d16d60 --- /dev/null +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -0,0 +1,702 @@ +import copy +import gc +import logging +import os +from pathlib import Path +from shutil import rmtree +from typing import Dict, Iterator, Optional, OrderedDict, Tuple, Union + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.distributed import ProcessGroup +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler as LRScheduler + +from colossalai.interface import OptimizerWrapper + +from .general_checkpoint_io import GeneralCheckpointIO +from .index_file import CheckpointIndexFile +from .utils import ( + StateDictSharder, + gather_distributed_param, + get_model_base_filenames, + get_optimizer_base_filenames, + is_safetensors_available, + load_shard_state_dict, + load_state_dict_into_model, + load_states_into_optimizer, + save_config_file, + save_param_groups, + save_state_dict_shards, + search_tp_partition_dim, + sharded_optimizer_loading_epilogue, +) + +try: + from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys +except ImportError: + _EXTRA_STATE_KEY_SUFFIX = '_extra_state' + + +class HypridParallelCheckpointIO(GeneralCheckpointIO): + """ + CheckpointIO for Hybrid Parallel Training. + + Args: + dp_group (ProcessGroup): Process group along data parallel dimension. + pp_group (ProcessGroup): Process group along pipeline parallel dimension. + tp_group (ProcessGroup): Process group along tensor parallel dimension. + zero_stage (int): The zero stage of plugin. Should be in [0, 1, 2]. + verbose (bool, optional): Whether to print logging massage when saving/loading has been succesfully executed. Defaults to True. + """ + + def __init__(self, + dp_group: ProcessGroup, + pp_group: ProcessGroup, + tp_group: ProcessGroup, + zero_stage: int, + verbose: bool = True) -> None: + super().__init__() + self.dp_group = dp_group + self.pp_group = pp_group + self.tp_group = tp_group + self.dp_rank = dist.get_rank(self.dp_group) + self.tp_rank = dist.get_rank(self.tp_group) + self.pp_rank = dist.get_rank(self.pp_group) + self.dp_size = dist.get_world_size(dp_group) + self.pp_size = dist.get_world_size(pp_group) + self.tp_size = dist.get_world_size(tp_group) + self.use_zero = (zero_stage > 0) + self.verbose = verbose + self.working_to_master_map = None + self.master_to_working_map = None + + @staticmethod + def _model_sharder(model: nn.Module, + prefix: str = '', + keep_vars: bool = False, + size_per_shard: int = 1024) -> Iterator[Tuple[OrderedDict, int]]: + # An internel method that breaks state_dict of model into shards within limited size. + + state_dict_sharder = StateDictSharder(size_per_shard) + + # Save parameters. + for name, param in model.named_parameters(): + if param is None: + continue + # Gather tensor pieces when using tensor parallel. + param_ = gather_distributed_param(param, keep_vars=False) + block, block_size = state_dict_sharder.append_param(prefix + name, param_) + if block is not None: + yield block, block_size + + # Save buffers. + for name, buf in model.named_buffers(): + if buf is not None and name not in model._non_persistent_buffers_set: + buffer = buf if keep_vars else buf.detach() + block, block_size = state_dict_sharder.append_param(prefix + name, buffer) + if block is not None: + yield block, block_size + + # Save extra states. + extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX + if getattr(model.__class__, "get_extra_state", + torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state: + extra_state = model.get_extra_state() + block, block_size = state_dict_sharder.append_param(extra_state_key, extra_state) + if block is not None: + yield block, block_size + + # Return the last block in sharder. + yield state_dict_sharder.current_block, state_dict_sharder.current_block_size + + @staticmethod + def _optimizer_sharder(optimizer: OptimizerWrapper, + use_zero: bool, + dp_group: ProcessGroup, + tp_group: ProcessGroup, + master_to_working_map: Optional[Dict[int, torch.Tensor]] = None, + size_per_shard: int = 1024): + + # An internel method that breaks state_dict of optimizer into shards within limited size. + + state_dict_sharder = StateDictSharder(size_per_shard) + param_info = optimizer.param_info + + for param, state in optimizer.optim.state.items(): + + if param is None: + continue + + if master_to_working_map is not None: + working_param = master_to_working_map[id(param)] + else: + working_param = param + + param_id = param_info['param2id'][id(working_param)] + original_shape = param_info['param2shape'][id(working_param)] + state_ = HypridParallelCheckpointIO.gather_from_sharded_optimizer_state(state, + working_param, + original_shape=original_shape, + dp_group=dp_group, + tp_group=tp_group, + use_zero=use_zero, + inplace=False) + + block, block_size = state_dict_sharder.append_optim_state(param_id, state_) + if block is not None: + yield block, block_size + + # Return the last block in sharder. + yield state_dict_sharder.current_block, state_dict_sharder.current_block_size + + def save_sharded_model(self, + model: nn.Module, + checkpoint: str, + gather_dtensor: bool = True, + prefix: Optional[str] = None, + size_per_shard: int = 1024, + use_safetensors: bool = False) -> None: + """ + Save sharded model checkpoint under the given checkpointing path. + The following files will be created under the path: + - An index file (pytorch_model.bin.index.json) containing a map between model params/buffers and file names. + - Multiple files that store state tensors of models. + If pipeline parallelism is used, the filenames are in the form of "pytorch_model.-stage-000XX-shard-000XX.bin". + If pipeline parallelism is not used, "pytorch_model.-000XX.bin" + + + Args: + model (nn.Module): Model on local device to be saved. + checkpoint (str): Checkpointing path which should be a directory path. + gather_dtensor (bool, optional): Whether to gather_dtensor, currently not used. Defaults to True. + prefix (str, optional): Perfix of file to save. Defaults to None. + size_per_shard (int, optional): Size per shard in MB. Defaults to 1024. + use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False. + """ + + if os.path.isfile(checkpoint): + logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") + return + + Path(checkpoint).mkdir(parents=True, exist_ok=True) + + # Devices along the same dp_group share the same copies of model. + # So only let the device with dp_rank == 0 save the model. + if self.dp_rank != 0: + return + + # Then collect the sharded parameters & buffers along tp_group. + # Only devices with tp_rank == 0 are responsible for model saving. + state_dict_shard = HypridParallelCheckpointIO._model_sharder(model, size_per_shard=size_per_shard) + weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors) + index_file = CheckpointIndexFile(checkpoint) + control_saving = (self.tp_rank == 0) + + if self.pp_size == 1: + # When pipeline is not used, save the model shards as in general checkpointIO + total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=weights_name, + is_master=control_saving, + use_safetensors=use_safetensors) + if control_saving: + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + save_config_file(model, checkpoint) + if self.verbose: + logging.info(f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}.") + + else: + # When pipeline is used, each stage produces its own shard files and index files. + # Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/ + # After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder. + + final_index_file_path = copy.deepcopy(save_index_file) + tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files") + Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True) + + # Manage filenames of sharded weights and index file for each pipeline stage. + weights_name = weights_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-shard.bin") + weights_name = weights_name.replace(".safetensors", f"-stage-{self.pp_rank+1:05d}-shard.safetensors") + save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json") + save_index_file = os.path.join("tmp_index_files", save_index_file) + + total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=weights_name, + is_master=control_saving, + use_safetensors=use_safetensors, + use_pp_format=True) + if control_saving: + assert self.dp_rank == 0 and self.tp_rank == 0, "The saving process should have both dp_rank and tp_rank as 0." + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + else: + return + + dist.barrier(self.pp_group) + + # The global master rank integrates the index files and clean the folder. + if self.pp_rank == 0: + final_index_file = CheckpointIndexFile(checkpoint) + final_index_file.append_meta_data("total_size", 0) + + for filename in os.listdir(tmp_index_file_folder): + stage_index_file = CheckpointIndexFile.from_file(os.path.join(tmp_index_file_folder, filename)) + final_index_file.metadata["total_size"] += stage_index_file.metadata["total_size"] + for weight, weight_filename in stage_index_file.weight_map.items(): + final_index_file.append_weight_map(weight, weight_filename) + + final_index_file.write_index_file(final_index_file_path) + save_config_file(model, checkpoint) + rmtree(tmp_index_file_folder) + if self.verbose: + logging.info(f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {final_index_file_path}.") + + def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False): + """ + Load sharded model with the given path to index file of checkpoint folder. + + Args: + model (nn.Module): The model to be loaded. + checkpoint_index_file (str): Path to the index file of checkpointing folder. + strict (bool, optional): For name matching during loading state_dict. Defaults to False. + This argument should be manually set to False since params on same device might be stored in different files. + """ + + # Check whether the checkpoint uses safetensors. + use_safetensors = False + if "safetensors" in checkpoint_index_file.name: + use_safetensors = True + + if use_safetensors and not is_safetensors_available(): + raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.") + + # Read checkpoint index file. + ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file) + ckpt_root_path = ckpt_index_file.root_path + weight_map = ckpt_index_file.weight_map + strict = False + + # Load params & buffers to model. + # Keep a record of loaded files so that file will not be repeatedly loaded. + loaded_file = set() + + def _load(name: str): + if name not in weight_map: + raise ValueError(f"{name} is not stored in checkpoint, please check your checkpointing configuration!") + filename = weight_map[name] + + # If this param/buffer has been loaded before, directly return. + if filename in loaded_file: + return + + file_path = os.path.join(ckpt_root_path, filename) + state_dict = load_shard_state_dict(Path(file_path), use_safetensors) + missing_keys = [] + + load_state_dict_into_model(model, + state_dict, + missing_keys=missing_keys, + strict=strict, + load_sub_module=True) + loaded_file.add(filename) + + # Load parameters. + for name, _ in model.named_parameters(): + _load(name) + + # Load buffers. + non_persistent_buffers = set() + for n, m in model.named_modules(): + non_persistent_buffers |= set('.'.join((n, b)) for b in m._non_persistent_buffers_set) + for name, buf in model.named_buffers(): + if buf is not None and name not in non_persistent_buffers: + _load(name) + + # Load extra states. + extra_state_key = _EXTRA_STATE_KEY_SUFFIX + if getattr(model.__class__, "get_extra_state", + torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state: + _load(extra_state_key) + + # Update master params if mixed-precision training is enabled. + with torch.no_grad(): + if self.working_to_master_map is not None: + for param in model.parameters(): + if (param is None) or (id(param) not in self.working_to_master_map): + continue + master_param = self.working_to_master_map[id(param)] + if self.use_zero: + # master_param is sharded under Zero setting + padding_size = (self.dp_size - param.numel() % self.dp_size) % self.dp_size + if padding_size > 0: + padded_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size]) + else: + padded_param = param.data.view(-1) + sharded_param = padded_param.split(padded_param.numel() // self.dp_size)[self.dp_rank] + master_param.data.copy_(sharded_param.data) + else: + master_param.data.copy_(param.data) + + if self.verbose: + logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.") + + def save_sharded_optimizer(self, + optimizer: OptimizerWrapper, + checkpoint: str, + gather_dtensor: bool = True, + prefix: Optional[str] = None, + size_per_shard: int = 1024): + """ + Save sharded optimizer checkpoint under the given checkpointing path. + The following files will be created under the path: + - An index file (pytorch_optim.bin.index.json) containing a map between optimizer states and file names + - A group file (pytorch_optim_group.bin) recording information of param_groups + - Multiple files that store state tensors of optimizers. + If pipeline parallelism is used, the filenames are in the form of "pytorch_optim.-stage-000XX-shard-000XX.bin". + If pipeline parallelism is not used, "pytorch_optim.-000XX.bin" + + Args: + optimizer (OptimizerWrapper): Optimizer to save sharded state_dict + checkpoint (str): Path to save optimizer state_dict + gather_dtensor (bool): Whether to gather_dtensor, not used + prefix (str): Perfix of file to save + size_per_shard (int): Max file size of each file shard that store state tensors + """ + if os.path.isfile(checkpoint): + logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") + return + + Path(checkpoint).mkdir(parents=True, exist_ok=True) + + # Devices along the same dp_group share the same copies of states when zero is not used. + # In this case only let the device with dp_rank == 0 save the model. + if not self.use_zero and self.dp_rank != 0: + return + + # Then collect the sharded states along dp_group(if using zero)/tp_group. + # Only devices with (dp_rank == 0 and tp_rank == 0) are responsible for states saving. + state_dict_shard = HypridParallelCheckpointIO._optimizer_sharder( + optimizer, + use_zero=self.use_zero, + dp_group=self.dp_group, + tp_group=self.tp_group, + master_to_working_map=self.master_to_working_map, + size_per_shard=size_per_shard) + states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix) + index_file = CheckpointIndexFile(checkpoint) + control_saving = (self.dp_rank == 0 and self.tp_rank == 0) + + if self.pp_size == 1: + # When pipeline is not used, save the optimizer shards as in general checkpointIO + total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=states_name, + is_master=control_saving) + + if control_saving: + # Store param groups. + index_file.append_meta_data("param_groups", param_group_file) + group_file_path = os.path.join(checkpoint, param_group_file) + save_param_groups(optimizer.param_info, group_file_path) + # Store index file. + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + if self.verbose: + logging.info(f"The optimizer is going to be split to checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}.") + + else: + # When pipeline is used, each stage produces its own shard files and index files. + # Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/ + # After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder. + + final_index_file_path = copy.deepcopy(save_index_file) + tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files") + Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True) + + # Manage filenames of sharded weights and index file for each pipeline stage. + states_name = states_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-shard.bin") + save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json") + save_index_file = os.path.join("tmp_index_files", save_index_file) + + total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=states_name, + is_master=control_saving, + use_pp_format=True) + + if control_saving: + assert self.dp_rank == 0 and self.tp_rank == 0, "The saving process should have both dp_rank and tp_rank as 0." + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + else: + return + + dist.barrier(self.pp_group) + + # The global master rank integrates the index files and clean the folder. + if self.pp_rank == 0: + + final_index_file = CheckpointIndexFile(checkpoint) + final_index_file.append_meta_data("total_size", 0) + + for filename in os.listdir(tmp_index_file_folder): + stage_index_file = CheckpointIndexFile.from_file(os.path.join(tmp_index_file_folder, filename)) + final_index_file.metadata["total_size"] += stage_index_file.metadata["total_size"] + for param_id, state_filename in stage_index_file.weight_map.items(): + final_index_file.append_weight_map(param_id, state_filename) + + # Store param groups. + final_index_file.append_meta_data("param_groups", param_group_file) + group_file_path = os.path.join(checkpoint, param_group_file) + save_param_groups(optimizer.param_info, group_file_path) + + final_index_file.write_index_file(final_index_file_path) + rmtree(tmp_index_file_folder) + + if self.verbose: + logging.info(f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {final_index_file_path}.") + + def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_file: str, prefix: str = ""): + """ + Load sharded optimizer with the given path to index file of checkpoint folder. + + Args: + optimizer (OptimizerWrapper): The optimizer to be loaded. + checkpoint_index_file (str): Path to the index file of checkpointing folder. + prefix (str): Not used. + """ + + def _get_param_id_from_optimizer_param(param: torch.Tensor, + master_to_working_map: Optional[Dict[int, torch.Tensor]] = None): + if master_to_working_map is not None: + working_param = master_to_working_map[id(param)] + else: + working_param = param + return optimizer.param_info['param2id'][id(working_param)] + + # id_map is a mapping from param ids kept by current pipeline, to their corresponding parameter objects. + # When Zero is used, the mapped parameter objects should be fp32 master parameters. + # IDs should be obtained through saved param2id mapping earlier saved in optimizer.param_info. + id_map = {} + for pg in optimizer.optim.param_groups: + for param in pg['params']: + param_id = _get_param_id_from_optimizer_param(param, self.master_to_working_map) + id_map[param_id] = param + + # Read checkpoint index file. + ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file) + ckpt_root_path = ckpt_index_file.root_path + weight_map = ckpt_index_file.weight_map + weight_map = {int(k): v for k, v in weight_map.items()} # convert saved id from str to int + + # Load param_groups + param_group_path = ckpt_index_file.get_param_group_filename() + if param_group_path is None: + raise RuntimeError(f'Invalid index file path {checkpoint_index_file} for an optimizer. \ + Lacking param group file under current directory.') + saved_groups = torch.load(param_group_path) + + updated_groups = [] + for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups): + # obtain updated param group + new_pg = copy.deepcopy(saved_pg) + new_pg['params'] = old_pg['params'] # The parameters in the same group shouln't change. + updated_groups.append(new_pg) + optimizer.optim.__dict__.update({'param_groups': updated_groups}) + + # Load saved states to optimizer. + # Keep a record of loaded files so that file will not be repeatedly loaded. + loaded_file = set() + for pg in optimizer.optim.param_groups: + for param in pg['params']: + if param is None: + continue + param_id = _get_param_id_from_optimizer_param(param, self.master_to_working_map) + if param_id not in weight_map: + continue + filename = weight_map[param_id] + + # If this param's states has been loaded before, directly return. + if filename in loaded_file: + continue + + file_path = os.path.join(ckpt_root_path, filename) + state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False) + load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True) + loaded_file.add(filename) + + # Then shard the loaded optimizer states if using tp/zero. + for param, state in optimizer.optim.state.items(): + device = param.device + if self.master_to_working_map is not None: + working_param = self.master_to_working_map[id(param)] + else: + working_param = param + original_shape = optimizer.param_info['param2shape'][id(working_param)] + sharded_state = self.shard_from_complete_optimizer_state(state, + current_shape=working_param.shape, + original_shape=original_shape, + device=device, + inplace=True) + optimizer.optim.state[param] = sharded_state + + sharded_optimizer_loading_epilogue(optimizer.optim) + if self.verbose: + logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.") + + def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool = True): + # TODO(Baizhou): support this feature after implementing complete state_dict collection + raise NotImplementedError + + def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): + # TODO(Baizhou): support this feature after implementing complete state_dict collection + raise NotImplementedError + + def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool): + # TODO(Baizhou): support this feature after implementing complete state_dict collection + raise NotImplementedError + + def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool): + # TODO(Baizhou): support this feature after implementing complete state_dict collection + raise NotImplementedError + + def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): + """ + Save lr scheduler to checkpoint but only on master process. + """ + if self.coordinator.is_master(): + super().save_lr_scheduler(lr_scheduler, checkpoint) + + def link_master_and_working_param(self, working_to_master_map: Dict[Union[int, torch.Tensor], torch.Tensor], + master_to_working_map: Dict[Union[int, torch.Tensor], torch.Tensor]): + """ + Create mappings between working params (for forward/backward) and master params (for optimizer update) with passed in mappings. + This mapping can only be created when mixied precision is used. + The created mappings should be mappings from integer parameter addresses to parameter objects. + + Args: + working_to_master_map (Dict[Union[int, torch.Tensor], torch.Tensor]): A mapping from working parameters objects/addresses to master parameter objects. + master_to_working_map (Dict[Union[int, torch.Tensor], torch.Tensor]): A mapping from master parameters objects/addresses to working parameter objects. + """ + self.working_to_master_map = dict() + for k, v in working_to_master_map.items(): + if isinstance(k, torch.Tensor): + self.working_to_master_map[id(k)] = v + elif isinstance(k, int): + self.working_to_master_map[k] = v + else: + raise ValueError( + f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!") + + self.master_to_working_map = dict() + for k, v in master_to_working_map.items(): + if isinstance(k, torch.Tensor): + self.master_to_working_map[id(k)] = v + elif isinstance(k, int): + self.master_to_working_map[k] = v + else: + raise ValueError( + f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!") + + @staticmethod + def gather_from_sharded_optimizer_state(state: OrderedDict, param: torch.Tensor, original_shape: torch.Size, + dp_group: ProcessGroup, tp_group: ProcessGroup, use_zero: bool, + inplace: bool) -> OrderedDict: + """ + With given parameter and its optimizer states, gather the complete optimizer state for saving. + + Args: + state (OrderedDict): Optimizer states of given parameter, might be distributed among tp/dp group if using TP/Zero. + param (torch.Tensor): The given parameter. It should be working_param when using Zero. + original_shape (torch.Size): The size of parameter before sharding. + dp_group (ProcessGroup): The process group of data parallel. + tp_group (ProcessGroup): The process group of tensor parallel. + use_zero (bool): Whether Zero is used. + inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state. + + Returns: + OrderedDict: The complete optimizer state of given parameter. + """ + dp_size = dist.get_world_size(dp_group) + tp_size = dist.get_world_size(tp_group) + current_shape = param.shape + state_ = state if inplace else copy.deepcopy(state) + + for k, v in state_.items(): + if isinstance(v, torch.Tensor) and k != 'step': + + # First gather Zero shards. + if use_zero: + v = v.cuda() + gather_tensor = [torch.zeros_like(v) for _ in range(dp_size)] + dist.all_gather(gather_tensor, v, group=dp_group) + v = torch.stack(gather_tensor).view(-1)[:param.numel()].reshape_as(param) + + # Then gather TP shards. + partition_dim = search_tp_partition_dim(current_shape, original_shape, tp_size) + if partition_dim is not None: + gather_tensor = [torch.zeros_like(v) for _ in range(tp_size)] + dist.all_gather(gather_tensor, v, group=tp_group) + v = torch.cat(gather_tensor, dim=partition_dim) + + state_[k] = v.detach().clone().cpu() + + return state_ + + def shard_from_complete_optimizer_state(self, state: OrderedDict, current_shape: torch.Size, + original_shape: torch.Size, device: torch.device, + inplace: bool) -> OrderedDict: + """ + With complete optimizer states of a specific parameter loaded from checkpoint, + slice out the sharded optimizer states kept by current device. + + Args: + state (OrderedDict): Complete optimizer states of a given parameter, loaded from checkpoint. + current_shape (torch.Size): The size of parameter after sharding. + original_shape (torch.Size): The size of parameter before sharding. + device (torch.device): The destination device of loaded optimizer states. + inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state. + + Returns: + OrderedDict: The sharded optimizer state of the given parameter. + """ + state_ = state if inplace else copy.deepcopy(state) + + for k, v in state_.items(): + if isinstance(v, torch.Tensor) and k != 'step': + + # Shard state along tensor parallel group. + partition_dim = search_tp_partition_dim(current_shape, original_shape, self.tp_size) + if partition_dim is not None: + slice_size = current_shape[partition_dim] + v = v.split(slice_size, dim=partition_dim)[self.tp_rank] + + # Shard state along data parallel group when using Zero. + if self.use_zero: + padding_size = (self.dp_size - v.numel() % self.dp_size) % self.dp_size + with torch.no_grad(): + v = v.flatten() + if padding_size > 0: + v = torch.nn.functional.pad(v, [0, padding_size]) + slice_size = v.numel() // self.dp_size + v = v.split(slice_size, dim=0)[self.dp_rank] + + state_[k] = v.detach().clone().to(device) + + return state_ diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 77ff7784a514..6dadaba3e64f 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -1,4 +1,5 @@ # coding=utf-8 +import copy import os import re from collections import abc as container_abcs @@ -10,10 +11,17 @@ import torch import torch.nn as nn from torch.optim import Optimizer +from transformers.modeling_utils import PreTrainedModel, get_parameter_dtype +from transformers.modeling_utils import unwrap_model as unwrap_huggingface_model -from colossalai.interface import OptimizerWrapper +from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.nn.optimizer import ColossalaiOptimizer -from colossalai.tensor.d_tensor import is_distributed_tensor +from colossalai.tensor.d_tensor import ( + is_customized_distributed_tensor, + is_distributed_tensor, + to_global, + to_global_for_customized_distributed_tensor, +) SAFE_WEIGHTS_NAME = "model.safetensors" WEIGHTS_NAME = "pytorch_model.bin" @@ -88,8 +96,35 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool: return False +def search_tp_partition_dim(current_shape: torch.Size, original_shape: torch.Size, tp_size: int) -> Optional[int]: + """ + Given the current shape of parameter and the shape of parameter before sharding, + return the dimension along which the parameter is sharded when using tensor parallel. + If tensor parallel is not used, return None. + + Args: + current_shape (torch.Size): The current shape of parameter after sharding. + original_shape (torch.Size): The shape of parameter before sharding. + tp_size (int): The size of tp group. + + Returns: + Optional[int]: The dimension along which parameter is partitioned. + """ + partition_dim = None + for dim, length in enumerate(original_shape): + if length > current_shape[dim]: + partition_dim = dim + break + if partition_dim is not None: + assert original_shape[partition_dim] == tp_size * current_shape[partition_dim], \ + f"The parameter isn't evenly distributed among tensor parallel group: \ + shape before sharding {original_shape}, shape after sharding {current_shape}" + + return partition_dim + + # ====================================== -# Helper functions for saving shard file +# Helper classes and functions for saving shard file # ====================================== def unwrap_optimizer(optimizer: OptimizerWrapper): ''' @@ -104,12 +139,97 @@ def unwrap_optimizer(optimizer: OptimizerWrapper): return unwrapped_optim +class StateDictSharder: + + def __init__(self, size_per_shard: int) -> None: + self.max_shard_size = size_per_shard + self.current_block = OrderedDict() + self.current_block_size = 0 + + def append_param(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[OrderedDict], int]: + + tensor_size = calculate_tensor_size(tensor) + ret_block = None + ret_block_size = 0 + + # before we return the current block and create a new block, + # we need to ensure that the current block is not empty + if self.current_block_size + tensor_size > self.max_shard_size and self.current_block_size > 0: + ret_block = self.current_block + ret_block_size = self.current_block_size + self.current_block = OrderedDict() + self.current_block_size = 0 + + self.current_block[name] = tensor + self.current_block_size += tensor_size + return ret_block, ret_block_size + + def append_optim_state(self, param_id: int, state: OrderedDict) -> Tuple[Optional[OrderedDict], int]: + + # A state might contain more than one tensors. + # e.g. each Adam state includes: 'step', 'exp_avg', 'exp_avg_sq' + state_size = 0 + isDTensor = False + for state_tensor in state.values(): + + # When state_tensor is not of Tensor class, + # e.g., a SGD optimizer with momentum set to 0 can have None as state + # The calculation of tensor size should be skipped to avoid error. + if not isinstance(state_tensor, torch.Tensor): + continue + + # If the states are stored as DTensors, mark isDTensor as true. + if is_distributed_tensor(state_tensor): + isDTensor = True + state_size += calculate_tensor_size(state_tensor) + + ret_block = None + ret_block_size = 0 + + # directly return if state is stored as distributed tensor + if isDTensor: + return ret_block, ret_block_size + + # before we return the current block and create a new block, + # we need to ensure that the current block is not empty + if self.current_block_size + state_size > self.max_shard_size and self.current_block_size > 0: + ret_block = self.current_block + ret_block_size = self.current_block_size + self.current_block = OrderedDict() + self.current_block_size = 0 + + self.current_block[param_id] = state + self.current_block_size += state_size + return ret_block, ret_block_size + + +def gather_distributed_param(param: torch.Tensor, keep_vars: bool = False) -> torch.Tensor: + """ + Gather the complete parameter for saving if passed in param is distributed under tp setting. + + Args: + param (torch.Tensor): A model parameter, might be d_tensor. + keep_vars (bool, optional): Whether to return the parameter in calculation graph. Defaults to False. + + Returns: + torch.Tensor: the complete parameter + """ + param_ = param if keep_vars else param.detach() + if is_distributed_tensor(param_): + return to_global(param_) + elif is_customized_distributed_tensor(param_): + return to_global_for_customized_distributed_tensor(param_) + else: + return param_ + + def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]], checkpoint: str, index_file: "CheckpointIndexFile", base_filename: str, is_master: bool, - use_safetensors: bool = False) -> int: + use_safetensors: bool = False, + use_pp_format: bool = False) -> int: ''' Save sharded state dict only on master rank, this method can be used by both model and optimizer states. Args: @@ -117,18 +237,21 @@ def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]] checkpoint (str): The path of checkpoint directory as string. index_file (CheckpointIndexFile): The index file object to be updated. base_filename (str): Decides the prefix of filenames of shards. - is_master (bool): Whether current rank is master. - use_safetensors (bool): Whether to use safetensors to save checkpoint. + is_master (bool): Whether current rank is main process. + use_safetensors (bool, optional): Whether to use safetensors to save checkpoint. Defaults to False. + use_pp_format: (bool, optional): Whether to save the files in pipeline format including stage information. Defaults to False. Returns: int: the total size of shards ''' total_size = 0 + shard_filenames = [] for idx, shard_pair in enumerate(sharded_state_dict): + shard, current_size = shard_pair if not is_master: + del shard continue - shard, current_size = shard_pair shard_file = get_shard_filename(base_filename, idx) total_size = total_size + current_size for key in shard.keys(): @@ -137,6 +260,11 @@ def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]] # Only save on master rank. save_state_dict(shard, checkpoint_file_path, use_safetensors=use_safetensors) + shard_filenames.append(shard_file) + del shard + + # Clean folder, deleted unneeded files. + clean_folder(checkpoint, base_filename, shard_filenames, is_master=is_master, use_pp_format=use_pp_format) return total_size @@ -146,28 +274,17 @@ def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a given size. """ - current_block = {} - current_block_size = 0 + state_dict_sharder = StateDictSharder(max_shard_size) for key, weight in state_dict.items(): - ret_block = None - ret_block_size = 0 if not is_distributed_tensor(weight): - weight_size = calculate_tensor_size(weight) - - # If this weight is going to tip up over the maximal size, we split. - if current_block_size + weight_size > max_shard_size and current_block_size > 0: - ret_block = current_block - ret_block_size = current_block_size - current_block = {} - current_block_size = 0 - current_block[key] = weight - current_block_size += weight_size + block, block_size = state_dict_sharder.append_param(key, weight) - if ret_block != None: - yield ret_block, ret_block_size + if block != None: + yield block, block_size - yield current_block, current_block_size + # Return the last block in sharder. + yield state_dict_sharder.current_block, state_dict_sharder.current_block_size def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]: @@ -178,47 +295,207 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) -> # Only split state_dict['state']; state_dict['param_group'] is not considered in this function. states = state_dict['state'] - - current_block = {} - current_block_size = 0 + state_dict_sharder = StateDictSharder(max_shard_size) for param_id, state in states.items(): + block, block_size = state_dict_sharder.append_optim_state(param_id, state) + if block != None: + yield block, block_size - ret_block = None - ret_block_size = 0 + # Return the last block in sharder. + yield state_dict_sharder.current_block, state_dict_sharder.current_block_size - # A state might contain more than one tensors. - # e.g. each Adam state includes: 'step', 'exp_avg', 'exp_avg_sq' - state_size = 0 - isDTensor = False - for state_tensor in state.values(): - # When state_tensor is not of Tensor class, - # e.g., a SGD optimizer with momentum set to 0 can have None as state - # The calculation of tensor size should be skipped to avoid error. - if not isinstance(state_tensor, torch.Tensor): - continue +# ====================================== +# Helper functions for saving state dict +# ====================================== - # If the states are stored as DTensors, mark isDTensor as true. - if is_distributed_tensor(state_tensor): - isDTensor = True - state_size += calculate_tensor_size(state_tensor) - if not isDTensor: +def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors: bool) -> None: + """ + Save state dict to checkpoint. + + Args: + state_dict (dict): state dict. + checkpoint_file_path (str): path to the checkpoint file. + use_safetensors (bool): whether to use safetensors to save the checkpoint. + """ + if use_safetensors: + assert is_safetensors_available(), "safetensors is not available." + assert checkpoint_file_path.endswith('.safetensors'), \ + "safetensors only supports .safetensors suffix for checkpoint file." + from safetensors.torch import save_file as safe_save_file + safe_save_file(state_dict, checkpoint_file_path, metadata={"format": "pt"}) + else: + torch.save(state_dict, checkpoint_file_path) + + +def save_param_groups(state_dict: dict, group_file_path: str) -> None: + """ + Save information of param_groups to given file path. + + Args: + state_dict (dict): state dict. + group_file_path (str): path to the group file. + """ + param_groups = state_dict["param_groups"] + torch.save(param_groups, group_file_path) + + +def clean_folder(checkpoint_path: str, + weights_name: str, + shard_filenames: List[str], + is_master: bool = True, + use_pp_format: bool = False): + """ + Clean the unneeded files in checkpoint directory after shards of state_dict have been saved. + + Args: + checkpoint_path (str): Path to the checkpoint directory. + weights_name (str): Decides the prefix of filenames of weight shards. + shard_filenames (List[str]): The list of saved shard filenames which should not be removed. + is_master (bool, optional): Whether current rank is main process. Defaults to True. + use_pp_format: (bool, optional): Whether to save the files in pipeline format including stage information. Defaults to False. + + """ + if is_master: + for filename in os.listdir(checkpoint_path): + full_filename = os.path.join(checkpoint_path, filename) + weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "") + filename_no_suffix = filename.replace(".bin", "").replace(".safetensors", "") + if not use_pp_format: + reg = re.compile(r"(.*?)-\d{5}") + else: + # When this checkpoint is created by pipeline parallel process, the pattern is a little different. + reg = re.compile(r"(.*?)-stage-\d{5}-shard-\d{5}") + if (filename.startswith(weights_no_suffix) and os.path.isfile(full_filename) + and filename not in shard_filenames and reg.fullmatch(filename_no_suffix) is not None): + os.remove(full_filename) + + +def save_config_file(model: nn.Module, checkpoint_path: str, is_master: bool = True): + """ + Save config.json/generation_config.json if model is a Huggingface pretrained model. + This method can only be called when a model is saved in a sharded way. + + Args: + model (nn.Module): The model whose config should be saved if it's a huggingface model. + checkpoint_path (str): Path to the checkpoint directory. + is_master (bool): Whether current rank is main process. + """ + if not isinstance(model, PreTrainedModel): + return + + model = unwrap_huggingface_model(model) + + # save the string version of dtype to the config, e.g. convert torch.float32 => "float32" + dtype = get_parameter_dtype(model) + model.config.torch_dtype = str(dtype).split(".")[1] + + # Attach architecture to the config + model.config.architectures = [model.__class__.__name__] + + # Save the config + if is_master: + model.config.save_pretrained(checkpoint_path) + if model.can_generate(): + model.generation_config.save_pretrained(checkpoint_path) + + +def save_dtensor(name: str, tensor: torch.Tensor, index_file: "CheckpointIndexFile", use_safetensors: bool) -> None: + """ + Save distributed tensor to checkpoint. This checkpoint will be a dictionary which contains + only one tensor. + + Args: + tensor (Tensor): tensor to be saved. + index_file (CheckpointIndexFile): path to the checkpoint file. + size_per_shard (int): size per shard in MB. + """ + root_path = index_file.root_path + output_root_path = root_path.joinpath('dtensor') + + # create directory + output_root_path.mkdir(exist_ok=True) + + # save tensor to this directory + # TODO(YuliangLiu): get index of the tensor shard + # e.g. index = + index = 0 + + # save tensor to file + ckpt_file_name = generate_dtensor_file_name(name, index, use_safetensors) + ckpt_file_path = output_root_path.joinpath(ckpt_file_name) + + # dtensor ckpt file always contains only one tensor + state_dict = {name: tensor} + save_state_dict(state_dict, str(ckpt_file_path), use_safetensors) + + # update the weight map + # * means all shards + ckpt_file_name_in_weight_map = 'dtensor/' + generate_dtensor_file_name(name, '*', use_safetensors) + index_file.append_weight_map(name, ckpt_file_name_in_weight_map) + + +def get_checkpoint_file_suffix(use_safetensors: bool) -> str: + """ + Get checkpoint file suffix. + + Args: + use_safetensors (bool): whether to use safetensors to save the checkpoint. + + Returns: + str: checkpoint file suffix. + """ + if use_safetensors: + return '.safetensors' + else: + return '.bin' + + +def generate_checkpoint_shard_file_name(index: int, + total_number: int, + use_safetensors: bool, + prefix: str = None) -> str: + """ + Generate checkpoint shard file name. + + Args: + index (int): index of the shard. + total_number (int): total number of shards. + use_safetensors (bool): whether to use safetensors to save the checkpoint. + prefix (str): prefix of the shard file name. Default: None. - if current_block_size + state_size > max_shard_size and current_block_size > 0: - ret_block = current_block - ret_block_size = current_block_size - current_block = {} - current_block_size = 0 + Returns: + str: checkpoint shard file name. + """ + suffix = get_checkpoint_file_suffix(use_safetensors) - current_block[param_id] = state - current_block_size += state_size + if prefix is None: + return f"{index:05d}-of-{total_number:05d}.{suffix}" + else: + return f"{prefix}-{index:05d}-of-{total_number:05d}.{suffix}" - if ret_block != None: - yield ret_block, ret_block_size - yield current_block, current_block_size +def generate_dtensor_file_name(param_name: str, index: int, use_safetensors: bool) -> str: + """ + Generate dtensor file name. + + Args: + param_name (str): name of the distributed parameter. + index (int): index of the shard. + use_safetensors (bool): whether to use safetensors to save the checkpoint. + + Returns: + str: dtensor file name. + """ + suffix = get_checkpoint_file_suffix(use_safetensors) + return f'{param_name}.{index}.{suffix}' + + +# ======================================== +# Helper functions for loading state dict +# ======================================== def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool = False): @@ -331,17 +608,21 @@ def update_group(group, new_group): return id_map -def load_states_into_optimizer(optimizer: Optimizer, state_dict: dict, id_map: dict): +def load_states_into_optimizer(optimizer: Optimizer, state_dict: dict, id_map: dict, strict: bool = False): r"""Copies states from `state_dict` into an Optimizer object. Args: optimizer(Optimizer): An initialized Optimizer object to be loaded - state_dict(dict): a mapping from tensor index (an integer) + state_dict(dict): A mapping from tensor index (an integer) to its states to be loaded (a mapping from state name to a tensor). - id_map(dict): a mapping from tensor index (an integer) + id_map(dict): A mapping from tensor index (an integer) to its corresponding parameter (a tensor) whose states will be updated. + strict(bool, optional): If set to True, only load the parameters with its id in id_map. Defaults to False. """ + # Ensure that the keys of state_dict are integers. + state_dict = {int(k): v for k, v in state_dict.items()} + def cast(param, value, key=None): r"""Make a deep copy of value, casting all tensors to device of param.""" if isinstance(value, torch.Tensor): @@ -368,7 +649,7 @@ def cast(param, value, key=None): if k in id_map: param = id_map[k] new_states[param] = cast(param, v) - else: + elif not strict: new_states[k] = v optimizer.state.update(new_states) @@ -386,165 +667,6 @@ def sharded_optimizer_loading_epilogue(optimizer: Optimizer): optimizer.defaults.setdefault('differentiable', False) -# ====================================== -# Helper functions for saving state dict -# ====================================== - - -def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors: bool) -> None: - """ - Save state dict to checkpoint. - - Args: - state_dict (dict): state dict. - checkpoint_file_path (str): path to the checkpoint file. - use_safetensors (bool): whether to use safetensors to save the checkpoint. - """ - if use_safetensors: - assert is_safetensors_available(), "safetensors is not available." - assert checkpoint_file_path.endswith('.safetensors'), \ - "safetensors only supports .safetensors suffix for checkpoint file." - from safetensors.torch import save_file as safe_save_file - safe_save_file(state_dict, checkpoint_file_path, metadata={"format": "pt"}) - else: - torch.save(state_dict, checkpoint_file_path) - - -def save_param_groups(state_dict: dict, group_file_path: str) -> None: - """ - Save information of param_groups to given file path. - - Args: - state_dict (dict): state dict. - group_file_path (str): path to the group file. - """ - param_groups = state_dict["param_groups"] - torch.save(param_groups, group_file_path) - - -def save_dtensor(name: str, tensor: torch.Tensor, index_file: "CheckpointIndexFile", use_safetensors: bool) -> None: - """ - Save distributed tensor to checkpoint. This checkpoint will be a dictionary which contains - only one tensor. - - Args: - tensor (Tensor): tensor to be saved. - index_file (CheckpointIndexFile): path to the checkpoint file. - size_per_shard (int): size per shard in MB. - """ - root_path = index_file.root_path - output_root_path = root_path.joinpath('dtensor') - - # create directory - output_root_path.mkdir(exist_ok=True) - - # save tensor to this directory - # TODO(YuliangLiu): get index of the tensor shard - # e.g. index = - index = 0 - - # save tensor to file - ckpt_file_name = generate_dtensor_file_name(name, index, use_safetensors) - ckpt_file_path = output_root_path.joinpath(ckpt_file_name) - - # dtensor ckpt file always contains only one tensor - state_dict = {name: tensor} - save_state_dict(state_dict, str(ckpt_file_path), use_safetensors) - - # update the weight map - # * means all shards - ckpt_file_name_in_weight_map = 'dtensor/' + generate_dtensor_file_name(name, '*', use_safetensors) - index_file.append_weight_map(name, ckpt_file_name_in_weight_map) - - -def get_checkpoint_file_suffix(use_safetensors: bool) -> str: - """ - Get checkpoint file suffix. - - Args: - use_safetensors (bool): whether to use safetensors to save the checkpoint. - - Returns: - str: checkpoint file suffix. - """ - if use_safetensors: - return '.safetensors' - else: - return '.bin' - - -def generate_checkpoint_shard_file_name(index: int, - total_number: int, - use_safetensors: bool, - prefix: str = None) -> str: - """ - Generate checkpoint shard file name. - - Args: - index (int): index of the shard. - total_number (int): total number of shards. - use_safetensors (bool): whether to use safetensors to save the checkpoint. - prefix (str): prefix of the shard file name. Default: None. - - Returns: - str: checkpoint shard file name. - """ - suffix = get_checkpoint_file_suffix(use_safetensors) - - if prefix is None: - return f"{index:05d}-of-{total_number:05d}.{suffix}" - else: - return f"{prefix}-{index:05d}-of-{total_number:05d}.{suffix}" - - -def generate_dtensor_file_name(param_name: str, index: int, use_safetensors: bool) -> str: - """ - Generate dtensor file name. - - Args: - param_name (str): name of the distributed parameter. - index (int): index of the shard. - use_safetensors (bool): whether to use safetensors to save the checkpoint. - - Returns: - str: dtensor file name. - """ - suffix = get_checkpoint_file_suffix(use_safetensors) - return f'{param_name}.{index}.{suffix}' - - -def save_state_dict_as_shard( - state_dict: dict, - checkpoint_path: str, - index: int, - total_number: int, - use_safetensors: bool, - prefix: str = None, -) -> None: - """ - Save state dict as shard. - - Args: - state_dict (dict): state dict. - checkpoint_path (str): path to the checkpoint file. - index (int): index of the shard. - total_number (int): total number of shards. - prefix (str): prefix of the shard file name. - use_safetensors (bool): whether to use safetensors to save the checkpoint. - """ - # generate the shard name - shard_file_name = generate_checkpoint_shard_file_name(index, total_number, use_safetensors, prefix) - shard_file_path = Path(checkpoint_path).joinpath(shard_file_name).absolute() - - # save the shard - save_state_dict(state_dict, str(shard_file_path), use_safetensors) - - -# ======================================== -# Helper functions for loading state dict -# ======================================== - - def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]: """ Check whether the checkpoint has an index file. @@ -654,5 +776,5 @@ def get_shard_filename(weights_name: str, idx: int): get shard file name """ shard_file = weights_name.replace(".bin", f"-{idx+1:05d}.bin") - shard_file = shard_file.replace(".safetensors", f"-{idx + 1:05d}.safetensors") + shard_file = shard_file.replace(".safetensors", f"-{idx+1:05d}.safetensors") return shard_file diff --git a/colossalai/cluster/process_group_mesh.py b/colossalai/cluster/process_group_mesh.py index 1dfd261d5d01..623160003767 100644 --- a/colossalai/cluster/process_group_mesh.py +++ b/colossalai/cluster/process_group_mesh.py @@ -94,17 +94,23 @@ def unravel(rank: int, shape: Tuple[int, ...]) -> Tuple[int, ...]: return np.unravel_index(rank, shape) @staticmethod - def ravel(coord: Tuple[int, ...], shape: Tuple[int, ...]) -> int: + def ravel(coord: Tuple[int, ...], shape: Tuple[int, ...], mode: str = 'raise') -> int: """Convert a coordinate to a rank. + mode: ['raise', 'wrap', 'clip'], see https://numpy.org/doc/stable/reference/generated/numpy.ravel_multi_index.html. + with wrap, index out of range would be wrapped around. + For instance, ravel((0, i, 0), (1, 2, 1), 'wrap') returns (i % 2) Args: coords (Tuple[int, ...]): Coordinate to be converted. shape (Tuple[int, ...]): Shape of the process group mesh. + mode (Optional[str]): The mode for numpy.ravel_multi_index. Returns: int: Rank of the coordinate. """ - return np.ravel_multi_index(coord, shape) + + assert mode in ["raise", "wrap", "clip"] + return np.ravel_multi_index(coord, shape, mode) def get_group(self, ranks_in_group: List[int], backend: Optional[str] = None) -> ProcessGroup: """Get the process group with the given ranks. It the process group doesn't exist, it will be created. diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index af7a00b5c720..aed85cf91512 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -173,14 +173,10 @@ def recv_forward(self, prev_rank: int = None) -> Any: Returns: Any: The input tensor or input tensor list. """ - if self.stage_manager.is_first_stage(): - input_tensor = None - else: - if prev_rank is None: - prev_rank = self.stage_manager.get_prev_rank() - cur_rank = self.stage_manager.get_rank() - input_tensor = _recv_object(prev_rank, cur_rank, - self.stage_manager.get_p2p_process_group(prev_rank, cur_rank)) + if prev_rank is None: + prev_rank = self.stage_manager.get_prev_rank() + cur_rank = self.stage_manager.get_rank() + input_tensor = _recv_object(prev_rank, cur_rank, self.stage_manager.get_p2p_process_group(prev_rank, cur_rank)) return input_tensor @@ -193,14 +189,11 @@ def recv_backward(self, next_rank: int = None) -> Any: Returns: Any: The input gradient tensor or gradient tensor list. """ - if self.stage_manager.is_last_stage(): - output_tensor_grad = None - else: - if next_rank is None: - next_rank = self.stage_manager.get_next_rank() - cur_rank = self.stage_manager.get_rank() - output_tensor_grad = _recv_object(next_rank, cur_rank, - self.stage_manager.get_p2p_process_group(next_rank, cur_rank)) + if next_rank is None: + next_rank = self.stage_manager.get_next_rank() + cur_rank = self.stage_manager.get_rank() + output_tensor_grad = _recv_object(next_rank, cur_rank, + self.stage_manager.get_p2p_process_group(next_rank, cur_rank)) return output_tensor_grad @@ -211,12 +204,10 @@ def send_forward(self, output_object: Any, next_rank: int = None) -> None: output_object (Any): Object to be sent. next_rank (int, optional): The rank of the recipient of the tensor. """ - if not self.stage_manager.is_last_stage(): - if next_rank is None: - next_rank = self.stage_manager.get_next_rank() - cur_rank = self.stage_manager.get_rank() - _send_object(output_object, cur_rank, next_rank, - self.stage_manager.get_p2p_process_group(cur_rank, next_rank)) + if next_rank is None: + next_rank = self.stage_manager.get_next_rank() + cur_rank = self.stage_manager.get_rank() + _send_object(output_object, cur_rank, next_rank, self.stage_manager.get_p2p_process_group(cur_rank, next_rank)) def send_backward(self, input_object: Any, prev_rank: int = None) -> None: """Sends the gradient tensor to the previous stage in pipeline. @@ -225,9 +216,7 @@ def send_backward(self, input_object: Any, prev_rank: int = None) -> None: input_object (Any): Object to be sent. prev_rank (int, optional): The rank of the recipient of the tensor """ - if not self.stage_manager.is_first_stage(): - if prev_rank is None: - prev_rank = self.stage_manager.get_prev_rank() - cur_rank = self.stage_manager.get_rank() - _send_object(input_object, cur_rank, prev_rank, - self.stage_manager.get_p2p_process_group(cur_rank, prev_rank)) + if prev_rank is None: + prev_rank = self.stage_manager.get_prev_rank() + cur_rank = self.stage_manager.get_rank() + _send_object(input_object, cur_rank, prev_rank, self.stage_manager.get_p2p_process_group(cur_rank, prev_rank)) diff --git a/colossalai/pipeline/schedule/_utils.py b/colossalai/pipeline/schedule/_utils.py index 3ed9239272f1..583558551b3c 100644 --- a/colossalai/pipeline/schedule/_utils.py +++ b/colossalai/pipeline/schedule/_utils.py @@ -1,9 +1,59 @@ -from typing import Any, List, Optional +from collections import OrderedDict +from typing import Any, List, Optional, Tuple import torch import torch.cuda from torch.nn import Module -from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten +from torch.utils._pytree import ( + SUPPORTED_NODES, + LeafSpec, + TreeSpec, + _is_leaf, + _register_pytree_node, + tree_flatten, + tree_map, + tree_unflatten, +) + + +# this register are for torch under version 1.13.1, maybe removed in the future +def _odict_flatten(d: 'OrderedDict[Any, Any]') -> Tuple[List[Any], Any]: + return list(d.values()), list(d.keys()) + + +def _odict_unflatten(values: List[Any], context: Any) -> 'OrderedDict[Any, Any]': + return OrderedDict((key, value) for key, value in zip(context, values)) + + +_register_pytree_node(OrderedDict, _odict_flatten, _odict_unflatten) + + +def tree_map_hf(fn: Any, pytree: Any): + flat_args, spec = tree_flatten_hf(pytree) + return tree_unflatten([fn(i) for i in flat_args], spec) + + +# use this flatten function to handle the ModelingOutput Class instance. +def tree_flatten_hf(pytree: Any) -> Tuple[List[Any], TreeSpec]: + """Flattens a pytree into a list of values an a TreeSpec that can be used + to reconstruct the pytree. + """ + if isinstance(pytree, OrderedDict): + node_type = OrderedDict + flatten_fn = SUPPORTED_NODES[node_type].flatten_fn + child_pytrees, context = flatten_fn(pytree) + + # Recursively flatten the children + result: List[Any] = [] + children_specs: List['TreeSpec'] = [] + for child in child_pytrees: + flat, child_spec = tree_flatten_hf(child) + result += flat + children_specs.append(child_spec) + return result, TreeSpec(node_type, context, children_specs) + else: + result, tree_spec = tree_flatten(pytree) + return result, tree_spec def to_device(x: Any, device: Optional[torch.device] = None) -> Any: @@ -104,7 +154,7 @@ def detach(x: Any) -> Any: return x -def merge_batch(data: List[Any]) -> Any: +def merge_batch(data: List[Any], batch_size_dim=0) -> Any: """Merge micro batches into a batch. Args: @@ -118,12 +168,17 @@ def merge_batch(data: List[Any]) -> Any: flattened_data = [] tree_spec = None for d in data: - elems, tree_spec = tree_flatten(d) + # elems should be an instance of OrderedDict + elems, tree_spec = tree_flatten_hf(d) flattened_data.append(elems) merged_data = [] + for elem_batch in zip(*flattened_data): if isinstance(elem_batch[0], torch.Tensor): - merged_data.append(torch.cat(elem_batch, dim=0)) + if len(elem_batch[0].shape) == 0: # set loss to None in pipeline outputs + merged_data.append(None) + else: + merged_data.append(torch.cat(elem_batch, dim=batch_size_dim)) else: merged_data.append(list(elem_batch)) return tree_unflatten(merged_data, tree_spec) diff --git a/colossalai/pipeline/schedule/interleaved_pp.py b/colossalai/pipeline/schedule/interleaved_pp.py new file mode 100644 index 000000000000..35a33491b03c --- /dev/null +++ b/colossalai/pipeline/schedule/interleaved_pp.py @@ -0,0 +1,370 @@ +from functools import partial +from typing import Any, Callable, Iterable, List, Optional, Union + +import torch +import torch.cuda +from torch.nn import Module +from torch.utils._pytree import tree_map + +from colossalai.interface import OptimizerWrapper +from colossalai.pipeline.p2p import PipelineP2PCommunication +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.utils.cuda import get_current_device + +from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device +from .base import PipelineSchedule + + +class InterleavedSchedule(PipelineSchedule): + + def __init__(self, num_microbatches: int, num_model_chunks: int, stage_manager: PipelineStageManager) -> None: + self.num_model_chunks = num_model_chunks + assert num_microbatches % self.num_model_chunks == 0, \ + "Number of microbatches should be an integer multiple of number of model chunks" + super().__init__(stage_manager) + self.comm = PipelineP2PCommunication(stage_manager) + self.num_microbatches = num_microbatches + self.batch: Optional[Any] = None + self.batch_size: Optional[int] = None + self.microbatch_offset: Optional[int] = None + self.microbatch_size: Optional[int] = None + + def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: + """Load a batch from data iterator. + + Args: + data_iter (Iterable): Data iterator. + device (Optional[torch.device], optional): Target device. Defaults to None. + """ + batch = next(data_iter) + if device is not None: + batch = tree_map(partial(to_device, device=device), batch) + self.batch = batch + self.batch_size = get_batch_size(batch) + self.microbatch_offset = [0 for _ in range(self.num_model_chunks)] + assert self.batch_size % self.num_microbatches == 0, \ + "Batch size should divided by the number of microbatches" + self.microbatch_size = self.batch_size // self.num_microbatches + + def load_micro_batch(self, model_chunk_id: int) -> Any: + """Load a micro batch from the current batch. + + Args: + microbatch_id (int): the current model chunk idx. + + Returns: + Any: Micro batch. + """ + micro_batch = get_micro_batch(self.batch, self.microbatch_offset[model_chunk_id], self.microbatch_size) + self.microbatch_offset[model_chunk_id] += self.microbatch_size + return tree_map(partial(to_device, device=get_current_device()), micro_batch) + + def get_model_chunk_id(self, microbatch_id: int, forward: bool) -> int: + """Helper method to get the model chunk ID given the iteration number. + + Args: + microbatch_id (int): the current microbatch idx + forward (bool): if is the forward process + + Returns: + int: The model chunk idx of the input microbatch_id + """ + microbatch_id_in_group = (microbatch_id) % (self.stage_manager.num_stages * self.num_model_chunks) + model_chunk_id = microbatch_id_in_group // self.stage_manager.num_stages + if not forward: + model_chunk_id = (self.num_model_chunks - model_chunk_id - 1) + return model_chunk_id + + def is_first_stage(self, model_chunk_id: int) -> bool: + """Is the current virtual stage the first stage + + Args: + model_chunk_id (int): The current model chunk idx. + + Returns: + bool: Whether the current virtual stage is the first stage. + """ + if self.stage_manager.is_first_stage() and model_chunk_id == 0: + return True + return False + + def is_last_stage(self, model_chunk_id: int) -> bool: + """Is the current virtual stage the last stage + + Args: + model_chunk_id (int): The current model chunk idx. + + Returns: + bool: Whether the current virtual stage is the last stage. + """ + if self.stage_manager.is_last_stage() and model_chunk_id == self.num_model_chunks - 1: + return True + return False + + def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> Any: + """Copy the forward output from the previous stage in pipeline as the input tensor of this stage. + For interleaved 1F1B. + + Args: + model_chunk_id (int): The current model chunk idx. + prev_rank (int, optional): The rank of the source of the tensor. + + Returns: + Any: The input tensor or input tensor list. + """ + if self.is_first_stage(model_chunk_id): + input_tensor = None + else: + input_tensor = self.comm.recv_forward(prev_rank) + + return input_tensor + + def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> Any: + """Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage. + For interleaved 1F1B. + + Args: + model_chunk_id (int): The current model chunk idx. + next_rank (int, optional): The rank of the source of the tensor. + + Returns: + Any: The input gradient tensor or gradient tensor list. + """ + if self.is_last_stage(model_chunk_id): + output_tensor_grad = None + else: + output_tensor_grad = self.comm.recv_backward(next_rank) + + return output_tensor_grad + + def send_forward(self, model_chunk_id, output_object: Any, next_rank: int = None) -> None: + """Sends the input tensor to the next stage in pipeline. + For interleaved 1F1B. + + Args: + model_chunk_id (int): The current model chunk idx. + output_object (Any): Object to be sent. + next_rank (int, optional): The rank of the recipient of the tensor. + """ + if not self.is_last_stage(model_chunk_id): + self.comm.send_forward(output_object, next_rank) + + def send_backward(self, model_chunk_id, input_object: Any, prev_rank: int = None) -> None: + """Sends the gradient tensor to the previous stage in pipeline. + For interleaved 1F1B. + + Args: + model_chunk_id (int): The current model chunk idx. + input_object (Any): Object to be sent. + prev_rank (int, optional): The rank of the recipient of the tensor + """ + if not self.is_first_stage(model_chunk_id): + self.comm.send_backward(input_object, prev_rank) + + def forward_step(self, + model_chunk: Module, + model_chunk_id: int, + input_obj: Optional[dict], + criterion: Callable, + accum_loss: Optional[torch.Tensor] = None, + outputs: Optional[List[Any]] = None) -> Union[torch.Tensor, dict]: + """Forward one step of the pipeline + Args: + model (Module): Model Chunk to be run + input_obj (Optional[dict]): The output from the previous stage. If it is the first stage, the `input_obj` is None. + criterion (Callable): Criterion to calculate loss. + accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None. + outputs (Optional[List[Any]], optional): List to store the output of the last stage (final output). Defaults to None. + + Returns: + Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor). + """ + micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id) + + # for the first stage, input_obj is None + # for the non-first stage, input_obj is the output of the previous stage and it's must be a dict + output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj) + + if self.is_last_stage(model_chunk_id): + loss = criterion(output_obj, micro_batch) / self.num_microbatches + if accum_loss is not None: + accum_loss.add_(loss.detach()) + if outputs is not None: + outputs.append(tree_map(detach, output_obj)) + return loss + else: + return output_obj + + def backward_step(self, optimizer: OptimizerWrapper, input_obj: Optional[dict], + output_obj: Union[dict, torch.Tensor], output_obj_grad: Optional[dict]) -> Optional[dict]: + """Backward one step of the pipeline + + Args: + optimizer (OptimizerWrapper): Optimizer to update the model + input_obj (Optional[dict]): Output of the previous stage. If it is the first stage, the `input_obj` is None. + output_obj (Union[dict, torch.Tensor]): Output of the current stage. If it is the last stage, the output is the loss (Tensor). + output_obj_grad (dict): Gradient of the `output_obj`. If it is the last stage, the `output_obj_grad` is None. + + Returns: + Optional[dict]: Gradient of the `input_obj`. If it is the first stage, the `input_obj_grad` is None. + """ + + # Retain the grad on the input_obj. + tree_map(retain_grad, input_obj) + + # Backward pass. + if output_obj_grad is None: + optimizer.backward(output_obj) + else: + if "backward_tensor_keys" not in output_obj: + for k, grad in output_obj_grad.items(): + optimizer.backward_by_grad(output_obj[k], grad) + else: + for k, grad in output_obj_grad.items(): + output_obj[k].grad = grad + for k in output_obj["backward_tensor_keys"]: + tensor_to_backward = output_obj[k] + optimizer.backward_by_grad(tensor_to_backward, tensor_to_backward.grad) + + # Collect the grad of the input_obj. + input_obj_grad = None + if input_obj is not None: + input_obj_grad = {} + for k, v in input_obj.items(): + if isinstance(v, torch.Tensor) and v.grad is not None: + input_obj_grad[k] = v.grad + return input_obj_grad + + def forward_backward_step(self, + model_chunk: Module, + optimizer: OptimizerWrapper, + data_iter: Iterable, + criterion: Callable[..., Any], + return_loss: bool = False, + return_outputs: bool = False) -> dict: + """Runs interleaved 1F1B schedule, with communication between pipeline stages. + + Args: + model_chunk (List[Module]): Model Chunk to be trained. + optimizer (OptimizerWrapper): Optimizer to be used. + data_iter (Iterable): Data iterator. + criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor. + return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss. + return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs. + + Returns: + dict: A dict with keys: 'loss' and 'outputs'. + """ + forward_only = not torch.is_grad_enabled() + + self.load_batch(data_iter) + num_model_chunks = len(model_chunk) + + # num_warmup_microbatches is the step when not all the processes are working + num_microbatches = self.num_microbatches * num_model_chunks + if forward_only: + num_warmup_microbatches = num_microbatches + else: + num_warmup_microbatches = (self.stage_manager.num_stages - self.stage_manager.stage - 1) * 2 + num_warmup_microbatches += (num_model_chunks - 1) * self.stage_manager.num_stages + num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches) + + num_microbatches_remaining = num_microbatches - num_warmup_microbatches + + # Input, output tensors only need to be saved when doing backward passes + input_objs = None + output_objs = None + + if not forward_only: + input_objs = [[] for _ in range(num_model_chunks)] + output_objs = [[] for _ in range(num_model_chunks)] + + outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None + + if return_loss and self.stage_manager.is_last_stage(): + accum_loss = torch.zeros(1, device=get_current_device()) + else: + accum_loss = None + + # for ranks except the first one, get into recv state + # print(self.stage_manager.stage,num_microbatches, num_warmup_microbatches, num_microbatches_remaining) + input_obj = self.recv_forward(0) + input_objs[0].append(input_obj) + # Run warmup forward passes. + for i in range(num_warmup_microbatches): + model_chunk_id = self.get_model_chunk_id(i, forward=True) + + # recv first on first rank to avoid sending or recving at the same time + if self.stage_manager.is_first_stage(): + input_obj = self.recv_forward(model_chunk_id) + output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) + self.send_forward(model_chunk_id, output_obj) + if not forward_only: + input_objs[model_chunk_id].append(input_obj) + output_objs[model_chunk_id].append(output_obj) + else: + output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) + if not forward_only: + output_objs[model_chunk_id].append(output_obj) + self.send_forward(model_chunk_id, output_obj) + if num_microbatches_remaining == 0 and i + 1 == num_warmup_microbatches: + break + else: + model_chunk_id = self.get_model_chunk_id(i + 1, forward=True) + + input_obj = self.recv_forward(model_chunk_id) + if not forward_only: + input_objs[model_chunk_id].append(input_obj) + + # Run 1F1B in steady state. + for i in range(num_microbatches_remaining): + model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatches, forward=True) + last_iteration = (i == (num_microbatches_remaining - 1)) + + output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) + if forward_only: + self.send_forward(model_chunk_id, output_obj) + + if not last_iteration: + input_obj = self.recv_forward(model_chunk_id) + + else: + self.send_forward(model_chunk_id, output_obj) + # Add input_obj and output_obj to end of list. + input_objs[model_chunk_id].append(input_obj) + output_objs[model_chunk_id].append(output_obj) + + model_chunk_id = self.get_model_chunk_id(i, forward=False) + output_obj_grad = self.recv_backward(model_chunk_id) + + # Pop output_obj and output_obj from the start of the list for + # the backward pass. + input_obj = input_objs[model_chunk_id].pop(0) + output_obj = output_objs[model_chunk_id].pop(0) + + # backward + input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad) + + if last_iteration: + input_obj = None + else: + model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatches + 1, forward=True) + input_obj = self.recv_forward(model_chunk_id) + model_chunk_id = self.get_model_chunk_id(i, forward=False) + self.send_backward(model_chunk_id, input_obj_grad) + + # Run cooldown backward passes. + if not forward_only: + for i in range(num_microbatches_remaining, num_microbatches): + model_chunk_id = self.get_model_chunk_id(i, forward=False) + # print(f"{self.stage_manager.stage}/{model_chunk_id}: {len(input_objs[model_chunk_id])} {len(output_objs[model_chunk_id])} {i}") + input_obj = input_objs[model_chunk_id].pop(0) + output_obj = output_objs[model_chunk_id].pop(0) + + output_obj_grad = self.recv_backward(model_chunk_id) + input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad) + self.send_backward(model_chunk_id, input_obj_grad) + + if outputs is not None: + outputs = merge_batch(outputs) + return {'loss': accum_loss, 'outputs': outputs} diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index ade3cf456fe3..5db1c7f30d7f 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -6,25 +6,47 @@ from torch.nn import Module from torch.utils._pytree import tree_map -from colossalai.interface import OptimizerWrapper +from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.pipeline.p2p import PipelineP2PCommunication from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.utils.cuda import get_current_device -from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device +from ._utils import ( + detach, + get_batch_size, + get_micro_batch, + merge_batch, + model_forward, + retain_grad, + to_device, + tree_map_hf, +) from .base import PipelineSchedule class OneForwardOneBackwardSchedule(PipelineSchedule): - def __init__(self, num_microbatches: int, stage_manager: PipelineStageManager) -> None: + def __init__(self, + stage_manager: PipelineStageManager, + num_microbatches: Optional[int] = None, + microbatch_size: Optional[int] = None) -> None: + """1F1B pipeline schedule. + + Args: + stage_manager (PipelineStageManager): Pipeline stage manager + num_microbatches (Optional[int], optional): The number of microbatches. If not provided, it will be derived from microbatch size. Defaults to None. + microbatch_size (Optional[int], optional): Microbatch size. If num_microbatches is provided, this will be ignored. Defaults to None. + """ super().__init__(stage_manager) + assert num_microbatches is not None or microbatch_size is not None, \ + "Either num_microbatches or microbatch_size should be provided" self.comm = PipelineP2PCommunication(stage_manager) self.num_microbatches = num_microbatches + self.microbatch_size = microbatch_size self.batch: Optional[Any] = None self.batch_size: Optional[int] = None self.microbatch_offset: Optional[int] = None - self.microbatch_size: Optional[int] = None + self._use_microbatch_size = num_microbatches is None def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: """Load a batch from data iterator. @@ -39,9 +61,14 @@ def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) self.batch = batch self.batch_size = get_batch_size(batch) self.microbatch_offset = 0 - assert self.batch_size % self.num_microbatches == 0, \ - "Batch size should divided by the number of microbatches" - self.microbatch_size = self.batch_size // self.num_microbatches + if not self._use_microbatch_size: + assert self.batch_size % self.num_microbatches == 0, \ + "Batch size should divided by the number of microbatches" + self.microbatch_size = self.batch_size // self.num_microbatches + else: + assert self.batch_size % self.microbatch_size == 0, \ + "Batch size should divided by the microbatch size" + self.num_microbatches = self.batch_size // self.microbatch_size def load_micro_batch(self) -> Any: """Load a micro batch from the current batch. @@ -53,6 +80,62 @@ def load_micro_batch(self) -> Any: self.microbatch_offset += self.microbatch_size return tree_map(partial(to_device, device=get_current_device()), micro_batch) + def recv_forward(self, prev_rank: int = None) -> Any: + """Copy the forward output from the previous stage in pipeline as the input tensor of this stage. + For 1F1B. + + Args: + prev_rank (int, optional): The rank of the source of the tensor. + + Returns: + Any: The input tensor or input tensor list. + """ + if self.stage_manager.is_first_stage(): + input_tensor = None + else: + input_tensor = self.comm.recv_forward(prev_rank) + + return input_tensor + + def recv_backward(self, next_rank: int = None) -> Any: + """Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage. + For 1F1B. + + Args: + next_rank (int, optional): The rank of the source of the tensor. + + Returns: + Any: The input gradient tensor or gradient tensor list. + """ + if self.stage_manager.is_last_stage(): + output_tensor_grad = None + else: + output_tensor_grad = self.comm.recv_backward(next_rank) + + return output_tensor_grad + + def send_forward(self, output_object: Any, next_rank: int = None) -> None: + """Sends the input tensor to the next stage in pipeline. + For 1F1B. + + Args: + output_object (Any): Object to be sent. + next_rank (int, optional): The rank of the recipient of the tensor. + """ + if not self.stage_manager.is_last_stage(): + self.comm.send_forward(output_object, next_rank) + + def send_backward(self, input_object: Any, prev_rank: int = None) -> None: + """Sends the gradient tensor to the previous stage in pipeline. + For 1F1B. + + Args: + input_object (Any): Object to be sent. + prev_rank (int, optional): The rank of the recipient of the tensor + """ + if not self.stage_manager.is_first_stage(): + self.comm.send_backward(input_object, prev_rank) + def forward_step(self, model: Module, input_obj: Optional[dict], @@ -72,16 +155,16 @@ def forward_step(self, Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor). """ micro_batch = self.load_micro_batch() - # for the first stage, input_obj is None # for the non-first stage, input_obj is the output of the previous stage and it's must be a dict output_obj = model_forward(model, micro_batch, input_obj) if self.stage_manager.is_last_stage(): + loss = criterion(output_obj, micro_batch) / self.num_microbatches if accum_loss is not None: accum_loss.add_(loss.detach()) if outputs is not None: - outputs.append(tree_map(detach, output_obj)) + outputs.append(tree_map_hf(detach, output_obj)) return loss else: return output_obj @@ -102,7 +185,6 @@ def backward_step(self, optimizer: OptimizerWrapper, input_obj: Optional[dict], # Retain the grad on the input_obj. tree_map(retain_grad, input_obj) - # Backward pass. if output_obj_grad is None: optimizer.backward(output_obj) @@ -171,11 +253,11 @@ def forward_backward_step(self, # Run warmup forward passes. for i in range(num_warmup_microbatches): - input_obj = self.comm.recv_forward() + input_obj = self.recv_forward() output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs) - self.comm.send_forward(output_obj) + self.send_forward(output_obj) if not forward_only: input_objs.append(input_obj) @@ -185,7 +267,7 @@ def forward_backward_step(self, # If all microbatches are run in warmup / cooldown phase, then no need to # receive this tensor here. if num_microbatches_remaining > 0: - input_obj = self.comm.recv_forward() + input_obj = self.recv_forward() # Run 1F1B in steady state. for i in range(num_microbatches_remaining): @@ -193,15 +275,15 @@ def forward_backward_step(self, output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs) if forward_only: - self.comm.send_forward(output_obj) + self.send_forward(output_obj) if not last_iteration: - input_obj = self.comm.recv_forward() + input_obj = self.recv_forward() else: # TODO adjust here - self.comm.send_forward(output_obj) - output_obj_grad = self.comm.recv_backward() + self.send_forward(output_obj) + output_obj_grad = self.recv_backward() # Add input_obj and output_obj to end of list. input_objs.append(input_obj) @@ -216,8 +298,8 @@ def forward_backward_step(self, if last_iteration: input_obj = None else: - input_obj = self.comm.recv_forward() - self.comm.send_backward(input_obj_grad) + input_obj = self.recv_forward() + self.send_backward(input_obj_grad) # Run cooldown backward passes. if not forward_only: @@ -225,10 +307,12 @@ def forward_backward_step(self, input_obj = input_objs.pop(0) output_obj = output_objs.pop(0) - output_obj_grad = self.comm.recv_backward() + output_obj_grad = self.recv_backward() input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad) - self.comm.send_backward(input_obj_grad) + self.send_backward(input_obj_grad) if outputs is not None: - outputs = merge_batch(outputs) + if isinstance(model, ModelWrapper): + model = model.unwrap() + outputs = merge_batch(outputs, getattr(model, 'batch_size_dim', 0)) return {'loss': accum_loss, 'outputs': outputs} diff --git a/colossalai/pipeline/stage_manager.py b/colossalai/pipeline/stage_manager.py index fe228e2270dd..6ba7dc629958 100644 --- a/colossalai/pipeline/stage_manager.py +++ b/colossalai/pipeline/stage_manager.py @@ -17,28 +17,24 @@ class PipelineStageManager: Attributes: num_stages (int): Number of stages in the pipeline. stage (int): The current stage. - num_virtual_stages (int): Number of virtual stages in the pipeline. - virtual_stage (int): The current virtual stage. """ - def __init__(self, pg_mesh: ProcessGroupMesh, pipeline_axis: int) -> None: + def __init__(self, pg_mesh: ProcessGroupMesh, pipeline_axis: int, is_virtual: bool = False) -> None: self.pg_mesh = pg_mesh self.pipeline_axis = pipeline_axis - self.num_virtual_stages: Optional[int] = None - self.virtual_stage: Optional[int] = None self.prev_rank: Optional[Tuple[int, ...]] = None self.next_rank: Optional[Tuple[int, ...]] = None self.p2p_groups: Dict[Tuple[int, int], ProcessGroup] = {} # init prev and next coord coord = self.pg_mesh.coordinate() - if self.stage > 0: - prev_coord = coord[: self.pipeline_axis] + \ - (coord[self.pipeline_axis] - 1,) + coord[self.pipeline_axis + 1:] - self.prev_rank = self.pg_mesh.ravel(prev_coord, self.pg_mesh.shape) - if self.stage < self.num_stages - 1: - next_coord = coord[: self.pipeline_axis] + \ - (coord[self.pipeline_axis] + 1,) + coord[self.pipeline_axis + 1:] - self.next_rank = self.pg_mesh.ravel(next_coord, self.pg_mesh.shape) + # the prev rank of rank0 is the last rank + prev_coord = coord[: self.pipeline_axis] + \ + (coord[self.pipeline_axis] - 1,) + coord[self.pipeline_axis + 1:] + self.prev_rank = self.pg_mesh.ravel(prev_coord, self.pg_mesh.shape, mode='wrap') + # the next rank of the last rank is rank0 + next_coord = coord[: self.pipeline_axis] + \ + (coord[self.pipeline_axis] + 1,) + coord[self.pipeline_axis + 1:] + self.next_rank = self.pg_mesh.ravel(next_coord, self.pg_mesh.shape, mode='wrap') # init p2p process groups stages = list(range(self.num_stages)) @@ -48,32 +44,28 @@ def __init__(self, pg_mesh: ProcessGroupMesh, pipeline_axis: int) -> None: ranks_in_group = self.pg_mesh.get_ranks_in_group(group) self.p2p_groups[tuple(ranks_in_group)] = group - def is_first_stage(self, virtual: bool = False) -> bool: - """Is the current stage the first stage. + if is_virtual: + # add the process group of the first rank and the last rank + # only used in interleaved pipeline for now + group = self.pg_mesh.get_group_along_axis(self.pipeline_axis, [stages[0], stages[-1]]) + if self.stage in [stages[0], stages[-1]]: + ranks_in_group = self.pg_mesh.get_ranks_in_group(group) + self.p2p_groups[tuple(ranks_in_group)] = group - Args: - virtual (bool, optional): Whether to consider virtual stages. Defaults to False. + def is_first_stage(self) -> bool: + """Is the current stage the first stage. Returns: bool: Whether the current stage is the first stage. """ - if virtual: - assert self.num_virtual_stages is not None - return self.virtual_stage == 0 return self.stage == 0 - def is_last_stage(self, virtual: bool = False) -> bool: + def is_last_stage(self) -> bool: """Is the current stage the last stage. - Args: - virtual (bool, optional): Whether to consider virtual stages. Defaults to False. - Returns: bool: Whether the current stage is the last stage. """ - if virtual: - assert self.num_virtual_stages is not None - return self.virtual_stage == self.num_virtual_stages - 1 return self.stage == self.num_stages - 1 @property @@ -108,7 +100,6 @@ def get_prev_rank(self) -> int: Returns: int: Rank of the previous stage. """ - assert not self.is_first_stage(), "Cannot get previous rank in the first stage." return self.prev_rank def get_next_rank(self) -> int: @@ -117,39 +108,8 @@ def get_next_rank(self) -> int: Returns: int: Rank of the next stage. """ - assert not self.is_last_stage(), "Cannot get next rank in the last stage." return self.next_rank - def set_num_virtual_stages(self, num_virtual_stages: int) -> None: - """Set the number of virtual stages. - - Args: - num_virtual_stages (int): Number of virtual stages. - """ - self.num_virtual_stages = num_virtual_stages - - def set_virtual_stage(self, virtual_stage: int) -> None: - """Set the virtual stage. - - Args: - virtual_stage (int): Virtual stage. - """ - self.virtual_stage = virtual_stage - - @contextmanager - def switch_virtual_stage(self, virtual_stage: int) -> None: - """A context manager to switch virtual stage. - - Args: - virtual_stage (int): Target virtual stage. - """ - old_stage = self.virtual_stage - try: - self.set_virtual_stage(virtual_stage) - yield - finally: - self.set_virtual_stage(old_stage) - def get_p2p_process_group(self, first_rank: int, second_rank: int) -> ProcessGroup: """Get the p2p process group between two ranks. The order of the two ranks does not matter. diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md index 7dc15f0a0635..2e48a79dc1d7 100644 --- a/colossalai/shardformer/README.md +++ b/colossalai/shardformer/README.md @@ -429,12 +429,13 @@ As shown in the figures above, when the sequence length is around 1000 or greate ### Convergence -To validate that training the model using shardformers does not impact its convergence. We [fine-tuned the BERT model](./examples/convergence_benchmark.py) using both shardformer and non-shardformer approaches. We compared the accuracy, loss, F1 score of the training results. +To validate that training the model using shardformers does not impact its convergence. We [fine-tuned the BERT model](../../examples/language/bert/finetune.py) using both shardformer and non-shardformer approaches. The example that utilizes Shardformer simultaneously with Pipeline Parallelism and Data Parallelism (Zero1). We then compared the accuracy, loss, and F1 score of the training results. -| accuracy | f1 | loss | GPU number | model shard | + +| accuracy | f1 | loss | GPU number | model sharded | | :------: | :-----: | :-----: | :--------: | :---------: | -| 0.82594 | 0.87441 | 0.09913 | 4 | True | -| 0.81884 | 0.87299 | 0.10120 | 2 | True | -| 0.81855 | 0.87124 | 0.10357 | 1 | False | +| 0.84589 | 0.88613 | 0.43414 | 4 | True | +| 0.83594 | 0.88064 | 0.43298 | 1 | False | + Overall, the results demonstrate that using shardformers during model training does not affect the convergence. diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index 7e97bee01b33..45b305733813 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -1,3 +1,5 @@ +from typing import Any + import torch import torch.distributed as dist import torch.nn.functional as F @@ -141,6 +143,240 @@ def backward(ctx, grad_output): return grad_input, grad_weight, grad_bias, None, None, None +class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function): + """Gather input from sequence parallel in forward and reduce-scatter gradient in backward + + Args: + input_ (`torch.Tensor`): The input tensor from sequence parallel region. + process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication. + overlap (`bool`): Whther to overlap the all_gather op and gradient calculate in backward. + + """ + + @staticmethod + def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap=True): + ctx.save_for_backward(input_, weight) + ctx.use_bias = bias is not None + ctx.process_group = process_group + ctx.async_grad_reduce_scatter = async_grad_reduce_scatter + ctx.dim = dim + ctx.overlap = overlap + + input_parallel = _gather(input_, dim, process_group) + + if bias is not None: + output = F.linear(input_parallel, weight, bias) + else: + output = F.linear(input_parallel, weight) + + return output + + @staticmethod + def backward(ctx, grad_output): + input_, weight = ctx.saved_tensors + use_bias = ctx.use_bias + dim = ctx.dim + process_group = ctx.process_group + overlap = ctx.overlap + + if not overlap: + input_parallel = _gather(input_, dim, process_group) + + total_input = input_parallel + grad_input = grad_output.matmul(weight) + grad_output = grad_output.contiguous() + # Convert the tensor shapes to 2D for execution compatibility + if len(grad_output.shape) > 2: + grad_output = grad_output.view(-1, grad_output.shape[-1]) + total_input = total_input.view(-1, total_input.shape[-1]) + + if ctx.async_grad_reduce_scatter: + # Asynchronous reduce-scatter + input_list = [ + item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim) + ] + output = torch.empty(input_.shape, dtype=input_parallel.dtype, + device=input_parallel.device).contiguous() + handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) + # Delay the start of weight gradient computation shortly (3us) to have + # reduce-scatter scheduled first and have GPU resources allocated + _ = torch.empty(1, device=grad_output.device) + 1 + + grad_weight = grad_output.t().matmul(total_input) + grad_bias = grad_output.sum(dim=0) if use_bias else None + + if ctx.async_grad_reduce_scatter: + handle.wait() + + else: + input_ = input_.contiguous() + world_size = dist.get_world_size(process_group) + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + + # do all gather in is async way + gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True) + # calculate gradient and prepare data asynchronously with all-gather + # calculate + grad_input = grad_output.matmul(weight) + grad_output = grad_output.contiguous() + # Convert the tensor shapes to 2D for execution compatibility + if len(grad_output.shape) > 2: + grad_output = grad_output.view(-1, grad_output.shape[-1]) + grad_bias = grad_output.sum(dim=0) if use_bias else None + # prepare data + input_list = [ + item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim) + ] + output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous() + # wait until all-gather finished + gather_handle.wait() + + # do reduce-scatter in async way + reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) + input_parallel = torch.cat(tensor_list, dim=dim).contiguous() + # calculate gradient + if len(input_parallel.shape) > 2: + input_parallel = input_parallel.view(-1, input_parallel.shape[-1]) + grad_weight = grad_output.t().matmul(input_parallel) + # wait until reduce-scatter finished + reducescatter_handle.wait() + + return output, grad_weight, grad_bias, None, None, None, None + + +class _LinearWithReduceScatterForwardGatherBackward(torch.autograd.Function): + """Gather input from sequence parallel in forward and reduce-scatter gradient in backward + + Args: + input_ (`torch.Tensor`): The input tensor from sequence parallel region. + process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication. + + """ + + @staticmethod + def forward(ctx, input_, process_group, dim): + ctx.dim = dim + ctx.process_group = process_group + + # do reduce-scatter + new_shape = list(input_.shape) + assert new_shape[dim] % dist.get_world_size(process_group) == 0, \ + f'The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). ' + new_shape[dim] = new_shape[dim] // dist.get_world_size(process_group) + input_list = [item.contiguous() for item in torch.chunk(input_, dist.get_world_size(process_group), dim=dim)] + output = torch.empty(new_shape, dtype=input_.dtype, device=input_.device) + dist.reduce_scatter(output, input_list, group=process_group) + + return output + + @staticmethod + def backward(ctx, grad_output): + dim = ctx.dim + process_group = ctx.process_group + + return _gather(grad_output, dim, process_group), None, None + + +class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function): + """ + This class is designed for matmul operation with gather forward and reduce-scatter backward. + + Args: + input_ (`torch.Tensor`): input matrix. + dim (int): the dimension to perform split and gather + process_group (`torch.distributed.ProcessGroup`): the process group used for collective communication + + """ + + @staticmethod + def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap): + ctx.save_for_backward(input_, weight) + ctx.use_bias = bias is not None + ctx.process_group = process_group + ctx.async_grad_reduce_scatter = async_grad_reduce_scatter + ctx.dim = dim + ctx.overlap = overlap + + input_parallel = _gather(input_, dim, process_group) + + output = torch.matmul(input_parallel, weight) + + if bias is not None: + output = output + bias + return output + + @staticmethod + def backward(ctx, grad_output): + input_, weight = ctx.saved_tensors + use_bias = ctx.use_bias + dim = ctx.dim + process_group = ctx.process_group + overlap = ctx.overlap + + if not overlap: + input_parallel = _gather(input_, dim, process_group) + + total_input = input_parallel + grad_input = grad_output.matmul(weight.T) + grad_output = grad_output.contiguous() + # Convert the tensor shapes to 2D for execution compatibility + if len(grad_output.shape) > 2: + grad_output = grad_output.view(-1, grad_output.shape[-1]) + total_input = total_input.view(-1, total_input.shape[-1]) + + if ctx.async_grad_reduce_scatter: + # Asynchronous reduce-scatter + input_list = [ + item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim) + ] + output = torch.empty(input_.shape, dtype=input_parallel.dtype, + device=input_parallel.device).contiguous() + handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) + # Delay the start of weight gradient computation shortly (3us) to have + # reduce-scatter scheduled first and have GPU resources allocated + _ = torch.empty(1, device=grad_output.device) + 1 + + grad_weight = total_input.t().matmul(grad_output) + grad_bias = grad_output.sum(dim=0) if use_bias else None + + if ctx.async_grad_reduce_scatter: + handle.wait() + + else: + world_size = dist.get_world_size(process_group) + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + + # do all gather in is async way + gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True) + # calculate gradient and prepare data asynchronously with all-gather + # calculate + grad_input = grad_output.matmul(weight.T) + grad_output = grad_output.contiguous() + # Convert the tensor shapes to 2D for execution compatibility + if len(grad_output.shape) > 2: + grad_output = grad_output.view(-1, grad_output.shape[-1]) + grad_bias = grad_output.sum(dim=0) if use_bias else None + # prepare data + input_list = [ + item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim) + ] + output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous() + # wait until all-gather finished + gather_handle.wait() + + # do reduce-scatter in async way + reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) + input_parallel = torch.cat(tensor_list, dim=dim).contiguous() + # calculate gradient + if len(input_parallel.shape) > 2: + input_parallel = input_parallel.view(-1, input_parallel.shape[-1]) + grad_weight = input_parallel.t().matmul(grad_output) + # wait until reduce-scatter finished + reducescatter_handle.wait() + + return output, grad_weight, grad_bias, None, None, None, None + + class _SplitForwardGatherBackward(torch.autograd.Function): """ Split the input and keep only the corresponding chuck to the rank. @@ -200,6 +436,26 @@ def backward(ctx, grad_output): return _reduce(grad_output, ctx.process_group), None +class _GatherForwardSplitBackward(torch.autograd.Function): + """Gather the input from model parallel region and concatenate. + + Args: + input_: input matrix. + parallel_mode: parallel mode. + dim: dimension + """ + + @staticmethod + def forward(ctx, input_, dim, process_group): + ctx.process_group = process_group + ctx.dim = dim + return _gather(input_, dim, process_group) + + @staticmethod + def backward(ctx, grad_output): + return _split(grad_output, ctx.dim, ctx.process_group), None, None + + def _reduce(input_, process_group): # skip if only one rank involved if dist.get_world_size(process_group) == 1: @@ -235,9 +491,8 @@ def _gather(input_, dim=-1, process_group=None): return input_ # all gather - rank = dist.get_rank(process_group) + input_ = input_.contiguous() tensor_list = [torch.empty_like(input_) for _ in range(world_size)] - tensor_list[rank] = input_ torch.distributed.all_gather(tensor_list, input_, group=process_group) # concat @@ -246,24 +501,27 @@ def _gather(input_, dim=-1, process_group=None): return output -class _GatherForwardSplitBackward(torch.autograd.Function): - """Gather the input from model parallel region and concatenate. +def _reduce_scatter(input_, dim=1, process_group=None): + """ Do reduce-scatter operation. Args: - input_: input matrix. - parallel_mode: parallel mode. - dim: dimension + input_ (`torch.Tensor`): The input tensor from sequence parallel region. + dim (int): The dimension to perform reduce-scatter. + process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication. """ + world_size = dist.get_world_size(process_group) + if world_size == 1: + return input_ - @staticmethod - def forward(ctx, input_, dim, process_group): - ctx.process_group = process_group - ctx.dim = dim - return _gather(input_, dim, process_group) + # reduce-scatter + new_shape = list(input_.shape) + assert new_shape[dim] % dist.get_world_size(process_group) == 0, \ + f'The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). ' + new_shape[dim] = new_shape[dim] // world_size + output = torch.empty(new_shape, dtype=input_.dtype, device=input_.device) + dist.reduce_scatter(output, input_, group=process_group) - @staticmethod - def backward(ctx, grad_output): - return _split(grad_output, ctx.dim, ctx.process_group), None, None + return output def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce): @@ -274,6 +532,22 @@ def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allre return LinearWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce) +def linear_gather_forward_reducescatter_backward(input_, weight, bias, process_group, async_grad_reduce_scatter, dim, + overlap): + return _LinearWithGatherForwardReduceScatterBackward.apply(input_, weight, bias, process_group, + async_grad_reduce_scatter, dim, overlap) + + +def linear_reducescatter_forward_gather_backward(input_, process_group, dim): + return _LinearWithReduceScatterForwardGatherBackward.apply(input_, process_group, dim) + + +def matmul_gather_forward_reducescatter_backward(input_, weight, bias, process_group, async_grad_reduce_scatter, dim, + overlap): + return _MatmulWithGatherForwardReduceScatterBackward.apply(input_, weight, bias, process_group, + async_grad_reduce_scatter, dim, overlap) + + def gather_forward_split_backward(input_, dim, process_group): return _GatherForwardSplitBackward.apply(input_, dim, process_group) diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index d59b68ce4480..111d51b3f8d8 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -24,6 +24,8 @@ from ._operation import ( gather_forward_split_backward, + linear_gather_forward_reducescatter_backward, + linear_reducescatter_forward_gather_backward, linear_with_async_comm, reduce_forward, split_forward_gather_backward, @@ -50,6 +52,8 @@ class Linear1D_Col(ParallelModule): gather_output (bool, optional): If true, call all-gather on output and make Y available to all GPUs, otherwise, every GPU will have its output which is :math:`Y_i = XA_i`, defaults to False + seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False. + overlap (`bool`): If set to ``True``, it will overlap input all-gather with gradient computation during backward, defaults to False. skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion, defaults to False weight_initializer (`typing.Callable`): @@ -69,6 +73,9 @@ def __init__(self, device: torch.device = None, process_group: ProcessGroup = None, gather_output: bool = False, + seq_parallel: bool = False, + seq_parallel_dim: int = 1, + overlap: torch.cuda.Stream = None, skip_bias_add: bool = False, weight: Optional[Parameter] = None, bias_: Optional[Parameter] = None, @@ -80,6 +87,9 @@ def __init__(self, self.in_features = in_features self.out_features = out_features self.gather_output = gather_output + self.seq_parallel = seq_parallel + self.seq_parallel_dim = seq_parallel_dim + self.overlap = overlap self.skip_bias_add = skip_bias_add self.device = device self.process_group = process_group @@ -180,7 +190,12 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: # Matrix multiply. bias = self.bias if not self.skip_bias_add else None - output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True) + if self.seq_parallel: + output_parallel = linear_gather_forward_reducescatter_backward(input_parallel, self.weight, bias, + self.process_group, True, + self.seq_parallel_dim, self.overlap) + else: + output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True) if self.gather_output: # All-gather across the partitions. @@ -203,6 +218,8 @@ class Linear1D_Row(ParallelModule): bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. dtype (`torch.dtype`): The dtype of parameters, defaults to None. parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False. + process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None. + seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False. skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion, defaults to False weight_initializer (:class:`typing.Callable`, optional): @@ -221,6 +238,8 @@ def __init__(self, dtype: torch.dtype = None, device: torch.device = None, process_group: ProcessGroup = None, + seq_parallel: bool = False, + seq_parallel_dim: int = 1, parallel_input: bool = True, skip_bias_add: bool = False, weight: Optional[Parameter] = None, @@ -238,6 +257,8 @@ def __init__(self, self.parallel_input = parallel_input self.skip_bias_add = skip_bias_add self.process_group = process_group + self.seq_parallel = seq_parallel + self.seq_parallel_dim = seq_parallel_dim self.num_partitions = dist.get_world_size(self.process_group) if skip_bias_add and not bias: @@ -373,7 +394,11 @@ def forward(self, input_: Tensor) -> Tensor: output = torch.cat(output_parallel_list, dim=-1) else: output_parallel = F.linear(input_, self.weight) - output = reduce_forward(output_parallel, self.process_group) + if self.seq_parallel: + output = linear_reducescatter_forward_gather_backward(output_parallel, self.process_group, + self.seq_parallel_dim) + else: + output = reduce_forward(output_parallel, self.process_group) if not self.skip_bias_add: if self.bias is not None: diff --git a/colossalai/shardformer/layer/parallel_module.py b/colossalai/shardformer/layer/parallel_module.py index bda147b121ab..4f391920e29b 100644 --- a/colossalai/shardformer/layer/parallel_module.py +++ b/colossalai/shardformer/layer/parallel_module.py @@ -10,6 +10,7 @@ from torch.distributed import ProcessGroup from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, Module +from colossalai.checkpoint_io.utils import gather_distributed_param from colossalai.tensor.d_tensor import ( distribute_tensor, distribute_tensor_with_customization, @@ -56,13 +57,7 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): """ for name, param in self._parameters.items(): if param is not None: - param_ = param if keep_vars else param.detach() - if is_distributed_tensor(param_): - destination[prefix + name] = to_global(param_) - elif is_customized_distributed_tensor(param_): - destination[prefix + name] = to_global_for_customized_distributed_tensor(param_) - else: - destination[prefix + name] = param_ + destination[prefix + name] = gather_distributed_param(param, keep_vars=keep_vars) for name, buf in self._buffers.items(): if buf is not None and name not in self._non_persistent_buffers_set: diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py index df942d43ee2d..5ce77805f9b8 100644 --- a/colossalai/shardformer/layer/qkv_fused_linear.py +++ b/colossalai/shardformer/layer/qkv_fused_linear.py @@ -25,7 +25,9 @@ from ._operation import ( gather_forward_split_backward, + linear_reducescatter_forward_gather_backward, linear_with_async_comm, + matmul_gather_forward_reducescatter_backward, matmul_with_async_comm, reduce_backward, reduce_forward, @@ -150,6 +152,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule): device (`torch.device`): The device of parameters, defaults to None. n_fused (int): The number items fused, defaults to 3 (QKV). process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None. + seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False. gather_output (bool, optional): If true, call all-gather on output and make Y available to all GPUs, otherwise, every GPU will have its output which is :math:`Y_i = XA_i`, defaults to False @@ -173,6 +176,8 @@ def __init__(self, process_group: ProcessGroup = None, async_communication: bool = False, gather_output: bool = False, + seq_parallel: bool = False, + overlap: bool = False, skip_bias_add: bool = False, n_fused: int = 3, weight: Optional[Parameter] = None, @@ -185,6 +190,8 @@ def __init__(self, self.in_features = in_features self.out_features = out_features self.gather_output = gather_output + self.seq_parallel = seq_parallel + self.overlap = overlap self.skip_bias_add = skip_bias_add self.device = device self.n_fused = n_fused @@ -296,15 +303,19 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: assert input_.shape[-1] == self.weight.shape[0], \ 'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format( input_.shape, self.weight.shape, self.weight.shape[-1]) - # Set up backprop all-reduce. - input_parallel = reduce_backward(input_, self.process_group) - # input_parallel = input_ # Matrix multiply. bias = self.bias if not self.skip_bias_add else None - output_parallel = matmul_with_async_comm(input_parallel, self.weight, bias, self.process_group, - self.async_communication) + if self.seq_parallel: + input_parallel = input_ + output_parallel = matmul_gather_forward_reducescatter_backward(input_parallel, self.weight, bias, + self.process_group, True, 1, self.overlap) + else: + # Set up backprop all-reduce. + input_parallel = reduce_backward(input_, self.process_group) + output_parallel = matmul_with_async_comm(input_parallel, self.weight, bias, self.process_group, + self.async_communication) if self.gather_output: # All-gather across the partitions. @@ -329,6 +340,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule): dtype (`torch.dtype`): The dtype of parameters, defaults to None. parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False. skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, + seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False. which is preserved for kernel fusion, defaults to False weight_initializer (:class:`typing.Callable`, optional): The initializer of weight, defaults to kaiming uniform initializer. @@ -346,6 +358,7 @@ def __init__(self, dtype: torch.dtype = None, device: torch.device = None, process_group: ProcessGroup = None, + seq_parallel: bool = False, parallel_input: bool = True, skip_bias_add: bool = False, weight: Optional[Parameter] = None, @@ -363,6 +376,7 @@ def __init__(self, self.parallel_input = parallel_input self.skip_bias_add = skip_bias_add self.process_group = process_group + self.seq_parallel = seq_parallel self.num_partitions = dist.get_world_size(self.process_group) if skip_bias_add and not bias: @@ -499,7 +513,10 @@ def forward(self, input_: Tensor) -> Tensor: output = torch.cat(output_parallel_list, dim=-1) else: output_parallel = torch.matmul(input_, self.weight) - output = reduce_forward(output_parallel, self.process_group) + if self.seq_parallel: + output = linear_reducescatter_forward_gather_backward(output_parallel, self.process_group, 1) + else: + output = reduce_forward(output_parallel, self.process_group) if not self.skip_bias_add: if self.bias is not None: diff --git a/colossalai/shardformer/modeling/bert.py b/colossalai/shardformer/modeling/bert.py index 5bd1c531cc68..30855a622adb 100644 --- a/colossalai/shardformer/modeling/bert.py +++ b/colossalai/shardformer/modeling/bert.py @@ -1,6 +1,6 @@ import math import warnings -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union import torch from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss @@ -29,6 +29,8 @@ from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer import ShardConfig +from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward class BertPipelineForwards: @@ -56,6 +58,7 @@ def bert_model_forward( stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, # this is from the previous stage stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, ): # TODO(jianghai): add explaination of the output here. r""" @@ -177,6 +180,17 @@ def bert_model_forward( start_idx, end_idx = stage_index[0], stage_index[1] # layer_outputs layer_outputs = hidden_states if hidden_states is not None else None + + # split the input tensor along sequence dimension + # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] + if shard_config is not None and shard_config.enable_sequence_parallelism: + hidden_states = split_forward_gather_backward(hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group) + if encoder_hidden_states is not None: + encoder_hidden_states = split_forward_gather_backward( + encoder_hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group) + for idx, encoder_layer in enumerate(self.encoder.layer[start_idx:end_idx], start=start_idx): if stage_manager.is_first_stage() and idx == 0: encoder_attention_mask = encoder_extended_attention_mask @@ -223,11 +237,17 @@ def custom_forward(*inputs): all_cross_attentions = all_cross_attentions + \ (layer_outputs[2],) + # When sequence parallelism done, gather the output tensor in forward and split it in backward + if shard_config is not None and shard_config.enable_sequence_parallelism: + hidden_states = gather_forward_split_backward(hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group) + if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) # end of a stage loop - sequence_output = layer_outputs[0] if layer_outputs is not None else None + sequence_output = hidden_states if hidden_states is not None else None if stage_manager.is_last_stage(): pooled_output = self.pooler(sequence_output) if self.pooler is not None else None @@ -268,6 +288,7 @@ def bert_for_pretraining_forward( hidden_states: Optional[torch.FloatTensor] = None, stage_manager: Optional[PipelineStageManager] = None, stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, ): logger = logging.get_logger(__name__) @@ -294,6 +315,7 @@ def bert_for_pretraining_forward( stage_manager=stage_manager, hidden_states=hidden_states if hidden_states is not None else None, stage_index=stage_index, + shard_config=shard_config, ) past_key_values = None all_hidden_states = None @@ -350,6 +372,7 @@ def bert_lm_head_model_forward( hidden_states: Optional[torch.FloatTensor] = None, stage_manager: Optional[PipelineStageManager] = None, stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, ): r""" encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): @@ -404,7 +427,8 @@ def bert_lm_head_model_forward( return_dict=return_dict, stage_manager=stage_manager, hidden_states=hidden_states if hidden_states is not None else None, - stage_index=stage_index) + stage_index=stage_index, + shard_config=shard_config) past_key_values = None all_hidden_states = None all_self_attentions = None @@ -457,6 +481,7 @@ def bert_for_masked_lm_forward( hidden_states: Optional[torch.Tensor] = None, stage_manager: Optional[PipelineStageManager] = None, stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, ): r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -491,6 +516,7 @@ def bert_for_masked_lm_forward( hidden_states=hidden_states, stage_manager=stage_manager, stage_index=stage_index, + shard_config=shard_config, ) if stage_manager.is_last_stage(): @@ -532,6 +558,7 @@ def bert_for_next_sentence_prediction_forward( hidden_states: Optional[torch.Tensor] = None, stage_manager: Optional[PipelineStageManager] = None, stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, **kwargs, ): # -> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]: @@ -594,7 +621,8 @@ def bert_for_next_sentence_prediction_forward( return_dict=return_dict, hidden_states=hidden_states, stage_manager=stage_manager, - stage_index=stage_index) + stage_index=stage_index, + shard_config=shard_config) if stage_manager.is_last_stage(): pooled_output = outputs[1] @@ -636,6 +664,7 @@ def bert_for_sequence_classification_forward( hidden_states: Optional[torch.Tensor] = None, stage_manager: Optional[PipelineStageManager] = None, stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, ): r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -666,7 +695,8 @@ def bert_for_sequence_classification_forward( return_dict=return_dict, hidden_states=hidden_states, stage_manager=stage_manager, - stage_index=stage_index) + stage_index=stage_index, + shard_config=shard_config) if stage_manager.is_last_stage(): pooled_output = outputs[1] @@ -726,6 +756,7 @@ def bert_for_token_classification_forward( hidden_states: Optional[torch.Tensor] = None, stage_manager: Optional[PipelineStageManager] = None, stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, ): r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -742,21 +773,20 @@ def bert_for_token_classification_forward( logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') output_hidden_states = False - outputs = BertPipelineForwards.bert_model_forward( - self.bert, - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - hidden_states=hidden_states, - stage_manager=stage_manager, - stage_index=stage_index, - ) + outputs = BertPipelineForwards.bert_model_forward(self.bert, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + hidden_states=hidden_states, + stage_manager=stage_manager, + stage_index=stage_index, + shard_config=shard_config) if stage_manager.is_last_stage(): sequence_output = outputs[0] @@ -799,6 +829,7 @@ def bert_for_multiple_choice_forward( hidden_states: Optional[torch.Tensor] = None, stage_manager: Optional[PipelineStageManager] = None, stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, ): r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -843,6 +874,7 @@ def bert_for_multiple_choice_forward( hidden_states=hidden_states, stage_manager=stage_manager, stage_index=stage_index, + shard_config=shard_config, ) if stage_manager.is_last_stage(): pooled_output = outputs[1] @@ -886,6 +918,7 @@ def bert_for_question_answering_forward( hidden_states: Optional[torch.Tensor] = None, stage_manager: Optional[PipelineStageManager] = None, stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, ): # NOTE: the arg start_position and end_position are used only for the last stage r""" @@ -909,21 +942,20 @@ def bert_for_question_answering_forward( logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') output_hidden_states = False - outputs = BertPipelineForwards.bert_model_forward( - self.bert, - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - hidden_states=hidden_states, - stage_manager=stage_manager, - stage_index=stage_index, - ) + outputs = BertPipelineForwards.bert_model_forward(self.bert, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + hidden_states=hidden_states, + stage_manager=stage_manager, + stage_index=stage_index, + shard_config=shard_config) if stage_manager.is_last_stage(): sequence_output = outputs[0] @@ -1101,3 +1133,153 @@ def forward(self: BertOutput, hidden_states: torch.Tensor, input_tensor: torch.T return hidden_states return forward + + +def bert_sequence_parallel_forward_fn(shard_config: ShardConfig): + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.config.is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + + if token_type_ids is None: + if hasattr(self.embeddings, "token_type_ids"): + buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + + # split the input tensor along sequence dimension + # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] + embedding_output = split_forward_gather_backward(embedding_output, + dim=1, + process_group=shard_config.tensor_parallel_process_group) + if encoder_hidden_states is not None: + encoder_hidden_states = split_forward_gather_backward( + encoder_hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = encoder_outputs[0] + + # When sequence parallelism done, gather the output tensor in forward and split it in backward + sequence_output = gather_forward_split_backward(sequence_output, + dim=1, + process_group=shard_config.tensor_parallel_process_group) + + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + return forward diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py index 12276635ecfa..66f24dc6088b 100644 --- a/colossalai/shardformer/modeling/bloom.py +++ b/colossalai/shardformer/modeling/bloom.py @@ -23,6 +23,10 @@ from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward +from colossalai.shardformer.shard import ShardConfig + +logger = logging.get_logger(__name__) def build_bloom_alibi_tensor_fn(process_group: ProcessGroup) -> torch.Tensor: @@ -111,6 +115,7 @@ def bloom_model_forward( stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, **deprecated_arguments, ) -> Union[Tuple[torch.Tensor, ...], 'BaseModelOutputWithPastAndCrossAttentions']: @@ -205,6 +210,13 @@ def bloom_model_forward( past_key_values_length=past_key_values_length, ) + # split the input tensor along sequence dimension + # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] + if shard_config.enable_sequence_parallelism: + hidden_states = split_forward_gather_backward(hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group) + start_idx, end_idx = stage_index[0], stage_index[1] for i, (block, layer_past) in enumerate(zip(self.h[start_idx:end_idx], past_key_values[start_idx:end_idx]), start=start_idx): @@ -248,6 +260,12 @@ def custom_forward(*inputs): all_self_attentions = all_self_attentions + \ (outputs[2 if use_cache else 1],) + # When sequence parallelism done, gather the output tensor in forward and split it in backward + if shard_config.enable_sequence_parallelism: + hidden_states = gather_forward_split_backward(hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group) + if stage_manager.is_last_stage(): # Add last hidden state hidden_states = self.ln_f(hidden_states) @@ -287,6 +305,7 @@ def bloom_for_causal_lm_forward(self: BloomForCausalLM, stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, **deprecated_arguments): r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -327,7 +346,8 @@ def bloom_for_causal_lm_forward(self: BloomForCausalLM, return_dict=return_dict, stage_manager=stage_manager, hidden_states=hidden_states, - stage_index=stage_index) + stage_index=stage_index, + shard_config=shard_config) past_key_values = None all_hidden_states = None all_self_attentions = None @@ -380,6 +400,7 @@ def bloom_for_sequence_classification_forward( stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, **deprecated_arguments, ): r""" @@ -424,6 +445,7 @@ def bloom_for_sequence_classification_forward( stage_manager=stage_manager, hidden_states=hidden_states, stage_index=stage_index, + shard_config=shard_config, ) past_key_values = None all_hidden_states = None @@ -503,6 +525,7 @@ def bloom_for_token_classification_forward( stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, **deprecated_arguments, ): r""" @@ -547,6 +570,7 @@ def bloom_for_token_classification_forward( stage_manager=stage_manager, hidden_states=hidden_states, stage_index=stage_index, + shard_config=shard_config, ) past_key_values = None all_hidden_states = None @@ -597,6 +621,7 @@ def bloom_for_question_answering_forward( stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, ): r""" start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -632,6 +657,7 @@ def bloom_for_question_answering_forward( stage_manager=stage_manager, hidden_states=hidden_states, stage_index=stage_index, + shard_config=shard_config, ) past_key_values = None all_hidden_states = None @@ -700,8 +726,7 @@ def forward( fused_qkv = self.query_key_value(hidden_states) (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) - batch_size, tgt_len, _ = hidden_states.size() - assert tgt_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4." + batch_size, tgt_len, _ = query_layer.size() _, kv_length, _, _ = key_layer.size() @@ -896,3 +921,156 @@ def forward(self: BloomGelu, x: torch.Tensor) -> torch.Tensor: return self.bloom_gelu_forward(x, bias) return forward + + +def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig): + + from transformers import BloomModel + + def forward( + self: BloomModel, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **deprecated_arguments, + ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: + if deprecated_arguments.pop("position_ids", False) is not False: + # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None` + warnings.warn( + "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" + " passing `position_ids`.", + FutureWarning, + ) + if len(deprecated_arguments) > 0: + raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if past_key_values is None: + past_key_values = tuple([None] * len(self.h)) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape batch_size x num_heads x N x N + # head_mask has shape n_layer x batch x num_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + hidden_states = self.word_embeddings_layernorm(inputs_embeds) + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + # Compute alibi tensor: check build_alibi_tensor documentation + seq_length_with_past = seq_length + past_key_values_length = 0 + if past_key_values[0] is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) + else: + attention_mask = attention_mask.to(hidden_states.device) + + alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype) + + causal_mask = self._prepare_attn_mask( + attention_mask, + input_shape=(batch_size, seq_length), + past_key_values_length=past_key_values_length, + ) + # split the input tensor along sequence dimension + # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] + hidden_states = split_forward_gather_backward(hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group) + + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + alibi, + causal_mask, + layer_past, + head_mask[i], + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=causal_mask, + head_mask=head_mask[i], + use_cache=use_cache, + output_attentions=output_attentions, + alibi=alibi, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + + # When sequence parallelism done, gather the output tensor in forward and split it in backward + hidden_states = gather_forward_split_backward(hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group) + # Add last hidden state + hidden_states = self.ln_f(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + return forward diff --git a/colossalai/shardformer/modeling/chatglm.py b/colossalai/shardformer/modeling/chatglm2.py similarity index 69% rename from colossalai/shardformer/modeling/chatglm.py rename to colossalai/shardformer/modeling/chatglm2.py index 409e2e1f5497..16dcf87c8cfc 100644 --- a/colossalai/shardformer/modeling/chatglm.py +++ b/colossalai/shardformer/modeling/chatglm2.py @@ -9,6 +9,8 @@ from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer import ShardConfig +from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ( ChatGLMForConditionalGeneration, @@ -146,6 +148,7 @@ def chatglm_model_forward( stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, ): logger = logging.get_logger(__name__) output_hidden_states = (output_hidden_states @@ -198,6 +201,11 @@ def chatglm_model_forward( all_self_attentions = None all_hidden_states = () if output_hidden_states else None start_idx, end_idx = stage_index[0], stage_index[1] + + if shard_config.enable_sequence_parallelism: + hidden_states = split_forward_gather_backward(hidden_states, + dim=0, + process_group=shard_config.tensor_parallel_process_group) for idx in range(start_idx, end_idx): layer = self.encoder._get_layer(idx) if output_hidden_states: @@ -214,6 +222,11 @@ def chatglm_model_forward( hidden_states, kv_cache = layer_ret if use_cache: presents = presents + (kv_cache,) + + if shard_config.enable_sequence_parallelism: + hidden_states = gather_forward_split_backward(hidden_states, + dim=0, + process_group=shard_config.tensor_parallel_process_group) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if stage_manager.is_last_stage(): @@ -233,23 +246,22 @@ def chatglm_model_forward( return {'hidden_states': hidden_states} @staticmethod - def chatglm_for_conditional_generation_forward( - self: ChatGLMForConditionalGeneration, - input_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[Tuple[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - return_last_logit: Optional[bool] = False, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None, - ): + def chatglm_for_conditional_generation_forward(self: ChatGLMForConditionalGeneration, + input_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + return_last_logit: Optional[bool] = False, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None): logger = logging.get_logger(__name__) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = (return_dict if return_dict is not None else self.config.use_return_dict) @@ -266,6 +278,7 @@ def chatglm_for_conditional_generation_forward( stage_manager=stage_manager, hidden_states=hidden_states, stage_index=stage_index, + shard_config=shard_config, ) if stage_manager.is_last_stage(): hidden_states = transformer_outputs[0] @@ -296,3 +309,91 @@ def chatglm_for_conditional_generation_forward( ) else: return transformer_outputs + + +def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig): + + def forward( + self, + input_ids, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, + full_attention_mask: Optional[torch.BoolTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = (return_dict if return_dict is not None else self.config.use_return_dict) + + batch_size, seq_length = input_ids.shape + + if inputs_embeds is None: + inputs_embeds = self.embedding(input_ids) + + if self.pre_seq_len is not None: + if past_key_values is None: + past_key_values = self.get_prompt( + batch_size=batch_size, + device=input_ids.device, + dtype=inputs_embeds.dtype, + ) + if attention_mask is not None: + attention_mask = torch.cat( + [ + attention_mask.new_ones((batch_size, self.pre_seq_len)), + attention_mask, + ], + dim=-1, + ) + + if full_attention_mask is None: + if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): + full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) + + # Rotary positional embeddings + rotary_pos_emb = self.rotary_pos_emb(self.seq_length) + if position_ids is not None: + rotary_pos_emb = rotary_pos_emb[position_ids] + else: + rotary_pos_emb = rotary_pos_emb[None, :seq_length] + rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous() + + # Run encoder. + # [seq_len, batch_size, hidden_size] -> [seq_len/TP_size, batch_size, hidden_size] + inputs_embeds = split_forward_gather_backward(inputs_embeds, + dim=0, + process_group=shard_config.tensor_parallel_process_group) + hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( + inputs_embeds, + full_attention_mask, + rotary_pos_emb=rotary_pos_emb, + kv_caches=past_key_values, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + ) + + hidden_states = gather_forward_split_backward(hidden_states, + dim=0, + process_group=shard_config.tensor_parallel_process_group) + + if not return_dict: + return tuple(v for v in [ + hidden_states, + presents, + all_hidden_states, + all_self_attentions, + ] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + return forward diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 47835d5d5468..8ed367b25349 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -21,6 +21,8 @@ from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward +from colossalai.shardformer.shard import ShardConfig class GPT2PipelineForwards: @@ -47,7 +49,8 @@ def gpt2_model_forward( return_dict: Optional[bool] = None, stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]: + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]: # This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2Model.forward. # Please refer to original code of transformers for more details. @@ -145,7 +148,7 @@ def gpt2_model_forward( if token_type_ids is not None: token_type_embeds = self.wte(token_type_ids) hidden_states = hidden_states + token_type_embeds - hidden_states = self.drop(hidden_states) + hidden_states = self.drop(hidden_states) output_shape = input_shape + (hidden_states.size(-1),) @@ -159,6 +162,13 @@ def gpt2_model_forward( all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None all_hidden_states = () if output_hidden_states else None + # split the input tensor along sequence dimension + # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] + if shard_config.enable_sequence_parallelism: + hidden_states = split_forward_gather_backward(hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group) + # Going through held blocks. start_idx, end_idx = stage_index[0], stage_index[1] for i in range(start_idx, end_idx): @@ -212,6 +222,12 @@ def custom_forward(*inputs): if self.config.add_cross_attention: all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) + # When sequence parallelism done, gather the output tensor in forward and split it in backward + if shard_config.enable_sequence_parallelism: + hidden_states = gather_forward_split_backward(hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group) + if stage_manager.is_last_stage(): hidden_states = self.ln_f(hidden_states) @@ -257,7 +273,8 @@ def gpt2_lmhead_model_forward( return_dict: Optional[bool] = None, stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, CausalLMOutputWithCrossAttentions]: + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None) -> Union[Dict, Tuple, CausalLMOutputWithCrossAttentions]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set @@ -285,7 +302,8 @@ def gpt2_lmhead_model_forward( return_dict=return_dict, stage_manager=stage_manager, hidden_states=hidden_states, - stage_index=stage_index) + stage_index=stage_index, + shard_config=shard_config) # If not at the last stage, return hidden_states as in GPT2Model if not stage_manager.is_last_stage(): @@ -335,7 +353,8 @@ def gpt2_double_heads_model_forward( return_dict: Optional[bool] = None, stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, GPT2DoubleHeadsModelOutput]: + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None) -> Union[Dict, Tuple, GPT2DoubleHeadsModelOutput]: r""" mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input): Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) - @@ -367,7 +386,8 @@ def gpt2_double_heads_model_forward( return_dict=return_dict, stage_manager=stage_manager, hidden_states=hidden_states, - stage_index=stage_index) + stage_index=stage_index, + shard_config=shard_config) # If not at the last stage, return hidden_states as in GPT2Model if not stage_manager.is_last_stage(): @@ -421,7 +441,8 @@ def gpt2_for_question_answering_forward( return_dict: Optional[bool] = None, stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, QuestionAnsweringModelOutput]: + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None) -> Union[Dict, Tuple, QuestionAnsweringModelOutput]: r""" start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for position (index) of the start of the labelled span for computing the token classification loss. @@ -449,7 +470,8 @@ def gpt2_for_question_answering_forward( return_dict=return_dict, stage_manager=stage_manager, hidden_states=hidden_states, - stage_index=stage_index) + stage_index=stage_index, + shard_config=shard_config) # If not at the last stage, return hidden_states as in GPT2Model if not stage_manager.is_last_stage(): @@ -508,7 +530,8 @@ def gpt2_for_token_classification_forward( return_dict: Optional[bool] = None, stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, TokenClassifierOutput]: + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None) -> Union[Dict, Tuple, TokenClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., @@ -534,7 +557,8 @@ def gpt2_for_token_classification_forward( return_dict=return_dict, stage_manager=stage_manager, hidden_states=hidden_states, - stage_index=stage_index) + stage_index=stage_index, + shard_config=shard_config) # If not at the last stage, return hidden_states as in GPT2Model if not stage_manager.is_last_stage(): @@ -578,7 +602,8 @@ def gpt2_for_sequence_classification_forward( return_dict: Optional[bool] = None, stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, SequenceClassifierOutputWithPast]: + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None) -> Union[Dict, Tuple, SequenceClassifierOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., @@ -613,7 +638,8 @@ def gpt2_for_sequence_classification_forward( return_dict=return_dict, stage_manager=stage_manager, hidden_states=hidden_states, - stage_index=stage_index) + stage_index=stage_index, + shard_config=shard_config) # If not at the last stage, return hidden_states as in GPT2Model if not stage_manager.is_last_stage(): @@ -696,7 +722,6 @@ def forward( output_attentions: Optional[bool] = False, ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: _, tgt_len, _ = hidden_states.size() - assert tgt_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4." if encoder_hidden_states is not None: if not hasattr(self, "q_attn"): @@ -753,3 +778,210 @@ def forward( return outputs return forward + + +def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig): + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + if position_ids is not None: + position_ids = position_ids.view(-1, input_shape[-1]) + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + else: + past_length = past_key_values[0][0].size(-2) + if position_ids is None: + position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + + # GPT2Attention mask. + if attention_mask is not None: + if batch_size <= 0: + raise ValueError("batch_size has to be defined and > 0") + attention_mask = attention_mask.view(batch_size, -1) + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and the dtype's smallest value for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.add_cross_attention and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # head_mask has shape n_layer x batch x n_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + + hidden_states = self.drop(hidden_states) + + output_shape = input_shape + (hidden_states.size(-1),) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + all_hidden_states = () if output_hidden_states else None + + # split the input tensor along sequence dimension + # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] + hidden_states = split_forward_gather_backward(hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group) + + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure layer_past is on same device as hidden_states (might not be correct) + if layer_past is not None: + layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if isinstance(head_mask, torch.Tensor): + head_mask = head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, use_cache, output_attentions) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + None, + attention_mask, + head_mask[i], + encoder_hidden_states, + encoder_attention_mask, + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + # When sequence parallelism done, gather the output tensor in forward and split it in backward + hidden_states = gather_forward_split_backward(hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group) + + hidden_states = self.ln_f(hidden_states) + hidden_states = hidden_states.view(output_shape) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] + if v is not None) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + return forward diff --git a/colossalai/shardformer/modeling/whisper.py b/colossalai/shardformer/modeling/whisper.py index 0a16c6f788da..62f8f7b4763e 100644 --- a/colossalai/shardformer/modeling/whisper.py +++ b/colossalai/shardformer/modeling/whisper.py @@ -1,7 +1,26 @@ -from typing import Optional, Tuple +import logging +import random +from typing import Dict, List, Optional, Set, Tuple, Union import torch from torch import nn +from torch.nn import CrossEntropyLoss +from transformers.modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, + SequenceClassifierOutput, +) +from transformers.models.whisper.modeling_whisper import ( + WhisperEncoder, + WhisperForAudioClassification, + WhisperForConditionalGeneration, + WhisperModel, +) +from transformers.utils import logging + +from colossalai.pipeline.stage_manager import PipelineStageManager def get_whisper_flash_attention_forward(): @@ -247,3 +266,697 @@ def forward( return outputs return forward + + +class WhisperPipelineForwards: + ''' + This class serves as a micro library for forward function substitution of Llama models + under pipeline setting. + ''' + + @staticmethod + def whisper_encoder_forward( + self: WhisperEncoder, + input_features, + attention_mask=None, + head_mask=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + encoder_states=None, + all_attentions=None, + stage_index: Optional[List[int]] = None, + decoder_starting_stage: Optional[int] = None, + ): + r""" + Args: + input_features (`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`): + Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be + obtained by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a + `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into + `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding + and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] + attention_mask (`torch.Tensor`)`, *optional*): + Whisper does not support masking of the `input_features`, this argument is preserved for compatibility, + but it is not used. By default the silence in the input log mel spectrogram are ignored. + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + logger = logging.get_logger(__name__) + + stage = stage_manager.stage + at_first_stage = (stage == 0) + at_last_stage = (stage == decoder_starting_stage - 1) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Process inputs if at the first stage of encoder. + if at_first_stage: + inputs_embeds = nn.functional.gelu(self.conv1(input_features)) + inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) + + inputs_embeds = inputs_embeds.permute(0, 2, 1) + embed_pos = self.embed_positions.weight + + hidden_states = inputs_embeds + embed_pos + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + assert head_mask.size()[0] == ( + len(self.layers) + ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." + + else: + if hidden_states is None: + raise ValueError( + "hidden_states shouldn't be None for stages other than the first stage of encoder/decoder.") + + start_idx, end_idx = stage_index[0], stage_index[1] + + for idx in range(start_idx, end_idx): + encoder_layer = self.layers[idx] + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = random.uniform(0, 1) + if self.training and (dropout_probability < self.layerdrop): # skip the layer + layer_outputs = (None, None) + else: + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(encoder_layer), + hidden_states, + None, + (head_mask[idx] if head_mask is not None else None), + ) + else: + layer_outputs = encoder_layer( + hidden_states, + None, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if at_last_stage: + hidden_states = self.layer_norm(hidden_states) + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput(last_hidden_state=hidden_states, + hidden_states=encoder_states, + attentions=all_attentions) + + else: + return {'hidden_states': hidden_states, 'head_mask': head_mask} + + @staticmethod + def whisper_decoder_forward( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + head_mask=None, + cross_attn_head_mask=None, + past_key_values=None, + inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + decoder_starting_stage: Optional[int] = None, + ): + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`WhisperTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention + on hidden heads. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of + shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing + `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more + control over how to convert `input_ids` indices into associated vectors than the model's internal + embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + logger = logging.get_logger(__name__) + stage = stage_manager.stage + at_first_stage = (stage == decoder_starting_stage) + at_last_stage = (stage == stage_manager.num_stages - 1) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + next_decoder_cache = () if use_cache else None + + # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): + if attn_mask is not None: + assert attn_mask.size()[0] == (len(self.layers)), ( + f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}.") + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if at_first_stage: + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + # embed positions + if input_ids is not None: + positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length) + else: + positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length) + + attention_mask = self._prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds, + past_key_values_length) + + hidden_states = inputs_embeds + positions + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`..." + ) + use_cache = False + + else: + + if hidden_states is None: + raise ValueError( + "hidden_states shouldn't be None for stages other than the first stage of encoder/decoder.") + input_shape = hidden_states.size()[:-1] + + attention_mask = self._prepare_decoder_attention_mask(attention_mask, input_shape, hidden_states, + past_key_values_length) + + start_idx, end_idx = stage_index[0], stage_index[1] + + for idx in range(start_idx, end_idx): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + decoder_layer = self.layers[idx] + + if output_hidden_states: + all_hidden_states += (hidden_states,) + dropout_probability = random.uniform(0, 1) + if self.training and (dropout_probability < self.layerdrop): + continue + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, use_cache) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + encoder_hidden_states, + None, # encoder attention mask + head_mask[idx] if head_mask is not None else None, + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, + None, # past_key_value + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=(cross_attn_head_mask[idx] + if cross_attn_head_mask is not None else None), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + if at_last_stage: + hidden_states = self.layer_norm(hidden_states) + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + if v is not None) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + else: + return { + 'head_mask': head_mask, + 'cross_attn_head_mask': cross_attn_head_mask, + 'hidden_states': hidden_states, + } + + @staticmethod + def whisper_model_forward( + self: WhisperModel, + input_features: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + decoder_starting_stage: Optional[int] = None, + ): + r""" + Returns: + + Example: + ```python + >>> import torch + >>> from transformers import AutoFeatureExtractor, WhisperModel + >>> from datasets import load_dataset + + >>> model = WhisperModel.from_pretrained("openai/whisper-base") + >>> feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-base") + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + >>> inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt") + >>> input_features = inputs.input_features + >>> decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id + >>> last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state + >>> list(last_hidden_state.shape) + [1, 2, 512] + ```""" + # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + if past_key_values: + logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.') + past_key_values = None + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if use_cache: + logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') + use_cache = False + + logger = logging.get_logger(__name__) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + in_decoder = stage_manager.stage >= decoder_starting_stage + if not in_decoder: + if encoder_outputs is None: + input_features = self._mask_input_features(input_features, attention_mask=attention_mask) + + encoder_outputs = WhisperPipelineForwards.whisper_encoder_forward( + self.encoder, + input_features, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + decoder_starting_stage=decoder_starting_stage) + + if stage_manager.stage == decoder_starting_stage - 1: + # last stage of encoder + return {'encoder_hidden_states': encoder_outputs[0]} + else: + return encoder_outputs + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + at_last_decoder_stage = stage_manager.is_last_stage() + at_first_decoder_stage = stage_manager.stage == decoder_starting_stage + if encoder_outputs is not None: + encoder_hidden_states = encoder_outputs[0] + elif encoder_hidden_states is None: + raise ValueError("Non-empty encoder_hidden_states should be passed in at decoder stages.") + + if not at_first_decoder_stage and hidden_states is None: + raise ValueError("If not at the first layer of decoder, non-empty hidden_states must be provided.") + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = WhisperPipelineForwards.whisper_decoder_forward(self.decoder, + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_hidden_states, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + decoder_starting_stage=decoder_starting_stage) + + # Directly return outputs of overloaded Whisper forward if not at last stage. + if not at_last_decoder_stage: + # encoder_hidden_states should be passed to the next stage + decoder_outputs['encoder_hidden_states'] = encoder_hidden_states + return decoder_outputs + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_hidden_states, + ) + + @staticmethod + def whisper_for_conditional_generation_forward( + self: WhisperForConditionalGeneration, + input_features: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + decoder_starting_stage: Optional[int] = None, + ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` + or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is + only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> import torch + >>> from transformers import AutoProcessor, WhisperForConditionalGeneration + >>> from datasets import load_dataset + + >>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en") + >>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en") + + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + + >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt") + >>> input_features = inputs.input_features + + >>> generated_ids = model.generate(inputs=input_features) + + >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + >>> transcription + ' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.' + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None: + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id, + self.config.decoder_start_token_id) + in_decoder = stage_manager.stage >= decoder_starting_stage + at_last_decoder_stage = stage_manager.is_last_stage() + outputs = WhisperPipelineForwards.whisper_model_forward(self.model, + input_features, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + encoder_outputs=encoder_outputs, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + stage_index=stage_index, + decoder_starting_stage=decoder_starting_stage) + if not in_decoder: + return outputs + + if not at_last_decoder_stage: + # encoder_hidden_states should be passed to the next stage + outputs['encoder_hidden_states'] = encoder_hidden_states + return outputs + + lm_logits = self.proj_out(outputs[0]) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + # move labels to correct device to enable PP + labels = labels.to(lm_logits.device) + loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.reshape(-1)) + + if not return_dict: + output = (lm_logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return Seq2SeqLMOutput( + loss=loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + @staticmethod + def whisper_for_audio_classification_forward( + self: WhisperForAudioClassification, + input_features: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + encoder_states=None, + all_attentions=None, + stage_index: Optional[List[int]] = None, + decoder_starting_stage: Optional[int] = None, + ): + r""" + This function is modified on the basis of transformers.models.whisper.modeling_whisper.WhisperForAudioClassification.forward. + Please refer to original code of transformers for more details. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # audio_classification only holds encoder + encoder_outputs = WhisperPipelineForwards.whisper_encoder_forward( + self.encoder, + input_features, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + decoder_starting_stage=decoder_starting_stage, + ) + + if not stage_manager.is_last_stage(): + return encoder_outputs + + if self.config.use_weighted_layer_sum: + hidden_states = torch.stack(encoder_outputs, dim=1) + norm_weights = nn.functional.softmax(self.layer_weights, dim=-1) + hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) + else: + hidden_states = encoder_outputs[0] + + hidden_states = self.projector(hidden_states) + pooled_output = hidden_states.mean(dim=1) + + logits = self.classifier(pooled_output) + + loss = None + + if labels is not None: + loss_fct = CrossEntropyLoss() + # move labels to correct device to enable PP + labels = labels.to(logits.device) + loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + encoder_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index eec339c02872..2fe49f0d5afe 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -125,9 +125,9 @@ class PolicyLocation: # ChatGLM "colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMModel": - PolicyLocation(file_name="chatglm", class_name="ChatGLMModelPolicy"), + PolicyLocation(file_name="chatglm2", class_name="ChatGLMModelPolicy"), "colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMForConditionalGeneration": - PolicyLocation(file_name="chatglm", class_name="ChatGLMForConditionalGenerationPolicy"), + PolicyLocation(file_name="chatglm2", class_name="ChatGLMForConditionalGenerationPolicy"), } diff --git a/colossalai/shardformer/policies/base_policy.py b/colossalai/shardformer/policies/base_policy.py index 69493bfb6007..961c6a5259fe 100644 --- a/colossalai/shardformer/policies/base_policy.py +++ b/colossalai/shardformer/policies/base_policy.py @@ -11,17 +11,12 @@ from colossalai.pipeline.stage_manager import PipelineStageManager +from ..layer.parallel_module import ParallelModule from ..shard.shard_config import ShardConfig __all__ = ["ParallelModule", "SubModuleReplacementDescription", "ModulePolicyDescription", "Policy"] -class ParallelModule(): - - def __init__(self): - pass - - @dataclass class SubModuleReplacementDescription: r""" diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index ace9ada3904f..a141b7bd8fdf 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -10,6 +10,7 @@ from .._utils import getattr_, setattr_ from ..modeling.bert import ( BertPipelineForwards, + bert_sequence_parallel_forward_fn, get_bert_flash_attention_forward, get_jit_fused_bert_output_forward, get_jit_fused_bert_self_output_forward, @@ -47,13 +48,15 @@ def module_policy(self): from transformers.models.bert.modeling_bert import ( BertEmbeddings, BertLayer, + BertModel, BertOutput, BertSelfAttention, BertSelfOutput, ) policy = {} - + use_sequence_parallel = self.shard_config.enable_sequence_parallelism + overlap = self.shard_config.enable_sequence_overlap if self.shard_config.enable_tensor_parallelism: policy[BertLayer] = ModulePolicyDescription(attribute_replacement={ "attention.self.all_head_size": @@ -69,14 +72,26 @@ def module_policy(self): SubModuleReplacementDescription( suffix="attention.self.query", target_module=col_nn.Linear1D_Col, + kwargs={ + "seq_parallel": use_sequence_parallel, + "overlap": overlap + }, ), SubModuleReplacementDescription( suffix="attention.self.key", target_module=col_nn.Linear1D_Col, + kwargs={ + "seq_parallel": use_sequence_parallel, + "overlap": overlap + }, ), SubModuleReplacementDescription( suffix="attention.self.value", target_module=col_nn.Linear1D_Col, + kwargs={ + "seq_parallel": use_sequence_parallel, + "overlap": overlap + }, ), SubModuleReplacementDescription( suffix="attention.self.dropout", @@ -85,6 +100,7 @@ def module_policy(self): SubModuleReplacementDescription( suffix="attention.output.dense", target_module=col_nn.Linear1D_Row, + kwargs={"seq_parallel": use_sequence_parallel}, ), SubModuleReplacementDescription( suffix="attention.output.dropout", @@ -93,10 +109,15 @@ def module_policy(self): SubModuleReplacementDescription( suffix="intermediate.dense", target_module=col_nn.Linear1D_Col, + kwargs={ + "seq_parallel": use_sequence_parallel, + "overlap": overlap + }, ), SubModuleReplacementDescription( suffix="output.dense", target_module=col_nn.Linear1D_Row, + kwargs={"seq_parallel": use_sequence_parallel}, ), SubModuleReplacementDescription( suffix="output.dropout", @@ -115,6 +136,12 @@ def module_policy(self): ) ]) + if use_sequence_parallel: + self.append_or_create_method_replacement( + description={'forward': bert_sequence_parallel_forward_fn(self.shard_config)}, + policy=policy, + target_key=BertModel) + # optimization configuration if self.shard_config.enable_fused_normalization: # Handle bert layer @@ -141,20 +168,26 @@ def module_policy(self): # use flash attention if self.shard_config.enable_flash_attention: - policy[BertSelfAttention] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_bert_flash_attention_forward(), - }) + }, + policy=policy, + target_key=BertSelfAttention) # use jit operator if self.shard_config.enable_jit_fused: - policy[BertSelfOutput] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_jit_fused_bert_self_output_forward(), 'dropout_add': get_jit_fused_dropout_add_func(), - }) - policy[BertOutput] = ModulePolicyDescription(method_replacement={ + }, + policy=policy, + target_key=BertSelfOutput) + self.append_or_create_method_replacement(description={ 'forward': get_jit_fused_bert_output_forward(), 'dropout_add': get_jit_fused_dropout_add_func(), - }) + }, + policy=policy, + target_key=BertOutput) return policy @@ -205,7 +238,13 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli layers_per_stage = Policy.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} + method_replacement = { + 'forward': + partial(new_forward, + stage_manager=stage_manager, + stage_index=stage_index, + shard_config=self.shard_config) + } self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) diff --git a/colossalai/shardformer/policies/blip2.py b/colossalai/shardformer/policies/blip2.py index 50356302e93e..2e5388ab0490 100644 --- a/colossalai/shardformer/policies/blip2.py +++ b/colossalai/shardformer/policies/blip2.py @@ -285,34 +285,30 @@ def module_policy(self): # use flash attention if self.shard_config.enable_flash_attention: - policy[Blip2Attention] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_blip2_flash_attention_forward(), - }) + }, + policy=policy, + target_key=Blip2Attention) # use jit operator if self.shard_config.enable_jit_fused: - policy[Blip2QFormerSelfOutput] = ModulePolicyDescription( - method_replacement={ - 'forward': get_jit_fused_blip2_QFormer_self_output_forward(), - 'dropout_add': get_jit_fused_dropout_add_func(), - }) - policy[Blip2QFormerOutput] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ + 'forward': get_jit_fused_blip2_QFormer_self_output_forward(), + 'dropout_add': get_jit_fused_dropout_add_func(), + }, + policy=policy, + target_key=Blip2QFormerSelfOutput) + self.append_or_create_method_replacement(description={ 'forward': get_jit_fused_blip2_QFormer_output_forward(), 'dropout_add': get_jit_fused_dropout_add_func(), - }) + }, + policy=policy, + target_key=Blip2QFormerOutput) return policy def postprocess(self): - binding_map = { - 'language_model.model.decoder.embed_tokens': 'language_model.lm_head', - } - - for k, v in binding_map.items(): - src_mod = getattr_(self.model, k) - dst_mod = getattr_(self.model, v) - dst_mod.weight = src_mod.weight - return self.model diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index b35764db3870..7c418d02bcb6 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -12,6 +12,7 @@ BloomPipelineForwards, build_bloom_alibi_tensor_fn, get_bloom_flash_attention_forward, + get_bloom_sequence_parallel_forward_fn, get_jit_fused_bloom_attention_forward, get_jit_fused_bloom_gelu_forward, get_jit_fused_bloom_mlp_forward, @@ -43,6 +44,8 @@ def module_policy(self): policy = {} + use_sequence_parallel = self.shard_config.enable_sequence_parallelism + overlap = self.shard_config.enable_sequence_overlap if self.shard_config.enable_tensor_parallelism: policy[BloomBlock] = ModulePolicyDescription(attribute_replacement={ "self_attention.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, @@ -53,11 +56,14 @@ def module_policy(self): SubModuleReplacementDescription( suffix="self_attention.query_key_value", target_module=col_nn.Linear1D_Col, - ), + kwargs={ + 'seq_parallel': use_sequence_parallel, + 'overlap': overlap + }), SubModuleReplacementDescription( suffix="self_attention.dense", target_module=col_nn.Linear1D_Row, - ), + kwargs={'seq_parallel': use_sequence_parallel}), SubModuleReplacementDescription( suffix="self_attention.attention_dropout", target_module=col_nn.DropoutForParallelInput, @@ -65,11 +71,14 @@ def module_policy(self): SubModuleReplacementDescription( suffix="mlp.dense_h_to_4h", target_module=col_nn.Linear1D_Col, - ), + kwargs={ + 'seq_parallel': use_sequence_parallel, + 'overlap': overlap + }), SubModuleReplacementDescription( suffix="mlp.dense_4h_to_h", target_module=col_nn.Linear1D_Row, - ), + kwargs={'seq_parallel': use_sequence_parallel}), ]) policy[BloomModel] = ModulePolicyDescription( @@ -116,26 +125,40 @@ def module_policy(self): policy=policy, target_key=BloomBlock) + if use_sequence_parallel: + self.append_or_create_method_replacement( + description={'forward': get_bloom_sequence_parallel_forward_fn(self.shard_config)}, + policy=policy, + target_key=BloomModel) + if self.shard_config.enable_flash_attention: - policy[BloomAttention] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_bloom_flash_attention_forward(), - 'dropout_add': get_dropout_add_func() - }) + 'dropout_add': get_dropout_add_func(), + }, + policy=policy, + target_key=BloomAttention) # enable jit fused operator if self.shard_config.enable_jit_fused: - policy[BloomAttention] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_jit_fused_bloom_attention_forward(), 'dropout_add': get_jit_fused_dropout_add_func(), - }) - policy[BloomMLP] = ModulePolicyDescription(method_replacement={ + }, + policy=policy, + target_key=BloomAttention) + self.append_or_create_method_replacement(description={ 'forward': get_jit_fused_bloom_mlp_forward(), 'dropout_add': get_jit_fused_dropout_add_func(), - }) - policy[BloomGelu] = ModulePolicyDescription(method_replacement={ + }, + policy=policy, + target_key=BloomMLP) + self.append_or_create_method_replacement(description={ 'forward': get_jit_fused_bloom_gelu_forward(), 'bloom_gelu_forward': get_jit_fused_gelu_forward_func(), - }) + }, + policy=policy, + target_key=BloomGelu) return policy @@ -154,7 +177,13 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages) stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} + method_replacement = { + 'forward': + partial(new_forward, + stage_manager=stage_manager, + stage_index=stage_index, + shard_config=self.shard_config) + } self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) diff --git a/colossalai/shardformer/policies/chatglm.py b/colossalai/shardformer/policies/chatglm2.py similarity index 67% rename from colossalai/shardformer/policies/chatglm.py rename to colossalai/shardformer/policies/chatglm2.py index e6b458936637..44898847056a 100644 --- a/colossalai/shardformer/policies/chatglm.py +++ b/colossalai/shardformer/policies/chatglm2.py @@ -7,7 +7,7 @@ import colossalai.shardformer.layer as col_nn from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.shardformer.modeling.chatglm import ChatGLMPipelineForwards +from colossalai.shardformer.modeling.chatglm2 import ChatGLMPipelineForwards from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ( ChatGLMForConditionalGeneration, @@ -15,7 +15,11 @@ GLMBlock, ) -from ..modeling.chatglm import get_flash_core_attention_forward, get_jit_fused_glm_block_forward +from ..modeling.chatglm2 import ( + get_chatglm_sequence_parallel_forward_fn, + get_flash_core_attention_forward, + get_jit_fused_glm_block_forward, +) from ..modeling.jit import get_jit_fused_dropout_add_func from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -37,6 +41,11 @@ def preprocess(self): new_vocab_size = vocab_size + world_size - vocab_size % world_size self.model.resize_token_embeddings(new_vocab_size) + if self.pipeline_stage_manager is not None: + # the batch_size_dim is bounded to Model + bsz_dim = 1 + setattr(self.model, 'batch_size_dim', bsz_dim) + return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: @@ -45,8 +54,9 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: policy = {} + use_sequence_parallel = self.shard_config.enable_sequence_parallelism + overlap = self.shard_config.enable_sequence_overlap if self.shard_config.enable_tensor_parallelism: - policy[ChatGLMModel] = ModulePolicyDescription(attribute_replacement={}, sub_module_replacement=[ SubModuleReplacementDescription( @@ -55,36 +65,43 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: ) ]) - policy[GLMBlock] = ModulePolicyDescription(attribute_replacement={ - "self_attention.num_attention_heads_per_partition": - self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, - "self_attention.projection_size": - (self.model.config.kv_channels * self.model.config.num_attention_heads) // - self.shard_config.tensor_parallel_size, - "self_attention.qkv_hidden_size": - (self.model.config.kv_channels * self.model.config.num_attention_heads * 3) // - self.shard_config.tensor_parallel_size, - "self_attention.core_attention.num_attention_heads_per_partition": - self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, - "self_attention.core_attention.hidden_size_per_partition": - self.model.config.kv_channels * self.model.config.num_attention_heads // - self.shard_config.tensor_parallel_size, - }, - param_replacement=[], - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="self_attention.query_key_value", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="self_attention.dense", - target_module=col_nn.Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="self_attention.core_attention.attention_dropout", - target_module=col_nn.DropoutForParallelInput, - ), - ]) + policy[GLMBlock] = ModulePolicyDescription( + attribute_replacement={ + "self_attention.num_attention_heads_per_partition": + self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + "self_attention.projection_size": + (self.model.config.kv_channels * self.model.config.num_attention_heads) // + self.shard_config.tensor_parallel_size, + "self_attention.qkv_hidden_size": + (self.model.config.kv_channels * self.model.config.num_attention_heads * 3) // + self.shard_config.tensor_parallel_size, + "self_attention.core_attention.num_attention_heads_per_partition": + self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + "self_attention.core_attention.hidden_size_per_partition": + self.model.config.kv_channels * self.model.config.num_attention_heads // + self.shard_config.tensor_parallel_size, + }, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription(suffix="self_attention.query_key_value", + target_module=col_nn.Linear1D_Col, + kwargs={ + 'seq_parallel': use_sequence_parallel, + 'seq_parallel_dim': 0, + 'overlap': overlap + }), + SubModuleReplacementDescription(suffix="self_attention.dense", + target_module=col_nn.Linear1D_Row, + kwargs={ + 'seq_parallel': use_sequence_parallel, + 'seq_parallel_dim': 0 + }), + SubModuleReplacementDescription( + suffix="self_attention.core_attention.attention_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + ]) + # optimization configuration if self.shard_config.enable_fused_normalization: if not self.model.config.rmsnorm: @@ -124,16 +141,27 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: # use flash attention if self.shard_config.enable_flash_attention: - policy[CoreAttention] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_flash_core_attention_forward(), - }) + }, + policy=policy, + target_key=CoreAttention) + + # use sequence parallel + if use_sequence_parallel: + self.append_or_create_method_replacement( + description={'forward': get_chatglm_sequence_parallel_forward_fn(self.shard_config)}, + policy=policy, + target_key=ChatGLMModel) # use jit fused operator if self.shard_config.enable_jit_fused: - policy[GLMBlock] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_jit_fused_glm_block_forward(), 'dropout_add': get_jit_fused_dropout_add_func(), - }) + }, + policy=policy, + target_key=GLMBlock) return policy @@ -178,7 +206,13 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli layers_per_stage = Policy.distribute_layers(module.num_layers, stage_manager.num_stages) stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} + method_replacement = { + 'forward': + partial(new_forward, + stage_manager=stage_manager, + stage_index=stage_index, + shard_config=self.shard_config) + } self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 20e5fa372c8f..5093fd469af8 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -6,7 +6,7 @@ import colossalai.shardformer.layer as col_nn from .._utils import getattr_, setattr_ -from ..modeling.gpt2 import GPT2PipelineForwards, get_gpt2_flash_attention_forward +from ..modeling.gpt2 import GPT2PipelineForwards, get_gpt2_flash_attention_forward, gpt2_sequence_parallel_forward_fn from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = [ @@ -37,7 +37,8 @@ def module_policy(self): from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model policy = {} - + use_sequence_parallel = self.shard_config.enable_sequence_parallelism + overlap = self.shard_config.enable_sequence_overlap if self.shard_config.enable_tensor_parallelism: policy[GPT2Model] = ModulePolicyDescription(sub_module_replacement=[ SubModuleReplacementDescription( @@ -49,47 +50,55 @@ def module_policy(self): target_module=col_nn.DropoutForParallelInput, ), ]) - policy[GPT2Block] = ModulePolicyDescription(attribute_replacement={ - "attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "attn.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, - }, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="attn.c_attn", - target_module=col_nn.GPT2FusedLinearConv1D_Col, - kwargs={ - "n_fused": 3, - }, - ), - SubModuleReplacementDescription( - suffix="attn.c_proj", - target_module=col_nn.GPT2FusedLinearConv1D_Row, - ), - SubModuleReplacementDescription( - suffix="mlp.c_fc", - target_module=col_nn.GPT2FusedLinearConv1D_Col, - kwargs={ - "n_fused": 1, - }, - ), - SubModuleReplacementDescription( - suffix="mlp.c_proj", - target_module=col_nn.GPT2FusedLinearConv1D_Row, - ), - SubModuleReplacementDescription( - suffix="attn.attn_dropout", - target_module=col_nn.DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="attn.resid_dropout", - target_module=col_nn.DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="mlp.dropout", - target_module=col_nn.DropoutForParallelInput, - ), - ]) + + policy[GPT2Block] = ModulePolicyDescription( + attribute_replacement={ + "attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "attn.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attn.c_attn", + target_module=col_nn.GPT2FusedLinearConv1D_Col, + kwargs={ + "n_fused": 3, + "seq_parallel": use_sequence_parallel, + "overlap": overlap + }, + ), + SubModuleReplacementDescription(suffix="attn.c_proj", + target_module=col_nn.GPT2FusedLinearConv1D_Row, + kwargs={ + "seq_parallel": use_sequence_parallel, + }), + SubModuleReplacementDescription( + suffix="mlp.c_fc", + target_module=col_nn.GPT2FusedLinearConv1D_Col, + kwargs={ + "n_fused": 1, + "seq_parallel": use_sequence_parallel, + "overlap": overlap + }, + ), + SubModuleReplacementDescription(suffix="mlp.c_proj", + target_module=col_nn.GPT2FusedLinearConv1D_Row, + kwargs={ + "seq_parallel": use_sequence_parallel, + }), + SubModuleReplacementDescription( + suffix="attn.attn_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="attn.resid_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="mlp.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + ]) # optimization configuration if self.shard_config.enable_fused_normalization: @@ -117,9 +126,15 @@ def module_policy(self): target_key=GPT2Block) if self.shard_config.enable_flash_attention: - policy[GPT2Attention] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_gpt2_flash_attention_forward(), - }) + }, + policy=policy, + target_key=GPT2Attention) + + if self.shard_config.enable_sequence_parallelism: + policy[GPT2Model].method_replacement = {"forward": gpt2_sequence_parallel_forward_fn(self.shard_config)} + return policy def postprocess(self): @@ -160,7 +175,13 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages) stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} + method_replacement = { + 'forward': + partial(new_forward, + stage_manager=stage_manager, + stage_index=stage_index, + shard_config=self.shard_config) + } self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 5ee95f3be8fa..c417e5d017bd 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -1,3 +1,4 @@ +import warnings from functools import partial from typing import Callable, Dict, List, Union @@ -35,6 +36,10 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: policy = {} + if self.shard_config.enable_sequence_parallelism: + self.shard_config.enable_sequence_parallelism = False + warnings.warn("Llama dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") + if self.shard_config.enable_tensor_parallelism: policy[LlamaDecoderLayer] = ModulePolicyDescription( attribute_replacement={ @@ -105,9 +110,11 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: target_key=LlamaModel) if self.shard_config.enable_flash_attention: - policy[LlamaAttention] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_llama_flash_attention_forward(), - }) + }, + policy=policy, + target_key=LlamaAttention) return policy diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index ba6036bd0658..abe491bfaace 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -1,3 +1,4 @@ +import warnings from functools import partial from typing import Callable, Dict, List @@ -39,6 +40,9 @@ def module_policy(self): from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer policy = {} + if self.shard_config.enable_sequence_parallelism: + self.shard_config.enable_sequence_parallelism = False + warnings.warn("OPT dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") if self.shard_config.enable_tensor_parallelism: policy[OPTDecoder] = ModulePolicyDescription(sub_module_replacement=[ @@ -100,16 +104,20 @@ def module_policy(self): # use flash attention if self.shard_config.enable_flash_attention: - policy[OPTAttention] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_opt_flash_attention_forward(), - }) + }, + policy=policy, + target_key=OPTAttention) # use jit fused operator if self.shard_config.enable_jit_fused: - policy[OPTDecoderLayer] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_jit_fused_opt_decoder_layer_forward(), 'dropout_add': get_jit_fused_dropout_add_func(), - }) + }, + policy=policy, + target_key=OPTDecoderLayer) return policy diff --git a/colossalai/shardformer/policies/sam.py b/colossalai/shardformer/policies/sam.py index b1eba0432b49..9753d5a737b9 100644 --- a/colossalai/shardformer/policies/sam.py +++ b/colossalai/shardformer/policies/sam.py @@ -199,12 +199,16 @@ def module_policy(self): # use flash attention if self.shard_config.enable_flash_attention: - policy[SamAttention] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_sam_flash_attention_forward(), - }) - policy[SamVisionAttention] = ModulePolicyDescription(method_replacement={ + }, + policy=policy, + target_key=SamAttention) + self.append_or_create_method_replacement(description={ 'forward': get_sam_vision_flash_attention_forward(), - }) + }, + policy=policy, + target_key=SamVisionAttention) return policy diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index 2ef52c214c6b..92cbd3f72b83 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -1,6 +1,8 @@ +import warnings from functools import partial from typing import Callable, Dict, List, Optional, Tuple +import numpy as np from torch import Tensor, nn from colossalai.shardformer.layer import ( @@ -58,6 +60,10 @@ def module_policy(self): policy = {} + if self.shard_config.enable_sequence_parallelism: + self.shard_config.enable_sequence_parallelism = False + warnings.warn("T5 dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") + if self.shard_config.enable_tensor_parallelism: policy[T5Stack] = ModulePolicyDescription(sub_module_replacement=[ SubModuleReplacementDescription( @@ -178,24 +184,33 @@ def module_policy(self): # use flash attention if self.shard_config.enable_flash_attention: - policy[T5Attention] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_t5_flash_attention_forward(), - }) + }, + policy=policy, + target_key=T5Attention) # use jit operator if self.shard_config.enable_jit_fused: - policy[T5LayerFF] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_jit_fused_T5_layer_ff_forward(), 'dropout_add': get_jit_fused_dropout_add_func(), - }) - policy[T5LayerSelfAttention] = ModulePolicyDescription(method_replacement={ + }, + policy=policy, + target_key=T5LayerFF) + self.append_or_create_method_replacement(description={ 'forward': get_T5_layer_self_attention_forward(), 'dropout_add': get_jit_fused_dropout_add_func(), - }) - policy[T5LayerCrossAttention] = ModulePolicyDescription(method_replacement={ + }, + policy=policy, + target_key=T5LayerSelfAttention) + self.append_or_create_method_replacement(description={ 'forward': get_T5_layer_cross_attention_forward(), 'dropout_add': get_jit_fused_dropout_add_func(), - }) + }, + policy=policy, + target_key=T5LayerCrossAttention) + return policy def postprocess(self): @@ -228,13 +243,7 @@ def distribute_t5_layers(num_encoder_layers: int, num_decoder_layers: int, def objective(num_encoder_stages): return abs(num_encoder_layers / num_encoder_stages - num_decoder_layers / (num_stages - num_encoder_stages)) - num_encoder_stages = 0 - optimal_diff = 2**31 - 1 - for i in range(1, num_stages): - attempt = objective(i) - if attempt < optimal_diff: - num_encoder_stages = i - optimal_diff = attempt + num_encoder_stages = np.argmin([objective(i) for i in range(1, num_stages)]) + 1 num_decoder_stages = num_stages - num_encoder_stages encoder_distribution = Policy.distribute_layers(num_encoder_layers, num_encoder_stages) diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py index 617720ee7950..b4fb8692e684 100644 --- a/colossalai/shardformer/policies/vit.py +++ b/colossalai/shardformer/policies/vit.py @@ -1,3 +1,4 @@ +import warnings from typing import Callable, Dict, List, Union import torch.nn as nn @@ -32,6 +33,10 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: policy = {} + if self.shard_config.enable_sequence_parallelism: + self.shard_config.enable_sequence_parallelism = False + warnings.warn("Vit dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") + if self.shard_config.enable_tensor_parallelism: policy[ViTEmbeddings] = ModulePolicyDescription(attribute_replacement={}, param_replacement=[], @@ -90,16 +95,20 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: # use flash attention if self.shard_config.enable_flash_attention: - policy[ViTSelfAttention] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_vit_flash_self_attention_forward(), - }) + }, + policy=policy, + target_key=ViTSelfAttention) # use jit fused operator if self.shard_config.enable_jit_fused: - policy[ViTOutput] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_jit_fused_vit_output_forward(), 'dropout_add': get_jit_fused_dropout_add_func(), - }) + }, + policy=policy, + target_key=ViTOutput) return policy def new_model_class(self): diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py index 2ac7a49fd27b..5d496f08e1db 100644 --- a/colossalai/shardformer/policies/whisper.py +++ b/colossalai/shardformer/policies/whisper.py @@ -1,10 +1,17 @@ +import warnings +from functools import partial +from typing import Callable, Dict, List, Tuple + +import numpy as np import torch.nn as nn +from torch import Tensor import colossalai.shardformer.layer as col_nn from .._utils import getattr_, setattr_ from ..modeling.jit import get_jit_fused_dropout_add_func from ..modeling.whisper import ( + WhisperPipelineForwards, get_jit_fused_whisper_decoder_layer_forward, get_jit_fused_whisper_encoder_layer_forward, get_whisper_flash_attention_forward, @@ -12,7 +19,8 @@ from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = [ - 'WhisperPolicy', 'WhisperModelPolicy', 'WhisperForConditionalGenerationPolicy', 'WhisperForAudioClassification' + 'WhisperPolicy', 'WhisperModelPolicy', 'WhisperForConditionalGenerationPolicy', + 'WhisperForAudioClassificationPolicy' ] @@ -26,7 +34,6 @@ def preprocess(self): r""" Reshape the Embedding layer to make the embedding dimension divisible by world_size """ - # TODO: vocab_size = self.model.config.vocab_size world_size = self.shard_config.tensor_parallel_size if vocab_size % world_size != 0: @@ -45,6 +52,11 @@ def module_policy(self): policy = {} + if self.shard_config.enable_sequence_parallelism: + self.shard_config.enable_sequence_parallelism = False + warnings.warn( + "Whisper dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") + if self.shard_config.enable_tensor_parallelism: policy[WhisperEncoderLayer] = ModulePolicyDescription(attribute_replacement={ "self_attn.embed_dim": @@ -191,20 +203,26 @@ def module_policy(self): # enable flash attention if self.shard_config.enable_flash_attention: - policy[WhisperAttention] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_whisper_flash_attention_forward(), - }) + }, + policy=policy, + target_key=WhisperAttention) # use jit fused operator if self.shard_config.enable_jit_fused: - policy[WhisperEncoderLayer] = ModulePolicyDescription(method_replacement={ - 'forward': get_jit_fused_whisper_encoder_layer_forward(), - 'dropout_add': get_jit_fused_dropout_add_func(), - }) - policy[WhisperDecoderLayer] = ModulePolicyDescription(method_replacement={ + self.append_or_create_method_replacement(description={ 'forward': get_jit_fused_whisper_decoder_layer_forward(), 'dropout_add': get_jit_fused_dropout_add_func(), - }) + }, + policy=policy, + target_key=WhisperDecoderLayer) + self.append_or_create_method_replacement(description={ + 'forward': get_jit_fused_whisper_encoder_layer_forward(), + 'dropout_add': get_jit_fused_dropout_add_func(), + }, + policy=policy, + target_key=WhisperEncoderLayer) return policy @@ -223,6 +241,146 @@ def add_lm_head_policy(self, base_policy): def postprocess(self): return self.model + @staticmethod + def distribute_whisper_layers(num_encoder_layers: int, num_decoder_layers: int, + num_stages: int) -> Tuple[List[int], int]: + """ + Distribute whisper layers into stages when pipeline parallel is used. + Return the layer distribution as a list and the starting stage of decoder. + If decoder doesn't exist, returned decoder starting stage is set to num_encoder_layers. + """ + + # number of encoder layers must be a positive integer + if num_encoder_layers <= 0: + raise ValueError("The number of encoder layers for whisper must be a positive integer.") + + # number of layers should be large enough to fill in every stage + if num_encoder_layers + num_decoder_layers < num_stages: + raise ValueError("The total number of layers can't be smaller than number of stages.") + + # in the case of whisperEncoderModel, set decoder starting stage to num_stages since it doesn't exist + if num_decoder_layers == 0: + return Policy.distribute_layers(num_encoder_layers, num_stages), num_stages + + # the number of stages distributed between encoder and decoder is optmized in this way: + # num_encoder_stages = argmin(abs(num_encoder_layers / encoder_stages - num_decoder_layers / decoder_stages)) + # s.t. num_encoder_stages + num_decoder_stages = num_stages, num_encoder_stages >= 1, num_decoder_stages >= 1 + def objective(num_encoder_stages): + return abs(num_encoder_layers / num_encoder_stages - num_decoder_layers / (num_stages - num_encoder_stages)) + + num_encoder_stages = np.argmin([objective(i) for i in range(1, num_stages)]) + 1 + num_decoder_stages = num_stages - num_encoder_stages + + encoder_distribution = Policy.distribute_layers(num_encoder_layers, num_encoder_stages) + decoder_distribution = Policy.distribute_layers(num_decoder_layers, num_decoder_stages) + return encoder_distribution + decoder_distribution, num_encoder_stages + + @staticmethod + def get_whisper_stage_index(layers_per_stage: List[int], stage: int, + decoder_starting_stage: int) -> Tuple[bool, int, int]: + """ + Input the distribution of layers among stages, the current stage and the first stage of decoder. + Return the starting/ending idx of layers in encoder/decoder + """ + if stage < decoder_starting_stage: + return Policy.get_stage_index(layers_per_stage[:decoder_starting_stage], stage) + else: + return Policy.get_stage_index(layers_per_stage[decoder_starting_stage:], stage - decoder_starting_stage) + + def get_held_layers(self) -> List[nn.Module]: + + assert self.pipeline_stage_manager is not None, "pipeline_stage_manager is None" + stage_manager = self.pipeline_stage_manager + + if self.model.__class__.__name__ == 'WhisperModel': + model = self.model + elif self.model.__class__.__name__ == 'WhisperForConditionalGeneration': + model = self.model.model + else: + model = None + + if model: + encoder = self.model.get_encoder() + decoder = self.model.get_decoder() + else: + # whisper for audio classification holds encoder only + encoder = self.model.encoder + decoder = None + + num_encoder_layers = len(encoder.layers) + if decoder: + num_decoder_layers = len(decoder.layers) + else: + num_decoder_layers = 0 + + held_layers = [] + layers_per_stage, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers( + num_encoder_layers, num_decoder_layers, stage_manager.num_stages) + start_idx, end_idx = WhisperPolicy.get_whisper_stage_index(layers_per_stage, stage_manager.stage, + decoder_starting_stage) + + if stage_manager.stage < decoder_starting_stage: + # current stage is in whisper's encoder + if stage_manager.is_first_stage(): + held_layers.append(encoder.embed_positions) + held_layers.append(encoder.conv1) + held_layers.append(encoder.conv2) + if stage_manager.stage == decoder_starting_stage - 1: + held_layers.append(encoder.layer_norm) + held_layers.extend(encoder.layers[start_idx:end_idx]) + else: + # current stage is in whisper's decoder + # TODO:(Jianghai) We divide encoder and decoder layers into different parts here, + # the case encoder and decoder put in same stage should be add in the future. + if stage_manager.stage == decoder_starting_stage: + held_layers.append(decoder.embed_tokens) + held_layers.append(decoder.embed_positions) + if stage_manager.is_last_stage(): + held_layers.append(decoder.layer_norm) + held_layers.extend(decoder.layers[start_idx:end_idx]) + return held_layers + + def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: + """If under pipeline parallel setting, replacing the original forward method of huggingface + to customized forward method, and add this changing to policy.""" + if not self.pipeline_stage_manager: + raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.") + stage_manager = self.pipeline_stage_manager + + if self.model.__class__.__name__ == 'WhisperModel': + model = self.model + elif self.model.__class__.__name__ == 'WhisperForConditionalGeneration': + model = self.model.model + else: + model = None + + if model: + encoder = self.model.get_encoder() + decoder = self.model.get_decoder() + else: + encoder = self.model.encoder + decoder = None + + num_encoder_layers = len(encoder.layers) + if decoder: + num_decoder_layers = len(decoder.layers) + else: + num_decoder_layers = 0 + + layers_per_stage, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers( + num_encoder_layers, num_decoder_layers, stage_manager.num_stages) + stage_index = WhisperPolicy.get_whisper_stage_index(layers_per_stage, stage_manager.stage, + decoder_starting_stage) + + method_replacement = { + 'forward': + partial(new_forward, + stage_manager=stage_manager, + stage_index=stage_index, + decoder_starting_stage=decoder_starting_stage) + } + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) + # WhisperModel class WhisperModelPolicy(WhisperPolicy): @@ -230,6 +388,24 @@ class WhisperModelPolicy(WhisperPolicy): def __init__(self) -> None: super().__init__() + def module_policy(self): + from transformers import WhisperModel + policy = super().module_policy() + + if self.pipeline_stage_manager is not None: + self.set_pipeline_forward(model_cls=WhisperModel, + new_forward=WhisperPipelineForwards.whisper_model_forward, + policy=policy) + + return policy + + def get_held_layers(self) -> List[nn.Module]: + return super().get_held_layers() + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + "no shared params in whisper model" + return [] + # WhisperForConditionalGeneration class WhisperForConditionalGenerationPolicy(WhisperPolicy): @@ -238,20 +414,82 @@ def __init__(self) -> None: super().__init__() def module_policy(self): - module_policy = super().module_policy() - module_policy = self.add_lm_head_policy(module_policy) - return module_policy + from transformers import WhisperForConditionalGeneration + policy = super().module_policy() + policy = self.add_lm_head_policy(policy) + + if self.pipeline_stage_manager is not None: + self.set_pipeline_forward(model_cls=WhisperForConditionalGeneration, + new_forward=WhisperPipelineForwards.whisper_for_conditional_generation_forward, + policy=policy) + return policy def postprocess(self): - binding_map = {"model.decoder.embed_tokens.weight": "proj_out.weight"} - for k, v in binding_map.items(): - param = getattr_(self.model, k) - setattr_(self.model, v, param) return self.model + def get_held_layers(self) -> List[nn.Module]: + held_layers = super().get_held_layers() + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.proj_out) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + module = self.model + model = module.model + + if model: + encoder = self.model.get_encoder() + decoder = self.model.get_decoder() + else: + encoder = self.model.encoder + decoder = None + + num_encoder_layers = len(encoder.layers) + if decoder: + num_decoder_layers = len(decoder.layers) + else: + num_decoder_layers = 0 + + stage_manager = self.pipeline_stage_manager + if stage_manager is not None and stage_manager.num_stages > 1: + _, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers(num_encoder_layers, num_decoder_layers, + stage_manager.num_stages) + shared_params = [] + shared_embedding = {} + if id(module.proj_out) == id(model.decoder.embed_tokens): + shared_embedding[decoder_starting_stage] = model.decoder.embed_tokens + shared_embedding[stage_manager.num_stages - 1] = module.proj_out + if len(shared_embedding) > 0: + shared_params.append(shared_embedding) + return shared_params + return [] + # WhisperForAudioClassification class WhisperForAudioClassificationPolicy(WhisperPolicy): def __init__(self) -> None: super().__init__() + + def preprocess(self): + return self.model + + def module_policy(self): + from transformers import WhisperForAudioClassification + policy = super().module_policy() + + if self.pipeline_stage_manager is not None: + self.set_pipeline_forward(model_cls=WhisperForAudioClassification, + new_forward=WhisperPipelineForwards.whisper_for_audio_classification_forward, + policy=policy) + return policy + + def get_held_layers(self) -> List[nn.Module]: + held_layers = super().get_held_layers() + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.projector) + held_layers.append(self.model.classifier) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + return [] diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 0c28f115d018..c5c3d185e950 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -20,6 +20,8 @@ class ShardConfig: enable_tensor_parallelism (bool): Whether to turn on tensor parallelism, default is True. enable_fused_normalization (bool): Whether to use fused layernorm, default is False. enable_all_optimization (bool): Whether to turn on all optimization, default is False. + enable_sequence_parallelism (bool): Whether to turn on sequence parallelism, default is False. + enable_sequence_overlap (bool): Whether to turn on sequence overlap, default is False. """ tensor_parallel_process_group: Optional[ProcessGroup] = None pipeline_stage_manager: Optional[PipelineStageManager] = None @@ -28,6 +30,8 @@ class ShardConfig: enable_all_optimization: bool = False enable_flash_attention: bool = False enable_jit_fused: bool = False + enable_sequence_parallelism: bool = False + enable_sequence_overlap: bool = False # pipeline_parallel_size: int # data_parallel_size: int @@ -40,6 +44,11 @@ def tensor_parallel_size(self): return self._tensor_parallel_size def __post_init__(self): + if not self.enable_tensor_parallelism and self.enable_sequence_parallelism: + raise ValueError( + "enable_sequence_parallelism can only be set to True when enable_tensor_parallelism is True") + if not self.enable_sequence_parallelism and self.enable_sequence_overlap: + raise ValueError("enable_sequence_overlap can only be set to True when enable_sequence_parallelism is True") if not self.enable_tensor_parallelism: self._tensor_parallel_size = 1 else: @@ -57,3 +66,5 @@ def _turn_on_all_optimization(self): self.enable_fused_normalization = True self.enable_flash_attention = True self.enable_jit_fused = True + self.enable_sequence_parallelism = True + self.enable_sequence_overlap = True diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index 0ed745a1fc4a..9ed384266a80 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -92,22 +92,21 @@ def _recursive_replace_layer( param_replacement (List[Callable]): The function list to get parameter shard information in policy method_replacement (Dict[str, Callable]): Key is the method name, value is the method for replacement sub_module_replacement ((List[SubModuleReplacementDescription]): The function list to get sub module shard information in policy + include (Set[nn.Module], optional): The set of modules to keep on current device when pipeline parallel is enabled. Defaults to None """ - # released layers are not shardable - can_replace_param_or_layer = include is None or module in include if (isinstance(origin_cls, str) and origin_cls == module.__class__.__name__) or \ (module.__class__ == origin_cls): if attr_replacement is not None: self._replace_attr(module, attr_replacement) - if param_replacement is not None and can_replace_param_or_layer: + if param_replacement is not None and (include is None or module in include): self._replace_param(module, param_replacement) if method_replacement is not None: self._replace_method(module, method_replacement) - if sub_module_replacement is not None and can_replace_param_or_layer: - self._replace_sub_module(module, sub_module_replacement) + if sub_module_replacement is not None: + self._replace_sub_module(module, sub_module_replacement, include) for name, child in module.named_children(): self._recursive_replace_layer(child, @@ -154,18 +153,17 @@ def _replace_method(self, module: nn.Module, method_replacement: Dict[str, Calla bound_method = MethodType(new_method, module) setattr(module, method_name, bound_method) - def _replace_sub_module( - self, - org_layer: nn.Module, - sub_module_replacement: List[SubModuleReplacementDescription], - ) -> None: + def _replace_sub_module(self, + org_layer: nn.Module, + sub_module_replacement: List[SubModuleReplacementDescription], + include: Optional[Set[nn.Module]] = None) -> None: r""" Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict Args: org_layer (torch.nn.Module): The origin layer object to shard sub_module_replacement (List[SubModuleReplacementDescription]): The sub module replacement description list - + include (Set[nn.Module], optional): The set of modules to keep on current device when pipeline parallel is enabled. Defaults to None """ for description in sub_module_replacement: suffix = description.suffix @@ -174,9 +172,12 @@ def _replace_sub_module( assert target_module is not None, 'target_module should not be None' - # TODO: support different parallel mode native_sub_module = getattr_(org_layer, suffix, ignore=True) + # Skip replacement if submodule is not kept by current device when pipeline parallel is enabled. + if (include is not None) and (native_sub_module is not None) and (native_sub_module not in include): + continue + assert not isinstance(native_sub_module, target_module), \ f"The module with suffix {suffix} has been replaced, please check the policy" diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 0cd90459b76a..741a977d1ea0 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -10,8 +10,9 @@ from torch.distributed import ProcessGroup from torch.distributed.distributed_c10d import _get_default_group -from colossalai.checkpoint_io.utils import calculate_tensor_size +from colossalai.checkpoint_io.utils import calculate_tensor_size, StateDictSharder from colossalai.interface import ModelWrapper + from colossalai.lazy import LazyTensor from colossalai.logging import get_dist_logger from colossalai.nn.parallel.data_parallel import _cast_float, free_storage @@ -733,7 +734,7 @@ def state_dict_shard(self, Yields: Iterator[OrderedDict]: A generator of state dict shard """ - sharder = _StateDictSharder(max_shard_size) + sharder = StateDictSharder(max_shard_size) # get the mapping between copies and fp16 parameters fp16_to_fp32 = dict() @@ -755,7 +756,7 @@ def state_dict_shard(self, gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0, dtype)) gathered_param = gathered_param_buffer.pop(fp32_param) - block, block_size = sharder.append(prefix + name, gathered_param) + block, block_size = sharder.append_param(prefix + name, gathered_param) if block is not None: yield block, block_size @@ -766,7 +767,7 @@ def state_dict_shard(self, for name, buf in self.named_buffers(): if buf is not None and name not in self._non_persistent_buffers_set: buffer = buf if keep_vars else buf.detach() - block, block_size = sharder.append(prefix + name, buffer) + block, block_size = sharder.append_param(prefix + name, buffer) if block is not None: yield block, block_size # save extra states @@ -774,32 +775,10 @@ def state_dict_shard(self, if getattr(self.__class__, "get_extra_state", torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state: extra_state = self.get_extra_state() - block, block_size = sharder.append(extra_state_key, extra_state) + block, block_size = sharder.append_param(extra_state_key, extra_state) if block is not None: yield block, block_size yield sharder.current_block, sharder.current_block_size -class _StateDictSharder: - - def __init__(self, max_shard_size: int) -> None: - self.max_shard_size = max_shard_size - self.current_block = OrderedDict() - self.current_block_size = 0 - - def append(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[OrderedDict], int]: - tensor_size = calculate_tensor_size(tensor) - ret_block = None - ret_block_size = 0 - - # before we return the current block and create a new block, - # we need to ensure that the current block is not empty - if self.current_block_size + tensor_size > self.max_shard_size and self.current_block_size > 0: - ret_block = self.current_block - ret_block_size = self.current_block_size - self.current_block = OrderedDict() - self.current_block_size = 0 - self.current_block[name] = tensor - self.current_block_size += tensor_size - return ret_block, ret_block_size diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py index 175b97647e16..0c593deff225 100644 --- a/colossalai/zero/gemini/gemini_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -10,7 +10,7 @@ from torch.optim import Optimizer from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin -from colossalai.checkpoint_io.utils import calculate_tensor_size +from colossalai.checkpoint_io.utils import calculate_tensor_size, StateDictSharder from colossalai.interface import OptimizerWrapper from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import CPUAdam, FusedAdam, HybridAdam @@ -692,49 +692,17 @@ def state_shard(self, Iterator[OrderedDict]: A generator of state dict shard of optimizer states. """ - current_block = {} - current_block_size = 0 - + sharder = StateDictSharder(max_shard_size) for param_id in self.id_to_real_params.keys(): dist.barrier() state = self.collect_states(param_id=param_id, only_rank_0=only_rank_0) - ret_block = None - ret_block_size = 0 - - # A state might contain more than one tensors. - # e.g. each Adam state includes: 'step', 'exp_avg', 'exp_avg_sq' - state_size = 0 - isDTensor = False - for state_tensor in state.values(): - - # When state_tensor is not of Tensor class, - # e.g., a SGD optimizer with momentum set to 0 can have None as state - # The calculation of tensor size should be skipped to avoid error. - if not isinstance(state_tensor, torch.Tensor): - continue - - # If the states are stored as DTensors, mark isDTensor as true. - if is_distributed_tensor(state_tensor): - isDTensor = True - state_size += calculate_tensor_size(state_tensor) - - if not isDTensor: - - if current_block_size + state_size > max_shard_size and current_block_size > 0: - ret_block = current_block - ret_block_size = current_block_size - current_block = {} - current_block_size = 0 - - current_block[param_id] = state - current_block_size += state_size - - if ret_block != None: - yield ret_block, ret_block_size + block, block_size = sharder.append_optim_state(param_id, state) + if block is not None: + yield block, block_size - yield current_block, current_block_size + yield sharder.current_block, sharder.current_block_size def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None: raise NotImplementedError('Gemini does not support clip_grad_by_value') diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index d9d6298d745a..0bdd6a3e2370 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -338,6 +338,24 @@ def backward(self, loss, retain_graph=False): self.zero_grad() + def backward_by_grad(self, tensor, grad): + assert not(self._partition_grads and not self.require_grad_sync), \ + "ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible" + + if self.mixed_precision_mixin is not None: + grad = self.mixed_precision_mixin.pre_backward_by_grad(tensor, grad) + torch.autograd.backward(tensor, grad) + + if not self.require_grad_sync: + return + self._reduce_grad(self._partition_grads) + + # clear reduced grads + if self._overlap_communication: + torch.cuda.synchronize() + + self.zero_grad() + def zero_grad(self, set_to_none=True): """ Set parameter gradients to zero. If set_to_none = True, gradient @@ -363,7 +381,6 @@ def zero_grad(self, set_to_none=True): def step(self, closure=None): assert closure is None, 'closure is not supported by step()' - if not self.require_grad_sync: return diff --git a/examples/language/bert/README.md b/examples/language/bert/README.md index da38e8375bf0..6601edb7960e 100644 --- a/examples/language/bert/README.md +++ b/examples/language/bert/README.md @@ -7,13 +7,15 @@ This directory includes two parts: Using the Booster API finetune Huggingface Be bash test_ci.sh ``` -### Results on 2-GPU +### Bert-Finetune Results + +| Plugin | Accuracy | F1-score | GPU number | +| -------------- | -------- | -------- | -------- | +| torch_ddp | 84.4% | 88.6% | 2 | +| torch_ddp_fp16 | 84.7% | 88.8% | 2 | +| gemini | 84.0% | 88.4% | 2 | +| hybrid_parallel | 84.5% | 88.6% | 4 | -| Plugin | Accuracy | F1-score | -| -------------- | -------- | -------- | -| torch_ddp | 84.4% | 88.6% | -| torch_ddp_fp16 | 84.7% | 88.8% | -| gemini | 84.0% | 88.4% | ## Benchmark ``` diff --git a/examples/language/bert/finetune.py b/examples/language/bert/finetune.py index 59f10a77c22d..c4d541c978a8 100644 --- a/examples/language/bert/finetune.py +++ b/examples/language/bert/finetune.py @@ -1,12 +1,14 @@ import argparse -from typing import List, Union +from contextlib import nullcontext +from typing import Callable, List, Union import evaluate import torch import torch.distributed as dist import torch.nn as nn from data import GLUEDataBuilder -from torch.optim import Optimizer +from torch.optim import Adam, Optimizer +from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils.data import DataLoader from tqdm import tqdm from transformers import ( @@ -18,8 +20,9 @@ import colossalai from colossalai.booster import Booster -from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin +from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.cluster import DistCoordinator +from colossalai.lazy import LazyInitContext from colossalai.nn.optimizer import HybridAdam from colossalai.utils import get_current_device @@ -32,14 +35,26 @@ WEIGHT_DECAY = 0.01 WARMUP_FRACTION = 0.1 +output_transform_fn = lambda x: x +criterion = lambda x: x.loss + def move_to_cuda(batch): return {k: v.cuda() for k, v in batch.items()} @torch.no_grad() -def evaluate_model(model: nn.Module, test_dataloader: Union[DataLoader, List[DataLoader]], num_labels: int, - task_name: str, eval_splits: List[str], coordinator: DistCoordinator): +def evaluate_model( + model: nn.Module, + optimizer, + criterion, + test_dataloader: Union[DataLoader, List[DataLoader]], + num_labels: int, + task_name: str, + eval_splits: List[str], + booster: Booster, + coordinator: DistCoordinator, +): metric = evaluate.load("glue", task_name, process_id=coordinator.rank, num_process=coordinator.world_size) model.eval() @@ -47,23 +62,66 @@ def evaluate_subset(dataloader: DataLoader): accum_loss = torch.zeros(1, device=get_current_device()) for batch in dataloader: batch = move_to_cuda(batch) - outputs = model(**batch) - val_loss, logits = outputs[:2] - accum_loss.add_(val_loss) - - if num_labels > 1: - preds = torch.argmax(logits, axis=1) - elif num_labels == 1: - preds = logits.squeeze() - labels = batch["labels"] - - metric.add_batch(predictions=preds, references=labels) + batch_size = batch["input_ids"].shape[0] + if hasattr(booster.plugin, "stage_manager") and booster.plugin.stage_manager is not None: + pg_mesh = booster.plugin.pg_mesh + pp_group = booster.plugin.pp_group + current_pp_group_ranks = pg_mesh.get_ranks_in_group(pp_group) + current_rank = dist.get_rank() + #TODO pass dataloader to execute_pipeline directly + batch = iter([batch]) + outputs = booster.execute_pipeline(batch, + model, + criterion, + optimizer, + return_loss=True, + return_outputs=True) + + if booster.plugin.stage_manager.is_last_stage(): + val_loss = outputs["loss"] + + logits = outputs["outputs"]["logits"] + + accum_loss.add_(val_loss) + + if num_labels > 1: + preds = torch.argmax(logits, axis=1) + elif num_labels == 1: + preds = logits.squeeze() + + dist.broadcast(preds, src=current_rank, group=pp_group) + dist.broadcast(val_loss, src=current_rank, group=pp_group) + + metric.add_batch(predictions=preds, references=labels) + elif current_rank in current_pp_group_ranks: + val_loss = torch.empty((1,), device=get_current_device()) + preds = torch.empty((batch_size,), dtype=torch.int64, device=get_current_device()) + + dist.broadcast(preds, src=current_pp_group_ranks[-1], group=pp_group) + dist.broadcast(val_loss, src=current_pp_group_ranks[-1], group=pp_group) + + accum_loss.add_(val_loss) + metric.add_batch(predictions=preds, references=labels) + + else: + batch = move_to_cuda(batch) + outputs = model(**batch) + val_loss, logits = outputs[:2] + accum_loss.add_(val_loss) + + if num_labels > 1: + preds = torch.argmax(logits, axis=1) + elif num_labels == 1: + preds = logits.squeeze() + + metric.add_batch(predictions=preds, references=labels) results = metric.compute() dist.all_reduce(accum_loss.div_(len(dataloader))) - if coordinator.is_master(): + if coordinator.is_master() and results is not None: results['loss'] = accum_loss.item() / coordinator.world_size + return results if isinstance(test_dataloader, DataLoader): @@ -77,25 +135,43 @@ def evaluate_subset(dataloader: DataLoader): return final_results -def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, lr_scheduler, train_dataloader: DataLoader, - booster: Booster, coordinator: DistCoordinator): +def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, _criterion: Callable, lr_scheduler: LRScheduler, + train_dataloader: DataLoader, booster: Booster, coordinator: DistCoordinator): + model.train() - with tqdm(train_dataloader, desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]', disable=not coordinator.is_master()) as pbar: + is_pp_last_stage = hasattr( + booster.plugin, + "stage_manager") and booster.plugin.stage_manager is not None and booster.plugin.stage_manager.is_last_stage() + with tqdm(train_dataloader, + desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]', + disable=not (coordinator.is_master() or is_pp_last_stage)) as pbar: for batch in pbar: # Forward pass batch = move_to_cuda(batch) - outputs = model(**batch) - loss = outputs[0] + if hasattr(booster.plugin, "stage_manager") and booster.plugin.stage_manager is not None: + #TODO pass train_dataloader to execute_pipeline directly + batch = iter([batch]) + outputs = booster.execute_pipeline(batch, + model, + _criterion, + optimizer, + return_loss=True, + return_outputs=True) + # Backward and optimize + if booster.plugin.stage_manager.is_last_stage(): + loss = outputs['loss'] + pbar.set_postfix({'loss': loss.item()}) + else: + outputs = model(**batch) + loss = _criterion(outputs, None) + # Backward + booster.backward(loss, optimizer) + pbar.set_postfix({'loss': loss.item()}) - # Backward and optimize - booster.backward(loss, optimizer) optimizer.step() optimizer.zero_grad() lr_scheduler.step() - # Print log info - pbar.set_postfix({'loss': loss.item()}) - def main(): # ============================== @@ -107,7 +183,7 @@ def main(): '--plugin', type=str, default='torch_ddp', - choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero'], + choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero', 'hybrid_parallel'], help="plugin to use") parser.add_argument( "--model_type", @@ -116,6 +192,7 @@ def main(): help="bert or albert", ) parser.add_argument('--target_f1', type=float, default=None, help="target f1 score. Raise exception if not reached") + parser.add_argument('--use_lazy_init', type=bool, default=False, help="for initiating lazy init context") args = parser.parse_args() if args.model_type == 'bert': @@ -145,6 +222,17 @@ def main(): plugin = GeminiPlugin(initial_scale=2**5) elif args.plugin == 'low_level_zero': plugin = LowLevelZeroPlugin(initial_scale=2**5) + elif args.plugin == 'hybrid_parallel': + + # modify the param accordingly for finetuning test cases + plugin = HybridParallelPlugin(tp_size=1, + pp_size=2, + num_microbatches=None, + microbatch_size=1, + enable_all_optimization=True, + zero_stage=1, + precision='fp16', + initial_scale=1) booster = Booster(plugin=plugin, **booster_kwargs) @@ -165,8 +253,9 @@ def main(): # bert pretrained model cfg = AutoConfig.from_pretrained(model_name, num_labels=data_builder.num_labels) + if model_name == "bert-base-uncased": - model = BertForSequenceClassification.from_pretrained(model_name, config=cfg) + model = BertForSequenceClassification.from_pretrained(model_name, config=cfg).cuda() elif model_name == "albert-xxlarge-v2": model = AlbertForSequenceClassification.from_pretrained(model_name, config=cfg) else: @@ -196,19 +285,27 @@ def main(): num_training_steps=total_steps, ) + def _criterion(outputs, inputs): + outputs = output_transform_fn(outputs) + loss = criterion(outputs) + return loss + # ============================== # Boost with ColossalAI # ============================== - model, optimizer, _, _, lr_scheduler = booster.boost(model, optimizer, lr_scheduler=lr_scheduler) + model, optimizer, _criterion, _, lr_scheduler = booster.boost(model, + optimizer, + criterion=_criterion, + lr_scheduler=lr_scheduler) # ============================== # Train model # ============================== for epoch in range(NUM_EPOCHS): - train_epoch(epoch, model, optimizer, lr_scheduler, train_dataloader, booster, coordinator) + train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator) - results = evaluate_model(model, test_dataloader, data_builder.num_labels, args.task, data_builder.eval_splits, - coordinator) + results = evaluate_model(model, optimizer, _criterion, test_dataloader, data_builder.num_labels, args.task, + data_builder.eval_splits, booster, coordinator) if coordinator.is_master(): print(results) diff --git a/examples/language/bert/test_ci.sh b/examples/language/bert/test_ci.sh index 7fc6daabb2f3..394ff831b855 100755 --- a/examples/language/bert/test_ci.sh +++ b/examples/language/bert/test_ci.sh @@ -3,6 +3,6 @@ set -xe pip install -r requirements.txt -for plugin in "torch_ddp" "torch_ddp_fp16" "gemini" "low_level_zero"; do +for plugin in "torch_ddp" "torch_ddp_fp16" "gemini" "low_level_zero" "hybrid_parallel"; do torchrun --standalone --nproc_per_node 4 finetune.py --target_f1 0.86 --plugin $plugin --model_type "bert" done diff --git a/pytest.ini b/pytest.ini index b869bb4fa116..38ad7d76de50 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,7 +1,5 @@ [pytest] markers = - cpu: tests which can run on CPU - gpu: tests which requires a single GPU - dist: tests which are run in a multi-GPU or multi-machine environment - experiment: tests for experimental features + dist: tests which are run in a multi-GPU or multi-machine environment (at least 4 GPUs) + largedist: tests which are run in a multi-GPU or multi-machine environment (at least 8 GPUs) addopts = --ignore=tests/test_analyzer --ignore=tests/test_auto_parallel --ignore=tests/test_autochunk --ignore=tests/test_moe --ignore=tests/test_fx --ignore=tests/test_legacy diff --git a/tests/kit/model_zoo/transformers/__init__.py b/tests/kit/model_zoo/transformers/__init__.py index 823ca032fc30..2a492361b13b 100644 --- a/tests/kit/model_zoo/transformers/__init__.py +++ b/tests/kit/model_zoo/transformers/__init__.py @@ -2,7 +2,7 @@ from .bert import * from .blip2 import * from .bloom import * -from .chatglm import * +from .chatglm2 import * from .gpt import * from .llama import * from .opt import * diff --git a/tests/kit/model_zoo/transformers/chatglm.py b/tests/kit/model_zoo/transformers/chatglm2.py similarity index 93% rename from tests/kit/model_zoo/transformers/chatglm.py rename to tests/kit/model_zoo/transformers/chatglm2.py index c6473ee2a025..d543df00bdfa 100644 --- a/tests/kit/model_zoo/transformers/chatglm.py +++ b/tests/kit/model_zoo/transformers/chatglm2.py @@ -12,8 +12,8 @@ def data_gen(): - input_ids = torch.tensor([[5941, 15, 2670, 3543, 632, 2075]], dtype=torch.int64) - attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1]]) + input_ids = torch.tensor([[5941, 15, 2670, 3543, 632, 2075, 632, 2075]], dtype=torch.int64) + attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]]) return dict(input_ids=input_ids, attention_mask=attention_mask) diff --git a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py new file mode 100644 index 000000000000..e43908e0c651 --- /dev/null +++ b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py @@ -0,0 +1,164 @@ +import pytest +import torch +import torch.distributed as dist +from torch.optim import Adam +from utils import shared_tempdir + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import HybridParallelPlugin +from colossalai.shardformer.layer.utils import Randomizer +from colossalai.tensor.d_tensor.api import clear_layout_converter +from colossalai.testing import ( + assert_close_loose, + check_state_dict_equal, + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + spawn, +) +from tests.kit.model_zoo import model_zoo + + +# TODO (Baizhou): Add test cases for shard=False +@clear_cache_before_run() +@parameterize('shard', [True]) +@parameterize('model_name', ['transformers_gpt']) +@parameterize('size_per_shard', [32]) +@parameterize('test_config', [{ + 'tp_size': 4, + 'pp_size': 1, + 'precision': 'fp32', +}, { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'precision': 'fp16', + 'initial_scale': 1 +}, { + 'tp_size': 2, + 'pp_size': 1, + 'zero_stage': 2, + 'precision': 'fp16', + 'initial_scale': 1 +}, { + 'tp_size': 1, + 'pp_size': 2, + 'num_microbatches': 4, + 'zero_stage': 1, + 'precision': 'fp16', + 'initial_scale': 1 +}]) +def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_config: dict): + + (model_fn, data_gen_fn, output_transform_fn, loss_fn, + _) = next(iter(model_zoo.get_sub_registry(model_name).values())) + criterion = loss_fn + plugin = HybridParallelPlugin(**test_config) + booster = Booster(plugin=plugin) + + def _criterion(outputs, inputs): + outputs = output_transform_fn(outputs) + loss = criterion(outputs) + return loss + + def _preprocess_data(data): + if booster.plugin.stage_manager is not None: + for k, v in data.items(): + if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__: + new_shape = [1] * v.dim() + new_shape[0] = 4 + data[k] = v.to('cuda').repeat(*new_shape) + return iter([data]) + else: + return {k: v.cuda() for k, v in data.items()} + + model = model_fn().cuda() + optimizer = Adam(model.parameters(), lr=1e-3) + model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) + + data = data_gen_fn() + model.train() + if booster.plugin.stage_manager is not None: + booster.execute_pipeline(_preprocess_data(data), + model, + _criterion, + optimizer, + return_loss=True, + return_outputs=False) + else: + output = model(**_preprocess_data(data)) + loss = criterion(output) + optimizer.backward(loss) + + optimizer.step() + + with shared_tempdir() as tempdir: + + model_ckpt_path = f"{tempdir}/model" + optimizer_ckpt_path = f"{tempdir}/optimizer" + booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard) + booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard) + dist.barrier() + + new_model = model_fn().cuda() + new_optimizer = Adam(new_model.parameters(), lr=1e-3) + new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion) + + booster.load_model(new_model, model_ckpt_path) + check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict(), False) + booster.load_optimizer(new_optimizer, optimizer_ckpt_path) + check_state_dict_equal(optimizer.unwrap().state_dict(), new_optimizer.unwrap().state_dict(), False) + dist.barrier() + + # Check whether the loaded model & optimizer works smoothly. + model.train() + new_model.train() + if booster.plugin.stage_manager is not None: + booster.execute_pipeline(_preprocess_data(data), + model, + _criterion, + optimizer, + return_loss=True, + return_outputs=False) + booster.execute_pipeline(_preprocess_data(data), + new_model, + _criterion, + new_optimizer, + return_loss=True, + return_outputs=False) + else: + old_model_loss = criterion(model(**_preprocess_data(data))) + optimizer.backward(old_model_loss) + new_model_loss = criterion(new_model(**_preprocess_data(data))) + new_optimizer.backward(new_model_loss) + + optimizer.step() + new_optimizer.step() + + # Check updated weights. + stage_manager = booster.plugin.stage_manager + + if stage_manager is None or stage_manager.is_first_stage(): + assert_close_loose(model.unwrap().wte.weight.data, new_model.unwrap().wte.weight.data, atol=5e-3, rtol=5e-3) + assert_close_loose(model.unwrap().h[0].mlp.c_fc.weight.data, + new_model.unwrap().h[0].mlp.c_fc.weight.data, + atol=5e-3, + rtol=5e-3) + + dist.barrier() + Randomizer.reset_index() + clear_layout_converter() + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + exam_state_dict() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [4]) +@rerun_if_address_is_in_use() +def test_hybrid_ckpIO(world_size): + spawn(run_dist, world_size) diff --git a/tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py b/tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py new file mode 100644 index 000000000000..bd041a5e2fd3 --- /dev/null +++ b/tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py @@ -0,0 +1,83 @@ +import os + +import pytest +import torch +import torch.distributed as dist +from utils import shared_tempdir + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin +from colossalai.nn.optimizer import HybridAdam +from colossalai.testing import ( + check_state_dict_equal, + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + spawn, +) +from tests.kit.model_zoo import model_zoo + + +@clear_cache_before_run() +@parameterize('model_name', ['transformers_gpt']) +@parameterize('plugin_type', ['ddp', 'zero', 'gemini']) +def exam_from_pretrained(plugin_type: str, model_name: str, shard=True, size_per_shard=32): + (model_fn, data_gen_fn, output_transform_fn, loss_fn, + _) = next(iter(model_zoo.get_sub_registry(model_name).values())) + criterion = loss_fn + + if plugin_type == 'ddp': + plugin = TorchDDPPlugin() + elif plugin_type == 'zero': + plugin = LowLevelZeroPlugin(stage=2, max_norm=1.0, initial_scale=32) + elif plugin_type == 'gemini': + plugin = GeminiPlugin(precision="fp16", initial_scale=32) + else: + raise ValueError(f"Plugin with type {plugin_type} is invalid, please check your argument.") + + booster = Booster(plugin=plugin) + + model = model_fn().cuda() + model_huggingface_cls = model.__class__ + optimizer = HybridAdam(model.parameters(), lr=0.001) + model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) + + data = data_gen_fn() + data = {k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()} + output = model(**data) + loss = criterion(output) + + booster.backward(loss, optimizer) + optimizer.step() + + with shared_tempdir() as tempdir: + + model_ckpt_path = f"{tempdir}/model" + booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard) + dist.barrier() + + new_model = model_huggingface_cls.from_pretrained(model_ckpt_path) + new_model = new_model.cuda() + new_optimizer = HybridAdam(new_model.parameters(), lr=0.001) + new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion) + + if plugin_type == 'gemini': + check_state_dict_equal(model.unwrap().state_dict(only_rank_0=False), + new_model.unwrap().state_dict(only_rank_0=False), False) + else: + check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict(), False) + dist.barrier() + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + exam_from_pretrained() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [2]) +@rerun_if_address_is_in_use() +def test_huggingface_compatibility(world_size): + spawn(run_dist, world_size) diff --git a/tests/test_config/test_load_config.py b/tests/test_config/test_load_config.py index 550af2a4ae81..38b5e3f5f4fc 100644 --- a/tests/test_config/test_load_config.py +++ b/tests/test_config/test_load_config.py @@ -8,7 +8,6 @@ from colossalai.context.config import Config -@pytest.mark.cpu def test_load_config(): filename = Path(__file__).parent.joinpath('sample_config.py') config = Config.from_file(filename) diff --git a/tests/test_context/test_hybrid_parallel.py b/tests/test_context/test_hybrid_parallel.py index 9f26a5af53ce..d25668afd430 100644 --- a/tests/test_context/test_hybrid_parallel.py +++ b/tests/test_context/test_hybrid_parallel.py @@ -143,7 +143,6 @@ def run_dist(rank, world_size, port, backend, port_list, host): reset_seeds() -@pytest.mark.cpu @rerun_if_address_is_in_use() def test_context(): """ diff --git a/tests/test_data/test_cifar10_dataset.py b/tests/test_data/test_cifar10_dataset.py index 4b9ca61d9f17..dfa9fa211ef0 100644 --- a/tests/test_data/test_cifar10_dataset.py +++ b/tests/test_data/test_cifar10_dataset.py @@ -5,11 +5,10 @@ from pathlib import Path import pytest -from torchvision import transforms, datasets from torch.utils.data import DataLoader +from torchvision import datasets, transforms -@pytest.mark.cpu def test_cifar10_dataset(): # build transform transform_pipeline = [transforms.ToTensor()] diff --git a/tests/test_data/test_data_parallel_sampler.py b/tests/test_data/test_data_parallel_sampler.py index 2ad3fd696c39..7beef707c096 100644 --- a/tests/test_data/test_data_parallel_sampler.py +++ b/tests/test_data/test_data_parallel_sampler.py @@ -53,7 +53,6 @@ def run_data_sampler(rank, world_size, port): torch.cuda.empty_cache() -@pytest.mark.cpu @rerun_if_address_is_in_use() def test_data_sampler(): spawn(run_data_sampler, 4) diff --git a/tests/test_data/test_deterministic_dataloader.py b/tests/test_data/test_deterministic_dataloader.py index 239e79dff7d8..283b5cc35279 100644 --- a/tests/test_data/test_deterministic_dataloader.py +++ b/tests/test_data/test_deterministic_dataloader.py @@ -64,7 +64,6 @@ def run_data_sampler(rank, world_size, port): torch.cuda.empty_cache() -@pytest.mark.cpu @rerun_if_address_is_in_use() def test_data_sampler(): spawn(run_data_sampler, 4) diff --git a/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py b/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py new file mode 100644 index 000000000000..0cbb852b97a0 --- /dev/null +++ b/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py @@ -0,0 +1,39 @@ +from colossalai.shardformer.policies.t5 import T5BasePolicy + + +def test_t5_pipeline_distribution(): + num_test_cases = 8 + test_dict = { + 'num_encoder_layers': [2, 1, 3, 2, 3, 2, 10, 5], + 'num_decoder_layers': [2, 8, 0, 2, 1, 5, 6, 22], + 'num_stages': [2, 2, 2, 4, 4, 4, 8, 8], + 'decoder_starting_stage': [1, 1, 2, 2, 3, 1, 5, 2] + } + + for i in range(num_test_cases): + _, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(test_dict['num_encoder_layers'][i], + test_dict['num_decoder_layers'][i], + test_dict['num_stages'][i]) + assert test_dict['decoder_starting_stage'][i] == decoder_starting_stage + + +def test_t5_pipeline_layers(): + num_test_cases = 4 + test_dict = { + 'num_encoder_layers': [2, 3, 2, 4], + 'num_decoder_layers': [2, 0, 2, 8], + 'num_stages': [2, 2, 4, 4], + 'layers_per_stage': [[[0, 2], [0, 2]], [[0, 1], [1, 3]], [[0, 1], [1, 2], [0, 1], [1, 2]], + [[0, 4], [0, 3], [3, 6], [6, 8]]] + } + + for i in range(num_test_cases): + layers_per_stage, decoder_starting_stage = T5BasePolicy.distribute_t5_layers( + test_dict['num_encoder_layers'][i], test_dict['num_decoder_layers'][i], test_dict['num_stages'][i]) + + for stage in range(test_dict['num_stages'][i]): + start_idx, end_idx = test_dict['layers_per_stage'][i][stage] + predicted_start, predicted_end = T5BasePolicy.get_t5_stage_index(layers_per_stage, stage, + decoder_starting_stage) + assert start_idx == predicted_start + assert end_idx == predicted_end diff --git a/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py b/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py new file mode 100644 index 000000000000..395519e97898 --- /dev/null +++ b/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py @@ -0,0 +1,44 @@ +from colossalai.shardformer.policies.whisper import WhisperPolicy + + +def test_whisper_pipeline_distribution(): + num_test_cases = 8 + test_dict = { + 'num_encoder_layers': [2, 1, 3, 2, 3, 2, 10, 5], + 'num_decoder_layers': [2, 8, 0, 2, 1, 5, 6, 22], + 'num_stages': [2, 2, 2, 4, 4, 4, 8, 8], + 'decoder_starting_stage': [1, 1, 2, 2, 3, 1, 5, 2] + } + + for i in range(num_test_cases): + _, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers(test_dict['num_encoder_layers'][i], + test_dict['num_decoder_layers'][i], + test_dict['num_stages'][i]) + assert test_dict['decoder_starting_stage'][i] == decoder_starting_stage + + +def test_whisper_pipeline_layers(): + num_test_cases = 4 + test_dict = { + 'num_encoder_layers': [2, 3, 2, 4], + 'num_decoder_layers': [2, 0, 2, 8], + 'num_stages': [2, 2, 4, 4], + 'layers_per_stage': [[[0, 2], [0, 2]], [[0, 1], [1, 3]], [[0, 1], [1, 2], [0, 1], [1, 2]], + [[0, 4], [0, 3], [3, 6], [6, 8]]] + } + + for i in range(num_test_cases): + layers_per_stage, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers( + test_dict['num_encoder_layers'][i], test_dict['num_decoder_layers'][i], test_dict['num_stages'][i]) + + for stage in range(test_dict['num_stages'][i]): + start_idx, end_idx = test_dict['layers_per_stage'][i][stage] + predicted_start, predicted_end = WhisperPolicy.get_whisper_stage_index(layers_per_stage, stage, + decoder_starting_stage) + assert start_idx == predicted_start + assert end_idx == predicted_end + + +if __name__ == '__main__': + test_whisper_pipeline_distribution() + test_whisper_pipeline_layers() diff --git a/tests/test_pipeline/test_schedule/test_interleaved.py b/tests/test_pipeline/test_schedule/test_interleaved.py new file mode 100644 index 000000000000..2ac31c8ca0d1 --- /dev/null +++ b/tests/test_pipeline/test_schedule/test_interleaved.py @@ -0,0 +1,161 @@ +import copy +from functools import partial +from types import MethodType + +import pytest +import torch +import torch.nn as nn + +import colossalai +from colossalai.cluster import ProcessGroupMesh +from colossalai.interface import OptimizerWrapper +from colossalai.pipeline.schedule.interleaved_pp import InterleavedSchedule +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.testing.random import seed_all + + +class MlpModel(nn.Module): + + def __init__(self): + super(MlpModel, self).__init__() + self.linear1 = nn.Linear(4, 8) + self.linear2 = nn.Linear(8, 8) + self.linear3 = nn.Linear(8, 8) + self.linear4 = nn.Linear(8, 8) + self.linear5 = nn.Linear(8, 8) + self.linear6 = nn.Linear(8, 8) + self.linear7 = nn.Linear(8, 8) + self.linear8 = nn.Linear(8, 4) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + x = self.linear3(x) + x = self.linear4(x) + x = self.linear5(x) + x = self.linear6(x) + x = self.linear7(x) + x = self.linear8(x) + return x + + +def pp_linear_fwd(forward, + data: torch.Tensor = None, + input_obj: torch.Tensor = None, + stage_mgr: PipelineStageManager = None, + num_chunks: int = None, + model_chunk_id: int = None): + + if stage_mgr.is_first_stage() and model_chunk_id == 0: + return {'input_obj': forward(data)} + elif stage_mgr.is_last_stage() and model_chunk_id == num_chunks - 1: + return forward(input_obj) + else: + return {'input_obj': forward(input_obj)} + + +@parameterize("num_micro_batches", [4, 8, 12]) +def examine_pp(num_micro_batches): + """ + This test is to examine the correctness of interleaved 1F1B, compared with torch. + Be aware it contains some hardcodes. + """ + world_size = torch.distributed.get_world_size() + local_rank = torch.distributed.get_rank() + seed_all(1453) + + NUM_MICRO_BATCHS = num_micro_batches + BATCH_SIZE = num_micro_batches + NUM_CHUNKS = 2 + + # create model + torch_model = MlpModel().cuda() + + pp_model = copy.deepcopy(torch_model).cuda() + + DP_DIM, PP_DIM, TP_DIM = 0, 1, 2 + pg_mesh = ProcessGroupMesh(1, world_size, 1) + stage_manager = PipelineStageManager(pg_mesh, PP_DIM, is_virtual=True) + schedule = InterleavedSchedule(NUM_MICRO_BATCHS, NUM_CHUNKS, stage_manager) + + sharded_model = torch.nn.ModuleList() + for idx, (_, sub_model) in enumerate(pp_model.named_children()): + if idx % (world_size) == local_rank: + sub_model._forward = sub_model.forward + sub_model.forward = MethodType( + partial(pp_linear_fwd, + stage_mgr=stage_manager, + num_chunks=NUM_CHUNKS, + model_chunk_id=len(sharded_model)), sub_model._forward) + sharded_model.append(sub_model.cuda()) + + # create optimizer + torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1) + pp_optimizer = OptimizerWrapper(torch.optim.SGD(sharded_model.parameters(), lr=1)) + + # create + seed_all(1453) + if local_rank == 0: + input_list = [torch.rand(BATCH_SIZE, 4).cuda()] + else: + input_list = [torch.zeros(BATCH_SIZE, 4).cuda()] + torch.distributed.all_reduce(input_list[0]) + + criterion = lambda x, y: torch.mean(x) + + # forward and backward + torch_output = torch_model(input_list[0]) + torch_loss = criterion(torch_output, _) + torch_loss.backward() + + pp_ret = schedule.forward_backward_step(sharded_model, + pp_optimizer, + iter(input_list), + criterion, + return_loss=True, + return_outputs=True) + + # check loss + if stage_manager.is_last_stage(): + assert torch.allclose(torch_loss, pp_ret['loss']) + + # check gradients + torch_grad = [] + for torch_p in torch_model.parameters(): + torch_grad.append(torch_p.grad.data) + + for idx, pp_p in enumerate(sharded_model.parameters()): + if idx < 2: + assert torch.allclose(torch_grad[idx + local_rank * 2], pp_p.grad.data) + else: + assert torch.allclose(torch_grad[idx + local_rank * 2 + 6], pp_p.grad.data) + + # step + torch_optimizer.step() + pp_optimizer.step() + + # check updated param + torch_param = [] + for torch_p in torch_model.parameters(): + torch_param.append(torch_p.data) + for idx, pp_p in enumerate(sharded_model.parameters()): + if idx < 2: + assert torch.allclose(torch_param[idx + local_rank * 2], pp_p.data) + else: + assert torch.allclose(torch_param[idx + local_rank * 2 + 6], pp_p.data) + + +def run_dist(rank, world_size, port): + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + examine_pp() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_pp(): + spawn(run_dist, 4) + + +if __name__ == '__main__': + test_pp() diff --git a/tests/test_pipeline/test_schedule/test_oneF_oneB.py b/tests/test_pipeline/test_schedule/test_oneF_oneB.py index 542116a1da75..d31eafd70e1a 100644 --- a/tests/test_pipeline/test_schedule/test_oneF_oneB.py +++ b/tests/test_pipeline/test_schedule/test_oneF_oneB.py @@ -61,7 +61,7 @@ def examine_pp(): DP_DIM, PP_DIM, TP_DIM = 0, 1, 2 pg_mesh = ProcessGroupMesh(1, world_size, 1) stage_manager = PipelineStageManager(pg_mesh, PP_DIM) - schedule = OneForwardOneBackwardSchedule(NUM_MICRO_BATCHS, stage_manager) + schedule = OneForwardOneBackwardSchedule(stage_manager, num_microbatches=NUM_MICRO_BATCHS) for idx, (_, sub_model) in enumerate(pp_model.named_children()): if idx % (world_size) == local_rank: diff --git a/tests/test_pipeline/test_stage_manager.py b/tests/test_pipeline/test_stage_manager.py index be4591d58f74..6e0cd1998c11 100644 --- a/tests/test_pipeline/test_stage_manager.py +++ b/tests/test_pipeline/test_stage_manager.py @@ -49,15 +49,6 @@ def check_stage_manager(): next_rank = ranks_in_group[ranks_in_group.index(rank) + 1] assert stage_manager.get_next_rank() == next_rank - # check virtual stage - stage_manager.set_num_virtual_stages(PP_SIZE * 2) - assert stage_manager.num_virtual_stages == PP_SIZE * 2 - stage_manager.set_virtual_stage(stage_manager.stage * 2) - assert stage_manager.virtual_stage == stage_manager.stage * 2 - with stage_manager.switch_virtual_stage(stage_manager.stage * 2 + 1): - assert stage_manager.virtual_stage == stage_manager.stage * 2 + 1 - assert stage_manager.virtual_stage == stage_manager.stage * 2 - # check p2p groups for prev, cur in zip(ranks_in_group[:-1], ranks_in_group[1:]): if rank in [prev, cur]: diff --git a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py index b45cd172c3ca..4c0f884a7ed5 100644 --- a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py @@ -53,8 +53,7 @@ def rearrange(tensor: torch.Tensor, dim: int): return rearanged_tensor -@parameterize('lazy_init', [False, True]) -def check_linear_conv_1d_col(lazy_init: bool): +def check_linear_conv_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool): ctx = LazyInitContext() if lazy_init else nullcontext() linear = Conv1D(192, 48).cuda() with ctx: @@ -62,7 +61,9 @@ def check_linear_conv_1d_col(lazy_init: bool): linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module(linear_copy, process_group=None, gather_output=True, - n_fused=3) + seq_parallel=seq_parallel, + n_fused=3, + overlap=overlap) assert linear.weight.shape == torch.Size([48, 192]) assert linear.bias.shape == torch.Size([192]) @@ -76,10 +77,11 @@ def check_linear_conv_1d_col(lazy_init: bool): linear.load_state_dict(linear_conv_col.state_dict()) # check computation correctness - x = torch.rand(4, 48).cuda() + x = torch.rand(1, 4, 48).cuda() out = linear(x) - gather_out = linear_conv_col(x) - assert_close(rearrange(out, 1), gather_out) + x_for_shard = x.expand_as(x.clone()) if seq_parallel is False else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()] + gather_out = linear_conv_col(x_for_shard) + assert_close(rearrange(out, -1), gather_out) # check backward correctness out.sum().backward() @@ -89,14 +91,16 @@ def check_linear_conv_1d_col(lazy_init: bool): assert_close(target_grad, linear_conv_col.weight.grad) -@parameterize('lazy_init', [False, True]) -def check_linear_conv_1d_row(lazy_init: bool): +def check_linear_conv_1d_row(lazy_init: bool, seq_parallel: bool): ctx = LazyInitContext() if lazy_init else nullcontext() linear = Conv1D(192, 48).cuda() with ctx: linear_copy = Conv1D(192, 48).cuda() - linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear_copy, process_group=None, parallel_input=False) + linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear_copy, + process_group=None, + parallel_input=False, + seq_parallel=seq_parallel) assert linear.weight.shape == torch.Size([48, 192]) assert linear_row.weight.shape == torch.Size([24, 192]) @@ -109,10 +113,11 @@ def check_linear_conv_1d_row(lazy_init: bool): linear.load_state_dict(linear_row.state_dict()) # check computation correctness - x = torch.rand(4, 48).cuda() + x = torch.rand(1, 4, 48).cuda() out = linear(x) gather_out = linear_row(x) - assert_close(out, gather_out) + target_out = out if seq_parallel is False else torch.chunk(out.clone(), 2, dim=1)[dist.get_rank()] + assert_close(target_out, gather_out) # check backward correctness out.sum().backward() @@ -123,12 +128,19 @@ def check_linear_conv_1d_row(lazy_init: bool): assert_close(target_grad, linear_row.weight.grad) +@parameterize('lazy_init', [False, True]) +@parameterize('seq_parallel', [False, True]) +@parameterize('overlap', [True]) +def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel: bool, overlap: bool): + check_linear_conv_1d_col(lazy_init, seq_parallel, overlap) + check_linear_conv_1d_row(lazy_init, seq_parallel) + + def run_dist(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') # test for linear conv - check_linear_conv_1d_col() - check_linear_conv_1d_row() + check_gpt2_qkv_fused_linear_1d() @rerun_if_address_is_in_use() diff --git a/tests/test_shardformer/test_layer/test_linear_1d.py b/tests/test_shardformer/test_layer/test_linear_1d.py index aa75879e0313..e6d86d533ed6 100644 --- a/tests/test_shardformer/test_layer/test_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_linear_1d.py @@ -12,13 +12,16 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -@parameterize('lazy_init', [False, True]) -def check_linear_1d_col(lazy_init: bool): +def check_linear_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool): ctx = LazyInitContext() if lazy_init else nullcontext() linear = nn.Linear(32, 128).cuda() with ctx: linear_copy = nn.Linear(32, 128).cuda() - linear_col = Linear1D_Col.from_native_module(linear_copy, process_group=None, gather_output=True) + linear_col = Linear1D_Col.from_native_module(linear_copy, + process_group=None, + gather_output=True, + seq_parallel=seq_parallel, + overlap=overlap) # ensure that the parameters are distributed assert is_distributed_tensor(linear_col.weight) @@ -35,10 +38,11 @@ def check_linear_1d_col(lazy_init: bool): linear_col.load_state_dict(linear.state_dict()) # check computation correctness - x = torch.rand(4, 32).cuda() + # [batch_size, seq_len, hidden_size] + x = torch.rand(2, 4, 32).cuda() x_for_unshard = x.expand_as(x.clone()) x_for_unshard.requires_grad_(True) - x_for_shard = x.expand_as(x.clone()) + x_for_shard = x.expand_as(x.clone()) if seq_parallel is False else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()] x_for_shard.requires_grad_(True) out = linear(x_for_unshard) @@ -56,17 +60,21 @@ def check_linear_1d_col(lazy_init: bool): # check the input gradients assert x_for_shard.grad is not None assert x_for_unshard.grad is not None - assert_close(x_for_unshard.grad, x_for_shard.grad) + target_unshard_gard = x_for_unshard.grad if seq_parallel is False else torch.chunk( + x_for_unshard.grad.clone(), 2, dim=1)[dist.get_rank()] + assert_close(target_unshard_gard, x_for_shard.grad) -@parameterize('lazy_init', [False, True]) -def check_linear_1d_row(lazy_init: bool): +def check_linear_1d_row(lazy_init: bool, seq_parallel: bool): ctx = LazyInitContext() if lazy_init else nullcontext() linear = nn.Linear(32, 128).cuda() with ctx: linear_copy = nn.Linear(32, 128).cuda() - linear_row = Linear1D_Row.from_native_module(linear_copy, process_group=None, parallel_input=False) + linear_row = Linear1D_Row.from_native_module(linear_copy, + process_group=None, + parallel_input=False, + seq_parallel=seq_parallel) assert linear_row.weight.shape == torch.Size([128, 16]) assert linear_row.bias.shape == torch.Size([128]) @@ -77,7 +85,8 @@ def check_linear_1d_row(lazy_init: bool): linear_row.load_state_dict(linear.state_dict()) # check computation correctness - x = torch.rand(4, 32).cuda() + # [batch_size, seq_len, hidden_size] + x = torch.rand(2, 4, 32).cuda() x_for_unshard = x.expand_as(x.clone()) x_for_unshard.requires_grad_(True) x_for_shard = x.expand_as(x.clone()) @@ -86,7 +95,8 @@ def check_linear_1d_row(lazy_init: bool): # run forward out = linear(x_for_unshard) gather_out = linear_row(x_for_shard) - assert_close(out, gather_out) + target_out = out if seq_parallel is False else torch.chunk(out.clone(), 2, dim=1)[dist.get_rank()] + assert_close(target_out, gather_out) # check backward correctness out.sum().backward() @@ -102,8 +112,7 @@ def check_linear_1d_row(lazy_init: bool): assert_close(x_for_unshard.grad, x_for_shard.grad) -@parameterize('lazy_init', [False, True]) -def check_linear_col_plus_row(lazy_init: bool): +def check_linear_col_plus_row(lazy_init: bool, seq_parallel: bool, overlap: bool): ctx = LazyInitContext() if lazy_init else nullcontext() linear_1 = nn.Linear(32, 128).cuda() @@ -112,8 +121,15 @@ def check_linear_col_plus_row(lazy_init: bool): with ctx: linear_1_copy = nn.Linear(32, 128).cuda() linear_2_copy = nn.Linear(128, 32).cuda() - linear_col = Linear1D_Col.from_native_module(linear_1_copy, process_group=None, gather_output=False) - linear_row = Linear1D_Row.from_native_module(linear_2_copy, process_group=None, parallel_input=True) + linear_col = Linear1D_Col.from_native_module(linear_1_copy, + process_group=None, + gather_output=False, + seq_parallel=seq_parallel, + overlap=overlap) + linear_row = Linear1D_Row.from_native_module(linear_2_copy, + process_group=None, + parallel_input=True, + seq_parallel=seq_parallel) linear_1.load_state_dict(linear_col.state_dict()) linear_col.load_state_dict(linear_1.state_dict()) @@ -121,16 +137,18 @@ def check_linear_col_plus_row(lazy_init: bool): linear_row.load_state_dict(linear_2.state_dict()) # check computation correctness - x = torch.rand(4, 32).cuda() + # [batch_size, seq_len, hidden_size] + x = torch.rand(2, 4, 32).cuda() x_for_unshard = x.expand_as(x.clone()) x_for_unshard.requires_grad_(True) - x_for_shard = x.expand_as(x.clone()) + x_for_shard = x.expand_as(x.clone()) if seq_parallel is False else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()] x_for_shard.requires_grad_(True) # run forward unshard_out = linear_2(linear_1(x_for_unshard)) shard_out = linear_row(linear_col(x_for_shard)) - assert_close(unshard_out, shard_out) + target_out = unshard_out if seq_parallel is False else torch.chunk(unshard_out.clone(), 2, dim=1)[dist.get_rank()] + assert_close(target_out, shard_out) # check backward correctness unshard_out.sum().backward() @@ -143,19 +161,28 @@ def check_linear_col_plus_row(lazy_init: bool): # check the input gradients assert x_for_shard.grad is not None assert x_for_unshard.grad is not None - assert_close(x_for_unshard.grad, x_for_shard.grad) + target_unshard_gard = x_for_unshard.grad if seq_parallel is False else torch.chunk( + x_for_unshard.grad.clone(), 2, dim=1)[dist.get_rank()] + assert_close(target_unshard_gard, x_for_shard.grad) + + +@parameterize('lazy_init', [False, True]) +@parameterize('seq_parallel', [False, True]) +@parameterize('overlap', [True]) +def run_dist_linear_test(lazy_init, seq_parallel, overlap): + check_linear_1d_col(lazy_init, seq_parallel, overlap) + check_linear_1d_row(lazy_init, seq_parallel) + check_linear_col_plus_row(lazy_init, seq_parallel, overlap) -def run_dist(rank, world_size, port): +def check_dist_linear(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - check_linear_1d_col() - check_linear_1d_row() - check_linear_col_plus_row() + run_dist_linear_test() @rerun_if_address_is_in_use() def test_linear(): - spawn(run_dist, nprocs=2) + spawn(check_dist_linear, nprocs=2) if __name__ == '__main__': diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 921af2a8b1d0..f77bf7495808 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -1,4 +1,5 @@ import copy +import math from contextlib import nullcontext from typing import Any, Callable, Dict, List, Optional @@ -12,6 +13,7 @@ from colossalai.booster import Booster from colossalai.booster.plugin import HybridParallelPlugin +from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule from colossalai.lazy import LazyInitContext from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig, ShardFormer @@ -25,6 +27,7 @@ def build_model(model_fn, enable_tensor_parallelism=True, enable_flash_attention=False, enable_jit_fused=False, + enable_sequence_parallelism=False, use_lazy_init: bool = False): # create new model ctx = LazyInitContext() if use_lazy_init else nullcontext() @@ -38,7 +41,8 @@ def build_model(model_fn, shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization, enable_tensor_parallelism=enable_tensor_parallelism, enable_flash_attention=enable_flash_attention, - enable_jit_fused=enable_jit_fused) + enable_jit_fused=enable_jit_fused, + enable_sequence_parallelism=enable_sequence_parallelism) model_copy = copy.deepcopy(org_model) shard_former = ShardFormer(shard_config=shard_config) sharded_model, shared_params = shard_former.optimize(model_copy) @@ -135,6 +139,16 @@ def _criterion(outputs, inputs): return loss data = data_gen_fn() + + if booster.plugin.enable_sequence_parallelism and booster.plugin.tp_size != 0: + seq_len = data['input_ids'].shape[1] + lcm = booster.plugin.tp_size * seq_len // math.gcd(booster.plugin.tp_size, seq_len) + times = lcm // seq_len + input_shape = data['input_ids'].shape + for k, v in data.items(): + if v.shape == input_shape: + data[k] = v.repeat(1, times) + sharded_model.train() if booster.plugin.stage_manager is not None: for k, v in data.items(): @@ -177,11 +191,10 @@ def check_output_hidden_state(org_output: Tensor, org_hidden_state = org_output.last_hidden_state - if stage_manager is None: - sharded_hidden_state = sharded_output.last_hidden_state - if stage_manager and stage_manager.is_last_stage(): - sharded_hidden_state = torch.cat([output.last_hidden_state for output in sharded_output['outputs']], dim=dim) + sharded_hidden_state = sharded_output['outputs']['last_hidden_state'] + else: + sharded_hidden_state = sharded_output.last_hidden_state assert torch.allclose(org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol), \ f"shard model's output hidden state is not equal to origin model's last hidden state\n{org_hidden_state}\n{sharded_hidden_state}" @@ -219,6 +232,43 @@ def check_weight(org_model: Module, f"shard model weight {suffix} is not equal to origin model weight\n{org_weight}\n{sharded_weight}" +def get_grad_tensors_for_check(org_model: Module, + sharded_model: Module, + layer_suffix: List[str], + tp_group: ProcessGroup = None, + dim: int = 0, + atol: float = 1e-5, + rtol: float = 1e-3, + verbose: bool = False, + name: str = None): + + grad_to_check = {} + for suffix in layer_suffix: + org_grad = getattr_(org_model, suffix).weight.grad + shard_grad = getattr_(sharded_model, suffix).weight.grad + shard_weight = getattr_(sharded_model, suffix).weight + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros_like(shard_grad).to('cuda') for _ in range(dist.get_world_size(tp_group))] + dist.all_gather(shard_grad_list, shard_grad, tp_group) + shard_grad = torch.cat(shard_grad_list, dim=dim) + + # embedding may be resized when using tensor parallel + if shard_grad.shape[0] > org_grad.shape[0]: + shard_grad = shard_grad[:org_grad.shape[0], :] + if verbose and dist.get_rank() == 0: + print(f"'{suffix}' grad: {org_grad}, {shard_grad}") + + grad_to_check[suffix] = { + "org_grad": org_grad.float(), + "shard_grad": shard_grad.float(), + "rtol": rtol, + "atol": atol + } + + return grad_to_check + + +# used by sam/blip2 def check_grad(org_model: Module, sharded_model: Module, layer_suffix: List[str], @@ -231,7 +281,6 @@ def check_grad(org_model: Module, org_grad = getattr_(org_model, suffix).weight.grad shard_grad = getattr_(sharded_model, suffix).weight.grad shard_weight = getattr_(sharded_model, suffix).weight - if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): shard_grad_list = [torch.zeros_like(shard_grad).to('cuda') for _ in range(dist.get_world_size(tp_group))] dist.all_gather(shard_grad_list, shard_grad, tp_group) @@ -246,3 +295,30 @@ def check_grad(org_model: Module, assert torch.allclose( org_grad.float(), shard_grad.float(), rtol=rtol, atol=atol ), f"error attribute '{suffix}', orgin model grad is not equal to shard model grad\n{org_grad}\n{shard_grad}" + + +def unwrap_model(module: Module, + base_model_class_name: Optional[str] = None, + base_model_attribute_name: Optional[str] = None): + if isinstance(module, HybridParallelModule): + module = module.unwrap() + if base_model_class_name is None: + return module + if module.__class__.__name__ == base_model_class_name: + return module + return getattr(module, base_model_attribute_name, None) + + +def check_all_grad_tensors(check_tensors): + """ + "org_grad": tensor to be compared from the original model + "shard_grad": tensor to be compared from the sharded model + """ + for suffix, check_info in check_tensors.items(): + org_grad = check_info["org_grad"] + shard_grad = check_info["shard_grad"] + rtol = check_info["rtol"] + atol = check_info["atol"] + assert torch.allclose( + org_grad, shard_grad, atol=atol, rtol=rtol + ), f"error attribute '{suffix}', orgin model grad is not equal to shard model grad\n{org_grad}\n{shard_grad}" diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py index 0a24e46d28f2..c779e417052b 100644 --- a/tests/test_shardformer/test_model/test_shard_bert.py +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -10,11 +10,13 @@ from tests.kit.model_zoo import model_zoo from tests.test_shardformer.test_model._utils import ( build_model_from_hybrid_plugin, - check_grad, + check_all_grad_tensors, check_loss, check_output_hidden_state, check_weight, + get_grad_tensors_for_check, run_forward_backward_with_hybrid_plugin, + unwrap_model, ) @@ -32,42 +34,58 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, output_transform_fn, criterion, booster) + stage_manager = booster.plugin.stage_manager tp_group = booster.plugin.tp_group - # check last hidden state & loss - if stage_manager is None or stage_manager.is_last_stage(): - if test_config['precision'] == 'fp32': - atol, rtol = 1e-5, 1e-3 - else: - atol, rtol = 5e-3, 5e-3 - if org_model.__class__.__name__ == 'BertModel': - check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) - check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) - # unwrap model - if org_model.__class__.__name__ == 'BertModel': - bert = org_model - sharded_bert = sharded_model.unwrap() - else: - bert = org_model.bert - sharded_bert = sharded_model.unwrap().bert + bert = unwrap_model(org_model, 'BertModel', 'bert') + sharded_bert = unwrap_model(sharded_model, 'BertModel', 'bert') col_layer_for_check = ['encoder.layer[0].output.dense'] row_layer_for_check = ['embeddings.word_embeddings', 'encoder.layer[0].intermediate.dense'] + # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. + grads_to_check = {} if test_config['precision'] == 'fp32': atol, rtol = 1e-4, 1e-3 else: atol, rtol = 5e-3, 5e-3 - if stage_manager is None or stage_manager.is_first_stage(): - #check_weight(bert.embeddings.word_embeddings, sharded_bert.embeddings.word_embeddings, tp_group, atol=1e-5, rtol=1e-3) - #check_weight(bert.encoder.layer[0].attention.self.query, sharded_bert.encoder.layer[0].attention.self.query, tp_group, atol=5e-3, rtol=1e-3) - check_grad(bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False) - check_grad(bert, sharded_bert, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False) - - # check weights after optimizer.step() + if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: + col_layer_grads = get_grad_tensors_for_check(bert, + sharded_bert, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False) + row_layer_grads = get_grad_tensors_for_check(bert, + sharded_bert, + row_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=0, + verbose=False) + grads_to_check.update(col_layer_grads) + grads_to_check.update(row_layer_grads) + + # optimizer executes step org_optimizer.step() sharded_optimizer.step() + + # check last hidden state & loss + if stage_manager is None or stage_manager.is_last_stage(): + if test_config['precision'] == 'fp32': + atol, rtol = 1e-5, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + if org_model.__class__.__name__ == 'BertModel': + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) + + check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) + + # check weights if test_config['precision'] == 'fp32': atol, rtol = 5e-3, 1e-3 else: @@ -75,6 +93,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, if stage_manager is None or stage_manager.is_first_stage(): check_weight(bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False) + # check grads + check_all_grad_tensors(grads_to_check) + torch.cuda.empty_cache() @@ -98,6 +119,29 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'enable_all_optimization': True, 'use_lazy_init': False, 'precision': 'fp32', +}, { + 'tp_size': 2, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': False, + 'precision': 'fp32' +}, { + 'tp_size': 2, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'zero_stage': 2, + 'precision': 'fp16', + 'initial_scale': 1 +}, { + 'tp_size': 1, + 'pp_size': 2, + 'num_microbatches': 2, + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'zero_stage': 1, + 'precision': 'fp16', + 'initial_scale': 1 }]) def run_bert_test(test_config): @@ -111,12 +155,50 @@ def run_bert_test(test_config): torch.cuda.empty_cache() +@parameterize('test_config', [ + { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_all_optimization': False, + 'use_lazy_init': False, + 'precision': 'fp32', + }, + { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_all_optimization': False, + 'use_lazy_init': False, + 'precision': 'fp16', + 'zero_stage': 1, + 'initial_scale': 1, + }, +]) +def run_bert_3d_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry('transformers_bert') + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + Randomizer.reset_index() + torch.cuda.empty_cache() + + def check_bert(rank, world_size, port): disable_existing_loggers() colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') run_bert_test() +def check_bert_3d(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_bert_3d_test() + + @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() @@ -124,5 +206,13 @@ def test_bert(): spawn(check_bert, 4) +@pytest.mark.largedist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_bert_3d(): + spawn(check_bert_3d, 8) + + if __name__ == "__main__": test_bert() + test_bert_3d() diff --git a/tests/test_shardformer/test_model/test_shard_bloom.py b/tests/test_shardformer/test_model/test_shard_bloom.py index ed0d1d8e401d..c9ee690c86dc 100644 --- a/tests/test_shardformer/test_model/test_shard_bloom.py +++ b/tests/test_shardformer/test_model/test_shard_bloom.py @@ -3,16 +3,19 @@ import colossalai from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.layer.utils import Randomizer from colossalai.tensor.d_tensor.api import clear_layout_converter from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo from tests.test_shardformer.test_model._utils import ( build_model_from_hybrid_plugin, - check_grad, + check_all_grad_tensors, check_loss, check_output_hidden_state, check_weight, + get_grad_tensors_for_check, run_forward_backward_with_hybrid_plugin, + unwrap_model, ) @@ -34,6 +37,43 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, stage_manager = booster.plugin.stage_manager tp_group = booster.plugin.tp_group + # unwrap model + bloom = unwrap_model(org_model, 'BloomModel', 'transformer') + sharded_bloom = unwrap_model(sharded_model, 'BloomModel', 'transformer') + + row_layer_for_check = ['h[0].self_attention.query_key_value', 'word_embeddings'] + col_layer_for_check = ['h[0].self_attention.dense'] + + # Save gradient tensors for comparison between the original model and the sharded model. + grads_to_check = {} + if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: + if test_config['precision'] == 'fp32': + atol, rtol = 1e-6, 1e-5 + else: + atol, rtol = 5e-3, 5e-3 + row_layer_grads = get_grad_tensors_for_check(bloom, + sharded_bloom, + row_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=0, + verbose=False) + col_layer_grads = get_grad_tensors_for_check(bloom, + sharded_bloom, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False) + grads_to_check.update(col_layer_grads) + grads_to_check.update(row_layer_grads) + + # optimizer executes step + org_optimizer.step() + sharded_optimizer.step() + # check last hidden state & loss if stage_manager is None or stage_manager.is_last_stage(): if test_config['precision'] == 'fp32': @@ -45,28 +85,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) - # unwrap model - if org_model.__class__.__name__ == 'BloomModel': - bloom = org_model - sharded_bloom = sharded_model.unwrap() - else: - bloom = org_model.transformer - sharded_bloom = sharded_model.unwrap().transformer - - # check grad - row_layer_for_check = ['h[0].self_attention.query_key_value', 'word_embeddings'] - col_layer_for_check = ['h[0].self_attention.dense'] - if stage_manager is None or stage_manager.is_first_stage(): - if test_config['precision'] == 'fp32': - atol, rtol = 1e-6, 1e-5 - else: - atol, rtol = 5e-3, 5e-3 - check_grad(bloom, sharded_bloom, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False) - check_grad(bloom, sharded_bloom, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False) - - # check weights after optimizer.step() - org_optimizer.step() - sharded_optimizer.step() if stage_manager is None or stage_manager.is_first_stage(): if test_config['precision'] == 'fp32': atol, rtol = 1e-4, 1e-3 @@ -74,6 +92,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, atol, rtol = 5e-3, 5e-3 check_weight(bloom, sharded_bloom, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False) + # check grads + check_all_grad_tensors(grads_to_check) + torch.cuda.empty_cache() @@ -97,18 +118,72 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'pp_size': 1, 'enable_all_optimization': True, 'use_lazy_init': False, - 'precision': 'fp32', + 'precision': 'fp32' +}, { + 'tp_size': 2, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': False, + 'precision': 'fp32' +}, { + 'tp_size': 2, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'zero_stage': 2, + 'precision': 'fp16', + 'initial_scale': 1 +}, { + 'tp_size': 1, + 'pp_size': 2, + 'num_microbatches': 2, + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'zero_stage': 1, + 'precision': 'fp16', + 'initial_scale': 1 }]) def run_bloom_test(test_config): - # TODO(baizhou): add test_config for TP+DP after supporting & debugging it + sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom') + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + Randomizer.reset_index() + torch.cuda.empty_cache() + + +@parameterize('test_config', [ + { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_all_optimization': False, + 'use_lazy_init': False, + 'precision': 'fp32', + 'initial_scale': 1, + }, + { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_all_optimization': False, + 'use_lazy_init': False, + 'precision': 'fp16', + 'zero_stage': 1, + 'initial_scale': 1, + }, +]) +def run_bloom_3d_test(test_config): sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) clear_layout_converter() + Randomizer.reset_index() torch.cuda.empty_cache() @@ -118,6 +193,12 @@ def check_bloom(rank, world_size, port): run_bloom_test() +def check_bloom_3d(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_bloom_3d_test() + + @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() @@ -125,5 +206,13 @@ def test_bloom(): spawn(check_bloom, 4) +@pytest.mark.largedist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_bloom_3d(): + spawn(check_bloom_3d, 8) + + if __name__ == "__main__": test_bloom() + test_bloom_3d() diff --git a/tests/test_shardformer/test_model/test_shard_chatglm.py b/tests/test_shardformer/test_model/test_shard_chatglm2.py similarity index 52% rename from tests/test_shardformer/test_model/test_shard_chatglm.py rename to tests/test_shardformer/test_model/test_shard_chatglm2.py index bb77759048b3..48f651c727f4 100644 --- a/tests/test_shardformer/test_model/test_shard_chatglm.py +++ b/tests/test_shardformer/test_model/test_shard_chatglm2.py @@ -4,16 +4,19 @@ import colossalai from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.layer.utils import Randomizer from colossalai.tensor.d_tensor.api import clear_layout_converter from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo from tests.test_shardformer.test_model._utils import ( build_model_from_hybrid_plugin, - check_grad, + check_all_grad_tensors, check_loss, check_output_hidden_state, check_weight, + get_grad_tensors_for_check, run_forward_backward_with_hybrid_plugin, + unwrap_model, ) @@ -35,6 +38,44 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, stage_manager = booster.plugin.stage_manager tp_group = booster.plugin.tp_group + # unwrap model + chatglm_model = unwrap_model(org_model, 'ChatGLMModel', 'transformer') + shard_chatglm_model = unwrap_model(sharded_model, 'ChatGLMModel', 'transformer') + + row_layer_for_check = ['encoder.layers[0].self_attention.query_key_value', 'embedding.word_embeddings'] + col_layer_for_check = ['encoder.layers[0].self_attention.dense'] + + # Save gradient tensors for comparison between the original model and the sharded model. + grads_to_check = {} + if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: + if test_config['precision'] == 'fp32': + atol, rtol = 1e-6, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + row_layer_grads = get_grad_tensors_for_check(chatglm_model, + shard_chatglm_model, + row_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=0, + verbose=False) + + col_layer_grads = get_grad_tensors_for_check(chatglm_model, + shard_chatglm_model, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False) + grads_to_check.update(col_layer_grads) + grads_to_check.update(row_layer_grads) + + # optimizer executes step + org_optimizer.step() + sharded_optimizer.step() + # check last hidden state & loss if stage_manager is None or stage_manager.is_last_stage(): if test_config['precision'] == 'fp32': @@ -47,43 +88,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) - # unwrap model - if org_model.__class__.__name__ == 'ChatGLMModel': - chatglm_model = org_model - shard_chatglm_model = sharded_model.unwrap() - else: - chatglm_model = org_model.transformer - shard_chatglm_model = sharded_model.unwrap().transformer - - # check grad - row_layer_for_check = ['encoder.layers[0].self_attention.query_key_value', 'embedding.word_embeddings'] - col_layer_for_check = ['encoder.layers[0].self_attention.dense'] - if stage_manager is None or stage_manager.is_first_stage(): - if test_config['precision'] == 'fp32': - atol, rtol = 1e-6, 1e-3 - else: - atol, rtol = 5e-3, 5e-3 - check_grad(chatglm_model, - shard_chatglm_model, - row_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=0, - verbose=False) - - check_grad(chatglm_model, - shard_chatglm_model, - col_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=1, - verbose=False) - - # check weights after optimizer.step() - org_optimizer.step() - sharded_optimizer.step() + # check weights if stage_manager is None or stage_manager.is_first_stage(): if test_config['precision'] == 'fp32': atol, rtol = 1e-4, 1e-3 @@ -98,6 +103,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, dim=1, verbose=False) + # check grads + check_all_grad_tensors(grads_to_check) + + Randomizer.reset_index() torch.cuda.empty_cache() @@ -121,12 +130,55 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'pp_size': 1, 'enable_all_optimization': True, 'use_lazy_init': False, - 'precision': 'fp32', + 'precision': 'fp32' +}, { + 'tp_size': 2, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': False, + 'precision': 'fp32' +}, { + 'tp_size': 2, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'zero_stage': 2, + 'precision': 'fp16', + 'initial_scale': 1 }]) def run_chatglm_test(test_config): - # TODO(baizhou): add test_config for TP+DP after supporting & debugging it + sub_model_zoo = model_zoo.get_sub_registry('transformers_chatglm') + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + clear_layout_converter() + torch.cuda.empty_cache() + + +@parameterize('test_config', [ + { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_all_optimization': False, + 'use_lazy_init': False, + 'precision': 'fp32', + 'initial_scale': 1, + }, + { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_all_optimization': False, + 'use_lazy_init': False, + 'precision': 'fp16', + 'zero_stage': 1, + 'initial_scale': 1, + }, +]) +def run_chatglm_3d_test(test_config): sub_model_zoo = model_zoo.get_sub_registry('transformers_chatglm') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): @@ -142,6 +194,12 @@ def check_chatglm(rank, world_size, port): run_chatglm_test() +def check_chatglm_3d(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_chatglm_3d_test() + + @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() @@ -149,5 +207,13 @@ def test_chatglm(): spawn(check_chatglm, 4) +@pytest.mark.largedist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_chatglm_3d(): + spawn(check_chatglm_3d, 8) + + if __name__ == "__main__": test_chatglm() + test_chatglm_3d() diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 1a81b3360655..768063e537c7 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -3,18 +3,20 @@ from torch import distributed as dist import colossalai -from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.layer.utils import Randomizer from colossalai.tensor.d_tensor.api import clear_layout_converter from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo from tests.test_shardformer.test_model._utils import ( build_model_from_hybrid_plugin, - check_grad, + check_all_grad_tensors, check_loss, check_output_hidden_state, check_weight, + get_grad_tensors_for_check, run_forward_backward_with_hybrid_plugin, + unwrap_model, ) @@ -36,6 +38,43 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, stage_manager = booster.plugin.stage_manager tp_group = booster.plugin.tp_group + # unwrap model + gpt2 = unwrap_model(org_model, 'GPT2Model', 'transformer') + sharded_gpt2 = unwrap_model(sharded_model, 'GPT2Model', 'transformer') + + col_layer_for_check = ['h[0].mlp.c_fc'] + row_layer_for_check = ['wte', 'h[0].mlp.c_proj'] + + # Save gradient tensors for comparison between the original model and the sharded model. + grads_to_check = {} + if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: + if test_config['precision'] == 'fp32': + atol, rtol = 1e-4, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + col_layer_grads = get_grad_tensors_for_check(gpt2, + sharded_gpt2, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False) + row_layer_grads = get_grad_tensors_for_check(gpt2, + sharded_gpt2, + row_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=0, + verbose=False) + grads_to_check.update(col_layer_grads) + grads_to_check.update(row_layer_grads) + + # optimizer executes step + org_optimizer.step() + sharded_optimizer.step() + # check last hidden state & loss if stage_manager is None or stage_manager.is_last_stage(): if test_config['precision'] == 'fp32': @@ -48,32 +87,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) - def unwrap(module): - if isinstance(module, HybridParallelModule): - module = module.unwrap() - if module.__class__.__name__ == 'GPT2Model': - return module - return module.transformer - - # unwrap model - gpt2 = unwrap(org_model) - sharded_gpt2 = unwrap(sharded_model) - - col_layer_for_check = ['h[0].mlp.c_fc'] - row_layer_for_check = ['wte', 'h[0].mlp.c_proj'] - - # check grad - if stage_manager is None or stage_manager.is_first_stage(): - if test_config['precision'] == 'fp32': - atol, rtol = 1e-4, 1e-3 - else: - atol, rtol = 5e-3, 5e-3 - check_grad(gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False) - check_grad(gpt2, sharded_gpt2, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False) - - # check weights after optimizer.step() - org_optimizer.step() - sharded_optimizer.step() + # check weights if stage_manager is None or stage_manager.is_first_stage(): if test_config['precision'] == 'fp32': atol, rtol = 5e-3, 1e-3 @@ -81,6 +95,10 @@ def unwrap(module): atol, rtol = 5e-3, 5e-3 check_weight(gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False) + # check grads + check_all_grad_tensors(grads_to_check) + + Randomizer.reset_index() torch.cuda.empty_cache() @@ -106,12 +124,80 @@ def unwrap(module): 'enable_all_optimization': True, 'use_lazy_init': False, 'precision': 'fp32', +}, { + 'tp_size': 2, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': False, + 'precision': 'fp32', +}, { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'enable_sequence_parallelism': True, + 'precision': 'fp32', +}, { + 'tp_size': 4, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'enable_sequence_parallelism': True, + 'precision': 'fp32', +}, { + 'tp_size': 2, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'zero_stage': 2, + 'precision': 'fp16', + 'initial_scale': 1 +}, { + 'tp_size': 1, + 'pp_size': 2, + 'num_microbatches': 2, + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'zero_stage': 1, + 'precision': 'fp16', + 'initial_scale': 1 }]) @clear_cache_before_run() def run_gpt2_test(test_config): - # TODO(baizhou): add test_config for TP+DP after supporting & debugging it + sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + torch.cuda.empty_cache() + +@parameterize('test_config', [ + { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_all_optimization': False, + 'use_lazy_init': False, + 'precision': 'fp32', + 'initial_scale': 1, + }, + { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_all_optimization': False, + 'use_lazy_init': False, + 'precision': 'fp16', + 'zero_stage': 1, + 'initial_scale': 1, + }, +]) +@clear_cache_before_run() +def run_gpt2_3d_test(test_config): sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): @@ -127,10 +213,13 @@ def check_gpt2(rank, world_size, port): run_gpt2_test() -# TODO(ver217): fix this +def check_gpt2_3d(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_gpt2_3d_test() -@pytest.mark.skip("this will stuck in CI") +@pytest.mark.skip(reason="This test will hang in CI") @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() @@ -138,5 +227,13 @@ def test_gpt2(): spawn(check_gpt2, 4) +@pytest.mark.largedist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_gpt2_3d(): + spawn(check_gpt2_3d, 8) + + if __name__ == "__main__": test_gpt2() + test_gpt2_3d() diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 30ebdfbe5cd9..a60150e3cd72 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -6,16 +6,19 @@ import colossalai from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.layer.utils import Randomizer from colossalai.tensor.d_tensor.api import clear_layout_converter from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo from tests.test_shardformer.test_model._utils import ( build_model_from_hybrid_plugin, - check_grad, + check_all_grad_tensors, check_loss, check_output_hidden_state, check_weight, + get_grad_tensors_for_check, run_forward_backward_with_hybrid_plugin, + unwrap_model, ) os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' @@ -39,6 +42,43 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, stage_manager = booster.plugin.stage_manager tp_group = booster.plugin.tp_group + # unwrap model + llama_model = unwrap_model(org_model, 'LlamaModel', 'model') + shard_llama_model = unwrap_model(sharded_model, 'LlamaModel', 'model') + + row_layer_for_check = ['layers[0].self_attn.q_proj', 'embed_tokens'] + col_layer_for_check = ['layers[0].self_attn.o_proj'] + + # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. + grads_to_check = {} + if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: + if test_config['precision'] == 'fp32': + atol, rtol = 1e-6, 1e-4 + else: + atol, rtol = 5e-3, 5e-3 + row_layer_grads = get_grad_tensors_for_check(llama_model, + shard_llama_model, + row_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=0, + verbose=False) + col_layer_grads = get_grad_tensors_for_check(llama_model, + shard_llama_model, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False) + grads_to_check.update(col_layer_grads) + grads_to_check.update(row_layer_grads) + + # optimizer executes step + org_optimizer.step() + sharded_optimizer.step() + # check last hidden state & loss if stage_manager is None or stage_manager.is_last_stage(): if test_config['precision'] == 'fp32': @@ -51,42 +91,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) - # unwrap model - if org_model.__class__.__name__ == 'LlamaModel': - llama_model = org_model - shard_llama_model = sharded_model.unwrap() - else: - llama_model = org_model.model - shard_llama_model = sharded_model.unwrap().model - - # check grad - row_layer_for_check = ['layers[0].self_attn.q_proj', 'embed_tokens'] - col_layer_for_check = ['layers[0].self_attn.o_proj'] - if stage_manager is None or stage_manager.is_first_stage(): - if test_config['precision'] == 'fp32': - atol, rtol = 1e-6, 1e-4 - else: - atol, rtol = 5e-3, 5e-3 - check_grad(llama_model, - shard_llama_model, - row_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=0, - verbose=False) - check_grad(llama_model, - shard_llama_model, - col_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=1, - verbose=False) - - # check weights after optimizer.step() - org_optimizer.step() - sharded_optimizer.step() + # check weights if stage_manager is None or stage_manager.is_first_stage(): if test_config['precision'] == 'fp32': atol, rtol = 1e-4, 1e-3 @@ -101,6 +106,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, dim=1, verbose=False) + # check grads + check_all_grad_tensors(grads_to_check) + torch.cuda.empty_cache() @@ -128,19 +136,74 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'tp_size': 1, 'pp_size': 4, 'num_microbatches': 4, + 'enable_all_optimization': False, 'use_lazy_init': False, - 'precision': 'fp32', + 'precision': 'fp32' +}, { + 'tp_size': 2, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': False, + 'precision': 'fp32' +}, { + 'tp_size': 2, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'zero_stage': 2, + 'precision': 'fp16', + 'initial_scale': 1 +}, { + 'tp_size': 1, + 'pp_size': 2, + 'num_microbatches': 2, + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'zero_stage': 1, + 'precision': 'fp16', + 'initial_scale': 1 }]) def run_llama_test(test_config): - # TODO(baizhou): add test_config for TP+DP after supporting & debugging it + sub_model_zoo = model_zoo.get_sub_registry('transformers_llama') + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + Randomizer.reset_index() + torch.cuda.empty_cache() + + +@parameterize('test_config', [ + { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_all_optimization': False, + 'use_lazy_init': False, + 'precision': 'fp32', + 'initial_scale': 1, + }, + { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_all_optimization': False, + 'use_lazy_init': False, + 'precision': 'fp16', + 'zero_stage': 1, + 'initial_scale': 1, + }, +]) +def run_llama_3d_test(test_config): sub_model_zoo = model_zoo.get_sub_registry('transformers_llama') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) clear_layout_converter() + Randomizer.reset_index() torch.cuda.empty_cache() @@ -150,6 +213,12 @@ def check_llama(rank, world_size, port): run_llama_test() +def check_llama_3d(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_llama_3d_test() + + @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() @@ -157,5 +226,13 @@ def test_llama(): spawn(check_llama, 4) +@pytest.mark.largedist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_llama_3d(): + spawn(check_llama_3d, 8) + + if __name__ == "__main__": test_llama() + test_llama_3d() diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py index 8d1154d82638..3e74859ad1a8 100644 --- a/tests/test_shardformer/test_model/test_shard_opt.py +++ b/tests/test_shardformer/test_model/test_shard_opt.py @@ -6,16 +6,19 @@ import colossalai from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.layer.utils import Randomizer from colossalai.tensor.d_tensor.api import clear_layout_converter from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo from tests.test_shardformer.test_model._utils import ( build_model_from_hybrid_plugin, - check_grad, + check_all_grad_tensors, check_loss, check_output_hidden_state, check_weight, + get_grad_tensors_for_check, run_forward_backward_with_hybrid_plugin, + unwrap_model, ) os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' @@ -39,6 +42,43 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, stage_manager = booster.plugin.stage_manager tp_group = booster.plugin.tp_group + # unwrap model + opt_model = unwrap_model(org_model, 'OPTModel', 'model') + shard_opt_model = unwrap_model(sharded_model, 'OPTModel', 'model') + + row_layer_for_check = ['decoder.layers[0].self_attn.q_proj', 'decoder.embed_tokens'] # 'decoder.embed_tokens' + col_layer_for_check = ['decoder.layers[0].self_attn.out_proj'] + + # Save gradient tensors for comparison between the original model and the sharded model. + grads_to_check = {} + if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: + if test_config['precision'] == 'fp32': + atol, rtol = 1e-6, 1e-3 + else: + atol, rtol = 4e-2, 4e-2 + row_layer_grads = get_grad_tensors_for_check(opt_model, + shard_opt_model, + row_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=0, + verbose=False) + col_layer_grads = get_grad_tensors_for_check(opt_model, + shard_opt_model, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False) + grads_to_check.update(col_layer_grads) + grads_to_check.update(row_layer_grads) + + # optimizer executes step + org_optimizer.step() + sharded_optimizer.step() + # check last hidden state & loss if stage_manager is None or stage_manager.is_last_stage(): if test_config['precision'] == 'fp32': @@ -50,42 +90,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) - # unwrap model - if org_model.__class__.__name__ == 'OPTModel': - opt_model = org_model - shard_opt_model = sharded_model.unwrap() - else: - opt_model = org_model.model - shard_opt_model = sharded_model.unwrap().model - - # check grad - row_layer_for_check = ['decoder.layers[0].self_attn.q_proj', 'decoder.embed_tokens'] # 'decoder.embed_tokens' - col_layer_for_check = ['decoder.layers[0].self_attn.out_proj'] - if stage_manager is None or stage_manager.is_first_stage(): - if test_config['precision'] == 'fp32': - atol, rtol = 1e-6, 1e-3 - else: - atol, rtol = 3e-2, 3e-2 - check_grad(opt_model, - shard_opt_model, - row_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=0, - verbose=False) - check_grad(opt_model, - shard_opt_model, - col_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=1, - verbose=False) - - # check weights after optimizer.step() - org_optimizer.step() - sharded_optimizer.step() + # check weights if stage_manager is None or stage_manager.is_first_stage(): if test_config['precision'] == 'fp32': atol, rtol = 1e-3, 1e-3 @@ -100,6 +105,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, dim=1, verbose=False) + # check grads + check_all_grad_tensors(grads_to_check) + + Randomizer.reset_index() torch.cuda.empty_cache() @@ -123,12 +132,62 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'pp_size': 1, 'enable_all_optimization': True, 'use_lazy_init': False, - 'precision': 'fp32', + 'precision': 'fp32' +}, { + 'tp_size': 2, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': False, + 'precision': 'fp32' +}, { + 'tp_size': 2, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'zero_stage': 2, + 'precision': 'fp16', + 'initial_scale': 1 +}, { + 'tp_size': 1, + 'pp_size': 2, + 'num_microbatches': 2, + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'zero_stage': 1, + 'precision': 'fp16', + 'initial_scale': 1 }]) def run_opt_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry('transformers_opt') + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) - # TODO(baizhou): add test_config for TP+DP after supporting & debugging it + clear_layout_converter() + torch.cuda.empty_cache() + +@parameterize('test_config', [ + { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_all_optimization': False, + 'use_lazy_init': False, + 'precision': 'fp32', + 'initial_scale': 1, + }, + { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_all_optimization': False, + 'use_lazy_init': False, + 'precision': 'fp16', + 'zero_stage': 1, + 'initial_scale': 1, + }, +]) +def run_opt_3d_test(test_config): sub_model_zoo = model_zoo.get_sub_registry('transformers_opt') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): @@ -144,6 +203,12 @@ def check_OPTModel(rank, world_size, port): run_opt_test() +def check_opt_3d(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_opt_3d_test() + + @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() @@ -151,5 +216,13 @@ def test_OPTModel(): spawn(check_OPTModel, 4) +@pytest.mark.largedist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_opt_3d(): + spawn(check_opt_3d, 8) + + if __name__ == '__main__': test_OPTModel() + test_opt_3d() diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index 066f7ee815b4..768cae0a6734 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -1,5 +1,6 @@ import pytest import torch +from torch.nn.parallel import DistributedDataParallel as DDP import colossalai from colossalai.logging import disable_existing_loggers @@ -9,11 +10,13 @@ from tests.kit.model_zoo import model_zoo from tests.test_shardformer.test_model._utils import ( build_model_from_hybrid_plugin, - check_grad, + check_all_grad_tensors, check_loss, check_output_hidden_state, check_weight, + get_grad_tensors_for_check, run_forward_backward_with_hybrid_plugin, + unwrap_model, ) @@ -35,6 +38,32 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, stage_manager = booster.plugin.stage_manager tp_group = booster.plugin.tp_group + # unwrap model + t5 = unwrap_model(org_model) + sharded_t5 = unwrap_model(sharded_model) + + row_layer_for_check = ['shared', 'encoder.block[0].layer[0].SelfAttention.q'] + + # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. + grads_to_check = {} + if test_config['precision'] == 'fp32': + atol, rtol = 1e-5, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: + row_layer_grads = get_grad_tensors_for_check(t5, + sharded_t5, + row_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=0) + grads_to_check.update(row_layer_grads) + + # optimizer executes step + org_optimizer.step() + sharded_optimizer.step() + # check last hidden state & loss if stage_manager is None or stage_manager.is_last_stage(): if test_config['precision'] == 'fp32': @@ -47,30 +76,17 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) - # unwrap model - t5 = org_model - sharded_t5 = sharded_model.unwrap() - - row_layer_for_check = ['shared', 'encoder.block[0].layer[0].SelfAttention.q'] - - # check weights and gradients + # check weights if test_config['precision'] == 'fp32': - atol, rtol = 1e-5, 1e-3 - else: - atol, rtol = 5e-3, 5e-3 - if stage_manager is None or stage_manager.is_first_stage(): - check_grad(t5, sharded_t5, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0) - - # check weights after optimizer.step() - org_optimizer.step() - sharded_optimizer.step() - if test_config['precision'] == 'fp32': - atol, rtol = 1e-4, 1e-3 + atol, rtol = 5e-4, 1e-3 else: atol, rtol = 5e-3, 5e-3 if stage_manager is None or stage_manager.is_first_stage(): check_weight(t5, sharded_t5, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False) + # check grads + check_all_grad_tensors(grads_to_check) + torch.cuda.empty_cache() @@ -99,17 +115,36 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'tp_size': 1, 'pp_size': 4, 'num_microbatches': 4, + 'enable_all_optimization': False, 'use_lazy_init': False, - 'precision': 'fp32', + 'precision': 'fp32' +}, { + 'tp_size': 2, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': False, + 'precision': 'fp32' +}, { + 'tp_size': 2, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'zero_stage': 2, + 'precision': 'fp16', + 'initial_scale': 1 +}, { + 'tp_size': 1, + 'pp_size': 2, + 'num_microbatches': 2, + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'zero_stage': 1, + 'precision': 'fp16', + 'initial_scale': 1 }]) @clear_cache_before_run() def run_t5_test(test_config): - # TODO(baizhou): add plugin_config for TP+DP after supporting & debugging it - # {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True} - - # TODO(baizhou): add test_config for flash attention & jit operator after supporting - sub_model_zoo = model_zoo.get_sub_registry('transformers_t5') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): @@ -125,12 +160,49 @@ def run_t5_test(test_config): torch.cuda.empty_cache() +@parameterize('test_config', [ + { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_all_optimization': False, + 'use_lazy_init': False, + 'precision': 'fp32', + 'initial_scale': 1, + }, + { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_all_optimization': False, + 'use_lazy_init': False, + 'precision': 'fp16', + 'zero_stage': 1, + 'initial_scale': 1, + }, +]) +def run_t5_3d_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry('transformers_t5') + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + torch.cuda.empty_cache() + + def check_t5(rank, world_size, port): disable_existing_loggers() colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') run_t5_test() +def check_t5_3d(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_t5_3d_test() + + @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() @@ -138,5 +210,13 @@ def test_t5(): spawn(check_t5, 4) +@pytest.mark.largedist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_t5_3d(): + spawn(check_t5_3d, 8) + + if __name__ == "__main__": test_t5() + test_t5_3d() diff --git a/tests/test_shardformer/test_model/test_shard_vit.py b/tests/test_shardformer/test_model/test_shard_vit.py index 18df8ef555f2..15db63bfd9da 100644 --- a/tests/test_shardformer/test_model/test_shard_vit.py +++ b/tests/test_shardformer/test_model/test_shard_vit.py @@ -9,11 +9,13 @@ from tests.kit.model_zoo import model_zoo from tests.test_shardformer.test_model._utils import ( build_model_from_hybrid_plugin, - check_grad, + check_all_grad_tensors, check_loss, check_output_hidden_state, check_weight, + get_grad_tensors_for_check, run_forward_backward_with_hybrid_plugin, + unwrap_model, ) @@ -35,54 +37,56 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, stage_manager = booster.plugin.stage_manager tp_group = booster.plugin.tp_group - # check last hidden state & loss - if stage_manager is None or stage_manager.is_last_stage(): - if test_config['precision'] == 'fp32': - atol, rtol = 1e-5, 1e-3 - else: - atol, rtol = 5e-3, 5e-3 - - if org_model.__class__.__name__ == 'ViTModel': - check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) - - check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) - # unwrap model - if org_model.__class__.__name__ == 'ViTModel': - vit_model = org_model - shard_vit_model = sharded_model.unwrap() - else: - vit_model = org_model.vit - shard_vit_model = sharded_model.unwrap().vit + vit_model = unwrap_model(org_model, 'ViTModel', 'vit') + shard_vit_model = unwrap_model(sharded_model, 'ViTModel', 'vit') # check grad row_layer_for_check = ['encoder.layer[0].attention.attention.query', 'embeddings.patch_embeddings.projection'] col_layer_for_check = ['encoder.layer[0].attention.output.dense'] - if stage_manager is None or stage_manager.is_first_stage(): + + # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. + grads_to_check = {} + if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: if test_config['precision'] == 'fp32': atol, rtol = 1e-5, 1e-3 else: atol, rtol = 5e-3, 5e-3 - check_grad(vit_model, - shard_vit_model, - row_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=0, - verbose=False) - check_grad(vit_model, - shard_vit_model, - col_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=1, - verbose=False) - - # check weights after optimizer.step() + row_layer_grads = get_grad_tensors_for_check(vit_model, + shard_vit_model, + row_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=0, + verbose=False) + col_layer_grads = get_grad_tensors_for_check(vit_model, + shard_vit_model, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False) + grads_to_check.update(col_layer_grads) + grads_to_check.update(row_layer_grads) + + # optimizer executes step org_optimizer.step() sharded_optimizer.step() + + # check last hidden state & loss + if stage_manager is None or stage_manager.is_last_stage(): + if test_config['precision'] == 'fp32': + atol, rtol = 1e-5, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + + if org_model.__class__.__name__ == 'ViTModel': + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) + check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) + + # check weights if stage_manager is None or stage_manager.is_first_stage(): if test_config['precision'] == 'fp32': atol, rtol = 5e-3, 1e-3 @@ -97,9 +101,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, dim=1, verbose=False) + # check grads + check_all_grad_tensors(grads_to_check) + torch.cuda.empty_cache() +#TODO: num_microbatch size = 2 inf loss @parameterize('test_config', [{ 'tp_size': 2, 'pp_size': 2, @@ -120,15 +128,36 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'pp_size': 1, 'enable_all_optimization': True, 'use_lazy_init': False, - 'precision': 'fp32', + 'precision': 'fp32' +}, { + 'tp_size': 2, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': False, + 'precision': 'fp32' +}, { + 'tp_size': 2, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': False, + 'zero_stage': 2, + 'precision': 'fp16', + 'initial_scale': 1 +}, { + 'tp_size': 1, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_all_optimization': True, + 'use_lazy_init': False, + 'zero_stage': 1, + 'precision': 'fp16', + 'initial_scale': 1 }]) def run_vit_test(test_config): - # TODO(baizhou): add test_config for TP+DP after supporting & debugging it - # TODO(baizhou): fix bug when settign lazy_init for Conv2D Layers in ViT models + # TODO: fix bug when settign lazy_init for Conv2D Layers in ViT models sub_model_zoo = model_zoo.get_sub_registry('transformers_vit') - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) @@ -137,12 +166,48 @@ def run_vit_test(test_config): torch.cuda.empty_cache() +@parameterize('test_config', [ + { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_all_optimization': False, + 'use_lazy_init': False, + 'precision': 'fp32', + 'initial_scale': 1, + }, + { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 2, + 'enable_all_optimization': False, + 'use_lazy_init': False, + 'precision': 'fp32', + 'initial_scale': 1, + }, +]) +def run_vit_3d_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry('transformers_vit') + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + torch.cuda.empty_cache() + + def check_vit(rank, world_size, port): disable_existing_loggers() colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') run_vit_test() +def check_vit_3d(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_vit_3d_test() + + @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() @@ -150,5 +215,13 @@ def test_vit(): spawn(check_vit, 4) +@pytest.mark.largedist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_vit_3d(): + spawn(check_vit_3d, 8) + + if __name__ == "__main__": test_vit() + test_vit_3d() diff --git a/tests/test_shardformer/test_model/test_shard_whisper.py b/tests/test_shardformer/test_model/test_shard_whisper.py index 9b38ae07b1d6..d0c04c98f80a 100644 --- a/tests/test_shardformer/test_model/test_shard_whisper.py +++ b/tests/test_shardformer/test_model/test_shard_whisper.py @@ -3,6 +3,8 @@ import colossalai from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.layer.utils import Randomizer +from colossalai.tensor.d_tensor.api import clear_layout_converter from colossalai.testing import ( assert_hf_output_close, clear_cache_before_run, @@ -11,55 +13,205 @@ spawn, ) from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_model, check_grad, run_forward +from tests.test_shardformer.test_model._utils import ( + build_model_from_hybrid_plugin, + check_all_grad_tensors, + check_loss, + check_output_hidden_state, + check_weight, + get_grad_tensors_for_check, + run_forward_backward_with_hybrid_plugin, +) -def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): +def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): # check forward - org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, - output_transform_fn, loss_fn) - assert_hf_output_close(org_output, shard_output, ignore_keys='past_key_values', atol=1e-5) - - # do backward - org_loss.backward() - shard_loss.backward() - - assert torch.allclose(org_loss, shard_loss, - atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \ + build_model_from_hybrid_plugin(model_fn, loss_fn, test_config) + + org_loss, org_output, sharded_loss, sharded_output = \ + run_forward_backward_with_hybrid_plugin( + org_model, + sharded_model, + sharded_optimizer, + data_gen_fn, + output_transform_fn, + criterion, + booster) + + stage_manager = booster.plugin.stage_manager + tp_group = booster.plugin.tp_group # unwarp the model if org_model.__class__.__name__ == 'WhisperForConditionalGeneration': whisper = org_model.model - sharded_whisper = sharded_model.model + sharded_whisper = sharded_model.unwrap().model else: whisper = org_model - sharded_whisper = sharded_model + sharded_whisper = sharded_model.unwrap() # check grad if org_model.__class__.__name__ == 'WhisperForAudioClassification': col_layer_for_check = ['encoder.layers[0].self_attn.q_proj'] row_layer_for_check = ['encoder.layers[0].self_attn.out_proj'] else: - col_layer_for_check = ['encoder.layers[0].self_attn.q_proj', 'decoder.layers[0].self_attn.q_proj'] - row_layer_for_check = ['encoder.layers[0].self_attn.out_proj', 'decoder.layers[0].self_attn.out_proj'] - check_grad(whisper, sharded_whisper, col_layer_for_check, atol=1e-6, rtol=1e-5, dim=0, verbose=False) - check_grad(whisper, sharded_whisper, row_layer_for_check, atol=1e-6, rtol=1e-5, dim=1, verbose=False) + col_layer_for_check = [ + 'encoder.layers[0].self_attn.q_proj', + # 'decoder.layers[0].self_attn.q_proj' + ] + row_layer_for_check = [ + 'encoder.layers[0].self_attn.out_proj', + #'decoder.layers[0].self_attn.out_proj' + ] + + # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. + grads_to_check = {} + if test_config['precision'] == 'fp32': + atol, rtol = 2e-4, 2e-4 + else: + atol, rtol = 5e-3, 5e-3 + + if stage_manager is None or stage_manager.is_first_stage(): + row_layer_grads = get_grad_tensors_for_check(whisper, + sharded_whisper, + row_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1) + col_layer_grads = get_grad_tensors_for_check(whisper, + sharded_whisper, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=0) + grads_to_check.update(col_layer_grads) + grads_to_check.update(row_layer_grads) + + # optimizer executes step + org_optimizer.step() + sharded_optimizer.step() + + # check last hidden state & loss + if stage_manager is None or stage_manager.is_last_stage(): + if test_config['precision'] == 'fp32': + atol, rtol = 2e-4, 2e-4 + else: + atol, rtol = 5e-3, 5e-3 + + if org_model.__class__.__name__ == 'WhisperModel': + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) + + check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) + + # check weights + if test_config['precision'] == 'fp32': + atol, rtol = 1e-3, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + if stage_manager is None or stage_manager.is_first_stage(): + check_weight(whisper, + sharded_whisper, + row_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False) + check_weight(whisper, + sharded_whisper, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=0, + verbose=False) + + # check grads + check_all_grad_tensors(grads_to_check) + + torch.cuda.empty_cache() -@parameterize('enable_fused_normalization', [True, False]) -@parameterize('enable_tensor_parallelism', [True, False]) -@parameterize('enable_flash_attention', [True, False]) -@parameterize('enable_jit_fused', [True, False]) -def run_whisper_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused): +#TODO fix WhisperForConditionalGeneration enable jit fused operato +# TODO(jianghai) fix fp16 +@parameterize( + 'test_config', + [ + { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 2, + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'precision': 'fp32', + 'initial_scale': 1, + }, + { + 'tp_size': 1, + 'pp_size': 2, + 'num_microbatches': 4, + 'use_lazy_init': False, + 'precision': 'fp32', + 'initial_scale': 1, + }, + { + 'tp_size': 4, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': False, + 'precision': 'fp32', + }, + { + 'tp_size': 1, + 'pp_size': 4, + 'num_microbatches': 4, + 'use_lazy_init': False, + 'precision': 'fp32', + }, + # whisper is not supported fp16 for now. + ]) +def run_whisper_test(test_config): sub_model_zoo = model_zoo.get_sub_registry('transformers_whisper') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn, - enable_fused_normalization=enable_fused_normalization, - enable_tensor_parallelism=enable_tensor_parallelism, - enable_flash_attention=enable_flash_attention, - enable_jit_fused=enable_jit_fused) - check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) + if test_config['pp_size'] > 2 and name == 'transformers_whisper_for_audio_classification': + continue + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + Randomizer.reset_index() + torch.cuda.empty_cache() + + +@parameterize('test_config', [ + { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_all_optimization': False, + 'use_lazy_init': False, + 'precision': 'fp32', + 'initial_scale': 1, + }, + { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 2, + 'enable_all_optimization': False, + 'use_lazy_init': False, + 'precision': 'fp32', + 'initial_scale': 1, + }, +]) +def run_whisper_3d_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry('transformers_whisper') + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() torch.cuda.empty_cache() @@ -69,12 +221,26 @@ def check_whisper(rank, world_size, port): run_whisper_test() +def check_whisper_3d(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_whisper_3d_test() + + @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() def test_whisper(): - spawn(check_whisper, 2) + spawn(check_whisper, 4) + + +@pytest.mark.largedist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_whisper_3d(): + spawn(check_whisper_3d, 8) if __name__ == "__main__": test_whisper() + test_whisper_3d() diff --git a/tests/test_utils/test_activation_checkpointing.py b/tests/test_utils/test_activation_checkpointing.py index 2930552cc4e7..b7764c2f4371 100644 --- a/tests/test_utils/test_activation_checkpointing.py +++ b/tests/test_utils/test_activation_checkpointing.py @@ -40,7 +40,6 @@ def forward_inplace(x, weight): return out -@pytest.mark.gpu @clear_cache_before_run() @parameterize("use_reentrant", [True, False]) @parameterize("cpu_offload", [True, False])