Skip to content

Commit

Permalink
[Llama3-text vLLM integration] Modify Llama3 text model (new and old …
Browse files Browse the repository at this point in the history
…codebase) forward apis for vLLM compatibility (#16292)
  • Loading branch information
skhorasganiTT authored Dec 24, 2024
1 parent 8457be4 commit 9f24a71
Show file tree
Hide file tree
Showing 7 changed files with 238 additions and 51 deletions.
96 changes: 85 additions & 11 deletions models/demos/llama3/tt/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def prefill_forward_text(self, tokens: torch.Tensor, page_table=None, kv_cache=N
), "page_table must be a torch.Tensor when passing into prefill_forward"

for user_id in range(batch):
logger.info(f"Prefilling User {user_id + 1}")
seq_len = prompt_lens[user_id]
last_token_idx = seq_len - 1

Expand All @@ -76,6 +77,8 @@ def prefill_forward_text(self, tokens: torch.Tensor, page_table=None, kv_cache=N
# Since we give unpadded_seq_len, only the tile containing the last token is returned
output_logits[user_id] = logits

logger.info(f"Finished prefill for all users up to {batch_seq_len} tokens, Starting decode...")

return output_logits

def prefill_forward_single_user_text(self, tokens, page_table, user_id, last_token_idx, kv_cache=None):
Expand Down Expand Up @@ -162,14 +165,40 @@ def prefill_forward_single_user_text(self, tokens, page_table, user_id, last_tok
return logits

def decode_forward_text(
self,
tokens,
start_pos,
page_table=None,
kv_cache=None,
enable_trace=True,
read_from_device=True,
):
decode_kwargs = {
"current_pos": start_pos,
"tokens": tokens,
"page_table": page_table,
"kv_cache": kv_cache,
}
if enable_trace:
tt_logits = self._easy_trace_text(**decode_kwargs)
else:
tt_logits = self._decode_forward_no_trace_text(**decode_kwargs)

if read_from_device:
return self.read_decode_output(tt_logits, tokens.shape[0])
else:
return tt_logits

def _decode_forward_no_trace_text(
self,
tokens,
current_pos,
page_table=None,
kv_cache=None,
):
"""
Performs text decode step.
Returns logits
Returns tt_logits on device
"""
tt_tokens, tt_current_pos, tt_rot_mats, tt_page_table = self.model.prepare_inputs_decode(
tokens, current_pos, page_table
Expand All @@ -180,38 +209,41 @@ def decode_forward_text(
tt_current_pos,
rot_mats=tt_rot_mats,
page_table=tt_page_table,
kv_cache=kv_cache,
)

logits = self.model.process_output_decode(tt_logits)
return logits
return tt_logits

def capture_trace_text(
def _capture_trace_text(
self,
tokens,
current_pos,
page_table=None,
kv_cache=None,
):
"""
Captures a trace for the decode_forward method.
"""

# Compile run
self.decode_forward_text(tokens, current_pos, page_table)
self._decode_forward_no_trace_text(tokens, current_pos, page_table=page_table, kv_cache=kv_cache)
logger.info("Done Compiling Model")

# Get inputs ready for trace run
host_inputs = self.model.prepare_decode_inputs_host(tokens, current_pos, page_table)
host_inputs = self.model.prepare_decode_inputs_host(tokens, current_pos, page_table=page_table)

device_inputs = copy_host_to_device(host_inputs, mesh_device=self.mesh_device)

trace_id = ttnn.begin_trace_capture(self.mesh_device, cq_id=0)
transformed_inputs = self.model.transform_decode_inputs_device(*device_inputs)
tt_out_trace = self.model.ttnn_decode_forward(*transformed_inputs)
tt_out_trace = self.model.ttnn_decode_forward(*transformed_inputs, kv_cache=kv_cache)

ttnn.end_trace_capture(self.mesh_device, trace_id, cq_id=0)
logger.info("Done Capturing Decode Trace")

return trace_id, tt_out_trace, *device_inputs

def decode_forward_trace_text(
def _decode_forward_trace_text(
self,
trace_id,
device_inputs,
Expand All @@ -220,6 +252,9 @@ def decode_forward_trace_text(
current_pos,
page_table=None,
):
"""
Executes the trace for the decode_forward method but does not read back outputs.
"""
host_inputs = self.model.prepare_decode_inputs_host(tokens, current_pos, page_table)

device_inputs = copy_host_to_device(
Expand All @@ -229,9 +264,36 @@ def decode_forward_trace_text(

ttnn.execute_trace(self.mesh_device, trace_id, cq_id=0, blocking=False)

logits = self.model.process_output_decode(tt_out_trace)
return tt_out_trace

return logits
def _easy_trace_text(
self,
tokens,
current_pos,
page_table=None,
kv_cache=None,
):
"""
Tracing is easy! Just call this method and we'll handle tracing for you.
"""
if not hasattr(self, "trace_id_text"):
trace_id, tt_out_trace, *device_inputs = self._capture_trace_text(
tokens, current_pos, page_table=page_table, kv_cache=kv_cache
)
self.trace_id_text = trace_id
self.trace_inputs_text = device_inputs
self.trace_output_text = tt_out_trace

trace_logits_rm = self._decode_forward_trace_text(
self.trace_id_text,
self.trace_inputs_text,
self.trace_output_text,
tokens,
current_pos,
page_table=page_table,
)

return trace_logits_rm

def _prefill_forward_single_user(
self,
Expand Down Expand Up @@ -325,7 +387,7 @@ def prefill_forward(
output_full_text_row_masked_out_masks = []

for user_id in range(batch):
print(f"Prefilling User {user_id}")
logger.info(f"Prefilling User {user_id + 1}")
seq_len = prompt_lens[user_id]
(
xattn_caches,
Expand Down Expand Up @@ -833,3 +895,15 @@ def _get_prefill_user_page_table(self, page_table, kv_cache, prefill_len):
block_size = get_block_size(kv_cache)
num_blocks = num_blocks_in_seq(prefill_len, block_size)
return page_table[:, :num_blocks]

## Destructor (used to delete ttnn trace if exists)

def __del__(self):
if hasattr(self, "trace_id"):
ttnn.release_trace(self.mesh_device, self.trace_id)

if hasattr(self, "trace_id_text"):
ttnn.release_trace(self.mesh_device, self.trace_id_text)

if hasattr(super(LlamaGenerator, self), "__del__"):
super().__del__()
48 changes: 48 additions & 0 deletions models/demos/llama3/tt/generator_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,12 @@
import torch
import PIL
from llama_models.llama3.api.chat_format import create_vision_mask
from llama_models.llama3.api.tokenizer import Tokenizer
import ttnn

from models.demos.llama3.tt.generator import LlamaGenerator
from models.demos.llama3.tt.llama_model import TtTransformer
from models.demos.llama3.tt.model_config import LlamaOptimizations, TtModelArgs
from models.demos.llama3.demo.simple_vision_demo import create_multimodal_model
from models.utility_functions import nearest_32

Expand Down Expand Up @@ -106,3 +110,47 @@ def prefill_forward(
kv_cache=kv_cache,
cross_page_table=cross_page_table,
)


class TtLlamaForCausalLM(LlamaGenerator):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

@classmethod
def initialize_vllm_model(cls, hf_config, mesh_device, max_batch_size, n_layers=None):
instruct_mode = "Instruct" in hf_config._name_or_path
max_seq_len = 131072 # TODO: modify this for different models/devices
optimizations = LlamaOptimizations.performance # TODO: maybe change to accuracy
dtype = ttnn.bfloat8_b

# Load model args, weights
model_args = TtModelArgs(
mesh_device,
instruct=instruct_mode,
max_batch_size=max_batch_size,
optimizations=optimizations,
max_seq_len=max_seq_len,
)
if n_layers is not None:
model_args.n_layers = n_layers
state_dict = model_args.load_state_dict()

tt_model = TtTransformer(
args=model_args,
mesh_device=mesh_device,
dtype=dtype,
state_dict=state_dict,
weight_cache_path=model_args.weight_cache_path(dtype),
use_paged_kv_cache=True,
)
return cls(tt_model, model_args, mesh_device)

@property
def cache_path(self):
return self.model_args.model_cache_path

def prefill_forward(self, *args, **kwargs):
return super().prefill_forward_text(*args, **kwargs)

def decode_forward(self, *args, **kwargs):
return super().decode_forward_text(*args, **kwargs)
8 changes: 4 additions & 4 deletions models/demos/llama3/tt/llama_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,12 +322,12 @@ def sample_host(tt_input, mesh_device, temperature=0.6, top_p=0.08, on_host=True

def get_padded_prefill_len(seq_len):
"""
If seq_len is less than 32, pad to 32
If seq_len is more than 32, pad to whichever is smaller: a power of 2 or a multiple of 1024
If seq_len is less than 128, pad to 128
If seq_len is more than 128, pad to whichever is smaller: a power of 2 or a multiple of 1024
TODO: Generalize for max_mm_seq_len different from 1024
"""
if seq_len <= 32:
return 32
if seq_len <= 128:
return 128
pow_2_pad = nearest_pow_2(seq_len)
mult_1024_pad = 1024 * math.ceil(seq_len / 1024)
min_extended_pad = min(pow_2_pad, mult_1024_pad)
Expand Down
32 changes: 22 additions & 10 deletions models/demos/llama3/tt/llama_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import ttnn
import torch
import torch.nn as nn
from tqdm import tqdm
from models.demos.llama3.tt.llama_decoder import TtTransformerBlock
from models.common.rmsnorm import RMSNorm
import ttnn
Expand All @@ -28,6 +29,7 @@ def __init__(
state_dict,
weight_cache_path,
paged_attention_config=None,
use_paged_kv_cache=False,
):
super().__init__()
self.args = args
Expand Down Expand Up @@ -69,8 +71,9 @@ def __init__(
layer_num=i,
transformation_mats=self.trans_mats_dict,
paged_attention_config=paged_attention_config,
use_paged_kv_cache=use_paged_kv_cache,
)
for i in range(self.n_layers)
for i in tqdm(range(self.n_layers))
]
self.norm = DistributedNorm(
RMSNorm(
Expand Down Expand Up @@ -107,9 +110,8 @@ def prepare_inputs_prefill(self, tokens, start_pos=0, page_table=None, chunk_pag

tokens = tokens.reshape(1, 1, 1, -1)
S = tokens.shape[-1]
dims = (None, -1) if self.args.is_galaxy else (None, None)
dims = (None, None) # replicate
mesh_mapper = ttnn.ShardTensor2dMesh(self.mesh_device, dims=dims, mesh_shape=self.args.cluster_shape)

tokens = ttnn.from_torch(
tokens,
device=self.mesh_device,
Expand Down Expand Up @@ -170,19 +172,20 @@ def prepare_decode_inputs_host(self, tokens, current_pos, page_table=None):
Inputs are torch tensors or python types. Outputs are ttnn tensors on host.
NOTE: Tokens and current_pos are padded to batch
"""
B = tokens.shape[-1]
B = tokens.shape[0]
assert current_pos.shape[0] == B, "Batch size mismatch"
assert B == self.args.max_batch_size, "Batch size must be equal to max_batch_size"

dims = (None, -1) if self.args.is_galaxy else (None, None)
mesh_mapper = ttnn.ShardTensor2dMesh(self.mesh_device, dims=dims, mesh_shape=self.args.cluster_shape)

tokens = ttnn.from_torch(
tokens,
tokens.view(-1),
device=None,
dtype=ttnn.uint32,
mesh_mapper=mesh_mapper,
)
tokens = ttnn.unsqueeze_to_4D(tokens)

rot_current_pos = torch.maximum(
current_pos, torch.tensor(0, dtype=torch.int64)
Expand Down Expand Up @@ -240,17 +243,19 @@ def process_output_prefill(self, tt_out, last_token_idx):
)[0, 0, last_token_idx, :]
return logits

def process_output_decode(self, tt_out):
def process_output_decode(self, tt_out, B, S=1):
"""
Input is ttnn device tensor of logits. Output is torch logits tensor
"""
if self.args.num_devices > 1:
tt_out = ttnn.all_gather(tt_out, dim=3, num_links=1, topology=ttnn.Topology.Linear)
tt_out_rm = ttnn.untilize(tt_out, use_multicore=True)
tt_out = ttnn.untilize(tt_out, use_multicore=True)
if self.args.num_devices > 1:
return ttnn.to_torch(ttnn.get_device_tensors(tt_out_rm)[0]).float()
tt_out = ttnn.to_torch(ttnn.get_device_tensors(tt_out)[0]).float()
else:
return ttnn.to_torch(tt_out_rm).float()
tt_out = ttnn.to_torch(tt_out).float()
tt_out = tt_out[:, :, :B, :].view(B, S, -1)
return tt_out

def ttnn_prefill_forward(
self,
Expand Down Expand Up @@ -280,7 +285,14 @@ def ttnn_prefill_forward(
kv_cache=kv_cache,
)

def ttnn_decode_forward(self, x, current_pos, rot_mats, page_table=None, kv_cache=None):
def ttnn_decode_forward(
self,
x,
current_pos,
rot_mats,
page_table=None,
kv_cache=None,
):
"""
This method will take device tensors and any other args to run forward.
It returns ttnn device tensors.
Expand Down
2 changes: 1 addition & 1 deletion models/demos/llama3/tt/llama_rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def get_rot_idxs(self, position_idxs, on_host=False):
position_idxs,
dtype=ttnn.uint32,
layout=ttnn.ROW_MAJOR_LAYOUT,
mesh_mapper=ReplicateTensorToMesh(self.device) if self.num_devices > 1 else None,
mesh_mapper=ReplicateTensorToMesh(self.device) if self.is_mesh_device else None,
)
else: # On device
rot_idxs = ttnn.as_tensor(
Expand Down
9 changes: 8 additions & 1 deletion models/demos/t3000/llama2_70b/tt/generator_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

# SPDX-License-Identifier: Apache-2.0

from typing import Union
from dataclasses import dataclass
from pathlib import Path
import json
import torch

from models.demos.t3000.llama2_70b.tt.llama_generation import TtLlamaModelForGeneration
from models.demos.t3000.llama2_70b.tt.llama_common import (
Expand Down Expand Up @@ -41,6 +41,10 @@ class TTArgs:
model_config, ckpt_dir, _, cache_path = setup_llama_env(
llama_version=llama_version,
)

mesh_rows = t3k_mesh_device.shape.num_rows
mesh_cols = t3k_mesh_device.shape.num_cols
assert mesh_rows == 2 and mesh_cols == 4, f"Invalid mesh device shape: {mesh_rows}x{mesh_cols}"
check_mesh_device(t3k_mesh_device, model_config)

# initialize arg classes
Expand All @@ -67,3 +71,6 @@ class TTArgs:
@property
def cache_path(self):
return self.tt_model.cache_path

def prefill_forward(self, tokens: torch.Tensor, page_table, kv_cache, prompt_lens):
return super().prefill_forward(tokens, 0, page_table, kv_cache, prompt_lens)
Loading

0 comments on commit 9f24a71

Please sign in to comment.