diff --git a/colossalai/inference/config.py b/colossalai/inference/config.py index 5014821d0caf..f54555857957 100644 --- a/colossalai/inference/config.py +++ b/colossalai/inference/config.py @@ -32,6 +32,7 @@ class InferenceConfig: During generation, the beam width provided as sampling parameter should be less than or equivalent to this value. prefill_ratio (Optional[float]): A controling ratio for prefill and decoding in running list, we will do a step of prefill when the actual value exceeds this ratio. + pad_input: Whether to pad all inputs to the max length. quant_mode (Optional[str]): Quantization mode. revision (Optional[str]): The specific version(a branch, name, a commit id, or a tag name) of model to use. """ @@ -49,6 +50,7 @@ class InferenceConfig: beam_width: int = 1 # the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio prefill_ratio: Optional[float] = 1.2 + pad_input: bool = False quant_mode: Optional[str] = None revision: Optional[str] = None diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index a9686f07c8d6..7b21d1750fb4 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -57,7 +57,11 @@ def __init__( model.to(self.dtype) if model_policy is None: - model_policy = model_policy_map[self.model_config.model_type]() + if self.inference_config.pad_input: + model_type = "padding_" + self.model_config.model_type + else: + model_type = "nopadding_" + self.model_config.model_type + model_policy = model_policy_map[model_type]() pg_mesh = ProcessGroupMesh(inference_config.pp_size, inference_config.tp_size) @@ -168,7 +172,9 @@ def add_request( if prompts_token_ids is None: assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided." - prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=True)["input_ids"] + prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=self.inference_config.pad_input)[ + "input_ids" + ] if isinstance(prompts_token_ids, list): pass @@ -237,7 +243,9 @@ def step(self) -> List[str]: self.v_cache, ) - logits = logits[:, -1, :] + if self.inference_config.pad_input: + logits = logits[:, -1, :] + self.request_handler.search_tokens(self.generation_config, logits) finished_sequences = self.request_handler.update() diff --git a/colossalai/inference/modeling/models/nopadding_llama.py b/colossalai/inference/modeling/models/nopadding_llama.py new file mode 100644 index 000000000000..3a81a97f7a2e --- /dev/null +++ b/colossalai/inference/modeling/models/nopadding_llama.py @@ -0,0 +1,221 @@ +# This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/llama/modeling_llama.py +from typing import List, Optional, Tuple + +import torch +from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaDecoderLayer, + LlamaForCausalLM, + LlamaMLP, + LlamaModel, +) + +from colossalai.inference.flash_decoding_utils import FDIntermTensors +from colossalai.inference.struct import BatchInfo +from colossalai.kernel.triton import ( + context_attention_unpadded, + copy_kv_to_blocked_cache, + flash_decoding_attention, + get_xine_cache, + rotary_embedding, +) +from colossalai.logging import get_dist_logger + +from flash_attn.bert_padding import index_first_axis, pad_input # noqa + +logger = get_dist_logger(__name__) + +try: + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + logger.warning(f"triton has not been installed yet, we will use torch to complete the attention calculation.") + + +@torch.no_grad() +def llama_causal_lm_forward( + self: LlamaForCausalLM, + batch: BatchInfo = None, + k_caches: List[torch.Tensor] = None, + v_caches: List[torch.Tensor] = None, +): + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + hidden_states = llama_model_forward( + self.model, + batch=batch, + k_caches=k_caches, + v_caches=v_caches, + ) + logits = torch.mm(hidden_states, self.lm_head.weight.transpose(0, 1)) + return logits + + +@torch.no_grad() +def llama_model_forward( + self: LlamaModel, + batch: BatchInfo = None, + k_caches: List[torch.Tensor] = None, + v_caches: List[torch.Tensor] = None, +): + input_ids = batch.get_1D_inputs() + block_tables = batch.get_block_table_tensor() + + sequence_lengths = batch.get_sequence_lengths() + batch_size = len(sequence_lengths) + kv_seq_len = sequence_lengths.max().item() + + hidden_states = self.embed_tokens(input_ids) + + cos_sin = get_xine_cache(sequence_lengths, self._cos_cached, self._sin_cached, batch.is_prompts) + + if batch.is_prompts: + output_tensor = torch.zeros( + (sequence_lengths.sum().item(), batch.num_heads, batch.head_dim), dtype=batch.dtype, device=batch.device + ) + else: + output_tensor = torch.zeros( + (batch_size, 1, batch.num_heads, batch.head_dim), dtype=batch.dtype, device=batch.device + ) + sm_scale = 1.0 / (batch.head_dim**0.5) + + for layer_id, decoder_layer in enumerate(self.layers): + hidden_states = decoder_layer( + hidden_states, + block_tables=block_tables, + k_cache=k_caches[layer_id], + v_cache=v_caches[layer_id], + is_prompts=batch.is_prompts, + sequence_lengths=sequence_lengths, + kv_seq_len=kv_seq_len, + cos_sin=cos_sin, + fd_inter_tensor=batch.fd_inter_tensor, + output_tensor=output_tensor, + sm_scale=sm_scale, + ) + + if batch.is_prompts: + last_token_indexs = sequence_lengths.cumsum(dim=-1) + hidden_states = hidden_states[last_token_indexs - 1].contiguous() + hidden_states = self.norm(hidden_states) + + return hidden_states + + +@torch.no_grad() +def llama_decoder_layer_forward( + self: LlamaDecoderLayer, + hidden_states: torch.Tensor, + block_tables: torch.Tensor = None, + k_cache: torch.Tensor = None, + v_cache: torch.Tensor = None, + is_prompts: bool = True, + sequence_lengths: torch.Tensor = None, + kv_seq_len: int = 0, + cos_sin: Tuple[torch.Tensor] = None, + fd_inter_tensor: FDIntermTensors = None, + output_tensor: torch.Tensor = None, + sm_scale: int = None, +) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states = self.self_attn( + hidden_states=hidden_states, + block_tables=block_tables, + k_cache=k_cache, + v_cache=v_cache, + is_prompts=is_prompts, + sequence_lengths=sequence_lengths, + kv_seq_len=kv_seq_len, + cos_sin=cos_sin, + fd_inter_tensor=fd_inter_tensor, + output_tensor=output_tensor, + sm_scale=sm_scale, + ) + + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +# Replace transformers.models.llama.modeling_llama.LlamaAttention.forward +@torch.no_grad() +def llama_attn_forward( + self: LlamaAttention, + hidden_states: torch.Tensor, + block_tables: torch.Tensor = None, + k_cache: torch.Tensor = None, + v_cache: torch.Tensor = None, + is_prompts: bool = True, + sequence_lengths: torch.Tensor = None, + kv_seq_len: int = 0, + cos_sin: Tuple[torch.Tensor] = None, + fd_inter_tensor: FDIntermTensors = None, + output_tensor: torch.Tensor = None, + sm_scale: int = None, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + query_states = torch.mm(hidden_states, self.q_proj.weight.transpose(0, 1)).view(-1, self.num_heads, self.head_dim) + key_states = torch.mm(hidden_states, self.k_proj.weight.transpose(0, 1)).view( + -1, self.num_key_value_heads, self.head_dim + ) + value_states = torch.mm(hidden_states, self.v_proj.weight.transpose(0, 1)).view( + -1, self.num_key_value_heads, self.head_dim + ) + + rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1]) + + _, _, _, block_size = k_cache.shape + + if is_prompts: + attn_output = context_attention_unpadded( + q=query_states, + k=key_states, + v=value_states, + k_cache=k_cache, + v_cache=v_cache, + context_lengths=sequence_lengths, + block_tables=block_tables, + block_size=block_size, + output=output_tensor, + max_seq_len=kv_seq_len, + sm_scale=sm_scale, + ) + else: + copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables) + copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables) + attn_output = flash_decoding_attention( + q=query_states, + k_cache=k_cache, + v_cache=v_cache, + kv_seq_len=sequence_lengths, + block_tables=block_tables, + block_size=block_size, + max_seq_len_in_batch=kv_seq_len, + output=output_tensor, + mid_output=fd_inter_tensor.mid_output, + mid_output_lse=fd_inter_tensor.mid_output_lse, + sm_scale=sm_scale, + ) + attn_output = attn_output.squeeze(1) + + attn_output = attn_output.view(-1, self.num_heads, self.head_dim) + attn_output = attn_output.reshape(-1, self.hidden_size) + attn_output = torch.mm(attn_output, self.o_proj.weight.transpose(0, 1)) + + return attn_output + + +@torch.no_grad() +def nopad_mlp(self: LlamaMLP, hidden_states: torch.Tensor): + gate_proj_out = torch.mm(hidden_states, self.gate_proj.weight.transpose(0, 1)) + act_out = torch.nn.functional.silu(gate_proj_out, inplace=True) + up_proj_out = torch.mm(hidden_states, self.up_proj.weight.transpose(0, 1)) + tmp_out = act_out * up_proj_out + return torch.mm(tmp_out, self.down_proj.weight.transpose(0, 1)) diff --git a/colossalai/inference/modeling/models/llama.py b/colossalai/inference/modeling/models/padding_llama.py similarity index 90% rename from colossalai/inference/modeling/models/llama.py rename to colossalai/inference/modeling/models/padding_llama.py index 3e38905451fe..fb66360f5a6d 100644 --- a/colossalai/inference/modeling/models/llama.py +++ b/colossalai/inference/modeling/models/padding_llama.py @@ -11,6 +11,7 @@ context_attention_unpadded, copy_kv_to_blocked_cache, flash_decoding_attention, + get_xine_cache, rotary_embedding, ) from colossalai.logging import get_dist_logger @@ -101,12 +102,7 @@ def llama_model_forward( hidden_states = self.embed_tokens(input_ids) - # When testing, the performance of get_xine_cache is lower than that of get_cos_sin. - # cos = get_xine_cache(sequence_lengths, self._cos_cached, batch.is_prompts) - # sin = get_xine_cache(sequence_lengths, self._sin_cached, batch.is_prompts) - # cos_sin = (cos, sin) - - cos_sin = get_cos_sin(sequence_lengths, self._cos_cached, self._sin_cached, batch.is_prompts, batch.dtype) + cos_sin = get_xine_cache(sequence_lengths, self._cos_cached, self._sin_cached, batch.is_prompts) if batch.is_prompts: output_tensor = torch.zeros( @@ -135,7 +131,9 @@ def llama_model_forward( sm_scale=sm_scale, ) + hidden_states = hidden_states[:, -1, :].unsqueeze(dim=1).contiguous() hidden_states = self.norm(hidden_states) + return hidden_states @@ -327,26 +325,3 @@ def unpading_input(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attention_ k = index_first_axis(k.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices) v = index_first_axis(v.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices) return (q, k, v, indices) - - -@torch.no_grad() -def get_cos_sin(lengths, cos_cache, sin_cache, is_prompts, dtype): - """ - Get cos and sin for the cache, and return nopad format. - Args: - lengths: shape(num_seqs,), stores lenghth of each sequence. - cos_cache: shape(max_rotary_position(e.g.2048), head_dim), cos cache constrcuted in model. - sin_cache: shape(max_rotary_position(e.g.2048), head_dim), sin cache constrcuted in model. - is_prompts: bool, mark if in prefill mode. - dtype: The data type of this inference process. - """ - - if is_prompts: - index_arrays = [torch.arange(length) for length in lengths] - else: - index_arrays = [(length - 1).view(-1) for length in lengths] - indices = torch.cat(index_arrays, dim=-1) - cos_output = cos_cache[indices].to(dtype=dtype) - sin_output = sin_cache[indices].to(dtype=dtype) - - return (cos_output, sin_output) diff --git a/colossalai/inference/modeling/policy/__init__.py b/colossalai/inference/modeling/policy/__init__.py index 1009939416ed..9477cd957418 100644 --- a/colossalai/inference/modeling/policy/__init__.py +++ b/colossalai/inference/modeling/policy/__init__.py @@ -1,7 +1,9 @@ -from .llama import LlamaModelInferPolicy +from .nopadding_llama import NoPaddingLlamaModelInferPolicy +from .padding_llama import PaddingLlamaModelInferPolicy model_policy_map = { - "llama": LlamaModelInferPolicy, + "padding_llama": PaddingLlamaModelInferPolicy, + "nopadding_llama": NoPaddingLlamaModelInferPolicy, } -__all__ = ["LlamaModelInferPolicy", "model_polic_map"] +__all__ = ["PaddingLlamaModelInferPolicy", "NoPaddingLlamaModelInferPolicy", "model_polic_map"] diff --git a/colossalai/inference/modeling/policy/nopadding_llama.py b/colossalai/inference/modeling/policy/nopadding_llama.py new file mode 100644 index 000000000000..3eaa59f74cdd --- /dev/null +++ b/colossalai/inference/modeling/policy/nopadding_llama.py @@ -0,0 +1,107 @@ +from functools import partial + +import torch +from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaDecoderLayer, + LlamaFlashAttention2, + LlamaForCausalLM, + LlamaMLP, + LlamaModel, + LlamaRMSNorm, + LlamaSdpaAttention, +) + +from colossalai.inference.modeling.models.nopadding_llama import ( + llama_attn_forward, + llama_causal_lm_forward, + llama_decoder_layer_forward, + llama_model_forward, + nopad_mlp, +) +from colossalai.inference.utils import init_to_get_rotary + +# import colossalai +from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy + +try: + from colossalai.kernel.triton import rms_layernorm + + HAS_TRITON_RMSNORM = True +except: + print("you should install triton from https://github.com/openai/triton") + HAS_TRITON_RMSNORM = False + + +def get_triton_rmsnorm_forward(): + if HAS_TRITON_RMSNORM: + + def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor): + return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon) + + return _triton_rmsnorm_forward + else: + return None + + +class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + policy = super().module_policy() + self.shard_config._infer() + + infer_forward = llama_causal_lm_forward + method_replacement = {"forward": partial(infer_forward)} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=LlamaForCausalLM + ) + + infer_forward = llama_model_forward + method_replacement = {"forward": partial(infer_forward)} + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaModel) + + infer_forward = llama_decoder_layer_forward + method_replacement = {"forward": partial(infer_forward)} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=LlamaDecoderLayer + ) + + infer_forward = nopad_mlp + method_replacement = {"forward": partial(infer_forward)} + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaMLP) + + infer_forward = llama_attn_forward + method_replacement = {"forward": partial(infer_forward)} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=LlamaAttention + ) + + infer_forward = llama_attn_forward + method_replacement = {"forward": partial(infer_forward)} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=LlamaFlashAttention2 + ) + + infer_forward = llama_attn_forward + method_replacement = {"forward": partial(infer_forward)} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=LlamaSdpaAttention + ) + + infer_forward = None + if HAS_TRITON_RMSNORM: + infer_forward = get_triton_rmsnorm_forward() + + if infer_forward is not None: + method_replacement = {"forward": partial(infer_forward)} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=LlamaRMSNorm + ) + + return policy + + def postprocess(self): + init_to_get_rotary(self.model.model) + return self.model diff --git a/colossalai/inference/modeling/policy/llama.py b/colossalai/inference/modeling/policy/padding_llama.py similarity index 98% rename from colossalai/inference/modeling/policy/llama.py rename to colossalai/inference/modeling/policy/padding_llama.py index 514c274adb99..0c83189f8d6b 100644 --- a/colossalai/inference/modeling/policy/llama.py +++ b/colossalai/inference/modeling/policy/padding_llama.py @@ -11,7 +11,7 @@ LlamaSdpaAttention, ) -from colossalai.inference.modeling.models.llama import ( +from colossalai.inference.modeling.models.padding_llama import ( llama_attn_forward, llama_causal_lm_forward, llama_decoder_layer_forward, @@ -43,7 +43,7 @@ def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor): return None -class LlamaModelInferPolicy(LlamaForCausalLMPolicy): +class PaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy): def __init__(self) -> None: super().__init__() diff --git a/colossalai/inference/struct.py b/colossalai/inference/struct.py index feb50da9923c..22b5b5a3ab2f 100644 --- a/colossalai/inference/struct.py +++ b/colossalai/inference/struct.py @@ -358,21 +358,16 @@ def get_1D_inputs(self) -> Tuple[torch.LongTensor, torch.Tensor]: Flattening the input tokens. """ input_list = [] - input_len_list = [] assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first." for seq in self.sequences_set: if self.is_prompts: input_list.extend(seq.input_token_id) - input_len_list.append(seq.sentence_len) else: input_list.append(seq.output_token_id[-1]) - input_len_list.append(1) - return torch.tensor(input_list, dtype=torch.long, device=self.device), torch.tensor( - input_len_list, dtype=torch.int, device=self.device - ) + return torch.tensor(input_list, dtype=torch.long, device=self.device) def get_sequence_lengths(self): """ @@ -401,7 +396,9 @@ def get_attn_mask(self) -> torch.Tensor: past_values.append(seq.input_token_id + seq.output_token_id) max_seq_len = max(len(sub_list) for sub_list in past_values) - attn_mask = _make_tensor_with_pad(past_values, max_seq_len, 0, dtype=torch.int, device=self.device) + attn_mask = _make_tensor_with_pad( + past_values, max_seq_len, self.sequences_set[0].pad_token_id, dtype=torch.int, device=self.device + ) return attn_mask.ne(padding_id).long() diff --git a/tests/test_infer_ops/triton/test_xine_copy.py b/tests/test_infer_ops/triton/test_xine_copy.py index da2720659032..c19be5abe338 100644 --- a/tests/test_infer_ops/triton/test_xine_copy.py +++ b/tests/test_infer_ops/triton/test_xine_copy.py @@ -2,7 +2,6 @@ import torch from packaging import version -from colossalai.inference.modeling.models.llama import get_cos_sin from colossalai.kernel.triton import get_xine_cache try: @@ -16,6 +15,29 @@ TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") +@torch.no_grad() +def get_cos_sin(lengths, cos_cache, sin_cache, is_prompts, dtype): + """ + Get cos and sin for the cache, and return nopad format. + Args: + lengths: shape(num_seqs,), stores lenghth of each sequence. + cos_cache: shape(max_rotary_position(e.g.2048), head_dim), cos cache constrcuted in model. + sin_cache: shape(max_rotary_position(e.g.2048), head_dim), sin cache constrcuted in model. + is_prompts: bool, mark if in prefill mode. + dtype: The data type of this inference process. + """ + + if is_prompts: + index_arrays = [torch.arange(length) for length in lengths] + else: + index_arrays = [(length - 1).view(-1) for length in lengths] + indices = torch.cat(index_arrays, dim=-1) + cos_output = cos_cache[indices].to(dtype=dtype) + sin_output = sin_cache[indices].to(dtype=dtype) + + return (cos_output, sin_output) + + @pytest.mark.parametrize("BATCH_SIZE", [4]) @pytest.mark.parametrize("MAX_SEQ_LEN", [64]) @pytest.mark.parametrize("HEAD_DIM", [64]) @@ -23,15 +45,18 @@ def test_get_xine_cache(BATCH_SIZE, MAX_SEQ_LEN, HEAD_DIM, dtype): MAX_TOTAL_TOKENS = BATCH_SIZE * MAX_SEQ_LEN cos_cache = torch.randn((MAX_TOTAL_TOKENS, HEAD_DIM), dtype=dtype, device="cuda") + sin_cache = torch.randn((MAX_TOTAL_TOKENS, HEAD_DIM), dtype=dtype, device="cuda") lengths = torch.randint(2, MAX_SEQ_LEN, (BATCH_SIZE,), device="cuda") # prefill - cos_ref, sin_ref = get_cos_sin(lengths, cos_cache, cos_cache, is_prompts=True, dtype=dtype) - cos = get_xine_cache(lengths, cos_cache, is_prompts=True) + cos_ref, sin_ref = get_cos_sin(lengths, cos_cache, sin_cache, is_prompts=True, dtype=dtype) + cos, sin = get_xine_cache(lengths, cos_cache, sin_cache, is_prompts=True) assert torch.allclose(cos, cos_ref) + assert torch.allclose(sin, sin_ref) # decoding - ncos_ref, sin_ref = get_cos_sin(lengths, cos_cache, cos_cache, is_prompts=False, dtype=dtype) - cos = get_xine_cache(lengths, cos_cache, is_prompts=False) + ncos_ref, sin_ref = get_cos_sin(lengths, cos_cache, sin_cache, is_prompts=False, dtype=dtype) + cos, sin = get_xine_cache(lengths, cos_cache, sin_cache, is_prompts=False) assert torch.allclose(cos, ncos_ref) + assert torch.allclose(sin, sin_ref) configs = [