From 9f24a7131625735e68b105193636cb135675dfb2 Mon Sep 17 00:00:00 2001 From: Salar Hosseini <159165450+skhorasganiTT@users.noreply.github.com> Date: Tue, 24 Dec 2024 17:54:15 -0500 Subject: [PATCH] [Llama3-text vLLM integration] Modify Llama3 text model (new and old codebase) forward apis for vLLM compatibility (#16292) --- models/demos/llama3/tt/generator.py | 96 ++++++++++++++++--- models/demos/llama3/tt/generator_vllm.py | 48 ++++++++++ models/demos/llama3/tt/llama_common.py | 8 +- models/demos/llama3/tt/llama_model.py | 32 +++++-- models/demos/llama3/tt/llama_rope.py | 2 +- .../t3000/llama2_70b/tt/generator_vllm.py | 9 +- .../t3000/llama2_70b/tt/llama_generation.py | 94 +++++++++++++----- 7 files changed, 238 insertions(+), 51 deletions(-) diff --git a/models/demos/llama3/tt/generator.py b/models/demos/llama3/tt/generator.py index fc213d9ed2b..2a4372360e1 100644 --- a/models/demos/llama3/tt/generator.py +++ b/models/demos/llama3/tt/generator.py @@ -55,6 +55,7 @@ def prefill_forward_text(self, tokens: torch.Tensor, page_table=None, kv_cache=N ), "page_table must be a torch.Tensor when passing into prefill_forward" for user_id in range(batch): + logger.info(f"Prefilling User {user_id + 1}") seq_len = prompt_lens[user_id] last_token_idx = seq_len - 1 @@ -76,6 +77,8 @@ def prefill_forward_text(self, tokens: torch.Tensor, page_table=None, kv_cache=N # Since we give unpadded_seq_len, only the tile containing the last token is returned output_logits[user_id] = logits + logger.info(f"Finished prefill for all users up to {batch_seq_len} tokens, Starting decode...") + return output_logits def prefill_forward_single_user_text(self, tokens, page_table, user_id, last_token_idx, kv_cache=None): @@ -162,14 +165,40 @@ def prefill_forward_single_user_text(self, tokens, page_table, user_id, last_tok return logits def decode_forward_text( + self, + tokens, + start_pos, + page_table=None, + kv_cache=None, + enable_trace=True, + read_from_device=True, + ): + decode_kwargs = { + "current_pos": start_pos, + "tokens": tokens, + "page_table": page_table, + "kv_cache": kv_cache, + } + if enable_trace: + tt_logits = self._easy_trace_text(**decode_kwargs) + else: + tt_logits = self._decode_forward_no_trace_text(**decode_kwargs) + + if read_from_device: + return self.read_decode_output(tt_logits, tokens.shape[0]) + else: + return tt_logits + + def _decode_forward_no_trace_text( self, tokens, current_pos, page_table=None, + kv_cache=None, ): """ Performs text decode step. - Returns logits + Returns tt_logits on device """ tt_tokens, tt_current_pos, tt_rot_mats, tt_page_table = self.model.prepare_inputs_decode( tokens, current_pos, page_table @@ -180,38 +209,41 @@ def decode_forward_text( tt_current_pos, rot_mats=tt_rot_mats, page_table=tt_page_table, + kv_cache=kv_cache, ) - logits = self.model.process_output_decode(tt_logits) - return logits + return tt_logits - def capture_trace_text( + def _capture_trace_text( self, tokens, current_pos, page_table=None, + kv_cache=None, ): """ Captures a trace for the decode_forward method. """ # Compile run - self.decode_forward_text(tokens, current_pos, page_table) + self._decode_forward_no_trace_text(tokens, current_pos, page_table=page_table, kv_cache=kv_cache) + logger.info("Done Compiling Model") # Get inputs ready for trace run - host_inputs = self.model.prepare_decode_inputs_host(tokens, current_pos, page_table) + host_inputs = self.model.prepare_decode_inputs_host(tokens, current_pos, page_table=page_table) device_inputs = copy_host_to_device(host_inputs, mesh_device=self.mesh_device) trace_id = ttnn.begin_trace_capture(self.mesh_device, cq_id=0) transformed_inputs = self.model.transform_decode_inputs_device(*device_inputs) - tt_out_trace = self.model.ttnn_decode_forward(*transformed_inputs) + tt_out_trace = self.model.ttnn_decode_forward(*transformed_inputs, kv_cache=kv_cache) ttnn.end_trace_capture(self.mesh_device, trace_id, cq_id=0) + logger.info("Done Capturing Decode Trace") return trace_id, tt_out_trace, *device_inputs - def decode_forward_trace_text( + def _decode_forward_trace_text( self, trace_id, device_inputs, @@ -220,6 +252,9 @@ def decode_forward_trace_text( current_pos, page_table=None, ): + """ + Executes the trace for the decode_forward method but does not read back outputs. + """ host_inputs = self.model.prepare_decode_inputs_host(tokens, current_pos, page_table) device_inputs = copy_host_to_device( @@ -229,9 +264,36 @@ def decode_forward_trace_text( ttnn.execute_trace(self.mesh_device, trace_id, cq_id=0, blocking=False) - logits = self.model.process_output_decode(tt_out_trace) + return tt_out_trace - return logits + def _easy_trace_text( + self, + tokens, + current_pos, + page_table=None, + kv_cache=None, + ): + """ + Tracing is easy! Just call this method and we'll handle tracing for you. + """ + if not hasattr(self, "trace_id_text"): + trace_id, tt_out_trace, *device_inputs = self._capture_trace_text( + tokens, current_pos, page_table=page_table, kv_cache=kv_cache + ) + self.trace_id_text = trace_id + self.trace_inputs_text = device_inputs + self.trace_output_text = tt_out_trace + + trace_logits_rm = self._decode_forward_trace_text( + self.trace_id_text, + self.trace_inputs_text, + self.trace_output_text, + tokens, + current_pos, + page_table=page_table, + ) + + return trace_logits_rm def _prefill_forward_single_user( self, @@ -325,7 +387,7 @@ def prefill_forward( output_full_text_row_masked_out_masks = [] for user_id in range(batch): - print(f"Prefilling User {user_id}") + logger.info(f"Prefilling User {user_id + 1}") seq_len = prompt_lens[user_id] ( xattn_caches, @@ -833,3 +895,15 @@ def _get_prefill_user_page_table(self, page_table, kv_cache, prefill_len): block_size = get_block_size(kv_cache) num_blocks = num_blocks_in_seq(prefill_len, block_size) return page_table[:, :num_blocks] + + ## Destructor (used to delete ttnn trace if exists) + + def __del__(self): + if hasattr(self, "trace_id"): + ttnn.release_trace(self.mesh_device, self.trace_id) + + if hasattr(self, "trace_id_text"): + ttnn.release_trace(self.mesh_device, self.trace_id_text) + + if hasattr(super(LlamaGenerator, self), "__del__"): + super().__del__() diff --git a/models/demos/llama3/tt/generator_vllm.py b/models/demos/llama3/tt/generator_vllm.py index 7989aba9547..7ab11947416 100644 --- a/models/demos/llama3/tt/generator_vllm.py +++ b/models/demos/llama3/tt/generator_vllm.py @@ -6,8 +6,12 @@ import torch import PIL from llama_models.llama3.api.chat_format import create_vision_mask +from llama_models.llama3.api.tokenizer import Tokenizer +import ttnn from models.demos.llama3.tt.generator import LlamaGenerator +from models.demos.llama3.tt.llama_model import TtTransformer +from models.demos.llama3.tt.model_config import LlamaOptimizations, TtModelArgs from models.demos.llama3.demo.simple_vision_demo import create_multimodal_model from models.utility_functions import nearest_32 @@ -106,3 +110,47 @@ def prefill_forward( kv_cache=kv_cache, cross_page_table=cross_page_table, ) + + +class TtLlamaForCausalLM(LlamaGenerator): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @classmethod + def initialize_vllm_model(cls, hf_config, mesh_device, max_batch_size, n_layers=None): + instruct_mode = "Instruct" in hf_config._name_or_path + max_seq_len = 131072 # TODO: modify this for different models/devices + optimizations = LlamaOptimizations.performance # TODO: maybe change to accuracy + dtype = ttnn.bfloat8_b + + # Load model args, weights + model_args = TtModelArgs( + mesh_device, + instruct=instruct_mode, + max_batch_size=max_batch_size, + optimizations=optimizations, + max_seq_len=max_seq_len, + ) + if n_layers is not None: + model_args.n_layers = n_layers + state_dict = model_args.load_state_dict() + + tt_model = TtTransformer( + args=model_args, + mesh_device=mesh_device, + dtype=dtype, + state_dict=state_dict, + weight_cache_path=model_args.weight_cache_path(dtype), + use_paged_kv_cache=True, + ) + return cls(tt_model, model_args, mesh_device) + + @property + def cache_path(self): + return self.model_args.model_cache_path + + def prefill_forward(self, *args, **kwargs): + return super().prefill_forward_text(*args, **kwargs) + + def decode_forward(self, *args, **kwargs): + return super().decode_forward_text(*args, **kwargs) diff --git a/models/demos/llama3/tt/llama_common.py b/models/demos/llama3/tt/llama_common.py index 87f746c37f7..deac589b212 100644 --- a/models/demos/llama3/tt/llama_common.py +++ b/models/demos/llama3/tt/llama_common.py @@ -322,12 +322,12 @@ def sample_host(tt_input, mesh_device, temperature=0.6, top_p=0.08, on_host=True def get_padded_prefill_len(seq_len): """ - If seq_len is less than 32, pad to 32 - If seq_len is more than 32, pad to whichever is smaller: a power of 2 or a multiple of 1024 + If seq_len is less than 128, pad to 128 + If seq_len is more than 128, pad to whichever is smaller: a power of 2 or a multiple of 1024 TODO: Generalize for max_mm_seq_len different from 1024 """ - if seq_len <= 32: - return 32 + if seq_len <= 128: + return 128 pow_2_pad = nearest_pow_2(seq_len) mult_1024_pad = 1024 * math.ceil(seq_len / 1024) min_extended_pad = min(pow_2_pad, mult_1024_pad) diff --git a/models/demos/llama3/tt/llama_model.py b/models/demos/llama3/tt/llama_model.py index c33451bc4a7..14c9bdb0cf9 100644 --- a/models/demos/llama3/tt/llama_model.py +++ b/models/demos/llama3/tt/llama_model.py @@ -7,6 +7,7 @@ import ttnn import torch import torch.nn as nn +from tqdm import tqdm from models.demos.llama3.tt.llama_decoder import TtTransformerBlock from models.common.rmsnorm import RMSNorm import ttnn @@ -28,6 +29,7 @@ def __init__( state_dict, weight_cache_path, paged_attention_config=None, + use_paged_kv_cache=False, ): super().__init__() self.args = args @@ -69,8 +71,9 @@ def __init__( layer_num=i, transformation_mats=self.trans_mats_dict, paged_attention_config=paged_attention_config, + use_paged_kv_cache=use_paged_kv_cache, ) - for i in range(self.n_layers) + for i in tqdm(range(self.n_layers)) ] self.norm = DistributedNorm( RMSNorm( @@ -107,9 +110,8 @@ def prepare_inputs_prefill(self, tokens, start_pos=0, page_table=None, chunk_pag tokens = tokens.reshape(1, 1, 1, -1) S = tokens.shape[-1] - dims = (None, -1) if self.args.is_galaxy else (None, None) + dims = (None, None) # replicate mesh_mapper = ttnn.ShardTensor2dMesh(self.mesh_device, dims=dims, mesh_shape=self.args.cluster_shape) - tokens = ttnn.from_torch( tokens, device=self.mesh_device, @@ -170,7 +172,7 @@ def prepare_decode_inputs_host(self, tokens, current_pos, page_table=None): Inputs are torch tensors or python types. Outputs are ttnn tensors on host. NOTE: Tokens and current_pos are padded to batch """ - B = tokens.shape[-1] + B = tokens.shape[0] assert current_pos.shape[0] == B, "Batch size mismatch" assert B == self.args.max_batch_size, "Batch size must be equal to max_batch_size" @@ -178,11 +180,12 @@ def prepare_decode_inputs_host(self, tokens, current_pos, page_table=None): mesh_mapper = ttnn.ShardTensor2dMesh(self.mesh_device, dims=dims, mesh_shape=self.args.cluster_shape) tokens = ttnn.from_torch( - tokens, + tokens.view(-1), device=None, dtype=ttnn.uint32, mesh_mapper=mesh_mapper, ) + tokens = ttnn.unsqueeze_to_4D(tokens) rot_current_pos = torch.maximum( current_pos, torch.tensor(0, dtype=torch.int64) @@ -240,17 +243,19 @@ def process_output_prefill(self, tt_out, last_token_idx): )[0, 0, last_token_idx, :] return logits - def process_output_decode(self, tt_out): + def process_output_decode(self, tt_out, B, S=1): """ Input is ttnn device tensor of logits. Output is torch logits tensor """ if self.args.num_devices > 1: tt_out = ttnn.all_gather(tt_out, dim=3, num_links=1, topology=ttnn.Topology.Linear) - tt_out_rm = ttnn.untilize(tt_out, use_multicore=True) + tt_out = ttnn.untilize(tt_out, use_multicore=True) if self.args.num_devices > 1: - return ttnn.to_torch(ttnn.get_device_tensors(tt_out_rm)[0]).float() + tt_out = ttnn.to_torch(ttnn.get_device_tensors(tt_out)[0]).float() else: - return ttnn.to_torch(tt_out_rm).float() + tt_out = ttnn.to_torch(tt_out).float() + tt_out = tt_out[:, :, :B, :].view(B, S, -1) + return tt_out def ttnn_prefill_forward( self, @@ -280,7 +285,14 @@ def ttnn_prefill_forward( kv_cache=kv_cache, ) - def ttnn_decode_forward(self, x, current_pos, rot_mats, page_table=None, kv_cache=None): + def ttnn_decode_forward( + self, + x, + current_pos, + rot_mats, + page_table=None, + kv_cache=None, + ): """ This method will take device tensors and any other args to run forward. It returns ttnn device tensors. diff --git a/models/demos/llama3/tt/llama_rope.py b/models/demos/llama3/tt/llama_rope.py index 29843d49683..06406a4eb2d 100644 --- a/models/demos/llama3/tt/llama_rope.py +++ b/models/demos/llama3/tt/llama_rope.py @@ -132,7 +132,7 @@ def get_rot_idxs(self, position_idxs, on_host=False): position_idxs, dtype=ttnn.uint32, layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(self.device) if self.num_devices > 1 else None, + mesh_mapper=ReplicateTensorToMesh(self.device) if self.is_mesh_device else None, ) else: # On device rot_idxs = ttnn.as_tensor( diff --git a/models/demos/t3000/llama2_70b/tt/generator_vllm.py b/models/demos/t3000/llama2_70b/tt/generator_vllm.py index 86dbb12f25b..3855efcb8e5 100644 --- a/models/demos/t3000/llama2_70b/tt/generator_vllm.py +++ b/models/demos/t3000/llama2_70b/tt/generator_vllm.py @@ -2,10 +2,10 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Union from dataclasses import dataclass from pathlib import Path import json +import torch from models.demos.t3000.llama2_70b.tt.llama_generation import TtLlamaModelForGeneration from models.demos.t3000.llama2_70b.tt.llama_common import ( @@ -41,6 +41,10 @@ class TTArgs: model_config, ckpt_dir, _, cache_path = setup_llama_env( llama_version=llama_version, ) + + mesh_rows = t3k_mesh_device.shape.num_rows + mesh_cols = t3k_mesh_device.shape.num_cols + assert mesh_rows == 2 and mesh_cols == 4, f"Invalid mesh device shape: {mesh_rows}x{mesh_cols}" check_mesh_device(t3k_mesh_device, model_config) # initialize arg classes @@ -67,3 +71,6 @@ class TTArgs: @property def cache_path(self): return self.tt_model.cache_path + + def prefill_forward(self, tokens: torch.Tensor, page_table, kv_cache, prompt_lens): + return super().prefill_forward(tokens, 0, page_table, kv_cache, prompt_lens) diff --git a/models/demos/t3000/llama2_70b/tt/llama_generation.py b/models/demos/t3000/llama2_70b/tt/llama_generation.py index f0de257f730..0aee8f7bf77 100644 --- a/models/demos/t3000/llama2_70b/tt/llama_generation.py +++ b/models/demos/t3000/llama2_70b/tt/llama_generation.py @@ -153,12 +153,12 @@ def decode_forward_trace( # Run TT model ttnn.execute_trace(self.mesh_device, trace_id, cq_id=0, blocking=False) if read_from_device: - logits = self.read_forward_trace(tt_logits, unpadded_batch=batch) + logits = self.read_decode_output(tt_logits, unpadded_batch=batch) return logits else: return tt_logits - def read_forward_trace(self, tt_logits, unpadded_batch=None): + def read_decode_output(self, tt_logits, unpadded_batch=None): updated_tt_logits = ttnn.from_device(tt_logits) logits = self._process_logits(updated_tt_logits) @@ -169,34 +169,71 @@ def read_forward_trace(self, tt_logits, unpadded_batch=None): return logits - def decode_forward(self, tokens: torch.Tensor, start_pos: int, page_table=None, kv_cache=None): + def decode_forward( + self, + tokens: torch.Tensor, + start_pos: int, + page_table=None, + kv_cache=None, + enable_trace=False, + read_from_device=True, + ): batch = tokens.shape[0] - # Get inputs on device - tt_inp_emb, start_pos, rot_mat, cache_idxs_tt, tt_page_table = self.tt_model.prepare_device_inputs_decode( - tokens, start_pos, mode="decode", page_table=page_table - ) - - tt_logits = self.tt_model( - tt_inp_emb, - rot_mat, - start_pos, - cache_idxs=cache_idxs_tt, - page_table=tt_page_table, - kv_cache=kv_cache, - mode="decode", - ) + if not enable_trace: + # Get inputs on device + tt_inp_emb, start_pos, rot_mat, cache_idxs_tt, tt_page_table = self.tt_model.prepare_device_inputs_decode( + tokens, start_pos, mode="decode", page_table=page_table + ) - # del tt_inp_emb - # del rot_mat + tt_logits = self.tt_model( + tt_inp_emb, + rot_mat, + start_pos, + cache_idxs=cache_idxs_tt, + page_table=tt_page_table, + kv_cache=kv_cache, + mode="decode", + ) + else: + tt_logits = self._easy_trace(tokens, start_pos, page_table, kv_cache) - logits = self._process_logits(tt_logits) + if read_from_device: + return self.read_decode_output(tt_logits, unpadded_batch=batch) + else: + return tt_logits - logits = logits.permute(2, 1, 0, 3).squeeze().unsqueeze(1) # [batch, 1, vocab_size] - logits = logits[:batch] # Remove padded users - # del tt_logits + def _easy_trace(self, tokens, start_pos, page_table=None, kv_cache=None): + """ + Tracing is easy! Just call this method and we'll handle tracing for you. + """ + if not hasattr(self, "trace_id"): + trace_id, tt_inp, rot_idxs_tt, cache_idxs_tt, tt_logits, tt_page_table = self.capture_trace( + tokens, start_pos, page_table=page_table, kv_cache=kv_cache + ) + self.trace_id = trace_id + self.trace_inputs = { + "tt_inp": tt_inp, + "rot_idxs_tt": rot_idxs_tt, + "cache_idxs_tt": cache_idxs_tt, + "tt_page_table": tt_page_table, + } + self.trace_output = tt_logits + + trace_logits_rm = self.decode_forward_trace( + tokens, + start_pos, + self.trace_id, + self.trace_inputs["tt_inp"], + self.trace_inputs["rot_idxs_tt"], + self.trace_inputs["cache_idxs_tt"], + self.trace_output, + page_table=page_table, + tt_page_table=self.trace_inputs["tt_page_table"], + read_from_device=False, + ) - return logits + return trace_logits_rm def prefill_forward_single_user( self, tokens: torch.Tensor, start_pos: int, user_id: int, last_token_idx=None, page_table=None, kv_cache=None @@ -361,6 +398,15 @@ def _process_logits(self, tt_logits): ) return logits[..., : self.params.vocab_size].float() + ## Destructor (used to delete ttnn trace if exists) + + def __del__(self): + if hasattr(self, "trace_id"): + self.delete_trace(self.trace_id) + + if hasattr(super(TtLlamaModelForGeneration, self), "__del__"): + super().__del__() + def _get_prefill_user_page_table(page_table, kv_cache, prefill_len): # Ensure page_table is not padded with extra blocks for paged_fill_cache to work properly