Skip to content

Commit

Permalink
Add support for TT Llama3 text models (1B,3B,8B,70B-new) (#48)
Browse files Browse the repository at this point in the history
  • Loading branch information
skhorasganiTT authored Dec 24, 2024
1 parent e3dd6f5 commit 294dd4b
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 115 deletions.
42 changes: 28 additions & 14 deletions examples/offline_inference_tt.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,16 @@
from vllm.engine.multiprocessing.client import MQLLMEngineClient
from vllm.model_executor.models.mllama import MLLAMA_IMAGE_TOKEN, MLLAMA_IMAGE_TOKEN_ID

# Import and register models from tt-metal
from models.demos.t3000.llama2_70b.tt.generator_vllm import TtLlamaForCausalLM
from models.demos.llama3.tt.generator_vllm import TtMllamaForConditionalGeneration
ModelRegistry.register_model("TTLlamaForCausalLM", TtLlamaForCausalLM)
ModelRegistry.register_model("TTMllamaForConditionalGeneration", TtMllamaForConditionalGeneration)
def register_tt_models():
from models.demos.llama3.tt.generator_vllm import TtLlamaForCausalLM
# To use old version of llama70b tt-metal model, use the import below
# from models.demos.t3000.llama2_70b.tt.generator_vllm import TtLlamaForCausalLM
ModelRegistry.register_model("TTLlamaForCausalLM", TtLlamaForCausalLM)

from models.demos.llama3.tt.generator_vllm import TtMllamaForConditionalGeneration
ModelRegistry.register_model("TTMllamaForConditionalGeneration", TtMllamaForConditionalGeneration)

register_tt_models() # Import and register models from tt-metal


def get_sample_multi_modal_llama_inputs():
Expand All @@ -46,7 +51,22 @@ def get_sample_multi_modal_llama_inputs():
return inputs


def check_tt_model_supported(model):
supported_models = [
"meta-llama/Meta-Llama-3.1-70B",
"meta-llama/Meta-Llama-3.1-70B-Instruct",
"meta-llama/Meta-Llama-3.1-8B",
"meta-llama/Meta-Llama-3.1-8B-Instruct",
"meta-llama/Llama-3.2-1B",
"meta-llama/Llama-3.2-1B-Instruct",
"meta-llama/Llama-3.2-3B",
"meta-llama/Llama-3.2-3B-Instruct",
"meta-llama/Llama-3.2-11B-Vision-Instruct",
]
assert model in supported_models, f"Invalid model: {model}"

def run_inference(
model,
prompts_json,
max_tokens=128,
max_seqs_in_batch=32,
Expand All @@ -59,15 +79,7 @@ def run_inference(
disable_async_output_proc=False,
multi_modal=False,
):
if multi_modal:
model = "meta-llama/Llama-3.2-11B-Vision-Instruct"
if os.environ.get("MESH_DEVICE") is None:
os.environ["MESH_DEVICE"] = "N300"
else:
assert os.environ["MESH_DEVICE"] in ["N300", "T3K_LINE"], "Invalid MESH_DEVICE for multi-modal inference"
else:
model = "meta-llama/Meta-Llama-3.1-70B"
os.environ["MESH_DEVICE"] = "T3K_RING"
check_tt_model_supported(model)

# LLM args
engine_kw_args = {
Expand Down Expand Up @@ -216,6 +228,7 @@ async def generate_tokens_async(llm : MQLLMEngineClient, prompts, sampling_param

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="meta-llama/Meta-Llama-3.1-70B", help="Model name")
parser.add_argument("--prompts_json", type=str, default="tt_metal/prompts.json", help="Path to JSON file containing prompts")
parser.add_argument("--measure_perf", action="store_true", help="Measure performance")
parser.add_argument("--perf_prompt_len", type=int, default=128, help="Length of dummy prompts for performance measurement")
Expand All @@ -230,6 +243,7 @@ async def generate_tokens_async(llm : MQLLMEngineClient, prompts, sampling_param
args = parser.parse_args()

run_inference(
args.model,
args.prompts_json,
measure_perf=args.measure_perf,
perf_prompt_len=args.perf_prompt_len,
Expand Down
24 changes: 5 additions & 19 deletions examples/server_example_tt.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,21 @@
import argparse
import os
import sys
import runpy

from vllm import ModelRegistry
from offline_inference_tt import register_tt_models, check_tt_model_supported

# Import and register models from tt-metal
from models.demos.t3000.llama2_70b.tt.generator_vllm import TtLlamaForCausalLM
from models.demos.llama3.tt.generator_vllm import TtMllamaForConditionalGeneration
ModelRegistry.register_model("TTLlamaForCausalLM", TtLlamaForCausalLM)
ModelRegistry.register_model("TTMllamaForConditionalGeneration", TtMllamaForConditionalGeneration)
register_tt_models() # Import and register models from tt-metal


def main():
parser = argparse.ArgumentParser()
parser.add_argument("--multi_modal", action="store_true", help="Run multi-modal inference with Llama3.2-11b")
parser.add_argument("--model", type=str, default="meta-llama/Meta-Llama-3.1-70B", help="Model name")
args = parser.parse_args()

if args.multi_modal:
model = "meta-llama/Llama-3.2-11B-Vision-Instruct"
if os.environ.get("MESH_DEVICE") is None:
os.environ["MESH_DEVICE"] = "N300"
else:
assert os.environ["MESH_DEVICE"] in ["N300", "T3K_LINE"], "Invalid MESH_DEVICE for multi-modal inference"
sys.argv.remove("--multi_modal") # remove the flag for the API server
else:
model = "meta-llama/Meta-Llama-3.1-70B"
os.environ["MESH_DEVICE"] = "T3K_RING"
check_tt_model_supported(args.model)

sys.argv.extend([
"--model", model,
"--model", args.model,
"--block_size", "64",
"--max_num_seqs", "32",
"--max_model_len", "131072",
Expand Down
31 changes: 19 additions & 12 deletions tt_metal/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ Git-checkout the following branches in each repo separately:
To run Meta-Llama-3.1/3.2, it is required to have access to the model on Hugging Face. To gain access:
1. Request access on Hugging Face:
- Llama-3.1: [https://huggingface.co/meta-llama/Meta-Llama-3.1-70B](https://huggingface.co/meta-llama/Meta-Llama-3.1-70B)
- Llama-3.2: [https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct)
- Llama-3.2: [https://huggingface.co/meta-llama/Llama-3.2-1B](https://huggingface.co/meta-llama/Llama-3.2-1B)
- Llama-3.2-Vision: [https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct)
2. Once you have received access, create and copy your access token from the settings tab on Hugging Face.
3. Run this code in python and paste your access token:
```python
Expand All @@ -46,40 +47,46 @@ To run Meta-Llama-3.1/3.2, it is required to have access to the model on Hugging
## Preparing the tt-metal models

1. Ensure that `$PYTHONPATH` contains the path to tt-metal (should already have been done when installing tt-metal)
2. For the desired model, follow the setup instructions (if any) for the corresponding tt-metal demo. E.g. For Llama-3.1-70B, follow the [demo instructions](https://github.com/tenstorrent/tt-metal/tree/main/models/demos/t3000/llama3_70b) for preparing the weights and environment variables.
2. For the desired model, follow the setup instructions (if any) for the corresponding tt-metal demo. E.g. For Llama-3.1/3.2, follow the [demo instructions](https://github.com/tenstorrent/tt-metal/tree/main/models/demos/llama3) for preparing the weights and environment variables.

## Running the offline inference example

### Llama-3.1-70B
### Llama-3.1/3.2 Text Models (1B, 3B, 8B, 70B)

To generate tokens for sample prompts:
To generate tokens (Llama70B) for sample prompts (with batch size 32):
```python
WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml python examples/offline_inference_tt.py
MESH_DEVICE=T3K_LINE WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml python examples/offline_inference_tt.py
```

To measure performance for a single batch (with the default prompt length of 128 tokens):
To measure performance (Llama70B) for a single batch of 32 prompts (with the default prompt length of 128 tokens):
```python
WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml python examples/offline_inference_tt.py --measure_perf
MESH_DEVICE=T3K_LINE WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml python examples/offline_inference_tt.py --measure_perf
```

**Note**: By default, the inference example will run with Llama-3.1-70B. To run with Llama-3.1-8B, Llama-3.2-1B, or Llama-3.2-3B, ensure that the apprioriate environment variables are set as per the [demo instructions](https://github.com/tenstorrent/tt-metal/tree/main/models/demos/llama3), then set `MESH_DEVICE=<device>` (valid options for `<device>` are `N150`, `N300`, or `T3K_LINE`) and one of the following:
- Llama-3.1-8B: `--model "meta-llama/Meta-Llama-3.1-8B"`
- Llama-3.2-1B: `--model "meta-llama/Llama-3.2-1B"`
- Llama-3.2-3B: `--model "meta-llama/Llama-3.2-3B"`

### Llama-3.2-11B-Vision-Instruct

To generate tokens for sample prompts:
```python
WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml python examples/offline_inference_tt.py --multi_modal --max_seqs_in_batch 16 --num_repeat_prompts 8
MESH_DEVICE=N300 WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml python examples/offline_inference_tt.py --model "meta-llama/Llama-3.2-11B-Vision-Instruct" --multi_modal --max_seqs_in_batch 16 --num_repeat_prompts 8
```

To measure performance for a single batch (with the default prompt length of 128 tokens):
```python
WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml python examples/offline_inference_tt.py --measure_perf --multi_modal --max_seqs_in_batch 16 --num_repeat_prompts 4
MESH_DEVICE=N300 WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml python examples/offline_inference_tt.py --model "meta-llama/Llama-3.2-11B-Vision-Instruct" --measure_perf --multi_modal --max_seqs_in_batch 16 --num_repeat_prompts 4
```

**Note**: By default, the multi-modal inference example will run with `MESH_DEVICE=N300`. To run on T3000, set `MESH_DEVICE=T3K_LINE` and `--max_seqs_in_batch 32 --num_repeat_prompts 16`.
**Note**: To run on T3000, set `MESH_DEVICE=T3K_LINE` and `--max_seqs_in_batch 32 --num_repeat_prompts 16`.

## Running the server example

```python
VLLM_RPC_TIMEOUT=100000 WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml python examples/server_example_tt.py
VLLM_RPC_TIMEOUT=100000 MESH_DEVICE=T3K_LINE WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml python examples/server_example_tt.py
```

**Note**: By default, the server will run with Llama-3.1-70B. To run with Llama-3.2-11B-Vision-Instruct instead, add `--multi_modal`.
**Note**: By default, the server will run with Llama-3.1-70B. To run with other Llama versions, set `MESH_DEVICE` and `--model` as described in [Running the offline inference example](#running-the-offline-inference-example).

91 changes: 24 additions & 67 deletions vllm/worker/tt_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,6 @@ def __init__(
self.block_size = cache_config.block_size

self.trace_mode = trace_mode # whether to use ttnn tracing for model execution
self.execute_trace_kwargs = None # kw args for trace execution (populated during first decode execution)

self.cached_step_outputs: List[torch.Tensor] = [] # Only used for multi-step execution

Expand Down Expand Up @@ -448,81 +447,48 @@ def _send_prev_step_async_out(self, model_input: TTModelInput, step_idx):
def _execute_model_single_step(self, model_input: TTModelInput, kv_caches: List[torch.Tensor], is_decode, async_out_proc_per_trace=False, step_idx=0):
execute_model_kwargs = {
"tokens": model_input.input_tokens,
"start_pos": model_input.input_positions,
"page_table": model_input.block_tables,
"kv_cache": kv_caches,
"prompt_lens": model_input.prompt_lens,
**(model_input.multi_modal_kwargs or {}),
}
if not is_decode:
execute_model_kwargs["prompt_lens"] = model_input.prompt_lens
else:
execute_model_kwargs["start_pos"] = model_input.input_positions
if model_input.cross_block_tables is not None:
execute_model_kwargs["cross_page_table"] = model_input.cross_block_tables

if not self.model_config.is_encoder_decoder_model: # Forward for decoder-only
if self.trace_mode and is_decode: # Trace mode for decode
# Remove prompt_lens from execute_model_kwargs since it's not used for decode
execute_model_kwargs.pop("prompt_lens")

# Capture trace for the first decode execution
if self.execute_trace_kwargs is None:
logger.info("Capturing trace for first decode execution")
trace_id, tt_inp, rot_idxs_tt, cache_idxs_tt, tt_logits, tt_page_table = self.model.capture_trace(
**execute_model_kwargs
)
self.execute_trace_kwargs = {
"trace_id": trace_id,
"tt_inp": tt_inp,
"rot_idxs_tt": rot_idxs_tt,
"cache_idxs_tt": cache_idxs_tt,
"tt_logits": tt_logits,
"tt_page_table": tt_page_table,
"read_from_device": False,
}

# Remove kv_cache from execute_model_kwargs since it doesn't need to be copied to device for trace execution
execute_model_kwargs.pop("kv_cache")

tt_logits = self.model.decode_forward_trace(
**execute_model_kwargs, **self.execute_trace_kwargs
)
if async_out_proc_per_trace:
# trigger output processor on host while device is executing next step
self._send_prev_step_async_out(model_input, step_idx)
logits = self.model.read_forward_trace(tt_logits, model_input.unpadded_batch_size)
else: # prefill or non-traced decode
logits = self.model.forward(**execute_model_kwargs) # [batch_size, seq_len, vocab_size]
else: # Forward for encoder-decoder (may need to be updated for future models)
# TODO: remove different forward calls once TT models can manage intermediate outputs internally
if not is_decode:
# Remove start_pos from execute_model_kwargs since it's not used for prefill
execute_model_kwargs.pop("start_pos")

logits, cross_attention_masks, full_text_row_masked_out_mask = self.model.prefill_forward(**execute_model_kwargs)

# Save encoder-decoder data for use in subsequent decode steps
if not is_decode:
outputs = self.model.prefill_forward(**execute_model_kwargs)

if self.model_config.is_encoder_decoder_model:
# Save encoder-decoder data for use in subsequent decode steps (may need to be updated for future models)
logits, cross_attention_masks, full_text_row_masked_out_mask = outputs
if self.cached_enc_dec_data is None:
self.cached_enc_dec_data = {}
for i, seq_id in enumerate(model_input.seq_groups):
enc_dec_data = {"cross_attention_masks": cross_attention_masks[i],
"full_text_row_masked_out_mask": full_text_row_masked_out_mask[i]}
self.cached_enc_dec_data[seq_id] = enc_dec_data
else:
# Remove prompt_lens from execute_model_kwargs since it's not used for decode
execute_model_kwargs.pop("prompt_lens")
logits = outputs # [batch_size, seq_len, vocab_size]
else:
if self.model_config.is_encoder_decoder_model:
# Use encoder-decoder data from prefill step
cross_attention_masks = [self.cached_enc_dec_data[seq_id]["cross_attention_masks"] for seq_id in model_input.seq_groups]
full_text_row_masked_out_mask = [self.cached_enc_dec_data[seq_id]["full_text_row_masked_out_mask"] for seq_id in model_input.seq_groups]

enc_dec_kwargs = {"cross_attention_masks": cross_attention_masks,
"full_text_row_masked_out_mask": full_text_row_masked_out_mask}

tt_logits = self.model.decode_forward(
**execute_model_kwargs, **enc_dec_kwargs, enable_trace=self.trace_mode, read_from_device=False
)
if async_out_proc_per_trace:
# trigger output processor on host while device is executing next step
self._send_prev_step_async_out(model_input, step_idx)
logits = self.model.read_decode_output(tt_logits, model_input.unpadded_batch_size)
else:
enc_dec_kwargs = {}

tt_logits = self.model.decode_forward(
**execute_model_kwargs, **enc_dec_kwargs, enable_trace=self.trace_mode, read_from_device=False
)
if async_out_proc_per_trace:
# trigger output processor on host while device is executing next step
self._send_prev_step_async_out(model_input, step_idx)
logits = self.model.read_decode_output(tt_logits, model_input.unpadded_batch_size)

# Note: for other devices, vLLM applies vllm.model_executor.layers.logits_processor::LogitsProcessor::_apply_logits_processors on logits, we don't use this
# Note: for other devices, vLLM applies vllm.model_executor.layers.sampler::Sampler for sampling tokens, we don't use this
Expand All @@ -540,13 +506,4 @@ def _sample_tokens(self, logits, tt_sampling_params : TTSamplingParams):
p=tt_sampling_params.top_p,
k=tt_sampling_params.top_k,
temperature=tt_sampling_params.temperature
)

## Destructor (used to delete ttnn trace if using trace mode)

def __del__(self):
if self.trace_mode and self.execute_trace_kwargs is not None:
self.model.delete_trace(self.execute_trace_kwargs["trace_id"])

if hasattr(super(TTModelRunner, self), '__del__'):
super().__del__()
)
10 changes: 7 additions & 3 deletions vllm/worker/tt_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,11 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
appended to.
"""
# TODO: Add proper implementation which runs profiling on TT devices
max_tokens_all_users = 131072 # Note: includes num vision tokens for multi-modal
if ("meta-llama/Meta-Llama-3.1-8B" in self.model_config.model and
len(self.device_config.device.get_devices()) == 1): # Llama8B on N150
max_tokens_all_users = 65536
else:
max_tokens_all_users = 131072 # Note: includes num vision tokens for multi-modal
num_tt_blocks = math.ceil(max_tokens_all_users / self.cache_config.block_size)
num_tt_blocks = int(num_tt_blocks * 1.01) # Add 1% to account for vLLM's watermark_blocks
num_cpu_blocks = 0
Expand Down Expand Up @@ -411,7 +415,7 @@ def _get_dispatch_core_config(self, device_params):
def _open_mesh_device(self):
num_devices_available = len(ttnn.get_device_ids())

mesh_grid_dict = {"N150": (1, 1), "N300": (1, 2), "T3K_LINE": (1, 8), "T3K_RING": (2, 4)}
mesh_grid_dict = {"N150": (1, 1), "N300": (1, 2), "T3K_LINE": (1, 8), "T3K_RING": (2, 4), "TG": (8, 4)}
mesh_device = os.environ.get("MESH_DEVICE")
if mesh_device is not None:
assert mesh_device in mesh_grid_dict, f"Invalid MESH_DEVICE: {mesh_device}"
Expand All @@ -425,7 +429,7 @@ def _open_mesh_device(self):
assert f"Requested mesh grid shape {mesh_grid} is larger than number of available devices {num_devices_available}"

if self.trace_mode:
device_params = {"trace_region_size": 14227456} # TODO: make this configurable
device_params = {"trace_region_size": 23887872} # TODO: make this configurable
else:
device_params = {}
mesh_device = ttnn.open_mesh_device(
Expand Down

0 comments on commit 294dd4b

Please sign in to comment.