Skip to content

Commit

Permalink
[Inference]Add Nopadding Llama Modeling (#5327)
Browse files Browse the repository at this point in the history
* add nopadding llama modeling

* add nopadding_llama.py

* rm unused codes

* fix bugs in test_xine_copy.py

* fix code style
  • Loading branch information
isky-cd authored Jan 30, 2024
1 parent c7c104c commit e8f0642
Show file tree
Hide file tree
Showing 9 changed files with 386 additions and 49 deletions.
2 changes: 2 additions & 0 deletions colossalai/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class InferenceConfig:
During generation, the beam width provided as sampling parameter should be less than or equivalent to this value.
prefill_ratio (Optional[float]): A controling ratio for prefill and decoding in running list, we will do a step of prefill
when the actual value exceeds this ratio.
pad_input: Whether to pad all inputs to the max length.
quant_mode (Optional[str]): Quantization mode.
revision (Optional[str]): The specific version(a branch, name, a commit id, or a tag name) of model to use.
"""
Expand All @@ -49,6 +50,7 @@ class InferenceConfig:
beam_width: int = 1
# the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio
prefill_ratio: Optional[float] = 1.2
pad_input: bool = False
quant_mode: Optional[str] = None
revision: Optional[str] = None

Expand Down
14 changes: 11 additions & 3 deletions colossalai/inference/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,11 @@ def __init__(
model.to(self.dtype)

if model_policy is None:
model_policy = model_policy_map[self.model_config.model_type]()
if self.inference_config.pad_input:
model_type = "padding_" + self.model_config.model_type
else:
model_type = "nopadding_" + self.model_config.model_type
model_policy = model_policy_map[model_type]()

pg_mesh = ProcessGroupMesh(inference_config.pp_size, inference_config.tp_size)

Expand Down Expand Up @@ -168,7 +172,9 @@ def add_request(

if prompts_token_ids is None:
assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided."
prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=True)["input_ids"]
prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=self.inference_config.pad_input)[
"input_ids"
]

if isinstance(prompts_token_ids, list):
pass
Expand Down Expand Up @@ -237,7 +243,9 @@ def step(self) -> List[str]:
self.v_cache,
)

logits = logits[:, -1, :]
if self.inference_config.pad_input:
logits = logits[:, -1, :]

self.request_handler.search_tokens(self.generation_config, logits)
finished_sequences = self.request_handler.update()

Expand Down
221 changes: 221 additions & 0 deletions colossalai/inference/modeling/models/nopadding_llama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
# This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/llama/modeling_llama.py
from typing import List, Optional, Tuple

import torch
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaDecoderLayer,
LlamaForCausalLM,
LlamaMLP,
LlamaModel,
)

from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.inference.struct import BatchInfo
from colossalai.kernel.triton import (
context_attention_unpadded,
copy_kv_to_blocked_cache,
flash_decoding_attention,
get_xine_cache,
rotary_embedding,
)
from colossalai.logging import get_dist_logger

from flash_attn.bert_padding import index_first_axis, pad_input # noqa

logger = get_dist_logger(__name__)

try:
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
logger.warning(f"triton has not been installed yet, we will use torch to complete the attention calculation.")


@torch.no_grad()
def llama_causal_lm_forward(
self: LlamaForCausalLM,
batch: BatchInfo = None,
k_caches: List[torch.Tensor] = None,
v_caches: List[torch.Tensor] = None,
):
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
hidden_states = llama_model_forward(
self.model,
batch=batch,
k_caches=k_caches,
v_caches=v_caches,
)
logits = torch.mm(hidden_states, self.lm_head.weight.transpose(0, 1))
return logits


@torch.no_grad()
def llama_model_forward(
self: LlamaModel,
batch: BatchInfo = None,
k_caches: List[torch.Tensor] = None,
v_caches: List[torch.Tensor] = None,
):
input_ids = batch.get_1D_inputs()
block_tables = batch.get_block_table_tensor()

sequence_lengths = batch.get_sequence_lengths()
batch_size = len(sequence_lengths)
kv_seq_len = sequence_lengths.max().item()

hidden_states = self.embed_tokens(input_ids)

cos_sin = get_xine_cache(sequence_lengths, self._cos_cached, self._sin_cached, batch.is_prompts)

if batch.is_prompts:
output_tensor = torch.zeros(
(sequence_lengths.sum().item(), batch.num_heads, batch.head_dim), dtype=batch.dtype, device=batch.device
)
else:
output_tensor = torch.zeros(
(batch_size, 1, batch.num_heads, batch.head_dim), dtype=batch.dtype, device=batch.device
)
sm_scale = 1.0 / (batch.head_dim**0.5)

for layer_id, decoder_layer in enumerate(self.layers):
hidden_states = decoder_layer(
hidden_states,
block_tables=block_tables,
k_cache=k_caches[layer_id],
v_cache=v_caches[layer_id],
is_prompts=batch.is_prompts,
sequence_lengths=sequence_lengths,
kv_seq_len=kv_seq_len,
cos_sin=cos_sin,
fd_inter_tensor=batch.fd_inter_tensor,
output_tensor=output_tensor,
sm_scale=sm_scale,
)

if batch.is_prompts:
last_token_indexs = sequence_lengths.cumsum(dim=-1)
hidden_states = hidden_states[last_token_indexs - 1].contiguous()
hidden_states = self.norm(hidden_states)

return hidden_states


@torch.no_grad()
def llama_decoder_layer_forward(
self: LlamaDecoderLayer,
hidden_states: torch.Tensor,
block_tables: torch.Tensor = None,
k_cache: torch.Tensor = None,
v_cache: torch.Tensor = None,
is_prompts: bool = True,
sequence_lengths: torch.Tensor = None,
kv_seq_len: int = 0,
cos_sin: Tuple[torch.Tensor] = None,
fd_inter_tensor: FDIntermTensors = None,
output_tensor: torch.Tensor = None,
sm_scale: int = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states

hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states = self.self_attn(
hidden_states=hidden_states,
block_tables=block_tables,
k_cache=k_cache,
v_cache=v_cache,
is_prompts=is_prompts,
sequence_lengths=sequence_lengths,
kv_seq_len=kv_seq_len,
cos_sin=cos_sin,
fd_inter_tensor=fd_inter_tensor,
output_tensor=output_tensor,
sm_scale=sm_scale,
)

hidden_states = residual + hidden_states

# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states

return hidden_states


# Replace transformers.models.llama.modeling_llama.LlamaAttention.forward
@torch.no_grad()
def llama_attn_forward(
self: LlamaAttention,
hidden_states: torch.Tensor,
block_tables: torch.Tensor = None,
k_cache: torch.Tensor = None,
v_cache: torch.Tensor = None,
is_prompts: bool = True,
sequence_lengths: torch.Tensor = None,
kv_seq_len: int = 0,
cos_sin: Tuple[torch.Tensor] = None,
fd_inter_tensor: FDIntermTensors = None,
output_tensor: torch.Tensor = None,
sm_scale: int = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
query_states = torch.mm(hidden_states, self.q_proj.weight.transpose(0, 1)).view(-1, self.num_heads, self.head_dim)
key_states = torch.mm(hidden_states, self.k_proj.weight.transpose(0, 1)).view(
-1, self.num_key_value_heads, self.head_dim
)
value_states = torch.mm(hidden_states, self.v_proj.weight.transpose(0, 1)).view(
-1, self.num_key_value_heads, self.head_dim
)

rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])

_, _, _, block_size = k_cache.shape

if is_prompts:
attn_output = context_attention_unpadded(
q=query_states,
k=key_states,
v=value_states,
k_cache=k_cache,
v_cache=v_cache,
context_lengths=sequence_lengths,
block_tables=block_tables,
block_size=block_size,
output=output_tensor,
max_seq_len=kv_seq_len,
sm_scale=sm_scale,
)
else:
copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
attn_output = flash_decoding_attention(
q=query_states,
k_cache=k_cache,
v_cache=v_cache,
kv_seq_len=sequence_lengths,
block_tables=block_tables,
block_size=block_size,
max_seq_len_in_batch=kv_seq_len,
output=output_tensor,
mid_output=fd_inter_tensor.mid_output,
mid_output_lse=fd_inter_tensor.mid_output_lse,
sm_scale=sm_scale,
)
attn_output = attn_output.squeeze(1)

attn_output = attn_output.view(-1, self.num_heads, self.head_dim)
attn_output = attn_output.reshape(-1, self.hidden_size)
attn_output = torch.mm(attn_output, self.o_proj.weight.transpose(0, 1))

return attn_output


@torch.no_grad()
def nopad_mlp(self: LlamaMLP, hidden_states: torch.Tensor):
gate_proj_out = torch.mm(hidden_states, self.gate_proj.weight.transpose(0, 1))
act_out = torch.nn.functional.silu(gate_proj_out, inplace=True)
up_proj_out = torch.mm(hidden_states, self.up_proj.weight.transpose(0, 1))
tmp_out = act_out * up_proj_out
return torch.mm(tmp_out, self.down_proj.weight.transpose(0, 1))
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
context_attention_unpadded,
copy_kv_to_blocked_cache,
flash_decoding_attention,
get_xine_cache,
rotary_embedding,
)
from colossalai.logging import get_dist_logger
Expand Down Expand Up @@ -101,12 +102,7 @@ def llama_model_forward(

hidden_states = self.embed_tokens(input_ids)

# When testing, the performance of get_xine_cache is lower than that of get_cos_sin.
# cos = get_xine_cache(sequence_lengths, self._cos_cached, batch.is_prompts)
# sin = get_xine_cache(sequence_lengths, self._sin_cached, batch.is_prompts)
# cos_sin = (cos, sin)

cos_sin = get_cos_sin(sequence_lengths, self._cos_cached, self._sin_cached, batch.is_prompts, batch.dtype)
cos_sin = get_xine_cache(sequence_lengths, self._cos_cached, self._sin_cached, batch.is_prompts)

if batch.is_prompts:
output_tensor = torch.zeros(
Expand Down Expand Up @@ -135,7 +131,9 @@ def llama_model_forward(
sm_scale=sm_scale,
)

hidden_states = hidden_states[:, -1, :].unsqueeze(dim=1).contiguous()
hidden_states = self.norm(hidden_states)

return hidden_states


Expand Down Expand Up @@ -327,26 +325,3 @@ def unpading_input(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attention_
k = index_first_axis(k.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices)
v = index_first_axis(v.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices)
return (q, k, v, indices)


@torch.no_grad()
def get_cos_sin(lengths, cos_cache, sin_cache, is_prompts, dtype):
"""
Get cos and sin for the cache, and return nopad format.
Args:
lengths: shape(num_seqs,), stores lenghth of each sequence.
cos_cache: shape(max_rotary_position(e.g.2048), head_dim), cos cache constrcuted in model.
sin_cache: shape(max_rotary_position(e.g.2048), head_dim), sin cache constrcuted in model.
is_prompts: bool, mark if in prefill mode.
dtype: The data type of this inference process.
"""

if is_prompts:
index_arrays = [torch.arange(length) for length in lengths]
else:
index_arrays = [(length - 1).view(-1) for length in lengths]
indices = torch.cat(index_arrays, dim=-1)
cos_output = cos_cache[indices].to(dtype=dtype)
sin_output = sin_cache[indices].to(dtype=dtype)

return (cos_output, sin_output)
8 changes: 5 additions & 3 deletions colossalai/inference/modeling/policy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from .llama import LlamaModelInferPolicy
from .nopadding_llama import NoPaddingLlamaModelInferPolicy
from .padding_llama import PaddingLlamaModelInferPolicy

model_policy_map = {
"llama": LlamaModelInferPolicy,
"padding_llama": PaddingLlamaModelInferPolicy,
"nopadding_llama": NoPaddingLlamaModelInferPolicy,
}

__all__ = ["LlamaModelInferPolicy", "model_polic_map"]
__all__ = ["PaddingLlamaModelInferPolicy", "NoPaddingLlamaModelInferPolicy", "model_polic_map"]
Loading

0 comments on commit e8f0642

Please sign in to comment.