diff --git a/LICENSE b/LICENSE index 59d456c5b8a1..b3eb43520a6f 100644 --- a/LICENSE +++ b/LICENSE @@ -477,3 +477,53 @@ Copyright 2021- HPC-AI Technology Inc. All rights reserved. LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + + + ---------------- LICENSE FOR torch-int ---------------- + + MIT License + + Copyright (c) 2022 Guangxuan Xiao + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. + + + ---------------- LICENSE FOR smoothquant ---------------- + + MIT License + + Copyright (c) 2022 MIT HAN Lab + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. diff --git a/README.md b/README.md index b2efb7910489..1898d255e31c 100644 --- a/README.md +++ b/README.md @@ -132,7 +132,8 @@ distributed training and inference in a few lines. - One half-day of training using a few hundred dollars yields similar results to mainstream large models, open-source and commercial-free domain-specific LLM solution. [[code]](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Colossal-LLaMA-2) [[blog]](https://www.hpc-ai.tech/blog/one-half-day-of-training-using-a-few-hundred-dollars-yields-similar-results-to-mainstream-large-models-open-source-and-commercial-free-domain-specific-llm-solution) -[[model weights]](https://huggingface.co/hpcai-tech/Colossal-LLaMA-2-7b-base) +[[HuggingFace model weights]](https://huggingface.co/hpcai-tech/Colossal-LLaMA-2-7b-base) +[[Modelscope model weights]](https://www.modelscope.cn/models/colossalai/Colossal-LLaMA-2-7b-base/summary) | | Backbone | Tokens Consumed | | MMLU | CMMLU | AGIEval | GAOKAO | CEval | | :----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :-----: | :----: | :----: | :------------------------------: | diff --git a/applications/Colossal-LLaMA-2/README.md b/applications/Colossal-LLaMA-2/README.md index 34967c04360c..ae2e0c6bb2db 100644 --- a/applications/Colossal-LLaMA-2/README.md +++ b/applications/Colossal-LLaMA-2/README.md @@ -25,7 +25,9 @@ * [2023/09] [One Half-Day of Training Using a Few Hundred Dollars Yields Similar Results to Mainstream Large Models, Open-Source and Commercial-Free Domain-Specific Llm Solution](https://www.hpc-ai.tech/blog/one-half-day-of-training-using-a-few-hundred-dollars-yields-similar-results-to-mainstream-large-models-open-source-and-commercial-free-domain-specific-llm-solution) [[code]](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Colossal-LLaMA-2) [[blog]](https://www.hpc-ai.tech/blog/one-half-day-of-training-using-a-few-hundred-dollars-yields-similar-results-to-mainstream-large-models-open-source-and-commercial-free-domain-specific-llm-solution) -[[model weights]](https://huggingface.co/hpcai-tech/Colossal-LLaMA-2-7b-base) +[[HuggingFace model weights]](https://huggingface.co/hpcai-tech/Colossal-LLaMA-2-7b-base) +[[Modelscope model weights]](https://www.modelscope.cn/models/colossalai/Colossal-LLaMA-2-7b-base/summary) + ## Colossal-LLaMA-2-7B The [Colossal-AI](https://github.com/hpcaitech/ColossalAI) team has introduced the open-source model **Colossal-LLaMA-2-7B-base**. This model, a derivation of LLaMA-2, has undergone continual pre-training involving approximately 8.5 billion tokens over a duration of 15 hours with 64 A800 GPUs. At a cost of **less than $1,000**, you can achieve results **similar to those that cost millions of dollars to pretrain from scratch**. It is licensed under the LLaMA-2 license and [Apache 2.0 License](https://github.com/hpcaitech/ColossalAI/blob/main/LICENSE) **without any additional commercial use restrictions**. This solution can also be used to build models of specific domain knowledge or tasks. @@ -122,7 +124,23 @@ pred = model.generate(**inputs, print(tokenizer.decode(pred.cpu()[0], skip_special_tokens=True)[len(input):]) ``` -You can also download model weights from [🤗HuggingFace](https://huggingface.co/hpcai-tech/Colossal-LLaMA-2-7b-base). +You can also load our model using modelscope, use the following code: +```Python +from modelscope import AutoModelForCausalLM, AutoTokenizer, snapshot_download +model_dir = snapshot_download('colossalai/Colossal-LLaMA-2-7b-base', revision='v1.0.1') +tokenizer = AutoTokenizer.from_pretrained(model_dir, device_map="auto", trust_remote_code=True) +model = AutoModelForCausalLM.from_pretrained(model_dir, device_map="auto", trust_remote_code=True).eval() +generation_kwargs = {"max_new_tokens": 256, + "top_p": 0.95, + "temperature": 0.3 + } +input = '离离原上草,' +inputs = tokenizer(input, return_token_type_ids=False, return_tensors='pt') +inputs = inputs.to('cuda:0') +output = model.generate(**inputs, **generation_kwargs) +print(tokenizer.decode(output.cpu()[0], skip_special_tokens=True)[len(input):]) +``` +You can download model weights from [🤗HuggingFace](https://huggingface.co/hpcai-tech/Colossal-LLaMA-2-7b-base) or [👾Modelscope](https://modelscope.cn/models/colossalai/Colossal-LLaMA-2-7b-base/summary). ## Usage ### Install diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py b/applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py index 6c58c59307a6..1926ec78aba8 100644 --- a/applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py +++ b/applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py @@ -6,25 +6,20 @@ import torch import torch.nn.functional as F +from einops import rearrange +from flash_attn.bert_padding import pad_input, unpad_input +from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_kvpacked_func +from flash_attn.ops.rms_norm import rms_norm from transformers.models.llama.modeling_llama import ( - LlamaRMSNorm, LlamaAttention, - LlamaModel, LlamaForCausalLM, + LlamaModel, + LlamaRMSNorm, apply_rotary_pos_emb, repeat_kv, ) from colossalai.logging import get_dist_logger -from einops import rearrange - -from flash_attn.bert_padding import pad_input, unpad_input -from flash_attn.flash_attn_interface import ( - flash_attn_func, - flash_attn_varlen_kvpacked_func, -) -from flash_attn.ops.rms_norm import rms_norm - logger = get_dist_logger() @@ -65,6 +60,7 @@ def attention_forward( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, + **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """ Re-define LLaMA-2 `LlamaAttention` forward method using flash-attention. diff --git a/colossalai/amp/naive_amp/mixed_precision_optimizer.py b/colossalai/amp/naive_amp/mixed_precision_optimizer.py index 501a843f6992..9e07bdebf8fa 100644 --- a/colossalai/amp/naive_amp/mixed_precision_optimizer.py +++ b/colossalai/amp/naive_amp/mixed_precision_optimizer.py @@ -1,7 +1,7 @@ -from typing import Dict, List +from typing import Dict, List, Tuple import torch -from torch import Tensor +from torch import Tensor, inf from torch.nn import Module, Parameter from torch.optim import Optimizer @@ -68,8 +68,6 @@ def __init__( self.mixed_precision = BF16MixedPrecisionMixin() else: raise ValueError(f"Unsupported precision: {precision}") - if max_norm > 0.0: - raise NotImplementedError("max_norm is not supported yet.") self.max_norm = max_norm self.working_to_master_map: Dict[Parameter, Tensor] = {} self.master_to_working_map: Dict[Tensor, Parameter] = {} @@ -102,32 +100,65 @@ def zero_grad(self, *args, **kwargs): return super().zero_grad(*args, **kwargs) def _unscale_and_clip_grads(self, total_norm: float) -> None: + """ + Unscale and clip gradients before performing the optimization step. + + Args: + total_norm (float): The computed total gradient norm. + + Returns: + None + """ div_scale = 1.0 + + # If mixed-precision training is used, get the gradient division scale from the mixed-precision handler. if self.mixed_precision is not None: div_scale = self.mixed_precision.get_grad_div_scale() if self.max_norm > 0.0: - # norm is in fact norm*scale + # Calculate the scaling factor for gradient clipping + # The gradient norm is scaled by 'div_scale' and then clipped to 'max_norm' clip = ((total_norm / div_scale) + 1e-6) / self.max_norm + + # If the clip factor exceeds 1, adjust 'div_scale' accordingly to ensure clipping if clip > 1: div_scale = clip * div_scale + # Apply the scaling factor to gradients for group in self.param_groups: for p in group["params"]: if p.grad is None: continue p.grad.data.mul_(1.0 / div_scale) - def _compute_grad_norm(self) -> float: - if self.max_norm <= 0.0: - return 0.0 - grads = [p.grad for group in self.param_groups for p in group["params"] if p.grad is not None] - if len(grads) == 0: + def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_type: int = 2) -> int: + r""" + Compute and return the gradient norm for gradient clipping. + + Args: + param_gradient_pairs (List[Tuple[Tensor]]): List of (parameter, gradient) pairs; gradients are used for norm calculation. + norm_type (int, optional): Type of the norm used (e.g., 2 for L2 norm). Defaults to 2. + + Returns: + float: The total norm of the given gradients. + """ + + if len(param_gradient_pairs) == 0: return 0.0 - device = grads[0].device - # TODO(ver217): support tp - total_norm = torch.norm(torch.stack([torch.norm(g.detach(), 2).to(device) for g in grads]), 2) - return total_norm.item() + + # gradients used for norm calculation. + gradients = [grad for param, grad in param_gradient_pairs] + + if norm_type == inf: + total_norm = max(grad.data.abs().max() for grad in gradients) + + else: + total_norm_exponentiated = 0.0 + for grad in gradients: + total_norm_exponentiated += grad.data.double().norm(norm_type) ** norm_type + total_norm = total_norm_exponentiated ** (1.0 / norm_type) + + return total_norm def step(self, *args, **kwargs): if self.mixed_precision.should_skip_step(): @@ -142,8 +173,22 @@ def step(self, *args, **kwargs): if working_param.grad is not None: p.grad = working_param.grad.data.float() working_param.grad = None - total_norm = self._compute_grad_norm() + + # gradient unscale and clip. + if self.max_norm <= 0: + # no need to compute gradient norm. + total_norm = 0.0 + else: + # compute the total norm. + param_gradient_pairs = [ + (self.master_to_working_map[p], p.grad) + for group in self.param_groups + for p in group["params"] + if p.grad is not None + ] + total_norm = self._compute_grad_norm(param_gradient_pairs) self._unscale_and_clip_grads(total_norm) + self.optim.step(*args, **kwargs) # update working params for group in self.optim.param_groups: diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index ca722a0768dc..20a931b816ea 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -97,7 +97,7 @@ def save_sharded_model( Path(checkpoint_path).mkdir(parents=True, exist_ok=True) - state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True, dtype=torch.float32) + state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True) weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors) index_file = CheckpointIndexFile(checkpoint_path) @@ -245,6 +245,7 @@ class GeminiPlugin(DPPluginBase): chunk_config_dict (dict, optional): chunk configuration dictionary. chunk_init_device (torch.device, optional): device to initialize the chunk. placement_policy (str, optional): "static" and "auto". Defaults to "static". + enable_gradient_accumulation (bool, optional): Whether to enable gradient accumulation. When set to True, gradient will be stored after doing backward pass. Defaults to False. shard_param_frac (float, optional): fraction of parameters to be sharded. Only for "static" placement. If `shard_param_frac` is 1.0, it's equal to zero-3. If `shard_param_frac` is 0.0, it's equal to zero-2. Defaults to 1.0. offload_optim_frac (float, optional): fraction of optimizer states to be offloaded. Only for "static" placement. @@ -257,6 +258,7 @@ class GeminiPlugin(DPPluginBase): warmup_non_model_data_ratio (float, optional): ratio of expected non-model data memory during warmup. Only for "auto" placement. Defaults to 0.8. steady_cuda_cap_ratio (float, optional): ratio of allowed cuda capacity for model data during steady state. Only for "auto" placement. Defaults to 0.9. precision (str, optional): precision. Support 'fp16' and 'bf16'. Defaults to 'fp16'. + master_weights (bool, optional): Whether to keep fp32 master parameter weights in optimizer. Defaults to True. pin_memory (bool, optional): use pin memory on CPU. Defaults to False. force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False. strict_ddp_mode (bool, optional): use strict ddp mode (only use dp without other parallelism). Defaults to False. @@ -290,12 +292,14 @@ def __init__( chunk_config_dict: Optional[dict] = None, chunk_init_device: Optional[torch.device] = None, placement_policy: str = "static", + enable_gradient_accumulation: bool = False, shard_param_frac: float = 1.0, # only for static placement offload_optim_frac: float = 0.0, # only for static placement offload_param_frac: float = 0.0, # only for static placement warmup_non_model_data_ratio: float = 0.8, # only for auto placement steady_cuda_cap_ratio: float = 0.9, # only for auto placement precision: str = "fp16", + master_weights: bool = True, pin_memory: bool = False, force_outputs_fp32: bool = False, strict_ddp_mode: bool = False, @@ -321,6 +325,7 @@ def __init__( chunk_config_dict=chunk_config_dict, chunk_init_device=(chunk_init_device or get_current_device()), placement_policy=placement_policy, + enable_gradient_accumulation=enable_gradient_accumulation, shard_param_frac=shard_param_frac, offload_optim_frac=offload_optim_frac, offload_param_frac=offload_param_frac, @@ -334,6 +339,7 @@ def __init__( min_chunk_size_m=min_chunk_size_m, memstats=memstats, mixed_precision=PRECISION_STR_TO_DTYPE[precision], + master_weights=master_weights, ) self.zero_optim_config = dict( gpu_margin_mem_ratio=gpu_margin_mem_ratio, diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 479ccc3eb36e..72c3ec46ae75 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1,3 +1,4 @@ +import ctypes import random from contextlib import nullcontext from functools import partial @@ -7,7 +8,8 @@ import numpy as np import torch import torch.distributed as dist -from torch.distributed import ProcessGroup +from torch import Tensor, inf +from torch.distributed import ProcessGroup, get_world_size from torch.nn import Module, SyncBatchNorm from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer @@ -24,6 +26,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.shardformer.policies.base_policy import Policy +from colossalai.tensor.d_tensor.api import is_distributed_tensor from colossalai.zero.low_level import LowLevelZeroOptimizer from .pp_plugin_base import PipelinePluginBase @@ -160,12 +163,143 @@ def init_pipeline_optimizer(optim: Optimizer, model: Module): class HybridParallelNaiveOptimizer(OptimizerWrapper): - def __init__(self, optim: Optimizer, model: Module, use_pipeline: bool, param_info: OrderedDict): + def __init__( + self, + optim: Optimizer, + model: Module, + use_pipeline: bool, + param_info: OrderedDict, + max_norm: float = 0, + tp_process_group: Optional[ProcessGroup] = None, # if using tp + pp_process_group: Optional[ProcessGroup] = None, # if using pp + ): self.param_info = param_info if use_pipeline: init_pipeline_optimizer(optim, model) + self.stage_manager = model.stage_manager + self.shared_params = model.shared_params + self.max_norm = max_norm + self.tp_pg = tp_process_group + self.pp_pg = pp_process_group super().__init__(optim) + def step(self, *args, **kwargs): + r""" + Perform an optimization step. + + Args: + *args: Variable-length positional arguments to be passed to the optimizer's step function. + **kwargs: Keyword arguments to be passed to the optimizer's step function. + """ + + if self.max_norm > 0: + # Compute the total gradient norm. + param_gradient_pairs = [ + (p, p.grad) for group in self.optim.param_groups for p in group["params"] if p.grad is not None + ] + total_norm = self._compute_grad_norm(param_gradient_pairs) + + # Clip the gradients to prevent exploding gradients. + self._clip_grad_norm(total_norm) + + # Perform the optimization step using the underlying optimizer. + self.optim.step(*args, **kwargs) + + def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_type: int = 2) -> int: + r""" + Compute and return the gradient norm for gradient clipping. + + Args: + param_gradient_pairs (List[Tuple[Tensor]]): List of (parameter, gradient) pairs; gradients are used for norm calculation. + norm_type (int, optional): Type of the norm used (e.g., 2 for L2 norm). Defaults to 2. + + Returns: + float: The total norm of the given gradients. + """ + + if len(param_gradient_pairs) == 0: + return 0.0 + + tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1 + pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1 + norm_type = float(norm_type) + + # gradients used for norm calculation. + gradients = [grad for param, grad in param_gradient_pairs] + + if norm_type == inf: + total_norm = max(grad.data.abs().max() for grad in gradients) + total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) + if tp_size > 1: + dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg) + if pp_size > 1: + dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.pp_pg) + total_norm = total_norm_cuda.item() + else: + # gradients used for norm calculation. + gradients = [grad for param, grad in param_gradient_pairs] + # grad_to_param_mapping is used to check which gradients are not distributed across devices of the 'tp_group'. + grad_to_param_mapping = {id(grad): param for param, grad in param_gradient_pairs} + + total_norm_exponentiated = 0.0 + for grad in gradients: + grad_norm_exponentiated = grad.data.double().norm(norm_type) ** norm_type + + # If 'tp_size' is greater than 1 and the parameter for the gradient is not a distributed tensor, + # it indicates that the parameter is not distributed across devices of the 'tp_group'. + # Consequently, there is no need to perform an 'all_reduce' operation for 'grad_norm'. + # However, we still perform the 'all_reduce' operation for the sake of good coding practices. + # To ensure mathematical equivalence, we divide the 'grad_norm' by 'tp_size.' + if tp_size > 1: + param_for_grad = grad_to_param_mapping[id(grad)] + if not is_distributed_tensor(param_for_grad): + grad_norm_exponentiated /= tp_size + + # If 'pp_size' is greater than 1 and the gradient belongs to shared parameters, + # it means that this parameter is used in two different pipeline stages. + # To avoid redundant norm calculations, we divide the exponent of this norm by + # the number of shared stages. + if pp_size > 1: + for shared_param in self.shared_params: + if self.stage_manager.stage in shared_param: + stage_shared_param = shared_param[self.stage_manager.stage] + if grad is stage_shared_param.grad: + grad_norm_exponentiated /= len(shared_param) + + total_norm_exponentiated += grad_norm_exponentiated + + total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)]) + if tp_size > 1: + # compute norm in tp process group + dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg) + if pp_size > 1: + # compute norm in pp process group + dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.pp_pg) + + # compute the total_norm + total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type) + + return total_norm + + def _clip_grad_norm(self, total_norm: float) -> None: + r""" + Clips the gradients of the model's parameters to prevent exploding gradients. + + Args: + total_norm (float): The computed total gradient norm. + + Returns: + None + """ + clip_coef = torch.tensor(self.max_norm / (total_norm + 1e-6)) + clip_coef_clamped = torch.clamp(clip_coef, max=1.0) + + for group in self.optim.param_groups: + for p in group["params"]: + if p.grad is None: + continue + p.grad.data.mul_(clip_coef_clamped) + def update_master_params(self, model: Module): pass @@ -192,23 +326,108 @@ def __init__( hysteresis: int = 2, max_scale: float = 2**32, max_norm: float = 0, + tp_process_group: Optional[ProcessGroup] = None, # if using tp + pp_process_group: Optional[ProcessGroup] = None, # if using pp ): self.param_info = param_info + self.stage_manager = model.stage_manager + self.shared_params = model.shared_params + self.tp_pg = tp_process_group + self.pp_pg = pp_process_group if use_pipeline: init_pipeline_optimizer(optim, model) super().__init__( optim, - precision, - initial_scale, - min_scale, - growth_factor, - backoff_factor, - growth_interval, - hysteresis, - max_scale, - max_norm, + precision=precision, + initial_scale=initial_scale, + min_scale=min_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + max_scale=max_scale, + max_norm=max_norm, ) + def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_type: int = 2) -> int: + r""" + Compute and return the gradient norm for gradient clipping. + + Args: + param_gradient_pairs (List[Tuple[Tensor]]): List of (parameter, gradient) pairs; gradients are used for norm calculation. + norm_type (int, optional): Type of the norm used (e.g., 2 for L2 norm). Defaults to 2. + + Returns: + float: The total norm of the given gradients. + """ + if len(param_gradient_pairs) == 0: + return 0.0 + + tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1 + pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1 + norm_type = float(norm_type) + + if norm_type == inf: + # The parent class calculates the norm of 'dp' gradients, + # so we need to calculate the norm of 'tp' and 'pp' gradients. + total_norm = super()._compute_grad_norm(param_gradient_pairs, norm_type) + + total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) + + if tp_size > 1: + dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg) + if pp_size > 1: + dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.pp_pg) + + total_norm = total_norm_cuda.item() + + else: + # gradients used for norm calculation. + gradients = [grad for param, grad in param_gradient_pairs] + # grad_to_param_mapping is used to check which gradients are not distributed in tensor parallelism. + grad_to_param_mapping = {id(grad): param for param, grad in param_gradient_pairs} + + total_norm_exponentiated = 0.0 + for grad in gradients: + grad_norm_exponentiated = grad.data.double().norm(norm_type) ** norm_type + + # If 'tp_size' is greater than 1 and the parameter for the gradient is not a distributed tensor, + # it indicates that the parameter is not distributed across devices of the 'tp_group'. + # Consequently, there is no need to perform an 'all_reduce' operation for 'grad_norm'. + # However, we still perform the 'all_reduce' operation for the sake of good coding practices. + # To ensure mathematical equivalence, we divide the 'grad_norm' by 'tp_size.' + if tp_size > 1: + param_for_grad = grad_to_param_mapping[id(grad)] + if not is_distributed_tensor(param_for_grad): + grad_norm_exponentiated /= tp_size + + # If 'pp_size' is greater than 1 and the gradient belongs to shared parameters, + # it means that this parameter is used in two different pipeline stages. + # To avoid redundant norm calculations, we divide the exponent of this norm by + # the number of shared stages. + if pp_size > 1: + for shared_param in self.shared_params: + if self.stage_manager.stage in shared_param: + stage_working_shared_param = shared_param[self.stage_manager.stage] + stage_master_shared_param = self.working_to_master_map[stage_working_shared_param] + if grad is stage_master_shared_param.grad: + grad_norm_exponentiated /= len(shared_param) + + total_norm_exponentiated += grad_norm_exponentiated + + total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)]) + if tp_size > 1: + # compute norm in tp process group + dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg) + if pp_size > 1: + # compute norm in pp process group + dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.pp_pg) + + # compute the total_norm + total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type) + + return total_norm + class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): def __init__( @@ -233,32 +452,118 @@ def __init__( cpu_offload: bool = False, # cpu offload dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm tp_process_group: Optional[ProcessGroup] = None, # if using tp + pp_process_group: Optional[ProcessGroup] = None, # if using pp forced_dtype: Optional[torch.dtype] = None, ): self.param_info = param_info + self.stage_manager = model.stage_manager + self.shared_params = model.shared_params + self.dp_pg = dp_process_group + self.tp_pg = tp_process_group + self.pp_pg = pp_process_group if use_pipeline: init_pipeline_optimizer(optimizer, model) super().__init__( - optimizer, - initial_scale, - min_scale, - growth_factor, - backoff_factor, - growth_interval, - hysteresis, - max_scale, - clip_grad_norm, - verbose, - reduce_bucket_size, - communication_dtype, - overlap_communication, - partition_grad, - cpu_offload, - dp_process_group, - tp_process_group, - forced_dtype, + optimizer=optimizer, + initial_scale=initial_scale, + min_scale=min_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + max_scale=max_scale, + clip_grad_norm=clip_grad_norm, + verbose=verbose, + reduce_bucket_size=reduce_bucket_size, + communication_dtype=communication_dtype, + overlap_communication=overlap_communication, + partition_grad=partition_grad, + cpu_offload=cpu_offload, + dp_process_group=dp_process_group, + forced_dtype=forced_dtype, ) + def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float: + r""" + Compute and return the gradient norm for gradient clipping. + + Args: + gradients (List[Tensor]): A list of tensors containing gradients. + norm_type (int, optional): Type of the p-norm to be computed. Defaults to 2. + + Returns: + float: The computed gradient norm. + """ + + # Check if the list of gradients is empty + if len(gradients) == 0: + return 0.0 + + dp_size = get_world_size(self.dp_pg) if self.dp_pg is not None else 1 + tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1 + pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1 + norm_type = float(norm_type) + + if norm_type == inf: + # The parent class calculates the norm of 'dp' gradients, + # so we only need to calculate the norm 'tp' of 'pp' gradients. + total_norm = super()._compute_grad_norm(gradients, norm_type) + + total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) + + if tp_size > 1: + dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg) + if pp_size > 1: + dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.pp_pg) + + total_norm = total_norm_cuda.item() + else: + total_norm_exponentiated = 0.0 + for grad in gradients: + grad_norm_exponentiated = grad.data.double().norm(norm_type) ** norm_type + + # If 'tp_size' is greater than 1 and the parameter for the gradient is not a distributed tensor, + # it indicates that the parameter is not distributed across devices of the 'tp_group'. + # Consequently, there is no need to perform an 'all_reduce' operation for 'grad_norm'. + # However, we still perform the 'all_reduce' operation for the sake of good coding practices. + # To ensure mathematical equivalence, we divide the 'grad_norm' by 'tp_size.' + if tp_size > 1: + param_id_for_grad = self._grad_store.get_param_id_for_grad(grad) + param_for_grad = ctypes.cast(param_id_for_grad, ctypes.py_object).value + + if not is_distributed_tensor(param_for_grad): + grad_norm_exponentiated /= tp_size + + # If 'pp_size' is greater than 1 and the gradient belongs to shared parameters, + # it means that this parameter is used in two different pipeline stages. + # To avoid redundant norm calculations, we divide the exponent of this norm by + # the number of shared stages. + if pp_size > 1: + for shared_param in self.shared_params: + if self.stage_manager.stage in shared_param: + stage_shared_param = shared_param[self.stage_manager.stage] + working_grad = self._grad_store.get_working_grad_by_param_id(id(stage_shared_param)) + if grad is working_grad: + grad_norm_exponentiated /= len(shared_param) + + total_norm_exponentiated += grad_norm_exponentiated + + total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)]) + if dp_size > 1: + # compute norm in dp process group + dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.dp_pg) + if tp_size > 1: + # compute norm in tp process group + dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg) + if pp_size > 1: + # compute norm in pp process group + dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.pp_pg) + + # Compute the 'total_norm' from 'total_norm_exponentiated' + total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type) + + return total_norm + class HybridParallelPlugin(PipelinePluginBase): """ @@ -475,11 +780,19 @@ def configure( param_info=param_info, precision=self.precision, max_norm=self.max_norm, + pp_process_group=self.pp_group, + tp_process_group=self.tp_group, **self.amp_config, ) else: optimizer = HybridParallelNaiveOptimizer( - optimizer, model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info + optimizer, + model, + use_pipeline=self.enable_pipeline_parallelism, + param_info=param_info, + max_norm=self.max_norm, + pp_process_group=self.pp_group, + tp_process_group=self.tp_group, ) else: assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1." @@ -491,6 +804,7 @@ def configure( param_info=param_info, dp_process_group=self.dp_group, tp_process_group=self.tp_group, + pp_process_group=self.pp_group, verbose=True, clip_grad_norm=self.max_norm, **self.zero_config, diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index dffa4ce164ef..dc78fe8c094c 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -262,6 +262,7 @@ def __init__( communication_dtype: Optional[torch.dtype] = None, overlap_communication: bool = True, cpu_offload: bool = False, + master_weights: bool = True, verbose: bool = False, ) -> None: super().__init__() @@ -272,18 +273,19 @@ def __init__( self.precision = precision self.zero_optim_kwargs = dict( initial_scale=initial_scale, + min_scale=min_scale, growth_factor=growth_factor, backoff_factor=backoff_factor, growth_interval=growth_interval, hysteresis=hysteresis, - min_scale=min_scale, max_scale=max_scale, clip_grad_norm=max_norm, reduce_bucket_size=reduce_bucket_size_in_m * 1024 * 1024, communication_dtype=communication_dtype, overlap_communication=overlap_communication, - cpu_offload=cpu_offload, partition_grad=(stage == 2), + cpu_offload=cpu_offload, + master_weights=master_weights, ) self.verbose = verbose @@ -333,4 +335,4 @@ def get_checkpoint_io(self) -> CheckpointIO: def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]: assert isinstance(optimizer, LowLevelZeroOptimizer) - return optimizer.optim.no_sync() + return optimizer.no_sync() diff --git a/colossalai/inference/README.md b/colossalai/inference/README.md index 9a965dc982a4..d0c281e057b3 100644 --- a/colossalai/inference/README.md +++ b/colossalai/inference/README.md @@ -4,7 +4,7 @@ ## Introduction -`Colossal Inference` is a module that contains colossal-ai designed inference framework, featuring high performance, steady and easy usability. `Colossal Inference` incorporated the advantages of the latest open-source inference systems, including TGI, vLLM, FasterTransformer, LightLLM and flash attention. while combining the design of Colossal AI, especially Shardformer, to reduce the learning curve for users. +`Colossal Inference` is a module that contains colossal-ai designed inference framework, featuring high performance, steady and easy usability. `Colossal Inference` incorporated the advantages of the latest open-source inference systems, including LightLLM, TGI, vLLM, FasterTransformer and flash attention. while combining the design of Colossal AI, especially Shardformer, to reduce the learning curve for users. ## Design @@ -62,6 +62,12 @@ triton==2.0.0.dev20221202 vllm # for install flash-attention, please use commit hash: 67ae6fd74b4bc99c36b2ce524cf139c35663793c flash-attention + +# install lightllm since we depend on lightllm triton kernels +git clone https://github.com/ModelTC/lightllm +git checkout 28c1267cfca536b7b4f28e921e03de735b003039 +cd lightllm +pip3 install -e . ``` ### Docker @@ -73,6 +79,17 @@ You can use docker run to use docker container to set-up environment docker pull hpcaitech/colossalai-inference:v2 docker run -it --gpus all --name ANY_NAME -v $PWD:/workspace -w /workspace hpcaitech/colossalai-inference:v2 /bin/bash +# enter into docker container +cd /path/to/CollossalAI +pip install -e . + +# install lightllm +git clone https://github.com/ModelTC/lightllm +git checkout 28c1267cfca536b7b4f28e921e03de735b003039 +cd lightllm +pip3 install -e . + + ``` ### Dive into fast-inference! @@ -94,7 +111,7 @@ For various models, experiments were conducted using multiple batch sizes under ### Single GPU Performance: -Currently the stats below are calculated based on A100 (single GPU), and we calculate token latency based on average values of context-forward and decoding forward process, which means we combine both of processes to calculate token generation times. We are actively developing new features and methods to furthur optimize the performance of LLM models. Please stay tuned. +Currently the stats below are calculated based on A100 (single GPU), and we calculate token latency based on average values of context-forward and decoding forward process, which means we combine both of processes to calculate token generation times. We are actively developing new features and methods to further optimize the performance of LLM models. Please stay tuned. #### Llama diff --git a/colossalai/inference/__init__.py b/colossalai/inference/__init__.py index e69de29bb2d1..35891307e754 100644 --- a/colossalai/inference/__init__.py +++ b/colossalai/inference/__init__.py @@ -0,0 +1,3 @@ +from .pipeline import PPInferEngine + +__all__ = ["PPInferEngine"] diff --git a/colossalai/inference/pipeline/README.md b/colossalai/inference/pipeline/README.md new file mode 100644 index 000000000000..a90d5d6da182 --- /dev/null +++ b/colossalai/inference/pipeline/README.md @@ -0,0 +1,84 @@ +# 🐳 Pipeline Inference + +## Table of Contents +- [💡 Introduction](#introduction) +- [🔗 Design](#design) +- [🔨 Usage](#usage) + - [Example](#example) + - [Quick start](#quick-start) +- [📊 Performance](#performance) + +## Introduction + +`Pipeline Inference` is a module designed to make inference on a pipeline way. In inference systems, although there is no need to store intermediate information such as activations during forward propagation for backward propagation, the weights of some larger models still cannot fit on a single GPU for inference. This requires us to use model parallelism and other methods to reduce the memory occupation on a single GPU. Pipeline parallelism, as one of the traditional model parallelism approaches, has been widely used due to its reduced all-reduce communication requirements and simple layout. The main issue with pipeline parallelism, known as bubbles, can be almost eliminated in inference because the backward propagation that causes bubbles no longer exists in inference. This makes pipeline parallelism almost bubble-free in the ideal scenario where the sequence length is the same across the pipeline. + +## Design + +Pipeline Inference is composed of three parts: `PPInferEngine`, `MicroBatchManager` and `generate` [schedule](https://github.com/hpcaitech/ColossalAI/blob/feature/pipeline-infer/colossalai/pipeline/schedule/generate.py). + +1. `PPInderEngine` is the High-Level API for users to use. It is responsible for the following tasks: + - Initialize the pipeline inference environment with `PipelineStageManager` and mdoel with `ShardFormer`. + - Run the pipeline inference model. + +2. `MicroBatchManager` is a structure to manage the micro-batch information. It is responsible for the following tasks: + - Record each micro-batch information, like generated new tokens and kvcache. + - Record each micro-batch inference state, like prefill, generate or done. + - Update the micro-batch information. + +3. `generate` schedule implements the simple pipeline inference layout. When pipeline size is 2, we use `torch.distributed.P2Pop` to implement the communication between stages, mainly to solve the race communication. When pipeline size is larger than 2, we use `torch.distributed.broadcast` which is faster than `torch.distributed.P2Pop`. + +## Usage + +### Example +```python +from colossalai.pipeline import PPInferEngine +# Suppose the pipeline size is 2, and use fp16 to do infenrence. Use Llama as an example. +model = LlamaForCausalLM.from_pretrained('/path/to/model') +inputs = tokenizer("Hello, my dog is cute", "What a good day", return_tensors="pt") +engine = PPInferEngine( + pp_size=2, + dtype='fp16', + micro_batch_size=1, + new_length=10, + model=model, + model_policy=LlamaForCausalLMPipelinePolicy()) + +output = engine.inference([inputs]) + +``` + +### Quick start +```shell +cd benchmark +sh run.sh +``` + +## Performance + +We conducted multiple benchmark tests to evaluate the performance. We compared the inference `latency` and `throughputs` between `Pipeline Inference` and `hugging face` pipeline. The test environment is 2*A10, 20G. + +### Llama Throughput(tokens/s) + +#### 7b, fp16 +| batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(8) | 32(8) | 32(16)| +| :---: | :---: | :---: | :---: | :---: | :---: | :---:| +| Pipeline Inference(1024, 128) | 33.31 | 59.98 | 98.92 | 143.47 | 152.61 | OOM | +| Hugging Face(1024, 128) | 41.43 | 65.30 | 91.93 | 114.62 | OOM| OOM | +| Pipeline Inference(512, 512) | 43.37 | 82.81 | 148.03 | 229.06 | 238.67 | 312.82 | +| Hugging Face(512, 512) | 49.13 | 84.91 | 132.87 | 178.30 | OOM| OOM | + +#### 7b, fp32 +| batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(4) | +| :---: | :---: | :---: | :---: | :---: | +| Pipeline Inference(1024, 128) | 20.61 | 31.23 | 45.20 | 47.46 | +| Hugging Face(1024, 128) | 19.80 | 29.37| OOM | OOM | +| Pipeline Inference(512, 512) | 28.07 | 46.76 | 79.35 | 81.70 | +| Hugging Face(512, 512) | 25.67 | 43.97 | 60.67 | OOM | + +#### 13b, fp16 +| batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(4) | +| :---: | :---: | :---: | :---: | :---: | +| Pipeline Inference(1024, 128) | 21.73 | 38.06 | 61.02 | 64.30 | +| Hugging Face(1024, 128) | 23.48 | 37.59 | 53.44 | OOM | +| Pipeline Inference(512, 512) | 26.65 | 49.48 | 86.11 | 88.44 | +| Hugging Face(512, 512) | 27.45 | 47.74 | 74.46 | OOM | diff --git a/colossalai/inference/pipeline/__init__.py b/colossalai/inference/pipeline/__init__.py new file mode 100644 index 000000000000..41af9f3ef948 --- /dev/null +++ b/colossalai/inference/pipeline/__init__.py @@ -0,0 +1,3 @@ +from .engine import PPInferEngine + +__all__ = ["PPInferEngine"] diff --git a/colossalai/inference/pipeline/benchmark/benchmark.py b/colossalai/inference/pipeline/benchmark/benchmark.py new file mode 100644 index 000000000000..9c47909f70f0 --- /dev/null +++ b/colossalai/inference/pipeline/benchmark/benchmark.py @@ -0,0 +1,131 @@ +import argparse +import time + +import torch +import torch.distributed as dist +import transformers + +import colossalai +from colossalai.inference import PPInferEngine +from colossalai.inference.pipeline.policy.llama_ppinfer import LlamaForCausalLMPipelinePolicy + +GIGABYTE = 1024**3 +MEGABYTE = 1024 * 1024 + +colossalai.launch_from_torch(config={}) + + +def data_gen(batch_size: int = 4, seq_len: int = 512): + input_ids = torch.randint(10, 30000, (1, seq_len), dtype=torch.int32) + attention_mask = torch.ones((1, seq_len), dtype=torch.int32) + data = dict(input_ids=input_ids, attention_mask=attention_mask) + for k, v in data.items(): + if torch.is_tensor(v) or "Tensor" in v.__class__.__name__: + new_shape = [1] * v.dim() + new_shape[0] = batch_size + data[k] = v.to("cuda").repeat(*new_shape) + return data + + +def print_details_info(timestamps, model_config, args, whole_end2end): + if dist.get_rank() == 0: + prefill = [] + encoder = [] + end2end = [] + for timestamp in timestamps: + prefill.append(timestamp[1] - timestamp[0]) + encoder.append( + sum(timestamp[i + 1] - timestamp[i] for i in range(1, len(timestamp) - 1)) / (len(timestamp) - 2) + ) + end2end.append(timestamp[-1] - timestamp[0]) + print(whole_end2end) + with open( + f"{args.log_path}/llama-{args.model}{args.dtype}_pp{args.pp_size}_{args.seq_len}_{args.new_length}_bsz{args.batch_size}_mbsz{args.mb_size}.log", + "w+", + ) as f: + mb_avg_end2end = sum(end2end) / len(end2end) + mb_avg_latency = mb_avg_end2end / (args.new_length * args.mb_size) + whole_avg_latency = whole_end2end / (args.new_length * args.batch_size) + num_layers = getattr(model_config, "num_layers", model_config.num_hidden_layers) + num_parameters = num_layers * model_config.hidden_size * model_config.hidden_size * 12 / args.pp_size + if args.dtype in ["fp16", "bf16"]: + num_bytes = 2 + else: + num_bytes = 4 + + f.write( + f"llama-{args.model}{args.dtype}_pp{args.pp_size}, input_len:{args.seq_len}, output_len:{args.new_length}, bsz:{args.batch_size}, mbsz:{args.mb_size}\n" + ) + f.write("Average prefill time: {0:8.2f} ms\n".format(sum(prefill) / len(prefill) * 1000)) + f.write("Average encode time: {0:8.2f} ms\n".format(sum(encoder) / len(encoder) * 1000)) + f.write("Average micro batch end2end time: {0:8.2f} ms\n".format(mb_avg_end2end * 1000)) + f.write("Average micro batch Per Token Latency: {0:8.2f} ms\n".format(mb_avg_latency * 1000)) + f.write("Whole batch end2end time: {0:8.2f} ms\n".format(whole_end2end * 1000)) + f.write("Whole batch Per Token Latency: {0:8.2f} ms\n".format(whole_avg_latency * 1000)) + f.write("Throughput: {} tokens/s\n".format((1000 / (whole_avg_latency * 1000)))) + f.write("flops: {0:8.2f} TFlops/s\n".format(1 / whole_avg_latency * num_parameters * num_bytes / 1e12)) + f.write("----------------------------------------------------------\n") + + if torch.cuda.is_available(): + current_device = torch.cuda.current_device() + + # free memory and the total available memory in bytes + global_free_memory, total_GPU_memory_occupied = torch.cuda.mem_get_info() + memory_allocated = torch.cuda.memory_allocated() + max_memory_allocated = torch.cuda.max_memory_allocated() + memory_reserved = torch.cuda.memory_reserved() + max_memory_reserved = torch.cuda.max_memory_reserved() + with open( + f"{args.log_path}/llama-{args.model}{args.dtype}_pp{args.pp_size}_{args.seq_len}_{args.new_length}_bsz{args.batch_size}_mbsz{args.mb_size}.log", + "a", + ) as f: + f.write( + f"\nCurrently using GPU: {current_device}\n" + f"free memory : {global_free_memory / GIGABYTE:.4f} GB,\n" + f"total memory: {total_GPU_memory_occupied / GIGABYTE:.4f} GB,\n" + f"memory allocated: {memory_allocated / GIGABYTE:.4f} GB,\n" + f"Max CUDA memory allocated: {max_memory_allocated / GIGABYTE:.4f} GB,\n" + f"memory reserved/cached: {memory_reserved / GIGABYTE:.4f} GB,\n" + f"Max CUDA memory reserved/cached: {max_memory_reserved / GIGABYTE:.4f} GB,\n" + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model", default="toy", help="the size of model") + parser.add_argument("-b", "--batch_size", type=int, default=8, help="batch size") + parser.add_argument("-s", "--seq_len", type=int, default=8, help="sequence length") + parser.add_argument("--new_length", type=int, default=4, help="new tokens length") + parser.add_argument("--mb_size", type=int, default=1, help="micro_batch_size") + parser.add_argument("--pp_size", type=int, default=2, help="pipeline size") + parser.add_argument("--log_path", type=str, default="./log", help="where to store the benchmark log") + parser.add_argument("--dtype", type=str, default="fp16", help="data type") + args = parser.parse_args() + + if args.model == "toy": + model = transformers.LlamaForCausalLM(transformers.LlamaConfig(num_hidden_layers=8)) + elif args.model == "7b": + model = transformers.LlamaForCausalLM(transformers.AutoConfig.from_pretrained("decapoda-research/llama-7b-hf")) + elif args.model == "13b": + model = transformers.LlamaForCausalLM(transformers.AutoConfig.from_pretrained("decapoda-research/llama-13b-hf")) + else: + raise NotImplementedError + + engine = PPInferEngine( + pp_size=args.pp_size, + dtype=args.dtype, + micro_batch_size=args.mb_size, + new_length=args.new_length, + model=model, + model_policy=LlamaForCausalLMPipelinePolicy(), + verbose=True, + ) + data = data_gen(args.batch_size, args.seq_len) + + torch.cuda.synchronize() + whole_end2end = time.time() + output, timestamps = engine.inference([data]) + torch.cuda.synchronize() + whole_end2end = time.time() - whole_end2end + + print_details_info(timestamps, model.config, args, whole_end2end) diff --git a/colossalai/inference/pipeline/benchmark/run.sh b/colossalai/inference/pipeline/benchmark/run.sh new file mode 100644 index 000000000000..7d8da858692f --- /dev/null +++ b/colossalai/inference/pipeline/benchmark/run.sh @@ -0,0 +1,50 @@ +script_dir=$(cd "$(dirname "$0")" && pwd) +cd "${script_dir}" + +# 7b, fp32, 2 gpu, 1024, 128 +for BATCH_SIZE in 2 4 8 16; do + CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \ + --model="7b" \ + --dtype="fp16" \ + --batch_size=${BATCH_SIZE} \ + --seq_len=1024 \ + --new_length=128 \ + --mb_size=$((${BATCH_SIZE}/2)) \ + --pp_size=2 +done + +# 7b, fp32, 2 gpu, 512, 512 +for BATCH_SIZE in 2 4 8 16 32; do + CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \ + --model="7b" \ + --dtype="fp16" \ + --batch_size=${BATCH_SIZE} \ + --seq_len=512 \ + --new_length=512 \ + --mb_size=$((${BATCH_SIZE}/2)) \ + --pp_size=2 +done + +# 7b, fp32, 2 gpu, 1024, 128 +for BATCH_SIZE in 2 4 8; do + CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \ + --model="13b" \ + --dtype="fp16" \ + --batch_size=${BATCH_SIZE} \ + --seq_len=1024 \ + --new_length=128 \ + --mb_size=$((${BATCH_SIZE}/2)) \ + --pp_size=2 +done + +# 13b, fp16, 2 gpu, 512, 512 +for BATCH_SIZE in 2 4 8 16; do + CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \ + --model="13b" \ + --dtype="fp16" \ + --batch_size=${BATCH_SIZE} \ + --seq_len=512 \ + --new_length=512 \ + --mb_size=$((${BATCH_SIZE}/2)) \ + --pp_size=2 +done diff --git a/colossalai/inference/pipeline/engine.py b/colossalai/inference/pipeline/engine.py new file mode 100644 index 000000000000..4f42385caf8f --- /dev/null +++ b/colossalai/inference/pipeline/engine.py @@ -0,0 +1,97 @@ +import torch +import torch.nn as nn + +from colossalai.cluster import ProcessGroupMesh +from colossalai.pipeline.schedule.generate import GenerateSchedule +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.shardformer.policies.base_policy import Policy + +from .microbatch_manager import MicroBatchManager + + +class PPInferEngine: + """ + PPInferEngine is a class that handles the pipeline parallel inference. + + Args: + pp_size (int): the number of pipeline stages. + pp_model (`nn.Module`): the model already in pipeline parallelism style. + model (`nn.Module`): the model not in pipeline style, and will be modified with `ShardFormer`. + model_policy (`colossalai.shardformer.policies.base_policy.Policy`): the policy to shardformer model. + micro_batch_size (int): the micro batch size. + micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages. + new_length (int): the new length of the input sequence. + early_stopping (bool): whether to stop early. + + Example: + + ```python + from colossalai.ppinference import PPInferEngine + from transformers import GPT2LMHeadModel, GPT2Tokenizer + + model = transformers.GPT2LMHeadModel.from_pretrained('gpt2') + # assume the model is infered with 4 pipeline stages + inferengine = PPInferEngine(pp_size=4, model=model, model_policy={Your own policy for pipeline sharding}) + + input = ["Hello, my dog is cute, and I like"] + tokenized_input = tokenizer(input, return_tensors='pt') + output = engine.inference([tokenized_input]) + ``` + + """ + + def __init__( + self, + pp_size: int, + dtype: str = "fp16", + pp_model: nn.Module = None, + model: nn.Module = None, + model_policy: Policy = None, + new_length: int = 32, + micro_batch_size: int = 1, + micro_batch_buffer_size: int = None, + verbose: bool = False, + # TODO: implement early_stopping, and various gerneration options + early_stopping: bool = False, + do_sample: bool = False, + num_beams: int = 1, + ) -> None: + assert pp_model or (model and model_policy), "Either pp_model or model with model_policy should be provided." + self.pp_size = pp_size + self.pg_mesh = ProcessGroupMesh(pp_size) + self.stage_manager = PipelineStageManager(self.pg_mesh, 0, True) + self.mb_manager = MicroBatchManager( + self.stage_manager.stage, new_length, micro_batch_size, micro_batch_buffer_size or pp_size + ) + self.verbose = verbose + self.schedule = GenerateSchedule(self.stage_manager, self.mb_manager, verbose) + + assert dtype in ["fp16", "fp32", "bf16"], "dtype should be one of 'fp16', 'fp32', 'bf16'" + if dtype == "fp16": + model.half() + elif dtype == "bf16": + model.to(torch.bfloat16) + self.model = pp_model or self._shardformer(model, model_policy) + + def inference(self, input_list): + out, timestamp = self.schedule.generate_step(self.model, iter(input_list)) + if self.verbose: + return out, timestamp + else: + return out + + def _shardformer(self, model, model_policy): + shardconfig = ShardConfig( + tensor_parallel_process_group=None, + pipeline_stage_manager=self.stage_manager, + enable_tensor_parallelism=False, + enable_fused_normalization=False, + enable_all_optimization=False, + enable_flash_attention=False, + enable_jit_fused=False, + enable_sequence_parallelism=False, + ) + shardformer = ShardFormer(shard_config=shardconfig) + shard_model, _ = shardformer.optimize(model, model_policy) + return shard_model.cuda() diff --git a/colossalai/inference/pipeline/microbatch_manager.py b/colossalai/inference/pipeline/microbatch_manager.py new file mode 100644 index 000000000000..49d1bf3f42cb --- /dev/null +++ b/colossalai/inference/pipeline/microbatch_manager.py @@ -0,0 +1,238 @@ +from enum import Enum +from typing import Dict, Tuple + +import torch + +__all__ = "MicroBatchManager" + + +class Status(Enum): + PREFILL = 1 + GENERATE = 2 + DONE = 3 + COOLDOWN = 4 + + +class MicroBatchDescription: + """ + This is the class to record the infomation of each microbatch, and also do some update operation. + This clase is the base class of `HeadMicroBatchDescription` and `BodyMicroBatchDescription`, for more + details, please refer to the doc of these two classes blow. + + Args: + inputs_dict (Dict[str, torch.Tensor]): the inputs of current stage. The key should have `input_ids` and `attention_mask`. + output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`. + """ + + def __init__( + self, + inputs_dict: Dict[str, torch.Tensor], + output_dict: Dict[str, torch.Tensor], + new_length: int, + ) -> None: + assert output_dict.get("hidden_states") is not None + self.mb_length = output_dict["hidden_states"].shape[-2] + self.target_length = self.mb_length + new_length + self.kv_cache = () + + def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None): + if output_dict is not None: + self._update_kvcache(output_dict["past_key_values"]) + + def _update_kvcache(self, kv_cache: Tuple): + assert type(kv_cache) == tuple + self.kv_cache = kv_cache + + @property + def state(self): + """ + Return the state of current micro batch, when current length is equal to target length, + the state is DONE, otherwise GENERATE + + """ + # TODO: add the condition for early stopping + if self.cur_length == self.target_length: + return Status.DONE + elif self.cur_length == self.target_length - 1: + return Status.COOLDOWN + else: + return Status.GENERATE + + @property + def cur_length(self): + """ + Return the current sequnence length of micro batch + + """ + + +class HeadMicroBatchDescription(MicroBatchDescription): + """ + This class is used to record the infomation of the first stage of pipeline, the first stage should have attributes `input_ids` and `attention_mask` + and `new_tokens`, and the `new_tokens` is the tokens generated by the first stage. Also due to the schdule of pipeline, the operation to update the + information and the condition to determine the state is different from other stages. + + Args: + inputs_dict (Dict[str, torch.Tensor]): the inputs of current stage. The key should have `input_ids` and `attention_mask`. + output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`. + new_length (int): the new length of the input sequence. + + """ + + def __init__( + self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor], new_length: int + ) -> None: + super().__init__(inputs_dict, output_dict, new_length) + assert inputs_dict is not None + assert inputs_dict.get("input_ids") is not None and inputs_dict.get("attention_mask") is not None + self.input_ids = inputs_dict["input_ids"] + self.attn_mask = inputs_dict["attention_mask"] + self.new_tokens = None + + def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None): + super().update(output_dict, new_token) + if new_token is not None: + self._update_newtokens(new_token) + if self.state is not Status.DONE and new_token is not None: + self._update_attnmask() + + def _update_newtokens(self, new_token: torch.Tensor): + if self.new_tokens is None: + self.new_tokens = new_token + else: + self.new_tokens = torch.cat([self.new_tokens, new_token], dim=-1) + + def _update_attnmask(self): + self.attn_mask = torch.cat( + (self.attn_mask, torch.ones((self.attn_mask.shape[0], 1), dtype=torch.int64, device="cuda")), dim=-1 + ) + + @property + def cur_length(self): + """ + When there is no new_token, the length is mb_length, otherwise the sequence length is `mb_length` plus the length of new_token + + """ + if self.new_tokens is None: + return self.mb_length + else: + return self.mb_length + len(self.new_tokens[0]) + + +class BodyMicroBatchDescription(MicroBatchDescription): + """ + This class is used to record the infomation of the stages except the first stage of pipeline, the stages should have attributes `hidden_states` and `past_key_values`, + + Args: + inputs_dict (Dict[str, torch.Tensor]): will always be `None`. Other stages only receive hiddenstates from previous stage. + output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`. + """ + + def __init__( + self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor], new_length: int + ) -> None: + super().__init__(inputs_dict, output_dict, new_length) + + def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None): + super().update(output_dict, new_token) + + @property + def cur_length(self): + """ + When there is no kv_cache, the length is mb_length, otherwise the sequence length is `kv_cache[0][0].shape[-2]` plus 1 + + """ + if len(self.kv_cache) == 0: + return self.mb_length + else: + return self.kv_cache[0][0].shape[-2] + 1 + + +class MicroBatchManager: + """ + MicroBatchManager is a class that manages the micro batch. + + Args: + stage (int): stage id of current stage. + new_length (int): the new length of the input sequence. + micro_batch_size (int): the micro batch size. + micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages. + + """ + + def __init__(self, stage: int, new_length: int, micro_batch_size: int, micro_batch_buffer_size: int): + self.stage = stage + self.new_length = new_length + self.micro_batch_size = micro_batch_size + self.buffer_size = micro_batch_buffer_size + self.mb_descrption_buffer = {} + self.new_tokens_buffer = {} + self.idx = 0 + + def step(self, inputs_dict=None, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None): + """ + Update the state if microbatch manager, 2 conditions. + 1. For first stage in PREFILL, receive inputs and outputs, `_add_descrption` will save its inputs. + 2. For other conditon, only receive the output of previous stage, and update the descrption. + + Args: + inputs_dict (Dict[str, torch.Tensor]): the inputs of current stage. The key should have `input_ids` and `attention_mask`. + output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`. + new_token (torch.Tensor): the new token generated by current stage. + """ + # Add descrption first if the descrption is None + if inputs_dict is None and output_dict is None and new_token is None: + return Status.PREFILL + if self.mb_descrption_buffer.get(self.idx) is None: + self._add_descrption(inputs_dict, output_dict) + self.cur_descrption.update(output_dict, new_token) + return self.cur_state + + def export_new_tokens(self): + new_tokens_list = [] + for i in self.mb_descrption_buffer.values(): + new_tokens_list.extend(i.new_tokens.tolist()) + return new_tokens_list + + def is_micro_batch_done(self): + if len(self.mb_descrption_buffer) == 0: + return False + for mb in self.mb_descrption_buffer.values(): + if mb.state != Status.DONE: + return False + return True + + def clear(self): + self.mb_descrption_buffer.clear() + + def next(self): + self.idx = (self.idx + 1) % self.buffer_size + + def _add_descrption(self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor]): + if self.stage == 0: + self.mb_descrption_buffer[self.idx] = HeadMicroBatchDescription(inputs_dict, output_dict, self.new_length) + else: + self.mb_descrption_buffer[self.idx] = BodyMicroBatchDescription(inputs_dict, output_dict, self.new_length) + + def _remove_descrption(self): + self.mb_descrption_buffer.pop(self.idx) + + @property + def cur_descrption(self) -> MicroBatchDescription: + return self.mb_descrption_buffer.get(self.idx) + + @property + def cur_kv_cache(self): + if self.cur_descrption is None: + return None + return self.cur_descrption.kv_cache + + @property + def cur_state(self): + """ + Return the state of current micro batch, when current descrption is None, the state is PREFILL + + """ + if self.cur_descrption is None: + return Status.PREFILL + return self.cur_descrption.state diff --git a/colossalai/inference/pipeline/modeling/__init__.py b/colossalai/inference/pipeline/modeling/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/colossalai/inference/pipeline/modeling/gpt2.py b/colossalai/inference/pipeline/modeling/gpt2.py new file mode 100644 index 000000000000..d2bfcb8b6842 --- /dev/null +++ b/colossalai/inference/pipeline/modeling/gpt2.py @@ -0,0 +1,280 @@ +from typing import Dict, List, Optional, Tuple, Union + +import torch +from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions +from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel, GPT2Model +from transformers.utils import logging + +from colossalai.pipeline.stage_manager import PipelineStageManager + + +class GPT2PipelineForwards: + """ + This class serves as a micro library for forward function substitution of GPT2 models + under pipeline setting. + """ + + @staticmethod + def gpt2_model_forward( + self: GPT2Model, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + ) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]: + # This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2Model.forward. + # Please refer to original code of transformers for more details. + logger = logging.get_logger(__name__) + + # Preprocess passed in arguments + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + else: + past_length = past_key_values[0][0].size(-2) + + if stage_manager.is_first_stage(): + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + else: + if hidden_states is None: + raise ValueError("hidden_states shouldn't be None for stages other than the first stage.") + input_shape = hidden_states.size()[:-1] + batch_size, seq_length = input_shape[0], input_shape[1] + device = hidden_states.device + + # GPT2Attention mask. + if attention_mask is not None: + if batch_size <= 0: + raise ValueError("batch_size has to be defined and > 0") + attention_mask = attention_mask.view(batch_size, -1) + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and the dtype's smallest value for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.add_cross_attention and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # head_mask has shape n_layer x batch x n_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if stage_manager.is_first_stage(): + if position_ids is not None: + position_ids = position_ids.view(-1, input_shape[-1]) + else: + position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + hidden_states = self.drop(hidden_states) + + output_shape = input_shape + (hidden_states.size(-1),) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + all_hidden_states = () if output_hidden_states else None + + # Going through held blocks. + start_idx, end_idx = stage_index[0], stage_index[1] + for i, layer_past in zip(range(start_idx, end_idx), past_key_values): + block = self.h[i] + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure layer_past is on same device as hidden_states (might not be correct) + if layer_past is not None: + layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if isinstance(head_mask, torch.Tensor): + head_mask = head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, use_cache, output_attentions) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + None, + attention_mask, + head_mask[i], + encoder_hidden_states, + encoder_attention_mask, + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + if stage_manager.is_last_stage(): + hidden_states = self.ln_f(hidden_states) + + hidden_states = hidden_states.view(output_shape) + + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + return {"hidden_states": hidden_states, "past_key_values": presents} + + @staticmethod + def gpt2_lmhead_model_forward( + self: GPT2LMHeadModel, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + ) -> Union[Dict, Tuple, CausalLMOutputWithCrossAttentions]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + + This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.forward. + Please refer to original code of transformers for more details. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # If is first stage and after warmup, go throught lm_head first + if stage_manager.is_first_stage() and hidden_states is not None: + lm_logits = self.lm_head(hidden_states) + return {"logits": lm_logits} + + # Not first stage or before warmup, go through gpt2 model + outputs = GPT2PipelineForwards.gpt2_model_forward( + self.transformer, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + ) + + return outputs diff --git a/colossalai/inference/pipeline/modeling/llama.py b/colossalai/inference/pipeline/modeling/llama.py new file mode 100644 index 000000000000..f46e1fbdd7b3 --- /dev/null +++ b/colossalai/inference/pipeline/modeling/llama.py @@ -0,0 +1,229 @@ +from typing import List, Optional + +import torch +from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaModel +from transformers.utils import logging + +from colossalai.pipeline.stage_manager import PipelineStageManager + + +class LlamaPipelineForwards: + """ + This class serves as a micro library for forward function substitution of Llama models + under pipeline setting. + """ + + def llama_model_forward( + self: LlamaModel, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + ): + logger = logging.get_logger(__name__) + + # Preprocess passed in arguments + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if stage_manager.is_first_stage(): + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + device = input_ids.device if input_ids is not None else inputs_embeds.device + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + hidden_states = inputs_embeds + else: + input_shape = hidden_states.shape[:-1] + batch_size, seq_length = input_shape + device = hidden_states.device + + seq_length_with_past = seq_length + past_key_values_length = 0 + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + # embed positions, for the first stage, hidden_states is the input embeddings, + # for the other stages, hidden_states is the output of the previous stage + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device + ) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length + ) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + start_idx, end_idx = stage_index[0], stage_index[1] + if past_key_values is None: + past_key_values = tuple([None] * (end_idx - start_idx + 1)) + + for idx, past_key_value in zip(range(start_idx, end_idx), past_key_values): + decoder_layer = self.layers[idx] + if output_hidden_states: + all_hidden_states += (hidden_states,) + + # past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, None) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + position_ids, + None, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if stage_manager.is_last_stage(): + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + + # always return dict for imediate stage + return {"hidden_states": hidden_states, "past_key_values": next_cache} + + def llama_for_causal_lm_forward( + self: LlamaForCausalLM, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + ): + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you consciours? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." + ```""" + logger = logging.get_logger(__name__) + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + + # If is first stage and after warmup, go throught lm_head first + if stage_manager.is_first_stage() and hidden_states is not None: + lm_logits = self.lm_head(hidden_states) + return {"logits": lm_logits} + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = LlamaPipelineForwards.llama_model_forward( + self.model, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + ) + + return outputs diff --git a/colossalai/inference/pipeline/policy/gpt2_ppinfer.py b/colossalai/inference/pipeline/policy/gpt2_ppinfer.py new file mode 100644 index 000000000000..51e6425b113e --- /dev/null +++ b/colossalai/inference/pipeline/policy/gpt2_ppinfer.py @@ -0,0 +1,74 @@ +from functools import partial +from typing import Callable, Dict, List + +from torch import Tensor, nn + +import colossalai.shardformer.layer as col_nn +from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription +from colossalai.shardformer.policies.gpt2 import GPT2Policy + +from ..modeling.gpt2 import GPT2PipelineForwards + + +class GPT2LMHeadModelPipelinePolicy(GPT2Policy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel + + module_policy = super().module_policy() + + if self.shard_config.enable_tensor_parallelism: + addon_module = { + GPT2LMHeadModel: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True} + ) + ] + ) + } + module_policy.update(addon_module) + + if self.pipeline_stage_manager is not None: + self.set_pipeline_forward( + model_cls=GPT2LMHeadModel, + new_forward=GPT2PipelineForwards.gpt2_lmhead_model_forward, + policy=module_policy, + ) + return module_policy + + def get_held_layers(self) -> List[nn.Module]: + held_layers = super().get_held_layers() + # make the tie weight lm_head and embedding in the same device to save memory + # if self.pipeline_stage_manager.is_first_stage(): + if self.pipeline_stage_manager.is_first_stage(): + held_layers.append(self.model.lm_head) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """The weights of wte and lm_head are shared.""" + module = self.model + stage_manager = self.pipeline_stage_manager + if stage_manager is not None: + if stage_manager.num_stages > 1 and id(module.transformer.wte.weight) == id(module.lm_head.weight): + first_stage, last_stage = 0, stage_manager.num_stages - 1 + return [{first_stage: module.transformer.wte.weight, last_stage: module.lm_head.weight}] + return [] + + def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: + """If under pipeline parallel setting, replacing the original forward method of huggingface + to customized forward method, and add this changing to policy.""" + if not self.pipeline_stage_manager: + raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.") + stage_manager = self.pipeline_stage_manager + if self.model.__class__.__name__ == "GPT2Model": + module = self.model + else: + module = self.model.transformer + + layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) diff --git a/colossalai/inference/pipeline/policy/llama_ppinfer.py b/colossalai/inference/pipeline/policy/llama_ppinfer.py new file mode 100644 index 000000000000..6e12ed61bf7b --- /dev/null +++ b/colossalai/inference/pipeline/policy/llama_ppinfer.py @@ -0,0 +1,48 @@ +from typing import List + +from torch.nn import Module + +from colossalai.shardformer.layer import Linear1D_Col +from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription +from colossalai.shardformer.policies.llama import LlamaPolicy + +from ..modeling.llama import LlamaPipelineForwards + + +class LlamaForCausalLMPipelinePolicy(LlamaPolicy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + from transformers import LlamaForCausalLM + + policy = super().module_policy() + + if self.shard_config.enable_tensor_parallelism: + # add a new item for casual lm + new_item = { + LlamaForCausalLM: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True) + ) + ] + ) + } + policy.update(new_item) + + if self.pipeline_stage_manager: + # set None as default + self.set_pipeline_forward( + model_cls=LlamaForCausalLM, new_forward=LlamaPipelineForwards.llama_for_causal_lm_forward, policy=policy + ) + + return policy + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + stage_manager = self.pipeline_stage_manager + held_layers = super().get_held_layers() + if stage_manager.is_first_stage(): + held_layers.append(self.model.lm_head) + return held_layers diff --git a/colossalai/inference/pipeline/utils.py b/colossalai/inference/pipeline/utils.py new file mode 100644 index 000000000000..c26aa4e40b71 --- /dev/null +++ b/colossalai/inference/pipeline/utils.py @@ -0,0 +1,35 @@ +from typing import Set + +import torch.nn as nn + +from colossalai.shardformer._utils import getattr_, setattr_ + + +def set_tensors_to_none(model: nn.Module, include: Set[str] = set()) -> None: + """ + Set all parameters and buffers of model to None + + Args: + model (nn.Module): The model to set + """ + for module_suffix in include: + set_module = getattr_(model, module_suffix) + for n, p in set_module.named_parameters(): + setattr_(set_module, n, None) + for n, buf in set_module.named_buffers(): + setattr_(set_module, n, None) + setattr_(model, module_suffix, None) + + +def get_suffix_name(suffix: str, name: str): + """ + Get the suffix name of the module, as `suffix.name` when name is string or `suffix[name]` when name is a digit, + and 'name' when `suffix` is empty. + + Args: + suffix (str): The suffix of the suffix module + name (str): The name of the current module + """ + point = "" if suffix is "" else "." + suffix_name = suffix + f"[{name}]" if name.isdigit() else suffix + f"{point}{name}" + return suffix_name diff --git a/colossalai/inference/quant/smoothquant/__init__.py b/colossalai/inference/quant/smoothquant/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/colossalai/inference/quant/smoothquant/models/__init__.py b/colossalai/inference/quant/smoothquant/models/__init__.py new file mode 100644 index 000000000000..77541d8610c5 --- /dev/null +++ b/colossalai/inference/quant/smoothquant/models/__init__.py @@ -0,0 +1,12 @@ +try: + import torch_int + + HAS_TORCH_INT = True +except ImportError: + HAS_TORCH_INT = False + raise ImportError( + "Not install torch_int. Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int" + ) + +if HAS_TORCH_INT: + from .llama import LLamaSmoothquantAttention, LlamaSmoothquantMLP diff --git a/colossalai/inference/quant/smoothquant/models/base_model.py b/colossalai/inference/quant/smoothquant/models/base_model.py new file mode 100644 index 000000000000..180e6c6e8fa6 --- /dev/null +++ b/colossalai/inference/quant/smoothquant/models/base_model.py @@ -0,0 +1,482 @@ +# Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ +# Adapted from smoothquant: https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/calibration.py +# Adapted from smoothquant: https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/smooth.py + +import os +import warnings +from abc import abstractmethod +from functools import partial +from os.path import isdir, isfile, join +from typing import Dict, List, Optional, Union + +import accelerate +import numpy as np +import torch +import torch.nn as nn +import transformers +from safetensors.torch import save_file as safe_save +from tqdm import tqdm +from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel +from transformers.modeling_utils import no_init_weights +from transformers.utils.generic import ContextManagers +from transformers.utils.hub import PushToHubMixin, cached_file + +from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState +from colossalai.inference.tensor_parallel.kvcache_manager import MemoryManager + +SUPPORTED_MODELS = ["llama"] + + +class BaseSmoothForCausalLM(nn.Module, PushToHubMixin): + layer_type: str = None + + def __init__(self, model: PreTrainedModel, quantized: bool = False): + super().__init__() + + self.model = model + self.model_type = self.model.config.model_type + self._quantized = quantized + self.config = self.model.config + self.cache_manager = None + self.max_total_token_num = 0 + + @property + def quantized(self): + return self._quantized + + def init_cache_manager(self, max_total_token_num=2048): + if self.config.model_type == "llama": + head_num = self.config.num_key_value_heads + layer_num = self.config.num_hidden_layers + head_dim = self.config.hidden_size // head_num + + self.cache_manager = MemoryManager(max_total_token_num, torch.int8, head_num, head_dim, layer_num) + self.max_total_token_num = max_total_token_num + + def init_batch_state(self, max_output_len=256, **kwargs): + input_ids = kwargs["input_ids"] + batch_size = len(input_ids) + + seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + start_index = 0 + max_len_in_batch = -1 + + for i in range(batch_size): + seq_len = len(input_ids[i]) + seq_lengths[i] = seq_len + seq_start_indexes[i] = start_index + start_index += seq_len + max_len_in_batch = seq_len if seq_len > max_len_in_batch else max_len_in_batch + + if "max_total_token_num" in kwargs.keys(): + max_total_token_num = kwargs["max_total_token_num"] + self.init_cache_manager(max_total_token_num) + + if "max_new_tokens" in kwargs.keys(): + max_output_len = kwargs["max_new_tokens"] + + if batch_size * (max_len_in_batch + max_output_len) > self.max_total_token_num: + max_total_token_num = batch_size * (max_len_in_batch + max_output_len) + warnings.warn(f"reset max tokens to {max_total_token_num}") + self.init_cache_manager(max_total_token_num) + + block_loc = torch.empty((batch_size, max_len_in_batch + max_output_len), dtype=torch.long, device="cuda") + batch_infer_state = BatchInferState(batch_size, max_len_in_batch) + batch_infer_state.seq_len = seq_lengths.to("cuda") + batch_infer_state.start_loc = seq_start_indexes.to("cuda") + batch_infer_state.block_loc = block_loc + batch_infer_state.decode_layer_id = 0 + batch_infer_state.past_key_values_len = 0 + batch_infer_state.is_context_stage = True + batch_infer_state.set_cache_manager(self.cache_manager) + batch_infer_state.cache_manager.free_all() + return batch_infer_state + + @abstractmethod + @torch.inference_mode() + def quantize( + self, + examples: List[Dict[str, Union[List[int], torch.LongTensor]]], + ): + if self.quantized: + raise EnvironmentError("can't execute quantize because the model is quantized.") + + def forward(self, *args, **kwargs): + return self.model(*args, **kwargs) + + def generate(self, **kwargs): + """shortcut for model.generate""" + + batch_infer_state = self.init_batch_state(**kwargs) + if self.config.model_type == "llama": + setattr(self.model.model, "infer_state", batch_infer_state) + + with torch.inference_mode(): + return self.model.generate(**kwargs) + + def prepare_inputs_for_generation(self, *args, **kwargs): + """shortcut for model.prepare_inputs_for_generation""" + return self.model.prepare_inputs_for_generation(*args, **kwargs) + + def collect_act_scales(self, model, tokenizer, dataset, device, num_samples=512, seq_len=512): + for text in tqdm(dataset): + input_ids = tokenizer(text, return_tensors="pt", max_length=seq_len, truncation=True).input_ids.to(device) + model(input_ids) + + def collect_act_dict(self, model, tokenizer, dataset, act_dict, device, num_samples=512, seq_len=512): + pbar = tqdm(dataset) + for text in pbar: + input_ids = tokenizer(text, return_tensors="pt", max_length=seq_len, truncation=True).input_ids.to(device) + model(input_ids) + mean_scale = np.mean([v["input"] for v in act_dict.values()]) + pbar.set_description(f"Mean input scale: {mean_scale:.2f}") + + def get_act_scales(self, model, tokenizer, dataset, num_samples=512, seq_len=512): + model.eval() + device = next(model.parameters()).device + act_scales = {} + + def stat_tensor(name, tensor): + hidden_dim = tensor.shape[-1] + tensor = tensor.view(-1, hidden_dim).abs().detach() + comming_max = torch.max(tensor, dim=0)[0].float().cpu() + if name in act_scales: + act_scales[name] = torch.max(act_scales[name], comming_max) + else: + act_scales[name] = comming_max + + def stat_input_hook(m, x, y, name): + if isinstance(x, tuple): + x = x[0] + stat_tensor(name, x) + + hooks = [] + for name, m in model.named_modules(): + if isinstance(m, nn.Linear): + hooks.append(m.register_forward_hook(partial(stat_input_hook, name=name))) + + self.collect_act_scales(model, tokenizer, dataset, device, num_samples, seq_len) + + for h in hooks: + h.remove() + + return act_scales + + @torch.no_grad() + def smooth_ln_fcs(self, ln, fcs, act_scales, alpha=0.5): + if not isinstance(fcs, list): + fcs = [fcs] + for fc in fcs: + assert isinstance(fc, nn.Linear) + assert ln.weight.numel() == fc.in_features == act_scales.numel() + + device, dtype = fcs[0].weight.device, fcs[0].weight.dtype + act_scales = act_scales.to(device=device, dtype=dtype) + weight_scales = torch.cat([fc.weight.abs().max(dim=0, keepdim=True)[0] for fc in fcs], dim=0) + weight_scales = weight_scales.max(dim=0)[0].clamp(min=1e-5) + + scales = (act_scales.pow(alpha) / weight_scales.pow(1 - alpha)).clamp(min=1e-5).to(device).to(dtype) + + ln.weight.div_(scales) + if hasattr(ln, "bias"): + ln.bias.div_(scales) + + for fc in fcs: + fc.weight.mul_(scales.view(1, -1)) + + @classmethod + def create_quantized_model(model): + raise NotImplementedError("Not implement create_quantized_model method") + + def save_quantized( + self, + save_dir: str, + model_basename: str, + use_safetensors: bool = False, + safetensors_metadata: Optional[Dict[str, str]] = None, + ): + """save quantized model and configs to local disk""" + os.makedirs(save_dir, exist_ok=True) + + if not self.quantized: + raise EnvironmentError("can only save quantized model, please execute .quantize first.") + + self.model.to("cpu") + + model_base_name = model_basename # or f"smooth-" + if use_safetensors: + model_save_name = model_base_name + ".safetensors" + state_dict = self.model.state_dict() + state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()} + if safetensors_metadata is None: + safetensors_metadata = {} + elif not isinstance(safetensors_metadata, dict): + raise TypeError("safetensors_metadata must be a dictionary.") + else: + print(f"Received safetensors_metadata: {safetensors_metadata}") + new_safetensors_metadata = {} + converted_keys = False + for key, value in safetensors_metadata.items(): + if not isinstance(key, str) or not isinstance(value, str): + converted_keys = True + try: + new_key = str(key) + new_value = str(value) + except Exception as e: + raise TypeError( + f"safetensors_metadata: both keys and values must be strings and an error occured when trying to convert them: {e}" + ) + if new_key in new_safetensors_metadata: + print( + f"After converting safetensors_metadata keys to strings, the key '{new_key}' is duplicated. Ensure that all your metadata keys are strings to avoid overwriting." + ) + new_safetensors_metadata[new_key] = new_value + safetensors_metadata = new_safetensors_metadata + if converted_keys: + print( + f"One or more safetensors_metadata keys or values had to be converted to str(). Final safetensors_metadata: {safetensors_metadata}" + ) + + # Format is required to enable Accelerate to load the metadata + # otherwise it raises an OSError + safetensors_metadata["format"] = "pt" + + safe_save(state_dict, join(save_dir, model_save_name), safetensors_metadata) + else: + model_save_name = model_base_name + ".bin" + torch.save(self.model.state_dict(), join(save_dir, model_save_name)) + + self.model.config.save_pretrained(save_dir) + + def save_pretrained( + self, + save_dir: str, + use_safetensors: bool = False, + safetensors_metadata: Optional[Dict[str, str]] = None, + **kwargs, + ): + """alias of save_quantized""" + warnings.warn("you are using save_pretrained, which will re-direct to save_quantized.") + self.save_quantized(save_dir, use_safetensors, safetensors_metadata) + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: str, + max_memory: Optional[dict] = None, + trust_remote_code: bool = False, + torch_dtype: torch.dtype = torch.float16, + **model_init_kwargs, + ): + if not torch.cuda.is_available(): + raise EnvironmentError("Load pretrained model to do quantization requires CUDA available.") + + def skip(*args, **kwargs): + pass + + torch.nn.init.kaiming_uniform_ = skip + torch.nn.init.uniform_ = skip + torch.nn.init.normal_ = skip + + # Parameters related to loading from Hugging Face Hub + cache_dir = model_init_kwargs.pop("cache_dir", None) + force_download = model_init_kwargs.pop("force_download", False) + resume_download = model_init_kwargs.pop("resume_download", False) + proxies = model_init_kwargs.pop("proxies", None) + local_files_only = model_init_kwargs.pop("local_files_only", False) + use_auth_token = model_init_kwargs.pop("use_auth_token", None) + revision = model_init_kwargs.pop("revision", None) + subfolder = model_init_kwargs.pop("subfolder", "") + model_init_kwargs.pop("_commit_hash", None) + + cached_file_kwargs = { + "cache_dir": cache_dir, + "force_download": force_download, + "proxies": proxies, + "resume_download": resume_download, + "local_files_only": local_files_only, + "use_auth_token": use_auth_token, + "revision": revision, + "subfolder": subfolder, + } + + config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True, **cached_file_kwargs) + if config.model_type not in SUPPORTED_MODELS: + raise TypeError(f"{config.model_type} isn't supported yet.") + + # enforce some values despite user specified + model_init_kwargs["torch_dtype"] = torch_dtype + model_init_kwargs["trust_remote_code"] = trust_remote_code + if max_memory: + if "disk" in max_memory: + raise NotImplementedError("disk offload not support yet.") + with accelerate.init_empty_weights(): + model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) + model.tie_weights() + + max_memory = accelerate.utils.get_balanced_memory( + model, + max_memory=max_memory, + no_split_module_classes=[cls.layer_type], + dtype=model_init_kwargs["torch_dtype"], + low_zero=False, + ) + model_init_kwargs["device_map"] = accelerate.infer_auto_device_map( + model, + max_memory=max_memory, + no_split_module_classes=[cls.layer_type], + dtype=model_init_kwargs["torch_dtype"], + ) + model_init_kwargs["low_cpu_mem_usage"] = True + + del model + else: + model_init_kwargs["device_map"] = None + model_init_kwargs["low_cpu_mem_usage"] = False + + torch.cuda.empty_cache() + + merged_kwargs = {**model_init_kwargs, **cached_file_kwargs} + model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **merged_kwargs) + + model_config = model.config.to_dict() + seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"] + if any([k in model_config for k in seq_len_keys]): + for key in seq_len_keys: + if key in model_config: + model.seqlen = model_config[key] + break + else: + warnings.warn("can't get model's sequence length from model config, will set to 4096.") + model.seqlen = 4096 + model.eval() + + return cls(model, False) + + @classmethod + def from_quantized( + cls, + model_name_or_path: Optional[str], + model_basename: Optional[str] = None, + device_map: Optional[Union[str, Dict[str, Union[int, str]]]] = None, + max_memory: Optional[dict] = None, + device: Optional[Union[str, int]] = None, + low_cpu_mem_usage: bool = False, + torch_dtype: Optional[torch.dtype] = None, + use_safetensors: bool = False, + trust_remote_code: bool = False, + **kwargs, + ): + """load quantized model from local disk""" + + # Parameters related to loading from Hugging Face Hub + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", False) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", "") + commit_hash = kwargs.pop("_commit_hash", None) + + cached_file_kwargs = { + "cache_dir": cache_dir, + "force_download": force_download, + "proxies": proxies, + "resume_download": resume_download, + "local_files_only": local_files_only, + "use_auth_token": use_auth_token, + "revision": revision, + "subfolder": subfolder, + "_raise_exceptions_for_missing_entries": False, + "_commit_hash": commit_hash, + } + + # == step1: prepare configs and file names == # + config = AutoConfig.from_pretrained( + model_name_or_path, trust_remote_code=trust_remote_code, **cached_file_kwargs + ) + + if config.model_type not in SUPPORTED_MODELS: + raise TypeError(f"{config.model_type} isn't supported yet.") + + extensions = [] + if use_safetensors: + extensions.append(".safetensors") + else: + extensions += [".bin", ".pt"] + + model_name_or_path = str(model_name_or_path) + is_local = isdir(model_name_or_path) + + resolved_archive_file = None + if is_local: + model_save_name = join(model_name_or_path, model_basename) + for ext in extensions: + if isfile(model_save_name + ext): + resolved_archive_file = model_save_name + ext + break + else: # remote + for ext in extensions: + resolved_archive_file = cached_file(model_name_or_path, model_basename + ext, **cached_file_kwargs) + if resolved_archive_file is not None: + break + + if resolved_archive_file is None: # Could not find a model file to use + raise FileNotFoundError(f"Could not find model in {model_name_or_path}") + + model_save_name = resolved_archive_file + + # == step2: convert model to quantized-model (replace Linear) == # + def skip(*args, **kwargs): + pass + + torch.nn.init.kaiming_uniform_ = skip + torch.nn.init.uniform_ = skip + torch.nn.init.normal_ = skip + + transformers.modeling_utils._init_weights = False + + init_contexts = [no_init_weights()] + if low_cpu_mem_usage: + init_contexts.append(accelerate.init_empty_weights(include_buffers=True)) + + with ContextManagers(init_contexts): + model = AutoModelForCausalLM.from_config( + config, trust_remote_code=trust_remote_code, torch_dtype=torch_dtype + ) + cls.create_quantized_model(model) + model.tie_weights() + + # == step3: load checkpoint to quantized-model == # + accelerate.utils.modeling.load_checkpoint_in_model( + model, checkpoint=model_save_name, offload_state_dict=True, offload_buffers=True + ) + + # == step4: set seqlen == # + model_config = model.config.to_dict() + seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"] + if any([k in model_config for k in seq_len_keys]): + for key in seq_len_keys: + if key in model_config: + model.seqlen = model_config[key] + break + else: + warnings.warn("can't get model's sequence length from model config, will set to 4096.") + model.seqlen = 4096 + + return cls( + model, + True, + ) + + def __getattr__(self, item): + try: + return super().__getattr__(item) + except: + return getattr(self.model, item) + + +__all__ = ["BaseSmoothForCausalLM"] diff --git a/colossalai/inference/quant/smoothquant/models/linear.py b/colossalai/inference/quant/smoothquant/models/linear.py new file mode 100644 index 000000000000..048565bfbf5e --- /dev/null +++ b/colossalai/inference/quant/smoothquant/models/linear.py @@ -0,0 +1,177 @@ +# modified from torch-int: https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/nn/linear.py + +import torch +from torch_int._CUDA import linear_a8_w8_b8_o8, linear_a8_w8_bfp32_ofp32 +from torch_int.functional.quantization import quantize_per_tensor_absmax + +try: + from colossalai.kernel.op_builder.smoothquant import SmoothquantBuilder + + smoothquant_cuda = SmoothquantBuilder().load() + HAS_SMOOTHQUANT_CUDA = True +except ImportError: + HAS_SMOOTHQUANT_CUDA = False + raise ImportError("CUDA smoothquant linear is not installed") + + +class W8A8BFP32O32LinearSiLU(torch.nn.Module): + def __init__(self, in_features, out_features, alpha=1.0, beta=1.0): + super().__init__() + self.in_features = in_features + self.out_features = out_features + + self.register_buffer( + "weight", + torch.randint( + -127, + 127, + (self.out_features, self.in_features), + dtype=torch.int8, + requires_grad=False, + ), + ) + self.register_buffer( + "bias", + torch.zeros((1, self.out_features), dtype=torch.float, requires_grad=False), + ) + self.register_buffer("a", torch.tensor(alpha)) + + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + self.weight = self.weight.to(*args, **kwargs) + self.bias = self.bias.to(*args, **kwargs) + return self + + @torch.no_grad() + def forward(self, x): + x_shape = x.shape + x = x.view(-1, x_shape[-1]) + y = smoothquant_cuda.linear_silu_a8_w8_bfp32_ofp32(x, self.weight, self.bias, self.a.item(), 1.0) + y = y.view(*x_shape[:-1], -1) + return y + + @staticmethod + def from_float(module: torch.nn.Linear, input_scale): + int8_module = W8A8BFP32O32LinearSiLU(module.in_features, module.out_features) + int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight) + alpha = input_scale * weight_scale + int8_module.weight = int8_weight + if module.bias is not None: + int8_module.bias.data.copy_(module.bias.to(torch.float)) + int8_module.a = alpha + return int8_module + + +class W8A8B8O8Linear(torch.nn.Module): + # For qkv_proj + def __init__(self, in_features, out_features, alpha=1.0, beta=1.0): + super().__init__() + self.in_features = in_features + self.out_features = out_features + + self.register_buffer( + "weight", + torch.randint( + -127, + 127, + (self.out_features, self.in_features), + dtype=torch.int8, + requires_grad=False, + ), + ) + self.register_buffer( + "bias", + torch.zeros((1, self.out_features), dtype=torch.int8, requires_grad=False), + ) + self.register_buffer("a", torch.tensor(alpha)) + self.register_buffer("b", torch.tensor(beta)) + + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + self.weight = self.weight.to(*args, **kwargs) + self.bias = self.bias.to(*args, **kwargs) + return self + + @torch.no_grad() + def forward(self, x): + x_shape = x.shape + x = x.view(-1, x_shape[-1]) + y = linear_a8_w8_b8_o8(x, self.weight, self.bias, self.a.item(), self.b.item()) + y = y.view(*x_shape[:-1], -1) + return y + + @staticmethod + def from_float(module: torch.nn.Linear, input_scale, output_scale): + int8_module = W8A8B8O8Linear(module.in_features, module.out_features) + int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight) + alpha = input_scale * weight_scale / output_scale + int8_module.weight = int8_weight + int8_module.a = alpha + + if module.bias is not None: + int8_bias, bias_scale = quantize_per_tensor_absmax(module.bias) + int8_module.bias = int8_bias + beta = bias_scale / output_scale + int8_module.b = beta + + return int8_module + + +class W8A8BFP32OFP32Linear(torch.nn.Module): + # For fc2 and out_proj + def __init__(self, in_features, out_features, alpha=1.0, beta=1.0): + super().__init__() + self.in_features = in_features + self.out_features = out_features + + self.register_buffer( + "weight", + torch.randint( + -127, + 127, + (self.out_features, self.in_features), + dtype=torch.int8, + requires_grad=False, + ), + ) + self.register_buffer( + "bias", + torch.zeros(self.out_features, dtype=torch.float32, requires_grad=False), + ) + self.register_buffer("a", torch.tensor(alpha)) + + def _apply(self, fn): + # prevent the bias from being converted to half + super()._apply(fn) + self.bias = self.bias.to(torch.float32) + return self + + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + self.weight = self.weight.to(*args, **kwargs) + self.bias = self.bias.to(*args, **kwargs) + self.bias = self.bias.to(torch.float32) + return self + + @torch.no_grad() + def forward(self, x): + x_shape = x.shape + x = x.view(-1, x_shape[-1]) + y = linear_a8_w8_bfp32_ofp32(x, self.weight, self.bias, self.a.item(), 1) + y = y.view(*x_shape[:-1], -1) + return y + + @staticmethod + def from_float(module: torch.nn.Linear, input_scale): + int8_module = W8A8BFP32OFP32Linear(module.in_features, module.out_features) + int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight) + alpha = input_scale * weight_scale + int8_module.weight = int8_weight + int8_module.a = alpha + int8_module.input_scale = input_scale + int8_module.weight_scale = weight_scale + + if module.bias is not None: + int8_module.bias = module.bias.to(torch.float32) + + return int8_module diff --git a/colossalai/inference/quant/smoothquant/models/llama.py b/colossalai/inference/quant/smoothquant/models/llama.py new file mode 100644 index 000000000000..9c77feeb346e --- /dev/null +++ b/colossalai/inference/quant/smoothquant/models/llama.py @@ -0,0 +1,846 @@ +import math +import os +import types +from collections import defaultdict +from functools import partial +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch_int.nn.bmm import BMM_S8T_S8N_F32T, BMM_S8T_S8N_S8T +from transformers import PreTrainedModel +from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import ( + LLAMA_INPUTS_DOCSTRING, + LlamaAttention, + LlamaDecoderLayer, + LlamaMLP, + LlamaRotaryEmbedding, + repeat_kv, + rotate_half, +) +from transformers.utils import add_start_docstrings_to_model_forward + +from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState +from colossalai.kernel.triton import ( + copy_kv_cache_to_dest, + int8_rotary_embedding_fwd, + smooth_llama_context_attn_fwd, + smooth_token_attention_fwd, +) + +from .base_model import BaseSmoothForCausalLM +from .linear import W8A8B8O8Linear, W8A8BFP32O32LinearSiLU, W8A8BFP32OFP32Linear + + +class LLamaSmoothquantAttention(nn.Module): + def __init__( + self, + hidden_size: int, + num_heads: int, + ): + super().__init__() + self.hidden_size = hidden_size + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + + if (self.head_dim * num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {num_heads})." + ) + + self.qk_bmm = BMM_S8T_S8N_F32T(1.0) + self.pv_bmm = BMM_S8T_S8N_S8T(1.0) + + self.k_proj = W8A8B8O8Linear(hidden_size, hidden_size) + self.v_proj = W8A8B8O8Linear(hidden_size, hidden_size) + self.q_proj = W8A8B8O8Linear(hidden_size, hidden_size) + self.o_proj = W8A8BFP32OFP32Linear(hidden_size, hidden_size) + + self.register_buffer("q_output_scale", torch.tensor([1.0])) + self.register_buffer("k_output_scale", torch.tensor([1.0])) + self.register_buffer("v_output_scale", torch.tensor([1.0])) + self.register_buffer("q_rotary_output_scale", torch.tensor([1.0])) + self.register_buffer("k_rotary_output_scale", torch.tensor([1.0])) + self.register_buffer("out_input_scale", torch.tensor([1.0])) + self.register_buffer("attn_input_scale", torch.tensor([1.0])) + + self._init_rope() + self.num_key_value_heads = num_heads + + def _init_rope(self): + self.rotary_emb = LlamaRotaryEmbedding( + self.head_dim, + max_position_embeddings=2048, + base=10000.0, + ) + + @staticmethod + def pack( + module: LlamaAttention, + attn_input_scale: float, + q_output_scale: float, + k_output_scale: float, + v_output_scale: float, + q_rotary_output_scale: float, + k_rotary_output_scale: float, + out_input_scale: float, + ): + int8_module = LLamaSmoothquantAttention(module.hidden_size, module.num_heads) + + int8_module.attn_input_scale = torch.tensor([attn_input_scale]) + + int8_module.q_output_scale = torch.tensor([q_output_scale]) + int8_module.k_output_scale = torch.tensor([k_output_scale]) + int8_module.v_output_scale = torch.tensor([v_output_scale]) + + int8_module.q_rotary_output_scale = torch.tensor([q_rotary_output_scale]) + int8_module.k_rotary_output_scale = torch.tensor([k_rotary_output_scale]) + + int8_module.q_proj = W8A8B8O8Linear.from_float(module.q_proj, attn_input_scale, q_output_scale) + int8_module.k_proj = W8A8B8O8Linear.from_float(module.k_proj, attn_input_scale, k_output_scale) + int8_module.v_proj = W8A8B8O8Linear.from_float(module.v_proj, attn_input_scale, v_output_scale) + int8_module.o_proj = W8A8BFP32OFP32Linear.from_float(module.o_proj, out_input_scale) + + int8_module.out_input_scale = torch.tensor([out_input_scale]) + + return int8_module + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + @torch.no_grad() + def forward( + self, + hidden_states: torch.Tensor, + rotary_emb: Tuple[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + padding_mask: Optional[torch.LongTensor] = None, + infer_state: Optional[BatchInferState] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + cos = rotary_emb[0] + sin = rotary_emb[1] + + int8_rotary_embedding_fwd( + query_states.view(-1, self.num_heads, self.head_dim), + cos, + sin, + self.q_output_scale.item(), + self.q_rotary_output_scale.item(), + ) + int8_rotary_embedding_fwd( + key_states.view(-1, self.num_heads, self.head_dim), + cos, + sin, + self.k_output_scale.item(), + self.k_rotary_output_scale.item(), + ) + + # NOTE might want to revise + # need some way to record the length of past key values cache + # since we won't return past_key_value_cache right now + if infer_state.decode_layer_id == 0: # once per model.forward + infer_state.cache_manager.past_key_values_length += q_len # seq_len + + def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager): + copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id]) + copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id]) + return + + query_states = query_states.view(-1, self.num_heads, self.head_dim) + key_states = key_states.view(-1, self.num_heads, self.head_dim) + value_states = value_states.view(-1, self.num_heads, self.head_dim) + + if infer_state.is_context_stage: + # first token generation + + # copy key and value calculated in current step to memory manager + _copy_kv_to_mem_cache( + infer_state.decode_layer_id, + key_states, + value_states, + infer_state.context_mem_index, + infer_state.cache_manager, + ) + + attn_output = torch.empty_like(query_states) + + smooth_llama_context_attn_fwd( + query_states, + key_states, + value_states, + attn_output, + self.q_rotary_output_scale.item(), + self.k_rotary_output_scale.item(), + self.v_output_scale.item(), + self.out_input_scale.item(), + infer_state.start_loc, + infer_state.seq_len, + q_len, + ) + + else: + if infer_state.decode_is_contiguous: + # if decode is contiguous, then we copy to key cache and value cache in cache manager directly + cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][ + infer_state.decode_mem_start : infer_state.decode_mem_end, :, : + ] + cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][ + infer_state.decode_mem_start : infer_state.decode_mem_end, :, : + ] + cache_k.copy_(key_states) + cache_v.copy_(value_states) + else: + # if decode is not contiguous, use triton kernel to copy key and value cache + # k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head + _copy_kv_to_mem_cache( + infer_state.decode_layer_id, + key_states, + value_states, + infer_state.decode_mem_index, + infer_state.cache_manager, + ) + + # (batch_size, seqlen, nheads, headdim) + attn_output = torch.empty_like(query_states) + + smooth_token_attention_fwd( + query_states, + infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], + infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], + attn_output, + self.q_rotary_output_scale.item(), + self.k_rotary_output_scale.item(), + self.v_output_scale.item(), + self.out_input_scale.item(), + infer_state.block_loc, + infer_state.start_loc, + infer_state.seq_len, + infer_state.cache_manager.past_key_values_length, + ) + + attn_output = attn_output.view(bsz, q_len, self.num_heads * self.head_dim) + attn_output = self.o_proj(attn_output) + + return attn_output, None, None + + +class LlamaLayerNormQ(torch.nn.Module): + def __init__(self, dim, eps=1e-5): + super().__init__() + self.input_scale = 1.0 + self.variance_epsilon = eps + self.register_buffer("weight", torch.ones(dim, dtype=torch.float32)) + + def forward(self, x): + ln_output_fp = torch.nn.functional.layer_norm(x, x.shape[-1:], self.weight, None, self.variance_epsilon) + ln_output_int8 = ln_output_fp.round().clamp(-128, 127).to(torch.int8) + return ln_output_int8 + + @staticmethod + def from_float(module: torch.nn.LayerNorm, output_scale: float): + assert module.weight.shape[0] == module.weight.numel() + q_module = LlamaLayerNormQ(module.weight.shape[0], module.variance_epsilon) + q_module.weight = module.weight / output_scale + return q_module + + +class LlamaSmoothquantMLP(nn.Module): + def __init__(self, intermediate_size, hidden_size): + super().__init__() + self.gate_proj = W8A8BFP32O32LinearSiLU(hidden_size, intermediate_size) + self.up_proj = W8A8BFP32OFP32Linear(hidden_size, intermediate_size) + self.down_proj = W8A8BFP32OFP32Linear(intermediate_size, hidden_size) + self.register_buffer("down_proj_input_scale", torch.tensor([1.0])) + + @staticmethod + def pack( + mlp_module: LlamaMLP, + gate_proj_input_scale: float, + up_proj_input_scale: float, + down_proj_input_scale: float, + ): + int8_module = LlamaSmoothquantMLP( + mlp_module.intermediate_size, + mlp_module.hidden_size, + ) + + int8_module.gate_proj = W8A8BFP32O32LinearSiLU.from_float(mlp_module.gate_proj, gate_proj_input_scale) + int8_module.up_proj = W8A8BFP32OFP32Linear.from_float(mlp_module.up_proj, up_proj_input_scale) + int8_module.down_proj = W8A8BFP32OFP32Linear.from_float(mlp_module.down_proj, down_proj_input_scale) + int8_module.down_proj_input_scale = torch.tensor([down_proj_input_scale]) + return int8_module + + def forward( + self, + hidden_states: torch.Tensor, + ): + x_shape = hidden_states.shape + gate_out = self.gate_proj(hidden_states) + up_out = self.up_proj(hidden_states) + inter_out = gate_out * up_out + inter_out = inter_out.div_(self.down_proj_input_scale.item()).round().clamp(-128, 127).to(torch.int8) + down_out = self.down_proj(inter_out) + down_out = down_out.view(*x_shape[:-1], -1) + return down_out + + +class LlamaSmoothquantDecoderLayer(nn.Module): + def __init__(self, config: LlamaConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = LLamaSmoothquantAttention(config.hidden_size, config.num_attention_heads) + + self.mlp = LlamaSmoothquantMLP(config.intermediate_size, config.hidden_size) + self.input_layernorm = LlamaLayerNormQ(config.hidden_size, eps=config.rms_norm_eps) + + self.post_attention_layernorm = LlamaLayerNormQ(config.hidden_size, eps=config.rms_norm_eps) + + @staticmethod + def pack( + module: LlamaDecoderLayer, + attn_input_scale: float, + q_output_scale: float, + k_output_scale: float, + v_output_scale: float, + q_rotary_output_scale: float, + k_rotary_output_scale: float, + out_input_scale: float, + gate_input_scale: float, + up_input_scale: float, + down_input_scale: float, + ): + config = module.self_attn.config + int8_decoder_layer = LlamaSmoothquantDecoderLayer(config) + + int8_decoder_layer.input_layernorm = LlamaLayerNormQ.from_float(module.input_layernorm, attn_input_scale) + int8_decoder_layer.self_attn = LLamaSmoothquantAttention.pack( + module.self_attn, + attn_input_scale, + q_output_scale, + k_output_scale, + v_output_scale, + q_rotary_output_scale, + k_rotary_output_scale, + out_input_scale, + ) + + int8_decoder_layer.post_attention_layernorm = LlamaLayerNormQ.from_float( + module.post_attention_layernorm, gate_input_scale + ) + + int8_decoder_layer.mlp = LlamaSmoothquantMLP.pack( + module.mlp, + gate_input_scale, + up_input_scale, + down_input_scale, + ) + + return int8_decoder_layer + + def forward( + self, + hidden_states: torch.Tensor, + rotary_emb: Tuple[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + padding_mask: Optional[torch.LongTensor] = None, + infer_state: Optional[BatchInferState] = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + rotary_emb=rotary_emb, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + padding_mask=padding_mask, + infer_state=infer_state, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states, None, None + + +class LlamaApplyRotary(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, cos, sin, position_ids): + # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. + cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] + sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] + cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + x_embed = (x * cos) + (rotate_half(x) * sin) + + return x_embed + + +def llama_decoder_layer_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + padding_mask: Optional[torch.LongTensor] = None, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states = self.q_apply_rotary(query_states, cos, sin, position_ids) + key_states = self.k_apply_rotary(key_states, cos, sin, position_ids) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +def init_to_get_rotary(config, base=10000, use_elem=False): + """ + This function initializes the rotary positional embedding, it is compatible for all models and is called in ShardFormer + Args: + base : calculation arg + use_elem : activated when using chatglm-based models + """ + config.head_dim_ = config.hidden_size // config.num_attention_heads + if not hasattr(config, "rope_scaling"): + rope_scaling_factor = 1.0 + else: + rope_scaling_factor = config.rope_scaling.factor if config.rope_scaling is not None else 1.0 + + if hasattr(config, "max_sequence_length"): + max_seq_len = config.max_sequence_length + elif hasattr(config, "max_position_embeddings"): + max_seq_len = config.max_position_embeddings * rope_scaling_factor + else: + max_seq_len = 2048 * rope_scaling_factor + base = float(base) + + # NTK ref: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ + try: + ntk_alpha = float(os.environ.get("INFER_NTK_ALPHA", 1)) + assert ntk_alpha >= 1 + if ntk_alpha > 1: + print(f"Note: NTK enabled, alpha set to {ntk_alpha}") + max_seq_len *= ntk_alpha + base = base * (ntk_alpha ** (config.head_dim_ / (config.head_dim_ - 2))) # Base change formula + except: + pass + + n_elem = config.head_dim_ + if use_elem: + n_elem //= 2 + + inv_freq = 1.0 / (base ** (torch.arange(0, n_elem, 2, device="cpu", dtype=torch.float32) / n_elem)) + t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor + freqs = torch.outer(t, inv_freq) + + _cos_cached = torch.cos(freqs).to(torch.float) + _sin_cached = torch.sin(freqs).to(torch.float) + return _cos_cached, _sin_cached + + +@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) +def llama_model_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, +) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + seq_length_with_past = seq_length + past_key_values_length = 0 + + infer_state = self.infer_state + + if past_key_values is not None: + # NOT READY FOR PRIME TIME + # dummy but work, revise it + past_key_values_length = infer_state.cache_manager.past_key_values_length + # past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + # NOTE: differentiate with prefill stage + # block_loc require different value-assigning method for two different stage + # NOTE: differentiate with prefill stage + # block_loc require different value-assigning method for two different stage + if infer_state.is_context_stage: + infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num) + infer_state.init_block_loc( + infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index + ) + else: + alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size) + if alloc_mem is not None: + infer_state.decode_is_contiguous = True + infer_state.decode_mem_index = alloc_mem[0] + infer_state.decode_mem_start = alloc_mem[1] + infer_state.decode_mem_end = alloc_mem[2] + infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index + else: + print(f" *** Encountered allocation non-contiguous") + print( + f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}" + ) + infer_state.decode_is_contiguous = False + alloc_mem = infer_state.cache_manager.alloc(batch_size) + infer_state.decode_mem_index = alloc_mem + infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + # embed positions + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device) + padding_mask = None + else: + if 0 in attention_mask: + padding_mask = attention_mask + else: + padding_mask = None + + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + raise NotImplementedError("not implement gradient_checkpointing and training options ") + + if past_key_values_length == 0: + position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view( + position_ids.view(-1).shape[0], -1 + ) + position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view( + position_ids.view(-1).shape[0], -1 + ) + else: + position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view(batch_size, -1) + position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(batch_size, -1) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + infer_state.decode_layer_id = 0 + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + layer_outputs = decoder_layer( + hidden_states, + rotary_emb=(position_cos, position_sin), + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + padding_mask=padding_mask, + infer_state=infer_state, + ) + + hidden_states = layer_outputs[0] + infer_state.decode_layer_id += 1 + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + infer_state.is_context_stage = False + infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") + infer_state.seq_len += 1 + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class SmoothLlamaForCausalLM(BaseSmoothForCausalLM): + layer_type = "LlamaDecoderLayer" + + def __init__(self, model: PreTrainedModel, quantized: bool = False): + super().__init__(model, quantized) + + def get_act_dict( + self, + tokenizer, + dataset, + num_samples=512, + seq_len=512, + ): + llama_model = self.model + + llama_model.eval() + device = next(llama_model.parameters()).device + # print("model:", llama_model) + act_dict = defaultdict(dict) + + def stat_io_hook(m, x, y, name): + if isinstance(x, tuple): + x = x[0] + if name not in act_dict or "input" not in act_dict[name]: + act_dict[name]["input"] = x.detach().abs().max().item() + else: + act_dict[name]["input"] = max(act_dict[name]["input"], x.detach().abs().max().item()) + if isinstance(y, tuple): + y = y[0] + if name not in act_dict or "output" not in act_dict[name]: + act_dict[name]["output"] = y.detach().abs().max().item() + else: + act_dict[name]["output"] = max(act_dict[name]["output"], y.detach().abs().max().item()) + + for name, m in llama_model.named_modules(): + if isinstance(m, LlamaAttention): + setattr(m, "q_apply_rotary", LlamaApplyRotary()) + setattr(m, "k_apply_rotary", LlamaApplyRotary()) + m.forward = types.MethodType(llama_decoder_layer_forward, m) + + hooks = [] + for name, m in llama_model.named_modules(): + if isinstance(m, LlamaApplyRotary): + hooks.append(m.register_forward_hook(partial(stat_io_hook, name=name))) + if isinstance(m, torch.nn.Linear): + hooks.append(m.register_forward_hook(partial(stat_io_hook, name=name))) + + self.collect_act_dict(llama_model, tokenizer, dataset, act_dict, device, num_samples, seq_len) + + for hook in hooks: + hook.remove() + return act_dict + + def smooth_fn(self, scales, alpha=0.5): + model = self.model + for name, module in model.named_modules(): + if isinstance(module, LlamaDecoderLayer): + attn_ln = module.input_layernorm + qkv = [module.self_attn.q_proj, module.self_attn.k_proj, module.self_attn.v_proj] + qkv_input_scales = scales[name + ".self_attn.q_proj"] + self.smooth_ln_fcs(attn_ln, qkv, qkv_input_scales, alpha) + + def create_quantized_model(model): + llama_config = model.config + for i, layer in enumerate(model.model.layers): + model.model.layers[i] = LlamaSmoothquantDecoderLayer(llama_config) + + model.model.forward = types.MethodType(llama_model_forward, model.model) + cos, sin = init_to_get_rotary(llama_config) + model.model.register_buffer("_cos_cached", cos) + model.model.register_buffer("_sin_cached", sin) + + def quantized( + self, + tokenizer, + dataset, + num_samples=512, + seq_len=512, + alpha=0.5, + ): + llama_model = self.model + llama_config = llama_model.config + + act_scales = self.get_act_scales(llama_model, tokenizer, dataset, num_samples, seq_len) + + self.smooth_fn(act_scales, alpha) + + act_dict = self.get_act_dict(tokenizer, dataset, num_samples, seq_len) + decoder_layer_scales = [] + + for idx in range(llama_config.num_hidden_layers): + scale_dict = {} + scale_dict["attn_input_scale"] = act_dict[f"model.layers.{idx}.self_attn.q_proj"]["input"] / 127 + scale_dict["q_output_scale"] = act_dict[f"model.layers.{idx}.self_attn.q_proj"]["output"] / 127 + scale_dict["k_output_scale"] = act_dict[f"model.layers.{idx}.self_attn.k_proj"]["output"] / 127 + scale_dict["v_output_scale"] = act_dict[f"model.layers.{idx}.self_attn.v_proj"]["output"] / 127 + + scale_dict["q_rotary_output_scale"] = ( + act_dict[f"model.layers.{idx}.self_attn.q_apply_rotary"]["output"] / 127 + ) + scale_dict["k_rotary_output_scale"] = ( + act_dict[f"model.layers.{idx}.self_attn.k_apply_rotary"]["output"] / 127 + ) + + scale_dict["out_input_scale"] = act_dict[f"model.layers.{idx}.self_attn.o_proj"]["input"] / 127 + + scale_dict["gate_input_scale"] = act_dict[f"model.layers.{idx}.mlp.gate_proj"]["input"] / 127 + scale_dict["up_input_scale"] = act_dict[f"model.layers.{idx}.mlp.up_proj"]["input"] / 127 + scale_dict["down_input_scale"] = act_dict[f"model.layers.{idx}.mlp.down_proj"]["input"] / 127 + + decoder_layer_scales.append(scale_dict) + + for i, layer in enumerate(llama_model.model.layers): + orig_layer = layer + llama_model.model.layers[i] = LlamaSmoothquantDecoderLayer.pack(orig_layer, **decoder_layer_scales[i]) + + llama_model.model.forward = types.MethodType(llama_model_forward, llama_model.model) + + cos, sin = init_to_get_rotary(llama_config) + llama_model.model.register_buffer("_cos_cached", cos.to(self.model.device)) + llama_model.model.register_buffer("_sin_cached", sin.to(self.model.device)) diff --git a/colossalai/inference/tensor_parallel/batch_infer_state.py b/colossalai/inference/tensor_parallel/batch_infer_state.py index ac185f1b6529..de150311cc08 100644 --- a/colossalai/inference/tensor_parallel/batch_infer_state.py +++ b/colossalai/inference/tensor_parallel/batch_infer_state.py @@ -5,7 +5,7 @@ from .kvcache_manager import MemoryManager - +# adapted from: lightllm/server/router/model_infer/infer_batch.py @dataclass class BatchInferState: r""" @@ -41,6 +41,7 @@ def total_token_num(self): def set_cache_manager(self, manager: MemoryManager): self.cache_manager = manager + # adapted from: https://github.com/ModelTC/lightllm/blob/28c1267cfca536b7b4f28e921e03de735b003039/lightllm/common/infer_utils.py#L1 @staticmethod def init_block_loc( b_loc: torch.Tensor, seq_len: torch.Tensor, max_len_in_batch: int, alloc_mem_index: torch.Tensor diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py index e75004d506a3..c97134d1fa96 100644 --- a/colossalai/inference/tensor_parallel/engine.py +++ b/colossalai/inference/tensor_parallel/engine.py @@ -74,9 +74,14 @@ def __init__( model.config.num_hidden_layers if hasattr(model.config, "num_hidden_layers") else model.config.num_layers ) self.layer_num = num_hidden_layers - self.multi_query_group_num = ( - model.config.multi_query_group_num if hasattr(model.config, "multi_query_group_num") else 0 - ) + + self.multi_query_group_num = 0 + + if hasattr(model.config, "multi_query_group_num"): + self.multi_query_group_num = model.config.multi_query_group_num + + if hasattr(model.config, "num_key_value_heads"): + self.multi_query_group_num = model.config.num_key_value_heads self.tp_size = -1 # to be set with given shard config in self.prepare_shard_config self.cache_manager = None @@ -99,6 +104,7 @@ def _init_manager(self) -> None: assert self.tp_size >= 1, "TP size not initialized without providing a valid ShardConfig" assert self.head_num % self.tp_size == 0, f"Cannot shard {self.head_num} heads with tp size {self.tp_size}" self.head_num //= self.tp_size # update sharded number of heads + if self.multi_query_group_num: # NOTE the logic of MQA tensor parallelism should be specified. assert ( diff --git a/colossalai/inference/tensor_parallel/kvcache_manager.py b/colossalai/inference/tensor_parallel/kvcache_manager.py index e74a3a491a7b..c9e7aaae0844 100644 --- a/colossalai/inference/tensor_parallel/kvcache_manager.py +++ b/colossalai/inference/tensor_parallel/kvcache_manager.py @@ -1,7 +1,9 @@ -# Adapted from lightllm/common/mem_manager.py -# of the ModelTC/lightllm GitHub repository -# https://github.com/ModelTC/lightllm/blob/050af3ce65edca617e2f30ec2479397d5bb248c9/lightllm/common/mem_manager.py - +""" +Refered/Modified from lightllm/common/mem_manager.py +of the ModelTC/lightllm GitHub repository +https://github.com/ModelTC/lightllm/blob/050af3ce65edca617e2f30ec2479397d5bb248c9/lightllm/common/mem_manager.py +we slightly changed it to make it suitable for our colossal-ai shardformer TP-engine design. +""" import torch from transformers.utils import logging diff --git a/colossalai/inference/tensor_parallel/modeling/chatglm2.py b/colossalai/inference/tensor_parallel/modeling/chatglm2.py index 4b1bc601f436..b8274d3c660f 100644 --- a/colossalai/inference/tensor_parallel/modeling/chatglm2.py +++ b/colossalai/inference/tensor_parallel/modeling/chatglm2.py @@ -6,8 +6,6 @@ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState -from colossalai.kernel.triton.context_attention import llama2_context_attn_fwd -from colossalai.kernel.triton.rotary_embedding_kernel import Llama2Forwards from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ( ChatGLMForConditionalGeneration, @@ -20,6 +18,14 @@ from ._utils import copy_kv_to_mem_cache +try: + from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import context_attention_fwd as lightllm_llama2_context_attention_fwd + from lightllm.models.chatglm2.triton_kernel.rotary_emb import rotary_emb_fwd as chatglm2_rotary_emb_fwd + HAS_LIGHTLLM_KERNEL = True +except: + print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm") + HAS_LIGHTLLM_KERNEL = False + # This func is same as Llama model init_to_get_rotary, we should move them into _utils.py def _init_to_get_rotary(self, base=10000): @@ -433,17 +439,17 @@ def chatglm_flash_attn_kvcache_forward( cos, sin = infer_state.position_cos, infer_state.position_sin - Llama2Forwards.rotary_emb_fwd( + chatglm2_rotary_emb_fwd( query_layer.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head), cos, sin ) if self.multi_query_attention: - Llama2Forwards.rotary_emb_fwd( + chatglm2_rotary_emb_fwd( key_layer.view(-1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head), cos, sin, ) else: - Llama2Forwards.rotary_emb_fwd( + chatglm2_rotary_emb_fwd( key_layer.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head), cos, sin, @@ -474,7 +480,7 @@ def chatglm_flash_attn_kvcache_forward( attn_output = torch.empty_like(query_layer.view(-1, self.projection_size)) # NOTE: no bug in context attn fwd (del it ) - llama2_context_attn_fwd( + lightllm_llama2_context_attention_fwd( query_layer, key_layer, value_layer, diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index 7e6978ad815b..1d61251857ad 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -5,7 +5,8 @@ from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState -from colossalai.kernel.triton import llama_context_attn_fwd, rotary_embedding_fwd, token_attention_fwd +from colossalai.kernel.triton import llama_context_attn_fwd, token_attention_fwd +from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards from ._utils import copy_kv_to_mem_cache @@ -23,6 +24,17 @@ ) HAS_VLLM_KERNERL = False +try: + from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import ( + context_attention_fwd as lightllm_llama2_context_attention_fwd, + ) + from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd as llama_rotary_embedding_fwd + + HAS_LIGHTLLM_KERNEL = True +except: + print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm") + HAS_LIGHTLLM_KERNEL = False + def rotate_half(x): """Rotates half the hidden dims of the input.""" @@ -134,6 +146,7 @@ def llama_model_forward( seq_len = infer_state.seq_len infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) infer_state.position_sin = torch.index_select(self._sin_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) + infer_state.other_kv_index = infer_state.block_loc[0, seq_length_with_past - 1].item() if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) @@ -255,8 +268,8 @@ def llama_flash_attn_kvcache_forward( # key_states_transposed [bs, num_heads, seq_len, head_dim/embed_size_per_head] query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) - key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) - value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) + key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim) + value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim) # NOTE might want to revise # need some way to record the length of past key values cache @@ -264,12 +277,12 @@ def llama_flash_attn_kvcache_forward( cos, sin = infer_state.position_cos, infer_state.position_sin - rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin) - rotary_embedding_fwd(key_states.view(-1, self.num_heads, self.head_dim), cos, sin) + llama_rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin) + llama_rotary_embedding_fwd(key_states.view(-1, self.num_key_value_heads, self.head_dim), cos, sin) query_states = query_states.reshape(-1, self.num_heads, self.head_dim) - key_states = key_states.reshape(-1, self.num_heads, self.head_dim) - value_states = value_states.reshape(-1, self.num_heads, self.head_dim) + key_states = key_states.reshape(-1, self.num_key_value_heads, self.head_dim) + value_states = value_states.reshape(-1, self.num_key_value_heads, self.head_dim) if infer_state.is_context_stage: # first token generation @@ -282,15 +295,27 @@ def llama_flash_attn_kvcache_forward( infer_state.cache_manager, ) attn_output = torch.empty_like(query_states) - llama_context_attn_fwd( - query_states, - key_states, - value_states, - attn_output, - infer_state.start_loc, - infer_state.seq_len, - infer_state.max_len_in_batch, - ) + + if self.num_key_value_groups == 1: + llama_context_attn_fwd( + query_states, + key_states, + value_states, + attn_output, + infer_state.start_loc, + infer_state.seq_len, + infer_state.cache_manager.past_key_values_length, + ) + else: + lightllm_llama2_context_attention_fwd( + query_states, + key_states, + value_states, + attn_output, + infer_state.start_loc, + infer_state.seq_len, + infer_state.cache_manager.past_key_values_length, + ) else: if infer_state.decode_is_contiguous: # if decode is contiguous, then we copy to key cache and value cache in cache manager directly @@ -318,16 +343,29 @@ def llama_flash_attn_kvcache_forward( # (batch_size, seqlen, nheads, headdim) attn_output = torch.empty_like(query_states) - token_attention_fwd( - query_states, - infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], - infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], - attn_output, - infer_state.block_loc, - infer_state.start_loc, - infer_state.seq_len, - infer_state.max_len_in_batch, - ) + if self.num_key_value_groups == 1: + token_attention_fwd( + query_states, + infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], + infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], + attn_output, + infer_state.block_loc, + infer_state.start_loc, + infer_state.seq_len, + infer_state.cache_manager.past_key_values_length, + ) + else: + Llama2TokenAttentionForwards.token_attn( + query_states, + infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], + infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], + attn_output, + infer_state.block_loc, + infer_state.start_loc, + infer_state.seq_len, + infer_state.cache_manager.past_key_values_length, + infer_state.other_kv_index, + ) attn_output = attn_output.view(bsz, q_len, self.hidden_size) diff --git a/colossalai/inference/tensor_parallel/policies/llama.py b/colossalai/inference/tensor_parallel/policies/llama.py index 507c1203dd6b..7e163efe0173 100644 --- a/colossalai/inference/tensor_parallel/policies/llama.py +++ b/colossalai/inference/tensor_parallel/policies/llama.py @@ -12,8 +12,7 @@ from ..modeling.llama import LlamaInferenceForwards, get_llama_vllm_rmsnorm_forward try: - from colossalai.kernel.triton import rmsnorm_forward - + from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward as lightllm_rmsnorm_forward HAS_TRITON_RMSNORM = True except: print("you should install triton from https://github.com/openai/triton") @@ -22,9 +21,8 @@ def get_triton_rmsnorm_forward(): if HAS_TRITON_RMSNORM: - def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor): - return rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon) + return lightllm_rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon) return _triton_rmsnorm_forward else: diff --git a/colossalai/kernel/cuda_native/csrc/cpu_adam.cpp b/colossalai/kernel/cuda_native/csrc/cpu_adam.cpp index 0ab250218da3..be9300c545c2 100644 --- a/colossalai/kernel/cuda_native/csrc/cpu_adam.cpp +++ b/colossalai/kernel/cuda_native/csrc/cpu_adam.cpp @@ -35,23 +35,19 @@ SOFTWARE void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg, float *_exp_avg_sq, size_t _param_size, bool param_half_precision, bool grad_half_precision, - float loss_scale) { - size_t rounded_size = 0; + bool momentum_half_precision, + bool variance_half_precision, float loss_scale) { + size_t rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH); float betta1_minus1 = 1 - _betta1; float betta2_minus1 = 1 - _betta2; float step_size = -1 * _alpha / _bias_correction1; float w_decay = -1 * _alpha * _weight_decay; - __half *params_cast_h = NULL; - __half *grads_cast_h = NULL; - - if (param_half_precision) { - params_cast_h = reinterpret_cast<__half *>(_params); - } - if (grad_half_precision) { - grads_cast_h = reinterpret_cast<__half *>(grads); - } + __half *params_cast_h = reinterpret_cast<__half *>(_params); + __half *grads_cast_h = reinterpret_cast<__half *>(grads); + __half *momentum_cast_h = reinterpret_cast<__half *>(_exp_avg); + __half *variance_cast_h = reinterpret_cast<__half *>(_exp_avg_sq); #if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__) AVX_Data betta1_4; @@ -77,7 +73,6 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg, if (_weight_decay > 0) weight_decay_4.data = (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay)); - rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH); for (size_t t = 0; t < rounded_size; t += TILE) { size_t copy_size = TILE; @@ -87,28 +82,23 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg, #pragma omp parallel for for (size_t i = t; i < offset; i += SIMD_WIDTH) { AVX_Data grad_4; - if (grad_half_precision) { - grad_4.data = SIMD_LOAD_HALF(grads_cast_h + i); - } else { - grad_4.data = SIMD_LOAD(grads + i); - } + this->simd_load(grad_half_precision, grads + i, grads_cast_h + i, grad_4); if (loss_scale > 0) { AVX_Data loss_scale_vec; loss_scale_vec.data = SIMD_SET(loss_scale); grad_4.data = SIMD_DIV(grad_4.data, loss_scale_vec.data); } AVX_Data momentum_4; - momentum_4.data = SIMD_LOAD(_exp_avg + i); + this->simd_load(momentum_half_precision, _exp_avg + i, + momentum_cast_h + i, momentum_4); AVX_Data variance_4; - variance_4.data = SIMD_LOAD(_exp_avg_sq + i); + this->simd_load(variance_half_precision, _exp_avg_sq + i, + variance_cast_h + i, variance_4); AVX_Data param_4; - if (param_half_precision) { - param_4.data = SIMD_LOAD_HALF(params_cast_h + i); - } else { - param_4.data = SIMD_LOAD(_params + i); - } + this->simd_load(param_half_precision, _params + i, params_cast_h + i, + param_4); if (_weight_decay > 0 && !_adamw_mode) { grad_4.data = SIMD_FMA(param_4.data, weight_decay_4.data, grad_4.data); @@ -130,13 +120,12 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg, } param_4.data = SIMD_FMA(grad_4.data, step_size_4.data, param_4.data); - if (param_half_precision) { - SIMD_STORE_HALF((float *)(params_cast_h + i), param_4.data); - } else { - SIMD_STORE(_params + i, param_4.data); - } - SIMD_STORE(_exp_avg + i, momentum_4.data); - SIMD_STORE(_exp_avg_sq + i, variance_4.data); + this->simd_store(param_half_precision, _params + i, params_cast_h + i, + param_4); + this->simd_store(momentum_half_precision, _exp_avg + i, + momentum_cast_h + i, momentum_4); + this->simd_store(variance_half_precision, _exp_avg_sq + i, + variance_cast_h + i, variance_4); } } #endif @@ -154,8 +143,10 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg, } float param = param_half_precision ? (float)params_cast_h[k] : _params[k]; - float momentum = _exp_avg[k]; - float variance = _exp_avg_sq[k]; + float momentum = + momentum_half_precision ? (float)momentum_cast_h[k] : _exp_avg[k]; + float variance = variance_half_precision ? (float)variance_cast_h[k] + : _exp_avg_sq[k]; if (_weight_decay > 0 && !_adamw_mode) { grad = param * _weight_decay + grad; } @@ -178,8 +169,14 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg, params_cast_h[k] = (__half)param; else _params[k] = param; - _exp_avg[k] = momentum; - _exp_avg_sq[k] = variance; + if (momentum_half_precision) + momentum_cast_h[k] = (__half)(momentum); + else + _exp_avg[k] = momentum; + if (variance_half_precision) + variance_cast_h[k] = (__half)(variance); + else + _exp_avg_sq[k] = variance; } } } @@ -188,17 +185,14 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg, void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg, float *_exp_avg_sq, size_t _param_size, bool param_half_precision, bool grad_half_precision, - float loss_scale) { - size_t rounded_size = 0; + bool momentum_half_precision, + bool variance_half_precision, float loss_scale) { + size_t rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 4); - __half *params_cast_h = NULL; - __half *grads_cast_h = NULL; - if (param_half_precision) { - params_cast_h = reinterpret_cast<__half *>(_params); - } - if (grad_half_precision) { - grads_cast_h = reinterpret_cast<__half *>(grads); - } + __half *params_cast_h = reinterpret_cast<__half *>(_params); + __half *grads_cast_h = reinterpret_cast<__half *>(grads); + __half *momentum_cast_h = reinterpret_cast<__half *>(_exp_avg); + __half *variance_cast_h = reinterpret_cast<__half *>(_exp_avg_sq); #if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__) AVX_Data betta1_4; @@ -228,7 +222,6 @@ void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg, if (_weight_decay > 0) weight_decay_4.data = (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay)); - rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 4); for (size_t t = 0; t < rounded_size; t += TILE) { size_t copy_size = TILE; @@ -243,26 +236,21 @@ void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg, AVX_Data param_4[4]; #pragma unroll 4 for (int j = 0; j < 4; j++) { - if (grad_half_precision) { - grad_4[j].data = SIMD_LOAD_HALF(grads_cast_h + i + SIMD_WIDTH * j); - } else { - grad_4[j].data = SIMD_LOAD(grads + i + SIMD_WIDTH * j); - } + this->simd_load(grad_half_precision, grads + i + SIMD_WIDTH * j, + grads_cast_h + i + SIMD_WIDTH * j, grad_4[j]); if (loss_scale > 0) { AVX_Data loss_scale_vec; loss_scale_vec.data = SIMD_SET(loss_scale); grad_4[j].data = SIMD_DIV(grad_4[j].data, loss_scale_vec.data); } - - momentum_4[j].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * j); - variance_4[j].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * j); - - if (param_half_precision) { - param_4[j].data = SIMD_LOAD_HALF(params_cast_h + i + SIMD_WIDTH * j); - } else { - param_4[j].data = SIMD_LOAD(_params + i + SIMD_WIDTH * j); - } + this->simd_load(momentum_half_precision, _exp_avg + i + SIMD_WIDTH * j, + momentum_cast_h + i + SIMD_WIDTH * j, momentum_4[j]); + this->simd_load(variance_half_precision, + _exp_avg_sq + i + SIMD_WIDTH * j, + variance_cast_h + i + SIMD_WIDTH * j, variance_4[j]); + this->simd_load(param_half_precision, _params + i + SIMD_WIDTH * j, + params_cast_h + i + SIMD_WIDTH * j, param_4[j]); if (_weight_decay > 0 && !_adamw_mode) { grad_4[j].data = @@ -285,14 +273,13 @@ void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg, } param_4[j].data = SIMD_FMA(grad_4[j].data, step_size_4.data, param_4[j].data); - if (param_half_precision) { - SIMD_STORE_HALF((float *)(params_cast_h + i + SIMD_WIDTH * j), - param_4[j].data); - } else { - SIMD_STORE(_params + i + SIMD_WIDTH * j, param_4[j].data); - } - SIMD_STORE(_exp_avg + i + SIMD_WIDTH * j, momentum_4[j].data); - SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH * j, variance_4[j].data); + this->simd_store(param_half_precision, _params + i + SIMD_WIDTH * j, + params_cast_h + i + SIMD_WIDTH * j, param_4[j]); + this->simd_store(momentum_half_precision, _exp_avg + i + SIMD_WIDTH * j, + momentum_cast_h + i + SIMD_WIDTH * j, momentum_4[j]); + this->simd_store(variance_half_precision, + _exp_avg_sq + i + SIMD_WIDTH * j, + variance_cast_h + i + SIMD_WIDTH * j, variance_4[j]); } } } @@ -302,24 +289,26 @@ void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg, : _params + rounded_size), (grad_half_precision ? (float *)(grads_cast_h + rounded_size) : grads + rounded_size), - (_exp_avg + rounded_size), (_exp_avg_sq + rounded_size), + (momentum_half_precision ? (float *)(momentum_cast_h + rounded_size) + : _exp_avg + rounded_size), + (variance_half_precision ? (float *)(variance_cast_h + rounded_size) + : _exp_avg_sq + rounded_size), (_param_size - rounded_size), param_half_precision, - grad_half_precision, loss_scale); + grad_half_precision, momentum_half_precision, + variance_half_precision, loss_scale); } void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg, float *_exp_avg_sq, size_t _param_size, bool param_half_precision, bool grad_half_precision, - float loss_scale) { - size_t rounded_size = 0; - __half *params_cast_h = NULL; - __half *grads_cast_h = NULL; - if (param_half_precision) { - params_cast_h = reinterpret_cast<__half *>(_params); - } - if (grad_half_precision) { - grads_cast_h = reinterpret_cast<__half *>(grads); - } + bool momentum_half_precision, + bool variance_half_precision, float loss_scale) { + size_t rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 8); + __half *params_cast_h = reinterpret_cast<__half *>(_params); + __half *grads_cast_h = reinterpret_cast<__half *>(grads); + __half *momentum_cast_h = reinterpret_cast<__half *>(_exp_avg); + __half *variance_cast_h = reinterpret_cast<__half *>(_exp_avg_sq); + #if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__) AVX_Data betta1_4; betta1_4.data = SIMD_SET(_betta1); @@ -348,7 +337,6 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg, if (_weight_decay > 0) weight_decay_4.data = (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay)); - rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 8); for (size_t t = 0; t < rounded_size; t += TILE) { size_t copy_size = TILE; @@ -363,26 +351,21 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg, AVX_Data param_4[8]; #pragma unroll 8 for (int j = 0; j < 8; j++) { - if (grad_half_precision) { - grad_4[j].data = SIMD_LOAD_HALF(grads_cast_h + i + SIMD_WIDTH * j); - } else { - grad_4[j].data = SIMD_LOAD(grads + i + SIMD_WIDTH * j); - } + this->simd_load(grad_half_precision, grads + i + SIMD_WIDTH * j, + grads_cast_h + i + SIMD_WIDTH * j, grad_4[j]); if (loss_scale > 0) { AVX_Data loss_scale_vec; loss_scale_vec.data = SIMD_SET(loss_scale); grad_4[j].data = SIMD_DIV(grad_4[j].data, loss_scale_vec.data); } - - momentum_4[j].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * j); - variance_4[j].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * j); - - if (param_half_precision) { - param_4[j].data = SIMD_LOAD_HALF(params_cast_h + i + SIMD_WIDTH * j); - } else { - param_4[j].data = SIMD_LOAD(_params + i + SIMD_WIDTH * j); - } + this->simd_load(momentum_half_precision, _exp_avg + i + SIMD_WIDTH * j, + momentum_cast_h + i + SIMD_WIDTH * j, momentum_4[j]); + this->simd_load(variance_half_precision, + _exp_avg_sq + i + SIMD_WIDTH * j, + variance_cast_h + i + SIMD_WIDTH * j, variance_4[j]); + this->simd_load(param_half_precision, _params + i + SIMD_WIDTH * j, + params_cast_h + i + SIMD_WIDTH * j, param_4[j]); if (_weight_decay > 0 && !_adamw_mode) { grad_4[j].data = @@ -405,15 +388,13 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg, param_4[j].data = SIMD_FMA(grad_4[j].data, step_size_4.data, param_4[j].data); - if (param_half_precision) { - SIMD_STORE_HALF((float *)(params_cast_h + i + SIMD_WIDTH * j), - param_4[j].data); - } else { - SIMD_STORE(_params + i + SIMD_WIDTH * j, param_4[j].data); - } - - SIMD_STORE(_exp_avg + i + (SIMD_WIDTH * j), momentum_4[j].data); - SIMD_STORE(_exp_avg_sq + i + (SIMD_WIDTH * j), variance_4[j].data); + this->simd_store(param_half_precision, _params + i + SIMD_WIDTH * j, + params_cast_h + i + SIMD_WIDTH * j, param_4[j]); + this->simd_store(momentum_half_precision, _exp_avg + i + SIMD_WIDTH * j, + momentum_cast_h + i + SIMD_WIDTH * j, momentum_4[j]); + this->simd_store(variance_half_precision, + _exp_avg_sq + i + SIMD_WIDTH * j, + variance_cast_h + i + SIMD_WIDTH * j, variance_4[j]); } } } @@ -423,9 +404,13 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg, : _params + rounded_size), (grad_half_precision ? (float *)(grads_cast_h + rounded_size) : grads + rounded_size), - (_exp_avg + rounded_size), (_exp_avg_sq + rounded_size), + (momentum_half_precision ? (float *)(momentum_cast_h + rounded_size) + : _exp_avg + rounded_size), + (variance_half_precision ? (float *)(variance_cast_h + rounded_size) + : _exp_avg_sq + rounded_size), (_param_size - rounded_size), param_half_precision, - grad_half_precision, loss_scale); + grad_half_precision, momentum_half_precision, + variance_half_precision, loss_scale); } void Adam_Optimizer::step(size_t step, float lr, float beta1, float beta2, @@ -447,7 +432,9 @@ void Adam_Optimizer::step(size_t step, float lr, float beta1, float beta2, this->update_state(lr, epsilon, weight_decay, bias_correction); this->Step_8(params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr, params_c.numel(), (params.options().dtype() == at::kHalf), - (grads.options().dtype() == at::kHalf), loss_scale); + (grads.options().dtype() == at::kHalf), + (exp_avg.options().dtype() == at::kHalf), + (exp_avg_sq.options().dtype() == at::kHalf), loss_scale); } namespace py = pybind11; diff --git a/colossalai/kernel/cuda_native/csrc/cpu_adam.h b/colossalai/kernel/cuda_native/csrc/cpu_adam.h index 4247da942775..bf9b85997c78 100644 --- a/colossalai/kernel/cuda_native/csrc/cpu_adam.h +++ b/colossalai/kernel/cuda_native/csrc/cpu_adam.h @@ -50,9 +50,9 @@ SOFTWARE #define SIMD_DIV(x, y) _mm512_div_ps(x, y) #define SIMD_LOAD_HALF(x) \ _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(x))) -#define SIMD_STORE_HALF(x, d) \ - _mm256_store_ps( \ - x, _mm256_castsi256_ps(_mm512_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) +#define SIMD_STORE_HALF(x, d) \ + _mm256_storeu_ps((float *)(x), _mm256_castsi256_ps(_mm512_cvtps_ph( \ + d, _MM_FROUND_TO_NEAREST_INT))) #elif defined(__AVX256__) or defined(__AVX2__) #define SIMD_WIDTH 8 @@ -66,9 +66,9 @@ SOFTWARE #define SIMD_SQRT(x) _mm256_sqrt_ps(x) #define SIMD_DIV(x, y) _mm256_div_ps(x, y) #define SIMD_LOAD_HALF(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x))) -#define SIMD_STORE_HALF(x, d) \ - _mm_store_ps( \ - x, _mm_castsi128_ps(_mm256_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) +#define SIMD_STORE_HALF(x, d) \ + _mm_storeu_ps((float *)(x), _mm_castsi128_ps(_mm256_cvtps_ph( \ + d, _MM_FROUND_TO_NEAREST_INT))) #endif @@ -83,11 +83,12 @@ union AVX_Data { #endif -#define STEP(SPAN) \ - void Step_##SPAN(float *_params, float *grads, float *_exp_avg, \ - float *_exp_avg_sq, size_t _param_size, \ - bool param_half_precision = false, \ - bool grad_half_precision = false, float loss_scale = -1); +#define STEP(SPAN) \ + void Step_##SPAN( \ + float *_params, float *grads, float *_exp_avg, float *_exp_avg_sq, \ + size_t _param_size, bool param_half_precision = false, \ + bool grad_half_precision = false, bool momentum_half_precision = false, \ + bool variance_half_precision = false, float loss_scale = -1); class Adam_Optimizer { public: @@ -141,6 +142,24 @@ class Adam_Optimizer { } } + inline void simd_load(bool is_half, float *ptr, __half *h_ptr, + AVX_Data &data) { + if (is_half) { + data.data = SIMD_LOAD_HALF(h_ptr); + } else { + data.data = SIMD_LOAD(ptr); + } + } + + inline void simd_store(bool is_half, float *ptr, __half *h_ptr, + AVX_Data &data) { + if (is_half) { + SIMD_STORE_HALF(h_ptr, data.data); + } else { + SIMD_STORE(ptr, data.data); + } + } + void step(size_t step, float lr, float beta1, float beta2, float epsilon, float weight_decay, bool bias_correction, torch::Tensor ¶ms, torch::Tensor &grads, torch::Tensor &exp_avg, diff --git a/colossalai/kernel/cuda_native/csrc/smoothquant/binding.cpp b/colossalai/kernel/cuda_native/csrc/smoothquant/binding.cpp new file mode 100644 index 000000000000..8444272940b4 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/smoothquant/binding.cpp @@ -0,0 +1,8 @@ +#include + +#include "linear.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("linear_silu_a8_w8_bfp32_ofp32", &linear_silu_a8_w8_bfp32_ofp32, + "Linear SiLU (INT8)"); +} diff --git a/colossalai/kernel/cuda_native/csrc/smoothquant/linear.cu b/colossalai/kernel/cuda_native/csrc/smoothquant/linear.cu new file mode 100644 index 000000000000..a30d02a4cf42 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/smoothquant/linear.cu @@ -0,0 +1,162 @@ +// modified from https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/kernels/linear.cu + +#include "linear.h" +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +torch::Tensor linear_silu_a8_w8_bfp32_ofp32(torch::Tensor input, // INT8 + torch::Tensor weight, // INT8 + torch::Tensor bias, // FP32 + float alpha, // FP32 + float beta // FP32 +) { + auto M = input.size(0); + auto N = weight.size(0); + auto K = input.size(1); + + using ElementOutput = float; + using ElementAccumulator = int32_t; + using ElementComputeEpilogue = float; + using ElementInputA = int8_t; // <- data type of elements in input matrix A + using ElementInputB = int8_t; // <- data type of elements in input matrix B + + // The code section below describes matrix layout of input and output + // matrices. Column Major for Matrix A, Row Major for Matrix B and Row Major + // for Matrix C + using LayoutInputA = cutlass::layout::RowMajor; + using LayoutInputB = cutlass::layout::ColumnMajor; + using LayoutOutput = cutlass::layout::RowMajor; + +#if CUDA_ARCH >= 800 + using EpilogueOp = cutlass::epilogue::thread::LinearCombinationSilu< + ElementOutput, // <- data type of output matrix + 128 / cutlass::sizeof_bits< + ElementOutput>::value, // <- this is the number of elements per + // vectorized memory access. For half + // precision, it's 8 elements. This + // becomes the vector width of math + // instructions in epilogue too + ElementAccumulator, // <- data type of accumulator + ElementComputeEpilogue // <- data type for alpha in linear combination + // function + >; + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, + ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + EpilogueOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; +#elif CUDA_ARCH >= 750 + using EpilogueOp = cutlass::epilogue::thread::LinearCombinationSilu< + ElementOutput, // <- data type of output matrix + 128 / cutlass::sizeof_bits< + ElementOutput>::value, // <- this is the number of elements per + // vectorized memory access. For half + // precision, it's 8 elements. This + // becomes the vector width of math + // instructions in epilogue too + ElementAccumulator, // <- data type of accumulator + ElementComputeEpilogue // <- data type for alpha in linear combination + // function + >; + + using DefaultGemmCfg = cutlass::gemm::device::DefaultGemmConfiguration< + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, + ElementInputA, ElementInputB, ElementOutput, ElementAccumulator>; + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, + ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, + DefaultGemmCfg::ThreadblockShape, DefaultGemmCfg::WarpShape, + DefaultGemmCfg::InstructionShape, + EpilogueOp>; +#elif CUDA_ARCH >= 700 + #define USE_TORCH_SILU + using DefaultGemmCfg = cutlass::gemm::device::DefaultGemmConfiguration< + cutlass::arch::OpClassSimt, cutlass::arch::Sm70, + ElementInputA, ElementInputB, ElementOutput, ElementAccumulator>; + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, + ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, + cutlass::arch::OpClassSimt, cutlass::arch::Sm70, + DefaultGemmCfg::ThreadblockShape, DefaultGemmCfg::WarpShape, + DefaultGemmCfg::InstructionShape, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 1, ElementAccumulator, ElementComputeEpilogue>>; +#else + #error "Unsupported cuda arch" +#endif + + auto input_size = cutlass::MatrixCoord(M, K); + auto weight_size = cutlass::MatrixCoord(K, N); + auto output_size = cutlass::MatrixCoord(M, N); + + auto device = input.device(); + // use the broadcasted bias as the output + auto out = bias.to(device).view({1, -1}).repeat({M, 1}); + + // constexpr int kSparse = Gemm::kSparse; + // How many elements of A are covered per ElementE + // constexpr int kElementsPerElementE = Gemm::kElementsPerElementE; + // The size of individual meta data + // constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits; + cutlass::gemm::GemmCoord problem_size(M, N, K); + + cutlass::TensorRef input_ref( + input.data_ptr(), LayoutInputA::packed(input_size)); + cutlass::TensorRef weight_ref( + weight.data_ptr(), LayoutInputB::packed(weight_size)); + cutlass::TensorRef out_ref( + out.data_ptr(), LayoutOutput::packed(output_size)); + + typename Gemm::Arguments arguments{ + problem_size, // <- problem size of matrix multiplication + input_ref, // <- reference to matrix A on device + weight_ref, // <- reference to matrix B on device + out_ref, // <- reference to matrix C on device + out_ref, // <- reference to matrix D on device + {alpha, beta}, 1}; + Gemm gemm_op; + + // Using the arguments, query for extra workspace required for matrix + // multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check the problem size is supported or not + cutlass::Status status = gemm_op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot implement"); + } + + // Initialize CUTLASS kernel with arguments and workspace pointer + status = gemm_op.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot initialize"); + } + + status = gemm_op(); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot run"); + } +#ifdef USE_TORCH_SILU +#undef USE_TORCH_SILU + out = torch::silu(out); +#endif + return out; +} diff --git a/colossalai/kernel/cuda_native/csrc/smoothquant/linear.h b/colossalai/kernel/cuda_native/csrc/smoothquant/linear.h new file mode 100644 index 000000000000..b62a27f3f8f3 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/smoothquant/linear.h @@ -0,0 +1,12 @@ +#include +#include + +#include +#include + +torch::Tensor linear_silu_a8_w8_bfp32_ofp32(torch::Tensor input, // INT8 + torch::Tensor weight, // INT8 + torch::Tensor bias, // FP32 + float alpha, // FP32 + float beta // FP32 +); diff --git a/colossalai/kernel/triton/__init__.py b/colossalai/kernel/triton/__init__.py index 070ebe45f659..20da71d394bd 100644 --- a/colossalai/kernel/triton/__init__.py +++ b/colossalai/kernel/triton/__init__.py @@ -12,8 +12,8 @@ from .copy_kv_cache_dest import copy_kv_cache_to_dest from .fused_layernorm import layer_norm from .gptq_triton import gptq_fused_linear_triton - from .rms_norm import rmsnorm_forward - from .rotary_embedding_kernel import rotary_embedding_fwd + from .int8_rotary_embedding_kernel import int8_rotary_embedding_fwd + from .smooth_attention import smooth_llama_context_attn_fwd, smooth_token_attention_fwd from .softmax import softmax from .token_attention_kernel import token_attention_fwd @@ -22,9 +22,10 @@ "bloom_context_attn_fwd", "softmax", "layer_norm", - "rmsnorm_forward", "copy_kv_cache_to_dest", - "rotary_embedding_fwd", "token_attention_fwd", "gptq_fused_linear_triton", + "int8_rotary_embedding_fwd", + "smooth_llama_context_attn_fwd", + "smooth_token_attention_fwd", ] diff --git a/colossalai/kernel/triton/context_attention.py b/colossalai/kernel/triton/context_attention.py index 01d54566483a..1b4f6e44b0f2 100644 --- a/colossalai/kernel/triton/context_attention.py +++ b/colossalai/kernel/triton/context_attention.py @@ -238,329 +238,5 @@ def llama_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): num_warps=num_warps, num_stages=1, ) - return - - @triton.jit - def _fwd_kernel_latest( - Q, - K, - V, - sm_scale, - B_Start_Loc, - B_Seqlen, - Out, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - kv_group_num, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - ): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - start_m = tl.program_id(2) - - cur_kv_head = cur_head // kv_group_num - - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - - block_start_loc = BLOCK_M * start_m - - # initialize offsets - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - off_q = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs - + cur_head * stride_qh - + offs_d[None, :] * stride_qd - ) - off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd - off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd - - q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) - - k_ptrs = K + off_k - v_ptrs = V + off_v - - # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - - block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) - - for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = tl.load( - k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, - mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, - other=0.0, - ) - # mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0) - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - qk *= sm_scale - qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) - - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - p = tl.exp(qk - m_ij[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - m_i_new = tl.maximum(m_i, m_ij) - alpha = tl.exp(m_i - m_i_new) - beta = tl.exp(m_ij - m_i_new) - l_i_new = alpha * l_i + beta * l_ij - # -- update output accumulator -- - # scale p - p_scale = beta / l_i_new - p = p * p_scale[:, None] - # scale acc - acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] - # update acc - v = tl.load( - v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, - mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, - other=0.0, - ) - - p = p.to(v.dtype) - acc += tl.dot(p, v) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - # initialize pointers to output - off_o = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs - + cur_head * stride_oh - + offs_d[None, :] * stride_od - ) - out_ptrs = Out + off_o - tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) - return - - @triton.jit - def _fwd_kernel_old( - Q, - K, - V, - sm_scale, - B_Start_Loc, - B_Seqlen, - TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug - Out, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - stride_tmp_b, - stride_tmp_h, - stride_tmp_s, - kv_group_num, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - ): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - start_m = tl.program_id(2) - - cur_kv_head = cur_head // kv_group_num - - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - - block_start_loc = BLOCK_M * start_m - - # initialize offsets - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - off_q = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs - + cur_head * stride_qh - + offs_d[None, :] * stride_qd - ) - off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd - off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd - q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) - - k_ptrs = K + off_k - v_ptrs = V + off_v - - t_ptrs = TMP + cur_batch * stride_tmp_b + cur_head * stride_tmp_h + offs_m * stride_tmp_s - # t_ptrs = TMP + offs_m - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - - block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) - - for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = tl.load( - k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, - mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, - other=0.0, - ) - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - qk *= sm_scale - qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) - - m_ij = tl.max(qk, 1) - p = tl.exp(qk - m_ij[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - m_i_new = tl.maximum(m_i, m_ij) - alpha = tl.exp(m_i - m_i_new) - beta = tl.exp(m_ij - m_i_new) - l_i_new = alpha * l_i + beta * l_ij - # -- update output accumulator -- - # scale p - p_scale = beta / l_i_new - p = p * p_scale[:, None] - # scale acc - acc_scale = l_i / l_i_new * alpha - tl.store(t_ptrs, acc_scale) - acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load - acc = acc * acc_scale[:, None] - # update acc - v = tl.load( - v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, - mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, - other=0.0, - ) - - p = p.to(v.dtype) - acc += tl.dot(p, v) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - # initialize pointers to output - off_o = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs - + cur_head * stride_oh - + offs_d[None, :] * stride_od - ) - out_ptrs = Out + off_o - tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) - - return - - @torch.no_grad() - def llama2_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): - if triton.__version__ >= "2.1.0": - BLOCK = 128 - # shape constraints - Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] - assert Lq == Lk and Lk == Lv - assert Lk in {16, 32, 64, 128} - sm_scale = 1.0 / (Lq**0.5) # 计算scale系数 - batch, head = b_seq_len.shape[0], q.shape[1] - kv_group_num = q.shape[1] // k.shape[1] - - grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head, - - num_warps = 4 if Lk <= 64 else 8 - _fwd_kernel_latest[grid]( - q, - k, - v, - sm_scale, - b_start_loc, - b_seq_len, - o, - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - v.stride(0), - v.stride(1), - v.stride(2), - o.stride(0), - o.stride(1), - o.stride(2), - kv_group_num=kv_group_num, - BLOCK_M=BLOCK, - BLOCK_DMODEL=Lk, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - return - - elif triton.__version__ == "2.0.0": - BLOCK = 128 - # shape constraints - Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] - assert Lq == Lk and Lk == Lv - assert Lk in {16, 32, 64, 128} - - sm_scale = 1.0 / (Lq**0.5) - batch, head = b_seq_len.shape[0], q.shape[1] - kv_group_num = q.shape[1] // k.shape[1] - - grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) - tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32) - num_warps = 4 if Lk <= 64 else 8 - # num_warps = 4 - _fwd_kernel_old[grid]( - q, - k, - v, - sm_scale, - b_start_loc, - b_seq_len, - tmp, - o, - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - v.stride(0), - v.stride(1), - v.stride(2), - o.stride(0), - o.stride(1), - o.stride(2), - tmp.stride(0), - tmp.stride(1), - tmp.stride(2), - kv_group_num=kv_group_num, - BLOCK_M=BLOCK, - BLOCK_DMODEL=Lk, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - return + + return \ No newline at end of file diff --git a/colossalai/kernel/triton/copy_kv_cache_dest.py b/colossalai/kernel/triton/copy_kv_cache_dest.py index 0520bc111384..dee2566a1659 100644 --- a/colossalai/kernel/triton/copy_kv_cache_dest.py +++ b/colossalai/kernel/triton/copy_kv_cache_dest.py @@ -11,6 +11,7 @@ if HAS_TRITON: + # adapted from https://github.com/ModelTC/lightllm/blob/5c559dd7981ed67679a08a1e09a88fb4c1550b3a/lightllm/common/triton_kernel/destindex_copy_kv.py @triton.jit def _fwd_copy_kv_cache_dest( kv_cache_ptr, @@ -42,6 +43,7 @@ def _fwd_copy_kv_cache_dest( tl.store(o_ptrs, k, mask=offs_h[:, None] < head_num) return + # adepted from https://github.com/ModelTC/lightllm/blob/5c559dd7981ed67679a08a1e09a88fb4c1550b3a/lightllm/common/triton_kernel/destindex_copy_kv.py @torch.no_grad() def copy_kv_cache_to_dest(k_ptr, dest_index_ptr, out): seq_len = dest_index_ptr.shape[0] diff --git a/colossalai/kernel/triton/int8_rotary_embedding_kernel.py b/colossalai/kernel/triton/int8_rotary_embedding_kernel.py new file mode 100644 index 000000000000..537dd164d1ab --- /dev/null +++ b/colossalai/kernel/triton/int8_rotary_embedding_kernel.py @@ -0,0 +1,117 @@ +# Adapted from ModelTC https://github.com/ModelTC/lightllm +import torch +import triton +import triton.language as tl + + +@triton.jit +def _rotary_kernel( + q, + input_scale, + output_scale, + Cos, + Sin, + q_bs_stride, + q_h_stride, + q_d_stride, + cos_bs_stride, + cos_d_stride, + total_len, + HEAD_NUM: tl.constexpr, + BLOCK_HEAD: tl.constexpr, + BLOCK_SEQ: tl.constexpr, + HEAD_DIM: tl.constexpr, +): + current_head_index = tl.program_id(0) + current_seq_index = tl.program_id(1) + + dim_range0 = tl.arange(0, HEAD_DIM // 2) + dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM) + + current_head_range = current_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD) + current_seq_range = current_seq_index * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ) + + off_q0 = ( + current_seq_range[:, None, None] * q_bs_stride + + current_head_range[None, :, None] * q_h_stride + + dim_range0[None, None, :] * q_d_stride + ) + off_q1 = ( + current_seq_range[:, None, None] * q_bs_stride + + current_head_range[None, :, None] * q_h_stride + + dim_range1[None, None, :] * q_d_stride + ) + + off_dimcos_sin = current_seq_range[:, None, None] * cos_bs_stride + dim_range0[None, None, :] * cos_d_stride + + q0 = tl.load( + q + off_q0, + mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), + other=0.0, + ) + q1 = tl.load( + q + off_q1, + mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), + other=0.0, + ) + + cos = tl.load(Cos + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0) + sin = tl.load(Sin + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0) + + q0 = q0.to(tl.float32) * input_scale + q1 = q1.to(tl.float32) * input_scale + + out0 = (q0 * cos - q1 * sin) / output_scale + out1 = (q0 * sin + q1 * cos) / output_scale + + out0 = out0.to(tl.int8) + out1 = out1.to(tl.int8) + + tl.store( + q + off_q0, + out0, + mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), + ) + tl.store( + q + off_q1, + out1, + mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), + ) + + return + + +@torch.no_grad() +def int8_rotary_embedding_fwd(q, cos, sin, input_scale, output_scale): + total_len = q.shape[0] + head_num = q.shape[1] + head_dim = q.shape[2] + assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}" + BLOCK_HEAD = 4 + BLOCK_SEQ = 32 + grid = (triton.cdiv(head_num, BLOCK_HEAD), triton.cdiv(total_len, BLOCK_SEQ)) + if head_dim >= 128: + num_warps = 8 + else: + num_warps = 4 + + _rotary_kernel[grid]( + q, + input_scale, + output_scale, + cos, + sin, + q.stride(0), + q.stride(1), + q.stride(2), + cos.stride(0), + cos.stride(1), + total_len, + HEAD_NUM=head_num, + BLOCK_HEAD=BLOCK_HEAD, + BLOCK_SEQ=BLOCK_SEQ, + HEAD_DIM=head_dim, + num_warps=num_warps, + num_stages=1, + ) + return diff --git a/colossalai/kernel/triton/rms_norm.py b/colossalai/kernel/triton/rms_norm.py deleted file mode 100644 index d5d6f9d85df1..000000000000 --- a/colossalai/kernel/triton/rms_norm.py +++ /dev/null @@ -1,71 +0,0 @@ -import torch - -try: - import triton - import triton.language as tl - - HAS_TRITON = True -except ImportError: - HAS_TRITON = False - print("please install triton from https://github.com/openai/triton") - - -if HAS_TRITON: - """ - this kernel function is modified from - https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/rmsnorm.py - """ - - @triton.jit - def _rms_norm_fwd_fused( - X, # pointer to the input - Y, # pointer to the output - W, # pointer to the weights - stride, # how much to increase the pointer when moving by 1 row - N, # number of columns in X - eps, # epsilon to avoid division by zero - BLOCK_SIZE: tl.constexpr, - ): - # Map the program id to the row of X and Y it should compute. - row = tl.program_id(0) - Y += row * stride - X += row * stride - # Compute variance - _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) - for off in range(0, N, BLOCK_SIZE): - cols = off + tl.arange(0, BLOCK_SIZE) - x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) - _var += x * x - var = tl.sum(_var, axis=0) / N - rstd = 1 / tl.sqrt(var + eps) - # Normalize and apply linear transformation - for off in range(0, N, BLOCK_SIZE): - cols = off + tl.arange(0, BLOCK_SIZE) - mask = cols < N - w = tl.load(W + cols, mask=mask).to(tl.float32) - x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32) - x_hat = x * rstd - y = x_hat * w - # Write output - tl.store(Y + cols, y.to(tl.float16), mask=mask) - - def rmsnorm_forward(x, weight, eps): - # allocate output - y = torch.empty_like(x) - # reshape input data into 2D tensor - x_arg = x.view(-1, x.shape[-1]) - M, N = x_arg.shape - # Less than 64KB per feature: enqueue fused kernel - MAX_FUSED_SIZE = 65536 // x.element_size() - BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) - # print("BLOCK_SIZE:", BLOCK_SIZE) - if N > BLOCK_SIZE: - raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") - # heuristics for number of warps - num_warps = min(max(BLOCK_SIZE // 256, 1), 8) - # print(BLOCK_SIZE, num_warps, "block_size, numwarps") - BLOCK_SIZE = 128 * 2 * 2 * 2 * 2 * 2 * 2 * 2 - num_warps = 8 - # enqueue kernel - _rms_norm_fwd_fused[(M,)](x_arg, y, weight, x_arg.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps) - return y diff --git a/colossalai/kernel/triton/rotary_embedding_kernel.py b/colossalai/kernel/triton/rotary_embedding_kernel.py deleted file mode 100644 index fd74ba817551..000000000000 --- a/colossalai/kernel/triton/rotary_embedding_kernel.py +++ /dev/null @@ -1,212 +0,0 @@ -# Adapted from ModelTC https://github.com/ModelTC/lightllm -import torch -import triton -import triton.language as tl - - -@triton.jit -def _rotary_kernel( - q, - Cos, - Sin, - q_bs_stride, - q_h_stride, - q_d_stride, - cos_bs_stride, - cos_d_stride, - total_len, - HEAD_NUM: tl.constexpr, - BLOCK_HEAD: tl.constexpr, - BLOCK_SEQ: tl.constexpr, - HEAD_DIM: tl.constexpr, -): - current_head_index = tl.program_id(0) - current_seq_index = tl.program_id(1) - - current_head_range = current_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD) - current_seq_range = current_seq_index * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ) - - dim_range0 = tl.arange(0, HEAD_DIM // 2) - dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM) - - off_q0 = ( - current_seq_range[:, None, None] * q_bs_stride - + current_head_range[None, :, None] * q_h_stride - + dim_range0[None, None, :] * q_d_stride - ) - off_q1 = ( - current_seq_range[:, None, None] * q_bs_stride - + current_head_range[None, :, None] * q_h_stride - + dim_range1[None, None, :] * q_d_stride - ) - - off_dimcos_sin = current_seq_range[:, None, None] * cos_bs_stride + dim_range0[None, None, :] * cos_d_stride - - q0 = tl.load( - q + off_q0, - mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), - other=0.0, - ) - q1 = tl.load( - q + off_q1, - mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), - other=0.0, - ) - - cos = tl.load(Cos + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0) - sin = tl.load(Sin + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0) - - out0 = q0 * cos - q1 * sin - out1 = q0 * sin + q1 * cos - - tl.store( - q + off_q0, - out0, - mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), - ) - tl.store( - q + off_q1, - out1, - mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), - ) - - return - - -@torch.no_grad() -def rotary_embedding_fwd(q, cos, sin): - total_len = q.shape[0] - head_num = q.shape[1] - head_dim = q.shape[2] - assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}" - BLOCK_HEAD = 4 - BLOCK_SEQ = 32 - grid = (triton.cdiv(head_num, BLOCK_HEAD), triton.cdiv(total_len, BLOCK_SEQ)) - if head_dim >= 128: - num_warps = 8 - else: - num_warps = 4 - - _rotary_kernel[grid]( - q, - cos, - sin, - q.stride(0), - q.stride(1), - q.stride(2), - cos.stride(0), - cos.stride(1), - total_len, - HEAD_NUM=head_num, - BLOCK_HEAD=BLOCK_HEAD, - BLOCK_SEQ=BLOCK_SEQ, - HEAD_DIM=head_dim, - num_warps=num_warps, - num_stages=1, - ) - return - - -class Llama2Forwards: - @staticmethod - @triton.jit - def _rotary_kernel( - Q, - Cos, - Sin, - stride_qbs, - stride_qh, - stride_qd, - stride_cosbs, - stride_cosd, - stride_sinbs, - stride_sind, - max_total_len, - H, # N_CTX - BLOCK_HEAD: tl.constexpr, - BLOCK_SEQ: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - ): - cur_head_index = tl.program_id(0) - cur_seq_index = tl.program_id(1) - - cur_head_range = cur_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD) - cur_seq_range = cur_seq_index * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ) - - dim_range0 = tl.arange(0, BLOCK_DMODEL // 2) * 2 - dim_range1 = dim_range0 + 1 - off_q0 = ( - cur_seq_range[:, None, None] * stride_qbs - + cur_head_range[None, :, None] * stride_qh - + dim_range0[None, None, :] * stride_qd - ) - off_q1 = ( - cur_seq_range[:, None, None] * stride_qbs - + cur_head_range[None, :, None] * stride_qh - + dim_range1[None, None, :] * stride_qd - ) - - cos_range = tl.arange(0, BLOCK_DMODEL // 2) - off_dimcos_sin = cur_seq_range[:, None, None] * stride_cosbs + cos_range[None, None, :] * stride_cosd - - q0 = tl.load( - Q + off_q0, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < H), - other=0.0, - ) - q1 = tl.load( - Q + off_q1, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < H), - other=0.0, - ) - - cos = tl.load(Cos + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) - sin = tl.load(Sin + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) - - out0 = q0 * cos - q1 * sin - out1 = q0 * sin + q1 * cos - - tl.store( - Q + off_q0, out0, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < H) - ) - tl.store( - Q + off_q1, out1, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < H) - ) - - return - - @staticmethod - @torch.no_grad() - def rotary_emb_fwd(q, cos, sin): - total_len = q.shape[0] - head_num = q.shape[1] - head_dim = q.shape[2] // 2 - assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}" - BLOCK_HEAD = 4 - BLOCK_SEQ = 32 - grid = (triton.cdiv(head_num, BLOCK_HEAD), triton.cdiv(total_len, BLOCK_SEQ)) - if head_dim >= 128: - num_warps = 8 - else: - num_warps = 4 - - Llama2Forwards._rotary_kernel[grid]( - q, - cos, - sin, - q.stride(0), - q.stride(1), - q.stride(2), - cos.stride(0), - cos.stride(1), - sin.stride(0), - sin.stride(1), - total_len, - head_num, - BLOCK_HEAD=BLOCK_HEAD, - BLOCK_SEQ=BLOCK_SEQ, - BLOCK_DMODEL=head_dim, - num_warps=num_warps, - num_stages=1, - ) - return diff --git a/colossalai/kernel/triton/self_attention_nofusion.py b/colossalai/kernel/triton/self_attention_nofusion.py index 4b56c8afd67f..50d6786bd940 100644 --- a/colossalai/kernel/triton/self_attention_nofusion.py +++ b/colossalai/kernel/triton/self_attention_nofusion.py @@ -12,6 +12,7 @@ from .qkv_matmul_kernel import qkv_gemm_4d_kernel from .softmax import softmax_kernel + # adpeted from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/ops/transformer/inference/triton/triton_matmul_kernel.py#L312 def self_attention_forward_without_fusion( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, input_mask: torch.Tensor, scale: float ): @@ -141,6 +142,7 @@ def self_attention_forward_without_fusion( ) return output.view(batches, -1, d_model) + # modified from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/ops/transformer/inference/triton/attention.py#L212 def self_attention_compute_using_triton( qkv, input_mask, layer_past, alibi, scale, head_size, triangular=False, use_flash=False ): diff --git a/colossalai/kernel/triton/smooth_attention.py b/colossalai/kernel/triton/smooth_attention.py new file mode 100644 index 000000000000..ee0df6a74eaa --- /dev/null +++ b/colossalai/kernel/triton/smooth_attention.py @@ -0,0 +1,652 @@ +import math + +import torch + +try: + import triton + import triton.language as tl + + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +if HAS_TRITON: + """ + this function is modified from + https://github.com/ModelTC/lightllm/blob/f093edc20683ac3ea1bca3fb5d8320a0dd36cf7b/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L10 + """ + + @triton.jit + def _context_flash_attention_kernel( + Q, + K, + V, + q_input_scale, + k_input_scale, + v_input_scale, + pv_output_scale, + sm_scale, + B_Start_Loc, + B_Seqlen, + TMP, + alibi_ptr, + Out, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_tmp_b, + stride_tmp_h, + stride_tmp_s, + # suggtest set-up 64, 128, 256, 512 + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + batch_id = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + + # get batch info + cur_batch_seq_len = tl.load(B_Seqlen + batch_id) + cur_batch_start_index = tl.load(B_Start_Loc + batch_id) + block_start_loc = BLOCK_M * start_m + + load_p_ptrs = ( + Q + + (cur_batch_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + + offs_d[None, :] * stride_qd + ) + q = tl.load(load_p_ptrs, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) + q = q.to(tl.float16) * q_input_scale.to(tl.float16) + + k_ptrs = K + offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd + v_ptrs = V + offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd + t_ptrs = TMP + batch_id * stride_tmp_b + cur_head * stride_tmp_h + offs_m * stride_tmp_s + + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + if alibi_ptr is not None: + alibi_m = tl.load(alibi_ptr + cur_head) + + block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) + + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + k = tl.load( + k_ptrs + (cur_batch_start_index + start_n) * stride_kbs, + mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, + other=0.0, + ) + k = k.to(tl.float16) * k_input_scale.to(tl.float16) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk *= sm_scale + + if alibi_ptr is not None: + alibi_loc = offs_m[:, None] - (start_n + offs_n[None, :]) + qk -= alibi_loc * alibi_m + + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) + + m_ij = tl.max(qk, 1) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) + alpha = tl.exp(m_i - m_i_new) + beta = tl.exp(m_ij - m_i_new) + l_i_new = alpha * l_i + beta * l_ij + # -- update output accumulator -- + # scale p + p_scale = beta / l_i_new + p = p * p_scale[:, None] + # scale acc + acc_scale = l_i / l_i_new * alpha + tl.store(t_ptrs, acc_scale) + acc_scale = tl.load(t_ptrs) + acc = acc * acc_scale[:, None] + # update acc + v = tl.load( + v_ptrs + (cur_batch_start_index + start_n) * stride_vbs, + mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, + other=0.0, + ) + + v = v.to(tl.float16) * v_input_scale.to(tl.float16) + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + acc = (acc / pv_output_scale.to(tl.float16)).to(tl.int8) + off_o = ( + (cur_batch_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od + ) + out_ptrs = Out + off_o + tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) + return + + + + @torch.no_grad() + def smooth_llama_context_attn_fwd( + q, k, v, o, q_input_scale, k_input_scale, v_input_scale, pv_output_scale, b_start_loc, b_seq_len, max_input_len + ): + + BLOCK = 128 + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk, "context process only supports equal query, key, value length" + assert Lk == Lv, "context process only supports equal query, key, value length" + assert Lk in {16, 32, 64, 128} + BLOCK_N = 128 + sm_scale = 1.0 / math.sqrt(Lk) + batch, head = b_seq_len.shape[0], q.shape[1] + grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) + + tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32) + num_warps = 4 if Lk <= 64 else 8 + + _context_flash_attention_kernel[grid]( + q, + k, + v, + q_input_scale, + k_input_scale, + v_input_scale, + pv_output_scale, + sm_scale, + b_start_loc, + b_seq_len, + tmp, + None, + o, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + tmp.stride(0), + tmp.stride(1), + tmp.stride(2), + BLOCK_M=BLOCK, + BLOCK_DMODEL=Lk, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return + + @triton.jit + def _token_attn_1_kernel( + Q, + K, + q_input_scale, + k_input_scale, + sm_scale, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seqlen, + max_kv_cache_len, + attn_out, + kv_cache_loc_b_stride, + kv_cache_loc_s_stride, + q_batch_stride, + q_head_stride, + q_head_dim_stride, + k_batch_stride, + k_head_stride, + k_head_dim_stride, + attn_head_stride, + attn_batch_stride, + HEAD_DIM: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + current_batch = tl.program_id(0) + current_head = tl.program_id(1) + start_n = tl.program_id(2) + + offs_d = tl.arange(0, HEAD_DIM) + current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch) + current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch) + + current_batch_start_index = max_kv_cache_len - current_batch_seq_len + current_batch_end_index = max_kv_cache_len + + off_q = current_batch * q_batch_stride + current_head * q_head_stride + offs_d * q_head_dim_stride + + offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) + + block_stard_index = start_n * BLOCK_N + block_mask = tl.where(block_stard_index < current_batch_seq_len, 1, 0) + + for start_mark in range(0, block_mask, 1): + q = tl.load(Q + off_q + start_mark) + q = q.to(tl.float16) * q_input_scale.to(tl.float16) + offs_n_new = current_batch_start_index + offs_n + k_loc = tl.load( + kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new, + mask=offs_n_new < current_batch_end_index, + other=0, + ) + off_k = k_loc[:, None] * k_batch_stride + current_head * k_head_stride + offs_d[None, :] * k_head_dim_stride + k = tl.load(K + off_k, mask=offs_n_new[:, None] < current_batch_end_index, other=0.0) + k = k.to(tl.float16) * k_input_scale.to(tl.float16) + att_value = tl.sum(q[None, :] * k, 1) + att_value *= sm_scale + off_o = current_head * attn_head_stride + (current_batch_in_all_start_index + offs_n) * attn_batch_stride + tl.store(attn_out + off_o, att_value, mask=offs_n_new < current_batch_end_index) + return + + @triton.jit + def _token_attn_1_alibi_kernel( + Q, + K, + q_input_scale, + k_input_scale, + sm_scale, + alibi, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seqlen, + max_kv_cache_len, + attn_out, + kv_cache_loc_b_stride, + kv_cache_loc_s_stride, + q_batch_stride, + q_head_stride, + q_head_dim_stride, + k_batch_stride, + k_head_stride, + k_head_dim_stride, + attn_head_stride, + attn_batch_stride, + HEAD_DIM: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + current_batch = tl.program_id(0) + current_head = tl.program_id(1) + start_n = tl.program_id(2) + + offs_d = tl.arange(0, HEAD_DIM) + current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch) + current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch) + + current_batch_start_index = max_kv_cache_len - current_batch_seq_len + current_batch_end_index = max_kv_cache_len + + off_q = current_batch * q_batch_stride + current_head * q_head_stride + offs_d * q_head_dim_stride + + offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) + + block_stard_index = start_n * BLOCK_N + block_mask = tl.where(block_stard_index < current_batch_seq_len, 1, 0) + + for start_mark in range(0, block_mask, 1): + alibi_m = tl.load(alibi + current_head) + q = tl.load(Q + off_q + start_mark) + q = q.to(tl.float16) * q_input_scale.to(tl.float16) + + offs_n_new = current_batch_start_index + offs_n + k_loc = tl.load( + kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new, + mask=offs_n_new < current_batch_end_index, + other=0, + ) + off_k = k_loc[:, None] * k_batch_stride + current_head * k_head_stride + offs_d[None, :] * k_head_dim_stride + k = tl.load(K + off_k, mask=offs_n_new[:, None] < current_batch_end_index, other=0.0) + k = k.to(tl.float16) * k_input_scale.to(tl.float16) + att_value = tl.sum(q[None, :] * k, 1) + att_value *= sm_scale + att_value -= alibi_m * (current_batch_seq_len - 1 - offs_n) + off_o = current_head * attn_head_stride + (current_batch_in_all_start_index + offs_n) * attn_batch_stride + tl.store(attn_out + off_o, att_value, mask=offs_n_new < current_batch_end_index) + return + + @torch.no_grad() + def token_attn_fwd_1( + q, + k, + attn_out, + q_input_scale, + k_input_scale, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seqlen, + max_kv_cache_len, + alibi=None, + ): + BLOCK = 32 + # shape constraints + q_head_dim, k_head_dim = q.shape[-1], k.shape[-1] + assert q_head_dim == k_head_dim + assert k_head_dim in {16, 32, 64, 128} + sm_scale = 1.0 / (k_head_dim**0.5) + + batch, head_num = kv_cache_loc.shape[0], q.shape[1] + + grid = (batch, head_num, triton.cdiv(max_kv_cache_len, BLOCK)) + + num_warps = 4 if k_head_dim <= 64 else 8 + num_warps = 2 + + if alibi is not None: + _token_attn_1_alibi_kernel[grid]( + q, + k, + q_input_scale, + k_input_scale, + sm_scale, + alibi, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seqlen, + max_kv_cache_len, + attn_out, + kv_cache_loc.stride(0), + kv_cache_loc.stride(1), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + attn_out.stride(0), + attn_out.stride(1), + HEAD_DIM=k_head_dim, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + else: + _token_attn_1_kernel[grid]( + q, + k, + q_input_scale, + k_input_scale, + sm_scale, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seqlen, + max_kv_cache_len, + attn_out, + kv_cache_loc.stride(0), + kv_cache_loc.stride(1), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + attn_out.stride(0), + attn_out.stride(1), + HEAD_DIM=k_head_dim, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return + + @triton.jit + def _token_attn_softmax_fwd( + softmax_logics, + kv_cache_start_loc, + kv_cache_seqlen, + softmax_prob_out, + logics_head_dim_stride, + logics_batch_stride, + prob_head_dim_stride, + prob_batch_stride, + BLOCK_SIZE: tl.constexpr, + ): + current_batch = tl.program_id(0) + current_head = tl.program_id(1) + + col_offsets = tl.arange(0, BLOCK_SIZE) + current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch) + current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch) + + row = tl.load( + softmax_logics + + current_head * logics_head_dim_stride + + (current_batch_in_all_start_index + col_offsets) * logics_batch_stride, + mask=col_offsets < current_batch_seq_len, + other=-float("inf"), + ).to(tl.float32) + + row_minus_max = row - tl.max(row, axis=0) + numerator = tl.exp(row_minus_max) + denominator = tl.sum(numerator, axis=0) + softmax_output = numerator / denominator + + tl.store( + softmax_prob_out + + current_head * prob_head_dim_stride + + (current_batch_in_all_start_index + col_offsets) * prob_batch_stride, + softmax_output, + mask=col_offsets < current_batch_seq_len, + ) + return + + @torch.no_grad() + def token_attn_softmax_fwd(softmax_logics, kv_cache_start_loc, kv_cache_seqlen, softmax_prob_out, max_kv_cache_len): + BLOCK_SIZE = triton.next_power_of_2(max_kv_cache_len) + batch, head_num = kv_cache_start_loc.shape[0], softmax_logics.shape[0] + + num_warps = 4 + if BLOCK_SIZE >= 2048: + num_warps = 8 + if BLOCK_SIZE >= 4096: + num_warps = 16 + + _token_attn_softmax_fwd[(batch, head_num)]( + softmax_logics, + kv_cache_start_loc, + kv_cache_seqlen, + softmax_prob_out, + softmax_logics.stride(0), + softmax_logics.stride(1), + softmax_prob_out.stride(0), + softmax_prob_out.stride(1), + num_warps=num_warps, + BLOCK_SIZE=BLOCK_SIZE, + ) + return + + @triton.jit + def _token_attn_2_kernel( + Prob, + V, + attn_out, + v_input_scale, + pv_output_scale, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seqlen, + max_kv_cache_len, + kv_cache_loc_b_stride, + kv_cache_loc_s_stride, + prob_head_dim_stride, + prob_batch_stride, + v_batch_stride, + v_head_stride, + v_head_dim_stride, + attn_out_batch_stride, + attn_out_head_stride, + attn_out_head_dim_stride, + HEAD_DIM: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + current_batch = tl.program_id(0) + current_head = tl.program_id(1) + + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, HEAD_DIM) + current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch) + current_batch_start_index = max_kv_cache_len - current_batch_seq_len + current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch) + + v_loc_off = current_batch * kv_cache_loc_b_stride + (current_batch_start_index + offs_n) * kv_cache_loc_s_stride + p_offs = current_head * prob_head_dim_stride + (current_batch_in_all_start_index + offs_n) * prob_batch_stride + v_offs = current_head * v_head_stride + offs_d[None, :] * v_head_dim_stride + + acc = tl.zeros([HEAD_DIM], dtype=tl.float32) + for start_n in range(0, current_batch_seq_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + p_value = tl.load( + Prob + p_offs + start_n * kv_cache_loc_s_stride, + mask=(start_n + offs_n) < current_batch_seq_len, + other=0.0, + ) + v_loc = tl.load( + kv_cache_loc + v_loc_off + start_n * kv_cache_loc_s_stride, + mask=(start_n + offs_n) < current_batch_seq_len, + other=0.0, + ) + v_value = tl.load( + V + v_offs + v_loc[:, None] * v_batch_stride, + mask=(start_n + offs_n[:, None]) < current_batch_seq_len, + other=0.0, + ) + v_value = v_value.to(tl.float16) * v_input_scale.to(tl.float16) + acc += tl.sum(p_value[:, None] * v_value, 0) + + acc = (acc / pv_output_scale.to(tl.float16)).to(tl.int8) + off_o = ( + current_batch * attn_out_batch_stride + + current_head * attn_out_head_stride + + offs_d * attn_out_head_dim_stride + ) + out_ptrs = attn_out + off_o + tl.store(out_ptrs, acc) + return + + @torch.no_grad() + def token_attn_fwd_2( + prob, + v, + attn_out, + v_input_scale, + pv_output_scale, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seqlen, + max_kv_cache_len, + ): + if triton.__version__ >= "2.1.0": + BLOCK = 128 + else: + BLOCK = 64 + batch, head = kv_cache_loc.shape[0], v.shape[1] + grid = (batch, head) + num_warps = 4 + dim = v.shape[-1] + + _token_attn_2_kernel[grid]( + prob, + v, + attn_out, + v_input_scale, + pv_output_scale, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seqlen, + max_kv_cache_len, + kv_cache_loc.stride(0), + kv_cache_loc.stride(1), + prob.stride(0), + prob.stride(1), + v.stride(0), + v.stride(1), + v.stride(2), + attn_out.stride(0), + attn_out.stride(1), + attn_out.stride(2), + HEAD_DIM=dim, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return + + @torch.no_grad() + def smooth_token_attention_fwd( + q, + k, + v, + attn_out, + q_input_scale, + k_input_scale, + v_input_scale, + pv_output_scale, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seq_len, + max_len_in_batch, + alibi=None, + ): + head_num = k.shape[1] + batch_size = kv_cache_seq_len.shape[0] + calcu_shape1 = (batch_size, head_num, k.shape[2]) + total_token_num = k.shape[0] + + att_m_tensor = torch.empty((head_num, total_token_num), dtype=torch.float32, device="cuda") + + token_attn_fwd_1( + q.view(calcu_shape1), + k, + att_m_tensor, + q_input_scale, + k_input_scale, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seq_len, + max_len_in_batch, + alibi=alibi, + ) + + prob = torch.empty_like(att_m_tensor) + + token_attn_softmax_fwd(att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch) + att_m_tensor = None + token_attn_fwd_2( + prob, + v, + attn_out.view(calcu_shape1), + v_input_scale, + pv_output_scale, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seq_len, + max_len_in_batch, + ) + + prob = None + + return diff --git a/colossalai/kernel/triton/token_attention_kernel.py b/colossalai/kernel/triton/token_attention_kernel.py index c27394f0f9cf..8dc919bad125 100644 --- a/colossalai/kernel/triton/token_attention_kernel.py +++ b/colossalai/kernel/triton/token_attention_kernel.py @@ -12,401 +12,78 @@ HAS_TRITON = False print("please install triton from https://github.com/openai/triton") -if HAS_TRITON: - - @triton.jit - def _token_attn_1_kernel( - Q, - K, - sm_scale, - kv_cache_loc, - kv_cache_start_loc, - kv_cache_seqlen, - max_kv_cache_len, - attn_out, - kv_cache_loc_b_stride, - kv_cache_loc_s_stride, - q_batch_stride, - q_head_stride, - q_head_dim_stride, - k_batch_stride, - k_head_stride, - k_head_dim_stride, - attn_head_stride, - attn_batch_stride, - HEAD_DIM: tl.constexpr, - BLOCK_N: tl.constexpr, - ): - current_batch = tl.program_id(0) - current_head = tl.program_id(1) - start_n = tl.program_id(2) - - offs_d = tl.arange(0, HEAD_DIM) - current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch) - current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch) - - current_batch_start_index = max_kv_cache_len - current_batch_seq_len - current_batch_end_index = max_kv_cache_len - - off_q = current_batch * q_batch_stride + current_head * q_head_stride + offs_d * q_head_dim_stride - - offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) - - block_stard_index = start_n * BLOCK_N - block_mask = tl.where(block_stard_index < current_batch_seq_len, 1, 0) - - for start_mark in range(0, block_mask, 1): - q = tl.load(Q + off_q + start_mark) - offs_n_new = current_batch_start_index + offs_n - k_loc = tl.load( - kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new, - mask=offs_n_new < current_batch_end_index, - other=0, - ) - off_k = k_loc[:, None] * k_batch_stride + current_head * k_head_stride + offs_d[None, :] * k_head_dim_stride - k = tl.load(K + off_k, mask=offs_n_new[:, None] < current_batch_end_index, other=0.0) - att_value = tl.sum(q[None, :] * k, 1) - att_value *= sm_scale - off_o = current_head * attn_head_stride + (current_batch_in_all_start_index + offs_n) * attn_batch_stride - tl.store(attn_out + off_o, att_value, mask=offs_n_new < current_batch_end_index) - return - - @triton.jit - def _token_attn_1_alibi_kernel( - Q, - K, - sm_scale, - alibi, - kv_cache_loc, - kv_cache_start_loc, - kv_cache_seqlen, - max_kv_cache_len, - attn_out, - kv_cache_loc_b_stride, - kv_cache_loc_s_stride, - q_batch_stride, - q_head_stride, - q_head_dim_stride, - k_batch_stride, - k_head_stride, - k_head_dim_stride, - attn_head_stride, - attn_batch_stride, - HEAD_DIM: tl.constexpr, - BLOCK_N: tl.constexpr, - ): - current_batch = tl.program_id(0) - current_head = tl.program_id(1) - start_n = tl.program_id(2) - - offs_d = tl.arange(0, HEAD_DIM) - current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch) - current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch) - - current_batch_start_index = max_kv_cache_len - current_batch_seq_len - current_batch_end_index = max_kv_cache_len - - off_q = current_batch * q_batch_stride + current_head * q_head_stride + offs_d * q_head_dim_stride - - offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) - - block_stard_index = start_n * BLOCK_N - block_mask = tl.where(block_stard_index < current_batch_seq_len, 1, 0) +try: + from lightllm.models.llama2.triton_kernel.token_attention_nopad_att1 import ( + token_att_fwd as lightllm_llama2_token_att_fwd, + ) + from lightllm.models.llama2.triton_kernel.token_attention_nopad_reduceV import ( + token_att_fwd2 as lightllm_llama2_token_att_fwd2, + ) + from lightllm.models.llama2.triton_kernel.token_attention_nopad_softmax import ( + token_softmax_fwd as lightllm_llama2_token_softmax_fwd, + ) + + from lightllm.models.llama.triton_kernel.token_attention_nopad_reduceV import token_att_fwd2 as lightllm_llama_token_att_fw2 + from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd as lightllm_llama_token_att_fwd + from lightllm.models.llama.triton_kernel.token_attention_nopad_softmax import token_softmax_fwd as lightllm_llama_token_softmax_fwd + from lightllm.models.bloom.triton_kernel.token_attention_nopad_att1 import token_att_fwd as lightllm_bloom_token_att_fwd + + HAS_TRITON_TOKEN_ATTENTION = True +except ImportError: + print("unable to import lightllm kernels") + HAS_TRITON_TOKEN_ATTENTION = False - for start_mark in range(0, block_mask, 1): - alibi_m = tl.load(alibi + current_head) - q = tl.load(Q + off_q + start_mark) - offs_n_new = current_batch_start_index + offs_n - k_loc = tl.load( - kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new, - mask=offs_n_new < current_batch_end_index, - other=0, - ) - off_k = k_loc[:, None] * k_batch_stride + current_head * k_head_stride + offs_d[None, :] * k_head_dim_stride - k = tl.load(K + off_k, mask=offs_n_new[:, None] < current_batch_end_index, other=0.0) - att_value = tl.sum(q[None, :] * k, 1) - att_value *= sm_scale - att_value -= alibi_m * (current_batch_seq_len - 1 - offs_n) - off_o = current_head * attn_head_stride + (current_batch_in_all_start_index + offs_n) * attn_batch_stride - tl.store(attn_out + off_o, att_value, mask=offs_n_new < current_batch_end_index) - return +if HAS_TRITON: @torch.no_grad() - def token_attn_fwd_1( - q, k, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seqlen, max_kv_cache_len, alibi=None + def token_attention_fwd( + q, k, v, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_len_in_batch, alibi=None ): - BLOCK = 32 - # shape constraints - q_head_dim, k_head_dim = q.shape[-1], k.shape[-1] - assert q_head_dim == k_head_dim - assert k_head_dim in {16, 32, 64, 128} - sm_scale = 1.0 / (k_head_dim**0.5) - - batch, head_num = kv_cache_loc.shape[0], q.shape[1] - - grid = (batch, head_num, triton.cdiv(max_kv_cache_len, BLOCK)) + head_num = k.shape[1] + batch_size = kv_cache_seq_len.shape[0] + calcu_shape1 = (batch_size, head_num, k.shape[2]) + total_token_num = k.shape[0] - num_warps = 4 if k_head_dim <= 64 else 8 - num_warps = 2 + att_m_tensor = torch.empty((head_num, total_token_num), dtype=q.dtype, device="cuda") - if alibi is not None: - _token_attn_1_alibi_kernel[grid]( - q, + if alibi is None: + lightllm_llama_token_att_fwd( + q.view(calcu_shape1), k, - sm_scale, - alibi, + att_m_tensor, kv_cache_loc, kv_cache_start_loc, - kv_cache_seqlen, - max_kv_cache_len, - attn_out, - kv_cache_loc.stride(0), - kv_cache_loc.stride(1), - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - attn_out.stride(0), - attn_out.stride(1), - HEAD_DIM=k_head_dim, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, + kv_cache_seq_len, + max_len_in_batch, ) else: - _token_attn_1_kernel[grid]( - q, + lightllm_bloom_token_att_fwd( + q.view(calcu_shape1), k, - sm_scale, + att_m_tensor, + alibi, kv_cache_loc, kv_cache_start_loc, - kv_cache_seqlen, - max_kv_cache_len, - attn_out, - kv_cache_loc.stride(0), - kv_cache_loc.stride(1), - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - attn_out.stride(0), - attn_out.stride(1), - HEAD_DIM=k_head_dim, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - return - - @triton.jit - def _token_attn_softmax_fwd( - softmax_logics, - kv_cache_start_loc, - kv_cache_seqlen, - softmax_prob_out, - logics_head_dim_stride, - logics_batch_stride, - prob_head_dim_stride, - prob_batch_stride, - BLOCK_SIZE: tl.constexpr, - ): - current_batch = tl.program_id(0) - current_head = tl.program_id(1) - - col_offsets = tl.arange(0, BLOCK_SIZE) - current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch) - current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch) - - row = tl.load( - softmax_logics - + current_head * logics_head_dim_stride - + (current_batch_in_all_start_index + col_offsets) * logics_batch_stride, - mask=col_offsets < current_batch_seq_len, - other=-float("inf"), - ).to(tl.float32) - - row_minus_max = row - tl.max(row, axis=0) - numerator = tl.exp(row_minus_max) - denominator = tl.sum(numerator, axis=0) - softmax_output = numerator / denominator - - tl.store( - softmax_prob_out - + current_head * prob_head_dim_stride - + (current_batch_in_all_start_index + col_offsets) * prob_batch_stride, - softmax_output, - mask=col_offsets < current_batch_seq_len, - ) - return - - @torch.no_grad() - def token_attn_softmax_fwd(softmax_logics, kv_cache_start_loc, kv_cache_seqlen, softmax_prob_out, max_kv_cache_len): - BLOCK_SIZE = triton.next_power_of_2(max_kv_cache_len) - batch, head_num = kv_cache_start_loc.shape[0], softmax_logics.shape[0] - - num_warps = 4 - if BLOCK_SIZE >= 2048: - num_warps = 8 - if BLOCK_SIZE >= 4096: - num_warps = 16 - - _token_attn_softmax_fwd[(batch, head_num)]( - softmax_logics, - kv_cache_start_loc, - kv_cache_seqlen, - softmax_prob_out, - softmax_logics.stride(0), - softmax_logics.stride(1), - softmax_prob_out.stride(0), - softmax_prob_out.stride(1), - num_warps=num_warps, - BLOCK_SIZE=BLOCK_SIZE, - ) - return - - @triton.jit - def _token_attn_2_kernel( - Prob, - V, - attn_out, - kv_cache_loc, - kv_cache_start_loc, - kv_cache_seqlen, - max_kv_cache_len, - kv_cache_loc_b_stride, - kv_cache_loc_s_stride, - prob_head_dim_stride, - prob_batch_stride, - v_batch_stride, - v_head_stride, - v_head_dim_stride, - attn_out_batch_stride, - attn_out_head_stride, - attn_out_head_dim_stride, - HEAD_DIM: tl.constexpr, - BLOCK_N: tl.constexpr, - ): - current_batch = tl.program_id(0) - current_head = tl.program_id(1) - - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, HEAD_DIM) - current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch) - current_batch_start_index = max_kv_cache_len - current_batch_seq_len - current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch) - - v_loc_off = current_batch * kv_cache_loc_b_stride + (current_batch_start_index + offs_n) * kv_cache_loc_s_stride - p_offs = current_head * prob_head_dim_stride + (current_batch_in_all_start_index + offs_n) * prob_batch_stride - v_offs = current_head * v_head_stride + offs_d[None, :] * v_head_dim_stride - - acc = tl.zeros([HEAD_DIM], dtype=tl.float32) - for start_n in range(0, current_batch_seq_len, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - p_value = tl.load( - Prob + p_offs + start_n * kv_cache_loc_s_stride, - mask=(start_n + offs_n) < current_batch_seq_len, - other=0.0, - ) - v_loc = tl.load( - kv_cache_loc + v_loc_off + start_n * kv_cache_loc_s_stride, - mask=(start_n + offs_n) < current_batch_seq_len, - other=0.0, - ) - v_value = tl.load( - V + v_offs + v_loc[:, None] * v_batch_stride, - mask=(start_n + offs_n[:, None]) < current_batch_seq_len, - other=0.0, + kv_cache_seq_len, + max_len_in_batch, ) - acc += tl.sum(p_value[:, None] * v_value, 0) - - acc = acc.to(tl.float16) - off_o = ( - current_batch * attn_out_batch_stride - + current_head * attn_out_head_stride - + offs_d * attn_out_head_dim_stride - ) - out_ptrs = attn_out + off_o - tl.store(out_ptrs, acc) - return - - @torch.no_grad() - def token_attn_fwd_2(prob, v, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seqlen, max_kv_cache_len): - if triton.__version__ >= "2.1.0": - BLOCK = 128 - else: - BLOCK = 64 - batch, head = kv_cache_loc.shape[0], v.shape[1] - grid = (batch, head) - num_warps = 4 - dim = v.shape[-1] - - _token_attn_2_kernel[grid]( - prob, - v, - attn_out, - kv_cache_loc, - kv_cache_start_loc, - kv_cache_seqlen, - max_kv_cache_len, - kv_cache_loc.stride(0), - kv_cache_loc.stride(1), - prob.stride(0), - prob.stride(1), - v.stride(0), - v.stride(1), - v.stride(2), - attn_out.stride(0), - attn_out.stride(1), - attn_out.stride(2), - HEAD_DIM=dim, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - return - - @torch.no_grad() - def token_attention_fwd( - q, k, v, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_len_in_batch, alibi=None - ): - head_num = k.shape[1] - batch_size = kv_cache_seq_len.shape[0] - calcu_shape1 = (batch_size, head_num, k.shape[2]) - total_token_num = k.shape[0] - - att_m_tensor = torch.empty((head_num, total_token_num), dtype=q.dtype, device="cuda") - - token_attn_fwd_1( - q.view(calcu_shape1), - k, - att_m_tensor, - kv_cache_loc, - kv_cache_start_loc, - kv_cache_seq_len, - max_len_in_batch, - alibi=alibi, - ) prob = torch.empty_like(att_m_tensor) - token_attn_softmax_fwd(att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch) + lightllm_llama_token_softmax_fwd(att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch) att_m_tensor = None - token_attn_fwd_2( + lightllm_llama_token_att_fw2( prob, v, attn_out.view(calcu_shape1), kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_len_in_batch ) - prob = None - return class Llama2TokenAttentionForwards: @staticmethod @triton.jit + + # this function is adapted from https://github.com/ModelTC/lightllm/blob/5c559dd7981ed67679a08a1e09a88fb4c1550b3a/lightllm/models/llama2/triton_kernel/token_attention_nopad_softmax.py#L8 def _fwd_kernel( Logics, V, @@ -478,6 +155,7 @@ def _fwd_kernel( tl.store(out_ptrs, acc) return + # this function is adapted from https://github.com/ModelTC/lightllm/blob/5c559dd7981ed67679a08a1e09a88fb4c1550b3a/lightllm/models/llama2/triton_kernel/token_attention_nopad_softmax.py#L36 @staticmethod @torch.no_grad() def token_softmax_reducev_fwd(logics, v, o, b_loc, b_start_loc, b_seq_len, max_input_len, other_kv_index): @@ -514,277 +192,6 @@ def token_softmax_reducev_fwd(logics, v, o, b_loc, b_start_loc, b_seq_len, max_i ) return - @staticmethod - @triton.jit - def _fwd_kernel_token_softmax( - Logics, - B_Start_Loc, - B_Seqlen, - Prob_Out, - stride_logic_h, - stride_logic_bs, - stride_prob_h, - stride_prob_bs, - BLOCK_SIZE: tl.constexpr, - ): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - - col_offsets = tl.arange(0, BLOCK_SIZE) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - - row = tl.load( - Logics + cur_head * stride_logic_h + (cur_batch_in_all_start_index + col_offsets) * stride_logic_bs, - mask=col_offsets < cur_batch_seq_len, - other=-float("inf"), - ).to(tl.float32) - - row_minus_max = row - tl.max(row, axis=0) - numerator = tl.exp(row_minus_max) - denominator = tl.sum(numerator, axis=0) - softmax_output = numerator / denominator - - tl.store( - Prob_Out + cur_head * stride_prob_h + (cur_batch_in_all_start_index + col_offsets) * stride_prob_bs, - softmax_output, - mask=col_offsets < cur_batch_seq_len, - ) - return - - @staticmethod - @torch.no_grad() - def token_softmax_fwd(Logics, B_Start_Loc, B_Seqlen, Prob_Out, max_input_len): - BLOCK_SIZE = triton.next_power_of_2(max_input_len) - batch, head_num = B_Start_Loc.shape[0], Logics.shape[0] - - num_warps = 4 - if BLOCK_SIZE >= 2048: - num_warps = 8 - if BLOCK_SIZE >= 4096: - num_warps = 16 - - Llama2TokenAttentionForwards._fwd_kernel_token_softmax[(batch, head_num)]( - Logics, - B_Start_Loc, - B_Seqlen, - Prob_Out, - Logics.stride(0), - Logics.stride(1), - Prob_Out.stride(0), - Prob_Out.stride(1), - num_warps=num_warps, - BLOCK_SIZE=BLOCK_SIZE, - ) - return - - @staticmethod - @triton.jit - def _fwd_kernel_token_att1( - Q, - K, - sm_scale, - B_Loc, - B_Start_Loc, - B_Seqlen, - max_input_len, - Att_Out, - stride_b_loc_b, - stride_b_loc_s, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - att_stride_h, - att_stride_bs, - kv_group_num, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - ): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - start_n = tl.program_id(2) - - cur_kv_head = cur_head // kv_group_num - - offs_d = tl.arange(0, BLOCK_DMODEL) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - - cur_batch_start_index = max_input_len - cur_batch_seq_len - cur_batch_end_index = max_input_len - - off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d * stride_qd - - offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) - - block_stard_index = start_n * BLOCK_N - block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0) - - for start_mark in range(0, block_mask, 1): - q = tl.load(Q + off_q + start_mark) - offs_n_new = cur_batch_start_index + offs_n - k_loc = tl.load( - B_Loc + stride_b_loc_b * cur_batch + stride_b_loc_s * offs_n_new, - mask=offs_n_new < cur_batch_end_index, - other=0, - ) - off_k = k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :] * stride_kd - k = tl.load(K + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0) - att_value = tl.sum(q[None, :] * k, 1) - att_value *= sm_scale - off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n) * att_stride_bs - tl.store(Att_Out + off_o, att_value, mask=offs_n_new < cur_batch_end_index) - return - - @staticmethod - @torch.no_grad() - def token_att_fwd(q, k, att_out, B_Loc, B_Start_Loc, B_Seqlen, max_input_len): - BLOCK = 32 - # shape constraints - Lq, Lk = q.shape[-1], k.shape[-1] - assert Lq == Lk - assert Lk in {16, 32, 64, 128} - sm_scale = 1.0 / (Lk**0.5) - - batch, head_num = B_Loc.shape[0], q.shape[1] - - grid = (batch, head_num, triton.cdiv(max_input_len, BLOCK)) - kv_group_num = q.shape[1] // k.shape[1] - - num_warps = 4 if Lk <= 64 else 8 - num_warps = 2 - - Llama2TokenAttentionForwards._fwd_kernel_token_att1[grid]( - q, - k, - sm_scale, - B_Loc, - B_Start_Loc, - B_Seqlen, - max_input_len, - att_out, - B_Loc.stride(0), - B_Loc.stride(1), - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - att_out.stride(0), - att_out.stride(1), - kv_group_num=kv_group_num, - BLOCK_DMODEL=Lk, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - return - - @staticmethod - @triton.jit - def _fwd_kernel_token_att2( - Prob, - V, - Out, - B_Loc, - B_Start_Loc, - B_Seqlen, - max_input_len, # B_Start_Loc cumsum of input lens if continuous - stride_b_loc_b, - stride_b_loc_s, - stride_ph, - stride_pbs, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - kv_group_num, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - ): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - - cur_kv_head = cur_head // kv_group_num - - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_start_index = max_input_len - cur_batch_seq_len - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - - v_loc_off = cur_batch * stride_b_loc_b + (cur_batch_start_index + offs_n) * stride_b_loc_s - p_offs = cur_head * stride_ph + (cur_batch_in_all_start_index + offs_n) * stride_pbs - v_offs = cur_kv_head * stride_vh + offs_d[None, :] * stride_vd - - acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) - for start_n in range(0, cur_batch_seq_len, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - p_value = tl.load( - Prob + p_offs + start_n * stride_b_loc_s, mask=(start_n + offs_n) < cur_batch_seq_len, other=0.0 - ) - v_loc = tl.load( - B_Loc + v_loc_off + start_n * stride_b_loc_s, mask=(start_n + offs_n) < cur_batch_seq_len, other=0.0 - ) - v_value = tl.load( - V + v_offs + v_loc[:, None] * stride_vbs, - mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, - other=0.0, - ) - acc += tl.sum(p_value[:, None] * v_value, 0) - - acc = acc.to(tl.float16) - off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d * stride_od - out_ptrs = Out + off_o - tl.store(out_ptrs, acc) - return - - @staticmethod - @torch.no_grad() - def token_att_fwd2(prob, v, out, B_Loc, B_Start_Loc, B_Seqlen, max_input_len): - if triton.__version__ >= "2.1.0": - BLOCK = 128 - else: - BLOCK = 64 - batch, head = B_Loc.shape[0], prob.shape[0] - grid = (batch, head) - num_warps = 4 - dim = v.shape[-1] - - kv_group_num = prob.shape[0] // v.shape[1] - - Llama2TokenAttentionForwards._fwd_kernel_token_att2[grid]( - prob, - v, - out, - B_Loc, - B_Start_Loc, - B_Seqlen, - max_input_len, - B_Loc.stride(0), - B_Loc.stride(1), - prob.stride(0), - prob.stride(1), - v.stride(0), - v.stride(1), - v.stride(2), - out.stride(0), - out.stride(1), - out.stride(2), - kv_group_num=kv_group_num, - BLOCK_DMODEL=dim, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - return - # this is the interface of llama2 attn forward @staticmethod @torch.no_grad() @@ -796,7 +203,7 @@ def token_attn( calcu_shape1 = (batch_size, head_num, head_dim) att_m_tensor = torch.empty((head_num, total_token_num), dtype=q.dtype, device="cuda") - Llama2TokenAttentionForwards.token_att_fwd( + lightllm_llama2_token_att_fwd( q, k, att_m_tensor, @@ -808,12 +215,12 @@ def token_attn( if triton.__version__ == "2.0.0": prob = torch.empty_like(att_m_tensor) - Llama2TokenAttentionForwards.token_softmax_fwd( + lightllm_llama2_token_softmax_fwd( att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch ) att_m_tensor = None - Llama2TokenAttentionForwards.token_att_fwd2( + lightllm_llama2_token_att_fwd2( prob, v, attn_out.view(calcu_shape1), diff --git a/colossalai/legacy/context/parallel_context.py b/colossalai/legacy/context/parallel_context.py index 48bf8ab279e8..b95405a33092 100644 --- a/colossalai/legacy/context/parallel_context.py +++ b/colossalai/legacy/context/parallel_context.py @@ -54,7 +54,7 @@ def __init__(self): # logging self._verbose = False - self._logger = get_dist_logger() + self._logger = None @property def config(self): @@ -68,6 +68,12 @@ def verbose(self): def verbose(self, verbose_: bool): self._verbose = verbose_ + @property + def logger(self): + if self._logger is None: + self._logger = get_dist_logger() + return self._logger + def load_config(self, config: Union[dict, str]): """Loads the configuration from either a dict or a file. @@ -527,7 +533,7 @@ def set_device(self, device_ordinal: int = None): torch.cuda.set_device(device_ordinal) if self._verbose: - self._logger.info(f"process rank {global_rank} is bound to device {device_ordinal}") + self.logger.info(f"process rank {global_rank} is bound to device {device_ordinal}") def set_seed(self, seed: int): """Sets seeds for all random libraries. @@ -563,19 +569,19 @@ def set_seed(self, seed: int): seed_str = ", ".join([f"{k}: {v}" for k, v in seeds.items()]) if self._verbose: - self._logger.info( + self.logger.info( f"initialized seed on rank {global_rank}, " f"numpy: {seed}, python random: {seed}, {seed_str}," f"the default parallel seed is {ParallelMode.DATA}." ) else: if self._verbose: - self._logger.info( + self.logger.info( f"initialized seed on rank {global_rank}, " f"numpy: {seed}, python random: {seed}, pytorch: {seed}", ranks=[0], ) - self._logger.info( + self.logger.info( "WARNING: CUDA is not available, thus CUDA RNG cannot be used to track CUDA random number states", ranks=[0], ) diff --git a/colossalai/legacy/tensor/process_group.py b/colossalai/legacy/tensor/process_group.py index ec6043163336..230849f17576 100644 --- a/colossalai/legacy/tensor/process_group.py +++ b/colossalai/legacy/tensor/process_group.py @@ -31,7 +31,7 @@ def get(self, rank_list: List[int], backend: str = "nccl"): return self.dict[processgroup_key] -PYTORCHPGDICT_ = PyTorchProcessGroupDict() +PYTORCHPGDICT_ = None class ProcessGroup: @@ -59,6 +59,9 @@ def __init__( if not torch.distributed.is_initialized(): self.is_init = False return + global PYTORCHPGDICT_ + if PYTORCHPGDICT_ is None: + PYTORCHPGDICT_ = PyTorchProcessGroupDict() assert torch.distributed.is_initialized(), f"ProcessGroup must be used after distributed initialized" diff --git a/colossalai/nn/lr_scheduler/delayed.py b/colossalai/nn/lr_scheduler/delayed.py index ce7f126d6101..9d1d8f01dd2d 100644 --- a/colossalai/nn/lr_scheduler/delayed.py +++ b/colossalai/nn/lr_scheduler/delayed.py @@ -1,4 +1,10 @@ -from torch.optim.lr_scheduler import _LRScheduler +import torch +from packaging.version import Version + +if Version(torch.__version__) >= Version("2.0.0"): + from torch.optim.lr_scheduler import LRScheduler as _LRScheduler +else: + from torch.optim.lr_scheduler import _LRScheduler class _enable_get_lr_call: diff --git a/colossalai/nn/optimizer/cpu_adam.py b/colossalai/nn/optimizer/cpu_adam.py index f35dc0200237..c3c0180e8516 100644 --- a/colossalai/nn/optimizer/cpu_adam.py +++ b/colossalai/nn/optimizer/cpu_adam.py @@ -9,7 +9,8 @@ class CPUAdam(NVMeOptimizer): - """Implements Adam algorithm. + """ + Implements Adam algorithm. Supports parameters updating on both GPU and CPU, depending on the device of parameters. But the parameters and gradients should on the same device: @@ -77,6 +78,7 @@ def __init__( super(CPUAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir) self.adamw_mode = adamw_mode cpu_adam = CPUAdamBuilder().load() + # if you find yourself stuck here, make sure that you install colossalai with CUDA_EXT=1 specification self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode) def torch_adam_update( @@ -131,9 +133,6 @@ def step(self, closure=None, div_scale: float = -1): target_device = p.device if len(state) == 0: state["step"] = 0 - - # FIXME(ver217): CPU adam kernel only supports fp32 states now - assert p.dtype is torch.float, "CPUAdam only support fp32 parameters" # gradient momentums state["exp_avg"] = torch.zeros_like(p, device=target_device) # gradient variances diff --git a/colossalai/nn/optimizer/hybrid_adam.py b/colossalai/nn/optimizer/hybrid_adam.py index 32fc6136c4e6..c7a309b872ce 100644 --- a/colossalai/nn/optimizer/hybrid_adam.py +++ b/colossalai/nn/optimizer/hybrid_adam.py @@ -108,9 +108,6 @@ def step(self, closure=None, div_scale: float = -1): target_device = p.device if len(state) == 0: state["step"] = 0 - - # FIXME(ver217): CPU adam kernel only supports fp32 states now - assert p.dtype is torch.float, "HybridAdam only support fp32 parameters" # gradient momentums state["exp_avg"] = torch.zeros_like(p, device=target_device) # gradient variances diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index c69bbe6e8521..f822c1819adc 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -160,6 +160,86 @@ def _recv_object(src: int, dst: int, group: ProcessGroup) -> Any: return object_list[0] +def _p2p_comm( + tensor_send_next: torch.Tensor, + recv_prev: bool, + peer: int, + group: ProcessGroup, + comm_dtype: torch.dtype = torch.float16, +): + """ + Send and recv tensor using P2P communication, used when pipeline size is 2 to solve the race communication. + + Agrs: + tensor_send_next (torch.Tensor): tensor to be sent to next stage + recv_prev (bool): whether to receive tensor from previous stage + peer (int): rank of the peer + group (ProcessGroup): process group + comm_dtype (torch.dtype): dtype of the tensor to be sent + + Returns: + torch.Tensor: tensor received from previous stage + """ + # send and recv shape + send_next_shape = None + recv_prev_shape = None + + if tensor_send_next is not None: + send_next_shape = torch.tensor(tensor_send_next.size(), device=torch.cuda.current_device(), dtype=torch.int64) + if recv_prev: + recv_prev_shape = torch.empty((3), device=torch.cuda.current_device(), dtype=torch.int64) + + ops = [] + if send_next_shape is not None: + send_next_op = dist.P2POp(dist.isend, send_next_shape, peer=peer, group=group) + ops.append(send_next_op) + if recv_prev_shape is not None: + recv_prev_op = dist.P2POp( + dist.irecv, + recv_prev_shape, + peer=peer, + group=group, + ) + ops.append(recv_prev_op) + + if len(ops) > 0: + reqs = dist.batch_isend_irecv(ops) + for req in reqs: + req.wait() + + if recv_prev_shape is not None: + recv_prev_shape = recv_prev_shape.tolist() + + # send and recv data + tensor_recv_prev = None + if recv_prev: + tensor_recv_prev = torch.empty(recv_prev_shape, device=torch.cuda.current_device(), dtype=comm_dtype) + + ops = [] + if tensor_send_next is not None: + send_next_op = dist.P2POp( + dist.isend, + tensor_send_next, + peer=peer, + group=group, + ) + ops.append(send_next_op) + + if tensor_recv_prev is not None: + recv_prev_op = dist.P2POp( + dist.irecv, + tensor_recv_prev, + peer=peer, + group=group, + ) + ops.append(recv_prev_op) + if len(ops) > 0: + reqs = dist.batch_isend_irecv(ops) + for req in reqs: + req.wait() + return tensor_recv_prev + + class PipelineP2PCommunication: def __init__(self, stage_manager: PipelineStageManager) -> None: self.stage_manager = stage_manager @@ -221,3 +301,21 @@ def send_backward(self, input_object: Any, prev_rank: int = None) -> None: prev_rank = self.stage_manager.get_prev_rank() cur_rank = self.stage_manager.get_rank() _send_object(input_object, cur_rank, prev_rank, self.stage_manager.get_p2p_process_group(cur_rank, prev_rank)) + + def p2p_communicate( + self, output_object: Any, recv_pre: bool, peer: int = None, comm_dtype: torch.dtype = torch.float16 + ) -> None: + """ + Sends the input tensor to the next stage in pipeline, using `P2Pop` in torch. + + Args: + output_object (Any): Object to be sent. + next_rank (int, optional): The rank of the recipient of the tensor. + """ + if peer is None: + peer = self.stage_manager.get_next_rank() + cur_rank = self.stage_manager.get_rank() + recv_tensor = _p2p_comm( + output_object, recv_pre, peer, self.stage_manager.get_p2p_process_group(cur_rank, peer), comm_dtype + ) + return recv_tensor diff --git a/colossalai/pipeline/schedule/generate.py b/colossalai/pipeline/schedule/generate.py new file mode 100644 index 000000000000..1f4bbe9f8dad --- /dev/null +++ b/colossalai/pipeline/schedule/generate.py @@ -0,0 +1,358 @@ +import time +from functools import partial +from typing import Any, Iterable, Optional, Union + +import torch +import torch.cuda +from torch.nn import Module +from torch.utils._pytree import tree_map + +from colossalai.inference.pipeline.microbatch_manager import MicroBatchManager, Status +from colossalai.pipeline.p2p import PipelineP2PCommunication +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.utils.cuda import get_current_device + +from ._utils import get_batch_size, get_micro_batch, model_forward, to_device +from .base import PipelineSchedule + + +class ActionIntervalBuffer: + """ + The buffer to save the interval hidden states and new token for stage to use. + + """ + + def __int__(self): + self.hidden_states = None + self.new_token = None + + def clear(self): + self.hidden_states = None + self.new_token = None + + +class GenerateSchedule(PipelineSchedule): + """ + GenerateSchedule is a class that handles the pipeline parallel inference. + In our schedule, we place tie weight layer, embedding and lm_head in the same device to save space, so in + this schedule, the out for each encoding progress is on rank0. + + Args: + stage_manager (`PipelineStageManager`): Pipeline stage manager. + mb_manager (`MicroBatchManager`): Micro batch manager. + verbose (bool): Whether to verbose the information of the pipeline. + """ + + def __init__(self, stage_manager: PipelineStageManager, mb_manager: MicroBatchManager, verbose: bool) -> None: + super().__init__(stage_manager) + self.comm = PipelineP2PCommunication(stage_manager) + self.mb_manager = mb_manager + self.microbatch_size = mb_manager.micro_batch_size + self.batch: Optional[Any] = None + self.batch_size: Optional[int] = None + self.microbatch_offset: Optional[int] = None + self.num_microbatches: Optional[int] = None + self.action_interval_buffer = ActionIntervalBuffer() + self.verbose = verbose + self.timestamps = None + self.comm_dtype = None + + def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: + """Load a batch from data iterator. + + Args: + data_iter (Iterable): Data iterator. + device (Optional[torch.device], optional): Target device. Defaults to None. + """ + batch = next(data_iter) + if device is not None: + batch = tree_map(partial(to_device, device=device), batch) + self.batch = batch + self.batch_size = get_batch_size(batch) + self.microbatch_offset = 0 + assert ( + self.batch_size % self.microbatch_size == 0 + ), f"Batch size should divided by the number of microbatches, {self.batch_size}, {self.num_microbatches}" + self.num_microbatches = self.batch_size // self.microbatch_size + self.round = self.num_microbatches // self.stage_manager.num_stages + + def load_micro_batch(self) -> Any: + """Load a micro batch from the current batch. + + Returns: + Any: Micro batch. + """ + micro_batch = get_micro_batch(self.batch, self.microbatch_offset, self.microbatch_size) + self.microbatch_offset += self.microbatch_size + return tree_map(partial(to_device, device=get_current_device()), micro_batch) + + def _prepare_inputs_for_interval_stage(self): + """ + Prepare inputs for interval stage, for all the interval stage, the inputs is just the past_key_values + + Returns: + dict: inputs for interval stage, `{'past_key_values': torch.Tensor}` or `None` + """ + model_inputs = ( + {"past_key_values": self.mb_manager.cur_kv_cache} if self.mb_manager.cur_kv_cache is not None else None + ) + return model_inputs + + def _prepare_inputs_for_new_token(self, new_token: torch.Tensor): + """ + Prepare inputs for new token, the inputs is a dict with `input_ids`, `attention_mask` and `past_key_values` + `input_ids` is the new token, `attention_mask` is the previous mask add `1` in the end, + `past_key_values` is the past_key_values save in the micro batch manager + + Returns: + dict: inputs for new token, `{'input_ids': torch.Tensor, 'attention_mask': torch.Tensor, 'past_key_values': torch.Tensor}` + """ + new_mask = self.mb_manager.cur_descrption.attn_mask + past_key_values = self.mb_manager.cur_descrption.kv_cache + + return dict(input_ids=new_token, attention_mask=new_mask, past_key_values=past_key_values) + + def _get_token_id(self, hidden_state: torch.Tensor) -> torch.Tensor: + last_hidden_state = hidden_state[:, -1] + input_ids = torch.argmax(last_hidden_state, dim=-1).unsqueeze(1) + return input_ids + + def _recv_pre_stage(self) -> Any: + """ + Receive the output from previous stage + + Returns: + Any: The output from previous stage + """ + if self.stage_manager.num_stages == 2: + return self.comm.p2p_recv() + return self.comm.recv_forward() + + def _load_stage_action(self, model: Module) -> None: + """ + In this action, 1.load micro_batch 2.do the forward 3.step to update + """ + inputs_dict = self.load_micro_batch() + if self.verbose and self.stage_manager.is_first_stage(): + torch.cuda.synchronize() + self.timestamps[self.mb_manager.idx].append(time.time()) + output_dict = model_forward(model, inputs_dict, None) + + self.mb_manager.step(inputs_dict, output_dict, None) + self.action_interval_buffer.hidden_states = output_dict["hidden_states"] + + def _gen_token_action(self, model: Module): + """ + In this action, 1.do the forward with hidden_states to generate new tokens 2.step to update + """ + hidden_states = self.action_interval_buffer.hidden_states + assert hidden_states is not None, "When first stage in GENERATE phase, the hidden states should not be None" + hidden_states = {"hidden_states": hidden_states} + logits = model_forward(model, None, hidden_states) + if self.verbose and self.stage_manager.is_first_stage(): + torch.cuda.synchronize() + self.timestamps[self.mb_manager.idx].append(time.time()) + assert ( + "logits" in logits + ), f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}" + new_token = self._get_token_id(logits["logits"]) + + self.mb_manager.step(None, None, new_token) + self.action_interval_buffer.new_token = new_token + self.action_interval_buffer.hidden_states = None + + def _head_encoding_action(self, model: Module): + """ + In this action, 1.prepare inputs for encoding for first stage. 2.do the forward to get hidden states 3.step to update + """ + new_token = self.action_interval_buffer.new_token + assert new_token is not None, "When first stage in GENERATE phase, the new token should not be None" + inputs_dict = self._prepare_inputs_for_new_token(new_token) + output_dict = model_forward(model, inputs_dict, None) + + self.mb_manager.step(inputs_dict, output_dict, None) + self.action_interval_buffer.hidden_states = output_dict["hidden_states"] + + def _body_encoding_action(self, model: Module): + hidden_states = self.action_interval_buffer.hidden_states + assert hidden_states is not None, "When not first stage, the hidden states should not be None" + inputs_dict = self._prepare_inputs_for_interval_stage() + hidden_states = {"hidden_states": hidden_states} + output_dict = model_forward(model, inputs_dict, hidden_states) + + self.mb_manager.step(inputs_dict, output_dict, None) + self.action_interval_buffer.hidden_states = output_dict["hidden_states"] + + def _comm_action(self, recv_pre: bool) -> torch.Tensor: + """ + In this action, 1.receive the hidden_states from previous stage 2.send the hidden_states to next stage + """ + hidden_states = self.action_interval_buffer.hidden_states + ret = self.comm.p2p_communicate(hidden_states, recv_pre, comm_dtype=self.comm_dtype) + + self.action_interval_buffer.hidden_states = ret + + def _gen_action(self, model: Module): + """ + In p2p step method, we use `P2POp` asynchronous communication method, so the communication need to be done + at the begin of each microbatch, it's a more clear way to use an action list to do so. In this function, it will + generate a sequence action for current state, and do the action one by one. + + Args: + model (Module): Model to be run. + + Returns: + List[Callable]: A list of action, each action is a callable function, and it will be called in order. + """ + actions = [] + if self.stage_manager.is_first_stage(): + if self.mb_manager.cur_state is Status.PREFILL: + actions.append(partial(self._comm_action, False)) + actions.append(partial(self._load_stage_action, model)) + elif self.stage_manager.is_first_stage() and self.mb_manager.cur_state is Status.GENERATE: + actions.append(partial(self._comm_action, True)) + actions.append(partial(self._gen_token_action, model)) + actions.append(partial(self._head_encoding_action, model)) + elif self.stage_manager.is_first_stage() and self.mb_manager.cur_state is Status.COOLDOWN: + actions.append(partial(self._comm_action, True)) + actions.append(partial(self._gen_token_action, model)) + # other stage + else: + actions.append(partial(self._comm_action, True)) + actions.append(partial(self._body_encoding_action, model)) + + return actions + + def generate_step(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]: + if self.stage_manager.num_stages == 2: + return self.generate_step_p2p(model, data_iter) + else: + return self.generate_step_broadcast(model, data_iter) + + @torch.no_grad() + def generate_step_p2p(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]: + """ + Forward one step of the pipeline, when pipeline size is 2, the schedule is a circle, broadcast communication will be + blocked, so we use `P2POp` asynchronous communication method. + + Args: + model (Module): Model to be run. + data_iter (Iterable): Data iterator. + + Returns: + Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor). + """ + output_sequence = [] + self.load_batch(data_iter) + model.eval() + self.comm_dtype = model.dtype + + whole_timestamp = [] + + # run by round + for _ in range(self.round): + self.timestamps = ( + [[] for _ in range(self.stage_manager.num_stages)] + if self.verbose and self.stage_manager.is_first_stage() + else None + ) + self.action_interval_buffer.clear() + while self.mb_manager.is_micro_batch_done() is False: + actions = self._gen_action(model) + for action in actions: + action() + self.mb_manager.next() + # All microbatch in current round is DONE + if self.stage_manager.is_first_stage(): + output_sequence.extend(self.mb_manager.export_new_tokens()) + else: + self._comm_action(False) + self.mb_manager.clear() + if self.verbose and self.stage_manager.is_first_stage(): + whole_timestamp.extend(self.timestamps) + + return output_sequence, whole_timestamp + + @torch.no_grad() + def generate_step_broadcast(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]: + """ + Forward one step of the pipeline + + Args: + model (Module): Model to be run. + data_iter (Iterable): Data iterator. + + Returns: + Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor). + """ + output_sequence = [] + self.load_batch(data_iter) + model.eval() + + whole_timestamp = [] + # run by round + for _ in range(self.round): + self.timestamps = ( + [[] for _ in range(self.stage_manager.num_stages)] + if self.verbose and self.stage_manager.is_first_stage() + else None + ) + while self.mb_manager.is_micro_batch_done() is False: + inputs_dict = None + new_token = None + output_dict = None + + # First stage and in PREFILL phase, just load the inputs + if self.stage_manager.is_first_stage() and self.mb_manager.cur_state is Status.PREFILL: + inputs_dict = self.load_micro_batch() + if self.verbose and self.stage_manager.is_first_stage(): + torch.cuda.synchronize() + self.timestamps[self.mb_manager.idx].append(time.time()) + output_dict = model_forward(model, inputs_dict, None) + self.mb_manager.step(inputs_dict, output_dict, None) + # In GENERATE phase + else: + # Get hidden_states from previous stage + hidden_states = self.comm.recv_forward() + if self.stage_manager.is_first_stage(): + # First just generate a new token + assert ( + hidden_states is not None + ), "When first stage in GENERATE phase, the hidden states should not be None" + logits = model_forward(model, None, hidden_states) + if self.verbose and self.stage_manager.is_first_stage(): + torch.cuda.synchronize() + self.timestamps[self.mb_manager.idx].append(time.time()) + assert ( + "logits" in logits + ), f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}" + new_token = self._get_token_id(logits["logits"]) + self.mb_manager.step(None, None, new_token) + # If the current micro batch is not DONE, go through blocks + if self.mb_manager.cur_state in (Status.GENERATE, Status.COOLDOWN): + inputs_dict = self._prepare_inputs_for_new_token(new_token) + output_dict = model_forward(model, inputs_dict, None) + self.mb_manager.step(inputs_dict, output_dict, None) + else: + assert hidden_states is not None, "When not first stage, the hidden states should not be None" + inputs_dict = self._prepare_inputs_for_interval_stage() + output_dict = model_forward(model, inputs_dict, hidden_states) + self.mb_manager.step(inputs_dict, output_dict, None) + + # Current microbatch is not DONE, send hidden_state to next stage + if not self.stage_manager.is_first_stage() or self.mb_manager.cur_state in ( + Status.GENERATE, + Status.COOLDOWN, + ): + self.comm.send_forward({"hidden_states": output_dict["hidden_states"]}) + + self.mb_manager.next() + + # All microbatch in current round is DONE + if self.stage_manager.is_first_stage(): + output_sequence.extend(self.mb_manager.export_new_tokens()) + self.mb_manager.clear() + if self.verbose and self.stage_manager.is_first_stage(): + whole_timestamp.extend(self.timestamps) + + return output_sequence, whole_timestamp diff --git a/colossalai/pipeline/stage_manager.py b/colossalai/pipeline/stage_manager.py index b79867a2c651..d988015ceeda 100644 --- a/colossalai/pipeline/stage_manager.py +++ b/colossalai/pipeline/stage_manager.py @@ -12,6 +12,7 @@ class PipelineStageManager: Args: pg_mesh (ProcessGroupMesh): Process group mesh. pipeline_axis (int): The axis along which the pipeline is constructed. + is_virtual (bool): Whether to use circle p2p communication, it will make the first and last stage communicate with each other. Attributes: num_stages (int): Number of stages in the pipeline. @@ -24,6 +25,7 @@ def __init__(self, pg_mesh: ProcessGroupMesh, pipeline_axis: int, is_virtual: bo self.prev_rank: Optional[Tuple[int, ...]] = None self.next_rank: Optional[Tuple[int, ...]] = None self.p2p_groups: Dict[Tuple[int, int], ProcessGroup] = {} + # init prev and next coord coord = self.pg_mesh.coordinate() # the prev rank of rank0 is the last rank diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md index 4bd7d5208a64..63b28701e879 100644 --- a/colossalai/shardformer/README.md +++ b/colossalai/shardformer/README.md @@ -77,7 +77,7 @@ Following are the description `ShardConfig`'s arguments: - `enable_sequence_parallelism`: Whether to turn on sequence parallelism, which partitions non-tensor-parallel regions along the sequence dimension. Defaults to False. -- `enable_sequence_overlap`: Whether to turn on sequence overlap, wheich overlap the computation and communication in sequence parallelism. It can only be used when `enable_sequence_parallelism` is True. Defaults to False. +- `enable_sequence_overlap`: Whether to turn on sequence overlap, which overlap the computation and communication in sequence parallelism. It can only be used when `enable_sequence_parallelism` is True. Defaults to False. - `enable_all_optimization`: Whether to turn on all optimization tools including `fused normalizaion`, `flash attention`, `JIT fused operators`, `sequence parallelism` and `sequence overlap`. Defaults to False. diff --git a/colossalai/shardformer/modeling/vit.py b/colossalai/shardformer/modeling/vit.py index 2db83b912112..5a50e7379cdc 100644 --- a/colossalai/shardformer/modeling/vit.py +++ b/colossalai/shardformer/modeling/vit.py @@ -100,35 +100,24 @@ def pp_forward( embedding_output = self.embeddings( pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding ) + hidden_states = embedding_output else: assert ( hidden_states is not None ), f"Current stage is {stage_manager.stage}, hidden_states should not be None" - # Go through encoder + encoder_outputs = _encoder_forward( + encoder=self.encoder, + start_idx=stage_index[0], + end_idx=stage_index[1], + hidden_states=hidden_states, + head_mask=head_mask, + return_dict=return_dict, + stage_manager=stage_manager, + ) if not stage_manager.is_last_stage(): - hidden_states = _encoder_forward( - encoder=self.encoder, - start_idx=stage_index[0], - end_idx=stage_index[1], - hidden_states=embedding_output, - head_mask=head_mask, - return_dict=return_dict, - stage_manager=stage_manager, - ) - return {"hidden_states": hidden_states} - else: - encoder_outputs = _encoder_forward( - encoder=self.encoder, - start_idx=stage_index[0], - end_idx=stage_index[1], - hidden_states=hidden_states, - head_mask=head_mask, - return_dict=return_dict, - stage_manager=stage_manager, - ) + return {"hidden_states": encoder_outputs} - # Go through rest layers sequence_output = encoder_outputs[0] sequence_output = self.layernorm(sequence_output) pooled_output = self.pooler(sequence_output) if self.pooler is not None else None diff --git a/colossalai/testing/comparison.py b/colossalai/testing/comparison.py index 816bc0d7b6d7..4f2a4878e7ce 100644 --- a/colossalai/testing/comparison.py +++ b/colossalai/testing/comparison.py @@ -40,7 +40,7 @@ def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None): assert torch.all(a == b), f"expected tensors on rank {i} and {i + 1} to be equal but they are not, {a} vs {b}" -def check_state_dict_equal(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True): +def check_state_dict_equal(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True, ignore_dtype: bool = False): assert len(list(d1.keys())) == len( list(d2.keys()) ), f"Number of keys unequal: {len(list(d1.keys()))} vs {len(list(d2.keys()))}" @@ -58,6 +58,8 @@ def check_state_dict_equal(d1: OrderedDict, d2: OrderedDict, ignore_device: bool if not ignore_device: v1_i = v1_i.to("cpu") v2_i = v2_i.to("cpu") + if ignore_dtype: + v1_i = v1_i.to(v2_i.dtype) assert_close_loose(v1_i, v2_i) elif isinstance(v1_i, dict): assert isinstance(v2_i, dict) @@ -69,6 +71,8 @@ def check_state_dict_equal(d1: OrderedDict, d2: OrderedDict, ignore_device: bool if not ignore_device: v1 = v1.to("cpu") v2 = v2.to("cpu") + if ignore_dtype: + v1 = v1.to(v2.dtype) assert_close_loose(v1, v2) else: assert v1 == v2, f"{v1} not equals to {v2}" diff --git a/colossalai/zero/gemini/chunk/chunk.py b/colossalai/zero/gemini/chunk/chunk.py index bbef9013c20b..d3309fc5364f 100644 --- a/colossalai/zero/gemini/chunk/chunk.py +++ b/colossalai/zero/gemini/chunk/chunk.py @@ -160,6 +160,8 @@ def __init__( self.l2_norm_flag = False self.l2_norm = None + self.grad_chunk = None + @property def memory_usage(self) -> Dict[str, int]: cuda_memory = 0 @@ -414,7 +416,9 @@ def tensor_trans_state(self, tensor: torch.Tensor, tensor_state: TensorState) -> return self.__update_one_tensor_info(self.tensors_info[tensor], tensor_state) - def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data_slice: torch.Tensor) -> None: + def copy_tensor_to_chunk_slice( + self, tensor: torch.Tensor, data_slice: torch.Tensor, update_ptr: bool = True + ) -> None: """ Copy data slice to the memory space indexed by the input tensor in the chunk. @@ -427,7 +431,23 @@ def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data_slice: torch.Ten tensor_info = self.tensors_info[tensor] self.cuda_global_chunk[tensor_info.offset : tensor_info.end].copy_(data_slice.data.flatten()) - tensor.data = self.cuda_global_chunk[tensor_info.offset : tensor_info.end].view(tensor.shape) + if update_ptr: + tensor.data = self.cuda_global_chunk[tensor_info.offset : tensor_info.end].view(tensor.shape) + + def add_tensor_to_chunk_slice(self, tensor: torch.Tensor, data_slice: torch.Tensor) -> None: + """ + Add data slice to the memory space indexed by the input tensor in the chunk. + Only used when accumulating gradient chunks. + + Args: + tensor (torch.Tensor): the tensor used to retrieve meta information + data_slice (torch.Tensor): the tensor to be added to the chunk + """ + # sanity check + assert self.is_gathered + + tensor_info = self.tensors_info[tensor] + self.cuda_global_chunk[tensor_info.offset : tensor_info.end].add_(data_slice.data.flatten()) def get_valid_length(self) -> int: """Get the valid length of the chunk's payload.""" @@ -577,3 +597,46 @@ def print_tensor(tensor, prefix=""): output.append("\t\t# of {}: {}\n".format(st, self.tensor_state_cnter[st])) return "".join(output) + + def init_grad_chunk(self) -> "Chunk": + """Init grad chunk. This should be called in grad handler. + + Returns: + Chunk: Grad chunk + """ + if self.grad_chunk is None: + # grad chunk is not initialized + grad_chunk = Chunk( + chunk_size=self.chunk_size, + process_group=self.torch_pg, + dtype=self.dtype, + keep_gathered=self.keep_gathered, + pin_memory=self.pin_memory, + ) + grad_chunk.num_tensors = self.num_tensors + grad_chunk.utilized_size = self.utilized_size + grad_chunk.tensor_state_cnter[TensorState.HOLD] = self.num_tensors + for tensor, state in self.tensors_info.items(): + grad_chunk.tensors_info[tensor] = TensorInfo(TensorState.HOLD, state.offset, state.end) + + grad_chunk.valid_end = self.valid_end + + if grad_chunk.chunk_temp.device.type == "cpu": + grad_chunk.cuda_global_chunk = grad_chunk.chunk_temp.to(get_current_device()) + else: + grad_chunk.cuda_global_chunk = grad_chunk.chunk_temp + grad_chunk.chunk_temp = None + + if grad_chunk.pin_memory: + grad_chunk.cpu_shard = torch.empty( + grad_chunk.shard_size, dtype=grad_chunk.dtype, pin_memory=grad_chunk.pin_memory + ) + + self.grad_chunk = grad_chunk + else: + # grad chunk is initialized, just reallocate cuda global chunk + self.grad_chunk.cuda_shard = None + self.grad_chunk.is_gathered = True + alloc_storage(self.grad_chunk.cuda_global_chunk) + + return self.grad_chunk diff --git a/colossalai/zero/gemini/chunk/manager.py b/colossalai/zero/gemini/chunk/manager.py index 957e41b02d49..d3c512fe978d 100644 --- a/colossalai/zero/gemini/chunk/manager.py +++ b/colossalai/zero/gemini/chunk/manager.py @@ -5,7 +5,7 @@ import torch.distributed as dist from torch.distributed import ProcessGroup -from colossalai.utils import get_current_device +from colossalai.utils import free_storage, get_current_device from .chunk import Chunk, ChunkFullError, TensorState @@ -245,3 +245,47 @@ def __sub_accessed_chunk(self, chunk: Chunk): chunk.release_chunk() self.accessed_chunks.remove(chunk) self.accessed_mem -= chunk.chunk_mem + + def init_grad_chunk(self, chunk: Chunk) -> Chunk: + if chunk.grad_chunk is not None: + self.__sub_memory_usage(chunk.grad_chunk.memory_usage) + grad_chunk = chunk.init_grad_chunk() + self.__add_memory_usage(grad_chunk.memory_usage) + if grad_chunk not in self.accessed_chunks: + self.accessed_chunks.add(grad_chunk) + self.accessed_mem += grad_chunk.chunk_mem + return grad_chunk + + def rearrange_accumulated_grad_chunk(self, chunk: Chunk) -> Chunk: + """Rearrange gradients accumulated in chunk.grad_chunk, and getP prepared for gradient reduction.""" + + assert chunk.grad_chunk is not None + + # Make a backup for gradient accumulated before. + # Here backup gradients should be multiplied, since it will be divided after gradient reduction. + if chunk.grad_chunk.is_gathered: + accumulated_grad = chunk.grad_chunk.cuda_global_chunk.clone().detach().mul_(chunk.pg_size) + accumulated_grad_gathered = True + else: + if chunk.grad_chunk.cuda_shard is not None: + accumulated_grad = chunk.grad_chunk.cuda_shard.clone().detach().mul_(chunk.pg_size) + else: + accumulated_grad = ( + chunk.grad_chunk.cpu_shard.to(get_current_device()).clone().detach().mul_(chunk.pg_size) + ) + accumulated_grad_gathered = False + + # Reset grad_chunk, and chunk.grad_chunk will be accessed. + grad_chunk = self.init_grad_chunk(chunk) + grad_chunk.cuda_global_chunk.zero_() + + # Add backup gradients to grad_chunk. + if accumulated_grad_gathered: + grad_chunk.cuda_global_chunk.add_(accumulated_grad) + else: + grad_chunk.cuda_global_chunk[grad_chunk.shard_begin : grad_chunk.shard_end].add_(accumulated_grad) + + # Release accumulated_grad + free_storage(accumulated_grad) + + return grad_chunk diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 0ba9e53cfcd6..df7e1163c3d9 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -59,6 +59,7 @@ def __init__( chunk_config_dict: Optional[dict] = None, chunk_init_device: torch.device = torch.device("cpu"), placement_policy: str = "static", + enable_gradient_accumulation: bool = False, shard_param_frac: float = 1.0, # only for static placement offload_optim_frac: float = 0.0, # only for static placement offload_param_frac: float = 0.0, # only for static placement @@ -74,6 +75,7 @@ def __init__( mixed_precision: torch.dtype = torch.float16, process_group: Optional[ProcessGroup] = None, memstats: Optional[MemStats] = None, # genimi memory stats + master_weights: bool = True, verbose: bool = False, ) -> None: assert mixed_precision in (torch.float16, torch.bfloat16) @@ -115,6 +117,14 @@ def __init__( self.mixed_precision = mixed_precision self.dp_process_group = process_group or _get_default_group() + self.reuse_fp16_chunk = master_weights + self.master_weights = master_weights + + self.enable_gradient_accumulation = enable_gradient_accumulation + if self.enable_gradient_accumulation: + self.reuse_fp16_chunk = False + self.accumulating_grads = False # Whether model is accumulating gradients + self._logger = get_dist_logger() if self.gemini_manager._premade_memstats_: @@ -294,6 +304,8 @@ def _post_backward(self): f"{error_str}", ) self._setup_grads_ptr() + if self.enable_gradient_accumulation and not self.accumulating_grads: + self.accumulating_grads = True # Turn on the state of gradient accumulation. self._logger.debug( f"comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}" ) @@ -321,20 +333,48 @@ def grad_handle(self, p, grad): f"Parameter `{self.param2name[p]}` failed at the gradient reduction. " "Some unsupported torch function is operated upon this parameter." ) - self.chunk_manager.trans_tensor_state(p, TensorState.READY_FOR_REDUCE) - chunk.copy_tensor_to_chunk_slice(p, grad) - reduced = self.chunk_manager.reduce_chunk(chunk) + grad_chunk = chunk + if not self.reuse_fp16_chunk: + if not self.accumulating_grads: + grad_chunk = self.chunk_manager.init_grad_chunk(chunk) + else: + assert chunk.grad_chunk is not None + if chunk.grad_chunk not in self.chunk_manager.accessed_chunks: + grad_chunk = self.chunk_manager.rearrange_accumulated_grad_chunk(chunk) + else: + grad_chunk = chunk.grad_chunk + + # hold -> compute -> hold after bwd + grad_chunk.tensor_trans_state(p, TensorState.COMPUTE) + grad_chunk.tensor_trans_state(p, TensorState.HOLD_AFTER_BWD) + # fp16 param chunk: hold after bwd -> ready for reduce -> hold + chunk.tensor_trans_state(p, TensorState.READY_FOR_REDUCE) + chunk.tensor_trans_state(p, TensorState.HOLD) + + grad_chunk.tensor_trans_state(p, TensorState.READY_FOR_REDUCE) + if not self.accumulating_grads: + grad_chunk.copy_tensor_to_chunk_slice(p, grad, update_ptr=self.reuse_fp16_chunk) + else: + grad_chunk.add_tensor_to_chunk_slice(p, grad) + reduced = self.chunk_manager.reduce_chunk(grad_chunk) if reduced: - if chunk.is_gathered: - chunk.cuda_global_chunk.div_(chunk.pg_size) + if not self.reuse_fp16_chunk: + if chunk.keep_gathered: + self.chunk_manager.fake_release_chunk(chunk) + else: + self.chunk_manager.release_chunk(chunk) + if grad_chunk.is_gathered: + grad_chunk.cuda_global_chunk.div_(chunk.pg_size) else: - chunk.cuda_shard.div_(chunk.pg_size) + grad_chunk.cuda_shard.div_(chunk.pg_size) # check overflow elements - self.overflow_counter += chunk.has_inf_or_nan - # record l2 norm for gradient clipping + self.overflow_counter += grad_chunk.has_inf_or_nan + # record l2 norm for gradient clipping. flag is bound to fp16 chunk if chunk.l2_norm_flag: - chunk.set_l2_norm() - self.chunk_manager.move_chunk(chunk, self.grads_device[p], force_copy=True) + grad_chunk.set_l2_norm() + self.chunk_manager.move_chunk(grad_chunk, self.grads_device[p], force_copy=True) + if not (self.master_weights) or (self.enable_gradient_accumulation): + self.chunk_manager.move_chunk(chunk, self.grads_device[p], force_copy=True) return empty_grad def zero_grad(self, set_to_none: bool = False) -> None: @@ -344,9 +384,7 @@ def set_chunk_grad_device(self, chunk: Chunk, device: torch.device) -> None: for tensor in chunk.get_tensors(): self.grads_device[tensor] = device - def state_dict( - self, destination=None, prefix="", keep_vars=False, only_rank_0: bool = True, dtype: torch.dtype = torch.float16 - ): + def state_dict(self, destination=None, prefix="", keep_vars=False, only_rank_0: bool = True): """Returns a dictionary containing a whole state of the module. Both parameters and persistent buffers (e.g. running averages) are included. @@ -365,7 +403,7 @@ def state_dict( destination = OrderedDict() destination._metadata = OrderedDict() destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version) - self._save_to_state_dict(destination, prefix, keep_vars, only_rank_0, dtype) + self._save_to_state_dict(destination, prefix, keep_vars, only_rank_0) for hook in self._state_dict_hooks.values(): hook_result = hook(self, destination, prefix, local_metadata) @@ -373,7 +411,7 @@ def state_dict( destination = hook_result return destination - def _get_chunk_to_save_data(self, chunk: Chunk, only_rank_0: bool, dtype: torch.dtype = torch.float16) -> Dict: + def _get_chunk_to_save_data(self, chunk: Chunk, only_rank_0: bool) -> Dict: """ get gathered chunk content. @@ -386,9 +424,8 @@ def _get_chunk_to_save_data(self, chunk: Chunk, only_rank_0: bool, dtype: torch. """ # save parameters chunk_to_save_data = dict() - temp_chunk = get_temp_total_chunk_on_cuda(chunk) - if torch.is_floating_point(temp_chunk): - temp_chunk = temp_chunk.to(dtype) + temp_chunk = get_temp_total_chunk_on_cuda(chunk, self.mixed_precision) + for tensor, tensor_info in chunk.tensors_info.items(): record_tensor = torch.empty([0]) record_flag = (not only_rank_0) | (dist.get_rank(chunk.torch_pg) == 0) @@ -401,9 +438,7 @@ def _get_chunk_to_save_data(self, chunk: Chunk, only_rank_0: bool, dtype: torch. del temp_chunk return chunk_to_save_data - def _get_param_to_save_data( - self, param_list: List[torch.nn.Parameter], only_rank_0: bool, dtype: torch.dtype - ) -> Dict: + def _get_param_to_save_data(self, param_list: List[torch.nn.Parameter], only_rank_0: bool) -> Dict: """ get param content from chunks. @@ -418,10 +453,10 @@ def _get_param_to_save_data( param_to_save_data = dict() chunk_list = self.chunk_manager.get_chunks(param_list) for chunk in chunk_list: - param_to_save_data.update(self._get_chunk_to_save_data(chunk, only_rank_0, dtype)) + param_to_save_data.update(self._get_chunk_to_save_data(chunk, only_rank_0)) return param_to_save_data - def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True, dtype=torch.float16): + def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True): r"""Saves module state to `destination` dictionary, containing a state of the module, but not its descendants. This is called on every submodule in :meth:`~torch.nn.Module.state_dict`. @@ -438,14 +473,18 @@ def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True, # get copies of fp32 parameters in CPU # as memory of fp16_params may be reused by grad, it's not reliable, we should use fp32_params and convert to fp16 - param_to_save_data = self._get_param_to_save_data(self.fp32_params, only_rank_0, dtype) + params = self.fp32_params if self.reuse_fp16_chunk else self.fp16_params + param_to_save_data = self._get_param_to_save_data(params, only_rank_0) # get the mapping between copies and fp16 parameters p_mapping = dict() - for p, fp32_p in zip(self.fp16_params, self.fp32_params): - name = self.param2name[p] - assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name) - record_parameter = param_to_save_data[fp32_p] - p_mapping[p] = record_parameter + if self.reuse_fp16_chunk: + for p, fp32_p in zip(self.fp16_params, self.fp32_params): + name = self.param2name[p] + assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name) + record_parameter = param_to_save_data[fp32_p] + p_mapping[p] = record_parameter + else: + p_mapping = param_to_save_data for name, param in self.name2param.items(): if param is not None: if is_ddp_ignored(param): @@ -593,7 +632,7 @@ def load(param_name, dest_tensor, copy_func): elif strict: missing_keys.append(state_key) - def load_fp32_parameter(chunk_slice, data): + def load_parameter(chunk_slice, data): chunk_slice.copy_(data.flatten()) for name, param in self.named_parameters(): @@ -607,14 +646,15 @@ def load_fp32_parameter(chunk_slice, data): name = self.param2name[p] fp32_to_name[fp32_p] = name - chunk_list = self.chunk_manager.get_chunks(self.fp32_params) + params_to_load = self.fp32_params if self.reuse_fp16_chunk else self.fp16_params + chunk_list = self.chunk_manager.get_chunks(params_to_load) for chunk in chunk_list: - temp_chunk = get_temp_total_chunk_on_cuda(chunk) + temp_chunk = get_temp_total_chunk_on_cuda(chunk, self.mixed_precision) for tensor, tensor_info in chunk.tensors_info.items(): - parameter_name = fp32_to_name[tensor] + parameter_name = fp32_to_name[tensor] if self.reuse_fp16_chunk else self.param2name[tensor] parameter_slice = temp_chunk[tensor_info.offset : tensor_info.end] - load(parameter_name, tensor, partial(load_fp32_parameter, parameter_slice)) + load(parameter_name, tensor, partial(load_parameter, parameter_slice)) if chunk.is_gathered: chunk.cuda_global_chunk.copy_(temp_chunk) @@ -624,11 +664,11 @@ def load_fp32_parameter(chunk_slice, data): chunk.cpu_shard.copy_(temp_chunk[chunk.shard_begin : chunk.shard_end]) del temp_chunk - - for chunk_32 in chunk_list: - chunk_16 = chunk_32.paired_chunk - assert chunk_16 is not None - chunk_16.payload.copy_(chunk_32.payload) + if self.reuse_fp16_chunk: + for chunk_32 in chunk_list: + chunk_16 = chunk_32.paired_chunk + assert chunk_16 is not None + chunk_16.payload.copy_(chunk_32.payload) for name, buf in persistent_buffers.items(): if buf is not None: @@ -668,12 +708,9 @@ def _init_chunks(self, param_order, strict_ddp_mode: bool, cpu_offload: bool, pi p.data = p.data.to(device=get_current_device(), dtype=self.mixed_precision) continue - # create a fp32 parameter - fp32_p = p.data.float() # create a fp16 parameter p.data = p.data.to(self.mixed_precision) - - # register the fp16 parameter and fp32 parameter in the chunk manager + # register the fp16 parameter self.chunk_manager.register_tensor( tensor=p, group_type="fp16_param", @@ -682,22 +719,27 @@ def _init_chunks(self, param_order, strict_ddp_mode: bool, cpu_offload: bool, pi cpu_offload=cpu_offload, pin_memory=pin_memory, ) - self.chunk_manager.register_tensor( - tensor=fp32_p, - group_type="fp32_param", - config_key=dp_world_size, - process_group=self.dp_process_group, - cpu_offload=cpu_offload, - pin_memory=pin_memory, - ) - self.fp16_params.append(p) - self.fp32_params.append(fp32_p) + + if self.master_weights: + # create a fp32 parameter + fp32_p = p.data.float() + self.chunk_manager.register_tensor( + tensor=fp32_p, + group_type="fp32_param", + config_key=dp_world_size, + process_group=self.dp_process_group, + cpu_offload=cpu_offload, + pin_memory=pin_memory, + ) + self.fp32_params.append(fp32_p) self.chunk_manager.close_all_groups() self.gemini_manager.setup_grads_device(self.fp16_params, self.grads_device) + # move master weights to corresponding device and setup paired chunks + # if no master weights, fp32_params should be empty and this loop will be skipped for p, fp32_p in zip(self.fp16_params, self.fp32_params): chunk_16 = self.chunk_manager.get_chunk(p) chunk_32 = self.chunk_manager.get_chunk(fp32_p) @@ -734,7 +776,6 @@ def state_dict_shard( keep_vars: bool = False, max_shard_size: int = 1024, only_rank_0: bool = True, - dtype: torch.dtype = torch.float16, ) -> Iterator[Tuple[OrderedDict, int]]: """Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``. @@ -769,11 +810,11 @@ def state_dict_shard( gathered_param = param if keep_vars else param.detach() else: # as memory of fp16 param may be reused, we should use fp32 param and then convert to fp16 - fp32_param = fp16_to_fp32[param] - if fp32_param not in gathered_param_buffer: - chunk = self.chunk_manager.get_chunk(fp32_param) - gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0, dtype)) - gathered_param = gathered_param_buffer.pop(fp32_param) + param_to_save = fp16_to_fp32[param] if self.reuse_fp16_chunk else param + if param_to_save not in gathered_param_buffer: + chunk = self.chunk_manager.get_chunk(param_to_save) + gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0)) + gathered_param = gathered_param_buffer.pop(param_to_save) block, block_size = sharder.append_param(prefix + name, gathered_param) if block is not None: diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py index 1aece99541b9..0d0298e067f3 100644 --- a/colossalai/zero/gemini/gemini_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -105,7 +105,7 @@ def __init__( self.gemini_manager = module.gemini_manager self.chunk_manager: ChunkManager = self.gemini_manager.chunk_manager self.param_to_range: Dict[Parameter, Tuple[int, int]] = dict() - self.param_to_chunk32: Dict[Parameter, Chunk] = dict() + self.param_to_chunk16: Dict[Parameter, Chunk] = dict() self.chunk16_set: Set[Chunk] = set() self.clipping_flag = max_norm > 0.0 self.max_norm = max_norm @@ -130,7 +130,7 @@ def __init__( else: ddp_param_list.append(param) - for p, fp32_p in zip(ddp_param_list, module.fp32_params): + for p in ddp_param_list: chunk_16 = self.chunk_manager.get_chunk(p) if chunk_16 not in self.chunk16_set: chunk_16.l2_norm_flag = self.clipping_flag @@ -174,13 +174,15 @@ def __init__( def _set_grad_ptr(self): for group in self.param_groups: for fake_param in group["params"]: - chunk32 = self.param_to_chunk32[fake_param] + chunk16 = self.param_to_chunk16[fake_param] begin, end = self.param_to_range[fake_param] - chunk16 = chunk32.paired_chunk - fake_param.data = chunk16.payload[begin:end] + grad_chunk16 = chunk16 if self.module.reuse_fp16_chunk else chunk16.grad_chunk + fake_param.data = grad_chunk16.payload[begin:end] fake_param.grad = fake_param.data - fake_param.data = chunk32.payload[begin:end] + + to_update_chunk = chunk16.paired_chunk if self.module.master_weights else chunk16 + fake_param.data = to_update_chunk.payload[begin:end] def _update_fp16_params(self): none_tensor = torch.empty([0]) @@ -194,23 +196,25 @@ def _update_fp16_params(self): def _clear_global_norm(self) -> None: for c16 in self.chunk16_set: - c16.l2_norm = None + grad_chunk = c16 if self.module.reuse_fp16_chunk else c16.grad_chunk + grad_chunk.l2_norm = None def _calc_global_norm(self) -> float: norm_sqr: float = 0.0 group_to_norm = dict() for c16 in self.chunk16_set: - assert c16.l2_norm is not None + grad_chunk = c16 if self.module.reuse_fp16_chunk else c16.grad_chunk + assert grad_chunk.l2_norm is not None - if c16.is_gathered: - norm_sqr += c16.l2_norm + if grad_chunk.is_gathered: + norm_sqr += grad_chunk.l2_norm else: # this chunk is sharded, use communication to collect total norm - if c16.torch_pg not in group_to_norm: - group_to_norm[c16.torch_pg] = 0.0 - group_to_norm[c16.torch_pg] += c16.l2_norm + if grad_chunk.torch_pg not in group_to_norm: + group_to_norm[grad_chunk.torch_pg] = 0.0 + group_to_norm[grad_chunk.torch_pg] += grad_chunk.l2_norm - c16.l2_norm = None # clear l2 norm + grad_chunk.l2_norm = None # clear l2 norm comm_buffer = torch.zeros(1, dtype=torch.float, device=get_current_device()) for group, part_norm in group_to_norm.items(): @@ -237,7 +241,8 @@ def zero_grad(self, *args, **kwargs): return self.optim.zero_grad(set_to_none=True) def step(self, *args, **kwargs): - self._maybe_move_fp32_params() + if self.module.master_weights: + self._maybe_move_fp32_params() self._set_grad_ptr() if self.mix_precision_mixin.should_skip_step(): @@ -245,7 +250,8 @@ def step(self, *args, **kwargs): self._logger.info(f"Found overflow. Skip step") self._clear_global_norm() # clear recorded norm self.zero_grad() # reset all gradients - self._update_fp16_params() + if self.module.reuse_fp16_chunk: + self._update_fp16_params() return # get combined scale. combined scale = loss scale * clipping norm @@ -255,7 +261,9 @@ def step(self, *args, **kwargs): ret = self.optim.step(div_scale=combined_scale, *args, **kwargs) self._register_states() self.zero_grad() - self._update_fp16_params() + if self.module.master_weights: + self._update_fp16_params() + self.module.accumulating_grads = False return ret def clip_grad_norm(self, model: torch.nn.Module, max_norm: float, norm_type: float = 2.0): @@ -282,8 +290,8 @@ def _maybe_move_fp32_params(self): for group in self.param_groups: for fake_param in group["params"]: - chunk32 = self.param_to_chunk32[fake_param] - chunk16 = chunk32.paired_chunk + chunk16 = self.param_to_chunk16[fake_param] + chunk32 = chunk16.paired_chunk if chunk32.device_type == "cuda": continue @@ -297,7 +305,8 @@ def _maybe_move_fp32_params(self): for group in self.param_groups: for fake_param in group["params"]: - chunk32 = self.param_to_chunk32[fake_param] + chunk16 = self.param_to_chunk16[fake_param] + chunk32 = chunk16.paired_chunk if chunk32.device_type == "cuda": state = self.optim.state[fake_param] for k, v in state.items(): @@ -341,7 +350,7 @@ def get_range_pair(local_chunk: Chunk, local_param: Parameter): continue grad_device = self.module.grads_device[param] fake_param = torch.nn.Parameter(torch.empty([0], device=grad_device)) - self.param_to_chunk32[fake_param] = chunk16.paired_chunk + self.param_to_chunk16[fake_param] = chunk16 self.param_to_range[fake_param] = range_pair self.id_to_fake_params[param_id] = fake_param fake_params_list.append(fake_param) @@ -366,7 +375,7 @@ def get_offsets(self, param_id: int) -> tuple: if param_id not in self.id_to_fake_params: return -1, -1, -1 fake_param = self.id_to_fake_params[param_id] - chunk = self.param_to_chunk32[fake_param].paired_chunk + chunk = self.param_to_chunk16[fake_param] param = self.id_to_real_params[param_id] param_info = chunk.tensors_info[param] diff --git a/colossalai/zero/gemini/utils.py b/colossalai/zero/gemini/utils.py index 264099d22de2..5305953fe1ee 100644 --- a/colossalai/zero/gemini/utils.py +++ b/colossalai/zero/gemini/utils.py @@ -11,7 +11,7 @@ from .chunk import Chunk -def get_temp_total_chunk_on_cuda(chunk: Chunk): +def get_temp_total_chunk_on_cuda(chunk: Chunk, dtype: torch.dtype): if chunk.is_gathered: return chunk.cuda_global_chunk @@ -20,7 +20,9 @@ def get_temp_total_chunk_on_cuda(chunk: Chunk): else: shard_temp = chunk.cpu_shard.to(get_current_device()) - total_temp = torch.zeros(chunk.chunk_size, dtype=chunk.dtype, device=get_current_device()) + shard_temp = shard_temp.to(dtype) + + total_temp = torch.zeros(chunk.chunk_size, dtype=dtype, device=get_current_device()) gather_list = list(torch.chunk(input=total_temp, chunks=chunk.pg_size, dim=0)) dist.all_gather(tensor_list=gather_list, tensor=shard_temp, group=chunk.torch_pg) diff --git a/colossalai/zero/low_level/_utils.py b/colossalai/zero/low_level/_utils.py index 0a15f8ddd718..de08ecf3d57f 100644 --- a/colossalai/zero/low_level/_utils.py +++ b/colossalai/zero/low_level/_utils.py @@ -3,9 +3,7 @@ import torch import torch.distributed as dist -from torch import Tensor, inf from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors -from torch.distributed import ProcessGroup def flatten(input_): @@ -192,53 +190,6 @@ def calculate_global_norm_from_list(norm_list): total_norm += norm**2.0 return math.sqrt(total_norm) - -def compute_norm(gradients: Tensor, dp_group: ProcessGroup, tp_group: ProcessGroup, norm_type: int = 2) -> int: - """Clips gradient norm of an iterable of parameters. - This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and - added functionality to handle model parallel parameters. - - Args: - gradients (Tensor): The gradients to compute norm - dp_group (ProcessGroup): The process group of ZeRO Data Parallelism - tp_group (ProcessGroup): The process group of Tensor Parallelism - norm_type (int, optional): type of the used p-norm, Can be ``'inf'`` for infinity norm. Defaults to 2. - - Returns: - int: The total norm of given gradients - """ - - norm_type = float(norm_type) - if norm_type == inf: - total_norm = max(g.data.abs().max() for g in gradients) - total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) - dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=dp_group) - - # Take max across all GPUs. - if tp_group is not None: - dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.MAX) - total_norm = total_norm_cuda[0].item() - else: - total_norm = 0.0 - for g in gradients: - param_norm = g.data.double().norm(norm_type) - total_norm += param_norm.item() ** norm_type - - # Sum across all model parallel GPUs. - total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) - torch.distributed.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=dp_group) - - if tp_group is not None: - dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=tp_group) - - total_norm = total_norm_cuda[0].item() ** (1.0 / norm_type) - - if total_norm == float("inf") or total_norm == -float("inf") or total_norm != total_norm: - total_norm = -1 - - return total_norm - - def sync_tensor(flat_tensor, tensor_list): """ Synchronize the flattened tensor and unflattened tensor list. When diff --git a/colossalai/zero/low_level/bookkeeping/gradient_store.py b/colossalai/zero/low_level/bookkeeping/gradient_store.py index 3ce688cfa930..1164532fa3a3 100644 --- a/colossalai/zero/low_level/bookkeeping/gradient_store.py +++ b/colossalai/zero/low_level/bookkeeping/gradient_store.py @@ -21,6 +21,8 @@ def __init__(self, *args, partition_grad: bool = False): # for zero2, it's `param_id: [grad_local_rank]` self._working_index = 0 if partition_grad else self._local_rank + self.grad_to_param_mapping = dict() + def get_partitioned_gradients_by_param_id(self, group_id: int, param_id: int) -> List: """Return list of gradient slices of a specific parameter @@ -54,6 +56,8 @@ def append_gradients_by_param_id(self, grad: Tensor, group_id: int, param_id: in else: self._grads_of_params[group_id][param_id].append(grad) + self.grad_to_param_mapping[id(grad)] = param_id + def add_gradients_by_param_id(self, grad: Tensor, grad_idx: int, group_id: int, param_id: int): """Add a gradient slice on an existing slice of the parameter's gradient Used when no_sync is not activated. @@ -83,8 +87,37 @@ def get_working_grads_by_group_id(self, group_id: int) -> List: return grad_list + def get_working_grad_by_param_id(self, param_id) -> Tensor: + """ + Return the working gradient for the specified parameter. + + Args: + param_id (int): The index of the parameter. + + Returns: + Tensor: The the working gradient slices for the specified param_id. + """ + + for group in self._grads_of_params.values(): + if param_id in group.keys(): + return group[param_id][self._working_index] + + raise KeyError(f"Working gradient for param_id {param_id} not found.") + def reset_grads_by_group_id(self, group_id: int): self._grads_of_params[group_id] = dict() def reset_all_gradients(self): self._grads_of_params = dict() + + def get_param_id_for_grad(self, grad: Tensor) -> int: + """Return the id of a parameter which the gradient slice belongs to + + Args: + grad (Tensor): the gradient slice + + Returns: + int: the id of a parameter which the gradient slice belongs to + """ + + return self.grad_to_param_mapping[id(grad)] diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 72df93ace302..e6974a6760ce 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -2,11 +2,12 @@ import copy from contextlib import contextmanager from functools import partial -from typing import Dict, Iterator, Optional, Tuple +from typing import Dict, Iterator, List, Optional, Tuple import torch import torch.distributed as dist import torch.nn as nn +from torch import Tensor, inf from torch.distributed import ProcessGroup from torch.optim import Optimizer @@ -21,14 +22,7 @@ # from colossalai.tensor import ColoParameter, ProcessGroup from colossalai.utils.cuda import get_current_device -from ._utils import ( - calculate_global_norm_from_list, - compute_norm, - flatten, - has_inf_or_nan, - release_param_grad, - sync_tensor, -) +from ._utils import calculate_global_norm_from_list, flatten, has_inf_or_nan, release_param_grad, sync_tensor from .bookkeeping import BucketStore, GradientStore, ParameterStore @@ -80,8 +74,8 @@ def __init__( partition_grad: bool = False, # stage 2 flag cpu_offload: bool = False, # cpu offload dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm - tp_process_group: Optional[ProcessGroup] = None, # if using tp forced_dtype: Optional[torch.dtype] = None, + master_weights: bool = True, # master weights ): super(LowLevelZeroOptimizer, self).__init__(optim=optimizer) self._dtype = self.optim.param_groups[0]["params"][0].dtype @@ -101,8 +95,6 @@ def __init__( self._local_rank = dist.get_rank(group=self.dp_pg) self._world_size = dist.get_world_size(group=self.dp_pg) - self.tp_pg = tp_process_group - # working and master params for mixed precision training self._working_param_groups = dict() self._master_param_groups_of_current_rank = dict() @@ -115,6 +107,9 @@ def __init__( # gradient clipping self._clip_grad_norm = clip_grad_norm + # master weights copy + self._master_weights = master_weights + if forced_dtype: for group in self.optim.param_groups: group_params = group["params"] @@ -144,7 +139,6 @@ def __init__( self._working_param_groups[group_id] = group_params master_param_current_rank = self._create_master_param_current_rank(group_params) - self._master_param_groups_of_current_rank[group_id] = master_param_current_rank # need to replace the params in the `params` field in the optimizer @@ -209,11 +203,18 @@ def _create_master_param_current_rank(self, param_list): with torch.no_grad(): if padding_size > 0: padding_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size]) + # reset working params' ptr when no master weights + if self._master_weights == False: + param.data = padding_param[: param.numel()].view(param.shape) else: padding_param = param.data.view(-1) splited_params = padding_param.split(padding_param.numel() // self._world_size) - splited_param_current_rank = splited_params[self._local_rank].detach().float().to(device) + # use fp32 when master_weights is True + if self._master_weights is True: + splited_param_current_rank = splited_params[self._local_rank].detach().float().to(device) + else: + splited_param_current_rank = splited_params[self._local_rank] params_current_rank.append(splited_param_current_rank) self._param_store.link_master_and_working_param(splited_param_current_rank, param) @@ -411,9 +412,7 @@ def step(self, closure=None): # and should not be updated real_working_params = dict() real_master_params = dict() - grad_index = 0 if self._partition_grads else self._local_rank - for group_id in range(self.num_param_groups): master_params = self._master_param_groups_of_current_rank[group_id] real_working_params[group_id] = [] @@ -426,14 +425,19 @@ def step(self, closure=None): grads = self._grad_store.get_partitioned_gradients_by_param_id(group_id, id(working_param)) if len(grads) > 0: real_working_params[group_id].append(working_param) - grad = grads[grad_index].to(splited_param.dtype).to(splited_param.device) + # no need to copy fp32 grad if master_weights is False + grad = ( + grads[grad_index].to(splited_param.dtype).to(splited_param.device) + if self._master_weights + else grads[grad_index] + ) splited_param.grad = grad grad_partition_groups.append(grad) real_master_params[group_id].append(splited_param) # compute norm working_grads = self._grad_store.get_working_grads_by_group_id(group_id) - norm_group = compute_norm(gradients=working_grads, dp_group=self.dp_pg, tp_group=self.tp_pg) + norm_group = self._compute_grad_norm(gradients=working_grads) norm_groups.append(norm_group) self._grad_store.reset_grads_by_group_id(group_id) @@ -454,19 +458,56 @@ def step(self, closure=None): release_param_grad(self._master_param_groups_of_current_rank[group_id]) # update working partition updated by the current rank - dtype = real_working_params[0][0].dtype + # dtype = real_working_params[0][0].dtype for group_id in range(self.num_param_groups): master_working_param = self.optim.param_groups[group_id]["params"] for idx, splited_param in enumerate(master_working_param): working_param = real_working_params[group_id][idx] all_splited_param = [ - torch.zeros(splited_param.shape, device="cuda", dtype=dtype) for _ in range(self._world_size) + torch.zeros(splited_param.shape, device="cuda", dtype=self._dtype) for _ in range(self._world_size) ] - dist.all_gather(all_splited_param, splited_param.cuda().to(dtype), group=self.dp_pg) + dist.all_gather(all_splited_param, splited_param.cuda().to(self._dtype), group=self.dp_pg) working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param)) - self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id] + def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float: + r""" + Compute and return the gradient norm for gradient clipping. + + Args: + gradients (List[Tensor]): The gradients to compute norm + norm_type (int, optional): type of the used p-norm, Can be ``'inf'`` for infinity norm. Defaults to 2. + + Returns: + float: The total norm of given gradients + """ + + if len(gradients) == 0: + return 0.0 + + norm_type = float(norm_type) + if norm_type == inf: + total_norm = max(grad.data.abs().max() for grad in gradients) + + total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) + dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=self.dp_pg) + total_norm = total_norm_cuda.item() + + else: + total_norm_exponentiated = 0.0 + for grad in gradients: + grad_norm_exponentiated = grad.data.double().norm(norm_type) ** norm_type + total_norm_exponentiated += grad_norm_exponentiated + + # Sum across all model parallel GPUs. + total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)]) + torch.distributed.all_reduce( + total_norm_exponentiated_cuda, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg + ) + total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type) + + return total_norm + ############################# # Mixed Precision Utilities # ############################# diff --git a/docs/sidebars.json b/docs/sidebars.json index 45e86afc1f61..123211db5897 100644 --- a/docs/sidebars.json +++ b/docs/sidebars.json @@ -64,7 +64,6 @@ "label": "Advanced Tutorials", "collapsed": true, "items": [ - "advanced_tutorials/train_vit_using_pipeline_parallelism", "advanced_tutorials/train_vit_with_hybrid_parallelism", "advanced_tutorials/train_gpt_using_hybrid_parallelism", "advanced_tutorials/meet_gemini", diff --git a/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md b/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md index 0218264cc258..7a0e3b1a0276 100644 --- a/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md +++ b/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md @@ -1,10 +1,14 @@ -# Train GPT Using Hybrid Parallelism +# Fine-tune GPT-2 Using Hybrid Parallelism -Author: Hongxin Liu, Yongbin Li +Author: Hongxin Liu, Yongbin Li, Mingyan Jiang + +**Prerequisite:** +- [parallelism plugin](../basics/booster_plugins.md) +- [booster API](../basics/booster_api.md) **Example Code** -- [ColossalAI-Examples GPT2](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/language/gpt_2) -- [ColossalAI-Examples GPT3](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/language/gpt_3) +- [ColossalAI-Examples GPT](https://github.com/hpcaitech/ColossalAI/blob/main/examples/language/gpt/hybridparallelism/finetune.py) + **Related Paper** - [Colossal-AI: A Unified Deep Learning System For Large-Scale Parallel Training](https://arxiv.org/abs/2110.14883) @@ -12,260 +16,192 @@ Author: Hongxin Liu, Yongbin Li ## Introduction -In the previous tutorial, we introduce how to train ViT with pipeline. In this tutorial, you will learn a more complex scenario -- train GPT with hybrid parallelism. In this case, GPT-3 is so large that CPU memory cannot fit it as well. Therefore, you must split the model by yourself. +In the previous tutorial, we introduce how to train ViT with pipeline. In this tutorial, you will learn a more complex scenario -- fine-tune GPT-2 with hybrid parallelism. In this case, GPT-2 is so large that CPU memory cannot fit it as well. Therefore, you must split the model. ## Table of content In this tutorial we will cover: -1. The definition of GPT model, based on colossalai/model_zoo -2. Processing the dataset -3. Training GPT using hybrid parallelism +1. Initialize the hybrid parallelism plugin. +2. Defining the Training Components of the GPT-2 Model +3. Boost the GPT-2 Model with [`HybridParallelPlugin`](../basics/booster_plugins.md) +4. Training GPT-2 using hybrid parallelism ## Import libraries ```python -import json -import os -from typing import Callable - -import colossalai -import colossalai.utils as utils -import model_zoo.gpt.gpt as col_gpt +from typing import Callable, List, Union import torch +import torch.distributed as dist import torch.nn as nn -from colossalai import nn as col_nn -from colossalai.amp import AMP_TYPE -from colossalai.legacy.builder.pipeline import partition_uniform -from colossalai.legacy.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.legacy.engine.schedule import (InterleavedPipelineSchedule, - PipelineSchedule) -from colossalai.logging import disable_existing_loggers, get_dist_logger -from colossalai.legacy.nn.layer.wrapper import PipelineSharedModuleWrapper -from colossalai.legacy.trainer import Trainer, hooks -from colossalai.utils.timer import MultiTimer -from model_zoo.gpt import GPTLMLoss -from torch.nn import functional as F -from torch.utils.data import Dataset -from transformers import GPT2Tokenizer -``` - - - -## Define GPT model - -In the previous tutorial, we introduced 3 ways to build a pipelined model. But for huge models like GPT-3, you can't even build the model in CPU. In this case, you must split the model by yourself. - -GPT dataloader returns `input_ids` and `attention_mask`, so we use two keyword arguments in `forward()` to get them. Note that for stages except the first stage, the first positional argument of `forward()` is the output tensor from the previous stage. So the `hidden_states` is from the previous stage, and for the first stage it's `None`. - -For GPT, the *word embedding layer* shares the weights with the *output head*. We provide `PipelineSharedModuleWrapper` to share parameters among pipeline stages. It takes a `list` of `int` as argument, which means those ranks share the parameters. You can use `register_module()` or `register_parameter()` to register a module or a parameter as the shared module or parameter. If you have multiple sets of shared modules / parameters, you should have multiple `PipelineSharedModuleWrapper` instance. If the parameter is shared within **one** stage, you should not use `PipelineSharedModuleWrapper`, and just use the same module / parameter instance. In this example, the *word embedding layer* is at the first stage, and the *output head* is at the last stage. Thus, they are shared among ranks `[0, pipeline_size - 1]`. - -For the first stage, it maintains the embedding layer and some transformer blocks. For the last stage, it maintains some transformer blocks and the output head layer. For other stages, they just maintain some transformer blocks. `partition_uniform(num_layers, pipeline_size, num_chunks)` returns the parts of all ranks, and the part is a `tuple` of `(start, end)` (exclude end). `start == 0` means that it's the first stage, and `end == num_layers` means it's the last stage. +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler as LRScheduler +from tqdm import tqdm +from transformers import AutoConfig, GPT2ForSequenceClassification, get_linear_schedule_with_warmup +from transformers import AutoTokenizer +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin +from colossalai.cluster import DistCoordinator +from colossalai.nn.optimizer import HybridAdam +from colossalai.utils import get_current_device +``` +## Define Plugin +Create a `HybridParallelPlugin` object and specify the desired parallelism strategies to be used. In this example, both pipeline parallelism and ZeRO-1 are used simultaneously. ```python -class PipelineGPTHybrid(nn.Module): - def __init__(self, - num_layers: int = 12, - hidden_size: int = 768, - num_attention_heads: int = 12, - vocab_size: int = 50304, - embed_drop_rate: float = 0., - act_func: Callable = F.gelu, - mlp_ratio: int = 4, - attn_drop_rate: float = 0., - drop_rate: float = 0., - dtype: torch.dtype = torch.float, - checkpoint: bool = False, - max_position_embeddings: int = 1024, - layer_norm_epsilon: float = 1e-5, - first: bool = False, - last: bool = False): - super().__init__() - self.embedding = None - self.norm = None - self.head = None - if first: - self.embedding = col_gpt.GPTEmbedding( - hidden_size, vocab_size, max_position_embeddings, dropout=embed_drop_rate, dtype=dtype) - self.blocks = nn.ModuleList([ - col_gpt.GPTBlock(hidden_size, num_attention_heads, mlp_ratio=mlp_ratio, attention_dropout=attn_drop_rate, - dropout=drop_rate, dtype=dtype, checkpoint=checkpoint, activation=act_func) - for _ in range(num_layers) - ]) - if last: - self.norm = col_nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) - self.head = col_gpt.GPTLMHead(vocab_size=vocab_size, - dim=hidden_size, - dtype=dtype, - bias=False) - - def forward(self, hidden_states=None, input_ids=None, attention_mask=None): - if self.embedding is not None: - hidden_states = self.embedding(input_ids=input_ids) - batch_size = hidden_states.shape[0] - attention_mask = attention_mask.view(batch_size, -1) - attention_mask = attention_mask[:, None, None, :] - attention_mask = attention_mask.to(dtype=hidden_states.dtype) # fp16 compatibility - attention_mask = (1.0 - attention_mask) * -10000.0 - for block in self.blocks: - hidden_states, attention_mask = block(hidden_states, attention_mask) - if self.norm is not None: - hidden_states = self.head(self.norm(hidden_states)) - return hidden_states - - -def build_gpt_pipeline(num_layers, num_chunks, device=torch.device('cuda'), **kwargs): - logger = get_dist_logger() - pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) - pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) - rank = gpc.get_global_rank() - wrapper = PipelineSharedModuleWrapper([0, pipeline_size - 1]) - parts = partition_uniform(num_layers, pipeline_size, num_chunks)[pipeline_rank] - models = [] - for start, end in parts: - kwargs['num_layers'] = end - start - kwargs['first'] = start == 0 - kwargs['last'] = end == num_layers - logger.info(f'Rank{rank} build layer {start}-{end}, {end-start}/{num_layers} layers') - chunk = PipelineGPTHybrid(**kwargs).to(device) - if start == 0: - wrapper.register_module(chunk.embedding.word_embeddings) - elif end == num_layers: - wrapper.register_module(chunk.head) - models.append(chunk) - if len(models) == 1: - model = models[0] - else: - model = nn.ModuleList(models) - return model - - -def GPT2_exlarge_pipeline_hybrid(num_chunks=1, checkpoint=False, dtype=torch.float): - cfg = dict(hidden_size=1600, num_attention_heads=32, checkpoint=checkpoint, dtype=dtype) - return build_gpt_pipeline(48, num_chunks, **cfg) - - -def GPT3_pipeline_hybrid(num_chunks=1, checkpoint=False, dtype=torch.float): - cfg = dict(hidden_size=12288, num_attention_heads=96, - checkpoint=checkpoint, max_position_embeddings=2048, dtype=dtype) - return build_gpt_pipeline(96, num_chunks, **cfg) +plugin = HybridParallelPlugin( + tp_size=1, + pp_size=2, + num_microbatches=None, + microbatch_size=1, + enable_all_optimization=True, + zero_stage=1, + precision="fp16", + initial_scale=1, +) ``` +## Define GPT-2's Training Components -## Process the dataset - -We provide a small GPT web-text dataset here. The original format is loose JSON, and we will save the processed dataset. +Before using hybrid parallelism, you need to define the components used for training. +Define hyperparameters ```python -class WebtextDataset(Dataset): - def __init__(self, path, seq_len=1024) -> None: - super().__init__() - root = os.path.dirname(path) - encoded_data_cache_path = os.path.join(root, f'gpt_webtext_{seq_len}.pt') - if os.path.isfile(encoded_data_cache_path): - seq_len_, data, attention_mask = torch.load( - encoded_data_cache_path) - if seq_len_ == seq_len: - self.data = data - self.attention_mask = attention_mask - return - raw_data = [] - with open(path) as f: - for line in f.readlines(): - raw_data.append(json.loads(line)['text']) - tokenizer = GPT2Tokenizer.from_pretrained('gpt2') - tokenizer.pad_token = tokenizer.unk_token - encoded_data = tokenizer( - raw_data, padding=True, truncation=True, max_length=seq_len, return_tensors='pt') - self.data = encoded_data['input_ids'] - self.attention_mask = encoded_data['attention_mask'] - torch.save((seq_len, self.data, self.attention_mask), - encoded_data_cache_path) - - def __len__(self): - return len(self.data) - - def __getitem__(self, index): - return { - 'input_ids': self.data[index], - 'attention_mask': self.attention_mask[index] - }, self.data[index] +NUM_EPOCHS = 3 +BATCH_SIZE = 32 +LEARNING_RATE = 2.4e-5 +WEIGHT_DECAY = 0.01 +WARMUP_FRACTION = 0.1 ``` - -## Training GPT using hybrid parallelism - -In the previous tutorial, we explained the meanings of some pipeline arguments. In this case, we can determine the shape of each output tensor which is exchanged among pipeline stages. For GPT, the shape is `(MICRO BATCH SIZE, SEQUENCE LEN, HIDDEN SIZE)`. By setting this, we can avoid exchanging the tensor shape of each stage. When you are not sure of the tensor shape, you can just leave it `None`, and the shape is inferred automatically. Make sure that the `dtype` of your model is correct. When you use `fp16`, the `dtype` of your model must be `torch.half`. Otherwise, the `dtype` must be `torch.float`. For pipeline parallelism, only `AMP_TYPE.NAIVE` is supported. - -You can easily use tensor parallel by setting `parallel` in `CONFIG`. The data parallelism size is automatically set based on the number of GPUs. - +we create a distributed environment. ```python -NUM_EPOCHS = 60 -SEQ_LEN = 1024 -BATCH_SIZE = 192 -NUM_CHUNKS = None -TENSOR_SHAPE = (1, 1024, 1600) -# only pipeline parallel -# CONFIG = dict(parallel=dict(pipeline=2), fp16=dict(mode=AMP_TYPE.NAIVE)) -# pipeline + 1D model parallel -CONFIG = dict(NUM_MICRO_BATCHES = 192, parallel=dict(pipeline=2, tensor=dict(mode='1d', size=2)), fp16=dict(mode=AMP_TYPE.NAIVE)) - - -def train(): - disable_existing_loggers() - parser = colossalai.get_default_parser() - args = parser.parse_args() - colossalai.launch_from_torch(config=CONFIG, backend=args.backend) - logger = get_dist_logger() - - train_ds = WebtextDataset(os.environ['DATA'], seq_len=SEQ_LEN) - train_dataloader = utils.get_dataloader(train_ds, - seed=42, - batch_size=BATCH_SIZE, - pin_memory=True, - shuffle=True, - drop_last=True) - - use_interleaved = NUM_CHUNKS is not None - num_chunks = 1 if not use_interleaved else NUM_CHUNKS - model = GPT2_exlarge_pipeline_hybrid(num_chunks=num_chunks, checkpoint=True, dtype=torch.half) - # model = GPT3_pipeline_hybrid(num_chunks=num_chunks, checkpoint=True, dtype=torch.half) - if use_interleaved and not isinstance(model, nn.ModuleList): - model = nn.ModuleList([model]) - - criterion = GPTLMLoss() - - optimizer = torch.optim.Adam(model.parameters(), lr=0.00015, weight_decay=1e-2,) +# Launch ColossalAI +colossalai.launch_from_torch(config={}, seed=42) +coordinator = DistCoordinator() +``` +prepare the dataset. You can use `plugin.prepare_dataloader` to generate a dataloader or customize your own dataloader. +```python +def tokenize_batch(batch, tokenizer: Optional[AutoTokenizer] = None, max_length: int = 2048): + texts = [sample["sentence1"] + sample["sentence2"] for sample in batch] + data = tokenizer(texts, return_tensors="pt", padding="max_length", truncation=True, max_length=max_length) + data = {k: v.cuda() for k, v in data.items()} + data["labels"] = data["input_ids"].clone() + return data + +tokenizer = AutoTokenizer.from_pretrained("gpt2") +dataset = datasets.load_dataset("glue", "mrpc") +train_dataloader = plugin.prepare_dataloader( + dataset["train"], + batch_size=BATCH_SIZE, + shuffle=True, + drop_last=True, + collate_fn=partial(tokenize_batch, tokenizer=tokenizer, max_length=512), +) +``` +Prepare gpt-2 model +```python +cfg = AutoConfig.from_pretrained("gpt2", num_labels=2) +model = GPT2ForSequenceClassification.from_pretrained("gpt2", config=cfg).cuda() - engine, train_dataloader, _, _ = colossalai.initialize(model, - optimizer, - criterion, - train_dataloader=train_dataloader) - global_batch_size = BATCH_SIZE * \ - gpc.get_world_size(ParallelMode.DATA) * getattr(gpc.config, "gradient_accumulation", 1) - logger.info(f'Init done, global batch size = {global_batch_size}', ranks=[0]) +``` +prepare optimizer +```python +lr = LEARNING_RATE * coordinator.world_size +no_decay = ["bias", "LayerNorm.weight"] +optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], + "weight_decay": WEIGHT_DECAY, + }, + { + "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, +] +optimizer = HybridAdam(optimizer_grouped_parameters, lr=lr, eps=1e-8) +``` +Prepare the lr_scheduler and criterion, and it's important to note that when hybrid parallelism with pipeline parallelism is used, a criterion function should also be defined. This function should take the input and output of the model's forward pass as parameters and return the loss. +```python +# lr scheduler +total_steps = len(train_dataloader) * NUM_EPOCHS +num_warmup_steps = int(WARMUP_FRACTION * total_steps) +lr_scheduler = get_linear_schedule_with_warmup( + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=total_steps, +) + +def _criterion(outputs, inputs): + return outputs.loss +``` +## Boost the GPT-2 Model +Define a booster with `HybridParallelPlugin`. Based on the configured plugin parameters, the booster will inject one or more parallel strategies into the model. In this example, pipeline parallelism, zero1, and mixed-precision training optimizations are utilized. +```python +booster = Booster(plugin=plugin) +``` +Boost these components with the defined booster. +```python +model, optimizer, _criterion, _, lr_scheduler = booster.boost( + model, optimizer, criterion=_criterion, lr_scheduler=lr_scheduler +) +``` - timer = MultiTimer() - trainer = Trainer( - engine=engine, - logger=logger, - timer=timer - ) +## Training GPT-2 using hybrid parallelism - hook_list = [ - hooks.LossHook(), - hooks.LogMetricByEpochHook(logger), - hooks.ThroughputHook(), - hooks.LogMetricByStepHook(), - ] +In the previous tutorial, We've explained how to inject various parallelism features into the model and its training components using the Booster and `HybridParallelPlugin`. Now we can start model training. +Define a training function. When pipeline parallelism is used, you need to call `booster.execute_pipeline` to schedule the stages of model training. +```python +def train_epoch( + epoch: int, + model: nn.Module, + optimizer: Optimizer, + _criterion: Callable, + lr_scheduler: LRScheduler, + train_dataloader: DataLoader, + booster: Booster, + coordinator: DistCoordinator, +): + use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1 + is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() + print_flag = (not use_pipeline and coordinator.is_master()) or (use_pipeline and is_pp_last_stage) + total_step = len(train_dataloader) + + model.train() + optimizer.zero_grad() + train_dataloader_iter = iter(train_dataloader) + with tqdm( + range(total_step), + desc=f"Epoch [{epoch + 1}/{NUM_EPOCHS}]", + disable=not print_flag, + ) as pbar: + # Forward pass + for _ in pbar: + if use_pipeline: + outputs = booster.execute_pipeline( + train_dataloader_iter, model, _criterion, optimizer, return_loss=True, return_outputs=True + ) + # Backward and optimize + if is_pp_last_stage: + loss = outputs["loss"] + pbar.set_postfix({"loss": loss.item()}) + else: + data = next(train_dataloader_iter) + data = move_to_cuda(data) + outputs = model(**data) + loss = _criterion(outputs, None) + # Backward + booster.backward(loss, optimizer) + pbar.set_postfix({"loss": loss.item()}) + + optimizer.step() + optimizer.zero_grad() + lr_scheduler.step() - trainer.fit( - train_dataloader=train_dataloader, - epochs=NUM_EPOCHS, - test_interval=1, - hooks=hook_list, - display_progress=True, - return_output_label=False, - ) ``` - +Training the gpt-2 model +```python +for epoch in range(NUM_EPOCHS): + train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator) +``` + \ No newline at end of file diff --git a/docs/source/en/advanced_tutorials/train_vit_using_pipeline_parallelism.md b/docs/source/en/advanced_tutorials/train_vit_using_pipeline_parallelism.md deleted file mode 100644 index 6dbe338008fa..000000000000 --- a/docs/source/en/advanced_tutorials/train_vit_using_pipeline_parallelism.md +++ /dev/null @@ -1,248 +0,0 @@ -# Train ViT Using Pipeline Parallelism - -Author: Hongxin Liu, Yongbin Li - -**Example Code** -- [ColossalAI-Examples Pipeline Parallel ViT](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/image/vision_transformer/pipeline_parallel) - -**Related Paper** -- [Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM](https://arxiv.org/abs/2104.04473) - -## Introduction - -In this tutorial, you will learn how to train Vision Transformer for image classification from scratch, using pipeline. -Pipeline parallelism is a kind of model parallelism, which is useful when your GPU memory cannot fit your model. -By using it, we split the original model into multi stages, and each stage maintains a part of the original model. -We assume that your GPU memory cannot fit ViT/L-16, and your memory can fit this model. - -## Table of contents - -In this tutorial we will cover: - -1. The definition of ViT model, based on [TIMM](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py) -2. Processing the dataset -3. Training ViT using pipeline - -## Import libraries - -```python -import os -from collections import OrderedDict -from functools import partial - -import colossalai -import colossalai.nn as col_nn -import torch -import torch.nn as nn -from colossalai.legacy.builder import build_pipeline_model -from colossalai.legacy.engine.schedule import (InterleavedPipelineSchedule, - PipelineSchedule) -from colossalai.logging import disable_existing_loggers, get_dist_logger -from colossalai.legacy.trainer import Trainer, hooks -from colossalai.utils import MultiTimer, get_dataloader -from timm.models import vision_transformer as vit -from torchvision import transforms -from torchvision.datasets import CIFAR10 -``` - - - -## Define Vision Transformer model - -Generally, we provide 3 ways to build a pipelined model: - -1. `colossalai.legacy.builder.build_pipeline_model_from_cfg` -2. `colossalai.legacy.builder.build_pipeline_model` -3. Split the model by stages by yourself - -When your memory can fit the model, you can use the first two methods to build your model, otherwise you must split the model by yourself. The first two methods first build the whole model on CPU, then split the model, and finally you can just move the corresponding part of model to GPU. - -`colossalai.legacy.builder.build_pipeline_model_from_cfg()` receives a config file of model, and it can split the model uniformly (by layer) or balanced (by parameter size). - -If you are familiar with `PyTorch`, you can use `colossalai.legacy.builder.build_pipeline_model()` which receives a `torch.nn.Sequential` model and split it by layer uniformly. - -In this tutorial, we will modify [TIMM/ViT](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py) to `torch.nn.Sequential` and then use `colossalai.legacy.builder.build_pipeline_model()` to build the pipelined model. - -When the data is **one** `Tensor`, you can use the positional argument in `forward()` of your model to get the data tensor. For the first stage of pipeline, the first positional argument of `forward()` is the data tensor loaded from data loader. For other stages, the first positional argument of `forward()` is the output tensor from the previous stage. Note that if the stage is not the last stage, the return of `forward()` must be a `Tensor`. - -When the data is a `dict` of `Tensor`, you can use named keyword arguments in `forward()` of your model to get the data `dict`. - -```python -class ViTEmbedding(nn.Module): - def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, embed_layer=vit.PatchEmbed, drop_rate=0., distilled=False): - super().__init__() - self.embed_dim = embed_dim # num_features for consistency with other models - self.num_tokens = 2 if distilled else 1 - self.patch_embed = embed_layer( - img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) - num_patches = self.patch_embed.num_patches - - self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) - self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None - self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) - self.pos_drop = nn.Dropout(p=drop_rate) - self.init_weights() - - def forward(self, x): - x = self.patch_embed(x) - cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks - if self.dist_token is None: - x = torch.cat((cls_token, x), dim=1) - else: - x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1) - x = self.pos_drop(x + self.pos_embed) - return x - - def init_weights(self): - vit.trunc_normal_(self.pos_embed, std=.02) - if self.dist_token is not None: - vit.trunc_normal_(self.dist_token, std=.02) - vit.trunc_normal_(self.cls_token, std=.02) - self.apply(vit._init_vit_weights) - - -class ViTHead(nn.Module): - def __init__(self, embed_dim=768, num_classes=1000, norm_layer=None, distilled=False, representation_size=None): - super().__init__() - norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) - self.norm = norm_layer(embed_dim) - self.num_classes = num_classes - self.distilled = distilled - self.num_features = embed_dim - # Representation layer - if representation_size and not distilled: - self.num_features = representation_size - self.pre_logits = nn.Sequential(OrderedDict([ - ('fc', nn.Linear(embed_dim, representation_size)), - ('act', nn.Tanh()) - ])) - else: - self.pre_logits = nn.Identity() - # Classifier head(s) - self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() - self.head_dist = None - if distilled: - self.head_dist = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() - self.init_weights() - - def forward(self, x): - x = self.norm(x) - if self.distilled: - x, x_dist = self.head(x[:, 0]), self.head_dist(x[:, 1]) - if self.training and not torch.jit.is_scripting(): - # during inference, return the average of both classifier predictions - return x, x_dist - else: - return (x + x_dist) / 2 - else: - x = self.pre_logits(x[:, 0]) - x = self.head(x) - return x - - def init_weights(self): - self.apply(vit._init_vit_weights) - - -def sequential_vit(img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, - num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False, - drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=vit.PatchEmbed, norm_layer=None, - act_layer=None): - norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) - act_layer = act_layer or nn.GELU - embedding = ViTEmbedding(img_size=img_size, patch_size=patch_size, in_chans=in_chans, - embed_dim=embed_dim, embed_layer=embed_layer, drop_rate=drop_rate, distilled=distilled) - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule - blocks = [vit.Block( - dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, - attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer) - for i in range(depth)] - for block in blocks: - block.apply(vit._init_vit_weights) - head = ViTHead(embed_dim=embed_dim, num_classes=num_classes, norm_layer=norm_layer, - distilled=distilled, representation_size=representation_size) - return nn.Sequential(embedding, *blocks, head) - - -def vit_large_patch16_224(**kwargs): - model_kwargs = dict(embed_dim=1024, depth=24, num_heads=16, **kwargs) - return sequential_vit(**model_kwargs) -``` - -## Process the dataset - -Generally, we train ViT on large dataset like Imagenet. For simplicity, we just use CIFAR-10 here, since this tutorial is just for pipeline training. - -```python -def build_cifar(batch_size): - transform_train = transforms.Compose([ - transforms.RandomCrop(224, pad_if_needed=True), - transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10), - transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), - ]) - transform_test = transforms.Compose([ - transforms.Resize(224), - transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), - ]) - - train_dataset = CIFAR10(root=os.environ['DATA'], train=True, download=True, transform=transform_train) - test_dataset = CIFAR10(root=os.environ['DATA'], train=False, transform=transform_test) - train_dataloader = get_dataloader(dataset=train_dataset, shuffle=True, batch_size=batch_size, pin_memory=True) - test_dataloader = get_dataloader(dataset=test_dataset, batch_size=batch_size, pin_memory=True) - return train_dataloader, test_dataloader -``` - -## Training ViT using pipeline - -You can set the size of pipeline parallel and number of microbatches in config. `NUM_CHUNKS` is useful when using interleaved-pipeline (for more details see [Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM](https://arxiv.org/abs/2104.04473) ). The original batch will be split into `num_microbatches`, and each stage will load a micro batch each time. Then we will generate an appropriate schedule for you to execute the pipeline training. If you don't need the output and label of model, you can set `return_output_label` to `False` when calling `trainer.fit()` which can further reduce GPU memory usage. - -You should `export DATA=/path/to/cifar`. - -```python -BATCH_SIZE = 16 -NUM_EPOCHS = 60 -NUM_CHUNKS = 1 -CONFIG = dict(NUM_MICRO_BATCHES=4, parallel=dict(pipeline=2)) - - -def train(): - disable_existing_loggers() - parser = colossalai.get_default_parser() - args = parser.parse_args() - colossalai.launch_from_torch(backend=args.backend, config=CONFIG) - logger = get_dist_logger() - - # build model - model = vit_large_patch16_224() - model = build_pipeline_model(model, num_chunks=NUM_CHUNKS, verbose=True) - - # build criterion - criterion = nn.CrossEntropyLoss() - - # optimizer - optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0) - - # build dataloader - train_dataloader, test_dataloader = build_cifar(BATCH_SIZE) - - engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model, optimizer, criterion, - train_dataloader, test_dataloader) - timer = MultiTimer() - - trainer = Trainer(engine=engine, timer=timer, logger=logger) - - hook_list = [ - hooks.LossHook(), - hooks.AccuracyHook(col_nn.metric.Accuracy()), - hooks.LogMetricByEpochHook(logger), - ] - - trainer.fit(train_dataloader=train_dataloader, - epochs=NUM_EPOCHS, - test_dataloader=test_dataloader, - test_interval=1, - hooks=hook_list, - display_progress=True) -``` - diff --git a/docs/source/en/advanced_tutorials/train_vit_with_hybrid_parallelism.md b/docs/source/en/advanced_tutorials/train_vit_with_hybrid_parallelism.md index 0ec9d5c3c5de..93fed61c34da 100644 --- a/docs/source/en/advanced_tutorials/train_vit_with_hybrid_parallelism.md +++ b/docs/source/en/advanced_tutorials/train_vit_with_hybrid_parallelism.md @@ -1,10 +1,14 @@ # Step By Step: Accelerate ViT Training With Colossal-AI (From Data Parallel to Hybrid Parallel) -Author: Yuxuan Lou +Author: Yuxuan Lou, Mingyan Jiang + +**Prerequisite:** +- [parallelism plugin](../basics/booster_plugins.md) +- [booster API](../basics/booster_api.md) **Example Code** -- [Colossal-AI Examples ViT on Cifar10](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/image/vision_transformer) +- [Colossal-AI Examples ViT on `beans`](https://github.com/hpcaitech/ColossalAI/blob/main/examples/images/vit/vit_train_demo.py) **Related Paper** - [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/pdf/2010.11929.pdf) @@ -13,14 +17,14 @@ Author: Yuxuan Lou ## Introduction In this example for ViT model, Colossal-AI provides three different parallelism techniques which accelerate model training: data parallelism, pipeline parallelism and tensor parallelism. -We will show you how to train ViT on CIFAR-10 dataset with these parallelism techniques. To run this example, you will need 2-4 GPUs. +We will show you how to train ViT on `beans` dataset with these parallelism techniques. To run this example, you will need 2-4 GPUs. ## Table of Contents 1. Colossal-AI installation -2. Steps to train ViT with data parallelism -3. Steps to train ViT with pipeline parallelism -4. Steps to train ViT with tensor parallelism or hybrid parallelism +2. Define the ViT model and related training components. +3. Boost the VIT Model with [`HybridParallelPlugin`](../basics/booster_plugins.md) +4. Train the VIT model using data parallelism, pipeline parallelism, and tensor parallelism. ## Colossal-AI Installation You can install Colossal-AI package and its dependencies with PyPI. @@ -29,619 +33,250 @@ pip install colossalai ``` - -## Data Parallelism -Data parallelism is one basic way to accelerate model training process. You can apply data parallelism to training by only two steps: -1. Define a configuration file -2. Change a few lines of code in train script - -### Define your configuration file (`data_parallel/config.py`) -To use Colossal-AI, the first step is to define a configuration file. And there are two kinds of variables here: - -1. **Colossal-AI feature specification** - -There is an array of features Colossal-AI provides to speed up training (parallel mode, mixed precision, ZeRO, etc.). Each feature is defined by a corresponding field in the config file. If we apply data parallel only, we do not need to specify the parallel mode. In this example, we use mixed precision training natively provided by PyTorch by define the mixed precision configuration `fp16 = dict(mode=AMP_TYPE.TORCH)`. - -2. **Global hyper-parameters** - -Global hyper-parameters include model-specific hyper-parameters, training settings, dataset information, etc. - +## Import libraries ```python -from colossalai.amp import AMP_TYPE - -# ViT Base -BATCH_SIZE = 256 -DROP_RATE = 0.1 -NUM_EPOCHS = 300 - -# mix precision -fp16 = dict( - mode=AMP_TYPE.TORCH, -) - -gradient_accumulation = 16 -clip_grad_norm = 1.0 - -dali = dict( - gpu_aug=True, - mixup_alpha=0.2 -) -``` +from typing import Any, Callable, Iterator -### Modify train script (`/data_parallel/train_with_cifar10.py`) +import torch +import torch.distributed as dist +import torch.nn as nn +import transformers +from data import BeansDataset, beans_collator +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler as LRScheduler +from torch.utils.data import DataLoader +from tqdm import tqdm +from transformers import ViTConfig, ViTForImageClassification, ViTImageProcessor -#### Import modules -- Colossal-AI related modules -```python import colossalai -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin +from colossalai.cluster import DistCoordinator from colossalai.logging import disable_existing_loggers, get_dist_logger -from colossalai.nn.lr_scheduler import LinearWarmupLR -from colossalai.legacy.nn.metric import Accuracy -from colossalai.legacy.trainer import Trainer, hooks -``` - -- Other modules -```python -import os - -import torch -from timm.models import vit_base_patch16_224 - - -from torchvision import transforms -from torchvision.datasets import CIFAR10 +from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR +from colossalai.nn.optimizer import HybridAdam ``` - -#### Launch Colossal-AI - -In train script, you need to initialize the distributed environment for Colossal-AI after your config file is prepared. We call this process `launch`. In Colossal-AI, we provided several launch methods to initialize the distributed backend. In most cases, you can use `colossalai.launch` and `colossalai.get_default_parser` to pass the parameters via command line. Besides, Colossal-AI can utilize the existing launch tool provided by PyTorch as many users are familiar with by using `colossalai.launch_from_torch`. For more details, you can view the related [documents](https://www.colossalai.org/docs/basics/launch_colossalai). - +## Define the Vision Transformer (VIT) model. +Define hyperparameters. ```python -# initialize distributed setting -parser = colossalai.get_default_parser() -args = parser.parse_args() -colossalai.launch_from_torch(config=args.config) - -disable_existing_loggers() -logger = get_dist_logger() +SEED = 42 +MODEL_PATH = "google/vit-base-patch16-224" +LEARNING_RATE = 5e-5 +WEIGHT_DECAY = 0.0 +NUM_EPOCH = 3 +WARMUP_RATIO = 0.3 +TP_SIZE = 2 +PP_SIZE = 2 ``` - -After initialization, you can access the variables in the config file by using `colossalai.core.global_context`. - +Create a distributed environment. ```python -#access parameters -print(gpc.config.BATCH_SIZE) +# Launch ColossalAI +colossalai.launch_from_torch(config={}, seed=SEEDå) +coordinator = DistCoordinator() +world_size = coordinator.world_size ``` +Before training, you can define the relevant components of the model training process as usual, such as defining the model, data loaders, optimizer, and so on. It's important to note that when using pipeline parallelism, you also need to define a criterion function. This function takes the input and output of the model forward pass as inputs and returns the loss. +Prepare the dataset. BeansDataset is defined in [data.py](https://github.com/hpcaitech/ColossalAI/blob/main/examples/images/vit/data.py). -#### Build Model - -If only data parallelism is required, you do not need to make any changes to your model. Here, we use `vit_base_patch16_224` from `timm`. ```python -# build model -model = vit_base_patch16_224(drop_rate=0.1, num_classes=gpc.config.NUM_CLASSES) +image_processor = ViTImageProcessor.from_pretrained(MODEL_PATH) +train_dataset = BeansDataset(image_processor, TP_SIZE, split="train") +eval_dataset = BeansDataset(image_processor, RP_SIZE, split="validation") +num_labels = train_dataset.num_labels ``` - -#### Build CIFAR-10 Dataloader -`colossalai.utils.get_dataloader` can help you build dataloader easily. - +Define the VIT model: ```python -def build_cifar(batch_size): - transform_train = transforms.Compose([ - transforms.RandomCrop(224, pad_if_needed=True), - transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10), - transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), - ]) - transform_test = transforms.Compose([ - transforms.Resize(224), - transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), - ]) - - train_dataset = CIFAR10(root=os.environ['DATA'], train=True, download=True, transform=transform_train) - test_dataset = CIFAR10(root=os.environ['DATA'], train=False, transform=transform_test) - train_dataloader = get_dataloader(dataset=train_dataset, shuffle=True, batch_size=batch_size, pin_memory=True) - test_dataloader = get_dataloader(dataset=test_dataset, batch_size=batch_size, pin_memory=True) - return train_dataloader, test_dataloader - - -# build dataloader -train_dataloader, test_dataloader = build_cifar(gpc.config.BATCH_SIZE) +config = ViTConfig.from_pretrained(MODEL_PATH) +config.num_labels = num_labels +config.id2label = {str(i): c for i, c in enumerate(train_dataset.label_names)} +config.label2id = {c: str(i) for i, c in enumerate(train_dataset.label_names)} +model = ViTForImageClassification.from_pretrained( + MODEL_PATH, config=config, ignore_mismatched_sizes=True +) ``` - -#### Define optimizer, loss function and LR scheduler - -Colossal-AI provides its own optimizer, loss function and LR scheduler. Those from PyTorch are also compatible. - +Define the optimizer: ```python -# build optimizer -optimizer = colossalai.nn.Lamb(model.parameters(), lr=1.8e-2, weight_decay=0.1) - -# build loss -criterion = torch.nn.CrossEntropyLoss() - -# lr_scheduler -lr_scheduler = LinearWarmupLR(optimizer, warmup_steps=50, total_steps=gpc.config.NUM_EPOCHS) +optimizer = HybridAdam(model.parameters(), lr=(LEARNING_RATE * world_size), weight_decay=WEIGHT_DECAY) ``` - -#### Start Colossal-AI engine - -Engine is essentially a wrapper class for model, optimizer and loss function. When we call `colossalai.initialize`, an engine object will be returned, and it has already been equipped with functionalities such as gradient clipping, gradient accumulation and zero optimizer as specified in your configuration file. Further model training is based on Colossal-AI engine. - +Define the learning rate scheduler: ```python -engine, train_dataloader, test_dataloader, _ = colossalai.initialize( - model, optimizer, criterion, train_dataloader, test_dataloader +total_steps = len(train_dataloader) * NUM_EPOCH +num_warmup_steps = int(WARMUP_RATIO * total_steps) +lr_scheduler = CosineAnnealingWarmupLR( + optimizer=optimizer, total_steps=(len(train_dataloader) * NUM_EPOCH), warmup_steps=num_warmup_steps ) ``` - -#### Train: Trainer API -Trainer is a more high-level wrapper for the user to execute training with fewer lines of code. It is easy to create a trainer object by passing the engine object. - -Besides, In trainer, the user can customize some hooks and attach these hooks to the trainer object. A hook object will execute life-cycle methods periodically based on the training scheme. For example, The `LRSchedulerHook` will execute `lr_scheduler.step()` to update the learning rate of the model during either `after_train_iter` or `after_train_epoch` stages. - -```python -# build trainer -trainer = Trainer(engine=engine, logger=logger) - -# build hooks -hook_list = [ - hooks.LossHook(), - hooks.AccuracyHook(accuracy_func=MixupAccuracy()), - hooks.LogMetricByEpochHook(logger), - hooks.LRSchedulerHook(lr_scheduler, by_epoch=True), - - # comment if you do not need to use the hooks below - hooks.SaveCheckpointHook(interval=1, checkpoint_dir='./ckpt'), - hooks.TensorboardHook(log_dir='./tb_logs', ranks=[0]), -] -``` - -Use `trainer.fit` for training: - -```python -# start training -trainer.fit( - train_dataloader=train_dataloader, - test_dataloader=test_dataloader, - epochs=gpc.config.NUM_EPOCHS, - hooks=hook_list, - display_progress=True, - test_interval=1 -) -``` - -### Start training -`DATA` is the filepath where CIFAR-10 dataset will be automatically downloaded and stored. - -`` is the number of GPUs you want to use to train ViT on CIFAR-10 with data parallelism. - -```bash -export DATA= -# If your torch >= 1.10.0 -torchrun --standalone --nproc_per_node train_dp.py --config ./configs/config_data_parallel.py -# If your torch >= 1.9.0 -# python -m torch.distributed.run --standalone --nproc_per_node= train_dp.py --config ./configs/config_data_parallel.py -# Otherwise -# python -m torch.distributed.launch --nproc_per_node --master_addr --master_port 29500 train_dp.py --config ./configs/config.py -``` - - - -## Pipeline Parallelism -Aside from data parallelism, Colossal-AI also support pipeline parallelism. In specific, Colossal-AI uses 1F1B pipeline introduced by NVIDIA. For more details, you can view the related [documents](https://www.colossalai.org/tutorials/features/pipeline_parallel). - -### Define your configuration file(`hybrid_parallel/configs/vit_pipeline.py`) -To apply pipeline parallel on the data parallel basis, you only need to add a **parallel dict** +Define the criterion function: ```python -from colossalai.amp import AMP_TYPE - -parallel = dict( - pipeline=2 -) -# pipeline config -NUM_MICRO_BATCHES = parallel['pipeline'] -TENSOR_SHAPE = (BATCH_SIZE // NUM_MICRO_BATCHES, SEQ_LENGTH, HIDDEN_SIZE) - -fp16 = dict(mode=AMP_TYPE.NAIVE) -clip_grad_norm = 1.0 +def _criterion(outputs, inputs): + return outputs.loss ``` - -Other configs: +## Boost the VIT Model +We begin using ColossalAI's hybrid parallelism strategy to enhance the model. First, let's define an object of `HybridParallelPlugin`. `HybridParallelPlugin` encapsulates various parallelism strategies in ColossalAI. Afterward, we use the `HybridParallelPlugin` object to initialize the booster and boost the VIT model. +### Training with AMP +In the HybridParallelPlugin plugin, you can determine the training precision by setting the precision parameter, which supports three types: 'fp16', 'bf16', and 'fp32'. 'fp16' and 'bf16' are half-precision types. Half-precision is used in two scenarios in the HybridParallelPlugin: +1. When using zero-data parallelism, you should set it to half-precision. +2. When specifying the use of AMP (Automatic Mixed Precision) for training. +You can set related parameters when using half-precision. +`initial_scale` (float, optional): Initial loss scaling factor for AMP. Default value is 2**16. +`min_scale` (float, optional): Minimum loss scaling factor for AMP. Default value is 1. +`growth_factor` (float, optional): Multiplicative factor used to increase the loss scaling factor when using AMP. Default value is 2. +`backoff_factor` (float, optional): Multiplicative factor used to decrease the loss scaling factor when using AMP. Default value is 0.5. +`growth_interval` (integer, optional): Number of steps to increase the loss scaling factor when using AMP, in cases where there is no overflow. Default value is 1000. +`hysteresis` (integer, optional): Number of overflows required before reducing the loss scaling factor when using AMP. Default value is 2. +`max_scale` (float, optional): Maximum loss scaling factor for AMP. Default value is 2**32. +Plugin example when using amp: ```python -# hyper parameters -# BATCH_SIZE is as per GPU -# global batch size = BATCH_SIZE x data parallel size -BATCH_SIZE = 256 -LEARNING_RATE = 3e-3 -WEIGHT_DECAY = 0.3 -NUM_EPOCHS = 300 -WARMUP_EPOCHS = 32 - -# model config -IMG_SIZE = 224 -PATCH_SIZE = 16 -HIDDEN_SIZE = 768 -DEPTH = 12 -NUM_HEADS = 12 -MLP_RATIO = 4 -NUM_CLASSES = 10 -CHECKPOINT = True -SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE) ** 2 + 1 # add 1 for cls token +plugin = HybridParallelPlugin( + precision="fp16", + initial_scale=1, + ) ``` - -### Build pipeline model (`/hybrid_parallel/model/vit.py`) -Colossal-AI provides two methods to build a pipeline model from the existing model. -- `colossalai.legacy.builder.build_pipeline_model_from_cfg` -- `colossalai.legacy.builder.build_pipeline_model` - -Besides, you can also build a pipeline model from scratch with Colossal-AI. +### Tensor parallelism +`HybridParallelPlugin` achieves tensor parallelism through Shardformer. In this plugin, you can set the `tp_size` to determine the size of tensor parallel groups. Additionally, there are multiple parameters that can be configured to optimize tensor parallelism features when using this plugin: +`enable_all_optimization` (boolean, optional): Whether to enable all optimization methods supported by Shardformer. Currently, all optimization methods include fused normalization, flash attention, and JIT. Default is False. +`enable_fused_normalization` (boolean, optional): Whether to enable fused normalization in Shardformer. Default is False. +`enable_flash_attention` (boolean, optional): Whether to enable flash attention in Shardformer. Default is False. +`enable_jit_fused` (boolean, optional): Whether to enable JIT (Just-In-Time) fusion in Shardformer. Default is False. +`enable_sequence_parallelism` (boolean): Whether to enable sequence parallelism in Shardformer. Default is False. +`enable_sequence_overlap` (boolean): Whether to enable sequence overlap in Shardformer. Default is False. +Example of a tensor parallelism plugin: ```python -import math -from typing import Callable - -import inspect -import torch -from colossalai import nn as col_nn -from colossalai.legacy.registry import LAYERS, MODELS -from colossalai.logging import get_dist_logger -from colossalai.core import global_context as gpc -from colossalai.context import ParallelMode -from colossalai.legacy.builder.pipeline import partition_uniform -from torch import dtype, nn -from model_zoo.vit.vit import ViTBlock, ViTEmbedding, ViTHead - - -@MODELS.register_module -class PipelineVisionTransformer(nn.Module): - def __init__(self, - img_size: int = 224, - patch_size: int = 16, - in_chans: int = 3, - num_classes: int = 1000, - depth: int = 12, - num_heads: int = 12, - dim: int = 768, - mlp_ratio: int = 4, - attention_dropout: float = 0., - dropout: float = 0.1, - drop_path: float = 0., - layernorm_epsilon: float = 1e-6, - activation: Callable = nn.functional.gelu, - representation_size: int = None, - dtype: dtype = None, - bias: bool = True, - checkpoint: bool = False, - init_method: str = 'torch', - first_stage=True, - last_stage=True, - start_idx=None, - end_idx=None,): - super().__init__() - - layers = [] - - if first_stage: - embed = ViTEmbedding(img_size=img_size, - patch_size=patch_size, - in_chans=in_chans, - embedding_dim=dim, - dropout=dropout, - dtype=dtype, - init_method=init_method) - layers.append(embed) - - # stochastic depth decay rule - dpr = [x.item() for x in torch.linspace(0, drop_path, depth)] - - if start_idx is None and end_idx is None: - start_idx = 0 - end_idx = depth - - blocks = [ - ViTBlock( - dim=dim, - num_heads=num_heads, - mlp_ratio=mlp_ratio, - attention_dropout=attention_dropout, - dropout=dropout, - drop_path=dpr[i], - activation=activation, - dtype=dtype, - bias=bias, - checkpoint=checkpoint, - init_method=init_method, - ) for i in range(start_idx, end_idx) - ] - layers.extend(blocks) - - if last_stage: - norm = col_nn.LayerNorm(normalized_shape=dim, eps=layernorm_epsilon, dtype=dtype) - head = ViTHead(dim=dim, - num_classes=num_classes, - representation_size=representation_size, - dtype=dtype, - bias=bias, - init_method=init_method) - layers.extend([norm, head]) - - self.layers = nn.Sequential( - *layers +plugin = HybridParallelPlugin( + tp_size=4, + enable_all_optimization=True ) - - def forward(self, x): - x = self.layers(x) - return x - - -def _filter_kwargs(func, kwargs): - sig = inspect.signature(func) - return {k: v for k, v in kwargs.items() if k in sig.parameters} - - -def _build_pipeline_vit(module_cls, num_layers, num_chunks, device=torch.device('cuda'), **kwargs): - logger = get_dist_logger() - if gpc.is_initialized(ParallelMode.PIPELINE): - pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) - pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) - else: - pipeline_size = 1 - pipeline_rank = 0 - rank = gpc.get_global_rank() - parts = partition_uniform(num_layers, pipeline_size, num_chunks)[pipeline_rank] - models = [] - - for start, end in parts: - kwargs['first_stage'] = start == 0 - kwargs['last_stage'] = end == num_layers - kwargs['start_idx'] = start - kwargs['end_idx'] = end - logger.info(f'Rank{rank} build layer {start}-{end}, {end-start}/{num_layers} layers') - chunk = module_cls(**_filter_kwargs(module_cls.__init__, kwargs)).to(device) - models.append(chunk) - if len(models) == 1: - model = models[0] - else: - model = nn.ModuleList(models) - return model - - -def build_pipeline_vit(num_layers, num_chunks, device=torch.device('cuda'), **kwargs): - return _build_pipeline_vit(PipelineVisionTransformer, num_layers, num_chunks, device, **kwargs) ``` +### Pipeline Parallelism -### Modify train script (`/hybrid_parallel/train_with_cifar10.py`) - -#### Import modules +`HybridParallelPlugin` determines the size of pipeline parallelism groups by setting `pp_size`. `num_microbatches` is used to specify the number of microbatches into which the entire batch is divided during pipeline parallelism, and `microbatch_size` can be set to define the size of these microbatches. The plugin will prioritize using `num_microbatches` to determine the microbatch configuration. +Example of a plugin for pipeline parallelism: ```python -from colossalai.legacy.engine.schedule import (InterleavedPipelineSchedule, - PipelineSchedule) -from colossalai.utils import MultiTimer -import os - -import colossalai - -import torch -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.logging import get_dist_logger -from colossalai.nn import CrossEntropyLoss -from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR -from colossalai.utils import is_using_pp, get_dataloader -from model.vit import build_pipeline_vit -from model_zoo.vit.vit import _create_vit_model -from tqdm import tqdm - -from torchvision import transforms -from torchvision.datasets import CIFAR10 +plugin = HybridParallelPlugin( + pp_size=4, + num_microbatches=None, + microbatch_size=1 + ) ``` - -#### Launch Colossal-AI -`colossalai.utils.is_using_pp` can help check whether pipeline parallelism is required in config file. - +### Data Parallelism +The `HybridParallelPlugin`'s data parallelism includes both the zero-dp series and Torch DDP. When `zero_stage` is set to 0 (the default), it means using Torch DDP. Please note that Torch DDP conflicts with pipeline parallelism and cannot be used together. When `zero_stage` is set to 1, it indicates the use of the zero1 strategy. When `zero_stage` is set to 2, it implies the use of the zero2 strategy. The zero2 strategy also cannot be used together with pipeline parallelism. If you want to use zero3, please use the [`GeminiPlugin`](../basics/booster_plugins.md). +When using data parallelism with the zero series, please set the training precision to half-precision. If you haven't specified the use of zero or pipeline parallelism, and if `world_size//(tp_size*pp_size)` is greater than 1, the HybridParallelPlugin will automatically enable the Torch DDP parallel strategy for you. +Here are some related parameters for configuring Torch DDP: +`broadcast_buffers` (boolean, optional): Whether to broadcast buffers at the beginning of training when using DDP. Default is True. +`ddp_bucket_cap_mb` (integer, optional): Size of the bucket (in MB) when using DDP. Default is 25. +`find_unused_parameters` (boolean, optional): Whether to search for unused parameters when using DDP. Default is False. +`check_reduction` (boolean, optional): Whether to check the reduction operation when using DDP. Default is False. +`gradient_as_bucket_view` (boolean, optional): Whether to use gradients as bucket views when using DDP. Default is False. +`static_graph` (boolean, optional): Whether to use a static graph when using DDP. Default is False. +Example of a plugin for Torch DDP. ```python -# initialize distributed setting -parser = colossalai.get_default_parser() -args = parser.parse_args() - -# launch from torch -colossalai.launch_from_torch(config=args.config) - -# get logger -logger = get_dist_logger() -logger.info("initialized distributed environment", ranks=[0]) - -if hasattr(gpc.config, 'LOG_PATH'): - if gpc.get_global_rank() == 0: - log_path = gpc.config.LOG_PATH - if not os.path.exists(log_path): - os.mkdir(log_path) - logger.log_to_file(log_path) - -use_pipeline = is_using_pp() +plugin = HybridParallelPlugin( + tp_size=2, + pp_size=1, + zero_stage=0, + precision="fp16", + initial_scale=1, + ) ``` - -#### Define model - +If there are 4 parallel processes, the parallel group size for Torch DDP is 2. +ZeRO-related parameters: +`zero_bucket_size_in_m` (integer, optional): The bucket size for gradient reduction in megabytes when using ZeRO. Default is 12. +`cpu_offload` (boolean, optional): Whether to enable cpu_offload when using ZeRO. Default is False. +`communication_dtype` (torch data type, optional): The data type for communication when using ZeRO. If not specified, the data type of the parameters will be used. Default is None. +`overlap_communication` (boolean, optional): Whether to overlap communication and computation when using ZeRO. Default is True. +Example of a plugin for ZERO1. ```python -# create model -model_kwargs = dict(img_size=gpc.config.IMG_SIZE, - patch_size=gpc.config.PATCH_SIZE, - dim=gpc.config.HIDDEN_SIZE, - depth=gpc.config.DEPTH, - num_heads=gpc.config.NUM_HEADS, - mlp_ratio=gpc.config.MLP_RATIO, - num_classes=gpc.config.NUM_CLASSES, - init_method='jax', - checkpoint=gpc.config.CHECKPOINT) - -if use_pipeline: - model = build_pipeline_vit(num_layers=model_kwargs['depth'], num_chunks=1, **model_kwargs) -else: - model = _create_vit_model(**model_kwargs) -``` - -#### Count number of parameters - -You can count model parameters on different pipeline stages easily. - -``` -# count number of parameters -total_numel = 0 -for p in model.parameters(): - total_numel += p.numel() -if not gpc.is_initialized(ParallelMode.PIPELINE): - pipeline_stage = 0 -else: - pipeline_stage = gpc.get_local_rank(ParallelMode.PIPELINE) -logger.info(f"number of parameters: {total_numel} on pipeline stage {pipeline_stage}") +plugin = HybridParallelPlugin( + tp_size=1, + pp_size=1, + zero_stage=1, + cpu_offload=True, + precision="fp16", + initial_scale=1, + ) ``` -#### Build dataloader, optimizer, etc. - +### Hybrid Parallelism +You can refer to the above-mentioned strategies to customize an appropriate hybrid parallelism strategy. And use this plugin to define a booster. ```python -def build_cifar(batch_size): - transform_train = transforms.Compose([ - transforms.RandomCrop(224, pad_if_needed=True), - transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10), - transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), - ]) - transform_test = transforms.Compose([ - transforms.Resize(224), - transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), - ]) - - train_dataset = CIFAR10(root=os.environ['DATA'], train=True, download=True, transform=transform_train) - test_dataset = CIFAR10(root=os.environ['DATA'], train=False, transform=transform_test) - train_dataloader = get_dataloader(dataset=train_dataset, shuffle=True, batch_size=batch_size, pin_memory=True) - test_dataloader = get_dataloader(dataset=test_dataset, batch_size=batch_size, pin_memory=True) - return train_dataloader, test_dataloader - - -# create dataloaders -train_dataloader , test_dataloader = build_cifar() - -# create loss function -criterion = CrossEntropyLoss(label_smoothing=0.1) - -# create optimizer -optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY) - -# create lr scheduler -lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, - total_steps=gpc.config.NUM_EPOCHS, - warmup_steps=gpc.config.WARMUP_EPOCHS) +plugin = HybridParallelPlugin( + tp_size=TP_SIZE, + pp_size=PP_SIZE, + num_microbatches=None, + microbatch_size=1, + enable_all_optimization=True, + precision="fp16", + initial_scale=1, + ) +booster = Booster(plugin=plugin) ``` - -#### Start Colossal-AI engine - +Next, we use `booster.boost` to inject the features encapsulated by the plugin into the model training components. ```python -# initialize -engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model=model, - optimizer=optimizer, - criterion=criterion, - train_dataloader=train_dataloader, - test_dataloader=test_dataloader) - -logger.info("Engine is built", ranks=[0]) +model, optimizer, _criterion, train_dataloader, lr_scheduler = booster.boost( + model=model, optimizer=optimizer, criterion=criterion, dataloader=train_dataloader, lr_scheduler=lr_scheduler + ) ``` - -#### Train: based on engine - -In the data parallelism example, we show how to train a model with Trainer API. We can also directly train a model based on engine. In this way, you can customize your training with more features. - +## Train ViT using hybrid parallelism. +Finally, we can use the hybrid parallelism strategy to train the model. Let's first define a training function that describes the training process. It's important to note that if the pipeline parallelism strategy is used, you should call `booster.execute_pipeline` to perform the model training. This function will invoke the `scheduler` to manage the model's forward and backward operations. ```python -data_iter = iter(train_dataloader) - -for epoch in range(gpc.config.NUM_EPOCHS): - # training - engine.train() - - if gpc.get_global_rank() == 0: - description = 'Epoch {} / {}'.format( - epoch, - gpc.config.NUM_EPOCHS +def run_forward_backward( + model: nn.Module, + optimizer: Optimizer, + criterion: Callable[[Any, Any], torch.Tensor], + data_iter: Iterator, + booster: Booster, +): + if optimizer is not None: + optimizer.zero_grad() + if isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1: + # run pipeline forward backward when enabling pp in hybrid parallel plugin + output_dict = booster.execute_pipeline( + data_iter, model, criterion, optimizer, return_loss=True, return_outputs=True ) - progress = tqdm(range(len(train_dataloader)), desc=description) + loss, outputs = output_dict["loss"], output_dict["outputs"] else: - progress = range(len(train_dataloader)) - for _ in progress: - engine.zero_grad() - engine.execute_schedule(data_iter, return_output_label=False) - engine.step() - lr_scheduler.step() -``` - -### Start training -```bash -export DATA= -# If your torch >= 1.10.0 -torchrun --standalone --nproc_per_node train_hybrid.py --config ./configs/config_pipeline_parallel.py -# If your torch >= 1.9.0 -# python -m torch.distributed.run --standalone --nproc_per_node= train_hybrid.py --config ./configs/config_pipeline_parallel.py -``` - - - - -## Tensor Parallelism and Hybrid Parallelism -Tensor parallelism partitions each weight parameter across multiple devices in order to reduce memory load. Colossal-AI support 1D, 2D, 2.5D and 3D tensor parallelism. Besides, you can combine tensor parallelism with pipeline parallelism and data parallelism to reach hybrid parallelism. Colossal-AI also provides an easy way to apply tensor parallelism and hybrid parallelism. On the basis of pipeline parallelism, a few lines of code changing in config file is all you need. - -### Define your configuration file(`/hybrid_parallel/configs/vit_1d_tp2_pp2.py`) -To use tensor parallelism, you only need to add related information to the **parallel dict**. To be specific, `TENSOR_PARALLEL_MODE` can be '1d', '2d', '2.5d', '3d'. And the size of different parallelism should satisfy: `#GPUs = pipeline parallel size x tensor parallel size x data parallel size`. `data parallel size` will automatically computed after you specify the number of GPUs, pipeline parallel size and tensor parallel size. - -```python -from colossalai.amp import AMP_TYPE -# parallel setting -TENSOR_PARALLEL_SIZE = 2 -TENSOR_PARALLEL_MODE = '1d' - -parallel = dict( - pipeline=2, - tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE) -) - -fp16 = dict(mode=AMP_TYPE.NAIVE) -clip_grad_norm = 1.0 - - -# pipeline config -NUM_MICRO_BATCHES = parallel['pipeline'] -TENSOR_SHAPE = (BATCH_SIZE // NUM_MICRO_BATCHES, SEQ_LENGTH, HIDDEN_SIZE) + batch = next(data_iter) + batch = move_to_cuda(batch, torch.cuda.current_device()) + outputs = model(**batch) + loss = criterion(outputs, None) + if optimizer is not None: + booster.backward(loss, optimizer) + +def train_epoch( + epoch: int, + model: nn.Module, + optimizer: Optimizer, + criterion: Callable[[Any, Any], torch.Tensor], + lr_scheduler: LRScheduler, + dataloader: DataLoader, + booster: Booster, + coordinator: DistCoordinator, +): + torch.cuda.synchronize() + + num_steps = len(dataloader) + data_iter = iter(dataloader) + enable_pbar = coordinator.is_master() + if isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1: + # when using pp, only the last stage of master pipeline (dp_rank and tp_rank are both zero) shows pbar + tp_rank = dist.get_rank(booster.plugin.tp_group) + dp_rank = dist.get_rank(booster.plugin.dp_group) + enable_pbar = tp_rank == 0 and dp_rank == 0 and booster.plugin.stage_manager.is_last_stage() + model.train() + + with tqdm(range(num_steps), desc=f"Epoch [{epoch + 1}]", disable=not enable_pbar) as pbar: + for _ in pbar: + loss, _ = run_forward_backward(model, optimizer, criterion, data_iter, booster) + optimizer.step() + lr_scheduler.step() + + # Print batch loss + if enable_pbar: + pbar.set_postfix({"loss": loss.item()}) ``` - -Other configs: +Start training the model. ```python -# hyper parameters -# BATCH_SIZE is as per GPU -# global batch size = BATCH_SIZE x data parallel size -BATCH_SIZE = 256 -LEARNING_RATE = 3e-3 -WEIGHT_DECAY = 0.3 -NUM_EPOCHS = 300 -WARMUP_EPOCHS = 32 - -# model config -IMG_SIZE = 224 -PATCH_SIZE = 16 -HIDDEN_SIZE = 768 -DEPTH = 12 -NUM_HEADS = 12 -MLP_RATIO = 4 -NUM_CLASSES = 10 -CHECKPOINT = True -SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE) ** 2 + 1 # add 1 for cls token -``` - -### Start training -```bash -export DATA= -# If your torch >= 1.10.0 -torchrun --standalone --nproc_per_node train_hybrid.py --config ./configs/config_hybrid_parallel.py -# If your torch >= 1.9.0 -# python -m torch.distributed.run --standalone --nproc_per_node= train_hybrid.py --config ./configs/config_hybrid_parallel.py +for epoch in range(NUM_EPOCH): + train_epoch(epoch, model, optimizer, criterion, lr_scheduler, train_dataloader, booster, coordinator) ``` diff --git a/docs/source/en/basics/booster_plugins.md b/docs/source/en/basics/booster_plugins.md index feb37fc15de2..fa360a4b9213 100644 --- a/docs/source/en/basics/booster_plugins.md +++ b/docs/source/en/basics/booster_plugins.md @@ -15,7 +15,7 @@ We currently provide the following plugins: - [Torch FSDP Plugin](#torch-fsdp-plugin): It is a wrapper of `torch.distributed.fsdp.FullyShardedDataParallel` and can be used to train models with zero-dp. - [Low Level Zero Plugin](#low-level-zero-plugin): It wraps the `colossalai.zero.low_level.LowLevelZeroOptimizer` and can be used to train models with zero-dp. It only supports zero stage-1 and stage-2. - [Gemini Plugin](#gemini-plugin): It wraps the [Gemini](../features/zero_with_chunk.md) which implements Zero-3 with chunk-based and heterogeneous memory management. -- [Hybrid Pararllel Plugin](#hybrid-parallel-plugin): It provides a tidy interface that integrates the power of Shardformer, pipeline manager, mixied precision training, TorchDDP and Zero stage 1/2 feature. With this plugin, transformer models can be easily trained with any combination of tensor parallel, pipeline parallel and data parallel (DDP/Zero) efficiently, along with various kinds of optimization tools for acceleration and memory saving. Detailed information about supported parallel strategies and optimization tools is explained in the section below. +- [Hybrid Parallel Plugin](#hybrid-parallel-plugin): It provides a tidy interface that integrates the power of Shardformer, pipeline manager, mixied precision training, TorchDDP and Zero stage 1/2 feature. With this plugin, transformer models can be easily trained with any combination of tensor parallel, pipeline parallel and data parallel (DDP/Zero) efficiently, along with various kinds of optimization tools for acceleration and memory saving. Detailed information about supported parallel strategies and optimization tools is explained in the section below. More plugins are coming soon. diff --git a/docs/source/en/features/gradient_accumulation_with_booster.md b/docs/source/en/features/gradient_accumulation_with_booster.md index 347cd6e519bb..ea97dd92e885 100644 --- a/docs/source/en/features/gradient_accumulation_with_booster.md +++ b/docs/source/en/features/gradient_accumulation_with_booster.md @@ -1,6 +1,6 @@ # Gradient Accumulation -Author: [Mingyan Jiang](https://github.com/jiangmingyan) +Author: [Mingyan Jiang](https://github.com/jiangmingyan), [Baizhou Zhang](https://github.com/Fridge003) **Prerequisite** - [Training Booster](../basics/booster_api.md) @@ -126,6 +126,7 @@ for idx, (img, label) in enumerate(train_dataloader): ``` + ### Step 6. Invoke Training Scripts To verify gradient accumulation, we can just check the change of parameter values. When gradient accumulation is set, parameters are only updated in the last step. You can run the script using this command: ```shell @@ -142,4 +143,30 @@ iteration 2, first 10 elements of param: tensor([-0.0208, 0.0189, 0.0234, 0.0 iteration 3, first 10 elements of param: tensor([-0.0141, 0.0464, 0.0507, 0.0321, 0.0356, -0.0150, 0.0172, -0.0118, 0.0222, 0.0473], device='cuda:0', grad_fn=) ``` + +## Gradient Accumulation on GeminiPlugin + +Currently the plugins supporting `no_sync()` method include `TorchDDPPlugin` and `LowLevelZeroPlugin` set to stage 1. `GeminiPlugin` doesn't support `no_sync()` method, but it can enable synchronized gradient accumulation in a torch-like way. + +To enable gradient accumulation feature, the argument `enable_gradient_accumulation` should be set to `True` when initializing `GeminiPlugin`. Following is the pseudocode snippet of enabling gradient accumulation for `GeminiPlugin`: + +```python +... +plugin = GeminiPlugin(..., enable_gradient_accumulation=True) +booster = Booster(plugin=plugin) +... + +... +for idx, (input, label) in enumerate(train_dataloader): + output = gemini_model(input.cuda()) + train_loss = criterion(output, label.cuda()) + train_loss = train_loss / GRADIENT_ACCUMULATION + booster.backward(train_loss, gemini_optimizer) + + if idx % (GRADIENT_ACCUMULATION - 1) == 0: + gemini_optimizer.step() # zero_grad is automatically done +... +``` + + diff --git a/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md b/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md index a1d58e9fddc2..11740698057f 100644 --- a/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md +++ b/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md @@ -1,10 +1,13 @@ -# 使用混合并行训练 GPT +# 使用混合并行训练 GPT-2 -作者: Hongxin Liu, Yongbin Li +作者: Hongxin Liu, Yongbin Li, Mingyan Jiang + +**前置教程** +- [并行插件](../basics/booster_plugins.md) +- [booster API](../basics/booster_api.md) **示例代码** -- [ColossalAI-Examples GPT2](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/language/gpt_2) -- [ColossalAI-Examples GPT3](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/language/gpt_3) +- [ColossalAI-Examples GPT2](https://github.com/hpcaitech/ColossalAI/blob/main/examples/language/gpt/hybridparallelism/finetune.py) **相关论文** - [Colossal-AI: A Unified Deep Learning System For Large-Scale Parallel Training](https://arxiv.org/abs/2110.14883) @@ -12,265 +15,190 @@ ## 引言 -在上一篇教程中,我们介绍了如何用流水并行训练 ViT。在本教程中,你将学习一个更复杂的场景--用混合并行方式训练GPT。在这种情况下,由于GPT-3过大,即使CPU内存也无法容纳它。因此,你必须自己分割模型。 +在上一篇教程中,我们介绍了如何用流水并行训练 ViT。在本教程中,你将学习一个更复杂的场景--用混合并行方式训练GPT-2。在这种情况下,由于GPT-2过大,即使CPU内存也无法容纳它。因此,该模型必须被分割。 ## 目录 在本教程中,我们将介绍: - -1. 基于 colossalai/model_zoo 定义 GPT 模型 -2. 处理数据集 -3. 使用混合并行训练 GPT +1. 初始化混合并行插件 +2. 定义 GPT-2 模型的训练组件 +3. 使用 [HybridParallelPlugin](../basics/booster_plugins.md) 增强GPT-2模型 +4. 使用混合并行训练 GPT-2 ## 导入依赖库 ```python -import json -import os -from typing import Callable - -import colossalai -import colossalai.utils as utils -import model_zoo.gpt.gpt as col_gpt +from typing import Callable, List, Union import torch +import torch.distributed as dist import torch.nn as nn -from colossalai import nn as col_nn -from colossalai.amp import AMP_TYPE -from colossalai.legacy.builder.pipeline import partition_uniform -from colossalai.legacy.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.legacy.engine.schedule import (InterleavedPipelineSchedule, - PipelineSchedule) -from colossalai.logging import disable_existing_loggers, get_dist_logger -from colossalai.legacy.nn.layer.wrapper import PipelineSharedModuleWrapper -from colossalai.legacy.trainer import Trainer, hooks -from colossalai.utils.timer import MultiTimer -from model_zoo.gpt import GPTLMLoss -from torch.nn import functional as F -from torch.utils.data import Dataset -from transformers import GPT2Tokenizer -``` - - - -## 定义 GPT 模型 - -在前面的教程中,我们介绍了3种建立流水并行模型的方法,但对于像 GPT-3 这样的巨大模型,你甚至不能在 CPU 中建立模型。在这种情况下,你必须自己分割模型。 - -GPT 数据加载器返回 `input_ids` 和 `attention_mask`, 因此我们在 `forward()` 中使用两个关键字参数来获得它们。请注意,对于除第一阶段以外的其他阶段, `forward()` 的第一个位置参数是上一阶段的输出张量。所以 `hidden_states` 来自前一阶段,并且对于第一阶段来说,它是 `None`。 - -对于 GPT, *word embedding layer* 与 *output head* 共享权重。我们提供 `PipelineSharedModuleWrapper` 在流水阶段间共享参数。它需要一个 `int` 型的 `list` 作为参数, 这意味着 rank 们共享这些参数。你可以使用 `register_module()` -或 `register_parameter()` 来注册一个模块或一个参数作为共享模块或参数。如果你有多组共享模块/参数,你应该有多个 `PipelineSharedModuleWrapper` 实例。 如果参数在**一个**阶段内共享, 你不应该使用 -`PipelineSharedModuleWrapper`, 而只是使用同一个模块/参数实例。在这个例子中,*word embedding layer* 在第一阶段, 而 *output head* 在最后一个阶段。因此,他们在 rank `[0, pipeline_size - 1]` 之间共享参数。 - -对于第一阶段,它维护 embedding layer 和一些 transformer blocks。对于最后一个阶段,它维护一些 transformer blocks 和 output head layer。对于其他阶段,他们只维护一些 transformer blocks。 -`partition_uniform(num_layers, pipeline_size, num_chunks)` 返回所有 rank 的 parts, part 是一个 `(start, end)` (不包括end) 的 `tuple`。`start == 0` 表示这是第一阶段, 而 `end == num_layers` 表示这是最后一个阶段。 +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler as LRScheduler +from tqdm import tqdm +from transformers import AutoConfig, GPT2ForSequenceClassification, get_linear_schedule_with_warmup +from transformers import AutoTokenizer +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin +from colossalai.cluster import DistCoordinator +from colossalai.nn.optimizer import HybridAdam +from colossalai.utils import get_current_device +``` +### 定义plugin +定义一个[`HybridParallelPlugin`](../basics/booster_plugins.md)对象,指定所需要使用的并行策略,在该例子中,同时使用了流水线并行和zero1. ```python -class PipelineGPTHybrid(nn.Module): - def __init__(self, - num_layers: int = 12, - hidden_size: int = 768, - num_attention_heads: int = 12, - vocab_size: int = 50304, - embed_drop_rate: float = 0., - act_func: Callable = F.gelu, - mlp_ratio: int = 4, - attn_drop_rate: float = 0., - drop_rate: float = 0., - dtype: torch.dtype = torch.float, - checkpoint: bool = False, - max_position_embeddings: int = 1024, - layer_norm_epsilon: float = 1e-5, - first: bool = False, - last: bool = False): - super().__init__() - self.embedding = None - self.norm = None - self.head = None - if first: - self.embedding = col_gpt.GPTEmbedding( - hidden_size, vocab_size, max_position_embeddings, dropout=embed_drop_rate, dtype=dtype) - self.blocks = nn.ModuleList([ - col_gpt.GPTBlock(hidden_size, num_attention_heads, mlp_ratio=mlp_ratio, attention_dropout=attn_drop_rate, - dropout=drop_rate, dtype=dtype, checkpoint=checkpoint, activation=act_func) - for _ in range(num_layers) - ]) - if last: - self.norm = col_nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) - self.head = col_gpt.GPTLMHead(vocab_size=vocab_size, - dim=hidden_size, - dtype=dtype, - bias=False) - - def forward(self, hidden_states=None, input_ids=None, attention_mask=None): - if self.embedding is not None: - hidden_states = self.embedding(input_ids=input_ids) - batch_size = hidden_states.shape[0] - attention_mask = attention_mask.view(batch_size, -1) - attention_mask = attention_mask[:, None, None, :] - attention_mask = attention_mask.to(dtype=hidden_states.dtype) # fp16 compatibility - attention_mask = (1.0 - attention_mask) * -10000.0 - for block in self.blocks: - hidden_states, attention_mask = block(hidden_states, attention_mask) - if self.norm is not None: - hidden_states = self.head(self.norm(hidden_states)) - return hidden_states - - -def build_gpt_pipeline(num_layers, num_chunks, device=torch.device('cuda'), **kwargs): - logger = get_dist_logger() - pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) - pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) - rank = gpc.get_global_rank() - wrapper = PipelineSharedModuleWrapper([0, pipeline_size - 1]) - parts = partition_uniform(num_layers, pipeline_size, num_chunks)[pipeline_rank] - models = [] - for start, end in parts: - kwargs['num_layers'] = end - start - kwargs['first'] = start == 0 - kwargs['last'] = end == num_layers - logger.info(f'Rank{rank} build layer {start}-{end}, {end-start}/{num_layers} layers') - chunk = PipelineGPTHybrid(**kwargs).to(device) - if start == 0: - wrapper.register_module(chunk.embedding.word_embeddings) - elif end == num_layers: - wrapper.register_module(chunk.head) - models.append(chunk) - if len(models) == 1: - model = models[0] - else: - model = nn.ModuleList(models) - return model - - -def GPT2_exlarge_pipeline_hybrid(num_chunks=1, checkpoint=False, dtype=torch.float): - cfg = dict(hidden_size=1600, num_attention_heads=32, checkpoint=checkpoint, dtype=dtype) - return build_gpt_pipeline(48, num_chunks, **cfg) - - -def GPT3_pipeline_hybrid(num_chunks=1, checkpoint=False, dtype=torch.float): - cfg = dict(hidden_size=12288, num_attention_heads=96, - checkpoint=checkpoint, max_position_embeddings=2048, dtype=dtype) - return build_gpt_pipeline(96, num_chunks, **cfg) +plugin = HybridParallelPlugin( + tp_size=1, + pp_size=2, + num_microbatches=None, + microbatch_size=1, + enable_all_optimization=True, + zero_stage=1, + precision="fp16", + initial_scale=1, +) ``` -## 处理数据集 - -我们在这里提供了一个小型 GPT web-text 数据集。 原始格式是 loose JSON, 我们将保存处理后的数据集。 - +## 创建分布式环境. ```python -class WebtextDataset(Dataset): - def __init__(self, path, seq_len=1024) -> None: - super().__init__() - root = os.path.dirname(path) - encoded_data_cache_path = os.path.join(root, f'gpt_webtext_{seq_len}.pt') - if os.path.isfile(encoded_data_cache_path): - seq_len_, data, attention_mask = torch.load( - encoded_data_cache_path) - if seq_len_ == seq_len: - self.data = data - self.attention_mask = attention_mask - return - raw_data = [] - with open(path) as f: - for line in f.readlines(): - raw_data.append(json.loads(line)['text']) - tokenizer = GPT2Tokenizer.from_pretrained('gpt2') - tokenizer.pad_token = tokenizer.unk_token - encoded_data = tokenizer( - raw_data, padding=True, truncation=True, max_length=seq_len, return_tensors='pt') - self.data = encoded_data['input_ids'] - self.attention_mask = encoded_data['attention_mask'] - torch.save((seq_len, self.data, self.attention_mask), - encoded_data_cache_path) - - def __len__(self): - return len(self.data) - - def __getitem__(self, index): - return { - 'input_ids': self.data[index], - 'attention_mask': self.attention_mask[index] - }, self.data[index] +# Launch ColossalAI +colossalai.launch_from_torch(config={}, seed=42) +coordinator = DistCoordinator() ``` - -## 使用混合并行训练 GPT - -在上一个教程中,我们解释了一些流水并行的参数含义。在本例中,我们可以确定在流水阶段之间交换的每个输出张量的形状。对于 GPT,该形状为 -`(MICRO BATCH SIZE, SEQUENCE LEN, HIDDEN SIZE)`。通过设置该参数,我们可以避免交换每个阶段的张量形状。当你不确定张量的形状时,你可以把它保留为 -`None`, 形状会被自动推测。请确保你的模型的 `dtype` 是正确的:当你使用 `fp16`,模型的 `dtype` 必须是 `torch.half`;否则,`dtype` 必须是 `torch.float`。对于流水并行,仅支持 `AMP_TYPE.NAIVE`。 - -你可以通过在 `CONFIG` 里使用 `parallel` 来轻松使用张量并行。数据并行的大小是根据 GPU 的数量自动设置的。 - +## 定义GPT-2模型的训练组件 +在使用混合并行之前,您需要定义训练所使用的组件。 +定义超参数。 ```python -NUM_EPOCHS = 60 -SEQ_LEN = 1024 -BATCH_SIZE = 192 -NUM_CHUNKS = None -TENSOR_SHAPE = (1, 1024, 1600) -# only pipeline parallel -# CONFIG = dict(NUM_MICRO_BATCHES = 192, parallel=dict(pipeline=2), fp16=dict(mode=AMP_TYPE.NAIVE)) -# pipeline + 1D model parallel -CONFIG = dict(NUM_MICRO_BATCHES = 192, parallel=dict(pipeline=2, tensor=dict(mode='1d', size=2)), fp16=dict(mode=AMP_TYPE.NAIVE)) - - -def train(): - disable_existing_loggers() - parser = colossalai.get_default_parser() - args = parser.parse_args() - colossalai.launch_from_torch(config=CONFIG, backend=args.backend) - logger = get_dist_logger() - - train_ds = WebtextDataset(os.environ['DATA'], seq_len=SEQ_LEN) - train_dataloader = utils.get_dataloader(train_ds, - seed=42, - batch_size=BATCH_SIZE, - pin_memory=True, - shuffle=True, - drop_last=True) - - use_interleaved = NUM_CHUNKS is not None - num_chunks = 1 if not use_interleaved else NUM_CHUNKS - model = GPT2_exlarge_pipeline_hybrid(num_chunks=num_chunks, checkpoint=True, dtype=torch.half) - # model = GPT3_pipeline_hybrid(num_chunks=num_chunks, checkpoint=True, dtype=torch.half) - if use_interleaved and not isinstance(model, nn.ModuleList): - model = nn.ModuleList([model]) - - criterion = GPTLMLoss() - - optimizer = torch.optim.Adam(model.parameters(), lr=0.00015, weight_decay=1e-2,) - - engine, train_dataloader, _, _ = colossalai.initialize(model, - optimizer, - criterion, - train_dataloader=train_dataloader) - global_batch_size = BATCH_SIZE * \ - gpc.get_world_size(ParallelMode.DATA) * getattr(gpc.config, "gradient_accumulation", 1) - logger.info(f'Init done, global batch size = {global_batch_size}', ranks=[0]) +NUM_EPOCHS = 3 +BATCH_SIZE = 32 +LEARNING_RATE = 2.4e-5 +WEIGHT_DECAY = 0.01 +WARMUP_FRACTION = 0.1 +``` +获取数据集。您可以使用`plugin.prepare_dataloader`生成dataloader,也可以自定义您的dataloader。 +```python +def tokenize_batch(batch, tokenizer: Optional[AutoTokenizer] = None, max_length: int = 2048): + texts = [sample["sentence1"] + sample["sentence2"] for sample in batch] + data = tokenizer(texts, return_tensors="pt", padding="max_length", truncation=True, max_length=max_length) + data = {k: v.cuda() for k, v in data.items()} + data["labels"] = data["input_ids"].clone() + return data + +tokenizer = AutoTokenizer.from_pretrained("gpt2") +dataset = datasets.load_dataset("glue", "mrpc") +train_dataloader = plugin.prepare_dataloader( + dataset["train"], + batch_size=BATCH_SIZE, + shuffle=True, + drop_last=True, + collate_fn=partial(tokenize_batch, tokenizer=tokenizer, max_length=512), +) +``` +定义GPT-2模型。 +```python +cfg = AutoConfig.from_pretrained("gpt2", num_labels=2) +model = GPT2ForSequenceClassification.from_pretrained("gpt2", config=cfg).cuda() +``` +准备优化器 +```python +lr = LEARNING_RATE * coordinator.world_size +no_decay = ["bias", "LayerNorm.weight"] +optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], + "weight_decay": WEIGHT_DECAY, + }, + { + "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, +] + +optimizer = HybridAdam(optimizer_grouped_parameters, lr=lr, eps=1e-8) +``` +准备 `lr_scheduler` 和 `criterion`,需要注意的是,当混合并行使用了管道并行时,还需定义`criterion`函数。这个函数应该以模型前后向的输入和输出作为参数,并返回loss。 +```python +# lr scheduler +total_steps = len(train_dataloader) * NUM_EPOCHS +num_warmup_steps = int(WARMUP_FRACTION * total_steps) +lr_scheduler = get_linear_schedule_with_warmup( + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=total_steps, +) + +def _criterion(outputs, inputs): + return outputs.loss +``` +## 增强GPT-2模型 +使用 HybridParallelPlugin 定义一个 booster(增强器)。根据设置的插件参数,booster会将一种或者多种并行策略注入到模型中。该例子中使用了管道并行,zero1,及半精度训练等优化。 +```python +booster = Booster(plugin=plugin) +``` +使用定义的 booster 来增强这些组件。 +```python +model, optimizer, _criterion, _, lr_scheduler = booster.boost( + model, optimizer, criterion=_criterion, lr_scheduler=lr_scheduler +) +``` - timer = MultiTimer() - trainer = Trainer( - engine=engine, - logger=logger, - timer=timer - ) +## 使用混合并行训练 GPT-2 - hook_list = [ - hooks.LossHook(), - hooks.LogMetricByEpochHook(logger), - hooks.ThroughputHook(), - hooks.LogMetricByStepHook(), - ] +在前面的教程中,我们已经解释了如何使用 Booster 和 HybridParallelPlugin 将各种并行特性注入到模型及其训练组件中。现在我们可以开始模型训练。 +定义一个训练函数。当使用了管道并行时,需要调用`booster.execute_pipeline`进行模型训练的阶段调度。 +```python +def train_epoch( + epoch: int, + model: nn.Module, + optimizer: Optimizer, + _criterion: Callable, + lr_scheduler: LRScheduler, + train_dataloader: DataLoader, + booster: Booster, + coordinator: DistCoordinator, +): + use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1 + is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() + print_flag = (not use_pipeline and coordinator.is_master()) or (use_pipeline and is_pp_last_stage) + total_step = len(train_dataloader) + + model.train() + optimizer.zero_grad() + train_dataloader_iter = iter(train_dataloader) + with tqdm( + range(total_step), + desc=f"Epoch [{epoch + 1}/{NUM_EPOCHS}]", + disable=not print_flag, + ) as pbar: + # Forward pass + for _ in pbar: + if use_pipeline: + outputs = booster.execute_pipeline( + train_dataloader_iter, model, _criterion, optimizer, return_loss=True, return_outputs=True + ) + # Backward and optimize + if is_pp_last_stage: + loss = outputs["loss"] + pbar.set_postfix({"loss": loss.item()}) + else: + data = next(train_dataloader_iter) + data = move_to_cuda(data) + outputs = model(**data) + loss = _criterion(outputs, None) + # Backward + booster.backward(loss, optimizer) + pbar.set_postfix({"loss": loss.item()}) + + optimizer.step() + optimizer.zero_grad() + lr_scheduler.step() - trainer.fit( - train_dataloader=train_dataloader, - epochs=NUM_EPOCHS, - test_interval=1, - hooks=hook_list, - display_progress=True, - return_output_label=False, - ) ``` - +训练 GPT-2 模型。 +```python +for epoch in range(NUM_EPOCHS): + train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator) +``` + \ No newline at end of file diff --git a/docs/source/zh-Hans/advanced_tutorials/train_vit_using_pipeline_parallelism.md b/docs/source/zh-Hans/advanced_tutorials/train_vit_using_pipeline_parallelism.md deleted file mode 100644 index 5ef863dcd423..000000000000 --- a/docs/source/zh-Hans/advanced_tutorials/train_vit_using_pipeline_parallelism.md +++ /dev/null @@ -1,247 +0,0 @@ -# 使用流水并行训练 ViT - -作者: Hongxin Liu, Yongbin Li - -**示例代码** -- [ColossalAI-Examples Pipeline Parallel ViT](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/image/vision_transformer/pipeline_parallel) - -**相关论文** -- [Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM](https://arxiv.org/abs/2104.04473) - -## 引言 - -在本教程中,你将学习如何使用流水并行从头开始训练用于图像分类的 Vision Transformer (ViT)。流水并行是一种模型并行,主要针对 GPU 内存不能满足模型容量的情况。 -通过使用流水并行,我们将原始模型分割成多个阶段,每个阶段保留原始模型的一部分。我们假设你的 GPU 内存不能容纳 ViT/L-16,而你的内存可以容纳这个模型。 - -## 目录 - -在本教程中,我们将介绍: - -1. 基于 [TIMM](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py) 定义 ViT 模型 -2. 处理数据集 -3. 使用流水并行训练 ViT - -## 导入依赖库 - -```python -import os -from collections import OrderedDict -from functools import partial - -import colossalai -import colossalai.nn as col_nn -import torch -import torch.nn as nn -from colossalai.legacy.builder import build_pipeline_model -from colossalai.legacy.engine.schedule import (InterleavedPipelineSchedule, - PipelineSchedule) -from colossalai.logging import disable_existing_loggers, get_dist_logger -from colossalai.legacy.trainer import Trainer, hooks -from colossalai.utils import MultiTimer, get_dataloader -from timm.models import vision_transformer as vit -from torchvision import transforms -from torchvision.datasets import CIFAR10 -``` - - -## 定义 Vision Transformer 模型 - -总的来说, 我们提供3种方法来建立一个流水并行的模型: - -1. `colossalai.legacy.builder.build_pipeline_model_from_cfg` -2. `colossalai.legacy.builder.build_pipeline_model` -3. 自己按阶段拆分模型 - -当你的内存能够容纳模型时,你可以使用前两种方法来建立你的模型,否则你必须自己分割模型。前两种方法首先在 CPU 上建立整个模型,然后分割模型,最后你可以直接把模型的相应部分移到 GPU 上。 - -`colossalai.legacy.builder.build_pipeline_model_from_cfg()` 接收一个模型的配置文件,它可以均匀地(按层)或平衡地(按参数大小)分割模型。 - -如果你熟悉 `PyTorch`, 你可以使用 `colossalai.legacy.builder.build_pipeline_model()` 它接收一个 `torch.nn.Sequential` 模型并按层均匀分割。 - -在本教程中,我们将修改 [TIMM/ViT](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py) to `torch.nn.Sequential`,然后使用 `colossalai.legacy.builder.build_pipeline_model()` 来建立流水线模型。 - -当数据是 **一个** `Tensor`, 你可以使用你的模型 `forward()` 中的位置参数来获得数据张量。对于流水线的第一阶段,`forward()` 的第一个位置参数是从数据加载器加载的数据张量。对于其他阶段,`forward()` 的第一个位置参数是上一阶段的输出张量。注意,如果该阶段不是最后一个阶段,则 `forward()` 的返回必须是一个 `Tensor`。 - -当数据是一个 `Tensor` 的 `dict`, 你可以使用你模型 `forward()` 的命名关键字参数来获得数据的 `dict`。 - -```python -class ViTEmbedding(nn.Module): - def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, embed_layer=vit.PatchEmbed, drop_rate=0., distilled=False): - super().__init__() - self.embed_dim = embed_dim # num_features for consistency with other models - self.num_tokens = 2 if distilled else 1 - self.patch_embed = embed_layer( - img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) - num_patches = self.patch_embed.num_patches - - self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) - self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None - self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) - self.pos_drop = nn.Dropout(p=drop_rate) - self.init_weights() - - def forward(self, x): - x = self.patch_embed(x) - cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks - if self.dist_token is None: - x = torch.cat((cls_token, x), dim=1) - else: - x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1) - x = self.pos_drop(x + self.pos_embed) - return x - - def init_weights(self): - vit.trunc_normal_(self.pos_embed, std=.02) - if self.dist_token is not None: - vit.trunc_normal_(self.dist_token, std=.02) - vit.trunc_normal_(self.cls_token, std=.02) - self.apply(vit._init_vit_weights) - - -class ViTHead(nn.Module): - def __init__(self, embed_dim=768, num_classes=1000, norm_layer=None, distilled=False, representation_size=None): - super().__init__() - norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) - self.norm = norm_layer(embed_dim) - self.num_classes = num_classes - self.distilled = distilled - self.num_features = embed_dim - # Representation layer - if representation_size and not distilled: - self.num_features = representation_size - self.pre_logits = nn.Sequential(OrderedDict([ - ('fc', nn.Linear(embed_dim, representation_size)), - ('act', nn.Tanh()) - ])) - else: - self.pre_logits = nn.Identity() - # Classifier head(s) - self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() - self.head_dist = None - if distilled: - self.head_dist = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() - self.init_weights() - - def forward(self, x): - x = self.norm(x) - if self.distilled: - x, x_dist = self.head(x[:, 0]), self.head_dist(x[:, 1]) - if self.training and not torch.jit.is_scripting(): - # during inference, return the average of both classifier predictions - return x, x_dist - else: - return (x + x_dist) / 2 - else: - x = self.pre_logits(x[:, 0]) - x = self.head(x) - return x - - def init_weights(self): - self.apply(vit._init_vit_weights) - - -def sequential_vit(img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, - num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False, - drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=vit.PatchEmbed, norm_layer=None, - act_layer=None): - norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) - act_layer = act_layer or nn.GELU - embedding = ViTEmbedding(img_size=img_size, patch_size=patch_size, in_chans=in_chans, - embed_dim=embed_dim, embed_layer=embed_layer, drop_rate=drop_rate, distilled=distilled) - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule - blocks = [vit.Block( - dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, - attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer) - for i in range(depth)] - for block in blocks: - block.apply(vit._init_vit_weights) - head = ViTHead(embed_dim=embed_dim, num_classes=num_classes, norm_layer=norm_layer, - distilled=distilled, representation_size=representation_size) - return nn.Sequential(embedding, *blocks, head) - - -def vit_large_patch16_224(**kwargs): - model_kwargs = dict(embed_dim=1024, depth=24, num_heads=16, **kwargs) - return sequential_vit(**model_kwargs) -``` - -## 处理数据集 - -一般来说, 我们在大型数据集如 ImageNet 上训练 ViT。为了简单期间,我们在这里只使用 CIFAR-10, 因为本教程只是用于流水并行训练。 - -```python -def build_cifar(batch_size): - transform_train = transforms.Compose([ - transforms.RandomCrop(224, pad_if_needed=True), - transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10), - transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), - ]) - transform_test = transforms.Compose([ - transforms.Resize(224), - transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), - ]) - - train_dataset = CIFAR10(root=os.environ['DATA'], train=True, download=True, transform=transform_train) - test_dataset = CIFAR10(root=os.environ['DATA'], train=False, transform=transform_test) - train_dataloader = get_dataloader(dataset=train_dataset, shuffle=True, batch_size=batch_size, pin_memory=True) - test_dataloader = get_dataloader(dataset=test_dataset, batch_size=batch_size, pin_memory=True) - return train_dataloader, test_dataloader -``` - -## 使用流水并行训练 ViT - -你可以在配置文件中设置流水并行的大小。`NUM_CHUNKS` 在使用交错流水线时很有用 (更多细节见 [Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM](https://arxiv.org/abs/2104.04473) )。 -原始 batch 将会被分割为 `num_microbatches`, 每个阶段每次将加载一个 micro batch。如果你确定性地知道每个阶段输出张量的形状,你可以在配置文件中设置 `tensor_shape` 来减少通信。 -我们的仓库会自动为用户生成合适的schedule来支持流水并行训练。如果你不需要模型的输出和标签,你可以在调用 `trainer.fit()` 时,将 `return_output_label` 设置为 `False`,这样能进一步减少 GPU 显存使用。 - -你应当使用 `export DATA=/path/to/cifar`。 - -```python -BATCH_SIZE = 16 -NUM_EPOCHS = 60 -NUM_CHUNKS = 1 -CONFIG = dict(NUM_MICRO_BATCHES=4, parallel=dict(pipeline=2)) - - -def train(): - disable_existing_loggers() - parser = colossalai.get_default_parser() - args = parser.parse_args() - colossalai.launch_from_torch(backend=args.backend, config=CONFIG) - logger = get_dist_logger() - - # build model - model = vit_large_patch16_224() - model = build_pipeline_model(model, num_chunks=NUM_CHUNKS, verbose=True) - - # build criterion - criterion = nn.CrossEntropyLoss() - - # optimizer - optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0) - - # build dataloader - train_dataloader, test_dataloader = build_cifar(BATCH_SIZE) - - engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model, optimizer, criterion, - train_dataloader, test_dataloader) - timer = MultiTimer() - - trainer = Trainer(engine=engine, timer=timer, logger=logger) - - hook_list = [ - hooks.LossHook(), - hooks.AccuracyHook(col_nn.metric.Accuracy()), - hooks.LogMetricByEpochHook(logger), - ] - - trainer.fit(train_dataloader=train_dataloader, - epochs=NUM_EPOCHS, - test_dataloader=test_dataloader, - test_interval=1, - hooks=hook_list, - display_progress=True) -``` - diff --git a/docs/source/zh-Hans/advanced_tutorials/train_vit_with_hybrid_parallelism.md b/docs/source/zh-Hans/advanced_tutorials/train_vit_with_hybrid_parallelism.md index f7dd8d477a66..3de41601a231 100644 --- a/docs/source/zh-Hans/advanced_tutorials/train_vit_with_hybrid_parallelism.md +++ b/docs/source/zh-Hans/advanced_tutorials/train_vit_with_hybrid_parallelism.md @@ -1,10 +1,14 @@ # 使用 Colossal-AI (从数据并行到异构并行)加速 ViT 训练详解 -作者:Yuxuan Lou +作者:Yuxuan Lou, Mingyan Jiang + +**前置教程** +- [并行插件](../basics/booster_plugins.md) +- [booster API](../basics/booster_api.md) **示例代码** -- [Colossal-AI Examples ViT on Cifar10](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/image/vision_transformer) +- [Colossal-AI Examples ViT on `beans`](https://github.com/hpcaitech/ColossalAI/blob/main/examples/images/vit/vit_train_demo.py) **相关文献** - [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/pdf/2010.11929.pdf) @@ -12,14 +16,14 @@ ## 引言 -在这个ViT模型的样例中,Colossal-AI 提供了三种不同的并行技术来加速模型训练:数据并行,流水线并行和张量并行。我们将展示如何使用这三种并行技术在 CIFAR-10 数据集上训练 ViT。为了运行项目,需要2-4个 GPU。 +在这个ViT模型的样例中,Colossal-AI 提供了三种不同的并行技术来加速模型训练:数据并行,流水线并行和张量并行。我们将展示如何使用这三种并行技术在 `beans` 数据集上训练 ViT。为了运行项目,需要2-4个 GPU。 ## 目录 1. Colossal-AI 安装方法 -2. 使用数据并行训练 ViT 步骤 -3. 使用数据流水线并行训练 ViT 步骤 -4. 使用张量并行或异构并行训练 ViT 步骤 +2. 定义VIT模型及相关训练组件 +3. 使用使用 [HybridParallelPlugin](../basics/booster_plugins.md) 增强VIT模型 +4. 使用数据并行、流水线并行及张量并行训练VIT模型 ## Colossal-AI 安装 可以通过 Python 的官方索引来安装 Colossal-AI 软件包。 @@ -27,566 +31,255 @@ pip install colossalai ``` - - -## 数据并行 -数据并行是实现加速模型训练的基本方法。通过两步可以实现训练的数据并行: -1. 构建一个配置文件 -2. 在训练脚本中修改很少的几行代码 - -### 构建配置文件 (`data_parallel/config.py`) -为了使用 Colossal-AI,第一步是构建配置文件。并且,在这里有两种变量: - -1. **Colossal-AI 功能配置** - -Colossal-AI 提供了一系列的功能来加快训练速度(包括模型并行,混合精度,零冗余优化器等)。每个功能都是由配置文件中的相应字段定义的。如果我们只用到数据并行,那么我们只需要具体说明并行模式。在本例中,我们使用 PyTorch 最初提出的混合精度训练,只需要定义混合精度配置 `fp16 = dict(mode=AMP_TYPE.TORCH)` 。 - -2. **全局超参数** - -全局超参数包括特定于模型的超参数、训练设置、数据集信息等。 +## 导入依赖库 ```python -from colossalai.amp import AMP_TYPE -# ViT Base -BATCH_SIZE = 256 -DROP_RATE = 0.1 -NUM_EPOCHS = 300 -# mix precision -fp16 = dict( - mode=AMP_TYPE.TORCH, -) -gradient_accumulation = 16 -clip_grad_norm = 1.0 -dali = dict( - gpu_aug=True, - mixup_alpha=0.2 -) -``` +from typing import Any, Callable, Iterator -### 修改训练脚本 (`/data_parallel/train_with_cifar10.py`) +import torch +import torch.distributed as dist +import torch.nn as nn +import transformers +from data import BeansDataset, beans_collator +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler as LRScheduler +from torch.utils.data import DataLoader +from tqdm import tqdm +from transformers import ViTConfig, ViTForImageClassification, ViTImageProcessor -#### 导入模块 -- Colossal-AI 相关模块 -```python import colossalai -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin +from colossalai.cluster import DistCoordinator from colossalai.logging import disable_existing_loggers, get_dist_logger -from colossalai.nn.lr_scheduler import LinearWarmupLR -from colossalai.legacy.nn.metric import Accuracy -from colossalai.legacy.trainer import Trainer, hooks +from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR +from colossalai.nn.optimizer import HybridAdam ``` - -- 其他模块 +## 定义 Vision Transformer 模型 +定义超参数 ```python -import os -import torch -from timm.models import vit_base_patch16_224 -from torchvision import transforms -from torchvision.datasets import CIFAR10 +SEED = 42 +MODEL_PATH = "google/vit-base-patch16-224" +LEARNING_RATE = 5e-5 +WEIGHT_DECAY = 0.0 +NUM_EPOCH = 3 +WARMUP_RATIO = 0.3 +TP_SIZE = 2 +PP_SIZE = 2 ``` - -#### 启动 Colossal-AI - -在训练脚本中,在构建好配置文件后,需要为 Colossal-AI 初始化分布式环境。我们将此过程称为 `launch` 。在 Colossal-AI 中,我们提供了几种启动方法来初始化分布式后端。在大多数情况下,您可以使用 `colossalai.launch` 和 `colossalai.get_default_parser ` 来实现使用命令行传递参数。此外,Colossal-AI 可以利用 PyTorch 提供的现有启动工具,正如许多用户通过使用熟知的 `colossalai.launch_from_torch` 那样。更多详细信息,您可以查看相关[文档](https://www.colossalai.org/docs/basics/launch_colossalai)。 - - +首先我们创建一个分布式环境 ```python -# initialize distributed setting -parser = colossalai.get_default_parser() -args = parser.parse_args() -colossalai.launch_from_torch(config=args.config) -disable_existing_loggers() -logger = get_dist_logger() +# Launch ColossalAI +colossalai.launch_from_torch(config={}, seed=SEEDå) +coordinator = DistCoordinator() +world_size = coordinator.world_size ``` - -初始化后,您可以使用 `colossalai.core.global_context` 访问配置文件中的变量。 - +在训练之前您可以按照正常流程定义模型训练的相关组,如定义模型,数据加载器,优化器等。需要注意的是,当使用管道并行时,还需定义一个criterion函数,该函数的输入是模型前向的输入和输出,返回的是loss。 +获取数据集, `BeansDataset`定义在[data.py](https://github.com/hpcaitech/ColossalAI/blob/main/examples/images/vit/data.py) ```python -#access parameters -print(gpc.config.BATCH_SIZE) +image_processor = ViTImageProcessor.from_pretrained(MODEL_PATH) +train_dataset = BeansDataset(image_processor, TP_SIZE, split="train") +eval_dataset = BeansDataset(image_processor, RP_SIZE, split="validation") +num_labels = train_dataset.num_labels ``` - -#### 构建模型 - -如果只需要数据并行性,则无需对模型代码进行任何更改。这里,我们使用 `timm` 中的 `vit_base_patch16_224`。 - +定义VIT模型: ```python -# build model -model = vit_base_patch16_224(drop_rate=0.1, num_classes=gpc.config.NUM_CLASSES) +config = ViTConfig.from_pretrained(MODEL_PATH) +config.num_labels = num_labels +config.id2label = {str(i): c for i, c in enumerate(train_dataset.label_names)} +config.label2id = {c: str(i) for i, c in enumerate(train_dataset.label_names)} +model = ViTForImageClassification.from_pretrained( + MODEL_PATH, config=config, ignore_mismatched_sizes=True +) ``` - -#### 构建 CIFAR-10 数据加载器 -`colossalai.utils.get_dataloader` 可以帮助您轻松构建数据加载器。 - +定义optimizer: ```python -def build_cifar(batch_size): - transform_train = transforms.Compose([ - transforms.RandomCrop(224, pad_if_needed=True), - transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10), - transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), - ]) - transform_test = transforms.Compose([ - transforms.Resize(224), - transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), - ]) - train_dataset = CIFAR10(root=os.environ['DATA'], train=True, download=True, transform=transform_train) - test_dataset = CIFAR10(root=os.environ['DATA'], train=False, transform=transform_test) - train_dataloader = get_dataloader(dataset=train_dataset, shuffle=True, batch_size=batch_size, pin_memory=True) - test_dataloader = get_dataloader(dataset=test_dataset, batch_size=batch_size, pin_memory=True) - return train_dataloader, test_dataloader -# build dataloader -train_dataloader, test_dataloader = build_cifar(gpc.config.BATCH_SIZE) +optimizer = HybridAdam(model.parameters(), lr=(LEARNING_RATE * world_size), weight_decay=WEIGHT_DECAY) ``` - -#### 定义优化器,损失函数和学习率调度器 - -Colossal-AI 提供了自己的优化器、损失函数和学习率调度器。PyTorch 的这些组件与Colossal-AI也兼容。 - +定义lr scheduler: ```python -# build optimizer -optimizer = colossalai.nn.Lamb(model.parameters(), lr=1.8e-2, weight_decay=0.1) -# build loss -criterion = torch.nn.CrossEntropyLoss() -# lr_scheduler -lr_scheduler = LinearWarmupLR(optimizer, warmup_steps=50, total_steps=gpc.config.NUM_EPOCHS) -``` - -#### 启动用于训练的 Colossal-AI 引擎 - -Engine 本质上是对模型、优化器和损失函数的封装类。当我们使用 `colossalai.initialize` ,将返回一个 engine 对象,并且它已经按照配置文件中的指定内容,配置了梯度剪裁、梯度累积和零冗余优化器等功能。之后,基于 Colossal-AI 的 engine 我们可以进行模型训练。 - -```python -engine, train_dataloader, test_dataloader, _ = colossalai.initialize( - model, optimizer, criterion, train_dataloader, test_dataloader +total_steps = len(train_dataloader) * NUM_EPOCH +num_warmup_steps = int(WARMUP_RATIO * total_steps) +lr_scheduler = CosineAnnealingWarmupLR( + optimizer=optimizer, total_steps=(len(train_dataloader) * NUM_EPOCH), warmup_steps=num_warmup_steps ) ``` - -#### 训练:Trainer 应用程序编程接口 -Trainer 是一个更高级的封装类,用户可以用更少的代码就可以实现训练。通过传递 engine 对象很容易创建 trainer 对象。 - -此外,在 trainer 中,用户可以自定义一些挂钩,并将这些挂钩连接到 trainer 对象。钩子对象将根据训练方案定期执行生命周期方法。例如,`LRSchedulerHook` 将执行`lr_scheduler.step()` 在 `after_train_iter` 或 `after_train_epoch` 阶段更新模型的学习速率。 - -```python -# build trainer -trainer = Trainer(engine=engine, logger=logger) -# build hooks -hook_list = [ - hooks.LossHook(), - hooks.AccuracyHook(accuracy_func=MixupAccuracy()), - hooks.LogMetricByEpochHook(logger), - hooks.LRSchedulerHook(lr_scheduler, by_epoch=True), - # comment if you do not need to use the hooks below - hooks.SaveCheckpointHook(interval=1, checkpoint_dir='./ckpt'), - hooks.TensorboardHook(log_dir='./tb_logs', ranks=[0]), -] -``` - -使用 `trainer.fit` 进行训练: - +定义criterion函数: ```python -# start training -trainer.fit( - train_dataloader=train_dataloader, - test_dataloader=test_dataloader, - epochs=gpc.config.NUM_EPOCHS, - hooks=hook_list, - display_progress=True, - test_interval=1 -) +def _criterion(outputs, inputs): + return outputs.loss ``` - -### 开始训练 -`DATA` 是自动下载和存储 CIFAR-10 数据集的文件路径。 - -`` 是要用于使用 CIFAR-10 数据集,以数据并行方式训练 ViT 的 GPU 数。 - -```bash -export DATA= -# If your torch >= 1.10.0 -torchrun --standalone --nproc_per_node train_dp.py --config ./configs/config_data_parallel.py -# If your torch >= 1.9.0 -# python -m torch.distributed.run --standalone --nproc_per_node= train_dp.py --config ./configs/config_data_parallel.py -# Otherwise -# python -m torch.distributed.launch --nproc_per_node --master_addr --master_port 29500 train_dp.py --config ./configs/config.py -``` - - - -## 流水线并行 -除了数据并行性,Colossal-AI 还支持流水线并行。具体而言,Colossal-AI 使用 NVIDIA 引入的 1F1B 流水线。更多详细信息,您可以查看相关[文档](https://www.colossalai.org/tutorials/features/pipeline_parallel)。 - -### 构建配置文件(`hybrid_parallel/configs/vit_pipeline.py`) -要在数据并行的基础上应用流水线并行,只需添加一个 **parallel dict** +## 增强VIT模型 +我们开始使用colossalai的混合并行策略来增强模型,首先我们先定义一个`HybridParallelPlugin`的对象,[`HybridParallelPlugin`](../basics/booster_plugins.md)封装了colossalai的多种并行策略,之后我们使用`HybridParallelPlugin`对象来初始化booster并调用`booster.boost`来增强模型。 +### 半精度训练 +在`HybridParallelPlugin`插件中,通过设置`precision`确定训练精度,可支持'fp16','bf16','fp32'三种类型。'fp16','bf16'为半精度类型,半精度在`HybridParallelPlugin`中有两种应用场景,一是使用zero数据并行时,需设置为半精度;二是指定使用amp半精度进行训练。 + +使用amp半精度时,可设置相关参数。 +`initial_scale`(浮点数,可选项):AMP的初始损失缩放比例。默认值为2**16。 +`min_scale`(浮点数,可选项):AMP的最小损失缩放比例。默认值为1。 +`growth_factor`(浮点数,可选项):在使用AMP时,用于增加损失缩放比例的乘法因子。默认值为2。 +`backoff_factor`(浮点数,可选项):在使用AMP时,用于减少损失缩放比例的乘法因子。默认值为0.5。 +`growth_interval`(整数,可选项):在使用AMP时,当没有溢出时增加损失缩放比例的步数。默认值为1000。 +`hysteresis`(整数,可选项):在使用AMP时,减少损失缩放比例之前的溢出次数。默认值为2。 +`max_scale`(浮点数,可选项):AMP的最大损失缩放比例。默认值为2**32。 + +使用AMP的plugin示例: ```python -from colossalai.amp import AMP_TYPE -parallel = dict( - pipeline=2 -) -# pipeline config -NUM_MICRO_BATCHES = parallel['pipeline'] -TENSOR_SHAPE = (BATCH_SIZE // NUM_MICRO_BATCHES, SEQ_LENGTH, HIDDEN_SIZE) -fp16 = dict(mode=AMP_TYPE.NAIVE) -clip_grad_norm = 1.0 +plugin = HybridParallelPlugin( + precision="fp16", + initial_scale=1, + ) ``` -其他配置: -```python -# hyperparameters -# BATCH_SIZE is as per GPU -# global batch size = BATCH_SIZE x data parallel size -BATCH_SIZE = 256 -LEARNING_RATE = 3e-3 -WEIGHT_DECAY = 0.3 -NUM_EPOCHS = 300 -WARMUP_EPOCHS = 32 -# model config -IMG_SIZE = 224 -PATCH_SIZE = 16 -HIDDEN_SIZE = 768 -DEPTH = 12 -NUM_HEADS = 12 -MLP_RATIO = 4 -NUM_CLASSES = 10 -CHECKPOINT = True -SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE) ** 2 + 1 # add 1 for cls token -``` +### 张量并行 +`HybridParallelPlugin`是通过shardformer实现张量并行,在该插件中,可设置`tp_size`确定张量并行组的大小,此外,还有多个参数可设置张量并行时的优化特性: -### 构建流水线模型 (`/hybrid_parallel/model/vit.py`) -Colossal-AI 提供了两种从现有模型构建流水线模型的方法。 -- `colossalai.legacy.builder.build_pipeline_model_from_cfg` -- `colossalai.legacy.builder.build_pipeline_model` +`enable_all_optimization`(布尔类型,可选项):是否启用Shardformer支持的所有优化方法,目前所有优化方法包括融合归一化、flash attention和JIT。默认为False。 +`enable_fused_normalization`(布尔类型,可选项):是否在Shardformer中启用融合归一化。默认为False。 +`enable_flash_attention`(布尔类型,可选项):是否在Shardformer中启用flash attention。默认为False。 +`enable_jit_fused`(布尔类型,可选项):是否在Shardformer中启用JIT。默认为False。 +`enable_sequence_parallelism`(布尔类型):是否在Shardformer中启用序列并行性。默认为False。 +`enable_sequence_overlap`(布尔类型):是否在Shardformer中启用序列重叠性。默认为False。 -此外,您还可以使用 Colossal-AI 从头开始构建流水线模型。 +张量并行的plugin示例 ```python -import math -from typing import Callable -import inspect -import torch -from colossalai import nn as col_nn -from colossalai.legacy.registry import LAYERS, MODELS -from colossalai.logging import get_dist_logger -from colossalai.core import global_context as gpc -from colossalai.context import ParallelMode -from colossalai.legacy.builder.pipeline import partition_uniform -from torch import dtype, nn -from model_zoo.vit.vit import ViTBlock, ViTEmbedding, ViTHead -@MODELS.register_module -class PipelineVisionTransformer(nn.Module): - def __init__(self, - img_size: int = 224, - patch_size: int = 16, - in_chans: int = 3, - num_classes: int = 1000, - depth: int = 12, - num_heads: int = 12, - dim: int = 768, - mlp_ratio: int = 4, - attention_dropout: float = 0., - dropout: float = 0.1, - drop_path: float = 0., - layernorm_epsilon: float = 1e-6, - activation: Callable = nn.functional.gelu, - representation_size: int = None, - dtype: dtype = None, - bias: bool = True, - checkpoint: bool = False, - init_method: str = 'torch', - first_stage=True, - last_stage=True, - start_idx=None, - end_idx=None,): - super().__init__() - layers = [] - if first_stage: - embed = ViTEmbedding(img_size=img_size, - patch_size=patch_size, - in_chans=in_chans, - embedding_dim=dim, - dropout=dropout, - dtype=dtype, - init_method=init_method) - layers.append(embed) - # stochastic depth decay rule - dpr = [x.item() for x in torch.linspace(0, drop_path, depth)] - if start_idx is None and end_idx is None: - start_idx = 0 - end_idx = depth - blocks = [ - ViTBlock( - dim=dim, - num_heads=num_heads, - mlp_ratio=mlp_ratio, - attention_dropout=attention_dropout, - dropout=dropout, - drop_path=dpr[i], - activation=activation, - dtype=dtype, - bias=bias, - checkpoint=checkpoint, - init_method=init_method, - ) for i in range(start_idx, end_idx) - ] - layers.extend(blocks) - if last_stage: - norm = col_nn.LayerNorm(normalized_shape=dim, eps=layernorm_epsilon, dtype=dtype) - head = ViTHead(dim=dim, - num_classes=num_classes, - representation_size=representation_size, - dtype=dtype, - bias=bias, - init_method=init_method) - layers.extend([norm, head]) - self.layers = nn.Sequential( - *layers +plugin = HybridParallelPlugin( + tp_size=4, + enable_all_optimization=True ) - def forward(self, x): - x = self.layers(x) - return x -def _filter_kwargs(func, kwargs): - sig = inspect.signature(func) - return {k: v for k, v in kwargs.items() if k in sig.parameters} -def _build_pipeline_vit(module_cls, num_layers, num_chunks, device=torch.device('cuda'), **kwargs): - logger = get_dist_logger() - if gpc.is_initialized(ParallelMode.PIPELINE): - pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) - pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) - else: - pipeline_size = 1 - pipeline_rank = 0 - rank = gpc.get_global_rank() - parts = partition_uniform(num_layers, pipeline_size, num_chunks)[pipeline_rank] - models = [] - for start, end in parts: - kwargs['first_stage'] = start == 0 - kwargs['last_stage'] = end == num_layers - kwargs['start_idx'] = start - kwargs['end_idx'] = end - logger.info(f'Rank{rank} build layer {start}-{end}, {end-start}/{num_layers} layers') - chunk = module_cls(**_filter_kwargs(module_cls.__init__, kwargs)).to(device) - models.append(chunk) - if len(models) == 1: - model = models[0] - else: - model = nn.ModuleList(models) - return model -def build_pipeline_vit(num_layers, num_chunks, device=torch.device('cuda'), **kwargs): - return _build_pipeline_vit(PipelineVisionTransformer, num_layers, num_chunks, device, **kwargs) ``` - -### 修改训练脚本 (`/hybrid_parallel/train_with_cifar10.py`) - -#### 导入模块 +### 流水线并行 +`HybridParallelPlugin`通过设置`pp_size`确定流水线并行组的大小,`num_microbatches`设置流水线并行时将整个batch划分为小batch的数量,`microbatch_size`可设置小batch的大小,插件会优先使用`num_microbatches`来确定micro batch的配置。 +流水线并行的plugin示例 ```python -from colossalai.legacy.engine.schedule import (InterleavedPipelineSchedule, - PipelineSchedule) -from colossalai.utils import MultiTimer -import os -import colossalai -import torch -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.logging import get_dist_logger -from colossalai.nn import CrossEntropyLoss -from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR -from colossalai.utils import is_using_pp, get_dataloader -from model.vit import build_pipeline_vit -from model_zoo.vit.vit import _create_vit_model -from tqdm import tqdm -from torchvision import transforms -from torchvision.datasets import CIFAR10 +plugin = HybridParallelPlugin( + pp_size=4, + num_microbatches=None, + microbatch_size=1 + ) ``` - -#### 启动 Colossal-AI -`colossalai.utils.is_using_pp` 可以帮您检查配置文件是否满足流水线并行的要求。 - +### 数据并行 +`HybridParallelPlugin`插件的数据并行包括zero-dp系列及torch DDP。当`zero_stage`为0(默认值)时表示使用torch DDP,注意torch DDP与流水线并行有冲突,不能一起使用。`zero_stage`为1时表示使用zero1策略。`zero_stage`为2使用zero2,zero2策略也无法与流水线并行一起使用。如果想使用zero3,请使用[`GeminiPlugin`](../basics/booster_plugins.md)。使用zero系列的数据并行,请设置训练精度为半精度。当未指定使用zero及流水线并行,且world_size//(tp_size*pp_size)大于1时,`HybridParallelPlugin`会为您打开torch DDP并行策略。 +torch DDP相关参数设置: +`broadcast_buffers`(布尔值,可选项):在使用DDP时,在训练开始时是否广播缓冲区。默认为True。 +`ddp_bucket_cap_mb`(整数,可选项):在使用DDP时的桶大小(以MB为单位)。默认为25。 +`find_unused_parameters`(布尔值,可选项):在使用DDP时是否查找未使用的参数。默认为False。 +`check_reduction(布尔值,可选项):在使用DDP时是否检查减少。默认为False。 +`gradient_as_bucket_view`(布尔值,可选项):在使用DDP时是否将梯度作为桶视图使用。默认为False。 +`static_graph`(布尔值,可选项):在使用DDP时是否使用静态图。默认为False。 + +Torch DDP的plugin示例 ```python -# initialize distributed setting -parser = colossalai.get_default_parser() -args = parser.parse_args() -# launch from torch -colossalai.launch_from_torch(config=args.config) -# get logger -logger = get_dist_logger() -logger.info("initialized distributed environment", ranks=[0]) -if hasattr(gpc.config, 'LOG_PATH'): - if gpc.get_global_rank() == 0: - log_path = gpc.config.LOG_PATH - if not os.path.exists(log_path): - os.mkdir(log_path) - logger.log_to_file(log_path) -use_pipeline = is_using_pp() +plugin = HybridParallelPlugin( + tp_size=2, + pp_size=1, + zero_stage=0, + precision="fp16", + initial_scale=1, + ) ``` +若并行进程为4,则torch DDP的并行组大小为2. +zero相关参数设置: +`zero_bucket_size_in_m`(整数,可选项):在使用ZeRO时,以百万元素为单位的梯度减小桶大小。默认为12。 +`cpu_offload`(布尔值,可选项):在使用ZeRO时是否打开`cpu_offload`。默认为False。 +`communication_dtype`(torch数据类型,可选项):在使用ZeRO时的通信数据类型。如果未指定,则将使用参数的数据类型。默认为None。 +`overlap_communication`(布尔值,可选项):在使用ZeRO时是否重叠通信和计算。默认为True。 -#### 定义模型 +zero1的plugin示例 ```python -# create model -model_kwargs = dict(img_size=gpc.config.IMG_SIZE, - patch_size=gpc.config.PATCH_SIZE, - dim=gpc.config.HIDDEN_SIZE, - depth=gpc.config.DEPTH, - num_heads=gpc.config.NUM_HEADS, - mlp_ratio=gpc.config.MLP_RATIO, - num_classes=gpc.config.NUM_CLASSES, - init_method='jax', - checkpoint=gpc.config.CHECKPOINT) -if use_pipeline: - model = build_pipeline_vit(num_layers=model_kwargs['depth'], num_chunks=1, **model_kwargs) -else: - model = _create_vit_model(**model_kwargs) -``` - -#### 计算参数个数 - -您可以轻松计算不同流水线阶段上的模型参数个数。 - -``` -# count number of parameters -total_numel = 0 -for p in model.parameters(): - total_numel += p.numel() -if not gpc.is_initialized(ParallelMode.PIPELINE): - pipeline_stage = 0 -else: - pipeline_stage = gpc.get_local_rank(ParallelMode.PIPELINE) -logger.info(f"number of parameters: {total_numel} on pipeline stage {pipeline_stage}") +plugin = HybridParallelPlugin( + tp_size=1, + pp_size=1, + zero_stage=1, + cpu_offload=True, + precision="fp16", + initial_scale=1, + ) ``` -#### 构建数据加载器,优化器等组件 +### 混合并行 +可参考上述的策略自定义合适的混合并行策略。定义混合并行的插件,并使用该插件定义一个booster: ```python -def build_cifar(batch_size): - transform_train = transforms.Compose([ - transforms.RandomCrop(224, pad_if_needed=True), - transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10), - transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), - ]) - transform_test = transforms.Compose([ - transforms.Resize(224), - transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), - ]) - train_dataset = CIFAR10(root=os.environ['DATA'], train=True, download=True, transform=transform_train) - test_dataset = CIFAR10(root=os.environ['DATA'], train=False, transform=transform_test) - train_dataloader = get_dataloader(dataset=train_dataset, shuffle=True, batch_size=batch_size, pin_memory=True) - test_dataloader = get_dataloader(dataset=test_dataset, batch_size=batch_size, pin_memory=True) - return train_dataloader, test_dataloader - - -# create dataloaders -train_dataloader , test_dataloader = build_cifar() -# create loss function -criterion = CrossEntropyLoss(label_smoothing=0.1) -# create optimizer -optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY) -# create lr scheduler -lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, - total_steps=gpc.config.NUM_EPOCHS, - warmup_steps=gpc.config.WARMUP_EPOCHS) +plugin = HybridParallelPlugin( + tp_size=TP_SIZE, + pp_size=PP_SIZE, + num_microbatches=None, + microbatch_size=1, + enable_all_optimization=True, + precision="fp16", + initial_scale=1, + ) +booster = Booster(plugin=plugin) ``` - -#### 启动 Colossal-AI 引擎 - +接着我们使用`booster.boost`来将plugin所封装的特性注入到模型训练组件中。 ```python -# initialize -engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model=model, - optimizer=optimizer, - criterion=criterion, - train_dataloader=train_dataloader, - test_dataloader=test_dataloader) -logger.info("Engine is built", ranks=[0]) +model, optimizer, _criterion, train_dataloader, lr_scheduler = booster.boost( + model=model, optimizer=optimizer, criterion=criterion, dataloader=train_dataloader, lr_scheduler=lr_scheduler + ) ``` - -#### 训练:基于engine - -在数据并行示例中,我们展示了如何使用 Trainer API 训练模型。我们还可以直接训练基于 engine 的模型。通过这种方式,您可以使用更多功能自定义训练方法。 - +## 使用混合并行训练 ViT +最后就可以使用混合并行策略来训练模型了。我们先定义一个训练函数,描述训练过程。需要注意的是,如果使用了管道并行策略,需要调用`booster.execute_pipeline`来执行模型的训练,它会调用`scheduler`管理模型的前后向操作。 ```python -data_iter = iter(train_dataloader) -for epoch in range(gpc.config.NUM_EPOCHS): - # training - engine.train() - if gpc.get_global_rank() == 0: - description = 'Epoch {} / {}'.format( - epoch, - gpc.config.NUM_EPOCHS +def run_forward_backward( + model: nn.Module, + optimizer: Optimizer, + criterion: Callable[[Any, Any], torch.Tensor], + data_iter: Iterator, + booster: Booster, +): + if optimizer is not None: + optimizer.zero_grad() + if isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1: + # run pipeline forward backward when enabling pp in hybrid parallel plugin + output_dict = booster.execute_pipeline( + data_iter, model, criterion, optimizer, return_loss=True, return_outputs=True ) - progress = tqdm(range(len(train_dataloader)), desc=description) + loss, outputs = output_dict["loss"], output_dict["outputs"] else: - progress = range(len(train_dataloader)) - for _ in progress: - engine.zero_grad() - engine.execute_schedule(data_iter, return_output_label=False) - engine.step() - lr_scheduler.step() -``` - -### 开始训练 -```bash -export DATA= -# If your torch >= 1.10.0 -torchrun --standalone --nproc_per_node train_hybrid.py --config ./configs/config_pipeline_parallel.py -# If your torch >= 1.9.0 -# python -m torch.distributed.run --standalone --nproc_per_node= train_hybrid.py --config ./configs/config_pipeline_parallel.py + batch = next(data_iter) + batch = move_to_cuda(batch, torch.cuda.current_device()) + outputs = model(**batch) + loss = criterion(outputs, None) + if optimizer is not None: + booster.backward(loss, optimizer) + +def train_epoch( + epoch: int, + model: nn.Module, + optimizer: Optimizer, + criterion: Callable[[Any, Any], torch.Tensor], + lr_scheduler: LRScheduler, + dataloader: DataLoader, + booster: Booster, + coordinator: DistCoordinator, +): + torch.cuda.synchronize() + + num_steps = len(dataloader) + data_iter = iter(dataloader) + enable_pbar = coordinator.is_master() + if isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1: + # when using pp, only the last stage of master pipeline (dp_rank and tp_rank are both zero) shows pbar + tp_rank = dist.get_rank(booster.plugin.tp_group) + dp_rank = dist.get_rank(booster.plugin.dp_group) + enable_pbar = tp_rank == 0 and dp_rank == 0 and booster.plugin.stage_manager.is_last_stage() + model.train() + + with tqdm(range(num_steps), desc=f"Epoch [{epoch + 1}]", disable=not enable_pbar) as pbar: + for _ in pbar: + loss, _ = run_forward_backward(model, optimizer, criterion, data_iter, booster) + optimizer.step() + lr_scheduler.step() + + # Print batch loss + if enable_pbar: + pbar.set_postfix({"loss": loss.item()}) ``` - - - - -## 张量并行和异构并行 -张量并行将每个权重参数跨多个设备进行分区,以减少内存负载。Colossal-AI 支持 1D、2D、2.5D 和 3D 张量并行。此外,还可以将张量并行、流水线并行和数据并行结合起来,实现混合并行。Colossal-AI 还提供了一种简单的方法来应用张量并行和混合并行。只需在配置文件中更改几行代码即可实现流水线并行。 - -### 构造您的配置文件 (`/hybrid_parallel/configs/vit_1d_tp2_pp2.py`) -使用张量并行,只需将相关信息添加到 **parallel dict**。具体而言,`TENSOR_PARALLEL_MODE` 可以是“1d”、“2d”、“2.5d”、“3d”。不同并行度的大小应满足:`#GPUs = pipeline parallel size x tensor parallel size x data parallel size`。在指定 GPU 数量、流水线并行大小和张量并行大小后 `data parallel size` 会自动计算。 - -```python -from colossalai.amp import AMP_TYPE -# parallel setting -TENSOR_PARALLEL_SIZE = 2 -TENSOR_PARALLEL_MODE = '1d' -parallel = dict( - pipeline=2, - tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE) -) -fp16 = dict(mode=AMP_TYPE.NAIVE) -clip_grad_norm = 1.0 -# pipeline config -NUM_MICRO_BATCHES = parallel['pipeline'] -TENSOR_SHAPE = (BATCH_SIZE // NUM_MICRO_BATCHES, SEQ_LENGTH, HIDDEN_SIZE) -``` - -其他配置: +开始训练模型 ```python -# hyperparameters -# BATCH_SIZE is as per GPU -# global batch size = BATCH_SIZE x data parallel size -BATCH_SIZE = 256 -LEARNING_RATE = 3e-3 -WEIGHT_DECAY = 0.3 -NUM_EPOCHS = 300 -WARMUP_EPOCHS = 32 -# model config -IMG_SIZE = 224 -PATCH_SIZE = 16 -HIDDEN_SIZE = 768 -DEPTH = 12 -NUM_HEADS = 12 -MLP_RATIO = 4 -NUM_CLASSES = 10 -CHECKPOINT = True -SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE) ** 2 + 1 # add 1 for cls token -``` - -### 开始训练 -```bash -export DATA= -# If your torch >= 1.10.0 -torchrun --standalone --nproc_per_node train_hybrid.py --config ./configs/config_hybrid_parallel.py -# If your torch >= 1.9.0 -# python -m torch.distributed.run --standalone --nproc_per_node= train_hybrid.py --config ./configs/config_hybrid_parallel.py +for epoch in range(NUM_EPOCH): + train_epoch(epoch, model, optimizer, criterion, lr_scheduler, train_dataloader, booster, coordinator) ``` diff --git a/docs/source/zh-Hans/features/gradient_accumulation_with_booster.md b/docs/source/zh-Hans/features/gradient_accumulation_with_booster.md index 3ad9b2e07a95..824308f94654 100644 --- a/docs/source/zh-Hans/features/gradient_accumulation_with_booster.md +++ b/docs/source/zh-Hans/features/gradient_accumulation_with_booster.md @@ -1,6 +1,6 @@ # 梯度累积 -作者: [Mingyan Jiang](https://github.com/jiangmingyan) +作者: [Mingyan Jiang](https://github.com/jiangmingyan), [Baizhou Zhang](https://github.com/Fridge003) **前置教程** - [训练中使用Booster](../basics/booster_api.md) @@ -93,6 +93,7 @@ model, optimizer, criterion, train_dataloader, _ = booster.boost(model=model, dataloader=train_dataloader) ``` + ### 步骤 5. 使用booster训练 使用booster构建一个普通的训练循环,验证梯度累积。 `param_by_iter` 记录分布训练的信息。 ```python @@ -144,4 +145,29 @@ iteration 2, first 10 elements of param: tensor([-0.0208, 0.0189, 0.0234, 0.0 iteration 3, first 10 elements of param: tensor([-0.0141, 0.0464, 0.0507, 0.0321, 0.0356, -0.0150, 0.0172, -0.0118, 0.0222, 0.0473], device='cuda:0', grad_fn=) ``` +## 在Gemini插件中使用梯度累积 + +目前支持`no_sync()`方法的插件包括 `TorchDDPPlugin` 和 `LowLevelZeroPlugin`(需要设置参数`stage`为1). `GeminiPlugin` 不支持 `no_sync()` 方法, 但是它可以通过和`pytorch`类似的方式来使用同步的梯度累积。 + +为了开启梯度累积功能,在初始化`GeminiPlugin`的时候需要将参数`enable_gradient_accumulation`设置为`True`。以下是 `GeminiPlugin` 进行梯度累积的伪代码片段: + +```python +... +plugin = GeminiPlugin(..., enable_gradient_accumulation=True) +booster = Booster(plugin=plugin) +... + +... +for idx, (input, label) in enumerate(train_dataloader): + output = gemini_model(input.cuda()) + train_loss = criterion(output, label.cuda()) + train_loss = train_loss / GRADIENT_ACCUMULATION + booster.backward(train_loss, gemini_optimizer) + + if idx % (GRADIENT_ACCUMULATION - 1) == 0: + gemini_optimizer.step() # zero_grad is automatically done +... +``` + + diff --git a/examples/inference/bench_llama.py b/examples/inference/bench_llama.py index 90d49f6a264a..0ca1953c6a41 100644 --- a/examples/inference/bench_llama.py +++ b/examples/inference/bench_llama.py @@ -3,7 +3,6 @@ import time import torch -from torch.profiler import ProfilerActivity, profile, record_function from transformers import LlamaForCausalLM, LlamaTokenizer import colossalai @@ -16,6 +15,7 @@ def print_perf_stats(latency_set, config, bs, warmup=3): + torch.cuda.empty_cache() # trim warmup queries latency_set = list(latency_set) latency_set = latency_set[warmup:] @@ -38,24 +38,29 @@ def run_llama_test(args): max_batch_size = args.batch_size max_input_len = args.input_len max_output_len = args.output_len + args.test_mode + + print("max_batch_size : " + str(max_batch_size)) tokenizer = LlamaTokenizer.from_pretrained(llama_model_path) tokenizer.pad_token_id = tokenizer.unk_token_id model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.eos_token_id) model = model.half() - model_config = model.config + model.config shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True) infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len) - generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False) + generate_kwargs = dict(max_new_tokens=1, do_sample=False) input_tokens = { "input_ids": torch.randint(1, 1000, (max_batch_size, max_input_len), device="cuda"), "attention_mask": torch.ones((max_batch_size, max_input_len), device="cuda"), } iters = 10 - times = [] + prefill_times = [] + + warmup = 3 for i in range(iters): torch.cuda.synchronize() @@ -65,17 +70,39 @@ def run_llama_test(args): end = time.time() out_len = outputs.shape[1] print("generation time {} s".format(str(end - start))) + print(out_len - max_input_len) + prefill_times.append((end - start) / (out_len - max_input_len)) + + prefill_times = prefill_times[warmup:] + prefill_time_avg = sum(prefill_times) / len(prefill_times) + generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False) + + times = [] + decoder_times = [] + for i in range(iters): + torch.cuda.synchronize() + start = time.time() + outputs = infer_engine.generate(input_tokens, **generate_kwargs) + torch.cuda.synchronize() + end = time.time() + out_len = outputs.shape[1] + print("generation time {} s".format(str(end - start))) + print(out_len - max_input_len) times.append((end - start) / (out_len - max_input_len)) + if args.test_mode == "decoder_test": + decoder_times.append((end - start - prefill_time_avg) / (out_len - max_input_len - 1)) + + times = times[warmup:] + latency = sum(times) / len(times) + print("total process latency is : " + str(latency) + " s") + print("total throughput is : " + str(1 / latency * max_batch_size)) - print("outputs, ", len(outputs)) - print_perf_stats(times, model_config, max_batch_size) + if args.test_mode == "decoder_test": + decoder_times = decoder_times[warmup:] + latency = sum(decoder_times) / len(decoder_times) - with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof: - with record_function("model_inference"): - torch.cuda.synchronize() - outputs = infer_engine.generate(input_tokens, **generate_kwargs) - torch.cuda.synchronize() - print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) + print("decoder process latency is : " + str(latency) + " s") + print("decoder throughput is : " + str(1 / latency * max_batch_size)) def check_llama(rank, world_size, port, args): @@ -95,8 +122,11 @@ def test_llama(args): parser.add_argument("-p", "--path", type=str, help="Model path", required=True) parser.add_argument("-tp", "--tp_size", type=int, default=1, help="Tensor parallel size") parser.add_argument("-b", "--batch_size", type=int, default=16, help="Maximum batch size") - parser.add_argument("--input_len", type=int, default=1024, help="Maximum input length") + parser.add_argument("--input_len", type=int, default=256, help="Maximum input length") parser.add_argument("--output_len", type=int, default=128, help="Maximum output length") + parser.add_argument( + "--test_mode", type=str, help="Test mode", default="e2e_test", choices=["e2e_test", "decoder_test"] + ) args = parser.parse_args() diff --git a/examples/inference/smoothquant_llama.py b/examples/inference/smoothquant_llama.py new file mode 100644 index 000000000000..ce7a00aa2739 --- /dev/null +++ b/examples/inference/smoothquant_llama.py @@ -0,0 +1,69 @@ +import argparse +import os + +import torch +from datasets import load_dataset +from transformers import LlamaTokenizer + +from colossalai.inference.quant.smoothquant.models.llama import SmoothLlamaForCausalLM + + +def build_model_and_tokenizer(model_name): + tokenizer = LlamaTokenizer.from_pretrained(model_name, model_max_length=512) + kwargs = {"torch_dtype": torch.float16, "device_map": "sequential"} + model = SmoothLlamaForCausalLM.from_pretrained(model_name, **kwargs) + model = model.to(torch.float32) + return model, tokenizer + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--model-name", type=str, help="model name") + parser.add_argument( + "--output-path", + type=str, + help="where to save the checkpoint", + ) + parser.add_argument( + "--dataset-path", + type=str, + help="location of the calibration dataset", + ) + parser.add_argument("--num-samples", type=int, default=512) + parser.add_argument("--seq-len", type=int, default=512) + args = parser.parse_args() + return args + + +@torch.no_grad() +def main(): + args = parse_args() + model_path = args.model_name + dataset_path = args.dataset_path + output_path = args.output_path + num_samples = 10 + seq_len = 512 + + model, tokenizer = build_model_and_tokenizer(model_path) + if not os.path.exists(dataset_path): + print(f"Cannot find the dataset at {args.dataset_path}") + raise FileNotFoundError + dataset = load_dataset("json", data_files=dataset_path, split="train") + + model.quantized(tokenizer, dataset, num_samples=num_samples, seq_len=seq_len) + model = model.cuda() + + model.save_quantized(output_path, model_basename="llama-7b") + + model = SmoothLlamaForCausalLM.from_quantized(output_path, model_basename="llama-7b") + model = model.cuda() + + generate_kwargs = dict(max_new_tokens=16, do_sample=False, use_cache=True) + input_tokens = tokenizer(["today is "], return_tensors="pt").to("cuda") + out = model.generate(**input_tokens, **generate_kwargs) + text = tokenizer.batch_decode(out) + print("out is:", text) + + +if __name__ == "__main__": + main() diff --git a/op_builder/smoothquant.py b/op_builder/smoothquant.py new file mode 100644 index 000000000000..d562a4c4f626 --- /dev/null +++ b/op_builder/smoothquant.py @@ -0,0 +1,52 @@ +import torch + +from .builder import Builder +from .utils import append_nvcc_threads + + +class SmoothquantBuilder(Builder): + NAME = "cu_smoothquant" + PREBUILT_IMPORT_PATH = "colossalai._C.cu_smoothquant" + + def __init__(self): + super().__init__(name=SmoothquantBuilder.NAME, prebuilt_import_path=SmoothquantBuilder.PREBUILT_IMPORT_PATH) + + def include_dirs(self): + ret = [self.csrc_abs_path("smoothquant"), self.get_cuda_home_include()] + return ret + + def sources_files(self): + ret = [ + self.csrc_abs_path(fname) + for fname in [ + "smoothquant/binding.cpp", + "smoothquant/linear.cu", + ] + ] + return ret + + def cxx_flags(self): + return ["-O3"] + self.version_dependent_macros + + def nvcc_flags(self): + compute_capability = torch.cuda.get_device_capability() + cuda_arch = compute_capability[0] * 100 + compute_capability[1] * 10 + + extra_cuda_flags = [ + "-v", + f"-DCUDA_ARCH={cuda_arch}", + "-std=c++17", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-DTHRUST_IGNORE_CUB_VERSION_CHECK", + ] + + ret = ["-O3", "--use_fast_math"] + self.version_dependent_macros + extra_cuda_flags + return append_nvcc_threads(ret) + + def builder(self): + try: + super().builder() + except: + warnings.warn("build smoothquant lib not successful") diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 8a4b0f1a0ffd..095617d76355 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -13,3 +13,6 @@ safetensors einops pydantic ray +sentencepiece +google +protobuf diff --git a/tests/components_to_test/bert.py b/tests/components_to_test/bert.py index f0061ad18c84..9f0eef75ae93 100644 --- a/tests/components_to_test/bert.py +++ b/tests/components_to_test/bert.py @@ -52,7 +52,6 @@ def bert_model_builder(checkpoint: bool = False): hidden_dropout_prob=0.0, attention_probs_dropout_prob=0.0, ) - print("building BertForSequenceClassification model") # adapting huggingface BertForSequenceClassification for single unittest calling interface class ModelAdaptor(BertForSequenceClassification): diff --git a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py index 9cc12f96bd4d..104ca254c572 100644 --- a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py +++ b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py @@ -14,6 +14,8 @@ _AMP_ERR_MODELS = ["timm_convit", "deepfm_interactionarch"] # These models have no parameters _LOW_LEVEL_ZERO_ERR_MODELS = ["dlrm_interactionarch"] +# These models will cause stuck, to be fixed +_STUCK_MODELS = ["transformers_albert_for_multiple_choice"] def run_fn(stage, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]: @@ -53,7 +55,7 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True): """ passed_models = [] failed_info = {} # (model_name, error) pair - ignore_models = _AMP_ERR_MODELS + _LOW_LEVEL_ZERO_ERR_MODELS + ignore_models = _AMP_ERR_MODELS + _LOW_LEVEL_ZERO_ERR_MODELS + _STUCK_MODELS skipped_models = [] for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items(): diff --git a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py index 634e81bb225d..f876040384b3 100644 --- a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py @@ -58,9 +58,7 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b dist.barrier() new_bert_model = BertForSequenceClassification.from_pretrained(pretrained_path) - check_state_dict_equal( - bert_model.state_dict(only_rank_0=False, dtype=torch.float32), new_bert_model.state_dict(), False - ) + check_state_dict_equal(bert_model.state_dict(only_rank_0=False), new_bert_model.state_dict(), False) @clear_cache_before_run() @@ -100,7 +98,9 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha dist.barrier() booster.load_model(new_model, model_ckpt_path) - check_state_dict_equal(model.state_dict(only_rank_0=False), new_model.state_dict(only_rank_0=False), False) + check_state_dict_equal( + model.state_dict(only_rank_0=False), new_model.state_dict(only_rank_0=False), False, ignore_dtype=True + ) booster.load_optimizer(new_optimizer, optimizer_ckpt_path) check_state_dict_equal( @@ -136,7 +136,7 @@ def exam_lazy_from_pretrained(): booster.save_model(model, save_path, shard=False) dist.barrier() state_dict = torch.load(save_path, map_location="cpu") - check_state_dict_equal(state_dict, orig_state_dict, False) + check_state_dict_equal(state_dict, orig_state_dict, False, ignore_dtype=True) def run_dist(rank, world_size, port): diff --git a/tests/test_checkpoint_io/test_gemini_torch_compability.py b/tests/test_checkpoint_io/test_gemini_torch_compability.py index d46e5380d944..bb7a60035e02 100644 --- a/tests/test_checkpoint_io/test_gemini_torch_compability.py +++ b/tests/test_checkpoint_io/test_gemini_torch_compability.py @@ -60,9 +60,10 @@ def exam_torch_load_from_gemini(shard: bool, model_name: str): # Add prefix to get aligned with pytorch parameter names. check_state_dict_equal( - model.state_dict(only_rank_0=False, prefix="module.module.", dtype=torch.float32), + model.state_dict(only_rank_0=False, prefix="module.module."), new_model.state_dict(), False, + ignore_dtype=True, ) new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path) @@ -125,9 +126,10 @@ def exam_gemini_load_from_torch(shard: bool, model_name: str): # Add prefix to get aligned with pytorch parameter names. check_state_dict_equal( - new_model.state_dict(only_rank_0=False, prefix="module.module.", dtype=torch.float32), + new_model.state_dict(only_rank_0=False, prefix="module.module."), model.state_dict(), False, + ignore_dtype=True, ) new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path) diff --git a/tests/test_checkpoint_io/test_general_checkpoint_io.py b/tests/test_checkpoint_io/test_general_checkpoint_io.py index 2a046a298dd7..8431036df6b7 100644 --- a/tests/test_checkpoint_io/test_general_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_general_checkpoint_io.py @@ -6,6 +6,7 @@ from torchvision.models import resnet18 from colossalai.checkpoint_io import GeneralCheckpointIO +from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.testing import check_state_dict_equal, clear_cache_before_run, parameterize # ======== @@ -22,6 +23,7 @@ def test_unsharded_checkpoint(use_safetensors: bool): # create a model and optimizer model = resnet18() optimizer = Adam(model.parameters(), lr=0.001) + lr_scheduler = CosineAnnealingWarmupLR(optimizer, total_steps=10) # create test data sample x = torch.randn(1, 3, 224, 224) @@ -31,6 +33,7 @@ def test_unsharded_checkpoint(use_safetensors: bool): loss = y.sum() loss.backward() optimizer.step() + lr_scheduler.step() # create a temp file for checkpoint if use_safetensors: @@ -39,19 +42,23 @@ def test_unsharded_checkpoint(use_safetensors: bool): suffix = ".bin" model_ckpt_tempfile = tempfile.NamedTemporaryFile(suffix=suffix) optimizer_ckpt_tempfile = tempfile.NamedTemporaryFile() + lr_scheduler_ckpt_tempfile = tempfile.NamedTemporaryFile() - # save the model and optimizer + # save the model, optimizer, lr_scheduler ckpt_io = GeneralCheckpointIO() ckpt_io.save_model(model, model_ckpt_tempfile.name, use_safetensors=use_safetensors) ckpt_io.save_optimizer(optimizer, optimizer_ckpt_tempfile.name) + ckpt_io.save_lr_scheduler(lr_scheduler, lr_scheduler_ckpt_tempfile.name) # create new model new_model = resnet18() new_optimizer = Adam(new_model.parameters(), lr=0.001) + new_lr_scheduler = CosineAnnealingWarmupLR(optimizer, total_steps=10) - # load the model and optimizer + # load the model, optimizer, lr_scheduler ckpt_io.load_model(new_model, model_ckpt_tempfile.name) ckpt_io.load_optimizer(new_optimizer, optimizer_ckpt_tempfile.name) + ckpt_io.load_lr_scheduler(new_lr_scheduler, lr_scheduler_ckpt_tempfile.name) # check for model and optimizer state dict recursively check_state_dict_equal(model.state_dict(), new_model.state_dict()) diff --git a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py index 8a4724c8a82c..e7f44f97e3cf 100644 --- a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py @@ -66,6 +66,7 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool): booster.load_optimizer(new_optimizer, optimizer_ckpt_path) check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict(), False) + torch.cuda.empty_cache() def run_dist(rank, world_size, port): diff --git a/tests/test_infer/test_llama2_infer.py b/tests/test_infer/test_llama2_infer.py new file mode 100644 index 000000000000..0eebed8892ea --- /dev/null +++ b/tests/test_infer/test_llama2_infer.py @@ -0,0 +1,69 @@ +import os + +import pytest +import torch +from packaging import version +from transformers import LlamaForCausalLM +from transformers.models.llama.configuration_llama import LlamaConfig + +import colossalai +from colossalai.inference.tensor_parallel.engine import TPInferEngine +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer import ShardConfig +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn + +os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" +TPSIZE = 2 +BATCH_SIZE = 8 +MAX_INPUT_LEN = 12 +MAX_OUTPUT_LEN = 100 + +CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5") + + +@parameterize( + "test_config", + [ + { + "tp_size": TPSIZE, + } + ], +) +def run_llama_test(test_config): + llama_config = LlamaConfig( + num_hidden_layers=2, num_key_value_heads=8, bos_token_id=0, eos_token_id=1, vocab_size=1200, hidden_size=1024 + ) + model = LlamaForCausalLM(llama_config) + model = model.half() + + shard_config = ShardConfig( + enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True + ) + infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) + generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False) + + input_tokens = { + "input_ids": torch.randint(1, 1000, (BATCH_SIZE, MAX_INPUT_LEN), device="cuda"), + "attention_mask": torch.ones((BATCH_SIZE, MAX_INPUT_LEN), device="cuda"), + } + outputs = infer_engine.generate(input_tokens, **generate_kwargs) + + assert outputs is not None + + +def check_llama(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_llama_test() + + +@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_llama(): + spawn(check_llama, TPSIZE) + + +if __name__ == "__main__": + test_llama() diff --git a/tests/test_infer/test_pipeline_infer.py b/tests/test_infer/test_pipeline_infer.py new file mode 100644 index 000000000000..ad8e32b48bae --- /dev/null +++ b/tests/test_infer/test_pipeline_infer.py @@ -0,0 +1,62 @@ +import pytest +import torch +import torch.distributed as dist +import transformers + +import colossalai +from colossalai.inference.pipeline.engine import PPInferEngine +from colossalai.inference.pipeline.policy.gpt2_ppinfer import GPT2LMHeadModelPipelinePolicy +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn + + +def data_gen(): + input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64) + attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) + return dict(input_ids=input_ids, attention_mask=attention_mask) + + +inputs = data_gen() +for k, v in inputs.items(): + if torch.is_tensor(v) or "Tensor" in v.__class__.__name__: + new_shape = [1] * v.dim() + new_shape[0] = 16 + inputs[k] = v.to("cuda").repeat(*new_shape) + + +def pipeline_inference_test(pp_size, new_length, micro_batch_size): + model = transformers.GPT2LMHeadModel(transformers.GPT2Config(n_layer=8)) + engine = PPInferEngine( + pp_size=pp_size, + model=model, + model_policy=GPT2LMHeadModelPipelinePolicy(), + new_length=new_length, + micro_batch_size=micro_batch_size, + ) + output = engine.inference([inputs]) + if dist.get_rank() == 0: + assert len(output[0]) == new_length, f"{len(output)}, {new_length}" + + +@parameterize("pp_size", [4]) +@parameterize("new_length", [4, 8, 16]) +@parameterize("micro_batch_size", [1, 4]) +@clear_cache_before_run() +def run_pipeline_inference_test(pp_size, new_length, micro_batch_size): + pipeline_inference_test(pp_size, new_length, micro_batch_size) + torch.cuda.empty_cache() + + +def check_pipeline_inference(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_pipeline_inference_test() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_pipeline_inference(): + spawn(check_pipeline_inference, nprocs=4) + + +if __name__ == "__main__": + test_pipeline_inference() diff --git a/tests/test_infer_ops/triton/test_llama2_token_attn.py b/tests/test_infer_ops/triton/test_llama2_token_attn.py deleted file mode 100644 index 0537a3d76129..000000000000 --- a/tests/test_infer_ops/triton/test_llama2_token_attn.py +++ /dev/null @@ -1,63 +0,0 @@ -import pytest -import torch -from packaging import version - -try: - pass - - from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards - - HAS_TRITON = True -except ImportError: - HAS_TRITON = False - print("please install triton from https://github.com/openai/triton") - -TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") - - -def torch_att(xq, xk, xv, bs, seqlen, num_head, head_dim): - xq = xq.view(bs, 1, num_head, head_dim) - xk = xk.view(bs, seqlen, num_head, head_dim) - xv = xv.view(bs, seqlen, num_head, head_dim) - - logics = torch.sum(xq * xk, dim=3, keepdim=False) * 1 / (head_dim**0.5) - prob = torch.softmax(logics, dim=1) - prob = prob.view(bs, seqlen, num_head, 1) - - return torch.sum(prob * xv, dim=1, keepdim=False) - - -@pytest.mark.skipif( - not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" -) -def test(): - Z, head_num, seq_len, head_dim = 2, 32, 2048, 128 - dtype = torch.float16 - - # attn out: 2,4096 - q = torch.empty((Z, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2) - k = torch.empty((Z * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2) - v = torch.empty((Z * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2) - o = torch.empty((Z, head_num, head_dim), dtype=dtype, device="cuda") - max_kv_cache_len = seq_len - kv_cache_start_loc = torch.zeros((Z,), dtype=torch.int32, device="cuda") - kv_cache_loc = torch.zeros((Z, seq_len), dtype=torch.int32, device="cuda") - kv_cache_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") - other_kv_index = 2048 - - kv_cache_seq_len[:] = seq_len - kv_cache_start_loc[0] = 0 - kv_cache_start_loc[1] = seq_len - - for i in range(Z): - kv_cache_loc[i, :] = torch.arange(i * seq_len, (i + 1) * seq_len, dtype=torch.int32, device="cuda") - - Llama2TokenAttentionForwards.token_attn( - q, k, v, o, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_kv_cache_len, other_kv_index - ) - torch_out = torch_att(q, k, v, Z, seq_len, head_num, head_dim) - assert torch.allclose(torch_out, o, atol=1e-3, rtol=0) - - -if __name__ == "__main__": - test() diff --git a/tests/test_infer_ops/triton/test_token_attn_1.py b/tests/test_infer_ops/triton/test_token_attn_1.py deleted file mode 100644 index fc5f8cd6c9dc..000000000000 --- a/tests/test_infer_ops/triton/test_token_attn_1.py +++ /dev/null @@ -1,74 +0,0 @@ -import math - -import pytest -import torch -from packaging import version - -try: - pass - - from colossalai.kernel.triton.token_attention_kernel import token_attn_fwd_1 - - HAS_TRITON = True -except ImportError: - HAS_TRITON = False - print("please install triton from https://github.com/openai/triton") - -TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") - - -def torch_attn(xq, xk, bs, seqlen, num_head, head_dim): - xq = xq.view(bs, 1, num_head, head_dim) - xk = xk.view(bs, seqlen, num_head, head_dim) - keys = xk - xq = xq.transpose(1, 2) - keys = keys.transpose(1, 2) - scores = ( - (torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(head_dim)).squeeze().transpose(0, 1).reshape(num_head, -1) - ) - return scores - - -def torch_attn_1(xq, xk, seqlen, num_head, head_dim): - xq = xq.view(1, num_head, head_dim) - xk = xk.view(seqlen, num_head, head_dim) - logics = torch.sum(xq * xk, dim=-1, keepdim=False) - - logics = logics.transpose(0, 1) / math.sqrt(head_dim) - return logics - - -@pytest.mark.skipif( - not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" -) -def test_attn_1(): - pass - - batch_size, seq_len, head_num, head_dim = 17, 1025, 12, 128 - - dtype = torch.float16 - - q = torch.empty((batch_size, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2) - k = torch.empty((batch_size * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2) - attn_out = torch.empty((head_num, batch_size * seq_len), dtype=dtype, device="cuda") - - b_loc = torch.zeros((batch_size, seq_len), dtype=torch.int32, device="cuda") - kv_cache_start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda") - kv_cache_seq_len = torch.zeros((batch_size,), dtype=torch.int32, device="cuda") - - for i in range(batch_size): - kv_cache_start_loc[i] = i * seq_len - kv_cache_seq_len[i] = seq_len - b_loc[i] = i * seq_len + torch.arange(0, seq_len, dtype=torch.int32, device="cuda") - - token_attn_fwd_1(q, k, attn_out, b_loc, kv_cache_start_loc, kv_cache_seq_len, seq_len) - - torch_out = torch_attn(q, k, batch_size, seq_len, head_num, head_dim).squeeze() - o = attn_out.squeeze() - print("max ", torch.max(torch.abs(torch_out - o))) - print("mean ", torch.mean(torch.abs(torch_out - o))) - assert torch.allclose(torch_out, o, atol=1e-2, rtol=0) - - -if __name__ == "__main__": - test_attn_1() diff --git a/tests/test_infer_ops/triton/test_token_attn_2.py b/tests/test_infer_ops/triton/test_token_attn_2.py deleted file mode 100644 index 2dd756f2ba91..000000000000 --- a/tests/test_infer_ops/triton/test_token_attn_2.py +++ /dev/null @@ -1,63 +0,0 @@ -import pytest -import torch -from packaging import version - -try: - pass - - from colossalai.kernel.triton.token_attention_kernel import token_attn_fwd_2 - - HAS_TRITON = True -except ImportError: - HAS_TRITON = False - print("please install triton from https://github.com/openai/triton") - -TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") - - -def torch_attn(V, P, bs, seqlen, num_head, head_dim): - V = V.view(bs, seqlen, num_head, head_dim).transpose(1, 2) - P = P.reshape(num_head, bs, 1, seqlen).transpose(0, 1) - attn_out = torch.matmul(P, V) - - return attn_out - - -@pytest.mark.skipif( - not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" -) -def test_token_attn_2(): - pass - - batch_size, seq_len, head_num, head_dim = 17, 1025, 12, 128 - dtype = torch.float16 - - V = torch.empty((batch_size * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=10) - Prob = ( - torch.empty((head_num, batch_size * seq_len), dtype=dtype, device="cuda") - .normal_(mean=0.4, std=0.2) - .reshape(head_num, batch_size, seq_len) - .softmax(-1) - .reshape(head_num, batch_size * seq_len) - ) - attn_out = torch.empty((batch_size, head_num, head_dim), dtype=dtype, device="cuda") - - kv_cache_start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda") - kv_cache_seq_len = torch.zeros((batch_size,), dtype=torch.int32, device="cuda") - kv_cache_loc = torch.zeros((batch_size, seq_len), dtype=torch.int32, device="cuda") - for i in range(batch_size): - kv_cache_start_loc[i] = i * seq_len - kv_cache_seq_len[i] = seq_len - kv_cache_loc[i] = i * seq_len + torch.arange(0, seq_len, dtype=torch.int32, device="cuda") - - token_attn_fwd_2(Prob, V, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, seq_len) - - torch_out = torch_attn(V, Prob, batch_size, seq_len, head_num, head_dim).squeeze() - o = attn_out - print("max ", torch.max(torch.abs(torch_out - o))) - print("mean ", torch.mean(torch.abs(torch_out - o))) - assert torch.allclose(torch_out, o, atol=1e-2, rtol=0) - - -if __name__ == "__main__": - test_token_attn_2() diff --git a/tests/test_infer_ops/triton/test_token_attn_fwd.py b/tests/test_infer_ops/triton/test_token_attn_fwd.py index 9c7a53798317..a7fc3d29b77a 100644 --- a/tests/test_infer_ops/triton/test_token_attn_fwd.py +++ b/tests/test_infer_ops/triton/test_token_attn_fwd.py @@ -3,16 +3,13 @@ from packaging import version try: - pass - from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd - HAS_TRITON = True except ImportError: HAS_TRITON = False print("please install triton from https://github.com/openai/triton") -TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) >= version.parse("11.6") def torch_att(xq, xk, xv, bs, seqlen, num_head, head_dim): diff --git a/tests/test_optimizer/test_adam_kernel.py b/tests/test_optimizer/test_adam_kernel.py index 8131ea3234d8..6bbe3e4e8172 100644 --- a/tests/test_optimizer/test_adam_kernel.py +++ b/tests/test_optimizer/test_adam_kernel.py @@ -13,9 +13,7 @@ _FUSED_ALLOWED_P_G_TYPES = [ (torch.float, torch.half), (torch.float, torch.float), - (torch.half, torch.float), (torch.half, torch.half), - (torch.bfloat16, torch.float), (torch.float, torch.bfloat16), (torch.bfloat16, torch.bfloat16), ] @@ -23,7 +21,6 @@ _CPU_ALLOWED_P_G_TYPES = [ (torch.float, torch.half), (torch.float, torch.float), - (torch.half, torch.float), (torch.half, torch.half), ] @@ -138,8 +135,8 @@ def check_adam_kernel( master_exp_avg_sq = torch.zeros_like(master_p) p = master_p.clone().to(p_dtype) g = master_g.clone().to(g_dtype) - exp_avg = master_exp_avg.clone() - exp_avg_sq = master_exp_avg_sq.clone() + exp_avg = master_exp_avg.clone().to(p_dtype) + exp_avg_sq = master_exp_avg_sq.clone().to(p_dtype) for step in range(1, 1 + n_steps): torch_adam.update(step, master_p, master_g, master_exp_avg, master_exp_avg_sq) diff --git a/tests/test_optimizer/test_adam_optim.py b/tests/test_optimizer/test_adam_optim.py index 59b40a0afa3c..68d71e3c4194 100644 --- a/tests/test_optimizer/test_adam_optim.py +++ b/tests/test_optimizer/test_adam_optim.py @@ -21,8 +21,6 @@ (torch.float, torch.float), # pure fp32 (torch.float, torch.half), # fp16 amp (torch.float, torch.bfloat16), # bfloat16 amp - # (torch.half, torch.half), # FIXME(ver217): cpu adam kernel does not support pure fp16 - # (torch.bfloat16, torch.bfloat16), # FIXME(ver217): cpu adam kernel does not support pure bfloat16 ] N_STEPS = 3 diff --git a/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_amp_optimizer.py b/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_amp_optimizer.py new file mode 100644 index 000000000000..0192afc99ae4 --- /dev/null +++ b/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_amp_optimizer.py @@ -0,0 +1,258 @@ +import pytest +import torch +from torch.nn.utils.clip_grad import clip_grad_norm_ + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.layer.utils import Randomizer +from colossalai.tensor.d_tensor.api import clear_layout_converter +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import ( + build_model_from_hybrid_plugin, + check_all_grad_tensors, + check_loss, + check_output_hidden_state, + check_weight, + get_grad_tensors_for_check, + run_forward_backward_with_hybrid_plugin, + unwrap_model, +) + + +def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( + model_fn, loss_fn, test_config + ) + + org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( + org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster + ) + + stage_manager = booster.plugin.stage_manager + tp_group = booster.plugin.tp_group + + bert = unwrap_model(org_model, "BertModel", "bert") + sharded_bert = unwrap_model(sharded_model, "BertModel", "bert") + + col_layer_for_check = ["encoder.layer[0].output.dense"] + row_layer_for_check = ["embeddings.word_embeddings", "encoder.layer[0].intermediate.dense"] + + if test_config["precision"] == "fp32": + atol, rtol = 1e-4, 1e-3 + elif test_config["precision"] == "fp16": + atol, rtol = 5e-3, 5e-3 + else: + atol, rtol = 2e-2, 2e-2 + + # Check grads + # Save gradient tensors for comparison between the original model and the sharded model. + grads_to_check = {} + if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: + col_layer_grads = get_grad_tensors_for_check( + bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False + ) + row_layer_grads = get_grad_tensors_for_check( + bert, sharded_bert, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False + ) + grads_to_check.update(col_layer_grads) + grads_to_check.update(row_layer_grads) + check_all_grad_tensors(grads_to_check) + + # Check gradient norm + # Convert the gradient data of the working parameter to float and assign it to the master parameter's gradient + # Note that this operation should have been done in the 'step' function, but it is performed here in advance for gradient norm calculation purposes. + # Although it will be done again in the 'step' function, it does not affect correctness. + for group in sharded_optimizer.optim.param_groups: + for p in group["params"]: + working_param = sharded_optimizer.master_to_working_map[p] + if p is working_param: + continue + if working_param.grad is not None: + p.grad = working_param.grad.data.float() + working_param.grad = None + # Create a list of parameter-gradient pairs containing working parameters and their gradients + param_gradient_pairs = [ + (sharded_optimizer.master_to_working_map[p], p.grad) + for group in sharded_optimizer.param_groups + for p in group["params"] + if p.grad is not None + ] + + origin_norm = clip_grad_norm_(org_model.parameters(), test_config["max_norm"]) + # Calculate the gradient norm of the sharded optimizer + device = origin_norm.device + hybrid_norm = torch.tensor(sharded_optimizer._compute_grad_norm(param_gradient_pairs)).to(device) + + # If using fp16 precision, divide by the initial scale + if test_config["precision"] == "fp16": + hybrid_norm /= test_config["initial_scale"] + + # Assert that the gradient norm of the original model is close to the gradient norm of the hybrid model + assert torch.allclose( + origin_norm, hybrid_norm, atol=atol, rtol=rtol + ), f"Original model grad norm is not equal to sharded model grad norm\n{origin_norm}\n{hybrid_norm}" + + # Optimizer executes step + org_optimizer.step() + sharded_optimizer.step() + + # Check last hidden state & loss + if stage_manager is None or stage_manager.is_last_stage(): + if test_config["precision"] == "fp32": + atol, rtol = 1e-5, 1e-3 + elif test_config["precision"] == "fp16": + atol, rtol = 5e-3, 5e-3 + else: + atol, rtol = 2e-2, 2e-2 + if org_model.__class__.__name__ == "BertModel": + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) + + check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) + + # Check weights + if test_config["precision"] == "fp32": + atol, rtol = 5e-3, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + if stage_manager is None or stage_manager.is_first_stage(): + check_weight(bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False) + + torch.cuda.empty_cache() + + +@parameterize( + "test_config", + [ + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": True, + "precision": "fp16", + "max_norm": 5, + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 1, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp16", + "max_norm": 5, + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp16", + "max_norm": 5, + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": True, + "precision": "bf16", + "max_norm": 5, + }, + { + "tp_size": 2, + "pp_size": 1, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "bf16", + "max_norm": 5, + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "bf16", + "max_norm": 5, + }, + ], +) +def run_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry("transformers_bert") + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + Randomizer.reset_index() + torch.cuda.empty_cache() + + +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "bf16", + "max_norm": 5, + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp16", + "max_norm": 5, + "initial_scale": 1, + }, + ], +) +def run_3d_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry("transformers_bert") + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + Randomizer.reset_index() + torch.cuda.empty_cache() + + +def check_grad_clip_norm(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_test() + + +def check_grad_clip_norm_3d(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_3d_test() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_grad_clip_norm(): + spawn(check_grad_clip_norm, 4) + + +@pytest.mark.largedist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_grad_clip_norm_3d(): + spawn(check_grad_clip_norm_3d, 8) + + +if __name__ == "__main__": + test_grad_clip_norm() + test_grad_clip_norm_3d() diff --git a/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_naive_optimizer.py b/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_naive_optimizer.py new file mode 100644 index 000000000000..da298f5c0be1 --- /dev/null +++ b/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_naive_optimizer.py @@ -0,0 +1,197 @@ +import pytest +import torch +from torch.nn.utils.clip_grad import clip_grad_norm_ + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.layer.utils import Randomizer +from colossalai.tensor.d_tensor.api import clear_layout_converter +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import ( + build_model_from_hybrid_plugin, + check_all_grad_tensors, + check_loss, + check_output_hidden_state, + check_weight, + get_grad_tensors_for_check, + run_forward_backward_with_hybrid_plugin, + unwrap_model, +) + + +def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( + model_fn, loss_fn, test_config + ) + + org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( + org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster + ) + + stage_manager = booster.plugin.stage_manager + tp_group = booster.plugin.tp_group + + bert = unwrap_model(org_model, "BertModel", "bert") + sharded_bert = unwrap_model(sharded_model, "BertModel", "bert") + + col_layer_for_check = ["encoder.layer[0].output.dense"] + row_layer_for_check = ["embeddings.word_embeddings", "encoder.layer[0].intermediate.dense"] + + if test_config["precision"] == "fp32": + atol, rtol = 1e-4, 1e-3 + elif test_config["precision"] == "fp16": + atol, rtol = 5e-3, 5e-3 + else: + atol, rtol = 2e-2, 2e-2 + + # Check grads + # Save gradient tensors for comparison between the original model and the sharded model. + grads_to_check = {} + if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: + col_layer_grads = get_grad_tensors_for_check( + bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False + ) + row_layer_grads = get_grad_tensors_for_check( + bert, sharded_bert, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False + ) + grads_to_check.update(col_layer_grads) + grads_to_check.update(row_layer_grads) + check_all_grad_tensors(grads_to_check) + + # Check grad norm + param_gradient_pairs = [ + (p, p.grad) for group in sharded_optimizer.param_groups for p in group["params"] if p.grad is not None + ] + origin_norm = clip_grad_norm_(org_model.parameters(), test_config["max_norm"]) + device = origin_norm.device + hybrid_norm = torch.tensor(sharded_optimizer._compute_grad_norm(param_gradient_pairs)).to(device) + assert torch.allclose( + origin_norm, hybrid_norm, atol=atol, rtol=rtol + ), f"orgin origin model grad norm is not equal to shard model grad norm\n{origin_norm}\n{hybrid_norm}" + + # optimizer executes step + org_optimizer.step() + sharded_optimizer.step() + + # check last hidden state & loss + if stage_manager is None or stage_manager.is_last_stage(): + if test_config["precision"] == "fp32": + atol, rtol = 1e-5, 1e-3 + elif test_config["precision"] == "fp16": + atol, rtol = 5e-3, 5e-3 + else: + atol, rtol = 2e-2, 2e-2 + + if org_model.__class__.__name__ == "BertModel": + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) + + check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) + + # check weights + if test_config["precision"] == "fp32": + atol, rtol = 5e-3, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + if stage_manager is None or stage_manager.is_first_stage(): + check_weight(bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False) + + torch.cuda.empty_cache() + + +@parameterize( + "test_config", + [ + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": True, + "precision": "fp32", + "max_norm": 5, + }, + { + "tp_size": 2, + "pp_size": 1, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + "max_norm": 5, + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + "max_norm": 5, + }, + ], +) +def run_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry("transformers_bert") + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + Randomizer.reset_index() + torch.cuda.empty_cache() + + +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + "max_norm": 5, + }, + ], +) +def run_3d_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry("transformers_bert") + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + Randomizer.reset_index() + torch.cuda.empty_cache() + + +def check_grad_clip_norm(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_test() + + +def check_grad_clip_norm_3d(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_3d_test() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_grad_clip_norm(): + spawn(check_grad_clip_norm, 4) + + +@pytest.mark.largedist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_grad_clip_norm_3d(): + spawn(check_grad_clip_norm_3d, 8) + + +if __name__ == "__main__": + test_grad_clip_norm() + test_grad_clip_norm_3d() diff --git a/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_zero_optimizer.py b/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_zero_optimizer.py new file mode 100644 index 000000000000..f1ac1de1acc9 --- /dev/null +++ b/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_zero_optimizer.py @@ -0,0 +1,241 @@ +import math + +import pytest +import torch +import torch.distributed as dist +from torch.nn.utils.clip_grad import clip_grad_norm_ + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.layer.utils import Randomizer +from colossalai.tensor.d_tensor.api import clear_layout_converter +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import ( + build_model_from_hybrid_plugin, + check_loss, + check_output_hidden_state, + check_weight, + run_forward_backward_with_hybrid_plugin, + unwrap_model, +) + + +def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( + model_fn, loss_fn, test_config + ) + + org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( + org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster + ) + + stage_manager = booster.plugin.stage_manager + tp_group = booster.plugin.tp_group + + bert = unwrap_model(org_model, "BertModel", "bert") + sharded_bert = unwrap_model(sharded_model, "BertModel", "bert") + + col_layer_for_check = ["encoder.layer[0].output.dense"] + + if test_config["precision"] == "fp32": + atol, rtol = 1e-4, 1e-3 + elif test_config["precision"] == "fp16": + atol, rtol = 5e-3, 5e-3 + else: + atol, rtol = 2e-2, 2e-2 + + dist.barrier() + # Check gradient norm + origin_norm = clip_grad_norm_(org_model.parameters(), test_config["max_norm"]) + + # Calculate the gradient norm of the sharded optimizer + device = origin_norm.device + norm_groups = [] + for group_id in range(sharded_optimizer.num_param_groups): + working_grads = sharded_optimizer._grad_store.get_working_grads_by_group_id(group_id) + norm_group = sharded_optimizer._compute_grad_norm(gradients=working_grads) + norm_groups.append(norm_group) + total_norm = 0.0 + for norm in norm_groups: + total_norm += norm**2.0 + hybrid_norm = torch.tensor(math.sqrt(total_norm)).to(device) + + # If using fp16 precision, divide by the initial scale + if test_config["precision"] == "fp16": + hybrid_norm /= test_config["initial_scale"] + + # Assert that the gradient norm of the original model is close to the gradient norm of the hybrid model + assert torch.allclose( + origin_norm, hybrid_norm, atol=atol, rtol=rtol + ), f"Original model grad norm is not equal to sharded model grad norm\n{origin_norm}\n{hybrid_norm}" + + # optimizer executes step + org_optimizer.step() + sharded_optimizer.step() + + # check last hidden state & loss + if stage_manager is None or stage_manager.is_last_stage(): + if test_config["precision"] == "fp32": + atol, rtol = 1e-5, 1e-3 + elif test_config["precision"] == "fp16": + atol, rtol = 5e-3, 5e-3 + else: + atol, rtol = 2e-2, 2e-2 + if org_model.__class__.__name__ == "BertModel": + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) + + check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) + + # check weights + if test_config["precision"] == "fp32": + atol, rtol = 5e-3, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + if stage_manager is None or stage_manager.is_first_stage(): + check_weight(bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False) + + torch.cuda.empty_cache() + + +@parameterize( + "test_config", + [ + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 4, + "zero_stage": 1, + "enable_all_optimization": False, + "use_lazy_init": True, + "precision": "fp16", + "max_norm": 5, + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 1, + "zero_stage": 1, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp16", + "max_norm": 5, + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 1, + "zero_stage": 2, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp16", + "max_norm": 5, + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 4, + "zero_stage": 1, + "enable_all_optimization": False, + "use_lazy_init": True, + "precision": "bf16", + "max_norm": 5, + }, + { + "tp_size": 2, + "pp_size": 1, + "zero_stage": 1, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "bf16", + "max_norm": 5, + }, + { + "tp_size": 2, + "pp_size": 1, + "zero_stage": 2, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "bf16", + "max_norm": 5, + }, + ], +) +def run_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry("transformers_bert") + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + Randomizer.reset_index() + torch.cuda.empty_cache() + + +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "zero_stage": 1, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "bf16", + "max_norm": 5, + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "zero_stage": 1, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp16", + "max_norm": 5, + "initial_scale": 1, + }, + ], +) +def run_3d_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry("transformers_bert") + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + Randomizer.reset_index() + torch.cuda.empty_cache() + + +def check_grad_clip_norm(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_test() + + +def check_grad_clip_norm_3d(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_3d_test() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_grad_clip_norm(): + spawn(check_grad_clip_norm, 4) + + +@pytest.mark.largedist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_grad_clip_norm_3d(): + spawn(check_grad_clip_norm_3d, 8) + + +if __name__ == "__main__": + test_grad_clip_norm() + test_grad_clip_norm_3d() diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 0a2b151d4274..6acbe4ff523d 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -10,6 +10,7 @@ from torch.distributed import ProcessGroup from torch.nn import Module from torch.optim import Adam, Optimizer +from torch.testing import assert_close from colossalai.booster import Booster from colossalai.booster.plugin import HybridParallelPlugin @@ -160,7 +161,7 @@ def _criterion(outputs, inputs): input_shape = data["input_ids"].shape for k, v in data.items(): if v.shape == input_shape: - data[k] = v.repeat(input_shape[:-1] + (input_shape[-1] * times,)) + data[k] = v.repeat((1,) * (v.dim() - 1) + (times,)) sharded_model.train() if booster.plugin.stage_manager is not None: @@ -207,15 +208,11 @@ def check_output_hidden_state( else: sharded_hidden_state = sharded_output.last_hidden_state - assert torch.allclose( - org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol - ), f"shard model's output hidden state is not equal to origin model's last hidden state\n{org_hidden_state}\n{sharded_hidden_state}" + assert_close(org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol) def check_loss(org_loss: Tensor, sharded_loss: Tensor, atol: float = 1e-5, rtol: float = 1e-3): - assert torch.allclose( - org_loss.float(), sharded_loss.float(), atol=atol, rtol=rtol - ), f"shard model loss is not equal to origin model loss\n{org_loss}\n{sharded_loss}" + assert torch.allclose(org_loss.float(), sharded_loss.float(), atol=atol, rtol=rtol) def check_weight( @@ -242,9 +239,7 @@ def check_weight( if verbose and dist.get_rank() == 0: print(f"'{suffix}' weight: {org_weight}, {sharded_weight}") - assert torch.allclose( - org_weight.float(), sharded_weight.float(), atol=atol, rtol=rtol - ), f"shard model weight {suffix} is not equal to origin model weight\n{org_weight}\n{sharded_weight}" + assert_close(org_weight.float(), sharded_weight.float(), atol=atol, rtol=rtol) def get_grad_tensors_for_check( @@ -310,9 +305,7 @@ def check_grad( if verbose and dist.get_rank() == 0: print(f"'{suffix}' grad: {org_grad}, {shard_grad}") - assert torch.allclose( - org_grad.float(), shard_grad.float(), rtol=rtol, atol=atol - ), f"error attribute '{suffix}', orgin model grad is not equal to shard model grad\n{org_grad}\n{shard_grad}" + assert_close(org_grad.float(), shard_grad.float(), rtol=rtol, atol=atol) def unwrap_model( @@ -337,6 +330,4 @@ def check_all_grad_tensors(check_tensors): shard_grad = check_info["shard_grad"] rtol = check_info["rtol"] atol = check_info["atol"] - assert torch.allclose( - org_grad, shard_grad, atol=atol, rtol=rtol - ), f"error attribute '{suffix}', orgin model grad is not equal to shard model grad\n{org_grad}\n{shard_grad}" + assert_close(org_grad, shard_grad, atol=atol, rtol=rtol) diff --git a/tests/test_shardformer/test_model/test_shard_vit.py b/tests/test_shardformer/test_model/test_shard_vit.py index 1c934bd22340..3a8af2d6d481 100644 --- a/tests/test_shardformer/test_model/test_shard_vit.py +++ b/tests/test_shardformer/test_model/test_shard_vit.py @@ -43,7 +43,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, grads_to_check = {} if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: if test_config["precision"] == "fp32": - atol, rtol = 1e-5, 1e-3 + atol, rtol = 2e-5, 1e-3 else: atol, rtol = 5e-3, 5e-3 row_layer_grads = get_grad_tensors_for_check( @@ -62,7 +62,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check last hidden state & loss if stage_manager is None or stage_manager.is_last_stage(): if test_config["precision"] == "fp32": - atol, rtol = 1e-5, 1e-3 + atol, rtol = 2e-3, 1e-3 else: atol, rtol = 5e-3, 5e-3 @@ -154,15 +154,6 @@ def run_vit_test(test_config): "precision": "fp32", "initial_scale": 1, }, - { - "tp_size": 2, - "pp_size": 2, - "num_microbatches": 2, - "enable_all_optimization": False, - "use_lazy_init": False, - "precision": "fp32", - "initial_scale": 1, - }, ], ) def run_vit_3d_test(test_config): diff --git a/tests/test_smoothquant/test_llama_attention.py b/tests/test_smoothquant/test_llama_attention.py new file mode 100644 index 000000000000..f8c79145c952 --- /dev/null +++ b/tests/test_smoothquant/test_llama_attention.py @@ -0,0 +1,136 @@ +import pytest +import torch +from packaging import version + +try: + from colossalai.kernel.triton import int8_rotary_embedding_fwd + + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +try: + from colossalai.inference.quant.smoothquant.models import LLamaSmoothquantAttention + + HAS_TORCH_INT = True +except ImportError: + HAS_TORCH_INT = False + print("Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int") + + +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") + +import math + +import torch +from torch.nn import functional as F + + +def torch_context_attention(xq, xk, xv, bs, seqlen, num_head, head_dim): + """ + adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/bloom/triton_kernel/context_flashattention_nopad.py#L253 + """ + xq = xq.view(bs, seqlen, num_head, head_dim) + xk = xk.view(bs, seqlen, num_head, head_dim) + xv = xv.view(bs, seqlen, num_head, head_dim) + mask = torch.tril(torch.ones(seqlen, seqlen), diagonal=0).unsqueeze(0).unsqueeze(0).cuda() + mask[mask == 0.0] = -100000000.0 + mask = mask.repeat(bs, num_head, 1, 1) + keys = xk + values = xv + xq = xq.transpose(1, 2) + keys = keys.transpose(1, 2) + values = values.transpose(1, 2) + scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(head_dim) + scores = F.softmax(scores.float() + mask, dim=-1).type_as(xq) + output = torch.matmul(scores, values).transpose(1, 2).contiguous().reshape(-1, num_head, head_dim) + + return output + + +@pytest.mark.skipif( + not TRITON_CUDA_SUPPORT or not HAS_TRITON or not HAS_TORCH_INT, + reason="triton requires cuda version to be higher than 11.4 or not install torch_int", +) +def test_llama_context_attention(): + head_num = 2 + seq_len = 32 + head_dim = 64 + dtype = torch.float + hidden_size = head_num * head_dim + + smooth_attn = LLamaSmoothquantAttention(head_num * head_dim, head_num) + + smooth_attn.q_proj.weight = torch.ones(hidden_size, hidden_size, device="cuda").to(torch.int8) + smooth_attn.k_proj.weight = torch.ones(hidden_size, hidden_size, device="cuda").to(torch.int8) + smooth_attn.v_proj.weight = torch.ones(hidden_size, hidden_size, device="cuda").to(torch.int8) + smooth_attn.out_proj.weight = torch.ones(hidden_size, hidden_size, device="cuda").to(torch.int8) + smooth_attn.out_proj.weight[:, 1:hidden_size] = torch.zeros(hidden_size - 1, device="cuda").to(torch.int8) + + qkv_weight_scale = 1.0 + + ones = torch.ones(hidden_size, hidden_size, dtype=torch.float, device="cuda") + + smooth_attn = smooth_attn.to("cuda") + + input = torch.randint(-20, 20, (1, seq_len, head_num * head_dim), dtype=torch.int8, device="cuda") + input_scale = 1 / 20.0 + + output = torch.matmul(input.to(torch.float) * input_scale, ones) + qkv_max_out = torch.max(torch.abs(output)) / 127 + smooth_attn.q_proj.a = torch.tensor(input_scale * qkv_weight_scale / qkv_max_out) + smooth_attn.k_proj.a = torch.tensor(input_scale * qkv_weight_scale / qkv_max_out) + smooth_attn.v_proj.a = torch.tensor(input_scale * qkv_weight_scale / qkv_max_out) + + q = smooth_attn.q_proj(input) + k = smooth_attn.k_proj(input) + v = smooth_attn.v_proj(input) + + cos_shape = (seq_len, head_dim // 2) + cos = torch.ones(cos_shape, dtype=dtype, device="cuda") + sin = torch.zeros(cos_shape, dtype=dtype, device="cuda") + in_scale = torch.tensor([qkv_max_out], device="cuda") + out_scale = torch.tensor([qkv_max_out], device="cuda") + int8_rotary_embedding_fwd(q.view(-1, head_num, head_dim), cos, sin, in_scale.item(), out_scale.item()) + int8_rotary_embedding_fwd(k.view(-1, head_num, head_dim), cos, sin, in_scale.item(), out_scale.item()) + + q = q.to(torch.float) * out_scale + k = k.to(torch.float) * out_scale + v = v.to(torch.float) * out_scale + torch_out = torch_context_attention(q.clone(), k.clone(), v.clone(), 1, seq_len, head_num, head_dim) + attn_out_max = torch.max(torch.abs(torch_out)) / 127 + + output = torch.matmul(torch_out.view(-1, seq_len, head_num * head_dim), ones) + smooth_attn.q_output_scale = torch.tensor(qkv_max_out) + smooth_attn.k_output_scale = torch.tensor(qkv_max_out) + + smooth_attn.v_output_scale = torch.tensor(qkv_max_out) + smooth_attn.q_rotary_output_scale = torch.tensor(qkv_max_out) + smooth_attn.k_rotary_output_scale = torch.tensor(qkv_max_out) + + smooth_attn.attn_output_scale = torch.tensor(attn_out_max) + smooth_attn.out_proj.a = torch.tensor([attn_out_max]) + + torch_out = ( + (torch_out / smooth_attn.attn_output_scale) + .round() + .clamp(-128, 127) + .to(torch.int8) + .view(-1, seq_len, head_num * head_dim) + ) + + torch_out = smooth_attn.out_proj(torch_out) + torch_out = torch_out.to(torch.float) + + smooth_attn = smooth_attn.to("cuda") + smooth_out, _, _ = smooth_attn(input, (cos, sin)) + smooth_out = smooth_out.to(torch.float) + + assert torch.allclose( + torch_out.cpu(), smooth_out.cpu(), rtol=1e-1, atol=1e-1 + ), "outputs from triton and torch are not matched" + + +if __name__ == "__main__": + test_llama_context_attention() diff --git a/tests/test_smoothquant/test_llama_mlp.py b/tests/test_smoothquant/test_llama_mlp.py new file mode 100644 index 000000000000..236edb10cb7f --- /dev/null +++ b/tests/test_smoothquant/test_llama_mlp.py @@ -0,0 +1,84 @@ +import warnings + +import pytest +import torch +from packaging import version + +try: + from colossalai.kernel.op_builder.smoothquant import SmoothquantBuilder + + smoothquant_cuda = SmoothquantBuilder().load() + HAS_SMOOTHQUANT_CUDA = True +except: + warnings.warn("CUDA smoothquant linear is not installed") + HAS_SMOOTHQUANT_CUDA = False + + +try: + from colossalai.inference.quant.smoothquant.models import LlamaSmoothquantMLP + + HAS_TORCH_INT = True +except: + HAS_TORCH_INT = False + warnings.warn("Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int") + + +CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") + + +def torch_llama_mlp(gate_proj, up_proj, down_proj, x): + gate_out = torch.mm(x, gate_proj) + silu = torch.nn.SiLU() + gate_out = silu(gate_out) + up_out = torch.mm(x, up_proj) + + o_out = gate_out * up_out + + max_up = torch.max(torch.abs(o_out)) + min_up = torch.min(torch.abs(o_out)) + + torch_out = torch.mm(o_out, down_proj) + + return (torch_out, max_up, min_up) + + +@pytest.mark.skipif( + not CUDA_SUPPORT or not HAS_SMOOTHQUANT_CUDA or not HAS_TORCH_INT, + reason="smoothquant linear not installed properly or not install torch_int", +) +def test_llama_mlp(): + hidden_size = 256 + intermediate_size = 512 + + smooth_mlp = LlamaSmoothquantMLP(intermediate_size, hidden_size) + + smooth_mlp.gate_proj.weight = torch.ones((intermediate_size, hidden_size), dtype=torch.int8, device="cuda") + + smooth_mlp.up_proj.weight = torch.randint( + -10, 10, (intermediate_size, hidden_size), dtype=torch.int8, device="cuda" + ) + smooth_mlp.down_proj.weight = torch.randint( + -10, 10, (hidden_size, intermediate_size), dtype=torch.int8, device="cuda" + ) + + x = torch.ones((1, 256), dtype=torch.int8, device="cuda") + + torch_out, max_inter, min_inter = torch_llama_mlp( + smooth_mlp.gate_proj.weight.transpose(0, 1).to(torch.float) / hidden_size, + smooth_mlp.up_proj.weight.transpose(0, 1).to(torch.float) / 127, + smooth_mlp.down_proj.weight.transpose(0, 1).to(torch.float) / 127, + x.to(torch.float), + ) + + smooth_mlp.down_proj_input_scale = torch.tensor(max_inter.item() / 127) + smooth_mlp.gate_proj.a = torch.tensor(1 / hidden_size) + smooth_mlp.up_proj.a = torch.tensor(1 / 127) + smooth_mlp.down_proj.a = torch.tensor(1 / 127 * (max_inter.item() / 127)) + + smooth_out = smooth_mlp(x) + + assert torch.allclose(torch_out, smooth_out, rtol=1e-02, atol=1e-01) + + +if __name__ == "__main__": + test_llama_mlp() diff --git a/tests/test_smoothquant/test_smoothquant_linear.py b/tests/test_smoothquant/test_smoothquant_linear.py new file mode 100644 index 000000000000..58a0b82f6759 --- /dev/null +++ b/tests/test_smoothquant/test_smoothquant_linear.py @@ -0,0 +1,39 @@ +import warnings + +import pytest +import torch + +try: + from colossalai.kernel.op_builder.smoothquant import SmoothquantBuilder + + smoothquant_cuda = SmoothquantBuilder().load() + HAS_SMOOTHQUANT_CUDA = True +except: + warnings.warn("CUDA smoothquant linear is not installed") + HAS_SMOOTHQUANT_CUDA = False + + +@pytest.mark.skipif( + not HAS_SMOOTHQUANT_CUDA, + reason="smoothquant linear not installed properly", +) +def test_linear(): + a = torch.randint(-127, 127, (128, 512), dtype=torch.int8, device="cuda") + b = torch.randint(-127, 127, (512, 256), dtype=torch.int8, device="cuda") + c = torch.rand(256, dtype=torch.float, device="cuda") + + alpha = 1 / 127 + beta = 1.0 + torch_out = torch.mm(a.to(torch.float) * alpha, b.to(torch.float)) + c + + silu = torch.nn.SiLU() + torch_out = silu(torch_out) + + b = b.transpose(0, 1).contiguous() + cuda_out = smoothquant_cuda.linear_silu_a8_w8_bfp32_ofp32(a, b, c, alpha, beta) + + assert torch.allclose(torch_out, cuda_out, rtol=1e-02, atol=1e-02) + + +if __name__ == "__main__": + test_linear() diff --git a/tests/test_infer_ops/triton/test_rotary_embedding.py b/tests/test_smoothquant/test_sq_rotary_embedding.py similarity index 73% rename from tests/test_infer_ops/triton/test_rotary_embedding.py rename to tests/test_smoothquant/test_sq_rotary_embedding.py index 7e05ccafbfc4..4cc76f00474d 100644 --- a/tests/test_infer_ops/triton/test_rotary_embedding.py +++ b/tests/test_smoothquant/test_sq_rotary_embedding.py @@ -6,9 +6,7 @@ from packaging import version try: - pass - - from colossalai.kernel.triton.rotary_embedding_kernel import rotary_embedding_fwd + from colossalai.kernel.triton import int8_rotary_embedding_fwd HAS_TRITON = True except ImportError: @@ -36,7 +34,7 @@ def test_rotary_emb(): SEQ_LEN = 1 HEAD_NUM = 32 HEAD_DIM = 128 - dtype = torch.half + dtype = torch.float # create data x_shape = (SEQ_LEN, HEAD_NUM, HEAD_DIM) x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda") @@ -45,10 +43,16 @@ def test_rotary_emb(): sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") # forward pass y_torch = torch_rotary_emb(x, cos, sin) - rotary_embedding_fwd(x, cos, sin) - y_triton = x - # compare - assert torch.allclose(y_torch, y_triton, atol=1e-2, rtol=0) + + input_scale = torch.max(torch.abs(x)) / 127 + output_scale = torch.max(torch.abs(y_torch)) / 127 + + x = x / input_scale + x = x.to(torch.int8) + + int8_rotary_embedding_fwd(x, cos, sin, input_scale.item(), output_scale.item()) + y_triton = x.to(torch.float) * output_scale + assert torch.allclose(y_triton, y_torch, atol=2e-1, rtol=1e-2, equal_nan=True) if __name__ == "__main__": diff --git a/tests/test_zero/test_gemini/test_fwd_bwd.py b/tests/test_zero/test_gemini/test_fwd_bwd.py index 94e70040019c..2fb2bcbc851a 100644 --- a/tests/test_zero/test_gemini/test_fwd_bwd.py +++ b/tests/test_zero/test_gemini/test_fwd_bwd.py @@ -27,6 +27,8 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module): chunk_manager = model.chunk_manager param_list = [p for p in model.parameters()] chunk_list = chunk_manager.get_chunks(param_list) + if not model.reuse_fp16_chunk: + chunk_list = [chunk.grad_chunk for chunk in chunk_list] for chunk in chunk_list: chunk_manager.access_chunk(chunk) @@ -36,13 +38,15 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module): @parameterize("placement_config", PLACEMENT_CONFIGS) @parameterize("keep_gather", [False, True]) -@parameterize("model_name", ["gpt2", "bert", "albert"]) +@parameterize("model_name", ["gpt2", "bert"]) @parameterize("use_grad_checkpoint", [False, True]) +@parameterize("master_weights", [False, True]) def exam_gpt_fwd_bwd( placement_config, keep_gather, model_name: str, use_grad_checkpoint: bool = False, + master_weights: bool = True, ): init_device = get_current_device() get_components_func = non_distributed_component_funcs.get_callable(model_name) @@ -60,12 +64,14 @@ def exam_gpt_fwd_bwd( config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) config_dict[world_size]["chunk_size"] = 5000 config_dict[world_size]["keep_gathered"] = keep_gather - model = GeminiDDP(model, config_dict, init_device, pin_memory=True, **placement_config) + model = GeminiDDP( + model, config_dict, init_device, pin_memory=True, **placement_config, master_weights=master_weights + ) optimizer = HybridAdam(model.parameters(), lr=1e-3) zero_optim = GeminiOptimizer(optimizer, model, initial_scale=1) rank = dist.get_rank() - amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False, loss_scale=1) + amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False, loss_scale=1, master_weights=master_weights) torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) torch_model = DDP(torch_model, device_ids=[rank]) @@ -106,4 +112,4 @@ def test_gpt(world_size): if __name__ == "__main__": - test_gpt(4) + test_gpt(1) diff --git a/tests/test_zero/test_gemini/test_grad_accum.py b/tests/test_zero/test_gemini/test_grad_accum.py new file mode 100644 index 000000000000..334a57410817 --- /dev/null +++ b/tests/test_zero/test_gemini/test_grad_accum.py @@ -0,0 +1,147 @@ +import pytest +import torch +import torch.distributed as dist +from apex import amp +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.testing import assert_close + +import colossalai +from colossalai.nn.optimizer import HybridAdam +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils import set_seed +from colossalai.utils.cuda import get_current_device +from colossalai.zero import GeminiDDP, GeminiOptimizer +from colossalai.zero.gemini.chunk import search_chunk_configuration +from tests.components_to_test import run_fwd +from tests.components_to_test.registry import non_distributed_component_funcs + +PLACEMENT_CONFIGS = [ + {"placement_policy": "static", "shard_param_frac": 0.0}, # zero2 + {"placement_policy": "static", "shard_param_frac": 1.0}, # zero3 + {"placement_policy": "static", "shard_param_frac": 0.5}, # zero3-half + {"placement_policy": "auto"}, +] + + +def check_grad(model: GeminiDDP, torch_model: torch.nn.Module): + chunk_manager = model.chunk_manager + grad_chunk_list = [] + device_list = [] + + # Access gradient chunks. + for p in model.parameters(): + grad_chunk = chunk_manager.get_chunk(p).grad_chunk + if grad_chunk not in grad_chunk_list: + chunk_manager.access_chunk(grad_chunk) + grad_chunk_list.append(grad_chunk) + device_list.append(model.grads_device[p]) + + # Compare gradients. + for p0, p1 in zip(model.parameters(), torch_model.parameters()): + assert_close(p0, p1.grad, rtol=1e-3, atol=5e-5) + + # Release gradient chunks and move them to gradient device. + for grad_chunk, device in zip(grad_chunk_list, device_list): + chunk_manager.release_chunk(grad_chunk) + chunk_manager.move_chunk(grad_chunk, device, force_copy=True) + + +@parameterize("placement_config", PLACEMENT_CONFIGS) +@parameterize("keep_gathered", [False, True]) +@parameterize("model_name", ["gpt2", "bert"]) +@parameterize("use_grad_checkpoint", [False, True]) +@parameterize("master_weights", [False, True]) +def exam_gemini_grad_acc( + placement_config, keep_gathered: bool, model_name: str, use_grad_checkpoint: bool, master_weights: bool +): + init_device = get_current_device() + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, train_dataloader, _, _, criterion = get_components_func() + + set_seed(42) + gemini_model = model_builder(use_grad_checkpoint) + + set_seed(42) + torch_model = model_builder(use_grad_checkpoint).cuda() + for torch_p, p in zip(torch_model.parameters(), gemini_model.parameters()): + torch_p.data.copy_(p.data) + + world_size = torch.distributed.get_world_size() + config_dict, *_ = search_chunk_configuration(gemini_model, search_range_m=1, search_interval=100) + config_dict[world_size]["chunk_size"] = 5000 + config_dict[world_size]["keep_gathered"] = keep_gathered + gemini_model = GeminiDDP( + gemini_model, + config_dict, + init_device, + pin_memory=True, + enable_gradient_accumulation=True, + master_weights=master_weights, + **placement_config, + ) + optimizer = HybridAdam(gemini_model.parameters(), lr=1e-3) + gemini_optim = GeminiOptimizer(optimizer, gemini_model, initial_scale=1) + + rank = dist.get_rank() + + # setting master_weights to False will cause overflow after optimizer.step() + amp_config = dict( + opt_level="O2", keep_batchnorm_fp32=False, loss_scale=1, min_loss_scale=1, max_loss_scale=1, master_weights=True + ) + torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) + torch_model, torch_optim = amp.initialize(torch_model, torch_optim, **amp_config) + torch_model = DDP(torch_model, device_ids=[rank]) + + set_seed(rank) + accum_iter = 4 + for i, (input_ids, label) in enumerate(train_dataloader): + delay_unscale = False if (i + 1) % accum_iter == 0 else True + input_ids, label = input_ids.cuda(), label.cuda() + + set_seed(42 + rank) + torch_loss = run_fwd(torch_model, input_ids, label, criterion) + torch_loss = torch_loss / accum_iter + with amp.scale_loss(torch_loss, torch_optim, delay_unscale=delay_unscale) as scaled_loss: + scaled_loss.backward() + + set_seed(42 + rank) + gemini_loss = run_fwd(gemini_model, input_ids, label, criterion) + gemini_loss = gemini_loss / accum_iter + gemini_optim.backward(gemini_loss) + + assert torch.allclose(torch_loss, gemini_loss, rtol=1e-3, atol=1e-5) + + check_grad(gemini_model, torch_model) + + if (i + 1) % accum_iter == 0: + torch_optim.step() + gemini_optim.step() + torch_optim.zero_grad() + + # check updated param + torch_dict = torch_model.state_dict() + gemini_dict = gemini_model.state_dict(only_rank_0=False) + + for key, value in gemini_dict.items(): + torch_key = "module." + key + torch_value = torch_dict[torch_key].to(value.device).to(value.dtype) + assert_close(value, torch_value, rtol=1e-3, atol=2e-3) + + if i == accum_iter: + break + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + exam_gemini_grad_acc() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_grad_accumulation(): + spawn(run_dist, 2) + + +if __name__ == "__main__": + test_grad_accumulation() diff --git a/tests/test_zero/test_gemini/test_grad_clip.py b/tests/test_zero/test_gemini/test_grad_clip.py index d8bcc555a15d..4c84e9e5a89a 100644 --- a/tests/test_zero/test_gemini/test_grad_clip.py +++ b/tests/test_zero/test_gemini/test_grad_clip.py @@ -52,7 +52,8 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module): @parameterize("placement_config", PLACEMENT_CONFIGS) @parameterize("model_name", ["gpt2"]) -def exam_grad_clipping(placement_config, model_name: str): +@parameterize("master_weights", [True, False]) +def exam_grad_clipping(placement_config, model_name: str, master_weights: bool): set_seed(1912) get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() @@ -78,7 +79,12 @@ def exam_grad_clipping(placement_config, model_name: str): init_device = None model = GeminiDDP( - model, chunk_config_dict=config_dict, chunk_init_device=init_device, pin_memory=True, **placement_config + model, + chunk_config_dict=config_dict, + chunk_init_device=init_device, + pin_memory=True, + master_weights=master_weights, + **placement_config, ) optimizer = HybridAdam(model.parameters(), lr=1e-3) @@ -99,7 +105,10 @@ def exam_grad_clipping(placement_config, model_name: str): torch_loss = run_fwd_bwd(torch_model, data, label, criterion, torch_optim) loss = run_fwd_bwd(model, data, label, criterion, zero_optim) - assert_close(torch_loss, loss) + + # as no master weights leads to error accumulation, we don't check the loss + if master_weights: + assert_close(torch_loss, loss) import apex.amp as apex_amp @@ -107,7 +116,8 @@ def exam_grad_clipping(placement_config, model_name: str): torch_optim.step() zero_optim.step() - check_param(model, torch_model) + if master_weights: + check_param(model, torch_model) def run_dist(rank, world_size, port): diff --git a/tests/test_zero/test_gemini/test_optim.py b/tests/test_zero/test_gemini/test_optim.py index b7c08392600f..0cf9aa073f9f 100644 --- a/tests/test_zero/test_gemini/test_optim.py +++ b/tests/test_zero/test_gemini/test_optim.py @@ -1,6 +1,7 @@ import pytest import torch import torch.distributed as dist +from packaging.version import Version from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing import assert_close @@ -44,7 +45,7 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module, dtype: torch.dtype): - zero_dict = model.state_dict(only_rank_0=False, dtype=dtype) + zero_dict = model.state_dict(only_rank_0=False) torch_dict = torch_model.state_dict() for key, value in torch_dict.items(): @@ -70,12 +71,14 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module, dtype: torch.dty @parameterize("placement_config", PLACEMENT_CONFIGS) @parameterize("model_name", TEST_MODELS) @parameterize("mixed_precision", [torch.half, torch.bfloat16]) -def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dtype): +@parameterize("master_weights", [True, False]) +def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dtype, master_weights: bool): set_seed(42) get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() torch_model = model_builder().cuda() + # apex no master weights leads to nan, so we don't use it amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False, loss_scale=128) torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) @@ -90,7 +93,9 @@ def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dt config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) config_dict[world_size]["chunk_size"] = 5000 config_dict[world_size]["keep_gathered"] = False - model = GeminiDDP(model, config_dict, **placement_config, mixed_precision=mixed_precision) + model = GeminiDDP( + model, config_dict, **placement_config, mixed_precision=mixed_precision, master_weights=master_weights + ) optimizer = HybridAdam(model.parameters(), lr=1e-3) zero_optim = GeminiOptimizer(optimizer, model, initial_scale=128) @@ -109,12 +114,15 @@ def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dt torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim) loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim) - assert_close(torch_loss, loss, rtol=rtol, atol=atol) + # as no master weights leads to error accumulation, we don't check the loss + if master_weights: + assert_close(torch_loss, loss, rtol=rtol, atol=atol) zero_optim.step() torch_optim.step() - check_param(model, torch_model, mixed_precision) + if master_weights: + check_param(model, torch_model, mixed_precision) @parameterize("placement_config", PLACEMENT_CONFIGS) @@ -154,6 +162,9 @@ def exam_tiny_example(placement_config, model_name: str, mixed_precision: torch. rtol, atol = 1.5e-6, 2e-5 if mixed_precision is torch.bfloat16: rtol, atol = 2e-3, 2e-3 + elif Version(torch.__version__) >= Version("2.0.0"): + rtol, atol = 4e-5, 3e-5 + for i, (input_ids, label) in enumerate(train_dataloader): if i > 2: break diff --git a/tests/test_zero/test_gemini/test_zeroddp_state_dict.py b/tests/test_zero/test_gemini/test_zeroddp_state_dict.py index 3130440bd925..bf16a301cd8a 100644 --- a/tests/test_zero/test_gemini/test_zeroddp_state_dict.py +++ b/tests/test_zero/test_gemini/test_zeroddp_state_dict.py @@ -27,7 +27,8 @@ def ignore_the_first_parameter(model: torch.nn.Module): @parameterize("placement_config", PLACEMENT_CONFIGS) @parameterize("keep_gathered", [True, False]) @parameterize("model_name", ["gpt2", "bert"]) -def exam_state_dict(placement_config, keep_gathered, model_name: str): +@parameterize("master_weights", [False, True]) +def exam_state_dict(placement_config, keep_gathered, model_name: str, master_weights: bool): set_seed(431) get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() @@ -42,7 +43,7 @@ def exam_state_dict(placement_config, keep_gathered, model_name: str): config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) config_dict[world_size]["chunk_size"] = 5000 config_dict[world_size]["keep_gathered"] = keep_gathered - model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True) + model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True, master_weights=master_weights) model.train() zero_dict = model.state_dict(only_rank_0=False) @@ -57,7 +58,8 @@ def exam_state_dict(placement_config, keep_gathered, model_name: str): @parameterize("placement_config", PLACEMENT_CONFIGS) @parameterize("keep_gathered", [True, False]) @parameterize("model_name", ["gpt2", "bert"]) -def exam_load_state_dict(placement_config, keep_gathered, model_name: str): +@parameterize("master_weights", [False, True]) +def exam_load_state_dict(placement_config, keep_gathered, model_name: str, master_weights: bool): set_seed(431) get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() @@ -72,7 +74,7 @@ def exam_load_state_dict(placement_config, keep_gathered, model_name: str): config_dict[world_size]["chunk_size"] = 5000 config_dict[world_size]["keep_gathered"] = keep_gathered - model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True) + model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True, master_weights=master_weights) torch_dict = torch_model.state_dict() model.load_state_dict(torch_dict, strict=False) @@ -86,7 +88,8 @@ def exam_load_state_dict(placement_config, keep_gathered, model_name: str): @parameterize("placement_config", PLACEMENT_CONFIGS) @parameterize("model_name", ["gpt2", "bert"]) -def exam_state_dict_shard(placement_config, model_name: str): +@parameterize("master_weights", [False, True]) +def exam_state_dict_shard(placement_config, model_name: str, master_weights: bool): get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() @@ -95,7 +98,7 @@ def exam_state_dict_shard(placement_config, model_name: str): model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2 config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) - model = GeminiDDP(model, config_dict, **placement_config) + model = GeminiDDP(model, config_dict, **placement_config, master_weights=master_weights) model.train() zero_dict = model.state_dict(only_rank_0=False) diff --git a/tests/test_zero/test_gemini/test_zerooptim_state_dict.py b/tests/test_zero/test_gemini/test_zerooptim_state_dict.py index 8aa656b74cf9..c65c6d292467 100644 --- a/tests/test_zero/test_gemini/test_zerooptim_state_dict.py +++ b/tests/test_zero/test_gemini/test_zerooptim_state_dict.py @@ -72,6 +72,7 @@ def run_dist(rank, world_size, port): exam_zero_optim_state_dict() +@pytest.mark.skip @pytest.mark.dist @pytest.mark.parametrize("world_size", [1, 4]) @rerun_if_address_is_in_use() diff --git a/tests/test_zero/test_low_level/test_zero1_2.py b/tests/test_zero/test_low_level/test_zero1_2.py index ebda9f6f25c5..e2196cfbf0f2 100644 --- a/tests/test_zero/test_low_level/test_zero1_2.py +++ b/tests/test_zero/test_low_level/test_zero1_2.py @@ -106,7 +106,8 @@ def exam_zero_1_2(): @parameterize("dtype", [torch.float16, torch.bfloat16]) -def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype): +@parameterize("master_weights", [True, False]) +def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool): """ In this test, two pairs of model and optimizers are created. 1. zero: use sharded optimizer and fp16 parameters @@ -131,7 +132,11 @@ def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype): # in `check_sharded_param_consistency.py`, we will test whether # level 1 and 2 will produce exactly the same results zero_optimizer = LowLevelZeroOptimizer( - zero_optimizer, overlap_communication=True, initial_scale=1, reduce_bucket_size=1024 * 1024 + zero_optimizer, + overlap_communication=True, + initial_scale=1, + reduce_bucket_size=1024 * 1024, + master_weights=master_weights, ) torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)