diff --git a/examples/offline_inference_tt.py b/examples/offline_inference_tt.py index e1c06b442bb35..e7841f686a92f 100644 --- a/examples/offline_inference_tt.py +++ b/examples/offline_inference_tt.py @@ -18,11 +18,16 @@ from vllm.engine.multiprocessing.client import MQLLMEngineClient from vllm.model_executor.models.mllama import MLLAMA_IMAGE_TOKEN, MLLAMA_IMAGE_TOKEN_ID -# Import and register models from tt-metal -from models.demos.t3000.llama2_70b.tt.generator_vllm import TtLlamaForCausalLM -from models.demos.llama3.tt.generator_vllm import TtMllamaForConditionalGeneration -ModelRegistry.register_model("TTLlamaForCausalLM", TtLlamaForCausalLM) -ModelRegistry.register_model("TTMllamaForConditionalGeneration", TtMllamaForConditionalGeneration) +def register_tt_models(): + from models.demos.llama3.tt.generator_vllm import TtLlamaForCausalLM + # To use old version of llama70b tt-metal model, use the import below + # from models.demos.t3000.llama2_70b.tt.generator_vllm import TtLlamaForCausalLM + ModelRegistry.register_model("TTLlamaForCausalLM", TtLlamaForCausalLM) + + from models.demos.llama3.tt.generator_vllm import TtMllamaForConditionalGeneration + ModelRegistry.register_model("TTMllamaForConditionalGeneration", TtMllamaForConditionalGeneration) + +register_tt_models() # Import and register models from tt-metal def get_sample_multi_modal_llama_inputs(): @@ -46,7 +51,22 @@ def get_sample_multi_modal_llama_inputs(): return inputs +def check_tt_model_supported(model): + supported_models = [ + "meta-llama/Meta-Llama-3.1-70B", + "meta-llama/Meta-Llama-3.1-70B-Instruct", + "meta-llama/Meta-Llama-3.1-8B", + "meta-llama/Meta-Llama-3.1-8B-Instruct", + "meta-llama/Llama-3.2-1B", + "meta-llama/Llama-3.2-1B-Instruct", + "meta-llama/Llama-3.2-3B", + "meta-llama/Llama-3.2-3B-Instruct", + "meta-llama/Llama-3.2-11B-Vision-Instruct", + ] + assert model in supported_models, f"Invalid model: {model}" + def run_inference( + model, prompts_json, max_tokens=128, max_seqs_in_batch=32, @@ -59,15 +79,7 @@ def run_inference( disable_async_output_proc=False, multi_modal=False, ): - if multi_modal: - model = "meta-llama/Llama-3.2-11B-Vision-Instruct" - if os.environ.get("MESH_DEVICE") is None: - os.environ["MESH_DEVICE"] = "N300" - else: - assert os.environ["MESH_DEVICE"] in ["N300", "T3K_LINE"], "Invalid MESH_DEVICE for multi-modal inference" - else: - model = "meta-llama/Meta-Llama-3.1-70B" - os.environ["MESH_DEVICE"] = "T3K_RING" + check_tt_model_supported(model) # LLM args engine_kw_args = { @@ -216,6 +228,7 @@ async def generate_tokens_async(llm : MQLLMEngineClient, prompts, sampling_param if __name__ == "__main__": parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str, default="meta-llama/Meta-Llama-3.1-70B", help="Model name") parser.add_argument("--prompts_json", type=str, default="tt_metal/prompts.json", help="Path to JSON file containing prompts") parser.add_argument("--measure_perf", action="store_true", help="Measure performance") parser.add_argument("--perf_prompt_len", type=int, default=128, help="Length of dummy prompts for performance measurement") @@ -230,6 +243,7 @@ async def generate_tokens_async(llm : MQLLMEngineClient, prompts, sampling_param args = parser.parse_args() run_inference( + args.model, args.prompts_json, measure_perf=args.measure_perf, perf_prompt_len=args.perf_prompt_len, diff --git a/examples/server_example_tt.py b/examples/server_example_tt.py index 3e818d3cd30f2..952c9c8a60fdd 100644 --- a/examples/server_example_tt.py +++ b/examples/server_example_tt.py @@ -1,35 +1,21 @@ import argparse -import os import sys import runpy -from vllm import ModelRegistry +from offline_inference_tt import register_tt_models, check_tt_model_supported -# Import and register models from tt-metal -from models.demos.t3000.llama2_70b.tt.generator_vllm import TtLlamaForCausalLM -from models.demos.llama3.tt.generator_vllm import TtMllamaForConditionalGeneration -ModelRegistry.register_model("TTLlamaForCausalLM", TtLlamaForCausalLM) -ModelRegistry.register_model("TTMllamaForConditionalGeneration", TtMllamaForConditionalGeneration) +register_tt_models() # Import and register models from tt-metal def main(): parser = argparse.ArgumentParser() - parser.add_argument("--multi_modal", action="store_true", help="Run multi-modal inference with Llama3.2-11b") + parser.add_argument("--model", type=str, default="meta-llama/Meta-Llama-3.1-70B", help="Model name") args = parser.parse_args() - if args.multi_modal: - model = "meta-llama/Llama-3.2-11B-Vision-Instruct" - if os.environ.get("MESH_DEVICE") is None: - os.environ["MESH_DEVICE"] = "N300" - else: - assert os.environ["MESH_DEVICE"] in ["N300", "T3K_LINE"], "Invalid MESH_DEVICE for multi-modal inference" - sys.argv.remove("--multi_modal") # remove the flag for the API server - else: - model = "meta-llama/Meta-Llama-3.1-70B" - os.environ["MESH_DEVICE"] = "T3K_RING" + check_tt_model_supported(args.model) sys.argv.extend([ - "--model", model, + "--model", args.model, "--block_size", "64", "--max_num_seqs", "32", "--max_model_len", "131072", diff --git a/tt_metal/README.md b/tt_metal/README.md index d39b35f2d2584..7492a21740ceb 100644 --- a/tt_metal/README.md +++ b/tt_metal/README.md @@ -35,7 +35,8 @@ Git-checkout the following branches in each repo separately: To run Meta-Llama-3.1/3.2, it is required to have access to the model on Hugging Face. To gain access: 1. Request access on Hugging Face: - Llama-3.1: [https://huggingface.co/meta-llama/Meta-Llama-3.1-70B](https://huggingface.co/meta-llama/Meta-Llama-3.1-70B) - - Llama-3.2: [https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct) + - Llama-3.2: [https://huggingface.co/meta-llama/Llama-3.2-1B](https://huggingface.co/meta-llama/Llama-3.2-1B) + - Llama-3.2-Vision: [https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct) 2. Once you have received access, create and copy your access token from the settings tab on Hugging Face. 3. Run this code in python and paste your access token: ```python @@ -46,40 +47,46 @@ To run Meta-Llama-3.1/3.2, it is required to have access to the model on Hugging ## Preparing the tt-metal models 1. Ensure that `$PYTHONPATH` contains the path to tt-metal (should already have been done when installing tt-metal) -2. For the desired model, follow the setup instructions (if any) for the corresponding tt-metal demo. E.g. For Llama-3.1-70B, follow the [demo instructions](https://github.com/tenstorrent/tt-metal/tree/main/models/demos/t3000/llama3_70b) for preparing the weights and environment variables. +2. For the desired model, follow the setup instructions (if any) for the corresponding tt-metal demo. E.g. For Llama-3.1/3.2, follow the [demo instructions](https://github.com/tenstorrent/tt-metal/tree/main/models/demos/llama3) for preparing the weights and environment variables. ## Running the offline inference example -### Llama-3.1-70B +### Llama-3.1/3.2 Text Models (1B, 3B, 8B, 70B) -To generate tokens for sample prompts: +To generate tokens (Llama70B) for sample prompts (with batch size 32): ```python -WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml python examples/offline_inference_tt.py +MESH_DEVICE=T3K_LINE WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml python examples/offline_inference_tt.py ``` -To measure performance for a single batch (with the default prompt length of 128 tokens): +To measure performance (Llama70B) for a single batch of 32 prompts (with the default prompt length of 128 tokens): ```python -WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml python examples/offline_inference_tt.py --measure_perf +MESH_DEVICE=T3K_LINE WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml python examples/offline_inference_tt.py --measure_perf ``` +**Note**: By default, the inference example will run with Llama-3.1-70B. To run with Llama-3.1-8B, Llama-3.2-1B, or Llama-3.2-3B, ensure that the apprioriate environment variables are set as per the [demo instructions](https://github.com/tenstorrent/tt-metal/tree/main/models/demos/llama3), then set `MESH_DEVICE=` (valid options for `` are `N150`, `N300`, or `T3K_LINE`) and one of the following: +- Llama-3.1-8B: `--model "meta-llama/Meta-Llama-3.1-8B"` +- Llama-3.2-1B: `--model "meta-llama/Llama-3.2-1B"` +- Llama-3.2-3B: `--model "meta-llama/Llama-3.2-3B"` + ### Llama-3.2-11B-Vision-Instruct To generate tokens for sample prompts: ```python -WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml python examples/offline_inference_tt.py --multi_modal --max_seqs_in_batch 16 --num_repeat_prompts 8 +MESH_DEVICE=N300 WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml python examples/offline_inference_tt.py --model "meta-llama/Llama-3.2-11B-Vision-Instruct" --multi_modal --max_seqs_in_batch 16 --num_repeat_prompts 8 ``` To measure performance for a single batch (with the default prompt length of 128 tokens): ```python -WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml python examples/offline_inference_tt.py --measure_perf --multi_modal --max_seqs_in_batch 16 --num_repeat_prompts 4 +MESH_DEVICE=N300 WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml python examples/offline_inference_tt.py --model "meta-llama/Llama-3.2-11B-Vision-Instruct" --measure_perf --multi_modal --max_seqs_in_batch 16 --num_repeat_prompts 4 ``` -**Note**: By default, the multi-modal inference example will run with `MESH_DEVICE=N300`. To run on T3000, set `MESH_DEVICE=T3K_LINE` and `--max_seqs_in_batch 32 --num_repeat_prompts 16`. +**Note**: To run on T3000, set `MESH_DEVICE=T3K_LINE` and `--max_seqs_in_batch 32 --num_repeat_prompts 16`. ## Running the server example ```python -VLLM_RPC_TIMEOUT=100000 WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml python examples/server_example_tt.py +VLLM_RPC_TIMEOUT=100000 MESH_DEVICE=T3K_LINE WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml python examples/server_example_tt.py ``` -**Note**: By default, the server will run with Llama-3.1-70B. To run with Llama-3.2-11B-Vision-Instruct instead, add `--multi_modal`. +**Note**: By default, the server will run with Llama-3.1-70B. To run with other Llama versions, set `MESH_DEVICE` and `--model` as described in [Running the offline inference example](#running-the-offline-inference-example). + diff --git a/vllm/worker/tt_model_runner.py b/vllm/worker/tt_model_runner.py index 04725abd17271..88f8103a8de3f 100644 --- a/vllm/worker/tt_model_runner.py +++ b/vllm/worker/tt_model_runner.py @@ -119,7 +119,6 @@ def __init__( self.block_size = cache_config.block_size self.trace_mode = trace_mode # whether to use ttnn tracing for model execution - self.execute_trace_kwargs = None # kw args for trace execution (populated during first decode execution) self.cached_step_outputs: List[torch.Tensor] = [] # Only used for multi-step execution @@ -448,57 +447,23 @@ def _send_prev_step_async_out(self, model_input: TTModelInput, step_idx): def _execute_model_single_step(self, model_input: TTModelInput, kv_caches: List[torch.Tensor], is_decode, async_out_proc_per_trace=False, step_idx=0): execute_model_kwargs = { "tokens": model_input.input_tokens, - "start_pos": model_input.input_positions, "page_table": model_input.block_tables, "kv_cache": kv_caches, - "prompt_lens": model_input.prompt_lens, **(model_input.multi_modal_kwargs or {}), } + if not is_decode: + execute_model_kwargs["prompt_lens"] = model_input.prompt_lens + else: + execute_model_kwargs["start_pos"] = model_input.input_positions if model_input.cross_block_tables is not None: execute_model_kwargs["cross_page_table"] = model_input.cross_block_tables - if not self.model_config.is_encoder_decoder_model: # Forward for decoder-only - if self.trace_mode and is_decode: # Trace mode for decode - # Remove prompt_lens from execute_model_kwargs since it's not used for decode - execute_model_kwargs.pop("prompt_lens") - - # Capture trace for the first decode execution - if self.execute_trace_kwargs is None: - logger.info("Capturing trace for first decode execution") - trace_id, tt_inp, rot_idxs_tt, cache_idxs_tt, tt_logits, tt_page_table = self.model.capture_trace( - **execute_model_kwargs - ) - self.execute_trace_kwargs = { - "trace_id": trace_id, - "tt_inp": tt_inp, - "rot_idxs_tt": rot_idxs_tt, - "cache_idxs_tt": cache_idxs_tt, - "tt_logits": tt_logits, - "tt_page_table": tt_page_table, - "read_from_device": False, - } - - # Remove kv_cache from execute_model_kwargs since it doesn't need to be copied to device for trace execution - execute_model_kwargs.pop("kv_cache") - - tt_logits = self.model.decode_forward_trace( - **execute_model_kwargs, **self.execute_trace_kwargs - ) - if async_out_proc_per_trace: - # trigger output processor on host while device is executing next step - self._send_prev_step_async_out(model_input, step_idx) - logits = self.model.read_forward_trace(tt_logits, model_input.unpadded_batch_size) - else: # prefill or non-traced decode - logits = self.model.forward(**execute_model_kwargs) # [batch_size, seq_len, vocab_size] - else: # Forward for encoder-decoder (may need to be updated for future models) - # TODO: remove different forward calls once TT models can manage intermediate outputs internally - if not is_decode: - # Remove start_pos from execute_model_kwargs since it's not used for prefill - execute_model_kwargs.pop("start_pos") - - logits, cross_attention_masks, full_text_row_masked_out_mask = self.model.prefill_forward(**execute_model_kwargs) - - # Save encoder-decoder data for use in subsequent decode steps + if not is_decode: + outputs = self.model.prefill_forward(**execute_model_kwargs) + + if self.model_config.is_encoder_decoder_model: + # Save encoder-decoder data for use in subsequent decode steps (may need to be updated for future models) + logits, cross_attention_masks, full_text_row_masked_out_mask = outputs if self.cached_enc_dec_data is None: self.cached_enc_dec_data = {} for i, seq_id in enumerate(model_input.seq_groups): @@ -506,23 +471,24 @@ def _execute_model_single_step(self, model_input: TTModelInput, kv_caches: List[ "full_text_row_masked_out_mask": full_text_row_masked_out_mask[i]} self.cached_enc_dec_data[seq_id] = enc_dec_data else: - # Remove prompt_lens from execute_model_kwargs since it's not used for decode - execute_model_kwargs.pop("prompt_lens") - + logits = outputs # [batch_size, seq_len, vocab_size] + else: + if self.model_config.is_encoder_decoder_model: # Use encoder-decoder data from prefill step cross_attention_masks = [self.cached_enc_dec_data[seq_id]["cross_attention_masks"] for seq_id in model_input.seq_groups] full_text_row_masked_out_mask = [self.cached_enc_dec_data[seq_id]["full_text_row_masked_out_mask"] for seq_id in model_input.seq_groups] - enc_dec_kwargs = {"cross_attention_masks": cross_attention_masks, "full_text_row_masked_out_mask": full_text_row_masked_out_mask} - - tt_logits = self.model.decode_forward( - **execute_model_kwargs, **enc_dec_kwargs, enable_trace=self.trace_mode, read_from_device=False - ) - if async_out_proc_per_trace: - # trigger output processor on host while device is executing next step - self._send_prev_step_async_out(model_input, step_idx) - logits = self.model.read_decode_output(tt_logits, model_input.unpadded_batch_size) + else: + enc_dec_kwargs = {} + + tt_logits = self.model.decode_forward( + **execute_model_kwargs, **enc_dec_kwargs, enable_trace=self.trace_mode, read_from_device=False + ) + if async_out_proc_per_trace: + # trigger output processor on host while device is executing next step + self._send_prev_step_async_out(model_input, step_idx) + logits = self.model.read_decode_output(tt_logits, model_input.unpadded_batch_size) # Note: for other devices, vLLM applies vllm.model_executor.layers.logits_processor::LogitsProcessor::_apply_logits_processors on logits, we don't use this # Note: for other devices, vLLM applies vllm.model_executor.layers.sampler::Sampler for sampling tokens, we don't use this @@ -540,13 +506,4 @@ def _sample_tokens(self, logits, tt_sampling_params : TTSamplingParams): p=tt_sampling_params.top_p, k=tt_sampling_params.top_k, temperature=tt_sampling_params.temperature - ) - - ## Destructor (used to delete ttnn trace if using trace mode) - - def __del__(self): - if self.trace_mode and self.execute_trace_kwargs is not None: - self.model.delete_trace(self.execute_trace_kwargs["trace_id"]) - - if hasattr(super(TTModelRunner, self), '__del__'): - super().__del__() \ No newline at end of file + ) \ No newline at end of file diff --git a/vllm/worker/tt_worker.py b/vllm/worker/tt_worker.py index 9d9b9c8d5caec..3ba2f5cb29bcc 100644 --- a/vllm/worker/tt_worker.py +++ b/vllm/worker/tt_worker.py @@ -237,7 +237,11 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: appended to. """ # TODO: Add proper implementation which runs profiling on TT devices - max_tokens_all_users = 131072 # Note: includes num vision tokens for multi-modal + if ("meta-llama/Meta-Llama-3.1-8B" in self.model_config.model and + len(self.device_config.device.get_devices()) == 1): # Llama8B on N150 + max_tokens_all_users = 65536 + else: + max_tokens_all_users = 131072 # Note: includes num vision tokens for multi-modal num_tt_blocks = math.ceil(max_tokens_all_users / self.cache_config.block_size) num_tt_blocks = int(num_tt_blocks * 1.01) # Add 1% to account for vLLM's watermark_blocks num_cpu_blocks = 0 @@ -411,7 +415,7 @@ def _get_dispatch_core_config(self, device_params): def _open_mesh_device(self): num_devices_available = len(ttnn.get_device_ids()) - mesh_grid_dict = {"N150": (1, 1), "N300": (1, 2), "T3K_LINE": (1, 8), "T3K_RING": (2, 4)} + mesh_grid_dict = {"N150": (1, 1), "N300": (1, 2), "T3K_LINE": (1, 8), "T3K_RING": (2, 4), "TG": (8, 4)} mesh_device = os.environ.get("MESH_DEVICE") if mesh_device is not None: assert mesh_device in mesh_grid_dict, f"Invalid MESH_DEVICE: {mesh_device}" @@ -425,7 +429,7 @@ def _open_mesh_device(self): assert f"Requested mesh grid shape {mesh_grid} is larger than number of available devices {num_devices_available}" if self.trace_mode: - device_params = {"trace_region_size": 14227456} # TODO: make this configurable + device_params = {"trace_region_size": 23887872} # TODO: make this configurable else: device_params = {} mesh_device = ttnn.open_mesh_device(