Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for TT Llama3 text models (1B,3B,8B,70B-new) #48

Merged
merged 10 commits into from
Dec 24, 2024
42 changes: 28 additions & 14 deletions examples/offline_inference_tt.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,16 @@
from vllm.engine.multiprocessing.client import MQLLMEngineClient
from vllm.model_executor.models.mllama import MLLAMA_IMAGE_TOKEN, MLLAMA_IMAGE_TOKEN_ID

# Import and register models from tt-metal
from models.demos.t3000.llama2_70b.tt.generator_vllm import TtLlamaForCausalLM
from models.demos.llama3.tt.generator_vllm import TtMllamaForConditionalGeneration
ModelRegistry.register_model("TTLlamaForCausalLM", TtLlamaForCausalLM)
ModelRegistry.register_model("TTMllamaForConditionalGeneration", TtMllamaForConditionalGeneration)
def register_tt_models():
from models.demos.llama3.tt.generator_vllm import TtLlamaForCausalLM
# To use old version of llama70b tt-metal model, use the import below
# from models.demos.t3000.llama2_70b.tt.generator_vllm import TtLlamaForCausalLM
ModelRegistry.register_model("TTLlamaForCausalLM", TtLlamaForCausalLM)

from models.demos.llama3.tt.generator_vllm import TtMllamaForConditionalGeneration
ModelRegistry.register_model("TTMllamaForConditionalGeneration", TtMllamaForConditionalGeneration)

register_tt_models() # Import and register models from tt-metal


def get_sample_multi_modal_llama_inputs():
Expand All @@ -46,7 +51,22 @@ def get_sample_multi_modal_llama_inputs():
return inputs


def check_tt_model_supported(model):
supported_models = [
"meta-llama/Meta-Llama-3.1-70B",
"meta-llama/Meta-Llama-3.1-70B-Instruct",
"meta-llama/Meta-Llama-3.1-8B",
"meta-llama/Meta-Llama-3.1-8B-Instruct",
"meta-llama/Llama-3.2-1B",
"meta-llama/Llama-3.2-1B-Instruct",
"meta-llama/Llama-3.2-3B",
"meta-llama/Llama-3.2-3B-Instruct",
"meta-llama/Llama-3.2-11B-Vision-Instruct",
]
assert model in supported_models, f"Invalid model: {model}"

def run_inference(
model,
prompts_json,
max_tokens=128,
max_seqs_in_batch=32,
Expand All @@ -59,15 +79,7 @@ def run_inference(
disable_async_output_proc=False,
multi_modal=False,
):
if multi_modal:
model = "meta-llama/Llama-3.2-11B-Vision-Instruct"
if os.environ.get("MESH_DEVICE") is None:
os.environ["MESH_DEVICE"] = "N300"
else:
assert os.environ["MESH_DEVICE"] in ["N300", "T3K_LINE"], "Invalid MESH_DEVICE for multi-modal inference"
else:
model = "meta-llama/Meta-Llama-3.1-70B"
os.environ["MESH_DEVICE"] = "T3K_RING"
check_tt_model_supported(model)

# LLM args
engine_kw_args = {
Expand Down Expand Up @@ -216,6 +228,7 @@ async def generate_tokens_async(llm : MQLLMEngineClient, prompts, sampling_param

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="meta-llama/Meta-Llama-3.1-70B", help="Model name")
parser.add_argument("--prompts_json", type=str, default="tt_metal/prompts.json", help="Path to JSON file containing prompts")
parser.add_argument("--measure_perf", action="store_true", help="Measure performance")
parser.add_argument("--perf_prompt_len", type=int, default=128, help="Length of dummy prompts for performance measurement")
Expand All @@ -230,6 +243,7 @@ async def generate_tokens_async(llm : MQLLMEngineClient, prompts, sampling_param
args = parser.parse_args()

run_inference(
args.model,
args.prompts_json,
measure_perf=args.measure_perf,
perf_prompt_len=args.perf_prompt_len,
Expand Down
24 changes: 5 additions & 19 deletions examples/server_example_tt.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,21 @@
import argparse
import os
import sys
import runpy

from vllm import ModelRegistry
from offline_inference_tt import register_tt_models, check_tt_model_supported

# Import and register models from tt-metal
from models.demos.t3000.llama2_70b.tt.generator_vllm import TtLlamaForCausalLM
from models.demos.llama3.tt.generator_vllm import TtMllamaForConditionalGeneration
ModelRegistry.register_model("TTLlamaForCausalLM", TtLlamaForCausalLM)
ModelRegistry.register_model("TTMllamaForConditionalGeneration", TtMllamaForConditionalGeneration)
register_tt_models() # Import and register models from tt-metal


def main():
parser = argparse.ArgumentParser()
parser.add_argument("--multi_modal", action="store_true", help="Run multi-modal inference with Llama3.2-11b")
parser.add_argument("--model", type=str, default="meta-llama/Meta-Llama-3.1-70B", help="Model name")
args = parser.parse_args()

if args.multi_modal:
model = "meta-llama/Llama-3.2-11B-Vision-Instruct"
if os.environ.get("MESH_DEVICE") is None:
os.environ["MESH_DEVICE"] = "N300"
else:
assert os.environ["MESH_DEVICE"] in ["N300", "T3K_LINE"], "Invalid MESH_DEVICE for multi-modal inference"
sys.argv.remove("--multi_modal") # remove the flag for the API server
else:
model = "meta-llama/Meta-Llama-3.1-70B"
os.environ["MESH_DEVICE"] = "T3K_RING"
check_tt_model_supported(args.model)

sys.argv.extend([
"--model", model,
"--model", args.model,
"--block_size", "64",
"--max_num_seqs", "32",
"--max_model_len", "131072",
Expand Down
31 changes: 19 additions & 12 deletions tt_metal/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ Git-checkout the following branches in each repo separately:
To run Meta-Llama-3.1/3.2, it is required to have access to the model on Hugging Face. To gain access:
1. Request access on Hugging Face:
- Llama-3.1: [https://huggingface.co/meta-llama/Meta-Llama-3.1-70B](https://huggingface.co/meta-llama/Meta-Llama-3.1-70B)
- Llama-3.2: [https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct)
- Llama-3.2: [https://huggingface.co/meta-llama/Llama-3.2-1B](https://huggingface.co/meta-llama/Llama-3.2-1B)
skhorasganiTT marked this conversation as resolved.
Show resolved Hide resolved
- Llama-3.2-Vision: [https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct)
2. Once you have received access, create and copy your access token from the settings tab on Hugging Face.
3. Run this code in python and paste your access token:
```python
Expand All @@ -46,40 +47,46 @@ To run Meta-Llama-3.1/3.2, it is required to have access to the model on Hugging
## Preparing the tt-metal models

1. Ensure that `$PYTHONPATH` contains the path to tt-metal (should already have been done when installing tt-metal)
2. For the desired model, follow the setup instructions (if any) for the corresponding tt-metal demo. E.g. For Llama-3.1-70B, follow the [demo instructions](https://github.com/tenstorrent/tt-metal/tree/main/models/demos/t3000/llama3_70b) for preparing the weights and environment variables.
2. For the desired model, follow the setup instructions (if any) for the corresponding tt-metal demo. E.g. For Llama-3.1/3.2, follow the [demo instructions](https://github.com/tenstorrent/tt-metal/tree/main/models/demos/llama3) for preparing the weights and environment variables.

## Running the offline inference example

### Llama-3.1-70B
### Llama-3.1/3.2 Text Models (1B, 3B, 8B, 70B)

To generate tokens for sample prompts:
To generate tokens (Llama70B) for sample prompts (with batch size 32):
```python
WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml python examples/offline_inference_tt.py
skhorasganiTT marked this conversation as resolved.
Show resolved Hide resolved
MESH_DEVICE=T3K_LINE WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml python examples/offline_inference_tt.py
```

To measure performance for a single batch (with the default prompt length of 128 tokens):
To measure performance (Llama70B) for a single batch of 32 prompts (with the default prompt length of 128 tokens):
```python
WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml python examples/offline_inference_tt.py --measure_perf
MESH_DEVICE=T3K_LINE WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml python examples/offline_inference_tt.py --measure_perf
```

**Note**: By default, the inference example will run with Llama-3.1-70B. To run with Llama-3.1-8B, Llama-3.2-1B, or Llama-3.2-3B, ensure that the apprioriate environment variables are set as per the [demo instructions](https://github.com/tenstorrent/tt-metal/tree/main/models/demos/llama3), then set `MESH_DEVICE=<device>` (valid options for `<device>` are `N150`, `N300`, or `T3K_LINE`) and one of the following:
- Llama-3.1-8B: `--model "meta-llama/Meta-Llama-3.1-8B"`
- Llama-3.2-1B: `--model "meta-llama/Llama-3.2-1B"`
- Llama-3.2-3B: `--model "meta-llama/Llama-3.2-3B"`

### Llama-3.2-11B-Vision-Instruct

To generate tokens for sample prompts:
```python
WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml python examples/offline_inference_tt.py --multi_modal --max_seqs_in_batch 16 --num_repeat_prompts 8
MESH_DEVICE=N300 WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml python examples/offline_inference_tt.py --model "meta-llama/Llama-3.2-11B-Vision-Instruct" --multi_modal --max_seqs_in_batch 16 --num_repeat_prompts 8
```

To measure performance for a single batch (with the default prompt length of 128 tokens):
```python
WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml python examples/offline_inference_tt.py --measure_perf --multi_modal --max_seqs_in_batch 16 --num_repeat_prompts 4
MESH_DEVICE=N300 WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml python examples/offline_inference_tt.py --model "meta-llama/Llama-3.2-11B-Vision-Instruct" --measure_perf --multi_modal --max_seqs_in_batch 16 --num_repeat_prompts 4
```

**Note**: By default, the multi-modal inference example will run with `MESH_DEVICE=N300`. To run on T3000, set `MESH_DEVICE=T3K_LINE` and `--max_seqs_in_batch 32 --num_repeat_prompts 16`.
**Note**: To run on T3000, set `MESH_DEVICE=T3K_LINE` and `--max_seqs_in_batch 32 --num_repeat_prompts 16`.

## Running the server example

```python
VLLM_RPC_TIMEOUT=100000 WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml python examples/server_example_tt.py
VLLM_RPC_TIMEOUT=100000 MESH_DEVICE=T3K_LINE WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml python examples/server_example_tt.py
```

**Note**: By default, the server will run with Llama-3.1-70B. To run with Llama-3.2-11B-Vision-Instruct instead, add `--multi_modal`.
**Note**: By default, the server will run with Llama-3.1-70B. To run with other Llama versions, set `MESH_DEVICE` and `--model` as described in [Running the offline inference example](#running-the-offline-inference-example).

91 changes: 24 additions & 67 deletions vllm/worker/tt_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,6 @@ def __init__(
self.block_size = cache_config.block_size

self.trace_mode = trace_mode # whether to use ttnn tracing for model execution
self.execute_trace_kwargs = None # kw args for trace execution (populated during first decode execution)

self.cached_step_outputs: List[torch.Tensor] = [] # Only used for multi-step execution

Expand Down Expand Up @@ -448,81 +447,48 @@ def _send_prev_step_async_out(self, model_input: TTModelInput, step_idx):
def _execute_model_single_step(self, model_input: TTModelInput, kv_caches: List[torch.Tensor], is_decode, async_out_proc_per_trace=False, step_idx=0):
execute_model_kwargs = {
"tokens": model_input.input_tokens,
"start_pos": model_input.input_positions,
"page_table": model_input.block_tables,
"kv_cache": kv_caches,
"prompt_lens": model_input.prompt_lens,
**(model_input.multi_modal_kwargs or {}),
}
if not is_decode:
execute_model_kwargs["prompt_lens"] = model_input.prompt_lens
else:
execute_model_kwargs["start_pos"] = model_input.input_positions
if model_input.cross_block_tables is not None:
execute_model_kwargs["cross_page_table"] = model_input.cross_block_tables

if not self.model_config.is_encoder_decoder_model: # Forward for decoder-only
if self.trace_mode and is_decode: # Trace mode for decode
# Remove prompt_lens from execute_model_kwargs since it's not used for decode
execute_model_kwargs.pop("prompt_lens")

# Capture trace for the first decode execution
if self.execute_trace_kwargs is None:
logger.info("Capturing trace for first decode execution")
trace_id, tt_inp, rot_idxs_tt, cache_idxs_tt, tt_logits, tt_page_table = self.model.capture_trace(
**execute_model_kwargs
)
self.execute_trace_kwargs = {
"trace_id": trace_id,
"tt_inp": tt_inp,
"rot_idxs_tt": rot_idxs_tt,
"cache_idxs_tt": cache_idxs_tt,
"tt_logits": tt_logits,
"tt_page_table": tt_page_table,
"read_from_device": False,
}

# Remove kv_cache from execute_model_kwargs since it doesn't need to be copied to device for trace execution
execute_model_kwargs.pop("kv_cache")

tt_logits = self.model.decode_forward_trace(
**execute_model_kwargs, **self.execute_trace_kwargs
)
if async_out_proc_per_trace:
# trigger output processor on host while device is executing next step
self._send_prev_step_async_out(model_input, step_idx)
logits = self.model.read_forward_trace(tt_logits, model_input.unpadded_batch_size)
else: # prefill or non-traced decode
logits = self.model.forward(**execute_model_kwargs) # [batch_size, seq_len, vocab_size]
else: # Forward for encoder-decoder (may need to be updated for future models)
# TODO: remove different forward calls once TT models can manage intermediate outputs internally
if not is_decode:
# Remove start_pos from execute_model_kwargs since it's not used for prefill
execute_model_kwargs.pop("start_pos")

logits, cross_attention_masks, full_text_row_masked_out_mask = self.model.prefill_forward(**execute_model_kwargs)

# Save encoder-decoder data for use in subsequent decode steps
if not is_decode:
outputs = self.model.prefill_forward(**execute_model_kwargs)

if self.model_config.is_encoder_decoder_model:
# Save encoder-decoder data for use in subsequent decode steps (may need to be updated for future models)
logits, cross_attention_masks, full_text_row_masked_out_mask = outputs
if self.cached_enc_dec_data is None:
self.cached_enc_dec_data = {}
for i, seq_id in enumerate(model_input.seq_groups):
enc_dec_data = {"cross_attention_masks": cross_attention_masks[i],
"full_text_row_masked_out_mask": full_text_row_masked_out_mask[i]}
self.cached_enc_dec_data[seq_id] = enc_dec_data
else:
# Remove prompt_lens from execute_model_kwargs since it's not used for decode
execute_model_kwargs.pop("prompt_lens")
logits = outputs # [batch_size, seq_len, vocab_size]
else:
if self.model_config.is_encoder_decoder_model:
# Use encoder-decoder data from prefill step
cross_attention_masks = [self.cached_enc_dec_data[seq_id]["cross_attention_masks"] for seq_id in model_input.seq_groups]
full_text_row_masked_out_mask = [self.cached_enc_dec_data[seq_id]["full_text_row_masked_out_mask"] for seq_id in model_input.seq_groups]

enc_dec_kwargs = {"cross_attention_masks": cross_attention_masks,
"full_text_row_masked_out_mask": full_text_row_masked_out_mask}

tt_logits = self.model.decode_forward(
**execute_model_kwargs, **enc_dec_kwargs, enable_trace=self.trace_mode, read_from_device=False
)
if async_out_proc_per_trace:
# trigger output processor on host while device is executing next step
self._send_prev_step_async_out(model_input, step_idx)
logits = self.model.read_decode_output(tt_logits, model_input.unpadded_batch_size)
else:
enc_dec_kwargs = {}

tt_logits = self.model.decode_forward(
**execute_model_kwargs, **enc_dec_kwargs, enable_trace=self.trace_mode, read_from_device=False
)
if async_out_proc_per_trace:
# trigger output processor on host while device is executing next step
self._send_prev_step_async_out(model_input, step_idx)
logits = self.model.read_decode_output(tt_logits, model_input.unpadded_batch_size)

# Note: for other devices, vLLM applies vllm.model_executor.layers.logits_processor::LogitsProcessor::_apply_logits_processors on logits, we don't use this
# Note: for other devices, vLLM applies vllm.model_executor.layers.sampler::Sampler for sampling tokens, we don't use this
Expand All @@ -540,13 +506,4 @@ def _sample_tokens(self, logits, tt_sampling_params : TTSamplingParams):
p=tt_sampling_params.top_p,
k=tt_sampling_params.top_k,
temperature=tt_sampling_params.temperature
)

## Destructor (used to delete ttnn trace if using trace mode)

def __del__(self):
if self.trace_mode and self.execute_trace_kwargs is not None:
self.model.delete_trace(self.execute_trace_kwargs["trace_id"])

if hasattr(super(TTModelRunner, self), '__del__'):
super().__del__()
)
10 changes: 7 additions & 3 deletions vllm/worker/tt_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,11 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
appended to.
"""
# TODO: Add proper implementation which runs profiling on TT devices
max_tokens_all_users = 131072 # Note: includes num vision tokens for multi-modal
if ("meta-llama/Meta-Llama-3.1-8B" in self.model_config.model and
len(self.device_config.device.get_devices()) == 1): # Llama8B on N150
max_tokens_all_users = 65536
else:
max_tokens_all_users = 131072 # Note: includes num vision tokens for multi-modal
num_tt_blocks = math.ceil(max_tokens_all_users / self.cache_config.block_size)
num_tt_blocks = int(num_tt_blocks * 1.01) # Add 1% to account for vLLM's watermark_blocks
num_cpu_blocks = 0
Expand Down Expand Up @@ -411,7 +415,7 @@ def _get_dispatch_core_config(self, device_params):
def _open_mesh_device(self):
num_devices_available = len(ttnn.get_device_ids())

mesh_grid_dict = {"N150": (1, 1), "N300": (1, 2), "T3K_LINE": (1, 8), "T3K_RING": (2, 4)}
mesh_grid_dict = {"N150": (1, 1), "N300": (1, 2), "T3K_LINE": (1, 8), "T3K_RING": (2, 4), "TG": (8, 4)}
mesh_device = os.environ.get("MESH_DEVICE")
if mesh_device is not None:
assert mesh_device in mesh_grid_dict, f"Invalid MESH_DEVICE: {mesh_device}"
Expand All @@ -425,7 +429,7 @@ def _open_mesh_device(self):
assert f"Requested mesh grid shape {mesh_grid} is larger than number of available devices {num_devices_available}"

if self.trace_mode:
device_params = {"trace_region_size": 14227456} # TODO: make this configurable
device_params = {"trace_region_size": 23887872} # TODO: make this configurable
else:
device_params = {}
mesh_device = ttnn.open_mesh_device(
Expand Down