From 611a5a80cac780e3a6e30c0f81f4ceb2a3b3f2c0 Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Mon, 16 Oct 2023 11:28:44 +0800 Subject: [PATCH] [inference] Add smmoothquant for llama (#4904) * [inference] add int8 rotary embedding kernel for smoothquant (#4843) * [inference] add smoothquant llama attention (#4850) * add smoothquant llama attention * remove uselss code * remove useless code * fix import error * rename file name * [inference] add silu linear fusion for smoothquant llama mlp (#4853) * add silu linear * update skip condition * catch smoothquant cuda lib exception * prcocess exception for tests * [inference] add llama mlp for smoothquant (#4854) * add llama mlp for smoothquant * fix down out scale * remove duplicate lines * add llama mlp check * delete useless code * [inference] add smoothquant llama (#4861) * add smoothquant llama * fix attention accuracy * fix accuracy * add kv cache and save pretrained * refactor example * delete smooth * refactor code * [inference] add smooth function and delete useless code for smoothquant (#4895) * add smooth function and delete useless code * update datasets * remove duplicate import * delete useless file * refactor codes (#4902) * rafactor code * add license * add torch-int and smoothquant license --- LICENSE | 50 ++ .../inference/quant/smoothquant/__init__.py | 0 .../quant/smoothquant/models/__init__.py | 12 + .../quant/smoothquant/models/base_model.py | 482 ++++++++++ .../quant/smoothquant/models/linear.py | 177 ++++ .../quant/smoothquant/models/llama.py | 846 ++++++++++++++++++ .../cuda_native/csrc/smoothquant/binding.cpp | 8 + .../cuda_native/csrc/smoothquant/linear.cu | 162 ++++ .../cuda_native/csrc/smoothquant/linear.h | 12 + colossalai/kernel/triton/__init__.py | 5 + .../triton/int8_rotary_embedding_kernel.py | 117 +++ colossalai/kernel/triton/smooth_attention.py | 652 ++++++++++++++ examples/inference/smoothquant_llama.py | 69 ++ op_builder/smoothquant.py | 52 ++ .../test_smoothquant/test_llama_attention.py | 136 +++ tests/test_smoothquant/test_llama_mlp.py | 84 ++ .../test_smoothquant_linear.py | 39 + .../test_sq_rotary_embedding.py | 59 ++ 18 files changed, 2962 insertions(+) create mode 100644 colossalai/inference/quant/smoothquant/__init__.py create mode 100644 colossalai/inference/quant/smoothquant/models/__init__.py create mode 100644 colossalai/inference/quant/smoothquant/models/base_model.py create mode 100644 colossalai/inference/quant/smoothquant/models/linear.py create mode 100644 colossalai/inference/quant/smoothquant/models/llama.py create mode 100644 colossalai/kernel/cuda_native/csrc/smoothquant/binding.cpp create mode 100644 colossalai/kernel/cuda_native/csrc/smoothquant/linear.cu create mode 100644 colossalai/kernel/cuda_native/csrc/smoothquant/linear.h create mode 100644 colossalai/kernel/triton/int8_rotary_embedding_kernel.py create mode 100644 colossalai/kernel/triton/smooth_attention.py create mode 100644 examples/inference/smoothquant_llama.py create mode 100644 op_builder/smoothquant.py create mode 100644 tests/test_smoothquant/test_llama_attention.py create mode 100644 tests/test_smoothquant/test_llama_mlp.py create mode 100644 tests/test_smoothquant/test_smoothquant_linear.py create mode 100644 tests/test_smoothquant/test_sq_rotary_embedding.py diff --git a/LICENSE b/LICENSE index 59d456c5b8a1..b3eb43520a6f 100644 --- a/LICENSE +++ b/LICENSE @@ -477,3 +477,53 @@ Copyright 2021- HPC-AI Technology Inc. All rights reserved. LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + + + ---------------- LICENSE FOR torch-int ---------------- + + MIT License + + Copyright (c) 2022 Guangxuan Xiao + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. + + + ---------------- LICENSE FOR smoothquant ---------------- + + MIT License + + Copyright (c) 2022 MIT HAN Lab + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. diff --git a/colossalai/inference/quant/smoothquant/__init__.py b/colossalai/inference/quant/smoothquant/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/colossalai/inference/quant/smoothquant/models/__init__.py b/colossalai/inference/quant/smoothquant/models/__init__.py new file mode 100644 index 000000000000..77541d8610c5 --- /dev/null +++ b/colossalai/inference/quant/smoothquant/models/__init__.py @@ -0,0 +1,12 @@ +try: + import torch_int + + HAS_TORCH_INT = True +except ImportError: + HAS_TORCH_INT = False + raise ImportError( + "Not install torch_int. Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int" + ) + +if HAS_TORCH_INT: + from .llama import LLamaSmoothquantAttention, LlamaSmoothquantMLP diff --git a/colossalai/inference/quant/smoothquant/models/base_model.py b/colossalai/inference/quant/smoothquant/models/base_model.py new file mode 100644 index 000000000000..180e6c6e8fa6 --- /dev/null +++ b/colossalai/inference/quant/smoothquant/models/base_model.py @@ -0,0 +1,482 @@ +# Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ +# Adapted from smoothquant: https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/calibration.py +# Adapted from smoothquant: https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/smooth.py + +import os +import warnings +from abc import abstractmethod +from functools import partial +from os.path import isdir, isfile, join +from typing import Dict, List, Optional, Union + +import accelerate +import numpy as np +import torch +import torch.nn as nn +import transformers +from safetensors.torch import save_file as safe_save +from tqdm import tqdm +from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel +from transformers.modeling_utils import no_init_weights +from transformers.utils.generic import ContextManagers +from transformers.utils.hub import PushToHubMixin, cached_file + +from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState +from colossalai.inference.tensor_parallel.kvcache_manager import MemoryManager + +SUPPORTED_MODELS = ["llama"] + + +class BaseSmoothForCausalLM(nn.Module, PushToHubMixin): + layer_type: str = None + + def __init__(self, model: PreTrainedModel, quantized: bool = False): + super().__init__() + + self.model = model + self.model_type = self.model.config.model_type + self._quantized = quantized + self.config = self.model.config + self.cache_manager = None + self.max_total_token_num = 0 + + @property + def quantized(self): + return self._quantized + + def init_cache_manager(self, max_total_token_num=2048): + if self.config.model_type == "llama": + head_num = self.config.num_key_value_heads + layer_num = self.config.num_hidden_layers + head_dim = self.config.hidden_size // head_num + + self.cache_manager = MemoryManager(max_total_token_num, torch.int8, head_num, head_dim, layer_num) + self.max_total_token_num = max_total_token_num + + def init_batch_state(self, max_output_len=256, **kwargs): + input_ids = kwargs["input_ids"] + batch_size = len(input_ids) + + seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + start_index = 0 + max_len_in_batch = -1 + + for i in range(batch_size): + seq_len = len(input_ids[i]) + seq_lengths[i] = seq_len + seq_start_indexes[i] = start_index + start_index += seq_len + max_len_in_batch = seq_len if seq_len > max_len_in_batch else max_len_in_batch + + if "max_total_token_num" in kwargs.keys(): + max_total_token_num = kwargs["max_total_token_num"] + self.init_cache_manager(max_total_token_num) + + if "max_new_tokens" in kwargs.keys(): + max_output_len = kwargs["max_new_tokens"] + + if batch_size * (max_len_in_batch + max_output_len) > self.max_total_token_num: + max_total_token_num = batch_size * (max_len_in_batch + max_output_len) + warnings.warn(f"reset max tokens to {max_total_token_num}") + self.init_cache_manager(max_total_token_num) + + block_loc = torch.empty((batch_size, max_len_in_batch + max_output_len), dtype=torch.long, device="cuda") + batch_infer_state = BatchInferState(batch_size, max_len_in_batch) + batch_infer_state.seq_len = seq_lengths.to("cuda") + batch_infer_state.start_loc = seq_start_indexes.to("cuda") + batch_infer_state.block_loc = block_loc + batch_infer_state.decode_layer_id = 0 + batch_infer_state.past_key_values_len = 0 + batch_infer_state.is_context_stage = True + batch_infer_state.set_cache_manager(self.cache_manager) + batch_infer_state.cache_manager.free_all() + return batch_infer_state + + @abstractmethod + @torch.inference_mode() + def quantize( + self, + examples: List[Dict[str, Union[List[int], torch.LongTensor]]], + ): + if self.quantized: + raise EnvironmentError("can't execute quantize because the model is quantized.") + + def forward(self, *args, **kwargs): + return self.model(*args, **kwargs) + + def generate(self, **kwargs): + """shortcut for model.generate""" + + batch_infer_state = self.init_batch_state(**kwargs) + if self.config.model_type == "llama": + setattr(self.model.model, "infer_state", batch_infer_state) + + with torch.inference_mode(): + return self.model.generate(**kwargs) + + def prepare_inputs_for_generation(self, *args, **kwargs): + """shortcut for model.prepare_inputs_for_generation""" + return self.model.prepare_inputs_for_generation(*args, **kwargs) + + def collect_act_scales(self, model, tokenizer, dataset, device, num_samples=512, seq_len=512): + for text in tqdm(dataset): + input_ids = tokenizer(text, return_tensors="pt", max_length=seq_len, truncation=True).input_ids.to(device) + model(input_ids) + + def collect_act_dict(self, model, tokenizer, dataset, act_dict, device, num_samples=512, seq_len=512): + pbar = tqdm(dataset) + for text in pbar: + input_ids = tokenizer(text, return_tensors="pt", max_length=seq_len, truncation=True).input_ids.to(device) + model(input_ids) + mean_scale = np.mean([v["input"] for v in act_dict.values()]) + pbar.set_description(f"Mean input scale: {mean_scale:.2f}") + + def get_act_scales(self, model, tokenizer, dataset, num_samples=512, seq_len=512): + model.eval() + device = next(model.parameters()).device + act_scales = {} + + def stat_tensor(name, tensor): + hidden_dim = tensor.shape[-1] + tensor = tensor.view(-1, hidden_dim).abs().detach() + comming_max = torch.max(tensor, dim=0)[0].float().cpu() + if name in act_scales: + act_scales[name] = torch.max(act_scales[name], comming_max) + else: + act_scales[name] = comming_max + + def stat_input_hook(m, x, y, name): + if isinstance(x, tuple): + x = x[0] + stat_tensor(name, x) + + hooks = [] + for name, m in model.named_modules(): + if isinstance(m, nn.Linear): + hooks.append(m.register_forward_hook(partial(stat_input_hook, name=name))) + + self.collect_act_scales(model, tokenizer, dataset, device, num_samples, seq_len) + + for h in hooks: + h.remove() + + return act_scales + + @torch.no_grad() + def smooth_ln_fcs(self, ln, fcs, act_scales, alpha=0.5): + if not isinstance(fcs, list): + fcs = [fcs] + for fc in fcs: + assert isinstance(fc, nn.Linear) + assert ln.weight.numel() == fc.in_features == act_scales.numel() + + device, dtype = fcs[0].weight.device, fcs[0].weight.dtype + act_scales = act_scales.to(device=device, dtype=dtype) + weight_scales = torch.cat([fc.weight.abs().max(dim=0, keepdim=True)[0] for fc in fcs], dim=0) + weight_scales = weight_scales.max(dim=0)[0].clamp(min=1e-5) + + scales = (act_scales.pow(alpha) / weight_scales.pow(1 - alpha)).clamp(min=1e-5).to(device).to(dtype) + + ln.weight.div_(scales) + if hasattr(ln, "bias"): + ln.bias.div_(scales) + + for fc in fcs: + fc.weight.mul_(scales.view(1, -1)) + + @classmethod + def create_quantized_model(model): + raise NotImplementedError("Not implement create_quantized_model method") + + def save_quantized( + self, + save_dir: str, + model_basename: str, + use_safetensors: bool = False, + safetensors_metadata: Optional[Dict[str, str]] = None, + ): + """save quantized model and configs to local disk""" + os.makedirs(save_dir, exist_ok=True) + + if not self.quantized: + raise EnvironmentError("can only save quantized model, please execute .quantize first.") + + self.model.to("cpu") + + model_base_name = model_basename # or f"smooth-" + if use_safetensors: + model_save_name = model_base_name + ".safetensors" + state_dict = self.model.state_dict() + state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()} + if safetensors_metadata is None: + safetensors_metadata = {} + elif not isinstance(safetensors_metadata, dict): + raise TypeError("safetensors_metadata must be a dictionary.") + else: + print(f"Received safetensors_metadata: {safetensors_metadata}") + new_safetensors_metadata = {} + converted_keys = False + for key, value in safetensors_metadata.items(): + if not isinstance(key, str) or not isinstance(value, str): + converted_keys = True + try: + new_key = str(key) + new_value = str(value) + except Exception as e: + raise TypeError( + f"safetensors_metadata: both keys and values must be strings and an error occured when trying to convert them: {e}" + ) + if new_key in new_safetensors_metadata: + print( + f"After converting safetensors_metadata keys to strings, the key '{new_key}' is duplicated. Ensure that all your metadata keys are strings to avoid overwriting." + ) + new_safetensors_metadata[new_key] = new_value + safetensors_metadata = new_safetensors_metadata + if converted_keys: + print( + f"One or more safetensors_metadata keys or values had to be converted to str(). Final safetensors_metadata: {safetensors_metadata}" + ) + + # Format is required to enable Accelerate to load the metadata + # otherwise it raises an OSError + safetensors_metadata["format"] = "pt" + + safe_save(state_dict, join(save_dir, model_save_name), safetensors_metadata) + else: + model_save_name = model_base_name + ".bin" + torch.save(self.model.state_dict(), join(save_dir, model_save_name)) + + self.model.config.save_pretrained(save_dir) + + def save_pretrained( + self, + save_dir: str, + use_safetensors: bool = False, + safetensors_metadata: Optional[Dict[str, str]] = None, + **kwargs, + ): + """alias of save_quantized""" + warnings.warn("you are using save_pretrained, which will re-direct to save_quantized.") + self.save_quantized(save_dir, use_safetensors, safetensors_metadata) + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: str, + max_memory: Optional[dict] = None, + trust_remote_code: bool = False, + torch_dtype: torch.dtype = torch.float16, + **model_init_kwargs, + ): + if not torch.cuda.is_available(): + raise EnvironmentError("Load pretrained model to do quantization requires CUDA available.") + + def skip(*args, **kwargs): + pass + + torch.nn.init.kaiming_uniform_ = skip + torch.nn.init.uniform_ = skip + torch.nn.init.normal_ = skip + + # Parameters related to loading from Hugging Face Hub + cache_dir = model_init_kwargs.pop("cache_dir", None) + force_download = model_init_kwargs.pop("force_download", False) + resume_download = model_init_kwargs.pop("resume_download", False) + proxies = model_init_kwargs.pop("proxies", None) + local_files_only = model_init_kwargs.pop("local_files_only", False) + use_auth_token = model_init_kwargs.pop("use_auth_token", None) + revision = model_init_kwargs.pop("revision", None) + subfolder = model_init_kwargs.pop("subfolder", "") + model_init_kwargs.pop("_commit_hash", None) + + cached_file_kwargs = { + "cache_dir": cache_dir, + "force_download": force_download, + "proxies": proxies, + "resume_download": resume_download, + "local_files_only": local_files_only, + "use_auth_token": use_auth_token, + "revision": revision, + "subfolder": subfolder, + } + + config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True, **cached_file_kwargs) + if config.model_type not in SUPPORTED_MODELS: + raise TypeError(f"{config.model_type} isn't supported yet.") + + # enforce some values despite user specified + model_init_kwargs["torch_dtype"] = torch_dtype + model_init_kwargs["trust_remote_code"] = trust_remote_code + if max_memory: + if "disk" in max_memory: + raise NotImplementedError("disk offload not support yet.") + with accelerate.init_empty_weights(): + model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) + model.tie_weights() + + max_memory = accelerate.utils.get_balanced_memory( + model, + max_memory=max_memory, + no_split_module_classes=[cls.layer_type], + dtype=model_init_kwargs["torch_dtype"], + low_zero=False, + ) + model_init_kwargs["device_map"] = accelerate.infer_auto_device_map( + model, + max_memory=max_memory, + no_split_module_classes=[cls.layer_type], + dtype=model_init_kwargs["torch_dtype"], + ) + model_init_kwargs["low_cpu_mem_usage"] = True + + del model + else: + model_init_kwargs["device_map"] = None + model_init_kwargs["low_cpu_mem_usage"] = False + + torch.cuda.empty_cache() + + merged_kwargs = {**model_init_kwargs, **cached_file_kwargs} + model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **merged_kwargs) + + model_config = model.config.to_dict() + seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"] + if any([k in model_config for k in seq_len_keys]): + for key in seq_len_keys: + if key in model_config: + model.seqlen = model_config[key] + break + else: + warnings.warn("can't get model's sequence length from model config, will set to 4096.") + model.seqlen = 4096 + model.eval() + + return cls(model, False) + + @classmethod + def from_quantized( + cls, + model_name_or_path: Optional[str], + model_basename: Optional[str] = None, + device_map: Optional[Union[str, Dict[str, Union[int, str]]]] = None, + max_memory: Optional[dict] = None, + device: Optional[Union[str, int]] = None, + low_cpu_mem_usage: bool = False, + torch_dtype: Optional[torch.dtype] = None, + use_safetensors: bool = False, + trust_remote_code: bool = False, + **kwargs, + ): + """load quantized model from local disk""" + + # Parameters related to loading from Hugging Face Hub + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", False) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", "") + commit_hash = kwargs.pop("_commit_hash", None) + + cached_file_kwargs = { + "cache_dir": cache_dir, + "force_download": force_download, + "proxies": proxies, + "resume_download": resume_download, + "local_files_only": local_files_only, + "use_auth_token": use_auth_token, + "revision": revision, + "subfolder": subfolder, + "_raise_exceptions_for_missing_entries": False, + "_commit_hash": commit_hash, + } + + # == step1: prepare configs and file names == # + config = AutoConfig.from_pretrained( + model_name_or_path, trust_remote_code=trust_remote_code, **cached_file_kwargs + ) + + if config.model_type not in SUPPORTED_MODELS: + raise TypeError(f"{config.model_type} isn't supported yet.") + + extensions = [] + if use_safetensors: + extensions.append(".safetensors") + else: + extensions += [".bin", ".pt"] + + model_name_or_path = str(model_name_or_path) + is_local = isdir(model_name_or_path) + + resolved_archive_file = None + if is_local: + model_save_name = join(model_name_or_path, model_basename) + for ext in extensions: + if isfile(model_save_name + ext): + resolved_archive_file = model_save_name + ext + break + else: # remote + for ext in extensions: + resolved_archive_file = cached_file(model_name_or_path, model_basename + ext, **cached_file_kwargs) + if resolved_archive_file is not None: + break + + if resolved_archive_file is None: # Could not find a model file to use + raise FileNotFoundError(f"Could not find model in {model_name_or_path}") + + model_save_name = resolved_archive_file + + # == step2: convert model to quantized-model (replace Linear) == # + def skip(*args, **kwargs): + pass + + torch.nn.init.kaiming_uniform_ = skip + torch.nn.init.uniform_ = skip + torch.nn.init.normal_ = skip + + transformers.modeling_utils._init_weights = False + + init_contexts = [no_init_weights()] + if low_cpu_mem_usage: + init_contexts.append(accelerate.init_empty_weights(include_buffers=True)) + + with ContextManagers(init_contexts): + model = AutoModelForCausalLM.from_config( + config, trust_remote_code=trust_remote_code, torch_dtype=torch_dtype + ) + cls.create_quantized_model(model) + model.tie_weights() + + # == step3: load checkpoint to quantized-model == # + accelerate.utils.modeling.load_checkpoint_in_model( + model, checkpoint=model_save_name, offload_state_dict=True, offload_buffers=True + ) + + # == step4: set seqlen == # + model_config = model.config.to_dict() + seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"] + if any([k in model_config for k in seq_len_keys]): + for key in seq_len_keys: + if key in model_config: + model.seqlen = model_config[key] + break + else: + warnings.warn("can't get model's sequence length from model config, will set to 4096.") + model.seqlen = 4096 + + return cls( + model, + True, + ) + + def __getattr__(self, item): + try: + return super().__getattr__(item) + except: + return getattr(self.model, item) + + +__all__ = ["BaseSmoothForCausalLM"] diff --git a/colossalai/inference/quant/smoothquant/models/linear.py b/colossalai/inference/quant/smoothquant/models/linear.py new file mode 100644 index 000000000000..048565bfbf5e --- /dev/null +++ b/colossalai/inference/quant/smoothquant/models/linear.py @@ -0,0 +1,177 @@ +# modified from torch-int: https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/nn/linear.py + +import torch +from torch_int._CUDA import linear_a8_w8_b8_o8, linear_a8_w8_bfp32_ofp32 +from torch_int.functional.quantization import quantize_per_tensor_absmax + +try: + from colossalai.kernel.op_builder.smoothquant import SmoothquantBuilder + + smoothquant_cuda = SmoothquantBuilder().load() + HAS_SMOOTHQUANT_CUDA = True +except ImportError: + HAS_SMOOTHQUANT_CUDA = False + raise ImportError("CUDA smoothquant linear is not installed") + + +class W8A8BFP32O32LinearSiLU(torch.nn.Module): + def __init__(self, in_features, out_features, alpha=1.0, beta=1.0): + super().__init__() + self.in_features = in_features + self.out_features = out_features + + self.register_buffer( + "weight", + torch.randint( + -127, + 127, + (self.out_features, self.in_features), + dtype=torch.int8, + requires_grad=False, + ), + ) + self.register_buffer( + "bias", + torch.zeros((1, self.out_features), dtype=torch.float, requires_grad=False), + ) + self.register_buffer("a", torch.tensor(alpha)) + + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + self.weight = self.weight.to(*args, **kwargs) + self.bias = self.bias.to(*args, **kwargs) + return self + + @torch.no_grad() + def forward(self, x): + x_shape = x.shape + x = x.view(-1, x_shape[-1]) + y = smoothquant_cuda.linear_silu_a8_w8_bfp32_ofp32(x, self.weight, self.bias, self.a.item(), 1.0) + y = y.view(*x_shape[:-1], -1) + return y + + @staticmethod + def from_float(module: torch.nn.Linear, input_scale): + int8_module = W8A8BFP32O32LinearSiLU(module.in_features, module.out_features) + int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight) + alpha = input_scale * weight_scale + int8_module.weight = int8_weight + if module.bias is not None: + int8_module.bias.data.copy_(module.bias.to(torch.float)) + int8_module.a = alpha + return int8_module + + +class W8A8B8O8Linear(torch.nn.Module): + # For qkv_proj + def __init__(self, in_features, out_features, alpha=1.0, beta=1.0): + super().__init__() + self.in_features = in_features + self.out_features = out_features + + self.register_buffer( + "weight", + torch.randint( + -127, + 127, + (self.out_features, self.in_features), + dtype=torch.int8, + requires_grad=False, + ), + ) + self.register_buffer( + "bias", + torch.zeros((1, self.out_features), dtype=torch.int8, requires_grad=False), + ) + self.register_buffer("a", torch.tensor(alpha)) + self.register_buffer("b", torch.tensor(beta)) + + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + self.weight = self.weight.to(*args, **kwargs) + self.bias = self.bias.to(*args, **kwargs) + return self + + @torch.no_grad() + def forward(self, x): + x_shape = x.shape + x = x.view(-1, x_shape[-1]) + y = linear_a8_w8_b8_o8(x, self.weight, self.bias, self.a.item(), self.b.item()) + y = y.view(*x_shape[:-1], -1) + return y + + @staticmethod + def from_float(module: torch.nn.Linear, input_scale, output_scale): + int8_module = W8A8B8O8Linear(module.in_features, module.out_features) + int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight) + alpha = input_scale * weight_scale / output_scale + int8_module.weight = int8_weight + int8_module.a = alpha + + if module.bias is not None: + int8_bias, bias_scale = quantize_per_tensor_absmax(module.bias) + int8_module.bias = int8_bias + beta = bias_scale / output_scale + int8_module.b = beta + + return int8_module + + +class W8A8BFP32OFP32Linear(torch.nn.Module): + # For fc2 and out_proj + def __init__(self, in_features, out_features, alpha=1.0, beta=1.0): + super().__init__() + self.in_features = in_features + self.out_features = out_features + + self.register_buffer( + "weight", + torch.randint( + -127, + 127, + (self.out_features, self.in_features), + dtype=torch.int8, + requires_grad=False, + ), + ) + self.register_buffer( + "bias", + torch.zeros(self.out_features, dtype=torch.float32, requires_grad=False), + ) + self.register_buffer("a", torch.tensor(alpha)) + + def _apply(self, fn): + # prevent the bias from being converted to half + super()._apply(fn) + self.bias = self.bias.to(torch.float32) + return self + + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + self.weight = self.weight.to(*args, **kwargs) + self.bias = self.bias.to(*args, **kwargs) + self.bias = self.bias.to(torch.float32) + return self + + @torch.no_grad() + def forward(self, x): + x_shape = x.shape + x = x.view(-1, x_shape[-1]) + y = linear_a8_w8_bfp32_ofp32(x, self.weight, self.bias, self.a.item(), 1) + y = y.view(*x_shape[:-1], -1) + return y + + @staticmethod + def from_float(module: torch.nn.Linear, input_scale): + int8_module = W8A8BFP32OFP32Linear(module.in_features, module.out_features) + int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight) + alpha = input_scale * weight_scale + int8_module.weight = int8_weight + int8_module.a = alpha + int8_module.input_scale = input_scale + int8_module.weight_scale = weight_scale + + if module.bias is not None: + int8_module.bias = module.bias.to(torch.float32) + + return int8_module diff --git a/colossalai/inference/quant/smoothquant/models/llama.py b/colossalai/inference/quant/smoothquant/models/llama.py new file mode 100644 index 000000000000..9c77feeb346e --- /dev/null +++ b/colossalai/inference/quant/smoothquant/models/llama.py @@ -0,0 +1,846 @@ +import math +import os +import types +from collections import defaultdict +from functools import partial +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch_int.nn.bmm import BMM_S8T_S8N_F32T, BMM_S8T_S8N_S8T +from transformers import PreTrainedModel +from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import ( + LLAMA_INPUTS_DOCSTRING, + LlamaAttention, + LlamaDecoderLayer, + LlamaMLP, + LlamaRotaryEmbedding, + repeat_kv, + rotate_half, +) +from transformers.utils import add_start_docstrings_to_model_forward + +from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState +from colossalai.kernel.triton import ( + copy_kv_cache_to_dest, + int8_rotary_embedding_fwd, + smooth_llama_context_attn_fwd, + smooth_token_attention_fwd, +) + +from .base_model import BaseSmoothForCausalLM +from .linear import W8A8B8O8Linear, W8A8BFP32O32LinearSiLU, W8A8BFP32OFP32Linear + + +class LLamaSmoothquantAttention(nn.Module): + def __init__( + self, + hidden_size: int, + num_heads: int, + ): + super().__init__() + self.hidden_size = hidden_size + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + + if (self.head_dim * num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {num_heads})." + ) + + self.qk_bmm = BMM_S8T_S8N_F32T(1.0) + self.pv_bmm = BMM_S8T_S8N_S8T(1.0) + + self.k_proj = W8A8B8O8Linear(hidden_size, hidden_size) + self.v_proj = W8A8B8O8Linear(hidden_size, hidden_size) + self.q_proj = W8A8B8O8Linear(hidden_size, hidden_size) + self.o_proj = W8A8BFP32OFP32Linear(hidden_size, hidden_size) + + self.register_buffer("q_output_scale", torch.tensor([1.0])) + self.register_buffer("k_output_scale", torch.tensor([1.0])) + self.register_buffer("v_output_scale", torch.tensor([1.0])) + self.register_buffer("q_rotary_output_scale", torch.tensor([1.0])) + self.register_buffer("k_rotary_output_scale", torch.tensor([1.0])) + self.register_buffer("out_input_scale", torch.tensor([1.0])) + self.register_buffer("attn_input_scale", torch.tensor([1.0])) + + self._init_rope() + self.num_key_value_heads = num_heads + + def _init_rope(self): + self.rotary_emb = LlamaRotaryEmbedding( + self.head_dim, + max_position_embeddings=2048, + base=10000.0, + ) + + @staticmethod + def pack( + module: LlamaAttention, + attn_input_scale: float, + q_output_scale: float, + k_output_scale: float, + v_output_scale: float, + q_rotary_output_scale: float, + k_rotary_output_scale: float, + out_input_scale: float, + ): + int8_module = LLamaSmoothquantAttention(module.hidden_size, module.num_heads) + + int8_module.attn_input_scale = torch.tensor([attn_input_scale]) + + int8_module.q_output_scale = torch.tensor([q_output_scale]) + int8_module.k_output_scale = torch.tensor([k_output_scale]) + int8_module.v_output_scale = torch.tensor([v_output_scale]) + + int8_module.q_rotary_output_scale = torch.tensor([q_rotary_output_scale]) + int8_module.k_rotary_output_scale = torch.tensor([k_rotary_output_scale]) + + int8_module.q_proj = W8A8B8O8Linear.from_float(module.q_proj, attn_input_scale, q_output_scale) + int8_module.k_proj = W8A8B8O8Linear.from_float(module.k_proj, attn_input_scale, k_output_scale) + int8_module.v_proj = W8A8B8O8Linear.from_float(module.v_proj, attn_input_scale, v_output_scale) + int8_module.o_proj = W8A8BFP32OFP32Linear.from_float(module.o_proj, out_input_scale) + + int8_module.out_input_scale = torch.tensor([out_input_scale]) + + return int8_module + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + @torch.no_grad() + def forward( + self, + hidden_states: torch.Tensor, + rotary_emb: Tuple[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + padding_mask: Optional[torch.LongTensor] = None, + infer_state: Optional[BatchInferState] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + cos = rotary_emb[0] + sin = rotary_emb[1] + + int8_rotary_embedding_fwd( + query_states.view(-1, self.num_heads, self.head_dim), + cos, + sin, + self.q_output_scale.item(), + self.q_rotary_output_scale.item(), + ) + int8_rotary_embedding_fwd( + key_states.view(-1, self.num_heads, self.head_dim), + cos, + sin, + self.k_output_scale.item(), + self.k_rotary_output_scale.item(), + ) + + # NOTE might want to revise + # need some way to record the length of past key values cache + # since we won't return past_key_value_cache right now + if infer_state.decode_layer_id == 0: # once per model.forward + infer_state.cache_manager.past_key_values_length += q_len # seq_len + + def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager): + copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id]) + copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id]) + return + + query_states = query_states.view(-1, self.num_heads, self.head_dim) + key_states = key_states.view(-1, self.num_heads, self.head_dim) + value_states = value_states.view(-1, self.num_heads, self.head_dim) + + if infer_state.is_context_stage: + # first token generation + + # copy key and value calculated in current step to memory manager + _copy_kv_to_mem_cache( + infer_state.decode_layer_id, + key_states, + value_states, + infer_state.context_mem_index, + infer_state.cache_manager, + ) + + attn_output = torch.empty_like(query_states) + + smooth_llama_context_attn_fwd( + query_states, + key_states, + value_states, + attn_output, + self.q_rotary_output_scale.item(), + self.k_rotary_output_scale.item(), + self.v_output_scale.item(), + self.out_input_scale.item(), + infer_state.start_loc, + infer_state.seq_len, + q_len, + ) + + else: + if infer_state.decode_is_contiguous: + # if decode is contiguous, then we copy to key cache and value cache in cache manager directly + cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][ + infer_state.decode_mem_start : infer_state.decode_mem_end, :, : + ] + cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][ + infer_state.decode_mem_start : infer_state.decode_mem_end, :, : + ] + cache_k.copy_(key_states) + cache_v.copy_(value_states) + else: + # if decode is not contiguous, use triton kernel to copy key and value cache + # k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head + _copy_kv_to_mem_cache( + infer_state.decode_layer_id, + key_states, + value_states, + infer_state.decode_mem_index, + infer_state.cache_manager, + ) + + # (batch_size, seqlen, nheads, headdim) + attn_output = torch.empty_like(query_states) + + smooth_token_attention_fwd( + query_states, + infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], + infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], + attn_output, + self.q_rotary_output_scale.item(), + self.k_rotary_output_scale.item(), + self.v_output_scale.item(), + self.out_input_scale.item(), + infer_state.block_loc, + infer_state.start_loc, + infer_state.seq_len, + infer_state.cache_manager.past_key_values_length, + ) + + attn_output = attn_output.view(bsz, q_len, self.num_heads * self.head_dim) + attn_output = self.o_proj(attn_output) + + return attn_output, None, None + + +class LlamaLayerNormQ(torch.nn.Module): + def __init__(self, dim, eps=1e-5): + super().__init__() + self.input_scale = 1.0 + self.variance_epsilon = eps + self.register_buffer("weight", torch.ones(dim, dtype=torch.float32)) + + def forward(self, x): + ln_output_fp = torch.nn.functional.layer_norm(x, x.shape[-1:], self.weight, None, self.variance_epsilon) + ln_output_int8 = ln_output_fp.round().clamp(-128, 127).to(torch.int8) + return ln_output_int8 + + @staticmethod + def from_float(module: torch.nn.LayerNorm, output_scale: float): + assert module.weight.shape[0] == module.weight.numel() + q_module = LlamaLayerNormQ(module.weight.shape[0], module.variance_epsilon) + q_module.weight = module.weight / output_scale + return q_module + + +class LlamaSmoothquantMLP(nn.Module): + def __init__(self, intermediate_size, hidden_size): + super().__init__() + self.gate_proj = W8A8BFP32O32LinearSiLU(hidden_size, intermediate_size) + self.up_proj = W8A8BFP32OFP32Linear(hidden_size, intermediate_size) + self.down_proj = W8A8BFP32OFP32Linear(intermediate_size, hidden_size) + self.register_buffer("down_proj_input_scale", torch.tensor([1.0])) + + @staticmethod + def pack( + mlp_module: LlamaMLP, + gate_proj_input_scale: float, + up_proj_input_scale: float, + down_proj_input_scale: float, + ): + int8_module = LlamaSmoothquantMLP( + mlp_module.intermediate_size, + mlp_module.hidden_size, + ) + + int8_module.gate_proj = W8A8BFP32O32LinearSiLU.from_float(mlp_module.gate_proj, gate_proj_input_scale) + int8_module.up_proj = W8A8BFP32OFP32Linear.from_float(mlp_module.up_proj, up_proj_input_scale) + int8_module.down_proj = W8A8BFP32OFP32Linear.from_float(mlp_module.down_proj, down_proj_input_scale) + int8_module.down_proj_input_scale = torch.tensor([down_proj_input_scale]) + return int8_module + + def forward( + self, + hidden_states: torch.Tensor, + ): + x_shape = hidden_states.shape + gate_out = self.gate_proj(hidden_states) + up_out = self.up_proj(hidden_states) + inter_out = gate_out * up_out + inter_out = inter_out.div_(self.down_proj_input_scale.item()).round().clamp(-128, 127).to(torch.int8) + down_out = self.down_proj(inter_out) + down_out = down_out.view(*x_shape[:-1], -1) + return down_out + + +class LlamaSmoothquantDecoderLayer(nn.Module): + def __init__(self, config: LlamaConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = LLamaSmoothquantAttention(config.hidden_size, config.num_attention_heads) + + self.mlp = LlamaSmoothquantMLP(config.intermediate_size, config.hidden_size) + self.input_layernorm = LlamaLayerNormQ(config.hidden_size, eps=config.rms_norm_eps) + + self.post_attention_layernorm = LlamaLayerNormQ(config.hidden_size, eps=config.rms_norm_eps) + + @staticmethod + def pack( + module: LlamaDecoderLayer, + attn_input_scale: float, + q_output_scale: float, + k_output_scale: float, + v_output_scale: float, + q_rotary_output_scale: float, + k_rotary_output_scale: float, + out_input_scale: float, + gate_input_scale: float, + up_input_scale: float, + down_input_scale: float, + ): + config = module.self_attn.config + int8_decoder_layer = LlamaSmoothquantDecoderLayer(config) + + int8_decoder_layer.input_layernorm = LlamaLayerNormQ.from_float(module.input_layernorm, attn_input_scale) + int8_decoder_layer.self_attn = LLamaSmoothquantAttention.pack( + module.self_attn, + attn_input_scale, + q_output_scale, + k_output_scale, + v_output_scale, + q_rotary_output_scale, + k_rotary_output_scale, + out_input_scale, + ) + + int8_decoder_layer.post_attention_layernorm = LlamaLayerNormQ.from_float( + module.post_attention_layernorm, gate_input_scale + ) + + int8_decoder_layer.mlp = LlamaSmoothquantMLP.pack( + module.mlp, + gate_input_scale, + up_input_scale, + down_input_scale, + ) + + return int8_decoder_layer + + def forward( + self, + hidden_states: torch.Tensor, + rotary_emb: Tuple[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + padding_mask: Optional[torch.LongTensor] = None, + infer_state: Optional[BatchInferState] = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + rotary_emb=rotary_emb, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + padding_mask=padding_mask, + infer_state=infer_state, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states, None, None + + +class LlamaApplyRotary(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, cos, sin, position_ids): + # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. + cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] + sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] + cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + x_embed = (x * cos) + (rotate_half(x) * sin) + + return x_embed + + +def llama_decoder_layer_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + padding_mask: Optional[torch.LongTensor] = None, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states = self.q_apply_rotary(query_states, cos, sin, position_ids) + key_states = self.k_apply_rotary(key_states, cos, sin, position_ids) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +def init_to_get_rotary(config, base=10000, use_elem=False): + """ + This function initializes the rotary positional embedding, it is compatible for all models and is called in ShardFormer + Args: + base : calculation arg + use_elem : activated when using chatglm-based models + """ + config.head_dim_ = config.hidden_size // config.num_attention_heads + if not hasattr(config, "rope_scaling"): + rope_scaling_factor = 1.0 + else: + rope_scaling_factor = config.rope_scaling.factor if config.rope_scaling is not None else 1.0 + + if hasattr(config, "max_sequence_length"): + max_seq_len = config.max_sequence_length + elif hasattr(config, "max_position_embeddings"): + max_seq_len = config.max_position_embeddings * rope_scaling_factor + else: + max_seq_len = 2048 * rope_scaling_factor + base = float(base) + + # NTK ref: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ + try: + ntk_alpha = float(os.environ.get("INFER_NTK_ALPHA", 1)) + assert ntk_alpha >= 1 + if ntk_alpha > 1: + print(f"Note: NTK enabled, alpha set to {ntk_alpha}") + max_seq_len *= ntk_alpha + base = base * (ntk_alpha ** (config.head_dim_ / (config.head_dim_ - 2))) # Base change formula + except: + pass + + n_elem = config.head_dim_ + if use_elem: + n_elem //= 2 + + inv_freq = 1.0 / (base ** (torch.arange(0, n_elem, 2, device="cpu", dtype=torch.float32) / n_elem)) + t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor + freqs = torch.outer(t, inv_freq) + + _cos_cached = torch.cos(freqs).to(torch.float) + _sin_cached = torch.sin(freqs).to(torch.float) + return _cos_cached, _sin_cached + + +@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) +def llama_model_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, +) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + seq_length_with_past = seq_length + past_key_values_length = 0 + + infer_state = self.infer_state + + if past_key_values is not None: + # NOT READY FOR PRIME TIME + # dummy but work, revise it + past_key_values_length = infer_state.cache_manager.past_key_values_length + # past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + # NOTE: differentiate with prefill stage + # block_loc require different value-assigning method for two different stage + # NOTE: differentiate with prefill stage + # block_loc require different value-assigning method for two different stage + if infer_state.is_context_stage: + infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num) + infer_state.init_block_loc( + infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index + ) + else: + alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size) + if alloc_mem is not None: + infer_state.decode_is_contiguous = True + infer_state.decode_mem_index = alloc_mem[0] + infer_state.decode_mem_start = alloc_mem[1] + infer_state.decode_mem_end = alloc_mem[2] + infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index + else: + print(f" *** Encountered allocation non-contiguous") + print( + f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}" + ) + infer_state.decode_is_contiguous = False + alloc_mem = infer_state.cache_manager.alloc(batch_size) + infer_state.decode_mem_index = alloc_mem + infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + # embed positions + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device) + padding_mask = None + else: + if 0 in attention_mask: + padding_mask = attention_mask + else: + padding_mask = None + + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + raise NotImplementedError("not implement gradient_checkpointing and training options ") + + if past_key_values_length == 0: + position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view( + position_ids.view(-1).shape[0], -1 + ) + position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view( + position_ids.view(-1).shape[0], -1 + ) + else: + position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view(batch_size, -1) + position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(batch_size, -1) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + infer_state.decode_layer_id = 0 + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + layer_outputs = decoder_layer( + hidden_states, + rotary_emb=(position_cos, position_sin), + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + padding_mask=padding_mask, + infer_state=infer_state, + ) + + hidden_states = layer_outputs[0] + infer_state.decode_layer_id += 1 + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + infer_state.is_context_stage = False + infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") + infer_state.seq_len += 1 + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class SmoothLlamaForCausalLM(BaseSmoothForCausalLM): + layer_type = "LlamaDecoderLayer" + + def __init__(self, model: PreTrainedModel, quantized: bool = False): + super().__init__(model, quantized) + + def get_act_dict( + self, + tokenizer, + dataset, + num_samples=512, + seq_len=512, + ): + llama_model = self.model + + llama_model.eval() + device = next(llama_model.parameters()).device + # print("model:", llama_model) + act_dict = defaultdict(dict) + + def stat_io_hook(m, x, y, name): + if isinstance(x, tuple): + x = x[0] + if name not in act_dict or "input" not in act_dict[name]: + act_dict[name]["input"] = x.detach().abs().max().item() + else: + act_dict[name]["input"] = max(act_dict[name]["input"], x.detach().abs().max().item()) + if isinstance(y, tuple): + y = y[0] + if name not in act_dict or "output" not in act_dict[name]: + act_dict[name]["output"] = y.detach().abs().max().item() + else: + act_dict[name]["output"] = max(act_dict[name]["output"], y.detach().abs().max().item()) + + for name, m in llama_model.named_modules(): + if isinstance(m, LlamaAttention): + setattr(m, "q_apply_rotary", LlamaApplyRotary()) + setattr(m, "k_apply_rotary", LlamaApplyRotary()) + m.forward = types.MethodType(llama_decoder_layer_forward, m) + + hooks = [] + for name, m in llama_model.named_modules(): + if isinstance(m, LlamaApplyRotary): + hooks.append(m.register_forward_hook(partial(stat_io_hook, name=name))) + if isinstance(m, torch.nn.Linear): + hooks.append(m.register_forward_hook(partial(stat_io_hook, name=name))) + + self.collect_act_dict(llama_model, tokenizer, dataset, act_dict, device, num_samples, seq_len) + + for hook in hooks: + hook.remove() + return act_dict + + def smooth_fn(self, scales, alpha=0.5): + model = self.model + for name, module in model.named_modules(): + if isinstance(module, LlamaDecoderLayer): + attn_ln = module.input_layernorm + qkv = [module.self_attn.q_proj, module.self_attn.k_proj, module.self_attn.v_proj] + qkv_input_scales = scales[name + ".self_attn.q_proj"] + self.smooth_ln_fcs(attn_ln, qkv, qkv_input_scales, alpha) + + def create_quantized_model(model): + llama_config = model.config + for i, layer in enumerate(model.model.layers): + model.model.layers[i] = LlamaSmoothquantDecoderLayer(llama_config) + + model.model.forward = types.MethodType(llama_model_forward, model.model) + cos, sin = init_to_get_rotary(llama_config) + model.model.register_buffer("_cos_cached", cos) + model.model.register_buffer("_sin_cached", sin) + + def quantized( + self, + tokenizer, + dataset, + num_samples=512, + seq_len=512, + alpha=0.5, + ): + llama_model = self.model + llama_config = llama_model.config + + act_scales = self.get_act_scales(llama_model, tokenizer, dataset, num_samples, seq_len) + + self.smooth_fn(act_scales, alpha) + + act_dict = self.get_act_dict(tokenizer, dataset, num_samples, seq_len) + decoder_layer_scales = [] + + for idx in range(llama_config.num_hidden_layers): + scale_dict = {} + scale_dict["attn_input_scale"] = act_dict[f"model.layers.{idx}.self_attn.q_proj"]["input"] / 127 + scale_dict["q_output_scale"] = act_dict[f"model.layers.{idx}.self_attn.q_proj"]["output"] / 127 + scale_dict["k_output_scale"] = act_dict[f"model.layers.{idx}.self_attn.k_proj"]["output"] / 127 + scale_dict["v_output_scale"] = act_dict[f"model.layers.{idx}.self_attn.v_proj"]["output"] / 127 + + scale_dict["q_rotary_output_scale"] = ( + act_dict[f"model.layers.{idx}.self_attn.q_apply_rotary"]["output"] / 127 + ) + scale_dict["k_rotary_output_scale"] = ( + act_dict[f"model.layers.{idx}.self_attn.k_apply_rotary"]["output"] / 127 + ) + + scale_dict["out_input_scale"] = act_dict[f"model.layers.{idx}.self_attn.o_proj"]["input"] / 127 + + scale_dict["gate_input_scale"] = act_dict[f"model.layers.{idx}.mlp.gate_proj"]["input"] / 127 + scale_dict["up_input_scale"] = act_dict[f"model.layers.{idx}.mlp.up_proj"]["input"] / 127 + scale_dict["down_input_scale"] = act_dict[f"model.layers.{idx}.mlp.down_proj"]["input"] / 127 + + decoder_layer_scales.append(scale_dict) + + for i, layer in enumerate(llama_model.model.layers): + orig_layer = layer + llama_model.model.layers[i] = LlamaSmoothquantDecoderLayer.pack(orig_layer, **decoder_layer_scales[i]) + + llama_model.model.forward = types.MethodType(llama_model_forward, llama_model.model) + + cos, sin = init_to_get_rotary(llama_config) + llama_model.model.register_buffer("_cos_cached", cos.to(self.model.device)) + llama_model.model.register_buffer("_sin_cached", sin.to(self.model.device)) diff --git a/colossalai/kernel/cuda_native/csrc/smoothquant/binding.cpp b/colossalai/kernel/cuda_native/csrc/smoothquant/binding.cpp new file mode 100644 index 000000000000..8444272940b4 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/smoothquant/binding.cpp @@ -0,0 +1,8 @@ +#include + +#include "linear.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("linear_silu_a8_w8_bfp32_ofp32", &linear_silu_a8_w8_bfp32_ofp32, + "Linear SiLU (INT8)"); +} diff --git a/colossalai/kernel/cuda_native/csrc/smoothquant/linear.cu b/colossalai/kernel/cuda_native/csrc/smoothquant/linear.cu new file mode 100644 index 000000000000..a30d02a4cf42 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/smoothquant/linear.cu @@ -0,0 +1,162 @@ +// modified from https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/kernels/linear.cu + +#include "linear.h" +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +torch::Tensor linear_silu_a8_w8_bfp32_ofp32(torch::Tensor input, // INT8 + torch::Tensor weight, // INT8 + torch::Tensor bias, // FP32 + float alpha, // FP32 + float beta // FP32 +) { + auto M = input.size(0); + auto N = weight.size(0); + auto K = input.size(1); + + using ElementOutput = float; + using ElementAccumulator = int32_t; + using ElementComputeEpilogue = float; + using ElementInputA = int8_t; // <- data type of elements in input matrix A + using ElementInputB = int8_t; // <- data type of elements in input matrix B + + // The code section below describes matrix layout of input and output + // matrices. Column Major for Matrix A, Row Major for Matrix B and Row Major + // for Matrix C + using LayoutInputA = cutlass::layout::RowMajor; + using LayoutInputB = cutlass::layout::ColumnMajor; + using LayoutOutput = cutlass::layout::RowMajor; + +#if CUDA_ARCH >= 800 + using EpilogueOp = cutlass::epilogue::thread::LinearCombinationSilu< + ElementOutput, // <- data type of output matrix + 128 / cutlass::sizeof_bits< + ElementOutput>::value, // <- this is the number of elements per + // vectorized memory access. For half + // precision, it's 8 elements. This + // becomes the vector width of math + // instructions in epilogue too + ElementAccumulator, // <- data type of accumulator + ElementComputeEpilogue // <- data type for alpha in linear combination + // function + >; + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, + ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + EpilogueOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; +#elif CUDA_ARCH >= 750 + using EpilogueOp = cutlass::epilogue::thread::LinearCombinationSilu< + ElementOutput, // <- data type of output matrix + 128 / cutlass::sizeof_bits< + ElementOutput>::value, // <- this is the number of elements per + // vectorized memory access. For half + // precision, it's 8 elements. This + // becomes the vector width of math + // instructions in epilogue too + ElementAccumulator, // <- data type of accumulator + ElementComputeEpilogue // <- data type for alpha in linear combination + // function + >; + + using DefaultGemmCfg = cutlass::gemm::device::DefaultGemmConfiguration< + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, + ElementInputA, ElementInputB, ElementOutput, ElementAccumulator>; + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, + ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, + DefaultGemmCfg::ThreadblockShape, DefaultGemmCfg::WarpShape, + DefaultGemmCfg::InstructionShape, + EpilogueOp>; +#elif CUDA_ARCH >= 700 + #define USE_TORCH_SILU + using DefaultGemmCfg = cutlass::gemm::device::DefaultGemmConfiguration< + cutlass::arch::OpClassSimt, cutlass::arch::Sm70, + ElementInputA, ElementInputB, ElementOutput, ElementAccumulator>; + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, + ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, + cutlass::arch::OpClassSimt, cutlass::arch::Sm70, + DefaultGemmCfg::ThreadblockShape, DefaultGemmCfg::WarpShape, + DefaultGemmCfg::InstructionShape, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 1, ElementAccumulator, ElementComputeEpilogue>>; +#else + #error "Unsupported cuda arch" +#endif + + auto input_size = cutlass::MatrixCoord(M, K); + auto weight_size = cutlass::MatrixCoord(K, N); + auto output_size = cutlass::MatrixCoord(M, N); + + auto device = input.device(); + // use the broadcasted bias as the output + auto out = bias.to(device).view({1, -1}).repeat({M, 1}); + + // constexpr int kSparse = Gemm::kSparse; + // How many elements of A are covered per ElementE + // constexpr int kElementsPerElementE = Gemm::kElementsPerElementE; + // The size of individual meta data + // constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits; + cutlass::gemm::GemmCoord problem_size(M, N, K); + + cutlass::TensorRef input_ref( + input.data_ptr(), LayoutInputA::packed(input_size)); + cutlass::TensorRef weight_ref( + weight.data_ptr(), LayoutInputB::packed(weight_size)); + cutlass::TensorRef out_ref( + out.data_ptr(), LayoutOutput::packed(output_size)); + + typename Gemm::Arguments arguments{ + problem_size, // <- problem size of matrix multiplication + input_ref, // <- reference to matrix A on device + weight_ref, // <- reference to matrix B on device + out_ref, // <- reference to matrix C on device + out_ref, // <- reference to matrix D on device + {alpha, beta}, 1}; + Gemm gemm_op; + + // Using the arguments, query for extra workspace required for matrix + // multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check the problem size is supported or not + cutlass::Status status = gemm_op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot implement"); + } + + // Initialize CUTLASS kernel with arguments and workspace pointer + status = gemm_op.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot initialize"); + } + + status = gemm_op(); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot run"); + } +#ifdef USE_TORCH_SILU +#undef USE_TORCH_SILU + out = torch::silu(out); +#endif + return out; +} diff --git a/colossalai/kernel/cuda_native/csrc/smoothquant/linear.h b/colossalai/kernel/cuda_native/csrc/smoothquant/linear.h new file mode 100644 index 000000000000..b62a27f3f8f3 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/smoothquant/linear.h @@ -0,0 +1,12 @@ +#include +#include + +#include +#include + +torch::Tensor linear_silu_a8_w8_bfp32_ofp32(torch::Tensor input, // INT8 + torch::Tensor weight, // INT8 + torch::Tensor bias, // FP32 + float alpha, // FP32 + float beta // FP32 +); diff --git a/colossalai/kernel/triton/__init__.py b/colossalai/kernel/triton/__init__.py index f065b2100fa8..27351a686d2f 100644 --- a/colossalai/kernel/triton/__init__.py +++ b/colossalai/kernel/triton/__init__.py @@ -13,8 +13,10 @@ from .copy_kv_cache_dest import copy_kv_cache_to_dest from .fused_layernorm import layer_norm from .gptq_triton import gptq_fused_linear_triton + from .int8_rotary_embedding_kernel import int8_rotary_embedding_fwd from .rms_norm import rmsnorm_forward from .rotary_embedding_kernel import rotary_embedding_fwd + from .smooth_attention import smooth_llama_context_attn_fwd, smooth_token_attention_fwd from .softmax import softmax from .token_attention_kernel import token_attention_fwd @@ -29,4 +31,7 @@ "rotary_embedding_fwd", "token_attention_fwd", "gptq_fused_linear_triton", + "int8_rotary_embedding_fwd", + "smooth_llama_context_attn_fwd", + "smooth_token_attention_fwd", ] diff --git a/colossalai/kernel/triton/int8_rotary_embedding_kernel.py b/colossalai/kernel/triton/int8_rotary_embedding_kernel.py new file mode 100644 index 000000000000..537dd164d1ab --- /dev/null +++ b/colossalai/kernel/triton/int8_rotary_embedding_kernel.py @@ -0,0 +1,117 @@ +# Adapted from ModelTC https://github.com/ModelTC/lightllm +import torch +import triton +import triton.language as tl + + +@triton.jit +def _rotary_kernel( + q, + input_scale, + output_scale, + Cos, + Sin, + q_bs_stride, + q_h_stride, + q_d_stride, + cos_bs_stride, + cos_d_stride, + total_len, + HEAD_NUM: tl.constexpr, + BLOCK_HEAD: tl.constexpr, + BLOCK_SEQ: tl.constexpr, + HEAD_DIM: tl.constexpr, +): + current_head_index = tl.program_id(0) + current_seq_index = tl.program_id(1) + + dim_range0 = tl.arange(0, HEAD_DIM // 2) + dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM) + + current_head_range = current_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD) + current_seq_range = current_seq_index * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ) + + off_q0 = ( + current_seq_range[:, None, None] * q_bs_stride + + current_head_range[None, :, None] * q_h_stride + + dim_range0[None, None, :] * q_d_stride + ) + off_q1 = ( + current_seq_range[:, None, None] * q_bs_stride + + current_head_range[None, :, None] * q_h_stride + + dim_range1[None, None, :] * q_d_stride + ) + + off_dimcos_sin = current_seq_range[:, None, None] * cos_bs_stride + dim_range0[None, None, :] * cos_d_stride + + q0 = tl.load( + q + off_q0, + mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), + other=0.0, + ) + q1 = tl.load( + q + off_q1, + mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), + other=0.0, + ) + + cos = tl.load(Cos + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0) + sin = tl.load(Sin + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0) + + q0 = q0.to(tl.float32) * input_scale + q1 = q1.to(tl.float32) * input_scale + + out0 = (q0 * cos - q1 * sin) / output_scale + out1 = (q0 * sin + q1 * cos) / output_scale + + out0 = out0.to(tl.int8) + out1 = out1.to(tl.int8) + + tl.store( + q + off_q0, + out0, + mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), + ) + tl.store( + q + off_q1, + out1, + mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), + ) + + return + + +@torch.no_grad() +def int8_rotary_embedding_fwd(q, cos, sin, input_scale, output_scale): + total_len = q.shape[0] + head_num = q.shape[1] + head_dim = q.shape[2] + assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}" + BLOCK_HEAD = 4 + BLOCK_SEQ = 32 + grid = (triton.cdiv(head_num, BLOCK_HEAD), triton.cdiv(total_len, BLOCK_SEQ)) + if head_dim >= 128: + num_warps = 8 + else: + num_warps = 4 + + _rotary_kernel[grid]( + q, + input_scale, + output_scale, + cos, + sin, + q.stride(0), + q.stride(1), + q.stride(2), + cos.stride(0), + cos.stride(1), + total_len, + HEAD_NUM=head_num, + BLOCK_HEAD=BLOCK_HEAD, + BLOCK_SEQ=BLOCK_SEQ, + HEAD_DIM=head_dim, + num_warps=num_warps, + num_stages=1, + ) + return diff --git a/colossalai/kernel/triton/smooth_attention.py b/colossalai/kernel/triton/smooth_attention.py new file mode 100644 index 000000000000..ee0df6a74eaa --- /dev/null +++ b/colossalai/kernel/triton/smooth_attention.py @@ -0,0 +1,652 @@ +import math + +import torch + +try: + import triton + import triton.language as tl + + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +if HAS_TRITON: + """ + this function is modified from + https://github.com/ModelTC/lightllm/blob/f093edc20683ac3ea1bca3fb5d8320a0dd36cf7b/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L10 + """ + + @triton.jit + def _context_flash_attention_kernel( + Q, + K, + V, + q_input_scale, + k_input_scale, + v_input_scale, + pv_output_scale, + sm_scale, + B_Start_Loc, + B_Seqlen, + TMP, + alibi_ptr, + Out, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_tmp_b, + stride_tmp_h, + stride_tmp_s, + # suggtest set-up 64, 128, 256, 512 + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + batch_id = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + + # get batch info + cur_batch_seq_len = tl.load(B_Seqlen + batch_id) + cur_batch_start_index = tl.load(B_Start_Loc + batch_id) + block_start_loc = BLOCK_M * start_m + + load_p_ptrs = ( + Q + + (cur_batch_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + + offs_d[None, :] * stride_qd + ) + q = tl.load(load_p_ptrs, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) + q = q.to(tl.float16) * q_input_scale.to(tl.float16) + + k_ptrs = K + offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd + v_ptrs = V + offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd + t_ptrs = TMP + batch_id * stride_tmp_b + cur_head * stride_tmp_h + offs_m * stride_tmp_s + + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + if alibi_ptr is not None: + alibi_m = tl.load(alibi_ptr + cur_head) + + block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) + + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + k = tl.load( + k_ptrs + (cur_batch_start_index + start_n) * stride_kbs, + mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, + other=0.0, + ) + k = k.to(tl.float16) * k_input_scale.to(tl.float16) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk *= sm_scale + + if alibi_ptr is not None: + alibi_loc = offs_m[:, None] - (start_n + offs_n[None, :]) + qk -= alibi_loc * alibi_m + + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) + + m_ij = tl.max(qk, 1) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) + alpha = tl.exp(m_i - m_i_new) + beta = tl.exp(m_ij - m_i_new) + l_i_new = alpha * l_i + beta * l_ij + # -- update output accumulator -- + # scale p + p_scale = beta / l_i_new + p = p * p_scale[:, None] + # scale acc + acc_scale = l_i / l_i_new * alpha + tl.store(t_ptrs, acc_scale) + acc_scale = tl.load(t_ptrs) + acc = acc * acc_scale[:, None] + # update acc + v = tl.load( + v_ptrs + (cur_batch_start_index + start_n) * stride_vbs, + mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, + other=0.0, + ) + + v = v.to(tl.float16) * v_input_scale.to(tl.float16) + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + acc = (acc / pv_output_scale.to(tl.float16)).to(tl.int8) + off_o = ( + (cur_batch_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od + ) + out_ptrs = Out + off_o + tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) + return + + + + @torch.no_grad() + def smooth_llama_context_attn_fwd( + q, k, v, o, q_input_scale, k_input_scale, v_input_scale, pv_output_scale, b_start_loc, b_seq_len, max_input_len + ): + + BLOCK = 128 + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk, "context process only supports equal query, key, value length" + assert Lk == Lv, "context process only supports equal query, key, value length" + assert Lk in {16, 32, 64, 128} + BLOCK_N = 128 + sm_scale = 1.0 / math.sqrt(Lk) + batch, head = b_seq_len.shape[0], q.shape[1] + grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) + + tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32) + num_warps = 4 if Lk <= 64 else 8 + + _context_flash_attention_kernel[grid]( + q, + k, + v, + q_input_scale, + k_input_scale, + v_input_scale, + pv_output_scale, + sm_scale, + b_start_loc, + b_seq_len, + tmp, + None, + o, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + tmp.stride(0), + tmp.stride(1), + tmp.stride(2), + BLOCK_M=BLOCK, + BLOCK_DMODEL=Lk, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return + + @triton.jit + def _token_attn_1_kernel( + Q, + K, + q_input_scale, + k_input_scale, + sm_scale, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seqlen, + max_kv_cache_len, + attn_out, + kv_cache_loc_b_stride, + kv_cache_loc_s_stride, + q_batch_stride, + q_head_stride, + q_head_dim_stride, + k_batch_stride, + k_head_stride, + k_head_dim_stride, + attn_head_stride, + attn_batch_stride, + HEAD_DIM: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + current_batch = tl.program_id(0) + current_head = tl.program_id(1) + start_n = tl.program_id(2) + + offs_d = tl.arange(0, HEAD_DIM) + current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch) + current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch) + + current_batch_start_index = max_kv_cache_len - current_batch_seq_len + current_batch_end_index = max_kv_cache_len + + off_q = current_batch * q_batch_stride + current_head * q_head_stride + offs_d * q_head_dim_stride + + offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) + + block_stard_index = start_n * BLOCK_N + block_mask = tl.where(block_stard_index < current_batch_seq_len, 1, 0) + + for start_mark in range(0, block_mask, 1): + q = tl.load(Q + off_q + start_mark) + q = q.to(tl.float16) * q_input_scale.to(tl.float16) + offs_n_new = current_batch_start_index + offs_n + k_loc = tl.load( + kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new, + mask=offs_n_new < current_batch_end_index, + other=0, + ) + off_k = k_loc[:, None] * k_batch_stride + current_head * k_head_stride + offs_d[None, :] * k_head_dim_stride + k = tl.load(K + off_k, mask=offs_n_new[:, None] < current_batch_end_index, other=0.0) + k = k.to(tl.float16) * k_input_scale.to(tl.float16) + att_value = tl.sum(q[None, :] * k, 1) + att_value *= sm_scale + off_o = current_head * attn_head_stride + (current_batch_in_all_start_index + offs_n) * attn_batch_stride + tl.store(attn_out + off_o, att_value, mask=offs_n_new < current_batch_end_index) + return + + @triton.jit + def _token_attn_1_alibi_kernel( + Q, + K, + q_input_scale, + k_input_scale, + sm_scale, + alibi, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seqlen, + max_kv_cache_len, + attn_out, + kv_cache_loc_b_stride, + kv_cache_loc_s_stride, + q_batch_stride, + q_head_stride, + q_head_dim_stride, + k_batch_stride, + k_head_stride, + k_head_dim_stride, + attn_head_stride, + attn_batch_stride, + HEAD_DIM: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + current_batch = tl.program_id(0) + current_head = tl.program_id(1) + start_n = tl.program_id(2) + + offs_d = tl.arange(0, HEAD_DIM) + current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch) + current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch) + + current_batch_start_index = max_kv_cache_len - current_batch_seq_len + current_batch_end_index = max_kv_cache_len + + off_q = current_batch * q_batch_stride + current_head * q_head_stride + offs_d * q_head_dim_stride + + offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) + + block_stard_index = start_n * BLOCK_N + block_mask = tl.where(block_stard_index < current_batch_seq_len, 1, 0) + + for start_mark in range(0, block_mask, 1): + alibi_m = tl.load(alibi + current_head) + q = tl.load(Q + off_q + start_mark) + q = q.to(tl.float16) * q_input_scale.to(tl.float16) + + offs_n_new = current_batch_start_index + offs_n + k_loc = tl.load( + kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new, + mask=offs_n_new < current_batch_end_index, + other=0, + ) + off_k = k_loc[:, None] * k_batch_stride + current_head * k_head_stride + offs_d[None, :] * k_head_dim_stride + k = tl.load(K + off_k, mask=offs_n_new[:, None] < current_batch_end_index, other=0.0) + k = k.to(tl.float16) * k_input_scale.to(tl.float16) + att_value = tl.sum(q[None, :] * k, 1) + att_value *= sm_scale + att_value -= alibi_m * (current_batch_seq_len - 1 - offs_n) + off_o = current_head * attn_head_stride + (current_batch_in_all_start_index + offs_n) * attn_batch_stride + tl.store(attn_out + off_o, att_value, mask=offs_n_new < current_batch_end_index) + return + + @torch.no_grad() + def token_attn_fwd_1( + q, + k, + attn_out, + q_input_scale, + k_input_scale, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seqlen, + max_kv_cache_len, + alibi=None, + ): + BLOCK = 32 + # shape constraints + q_head_dim, k_head_dim = q.shape[-1], k.shape[-1] + assert q_head_dim == k_head_dim + assert k_head_dim in {16, 32, 64, 128} + sm_scale = 1.0 / (k_head_dim**0.5) + + batch, head_num = kv_cache_loc.shape[0], q.shape[1] + + grid = (batch, head_num, triton.cdiv(max_kv_cache_len, BLOCK)) + + num_warps = 4 if k_head_dim <= 64 else 8 + num_warps = 2 + + if alibi is not None: + _token_attn_1_alibi_kernel[grid]( + q, + k, + q_input_scale, + k_input_scale, + sm_scale, + alibi, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seqlen, + max_kv_cache_len, + attn_out, + kv_cache_loc.stride(0), + kv_cache_loc.stride(1), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + attn_out.stride(0), + attn_out.stride(1), + HEAD_DIM=k_head_dim, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + else: + _token_attn_1_kernel[grid]( + q, + k, + q_input_scale, + k_input_scale, + sm_scale, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seqlen, + max_kv_cache_len, + attn_out, + kv_cache_loc.stride(0), + kv_cache_loc.stride(1), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + attn_out.stride(0), + attn_out.stride(1), + HEAD_DIM=k_head_dim, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return + + @triton.jit + def _token_attn_softmax_fwd( + softmax_logics, + kv_cache_start_loc, + kv_cache_seqlen, + softmax_prob_out, + logics_head_dim_stride, + logics_batch_stride, + prob_head_dim_stride, + prob_batch_stride, + BLOCK_SIZE: tl.constexpr, + ): + current_batch = tl.program_id(0) + current_head = tl.program_id(1) + + col_offsets = tl.arange(0, BLOCK_SIZE) + current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch) + current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch) + + row = tl.load( + softmax_logics + + current_head * logics_head_dim_stride + + (current_batch_in_all_start_index + col_offsets) * logics_batch_stride, + mask=col_offsets < current_batch_seq_len, + other=-float("inf"), + ).to(tl.float32) + + row_minus_max = row - tl.max(row, axis=0) + numerator = tl.exp(row_minus_max) + denominator = tl.sum(numerator, axis=0) + softmax_output = numerator / denominator + + tl.store( + softmax_prob_out + + current_head * prob_head_dim_stride + + (current_batch_in_all_start_index + col_offsets) * prob_batch_stride, + softmax_output, + mask=col_offsets < current_batch_seq_len, + ) + return + + @torch.no_grad() + def token_attn_softmax_fwd(softmax_logics, kv_cache_start_loc, kv_cache_seqlen, softmax_prob_out, max_kv_cache_len): + BLOCK_SIZE = triton.next_power_of_2(max_kv_cache_len) + batch, head_num = kv_cache_start_loc.shape[0], softmax_logics.shape[0] + + num_warps = 4 + if BLOCK_SIZE >= 2048: + num_warps = 8 + if BLOCK_SIZE >= 4096: + num_warps = 16 + + _token_attn_softmax_fwd[(batch, head_num)]( + softmax_logics, + kv_cache_start_loc, + kv_cache_seqlen, + softmax_prob_out, + softmax_logics.stride(0), + softmax_logics.stride(1), + softmax_prob_out.stride(0), + softmax_prob_out.stride(1), + num_warps=num_warps, + BLOCK_SIZE=BLOCK_SIZE, + ) + return + + @triton.jit + def _token_attn_2_kernel( + Prob, + V, + attn_out, + v_input_scale, + pv_output_scale, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seqlen, + max_kv_cache_len, + kv_cache_loc_b_stride, + kv_cache_loc_s_stride, + prob_head_dim_stride, + prob_batch_stride, + v_batch_stride, + v_head_stride, + v_head_dim_stride, + attn_out_batch_stride, + attn_out_head_stride, + attn_out_head_dim_stride, + HEAD_DIM: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + current_batch = tl.program_id(0) + current_head = tl.program_id(1) + + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, HEAD_DIM) + current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch) + current_batch_start_index = max_kv_cache_len - current_batch_seq_len + current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch) + + v_loc_off = current_batch * kv_cache_loc_b_stride + (current_batch_start_index + offs_n) * kv_cache_loc_s_stride + p_offs = current_head * prob_head_dim_stride + (current_batch_in_all_start_index + offs_n) * prob_batch_stride + v_offs = current_head * v_head_stride + offs_d[None, :] * v_head_dim_stride + + acc = tl.zeros([HEAD_DIM], dtype=tl.float32) + for start_n in range(0, current_batch_seq_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + p_value = tl.load( + Prob + p_offs + start_n * kv_cache_loc_s_stride, + mask=(start_n + offs_n) < current_batch_seq_len, + other=0.0, + ) + v_loc = tl.load( + kv_cache_loc + v_loc_off + start_n * kv_cache_loc_s_stride, + mask=(start_n + offs_n) < current_batch_seq_len, + other=0.0, + ) + v_value = tl.load( + V + v_offs + v_loc[:, None] * v_batch_stride, + mask=(start_n + offs_n[:, None]) < current_batch_seq_len, + other=0.0, + ) + v_value = v_value.to(tl.float16) * v_input_scale.to(tl.float16) + acc += tl.sum(p_value[:, None] * v_value, 0) + + acc = (acc / pv_output_scale.to(tl.float16)).to(tl.int8) + off_o = ( + current_batch * attn_out_batch_stride + + current_head * attn_out_head_stride + + offs_d * attn_out_head_dim_stride + ) + out_ptrs = attn_out + off_o + tl.store(out_ptrs, acc) + return + + @torch.no_grad() + def token_attn_fwd_2( + prob, + v, + attn_out, + v_input_scale, + pv_output_scale, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seqlen, + max_kv_cache_len, + ): + if triton.__version__ >= "2.1.0": + BLOCK = 128 + else: + BLOCK = 64 + batch, head = kv_cache_loc.shape[0], v.shape[1] + grid = (batch, head) + num_warps = 4 + dim = v.shape[-1] + + _token_attn_2_kernel[grid]( + prob, + v, + attn_out, + v_input_scale, + pv_output_scale, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seqlen, + max_kv_cache_len, + kv_cache_loc.stride(0), + kv_cache_loc.stride(1), + prob.stride(0), + prob.stride(1), + v.stride(0), + v.stride(1), + v.stride(2), + attn_out.stride(0), + attn_out.stride(1), + attn_out.stride(2), + HEAD_DIM=dim, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return + + @torch.no_grad() + def smooth_token_attention_fwd( + q, + k, + v, + attn_out, + q_input_scale, + k_input_scale, + v_input_scale, + pv_output_scale, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seq_len, + max_len_in_batch, + alibi=None, + ): + head_num = k.shape[1] + batch_size = kv_cache_seq_len.shape[0] + calcu_shape1 = (batch_size, head_num, k.shape[2]) + total_token_num = k.shape[0] + + att_m_tensor = torch.empty((head_num, total_token_num), dtype=torch.float32, device="cuda") + + token_attn_fwd_1( + q.view(calcu_shape1), + k, + att_m_tensor, + q_input_scale, + k_input_scale, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seq_len, + max_len_in_batch, + alibi=alibi, + ) + + prob = torch.empty_like(att_m_tensor) + + token_attn_softmax_fwd(att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch) + att_m_tensor = None + token_attn_fwd_2( + prob, + v, + attn_out.view(calcu_shape1), + v_input_scale, + pv_output_scale, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seq_len, + max_len_in_batch, + ) + + prob = None + + return diff --git a/examples/inference/smoothquant_llama.py b/examples/inference/smoothquant_llama.py new file mode 100644 index 000000000000..ce7a00aa2739 --- /dev/null +++ b/examples/inference/smoothquant_llama.py @@ -0,0 +1,69 @@ +import argparse +import os + +import torch +from datasets import load_dataset +from transformers import LlamaTokenizer + +from colossalai.inference.quant.smoothquant.models.llama import SmoothLlamaForCausalLM + + +def build_model_and_tokenizer(model_name): + tokenizer = LlamaTokenizer.from_pretrained(model_name, model_max_length=512) + kwargs = {"torch_dtype": torch.float16, "device_map": "sequential"} + model = SmoothLlamaForCausalLM.from_pretrained(model_name, **kwargs) + model = model.to(torch.float32) + return model, tokenizer + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--model-name", type=str, help="model name") + parser.add_argument( + "--output-path", + type=str, + help="where to save the checkpoint", + ) + parser.add_argument( + "--dataset-path", + type=str, + help="location of the calibration dataset", + ) + parser.add_argument("--num-samples", type=int, default=512) + parser.add_argument("--seq-len", type=int, default=512) + args = parser.parse_args() + return args + + +@torch.no_grad() +def main(): + args = parse_args() + model_path = args.model_name + dataset_path = args.dataset_path + output_path = args.output_path + num_samples = 10 + seq_len = 512 + + model, tokenizer = build_model_and_tokenizer(model_path) + if not os.path.exists(dataset_path): + print(f"Cannot find the dataset at {args.dataset_path}") + raise FileNotFoundError + dataset = load_dataset("json", data_files=dataset_path, split="train") + + model.quantized(tokenizer, dataset, num_samples=num_samples, seq_len=seq_len) + model = model.cuda() + + model.save_quantized(output_path, model_basename="llama-7b") + + model = SmoothLlamaForCausalLM.from_quantized(output_path, model_basename="llama-7b") + model = model.cuda() + + generate_kwargs = dict(max_new_tokens=16, do_sample=False, use_cache=True) + input_tokens = tokenizer(["today is "], return_tensors="pt").to("cuda") + out = model.generate(**input_tokens, **generate_kwargs) + text = tokenizer.batch_decode(out) + print("out is:", text) + + +if __name__ == "__main__": + main() diff --git a/op_builder/smoothquant.py b/op_builder/smoothquant.py new file mode 100644 index 000000000000..d562a4c4f626 --- /dev/null +++ b/op_builder/smoothquant.py @@ -0,0 +1,52 @@ +import torch + +from .builder import Builder +from .utils import append_nvcc_threads + + +class SmoothquantBuilder(Builder): + NAME = "cu_smoothquant" + PREBUILT_IMPORT_PATH = "colossalai._C.cu_smoothquant" + + def __init__(self): + super().__init__(name=SmoothquantBuilder.NAME, prebuilt_import_path=SmoothquantBuilder.PREBUILT_IMPORT_PATH) + + def include_dirs(self): + ret = [self.csrc_abs_path("smoothquant"), self.get_cuda_home_include()] + return ret + + def sources_files(self): + ret = [ + self.csrc_abs_path(fname) + for fname in [ + "smoothquant/binding.cpp", + "smoothquant/linear.cu", + ] + ] + return ret + + def cxx_flags(self): + return ["-O3"] + self.version_dependent_macros + + def nvcc_flags(self): + compute_capability = torch.cuda.get_device_capability() + cuda_arch = compute_capability[0] * 100 + compute_capability[1] * 10 + + extra_cuda_flags = [ + "-v", + f"-DCUDA_ARCH={cuda_arch}", + "-std=c++17", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-DTHRUST_IGNORE_CUB_VERSION_CHECK", + ] + + ret = ["-O3", "--use_fast_math"] + self.version_dependent_macros + extra_cuda_flags + return append_nvcc_threads(ret) + + def builder(self): + try: + super().builder() + except: + warnings.warn("build smoothquant lib not successful") diff --git a/tests/test_smoothquant/test_llama_attention.py b/tests/test_smoothquant/test_llama_attention.py new file mode 100644 index 000000000000..f8c79145c952 --- /dev/null +++ b/tests/test_smoothquant/test_llama_attention.py @@ -0,0 +1,136 @@ +import pytest +import torch +from packaging import version + +try: + from colossalai.kernel.triton import int8_rotary_embedding_fwd + + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +try: + from colossalai.inference.quant.smoothquant.models import LLamaSmoothquantAttention + + HAS_TORCH_INT = True +except ImportError: + HAS_TORCH_INT = False + print("Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int") + + +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") + +import math + +import torch +from torch.nn import functional as F + + +def torch_context_attention(xq, xk, xv, bs, seqlen, num_head, head_dim): + """ + adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/bloom/triton_kernel/context_flashattention_nopad.py#L253 + """ + xq = xq.view(bs, seqlen, num_head, head_dim) + xk = xk.view(bs, seqlen, num_head, head_dim) + xv = xv.view(bs, seqlen, num_head, head_dim) + mask = torch.tril(torch.ones(seqlen, seqlen), diagonal=0).unsqueeze(0).unsqueeze(0).cuda() + mask[mask == 0.0] = -100000000.0 + mask = mask.repeat(bs, num_head, 1, 1) + keys = xk + values = xv + xq = xq.transpose(1, 2) + keys = keys.transpose(1, 2) + values = values.transpose(1, 2) + scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(head_dim) + scores = F.softmax(scores.float() + mask, dim=-1).type_as(xq) + output = torch.matmul(scores, values).transpose(1, 2).contiguous().reshape(-1, num_head, head_dim) + + return output + + +@pytest.mark.skipif( + not TRITON_CUDA_SUPPORT or not HAS_TRITON or not HAS_TORCH_INT, + reason="triton requires cuda version to be higher than 11.4 or not install torch_int", +) +def test_llama_context_attention(): + head_num = 2 + seq_len = 32 + head_dim = 64 + dtype = torch.float + hidden_size = head_num * head_dim + + smooth_attn = LLamaSmoothquantAttention(head_num * head_dim, head_num) + + smooth_attn.q_proj.weight = torch.ones(hidden_size, hidden_size, device="cuda").to(torch.int8) + smooth_attn.k_proj.weight = torch.ones(hidden_size, hidden_size, device="cuda").to(torch.int8) + smooth_attn.v_proj.weight = torch.ones(hidden_size, hidden_size, device="cuda").to(torch.int8) + smooth_attn.out_proj.weight = torch.ones(hidden_size, hidden_size, device="cuda").to(torch.int8) + smooth_attn.out_proj.weight[:, 1:hidden_size] = torch.zeros(hidden_size - 1, device="cuda").to(torch.int8) + + qkv_weight_scale = 1.0 + + ones = torch.ones(hidden_size, hidden_size, dtype=torch.float, device="cuda") + + smooth_attn = smooth_attn.to("cuda") + + input = torch.randint(-20, 20, (1, seq_len, head_num * head_dim), dtype=torch.int8, device="cuda") + input_scale = 1 / 20.0 + + output = torch.matmul(input.to(torch.float) * input_scale, ones) + qkv_max_out = torch.max(torch.abs(output)) / 127 + smooth_attn.q_proj.a = torch.tensor(input_scale * qkv_weight_scale / qkv_max_out) + smooth_attn.k_proj.a = torch.tensor(input_scale * qkv_weight_scale / qkv_max_out) + smooth_attn.v_proj.a = torch.tensor(input_scale * qkv_weight_scale / qkv_max_out) + + q = smooth_attn.q_proj(input) + k = smooth_attn.k_proj(input) + v = smooth_attn.v_proj(input) + + cos_shape = (seq_len, head_dim // 2) + cos = torch.ones(cos_shape, dtype=dtype, device="cuda") + sin = torch.zeros(cos_shape, dtype=dtype, device="cuda") + in_scale = torch.tensor([qkv_max_out], device="cuda") + out_scale = torch.tensor([qkv_max_out], device="cuda") + int8_rotary_embedding_fwd(q.view(-1, head_num, head_dim), cos, sin, in_scale.item(), out_scale.item()) + int8_rotary_embedding_fwd(k.view(-1, head_num, head_dim), cos, sin, in_scale.item(), out_scale.item()) + + q = q.to(torch.float) * out_scale + k = k.to(torch.float) * out_scale + v = v.to(torch.float) * out_scale + torch_out = torch_context_attention(q.clone(), k.clone(), v.clone(), 1, seq_len, head_num, head_dim) + attn_out_max = torch.max(torch.abs(torch_out)) / 127 + + output = torch.matmul(torch_out.view(-1, seq_len, head_num * head_dim), ones) + smooth_attn.q_output_scale = torch.tensor(qkv_max_out) + smooth_attn.k_output_scale = torch.tensor(qkv_max_out) + + smooth_attn.v_output_scale = torch.tensor(qkv_max_out) + smooth_attn.q_rotary_output_scale = torch.tensor(qkv_max_out) + smooth_attn.k_rotary_output_scale = torch.tensor(qkv_max_out) + + smooth_attn.attn_output_scale = torch.tensor(attn_out_max) + smooth_attn.out_proj.a = torch.tensor([attn_out_max]) + + torch_out = ( + (torch_out / smooth_attn.attn_output_scale) + .round() + .clamp(-128, 127) + .to(torch.int8) + .view(-1, seq_len, head_num * head_dim) + ) + + torch_out = smooth_attn.out_proj(torch_out) + torch_out = torch_out.to(torch.float) + + smooth_attn = smooth_attn.to("cuda") + smooth_out, _, _ = smooth_attn(input, (cos, sin)) + smooth_out = smooth_out.to(torch.float) + + assert torch.allclose( + torch_out.cpu(), smooth_out.cpu(), rtol=1e-1, atol=1e-1 + ), "outputs from triton and torch are not matched" + + +if __name__ == "__main__": + test_llama_context_attention() diff --git a/tests/test_smoothquant/test_llama_mlp.py b/tests/test_smoothquant/test_llama_mlp.py new file mode 100644 index 000000000000..236edb10cb7f --- /dev/null +++ b/tests/test_smoothquant/test_llama_mlp.py @@ -0,0 +1,84 @@ +import warnings + +import pytest +import torch +from packaging import version + +try: + from colossalai.kernel.op_builder.smoothquant import SmoothquantBuilder + + smoothquant_cuda = SmoothquantBuilder().load() + HAS_SMOOTHQUANT_CUDA = True +except: + warnings.warn("CUDA smoothquant linear is not installed") + HAS_SMOOTHQUANT_CUDA = False + + +try: + from colossalai.inference.quant.smoothquant.models import LlamaSmoothquantMLP + + HAS_TORCH_INT = True +except: + HAS_TORCH_INT = False + warnings.warn("Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int") + + +CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") + + +def torch_llama_mlp(gate_proj, up_proj, down_proj, x): + gate_out = torch.mm(x, gate_proj) + silu = torch.nn.SiLU() + gate_out = silu(gate_out) + up_out = torch.mm(x, up_proj) + + o_out = gate_out * up_out + + max_up = torch.max(torch.abs(o_out)) + min_up = torch.min(torch.abs(o_out)) + + torch_out = torch.mm(o_out, down_proj) + + return (torch_out, max_up, min_up) + + +@pytest.mark.skipif( + not CUDA_SUPPORT or not HAS_SMOOTHQUANT_CUDA or not HAS_TORCH_INT, + reason="smoothquant linear not installed properly or not install torch_int", +) +def test_llama_mlp(): + hidden_size = 256 + intermediate_size = 512 + + smooth_mlp = LlamaSmoothquantMLP(intermediate_size, hidden_size) + + smooth_mlp.gate_proj.weight = torch.ones((intermediate_size, hidden_size), dtype=torch.int8, device="cuda") + + smooth_mlp.up_proj.weight = torch.randint( + -10, 10, (intermediate_size, hidden_size), dtype=torch.int8, device="cuda" + ) + smooth_mlp.down_proj.weight = torch.randint( + -10, 10, (hidden_size, intermediate_size), dtype=torch.int8, device="cuda" + ) + + x = torch.ones((1, 256), dtype=torch.int8, device="cuda") + + torch_out, max_inter, min_inter = torch_llama_mlp( + smooth_mlp.gate_proj.weight.transpose(0, 1).to(torch.float) / hidden_size, + smooth_mlp.up_proj.weight.transpose(0, 1).to(torch.float) / 127, + smooth_mlp.down_proj.weight.transpose(0, 1).to(torch.float) / 127, + x.to(torch.float), + ) + + smooth_mlp.down_proj_input_scale = torch.tensor(max_inter.item() / 127) + smooth_mlp.gate_proj.a = torch.tensor(1 / hidden_size) + smooth_mlp.up_proj.a = torch.tensor(1 / 127) + smooth_mlp.down_proj.a = torch.tensor(1 / 127 * (max_inter.item() / 127)) + + smooth_out = smooth_mlp(x) + + assert torch.allclose(torch_out, smooth_out, rtol=1e-02, atol=1e-01) + + +if __name__ == "__main__": + test_llama_mlp() diff --git a/tests/test_smoothquant/test_smoothquant_linear.py b/tests/test_smoothquant/test_smoothquant_linear.py new file mode 100644 index 000000000000..58a0b82f6759 --- /dev/null +++ b/tests/test_smoothquant/test_smoothquant_linear.py @@ -0,0 +1,39 @@ +import warnings + +import pytest +import torch + +try: + from colossalai.kernel.op_builder.smoothquant import SmoothquantBuilder + + smoothquant_cuda = SmoothquantBuilder().load() + HAS_SMOOTHQUANT_CUDA = True +except: + warnings.warn("CUDA smoothquant linear is not installed") + HAS_SMOOTHQUANT_CUDA = False + + +@pytest.mark.skipif( + not HAS_SMOOTHQUANT_CUDA, + reason="smoothquant linear not installed properly", +) +def test_linear(): + a = torch.randint(-127, 127, (128, 512), dtype=torch.int8, device="cuda") + b = torch.randint(-127, 127, (512, 256), dtype=torch.int8, device="cuda") + c = torch.rand(256, dtype=torch.float, device="cuda") + + alpha = 1 / 127 + beta = 1.0 + torch_out = torch.mm(a.to(torch.float) * alpha, b.to(torch.float)) + c + + silu = torch.nn.SiLU() + torch_out = silu(torch_out) + + b = b.transpose(0, 1).contiguous() + cuda_out = smoothquant_cuda.linear_silu_a8_w8_bfp32_ofp32(a, b, c, alpha, beta) + + assert torch.allclose(torch_out, cuda_out, rtol=1e-02, atol=1e-02) + + +if __name__ == "__main__": + test_linear() diff --git a/tests/test_smoothquant/test_sq_rotary_embedding.py b/tests/test_smoothquant/test_sq_rotary_embedding.py new file mode 100644 index 000000000000..4cc76f00474d --- /dev/null +++ b/tests/test_smoothquant/test_sq_rotary_embedding.py @@ -0,0 +1,59 @@ +# Adapted from ModelTC https://github.com/ModelTC/lightllm + + +import pytest +import torch +from packaging import version + +try: + from colossalai.kernel.triton import int8_rotary_embedding_fwd + + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") + + +def torch_rotary_emb(x, cos, sin): + seq_len, h, dim = x.shape + x0 = x[:, :, 0 : dim // 2] + x1 = x[:, :, dim // 2 : dim] + cos = cos.view((seq_len, 1, dim // 2)) + sin = sin.view((seq_len, 1, dim // 2)) + o0 = x0 * cos - x1 * sin + o1 = x0 * sin + x1 * cos + return torch.cat((o0, o1), dim=-1) + + +@pytest.mark.skipif( + not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" +) +def test_rotary_emb(): + SEQ_LEN = 1 + HEAD_NUM = 32 + HEAD_DIM = 128 + dtype = torch.float + # create data + x_shape = (SEQ_LEN, HEAD_NUM, HEAD_DIM) + x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda") + cos_shape = (SEQ_LEN, HEAD_DIM // 2) + cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") + sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") + # forward pass + y_torch = torch_rotary_emb(x, cos, sin) + + input_scale = torch.max(torch.abs(x)) / 127 + output_scale = torch.max(torch.abs(y_torch)) / 127 + + x = x / input_scale + x = x.to(torch.int8) + + int8_rotary_embedding_fwd(x, cos, sin, input_scale.item(), output_scale.item()) + y_triton = x.to(torch.float) * output_scale + assert torch.allclose(y_triton, y_torch, atol=2e-1, rtol=1e-2, equal_nan=True) + + +if __name__ == "__main__": + test_rotary_emb()