Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Llama3-text vLLM integration] Modify Llama3 text model (new and old codebase) forward apis for vLLM compatibility #16292

Merged
merged 7 commits into from
Dec 24, 2024
96 changes: 85 additions & 11 deletions models/demos/llama3/tt/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def prefill_forward_text(self, tokens: torch.Tensor, page_table=None, kv_cache=N
), "page_table must be a torch.Tensor when passing into prefill_forward"

for user_id in range(batch):
logger.info(f"Prefilling User {user_id + 1}")
seq_len = prompt_lens[user_id]
last_token_idx = seq_len - 1

Expand All @@ -76,6 +77,8 @@ def prefill_forward_text(self, tokens: torch.Tensor, page_table=None, kv_cache=N
# Since we give unpadded_seq_len, only the tile containing the last token is returned
output_logits[user_id] = logits

logger.info(f"Finished prefill for all users up to {batch_seq_len} tokens, Starting decode...")

return output_logits

def prefill_forward_single_user_text(self, tokens, page_table, user_id, last_token_idx, kv_cache=None):
Expand Down Expand Up @@ -162,14 +165,40 @@ def prefill_forward_single_user_text(self, tokens, page_table, user_id, last_tok
return logits

def decode_forward_text(
self,
tokens,
start_pos,
page_table=None,
kv_cache=None,
enable_trace=True,
read_from_device=True,
):
decode_kwargs = {
"current_pos": start_pos,
"tokens": tokens,
"page_table": page_table,
"kv_cache": kv_cache,
}
if enable_trace:
tt_logits = self._easy_trace_text(**decode_kwargs)
else:
tt_logits = self._decode_forward_no_trace_text(**decode_kwargs)

if read_from_device:
return self.read_decode_output(tt_logits, tokens.shape[0])
else:
return tt_logits

def _decode_forward_no_trace_text(
self,
tokens,
current_pos,
page_table=None,
kv_cache=None,
):
"""
Performs text decode step.
Returns logits
Returns tt_logits on device
"""
tt_tokens, tt_current_pos, tt_rot_mats, tt_page_table = self.model.prepare_inputs_decode(
tokens, current_pos, page_table
Expand All @@ -180,38 +209,41 @@ def decode_forward_text(
tt_current_pos,
rot_mats=tt_rot_mats,
page_table=tt_page_table,
kv_cache=kv_cache,
)

logits = self.model.process_output_decode(tt_logits)
return logits
return tt_logits

def capture_trace_text(
def _capture_trace_text(
self,
tokens,
current_pos,
page_table=None,
kv_cache=None,
):
"""
Captures a trace for the decode_forward method.
"""

# Compile run
self.decode_forward_text(tokens, current_pos, page_table)
self._decode_forward_no_trace_text(tokens, current_pos, page_table=page_table, kv_cache=kv_cache)
logger.info("Done Compiling Model")

# Get inputs ready for trace run
host_inputs = self.model.prepare_decode_inputs_host(tokens, current_pos, page_table)
host_inputs = self.model.prepare_decode_inputs_host(tokens, current_pos, page_table=page_table)

device_inputs = copy_host_to_device(host_inputs, mesh_device=self.mesh_device)

trace_id = ttnn.begin_trace_capture(self.mesh_device, cq_id=0)
transformed_inputs = self.model.transform_decode_inputs_device(*device_inputs)
tt_out_trace = self.model.ttnn_decode_forward(*transformed_inputs)
tt_out_trace = self.model.ttnn_decode_forward(*transformed_inputs, kv_cache=kv_cache)

ttnn.end_trace_capture(self.mesh_device, trace_id, cq_id=0)
logger.info("Done Capturing Decode Trace")

return trace_id, tt_out_trace, *device_inputs

def decode_forward_trace_text(
def _decode_forward_trace_text(
self,
trace_id,
device_inputs,
Expand All @@ -220,6 +252,9 @@ def decode_forward_trace_text(
current_pos,
page_table=None,
):
"""
Executes the trace for the decode_forward method but does not read back outputs.
"""
host_inputs = self.model.prepare_decode_inputs_host(tokens, current_pos, page_table)

device_inputs = copy_host_to_device(
Expand All @@ -229,9 +264,36 @@ def decode_forward_trace_text(

ttnn.execute_trace(self.mesh_device, trace_id, cq_id=0, blocking=False)

logits = self.model.process_output_decode(tt_out_trace)
return tt_out_trace

return logits
def _easy_trace_text(
self,
tokens,
current_pos,
page_table=None,
kv_cache=None,
):
"""
Tracing is easy! Just call this method and we'll handle tracing for you.
"""
if not hasattr(self, "trace_id_text"):
trace_id, tt_out_trace, *device_inputs = self._capture_trace_text(
tokens, current_pos, page_table=page_table, kv_cache=kv_cache
)
self.trace_id_text = trace_id
self.trace_inputs_text = device_inputs
self.trace_output_text = tt_out_trace

trace_logits_rm = self._decode_forward_trace_text(
self.trace_id_text,
self.trace_inputs_text,
self.trace_output_text,
tokens,
current_pos,
page_table=page_table,
)

return trace_logits_rm

def _prefill_forward_single_user(
self,
Expand Down Expand Up @@ -325,7 +387,7 @@ def prefill_forward(
output_full_text_row_masked_out_masks = []

for user_id in range(batch):
print(f"Prefilling User {user_id}")
logger.info(f"Prefilling User {user_id + 1}")
seq_len = prompt_lens[user_id]
(
xattn_caches,
Expand Down Expand Up @@ -833,3 +895,15 @@ def _get_prefill_user_page_table(self, page_table, kv_cache, prefill_len):
block_size = get_block_size(kv_cache)
num_blocks = num_blocks_in_seq(prefill_len, block_size)
return page_table[:, :num_blocks]

## Destructor (used to delete ttnn trace if exists)

def __del__(self):
if hasattr(self, "trace_id"):
ttnn.release_trace(self.mesh_device, self.trace_id)

if hasattr(self, "trace_id_text"):
ttnn.release_trace(self.mesh_device, self.trace_id_text)

if hasattr(super(LlamaGenerator, self), "__del__"):
super().__del__()
48 changes: 48 additions & 0 deletions models/demos/llama3/tt/generator_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,12 @@
import torch
import PIL
from llama_models.llama3.api.chat_format import create_vision_mask
from llama_models.llama3.api.tokenizer import Tokenizer
import ttnn

from models.demos.llama3.tt.generator import LlamaGenerator
from models.demos.llama3.tt.llama_model import TtTransformer
from models.demos.llama3.tt.model_config import LlamaOptimizations, TtModelArgs
from models.demos.llama3.demo.simple_vision_demo import create_multimodal_model
from models.utility_functions import nearest_32

Expand Down Expand Up @@ -106,3 +110,47 @@ def prefill_forward(
kv_cache=kv_cache,
cross_page_table=cross_page_table,
)


class TtLlamaForCausalLM(LlamaGenerator):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

@classmethod
def initialize_vllm_model(cls, hf_config, mesh_device, max_batch_size, n_layers=None):
instruct_mode = "Instruct" in hf_config._name_or_path
max_seq_len = 131072 # TODO: modify this for different models/devices
optimizations = LlamaOptimizations.performance # TODO: maybe change to accuracy
dtype = ttnn.bfloat8_b

# Load model args, weights
model_args = TtModelArgs(
mesh_device,
instruct=instruct_mode,
max_batch_size=max_batch_size,
optimizations=optimizations,
max_seq_len=max_seq_len,
)
if n_layers is not None:
model_args.n_layers = n_layers
state_dict = model_args.load_state_dict()

tt_model = TtTransformer(
args=model_args,
mesh_device=mesh_device,
dtype=dtype,
state_dict=state_dict,
weight_cache_path=model_args.weight_cache_path(dtype),
use_paged_kv_cache=True,
)
return cls(tt_model, model_args, mesh_device)

@property
def cache_path(self):
return self.model_args.model_cache_path

def prefill_forward(self, *args, **kwargs):
return super().prefill_forward_text(*args, **kwargs)

def decode_forward(self, *args, **kwargs):
return super().decode_forward_text(*args, **kwargs)
8 changes: 4 additions & 4 deletions models/demos/llama3/tt/llama_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,12 +322,12 @@ def sample_host(tt_input, mesh_device, temperature=0.6, top_p=0.08, on_host=True

def get_padded_prefill_len(seq_len):
"""
If seq_len is less than 32, pad to 32
If seq_len is more than 32, pad to whichever is smaller: a power of 2 or a multiple of 1024
If seq_len is less than 128, pad to 128
If seq_len is more than 128, pad to whichever is smaller: a power of 2 or a multiple of 1024
TODO: Generalize for max_mm_seq_len different from 1024
"""
if seq_len <= 32:
return 32
if seq_len <= 128:
return 128
pow_2_pad = nearest_pow_2(seq_len)
mult_1024_pad = 1024 * math.ceil(seq_len / 1024)
min_extended_pad = min(pow_2_pad, mult_1024_pad)
Expand Down
32 changes: 22 additions & 10 deletions models/demos/llama3/tt/llama_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import ttnn
import torch
import torch.nn as nn
from tqdm import tqdm
from models.demos.llama3.tt.llama_decoder import TtTransformerBlock
from models.common.rmsnorm import RMSNorm
import ttnn
Expand All @@ -28,6 +29,7 @@ def __init__(
state_dict,
weight_cache_path,
paged_attention_config=None,
use_paged_kv_cache=False,
):
super().__init__()
self.args = args
Expand Down Expand Up @@ -69,8 +71,9 @@ def __init__(
layer_num=i,
transformation_mats=self.trans_mats_dict,
paged_attention_config=paged_attention_config,
use_paged_kv_cache=use_paged_kv_cache,
)
for i in range(self.n_layers)
for i in tqdm(range(self.n_layers))
]
self.norm = DistributedNorm(
RMSNorm(
Expand Down Expand Up @@ -107,9 +110,8 @@ def prepare_inputs_prefill(self, tokens, start_pos=0, page_table=None, chunk_pag

tokens = tokens.reshape(1, 1, 1, -1)
S = tokens.shape[-1]
dims = (None, -1) if self.args.is_galaxy else (None, None)
dims = (None, None) # replicate
mesh_mapper = ttnn.ShardTensor2dMesh(self.mesh_device, dims=dims, mesh_shape=self.args.cluster_shape)

tokens = ttnn.from_torch(
tokens,
device=self.mesh_device,
Expand Down Expand Up @@ -170,19 +172,20 @@ def prepare_decode_inputs_host(self, tokens, current_pos, page_table=None):
Inputs are torch tensors or python types. Outputs are ttnn tensors on host.
NOTE: Tokens and current_pos are padded to batch
"""
B = tokens.shape[-1]
B = tokens.shape[0]
assert current_pos.shape[0] == B, "Batch size mismatch"
assert B == self.args.max_batch_size, "Batch size must be equal to max_batch_size"

dims = (None, -1) if self.args.is_galaxy else (None, None)
mesh_mapper = ttnn.ShardTensor2dMesh(self.mesh_device, dims=dims, mesh_shape=self.args.cluster_shape)

tokens = ttnn.from_torch(
tokens,
tokens.view(-1),
device=None,
dtype=ttnn.uint32,
mesh_mapper=mesh_mapper,
)
tokens = ttnn.unsqueeze_to_4D(tokens)

rot_current_pos = torch.maximum(
current_pos, torch.tensor(0, dtype=torch.int64)
Expand Down Expand Up @@ -240,17 +243,19 @@ def process_output_prefill(self, tt_out, last_token_idx):
)[0, 0, last_token_idx, :]
return logits

def process_output_decode(self, tt_out):
def process_output_decode(self, tt_out, B, S=1):
"""
Input is ttnn device tensor of logits. Output is torch logits tensor
"""
if self.args.num_devices > 1:
tt_out = ttnn.all_gather(tt_out, dim=3, num_links=1, topology=ttnn.Topology.Linear)
tt_out_rm = ttnn.untilize(tt_out, use_multicore=True)
tt_out = ttnn.untilize(tt_out, use_multicore=True)
if self.args.num_devices > 1:
return ttnn.to_torch(ttnn.get_device_tensors(tt_out_rm)[0]).float()
tt_out = ttnn.to_torch(ttnn.get_device_tensors(tt_out)[0]).float()
else:
return ttnn.to_torch(tt_out_rm).float()
tt_out = ttnn.to_torch(tt_out).float()
tt_out = tt_out[:, :, :B, :].view(B, S, -1)
return tt_out

def ttnn_prefill_forward(
self,
Expand Down Expand Up @@ -280,7 +285,14 @@ def ttnn_prefill_forward(
kv_cache=kv_cache,
)

def ttnn_decode_forward(self, x, current_pos, rot_mats, page_table=None, kv_cache=None):
def ttnn_decode_forward(
self,
x,
current_pos,
rot_mats,
page_table=None,
kv_cache=None,
):
"""
This method will take device tensors and any other args to run forward.
It returns ttnn device tensors.
Expand Down
2 changes: 1 addition & 1 deletion models/demos/llama3/tt/llama_rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def get_rot_idxs(self, position_idxs, on_host=False):
position_idxs,
dtype=ttnn.uint32,
layout=ttnn.ROW_MAJOR_LAYOUT,
mesh_mapper=ReplicateTensorToMesh(self.device) if self.num_devices > 1 else None,
mesh_mapper=ReplicateTensorToMesh(self.device) if self.is_mesh_device else None,
)
else: # On device
rot_idxs = ttnn.as_tensor(
Expand Down
9 changes: 8 additions & 1 deletion models/demos/t3000/llama2_70b/tt/generator_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

# SPDX-License-Identifier: Apache-2.0

from typing import Union
from dataclasses import dataclass
from pathlib import Path
import json
import torch

from models.demos.t3000.llama2_70b.tt.llama_generation import TtLlamaModelForGeneration
from models.demos.t3000.llama2_70b.tt.llama_common import (
Expand Down Expand Up @@ -41,6 +41,10 @@ class TTArgs:
model_config, ckpt_dir, _, cache_path = setup_llama_env(
llama_version=llama_version,
)

mesh_rows = t3k_mesh_device.shape.num_rows
mesh_cols = t3k_mesh_device.shape.num_cols
assert mesh_rows == 2 and mesh_cols == 4, f"Invalid mesh device shape: {mesh_rows}x{mesh_cols}"
check_mesh_device(t3k_mesh_device, model_config)

# initialize arg classes
Expand All @@ -67,3 +71,6 @@ class TTArgs:
@property
def cache_path(self):
return self.tt_model.cache_path

def prefill_forward(self, tokens: torch.Tensor, page_table, kv_cache, prompt_lens):
return super().prefill_forward(tokens, 0, page_table, kv_cache, prompt_lens)
Loading
Loading