From 1f7eccfb0fc57daf2c3c868a86900412e538f8c5 Mon Sep 17 00:00:00 2001 From: Miguel Tairum <150826086+mtairum@users.noreply.github.com> Date: Wed, 4 Dec 2024 10:59:50 +0000 Subject: [PATCH] Llama3 model family now supports batch-32, long context (up to 128k) and paged attention (#15327) Co-authored-by: avoraTT Co-authored-by: Stuti Raizada Co-authored-by: kpaigwar --- .github/workflows/t3000-demo-tests-impl.yaml | 2 +- .../workflows/t3000-frequent-tests-impl.yaml | 4 +- .../t3000-model-perf-tests-impl.yaml | 2 - .github/workflows/t3000-unit-tests-impl.yaml | 4 +- models/demos/llama3/README.md | 61 ++- models/demos/llama3/demo/demo.py | 427 ++++++++++++------ models/demos/llama3/demo/input_data_long.json | 6 - .../llama3/demo/input_data_long_128k.json | 6 + .../llama3/demo/input_data_long_32k.json | 7 + .../llama3/demo/input_data_long_64k.json | 7 + .../demos/llama3/demo/simple_vision_demo.py | 2 - models/demos/llama3/lt | 23 +- ..._llama_cross_attention_transformer_text.py | 2 - .../multimodal/test_llama_cross_block.py | 2 - .../tests/test_interleaved_to_sharded.py | 9 - .../demos/llama3/tests/test_llama_accuracy.py | 158 +++++-- .../llama3/tests/test_llama_attention.py | 175 +++++-- .../tests/test_llama_attention_prefill.py | 134 +++++- .../demos/llama3/tests/test_llama_decoder.py | 134 ++++-- .../tests/test_llama_decoder_prefill.py | 110 ++++- .../llama3/tests/test_llama_embedding.py | 15 +- models/demos/llama3/tests/test_llama_mlp.py | 21 +- models/demos/llama3/tests/test_llama_model.py | 205 +++++++-- .../llama3/tests/test_llama_model_prefill.py | 162 +++++-- models/demos/llama3/tests/test_llama_perf.py | 203 --------- .../demos/llama3/tests/test_llama_rms_norm.py | 21 +- models/demos/llama3/tests/test_lm_head.py | 8 +- models/demos/llama3/tt/llama_attention.py | 102 +++-- models/demos/llama3/tt/llama_common.py | 108 +++-- models/demos/llama3/tt/llama_decoder.py | 21 +- models/demos/llama3/tt/llama_model.py | 9 +- models/demos/llama3/tt/llama_rope.py | 168 +++++++ models/demos/llama3/tt/model_config.py | 44 +- tests/scripts/run_performance.sh | 15 - .../t3000/run_t3000_model_perf_tests.sh | 56 --- .../misc/test_rotary_embedding_llama.py | 65 +-- .../test_rotary_embedding_llama_fused_qk.py | 8 +- 37 files changed, 1633 insertions(+), 873 deletions(-) delete mode 100644 models/demos/llama3/demo/input_data_long.json create mode 100644 models/demos/llama3/demo/input_data_long_128k.json create mode 100644 models/demos/llama3/demo/input_data_long_32k.json create mode 100644 models/demos/llama3/demo/input_data_long_64k.json delete mode 100644 models/demos/llama3/tests/test_llama_perf.py create mode 100644 models/demos/llama3/tt/llama_rope.py diff --git a/.github/workflows/t3000-demo-tests-impl.yaml b/.github/workflows/t3000-demo-tests-impl.yaml index 9ad4ab1b818..f71636bdb15 100644 --- a/.github/workflows/t3000-demo-tests-impl.yaml +++ b/.github/workflows/t3000-demo-tests-impl.yaml @@ -16,7 +16,7 @@ jobs: test-group: [ { name: "t3k_falcon40b_tests", arch: wormhole_b0, cmd: run_t3000_falcon40b_tests, timeout: 50, owner_id: U053W15B6JF}, #Djordje Ivanovic { name: "t3k_llama3_tests", arch: wormhole_b0, cmd: run_t3000_llama3_tests, timeout: 30, owner_id: U03PUAKE719}, # Miguel Tairum - { name: "t3k_llama3_vision_tests", arch: wormhole_b0, cmd: run_t3000_llama3_vision_tests, timeout: 30, owner_id: U03FJB5TM5Y}, #Colman Glagovich + # { name: "t3k_llama3_vision_tests", arch: wormhole_b0, cmd: run_t3000_llama3_vision_tests, timeout: 30, owner_id: U03FJB5TM5Y}, #Colman Glagovich { name: "t3k_llama3_70b_tests", arch: wormhole_b0, cmd: run_t3000_llama3_70b_tests, timeout: 30, owner_id: U03FJB5TM5Y}, #Colman Glagovich { name: "t3k_falcon7b_tests", arch: wormhole_b0, cmd: run_t3000_falcon7b_tests, timeout: 90, owner_id: U05RWH3QUPM}, #Salar Hosseini { name: "t3k_mixtral_tests", arch: wormhole_b0, cmd: run_t3000_mixtral_tests, timeout: 50, owner_id: U03PUAKE719}, # Miguel Tairum diff --git a/.github/workflows/t3000-frequent-tests-impl.yaml b/.github/workflows/t3000-frequent-tests-impl.yaml index fde2ede1652..542e85187c6 100644 --- a/.github/workflows/t3000-frequent-tests-impl.yaml +++ b/.github/workflows/t3000-frequent-tests-impl.yaml @@ -18,8 +18,8 @@ jobs: { name: "t3k ethernet tests", arch: wormhole_b0, cmd: run_t3000_ethernet_tests, timeout: 60, owner_id: ULMEPM2MA}, #Sean Nijjar { name: "t3k trace stress tests", arch: wormhole_b0, cmd: run_t3000_trace_stress_tests, timeout: 120, owner_id: U03NG0A5ND7}, #Aditya Saigal { name: "t3k falcon40b tests", arch: wormhole_b0, cmd: run_t3000_falcon40b_tests, timeout: 120, owner_id: U04S2UV6L8N}, #Sofija Jovic - { name: "t3k llama3.2-vision tests", arch: wormhole_b0, cmd: run_t3000_llama3.2-11b-vision_freq_tests, timeout: 60, owner_id: U03FJB5TM5Y}, #Colman Glagovich - { name: "t3k n300 mesh llama3.2-vision tests", arch: wormhole_b0, cmd: run_t3000_spoof_n300_llama3.2-11b-vision_freq_tests, timeout: 60, owner_id: U03FJB5TM5Y}, #Colman Glagovich + # { name: "t3k llama3.2-vision tests", arch: wormhole_b0, cmd: run_t3000_llama3.2-11b-vision_freq_tests, timeout: 60, owner_id: U03FJB5TM5Y}, #Colman Glagovich + # { name: "t3k n300 mesh llama3.2-vision tests", arch: wormhole_b0, cmd: run_t3000_spoof_n300_llama3.2-11b-vision_freq_tests, timeout: 60, owner_id: U03FJB5TM5Y}, #Colman Glagovich { name: "t3k llama3 tests", arch: wormhole_b0, cmd: run_t3000_llama3_tests, timeout: 45, owner_id: U03PUAKE719}, #Miguel Tairum Cruz { name: "t3k llama2_70b tests", arch: wormhole_b0, cmd: run_t3000_llama2_70b_tests, timeout: 45, owner_id: U03FJB5TM5Y}, #Colman Glagovich # { name: "t3k llama3_70b tests", arch: wormhole_b0, cmd: run_t3000_llama3_70b_tests, timeout: 45, owner_id: U03FJB5TM5Y}, #Colman Glagovich # FIXME issue #14934 diff --git a/.github/workflows/t3000-model-perf-tests-impl.yaml b/.github/workflows/t3000-model-perf-tests-impl.yaml index efa896ad8a2..387a18d15a2 100644 --- a/.github/workflows/t3000-model-perf-tests-impl.yaml +++ b/.github/workflows/t3000-model-perf-tests-impl.yaml @@ -18,8 +18,6 @@ jobs: { name: "t3k LLM falcon7b model perf tests", model: "falcon7b", model-type: "LLM", arch: wormhole_b0, cmd: run_t3000_falcon7b_tests, timeout: 75, owner_id: U05RWH3QUPM}, # Salar Hosseini { name: "t3k LLM mixtral model perf tests", model: "mixtral", model-type: "LLM", arch: wormhole_b0, cmd: run_t3000_mixtral_tests, timeout: 75, owner_id: U03PUAKE719}, # Miguel Tairum { name: "t3k LLM llama2-70B model perf tests", model: "llama2-70b", model-type: "LLM", arch: wormhole_b0, cmd: run_t3000_llama2_70b_tests, timeout: 75, owner_id: U03FJB5TM5Y}, # Colman Glagovich - { name: "t3k LLM llama3-70B model perf tests", model: "llama3-70b", model-type: "LLM", arch: wormhole_b0, cmd: run_t3000_llama3_70b_tests, timeout: 60, owner_id: U03FJB5TM5Y}, # Colman Glagovich - { name: "t3k LLM llama3 model perf tests", model: "llama3", model-type: "LLM", arch: wormhole_b0, cmd: run_t3000_llama3_tests, timeout: 60, owner_id: U03PUAKE719}, # Miguel Tairum { name: "t3k LLM falcon40b model perf tests", model: "falcon40b", model-type: "LLM", arch: wormhole_b0, cmd: run_t3000_falcon40b_tests, timeout: 75, owner_id: U053W15B6JF}, # Djordje Ivanovic { name: "t3k CNN resnet50 model perf tests", model: "resnet50", model-type: "CNN", arch: wormhole_b0, cmd: run_t3000_resnet50_tests, timeout: 75, owner_id: U013121KDH9}, # Austin Ho { name: "t3k CCL perf tests", arch: wormhole_b0, cmd: run_t3000_ccl_all_gather_perf_tests && run_t3000_ccl_reduce_scatter_perf_tests, timeout: 75, tracy: true, owner_id: ULMEPM2MA}, # Sean Nijjar diff --git a/.github/workflows/t3000-unit-tests-impl.yaml b/.github/workflows/t3000-unit-tests-impl.yaml index 303de478fd7..f05ee8e7810 100644 --- a/.github/workflows/t3000-unit-tests-impl.yaml +++ b/.github/workflows/t3000-unit-tests-impl.yaml @@ -20,8 +20,8 @@ jobs: { name: "t3k falcon40b tests", arch: wormhole_b0, cmd: run_t3000_falcon40b_tests, timeout: 30, owner_id: U053W15B6JF}, #Djordje Ivanovic { name: "t3k llama3-small tests", arch: wormhole_b0, cmd: run_t3000_llama3-small_tests, timeout: 30, owner_id: U03PUAKE719}, #Miguel Tairum Cruz { name: "t3k llama3.2-11b tests", arch: wormhole_b0, cmd: run_t3000_llama3.2-11b_tests, timeout: 30, owner_id: U03PUAKE719}, #Miguel Tairum Cruz - { name: "t3k llama3.2-11b-vision tests", arch: wormhole_b0, cmd: run_t3000_llama3.2-11b-vision_unit_tests, timeout: 30, owner_id: U03FJB5TM5Y}, #Colman Glagovich - { name: "t3k n300 mesh llama3.2-11b-vision tests", arch: wormhole_b0, cmd: run_t3000_spoof_n300_llama3.2-11b-vision_unit_tests, timeout: 30, owner_id: U03FJB5TM5Y}, #Colman Glagovich + # { name: "t3k llama3.2-11b-vision tests", arch: wormhole_b0, cmd: run_t3000_llama3.2-11b-vision_unit_tests, timeout: 30, owner_id: U03FJB5TM5Y}, #Colman Glagovich + # { name: "t3k n300 mesh llama3.2-11b-vision tests", arch: wormhole_b0, cmd: run_t3000_spoof_n300_llama3.2-11b-vision_unit_tests, timeout: 30, owner_id: U03FJB5TM5Y}, #Colman Glagovich { name: "t3k llama3.1-70b tests", arch: wormhole_b0, cmd: run_t3000_llama3.1-70b_tests, timeout: 30, owner_id: U03PUAKE719}, #Miguel Tairum Cruz { name: "t3k mixtral tests", arch: wormhole_b0, cmd: run_t3000_mixtral_tests, timeout: 30, owner_id: U03PUAKE719}, #Miguel Tairum Cruz { name: "t3k grok tests", arch: wormhole_b0, cmd: run_t3000_grok_tests, timeout: 30, owner_id: U03HY7MK4BT}, #Mark O'Connor diff --git a/models/demos/llama3/README.md b/models/demos/llama3/README.md index aff98ce5239..1f8e0ad401b 100644 --- a/models/demos/llama3/README.md +++ b/models/demos/llama3/README.md @@ -14,6 +14,23 @@ All the above llama models (with the exception of 70B due to its large size) are - N300 (2-chips) - T3000 (8-chips) +Below is an updated table with max prefill context-length support for our demo. These were tested on both accuracy and performance mode. + +The main reason for a long context length not fitting on device is lack of memory memory. Any exceptions are marked in the table. + +| | N150 | N300 | T3K | TG +---------------|---------------|-----------------|----------------|-------------| +| Llama3.2-1B | 64k tokens | 64k tokens | 64k tokens [1] | TBD | +| Llama3.2-3B | 32k tokens | 32k tokens [1] | 64k tokens [1] | TBD | +| Llama3.1-8B | 16k tokens | 64k tokens | 128k tokens | TBD | +| Llama3.2-11B | 16k tokens | 64k tokens | 128k tokens | TBD | +| Llama3.1-70B | Not supported | Not supported | 32k tokens [2] | 128k tokens | + +[1] For these configurations, running context lengths greater than those specified on the table will generate a bad repetitive output. + +[2] Although longer prefill context-lengths are not supported due to model size and available memory, you can still decode (generate) tokens up to a maximum of 128k tokens. + + ## How to Run ### Download the weights @@ -67,30 +84,50 @@ $LLAMA_DIR/T3K # For T3000 ### Run the demo -The current demo is setup for a single user (batch=1) that loads a prompt file (around 128 tokens), prefills the encoded prompt and then runs decode for 120 iterations. +The Llama3 demo includes 3 main modes of operation and is fully parametrized to support other configurations. + +- `batch-1`: Runs a small prompt for a single user +- `batch-32`: Runs a small prompt for a a batch of 32 users +- `long-context`: Runs a large prompt (64k tokens) for a single user -The demo is also parametrized to run for 1 or 3 continuous batch of users, i.e. to simulate multiple users generating text one after another. +If you want to provide your own demo configuration, please take a look at the pytest parametrize calls in `models/demos/llama3/demo/demo.py`. For convenience we list all the supported params below: -The input prompts are based on the general or instruct (fine-tuned) weights. The prompts are included in the demo folder `models/demos/llama3/demo`. +- `input_prompts (string)`: input json file with prompts to process. See `models/demos/llama3/demo/*.json` for a list of input files +- `instruct (bool)`: Whether to use Llama instruct weights or general weights +- `repeat_batches (int)`: Number of consecutive batches of users to run (default: 1) +- `max_seq_len (int)`: Maximum context length supported by the model (refer to the table above) +- `batch_size (int)`: Number of users in a batch (Supports 1/2/4/8/16/32 batches) +- `max_generated_tokens (int)`: Maximum number of tokens to generate for each user (Note that the users will stop generation before this limit if they reach a eos token) +- `paged_attention (bool)`: Whether to use paged attention or default attention (vLLM support (WIP) requires paged attention) +- `page_params (dict)`: Page parameters for paged attention - [`block_size`, `max_num_blocks`]. For smaller context lengths use `block_size=32` and `max_num_blocks=1024`, for larger context use block_size=64 and max_num_blocks=2048 +- `sampling_params (dict)`: Sampling parameters for decoding -[`temperature`, `top_p`]. If temperature is set to 0, argmax (greedy decode) is used. +- `optimization (LlamaOptimizations)`: Optimization level to use for the model [`performance`, `accuracy`] + +Please note that using `argmax` with `batch_size > 1` or using `top-p` sampling with any batch size, these ops will be run on host. This is because those ops are not yet fully supported on device. A decrease in performance is expected when these configurations are enabled. When running the demo, do not forget to setup the `$LLAMA_DIR` environment variable to the corresponding Llama3 model weights. +Additionally, we also support the use of a fake device. This enables running a smaller chip demo in a larger multichip device. +Supported devices: [`N150`, `N300`, `T3K`, `TG`]. + +Example: `export FAKE_DEVICE=N150`, will enable running a single-chip demo on a multi-chip system. + ``` # Examples of how to run the demo for any supported Llama3 models -# Run a single continuous batch with instruct weights -pytest models/demos/llama3/demo/demo.py -k 'instruct and 1_batch' +# Batch-1 +pytest models/demos/llama3/demo/demo.py -k "performance and batch-1" -# Run 2 continuous batches with general weights -pytest models/demos/llama3/demo/demo.py -k 'general and 2_batch' +# Batch-32 +pytest models/demos/llama3/demo/demo.py -k "performance and batch-32" + +# Long-context +pytest models/demos/llama3/demo/demo.py -k "performance and long" ``` -By default we run the models in `LlamaOptimizations.performance` mode. You can override this by setting the `optimizations` argument in the demo. To compare the two on a long prompt, you can run: +The above examples are run in `LlamaOptimizations.performance` mode. +You can override this by setting the `optimizations` argument in the demo. To use instead the accuracy mode you can call the above tests with `-k "accuracy and ..."` instead of performance. -``` -pytest models/demos/llama3/demo/demo.py -k 'long-performance' -pytest models/demos/llama3/demo/demo.py -k 'long-accuracy' -``` ### Expected performance and accuracy diff --git a/models/demos/llama3/demo/demo.py b/models/demos/llama3/demo/demo.py index 94fd07d35e5..a828830d330 100644 --- a/models/demos/llama3/demo/demo.py +++ b/models/demos/llama3/demo/demo.py @@ -15,23 +15,27 @@ from pathlib import Path import hashlib +from models.utility_functions import nearest_32 from models.demos.llama3.tt.llama_common import ( - get_single_rot_mat, get_prefill_rot_mat, get_rot_transformation_mat, HostEmbedding, encode_prompt_llama_instruct, + PagedAttentionConfig, + sample_host, ) from models.demos.llama3.tt.llama_model import TtTransformer from models.demos.llama3.tt.llama_embedding import TtLlamaEmbedding +from models.demos.llama3.tt.llama_rope import TtLlamaRotarySetup from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.tokenizer import Tokenizer +from models.demos.llama3.tt.model_config import TtModelArgs from models.perf.benchmarking_utils import BenchmarkProfiler from models.demos.utils.llm_demo_utils import create_benchmark_data, verify_perf from models.demos.llama3.tt.model_config import LlamaOptimizations -def load_and_cache_context(context_url, cache_dir): +def load_and_cache_context(context_url, cache_dir, max_length=None): cache_file = cache_dir / hashlib.md5(context_url.encode()).hexdigest() if cache_file.exists(): @@ -53,11 +57,16 @@ def load_and_cache_context(context_url, cache_dir): logger.error(f"Error fetching context from URL: {context_url}. Error: {str(e)}") context_text = "" + # Clip the context to the max length provided + if max_length: + context_text = context_text[:max_length] + logger.info(f"Clipped the context text to {max_length} characters") + return context_text # load from json, return as a list -def load_inputs(user_input, batch): +def load_inputs(user_input, batch, instruct_mode): if isinstance(user_input, str): with open(user_input, "r") as f: user_input = json.load(f) @@ -69,8 +78,18 @@ def load_inputs(user_input, batch): for i in range(batch): prompt = user_input[i]["prompt"] if "context" in user_input[i]: - context_text = load_and_cache_context(user_input[i]["context"], cache_dir) - prompt = context_text + "\n\n" + prompt + if "max_length" in user_input[i]: # Clip the context to the max length provided + context_text = load_and_cache_context( + user_input[i]["context"], cache_dir, max_length=user_input[i]["max_length"] + ) + else: + context_text = load_and_cache_context(user_input[i]["context"], cache_dir) + if instruct_mode: + prompt = ( + "```" + context_text + "```\n\n" + prompt + ) # Add the markdown block to the context to comply with the prompt + else: + prompt = context_text in_prompt.append(prompt) return in_prompt @@ -154,14 +173,18 @@ def preprocess_inputs_prefill( def run_llama3_demo( user_input, - batch_size, - single_layer, mesh_device, + max_seq_len, + batch_size, + num_batches, + paged_attention, + paged_attention_config, + max_generated_tokens, + optimizations, + sampling_params, instruct_mode, is_ci_env, - num_batches, print_to_file, - optimizations, ): # Creat batch output file timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") @@ -170,10 +193,9 @@ def run_llama3_demo( os.chmod(output_directory, 0o755) output_filename = f"{output_directory}/demo_user_output_{timestamp}.txt" - # This module requires the env paths above for CI runs - from models.demos.llama3.tt.model_config import TtModelArgs - dtype = ttnn.bfloat8_b + assert batch_size <= 32, "Max batch size currently supported is 32" + assert max_seq_len <= 128 * 1024, "Max sequence length must be less than 128k tokens" # We disregard any warmup iteration for profiling, in favour of just measuring compile time on the first iteration N_warmup_iter = {"inference_prefill": 0, "inference_decode": 0} @@ -188,7 +210,7 @@ def run_llama3_demo( if len(user_input) == 1: input_prompts = user_input * batch_size else: - input_prompts = load_inputs(user_input, batch_size) + input_prompts = load_inputs(user_input, batch_size, instruct_mode) profiler.end("loading_inputs") # Generate the batched prompts (rotate the inputs between the users, for each batch) @@ -198,17 +220,91 @@ def run_llama3_demo( batch_prompts.append([input_prompts[(j + i) % len(input_prompts)] for j in range(len(input_prompts))]) # Load model args, weights, and tokenizer - model_args = TtModelArgs(mesh_device, instruct=instruct_mode, optimizations=optimizations) + model_args = TtModelArgs( + mesh_device, + instruct=instruct_mode, + max_batch_size=batch_size, + optimizations=optimizations, + max_seq_len=max_seq_len, + ) tokenizer = Tokenizer(model_args.tokenizer_path) - if single_layer: - model_args.n_layers = 1 + # Check max sequence length compatibility with model and architecture. Refer to README for more information + llama_model_name = model_args.model_name # ["3.2-1B", "3.2-3B", "3.1-8B", "3.2-11B", "3.1-70B"] + tt_device_name = model_args.device_name # ["N150", "N300", "T3K", "TG"] + + if llama_model_name == "3.2-1B": + assert ( + max_seq_len <= 64 * 1024 + ), "Llama3.2-1B only supports a max context length of 64k tokens across all architectures" + if llama_model_name == "3.2-3B": + if tt_device_name == "N150": + assert max_seq_len <= 32 * 1024, "N150 only supports a max context length of 32k tokens for Llama3.2-3B" + elif tt_device_name == "N300": + assert max_seq_len <= 64 * 1024, "N300 only supports a max context length of 64k tokens for Llama3.2-3B" + else: # T3K and TG + assert max_seq_len <= 64 * 1024, "T3K only supports a max context length of 64k tokens for Llama3.2-3B" + if llama_model_name in ["3.1-8B", "3.2-11B"]: + if tt_device_name == "N150": + assert ( + max_seq_len <= 16 * 1024 + ), "N150 only supports a max context length of 16k tokens for Llama3.1-8B and Llama3.2-11B" + elif tt_device_name == "N300": + assert ( + max_seq_len <= 64 * 1024 + ), "N300 only supports a max context length of 64k tokens for Llama3.1-8B and Llama3.2-11B" + else: # T3K and TG + assert ( + max_seq_len <= 128 * 1024 + ), "T3K only supports a max context length of 128k tokens for Llama3.1-8B and Llama3.2-11B" + if llama_model_name == "3.1-70B": + assert tt_device_name in ["T3K", "TG"], "Llama3.1-70B is only supported on T3K or TG" logger.info("Loading weights...") profiler.start("weight_loading") state_dict = model_args.load_state_dict() profiler.end("weight_loading") + # Setup RoPE transformation matrices + rope_setup = TtLlamaRotarySetup( + mesh_device, + batch_size, + model_args.head_dim, + model_args.max_seq_len, + model_args.rope_theta, + model_args.use_scaled_rope, + ) + transformation_mats_decode = rope_setup.get_trans_mats() + + transformation_mats_prefill_torch = get_rot_transformation_mat(model_args.head_dim) + transformation_mats_prefill = ttnn.from_torch( + transformation_mats_prefill_torch, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + device=mesh_device, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) + transformation_mats = {"decode": transformation_mats_decode, "prefill": transformation_mats_prefill} + + page_table_tt = None + + if paged_attention: + # Implied shuffling of blocks + permutation = torch.randperm(paged_attention_config.max_num_blocks) + # Page table which maps virtual blocks to physical + reverse_permutation = torch.argsort(permutation) + page_table = reverse_permutation.reshape( + model_args.max_batch_size, paged_attention_config.max_num_blocks // model_args.max_batch_size + ) + page_table_tt = ttnn.from_torch( + page_table, + device=mesh_device, + dtype=ttnn.int32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) + # Load TTNN Llama3.1 model logger.info("Loading weights to device...") profiler.start("loading_weights_to_device") @@ -218,6 +314,8 @@ def run_llama3_demo( dtype=dtype, state_dict=state_dict, weight_cache_path=model_args.weight_cache_path(dtype), + transformation_mats=transformation_mats, + paged_attention_config=paged_attention_config, ) tt_embd = TtLlamaEmbedding( mesh_device=mesh_device, @@ -232,7 +330,6 @@ def run_llama3_demo( profiler.end("loading_weights_to_device") logger.info("Finished loading weights to device.") - max_generated_tokens = 100 # Maximum number of tokens to generate per user num_tokens_generated_decode = [] logger.info("Starting inference...") @@ -252,6 +349,12 @@ def run_llama3_demo( instruct_mode, max_generated_tokens, ) + + max_encoded_prompt_len = max(len(p) for p in encoded_prompts) + assert ( + max_generated_tokens + max_encoded_prompt_len <= max_seq_len + ), f"Prompt prefill tokens ({max_encoded_prompt_len}) + maximum number of decoded iterations ({max_generated_tokens}) needs to be <= than max_seq_len ({max_seq_len})" + # Prefill embeddings are on host since we need to mask out the tokens after the prefill length after embeddings are computed pt_prefill_input = [embd(input_tokens_prefill_pt[b]).view(1, prefill_lens[b], -1) for b in range(batch_size)] profiler.end(f"preprocess_prefill_inputs", iteration=batch_idx) @@ -265,18 +368,6 @@ def run_llama3_demo( logger.info(f"Starting prefill...") - profiler.start(f"prepare_rot_mat_for_prefill", iteration=batch_idx) - transformation_mat_torch = get_rot_transformation_mat(model_args.head_dim) - transformation_mats = ttnn.from_torch( - transformation_mat_torch, - dtype=ttnn.bfloat16, - layout=ttnn.TILE_LAYOUT, - device=mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), - memory_config=ttnn.DRAM_MEMORY_CONFIG, - ) - profiler.end(f"prepare_rot_mat_for_prefill", iteration=batch_idx) - # Do not count the first user for prefill time and instead log it as compile time num_users_generated_prefill = batch_size - 1 if batch_size > 1 else 1 @@ -302,11 +393,11 @@ def run_llama3_demo( tt_out = tt_model( prefill_input, - None, # Current position - rot_mats_prefill, - transformation_mats, + current_pos=None, + rot_mats=rot_mats_prefill, user_id=batch_id, mode="prefill", + page_table=page_table_tt, get_last_token=((decoding_pos[batch_id] - 1) // 32) * 32, ) @@ -320,11 +411,11 @@ def run_llama3_demo( ttnn.deallocate(tt_out) tt_out = tt_model( prefill_input, - None, # Current position - rot_mats_prefill, - transformation_mats, + current_pos=None, + rot_mats=rot_mats_prefill, user_id=batch_id, mode="prefill", + page_table=page_table_tt, get_last_token=((decoding_pos[batch_id] - 1) // 32) * 32, ) @@ -345,8 +436,11 @@ def run_llama3_demo( profiler.start(f"prepare_first_decode_token_{batch_idx}") pt_out_batched = torch.stack(pt_out, dim=-2) pt_out_batched = torch.argmax(pt_out_batched, dim=-1) + # Pad the output tensor to be tile sized tt_out_tok = ttnn.from_torch( - torch.nn.functional.pad(pt_out_batched.unsqueeze(0).unsqueeze(0).unsqueeze(0), (0, 31), "constant", 0), + torch.nn.functional.pad( + pt_out_batched.unsqueeze(0).unsqueeze(0).unsqueeze(0), (0, 32 - len(pt_out_batched)), "constant", 0 + ), device=mesh_device, mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), dtype=ttnn.uint32, @@ -363,32 +457,37 @@ def run_llama3_demo( logger.info("Starting decode...") - profiler.start(f"get_single_rot_mat_decode_{batch_idx}") - current_rot_mat, rot_matrix = get_single_rot_mat( - model_args.head_dim, - mesh_device, - model_args.num_devices, - start_pos=decoding_pos[0] - 2, - ) - profiler.end(f"get_single_rot_mat_decode_{batch_idx}") + # Set sampling mode + argmax_on_device = False if (batch_size > 1 or sampling_params["temperature"] != 0) else True # Create events profiler.start(f"compile_trace_{batch_idx}") op_event = ttnn.create_event(mesh_device) write_event = ttnn.create_event(mesh_device) - current_pos = ttnn.from_torch( - torch.tensor(decoding_pos, dtype=torch.int32), + # Initial positions + current_pos = torch.tensor([decoding_pos[b] for b in range(batch_size)]) + + current_pos_tensor = ttnn.from_torch( + current_pos, device=mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), dtype=ttnn.int32, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), ) + # Get cos/sin matrices for the current position of each user + rot_mats, rot_mat_idxs = rope_setup.get_rot_mats(current_pos, return_rot_idxs=True) # Compile logger.info(f"Compiling model trace...") decode_input = ttnn.unsqueeze_to_4D(tt_embd(tt_out_tok)) decode_input = ttnn.to_memory_config(decode_input, tt_model.args.model_config["DECODE_RESIDUAL_MEMCFG"]) - tt_out = tt_model(decode_input, current_pos, rot_mat=current_rot_mat) + tt_out = tt_model( + decode_input, + current_pos_tensor, + rot_mats=rot_mats, + mode="decode", + page_table=page_table_tt, + ) if tt_model.args.num_devices > 1: tt_out_gathered = ttnn.all_gather(tt_out, dim=3, num_links=1, topology=ttnn.Topology.Linear) ttnn.deallocate(tt_out) @@ -396,11 +495,21 @@ def run_llama3_demo( tt_out_gathered = tt_out tt_out_rm = ttnn.untilize(tt_out_gathered, use_multicore=True) ttnn.deallocate(tt_out_gathered) - tt_out_tok = ttnn.argmax(tt_out_rm, dim=3, use_multicore=True, output_tensor=tt_out_tok) - ttnn.deallocate(tt_out_rm) - new_rot_mat = ttnn.linear(rot_matrix, current_rot_mat) - current_rot_mat = ttnn.copy(new_rot_mat, current_rot_mat) - ttnn.plus_one(current_pos) + if argmax_on_device: + tt_out_tok = ttnn.argmax( # FIXME When ttnn.argmax supports multicore, avoid falling back to host + tt_out_rm, dim=3, use_multicore=False if batch_size > 1 else True, output_tensor=tt_out_tok + ) + ttnn.deallocate(tt_out_rm) + else: + tt_out_tok_reset, _ = sample_host( + tt_out_rm, + mesh_device, + temperature=sampling_params["temperature"], + top_p=sampling_params["top_p"], + on_host=True, + ) + ttnn.copy_host_to_device_tensor(tt_out_tok_reset, tt_out_tok) + ttnn.plus_one(current_pos_tensor) profiler.end(f"compile_trace_{batch_idx}") # Capture Trace @@ -410,7 +519,14 @@ def run_llama3_demo( decode_input = ttnn.unsqueeze_to_4D(tt_embd(tt_out_tok)) decode_input = ttnn.to_memory_config(decode_input, tt_model.args.model_config["DECODE_RESIDUAL_MEMCFG"]) - tt_out = tt_model(decode_input, current_pos, rot_mat=current_rot_mat) + rot_mats = rope_setup.get_rot_mats(rot_mat_idxs) + tt_out = tt_model( + decode_input, + current_pos_tensor, + rot_mats=rot_mats, + mode="decode", + page_table=page_table_tt, + ) if tt_model.args.num_devices > 1: tt_out_gathered = ttnn.all_gather(tt_out, dim=3, num_links=1, topology=ttnn.Topology.Linear) ttnn.deallocate(tt_out) @@ -418,28 +534,36 @@ def run_llama3_demo( tt_out_gathered = tt_out tt_out_rm = ttnn.untilize(tt_out_gathered, use_multicore=True) ttnn.deallocate(tt_out_gathered) - tt_out_tok = ttnn.argmax(tt_out_rm, dim=3, use_multicore=True, output_tensor=tt_out_tok) - ttnn.deallocate(tt_out_rm) - new_rot_mat = ttnn.linear(rot_matrix, current_rot_mat) - current_rot_mat = ttnn.copy(new_rot_mat, current_rot_mat) - ttnn.plus_one(current_pos) + if argmax_on_device: + tt_out_tok = ttnn.argmax( + tt_out_rm, dim=3, use_multicore=False if batch_size > 1 else True, output_tensor=tt_out_tok + ) # FIXME Multicore is not compatible with batch > 1 + ttnn.deallocate(tt_out_rm) + ttnn.plus_one(current_pos_tensor) + # ttnn.plus_one(rot_mat_idxs) # FIXME <- This won't work since embedding requires uint32 and plus_one only works for int32 ttnn.end_trace_capture(mesh_device, trace_id, cq_id=0) # Reset the decoding position for the proper run of the model current_pos_reset = ttnn.from_torch( - torch.tensor(decoding_pos, dtype=torch.int32), + current_pos, dtype=ttnn.int32, mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device) if tt_model.args.num_devices > 1 else None, ) tt_out_tok_reset = ttnn.from_torch( - torch.nn.functional.pad(pt_out_batched.unsqueeze(0).unsqueeze(0).unsqueeze(0), (0, 31), "constant", 0), + torch.nn.functional.pad( + pt_out_batched.unsqueeze(0).unsqueeze(0).unsqueeze(0), (0, 32 - len(pt_out_batched)), "constant", 0 + ), + # torch.nn.functional.pad(pt_out_batched.unsqueeze(0).unsqueeze(0).unsqueeze(0), (0, 30), "constant", 0), dtype=ttnn.uint32, mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device) if tt_model.args.num_devices > 1 else None, ) - ttnn.copy_host_to_device_tensor(current_pos_reset, current_pos) + # Reset the current position and output token tensors for the real decode run + ttnn.copy_host_to_device_tensor(current_pos_reset, current_pos_tensor) ttnn.copy_host_to_device_tensor(tt_out_tok_reset, tt_out_tok) + rot_mat_idxs_reset = rope_setup.get_rot_idxs(current_pos, on_host=True) + ttnn.copy_host_to_device_tensor(rot_mat_idxs_reset, rot_mat_idxs) profiler.end(f"capture_trace_{batch_idx}") @@ -463,17 +587,37 @@ def run_llama3_demo( ttnn.execute_trace(mesh_device, trace_id, cq_id=0, blocking=True) ttnn.record_event(0, op_event) + # Update current pos and mat idxs on host and send to device + # TODO This is required for now since we cannot ttnn.plus_one(rot_mat_idxs) while it being uint32. + # If this tensor is int32, it won't be supported by ttnn.embedding + current_pos += 1 + rot_mat_idxs_updated = rope_setup.get_rot_idxs(current_pos, on_host=True) + ttnn.copy_host_to_device_tensor(rot_mat_idxs_updated, rot_mat_idxs) + # Write to host ttnn.wait_for_event(1, op_event) - tt_output_torch = ttnn.to_torch( - tt_out_tok.cpu(blocking=True, cq_id=1), mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1) - )[0, 0, 0, :batch_size] + if argmax_on_device: + tt_output_torch = ttnn.to_torch( + tt_out_tok.cpu(blocking=True, cq_id=1), mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1) + )[0, 0, 0, :batch_size] + else: + tt_out_tok_reset, tt_output_torch = sample_host( + tt_out_rm, + mesh_device, + temperature=sampling_params["temperature"], + top_p=sampling_params["top_p"], + on_host=True, + ) + tt_output_torch = tt_output_torch[0, 0, 0, :batch_size] + ttnn.copy_host_to_device_tensor(tt_out_tok_reset, tt_out_tok) ttnn.record_event(1, write_event) # Save output token to print out later for user in range(batch_size): user_tok = tt_output_torch[user].tolist() - if user_tok != 28803 and user_done[user] == False: # Stop saving the ouput after hitting the EOS token + if ( + user_tok != 128009 and user_done[user] == False + ): # Stop saving the ouput after hitting the eos token (<|eot_id|>) (128009) all_outputs[user].append(user_tok) else: user_done[user] = True @@ -515,19 +659,6 @@ def run_llama3_demo( iteration += 1 - # Reset rotation matrix every 100 iterations - profiler.start(f"reset_rot_mat_{iteration-1}", iteration=batch_idx) - if iteration % 100 == 0: - current_rot_mat_reset, rot_matrix_reset = get_single_rot_mat( - model_args.head_dim, - mesh_device, - model_args.num_devices, - start_pos=decoding_pos[0] + iteration, - on_host=True, - ) - ttnn.copy_host_to_device_tensor(current_rot_mat_reset, current_rot_mat) - profiler.end(f"reset_rot_mat_{iteration-1}", iteration=batch_idx) - # Upper limit of generated tokens for each user (to avoid infinite generation in case eos is not seen) if iteration >= max_generated_tokens: users_decoding = False @@ -578,7 +709,7 @@ def run_llama3_demo( compile_decode_time = profiler.get_duration("compile_decode") inference_prefill_time = profiler.get_duration("inference_prefill") inference_decode_time = profiler.get_duration("inference_decode") - log_printing_time = sum(profiler.get_duration(f"log_printing_iter_{i}") for i in range(max_generated_tokens)) + log_printing_time = sum(profiler.get_duration(f"log_printing_iter_{i}") for i in range(total_tokens_generated)) log_saving_file_time = profiler.get_duration(f"log_saving_file") # Correct the inference decode time to remove the time spent on compile (1st iteration) and log_printing (at the end of every iteration) @@ -602,13 +733,10 @@ def run_llama3_demo( "loading_inputs": profiler.get_duration("loading_inputs"), "weight_loading": profiler.get_duration("weight_loading"), "prepare_first_decode_token": profiler.get_duration("prepare_first_decode_token_0"), - "get_single_rot_mat_decode": profiler.get_duration("get_single_rot_mat_decode_0"), # Only for batch 0 "preprocess_prefill_inputs": profiler.get_duration("preprocess_prefill_inputs"), "loading_weights_to_device": profiler.get_duration("loading_weights_to_device"), - "prepare_rot_mat_for_prefill": profiler.get_duration("prepare_rot_mat_for_prefill"), "compile_trace": profiler.get_duration("compile_trace_0"), # Only for batch 0 "capture_trace": profiler.get_duration("capture_trace_0"), # Only for batch 0 - "reset_rot_mat": sum(profiler.get_duration(f"reset_rot_mat_{i}") for i in range(max_generated_tokens)), "Total compile time": compile_prefill_time + compile_decode_time, "Full demo runtime": profiler.get_duration("run"), } @@ -620,7 +748,7 @@ def run_llama3_demo( logger.info(f"Decode compile time: {round(measurements['compile_decode'], 4)}s") logger.info(f"Prefill inference time per user: {round(inference_prefill_time/num_users_generated_prefill, 4)}s") logger.info( - f"Total Decode inference time ({max_generated_tokens-1} iterations): {round(measurements['inference_decode'], 4)}s" + f"Total Decode inference time ({total_tokens_generated-1} iterations): {round(measurements['inference_decode'], 4)}s" ) logger.info("") logger.info(f"Time to first token: {round(measurements['prefill_time_to_token']* 1000, 2)}ms") @@ -691,8 +819,7 @@ def run_llama3_demo( } # Save benchmark data for CI dashboard - # if is_ci_env: - if True: + if is_ci_env: benchmark_data = create_benchmark_data(profiler, measurements, N_warmup_iter, targets) benchmark_data.prep_csvs( profiler, @@ -703,48 +830,71 @@ def run_llama3_demo( batch_size=batch_size, input_sequence_length=prefill_seq_len, output_sequence_length=1, - # config_params=, - # precision=, ) +# List of supported Parameters for demo.py +# +# input_prompts (string): input json file with prompts to process. See models/demos/llama3/demo/*.json for list of input files +# instruct (bool): Whether to use instruct weights or general weights +# repeat_batches (int): Number of consecutive batches of users to run (default: 1) +# max_seq_len (int): Maximum context length supported by the model (Llama3.1 and Llama3.2 models have a maximum context length of 128k, i.e., 128 * 1024) +# batch_size (int): Number of users in a batch (Supports 1/2/4/8/16/32 batches) +# max_generated_tokens (int): Maximum number of tokens to generate for each user (Note that the users will stop generation before this limit if they reach a EoS token) +# paged_attention (bool): Whether to use paged attention or default attention (vLLM requires paged attention) +# page_params (dict): Page parameters for paged attention (block_size, max_num_blocks) For smaller context lengths use block_size=32 and max_num_blocks=1024, for larger context use block_size=64 and max_num_blocks=2048 +# sampling_params (dict): Sampling parameters for decoding (temperature, top_p). If temperature is set to 0, argmax (greedy decode) is used. +# +# optimization (LlamaOptimizations): Optimization level to use for the model (performance or accuracy) +# FAKE_DEVICE (str): Fake device to use for testing (N150, N300, T3K, TG). Usage: `export FAKE_DEVICE=N150`, will enable running a single-chip demo on a multi-chip system. @pytest.mark.parametrize( - "input_prompts, instruct_weights, num_batches, single_layer, optimizations", + "input_prompts, instruct, repeat_batches, max_seq_len, batch_size, max_generated_tokens, paged_attention, page_params, sampling_params", [ - ("models/demos/llama3/demo/input_data_prefill_128.json", False, 1, False, LlamaOptimizations.performance), - ("models/demos/llama3/demo/input_data_prefill_128.json", False, 2, False, LlamaOptimizations.performance), - ( - "models/demos/llama3/demo/input_data_questions_prefill_128.json", - True, - 1, - False, - LlamaOptimizations.performance, + ( # Batch-1 run (Latency) - single user, small prompt + "models/demos/llama3/demo/input_data_questions_prefill_128.json", # input_prompts + True, # instruct mode + 1, # repeat_batches + 1024, # max_seq_len + 1, # batch_size + 200, # max_generated_tokens + True, # paged_attention + {"page_block_size": 32, "page_max_num_blocks": 1024}, # page_params # TODO This will be serviced by vLLM + {"temperature": 0, "top_p": 0.08}, # sampling_params (argmax) ), - ( - "models/demos/llama3/demo/input_data_questions_prefill_128.json", - True, - 2, - False, - LlamaOptimizations.performance, + ( # Batch-32 run (Throughput) - 32 users, small prompt + "models/demos/llama3/demo/input_data_questions_prefill_128.json", # input_prompts + True, # instruct mode + 1, # repeat_batches + 1024, # max_seq_len + 32, # batch_size + 200, # max_generated_tokens + True, # paged_attention + {"page_block_size": 32, "page_max_num_blocks": 1024}, # page_params # TODO This will be serviced by vLLM + {"temperature": 0, "top_p": 0.08}, # sampling_params (argmax) ), - ("models/demos/llama3/demo/input_data_long.json", True, 1, False, LlamaOptimizations.performance), - ("models/demos/llama3/demo/input_data_long.json", True, 1, False, LlamaOptimizations.accuracy), - ( - "models/demos/llama3/demo/input_data_questions_prefill_128.json", - True, - 1, - True, - LlamaOptimizations.performance, + ( # Long-context run - Single user, long prompt (adapted to the model being used and architecture) + "models/demos/llama3/demo/input_data_long_64k.json", # input_prompts + True, # instruct mode + 1, # repeat_batches + 64 * 1024, # max_seq_len + 1, # batch_size + 200, # max_generated_tokens + False, # paged_attention + {"page_block_size": 64, "page_max_num_blocks": 2048}, # page_params # TODO This will be serviced by vLLM + {"temperature": 0, "top_p": 0.08}, # sampling_params (argmax) ), ], ids=[ - "general_weights-1_batch", - "general_weights-2_batch", - "instruct_weights-1_batch", - "instruct_weights-2_batch", - "instruct_weights-long-performance", - "instruct_weights-long-accuracy", - "single_layer", + "batch-1", # latency + "batch-32", # throughput + "long-context", # max-length + ], +) +@pytest.mark.parametrize( + "optimizations", + [ + LlamaOptimizations.performance, + LlamaOptimizations.accuracy, ], ) @pytest.mark.parametrize("device_params", [{"trace_region_size": 23887872, "num_command_queues": 2}], indirect=True) @@ -758,29 +908,46 @@ def run_llama3_demo( indirect=True, ) def test_llama_demo( + input_prompts, + instruct, + repeat_batches, + max_seq_len, + batch_size, + max_generated_tokens, + paged_attention, + page_params, + sampling_params, + optimizations, mesh_device, use_program_cache, - input_prompts, - instruct_weights, is_ci_env, - num_batches, - single_layer, - optimizations, reset_seeds, ): - if is_ci_env and (instruct_weights == False or "long" in input_prompts or single_layer == True): - pytest.skip("CI demo test only runs instruct weights to reduce CI pipeline load (both are supported)") + if is_ci_env and ("long" in input_prompts or optimizations == LlamaOptimizations.accuracy): + pytest.skip("Do not run the 'long-context' or accuracy tests on CI to reduce load") mesh_device.enable_async(True) + if paged_attention: + paged_attention_config = PagedAttentionConfig( + block_size=page_params["page_block_size"], + max_num_blocks=page_params["page_max_num_blocks"], + ) + else: + paged_attention_config = None + return run_llama3_demo( user_input=input_prompts, - batch_size=1, - single_layer=single_layer, mesh_device=mesh_device, - instruct_mode=instruct_weights, + max_seq_len=max_seq_len, + batch_size=batch_size, + num_batches=repeat_batches, + paged_attention=paged_attention, + paged_attention_config=paged_attention_config, + max_generated_tokens=max_generated_tokens, + optimizations=optimizations, + sampling_params=sampling_params, + instruct_mode=instruct, is_ci_env=is_ci_env, - num_batches=num_batches, print_to_file=False, - optimizations=optimizations, ) diff --git a/models/demos/llama3/demo/input_data_long.json b/models/demos/llama3/demo/input_data_long.json deleted file mode 100644 index 33e9d6b48c7..00000000000 --- a/models/demos/llama3/demo/input_data_long.json +++ /dev/null @@ -1,6 +0,0 @@ -[ - { - "prompt": "Take three specific quotes from the above reference text (marked in a ```markdown block```) to summarize by means of metaphor what it can teach us about the value of AI in the modern age. The first quote should be from the first third of the text, the second quote from the middle and the final quote from the last third of the text. Be brief and exact in your answer.", - "context": "https://www.gutenberg.org/cache/epub/84/pg84.txt" - } -] diff --git a/models/demos/llama3/demo/input_data_long_128k.json b/models/demos/llama3/demo/input_data_long_128k.json new file mode 100644 index 00000000000..9dbadbde754 --- /dev/null +++ b/models/demos/llama3/demo/input_data_long_128k.json @@ -0,0 +1,6 @@ +[ + { + "prompt": "Explicitly state the quotes directly taken from the book inside double quotes like this: \n A. < add quote> \n Metaphor: \n B. < add quote> \n Metaphor: \n C. < add quote> \n Metaphor: \n with the metaphors after each quote. Double-check that the quotes are from the text specified above and that the metaphors relate to AI. End your answer after the 3 quotes / metaphors are finished.", + "context": "https://www.gutenberg.org/cache/epub/84/pg84.txt" + } +] diff --git a/models/demos/llama3/demo/input_data_long_32k.json b/models/demos/llama3/demo/input_data_long_32k.json new file mode 100644 index 00000000000..c0faca0b1eb --- /dev/null +++ b/models/demos/llama3/demo/input_data_long_32k.json @@ -0,0 +1,7 @@ +[ + { + "prompt": "Explicitly state the quotes directly taken from the book inside double quotes like this: \n A. < add quote> \n Metaphor: \n B. < add quote> \n Metaphor: \n C. < add quote> \n Metaphor: \n with the metaphors after each quote. Double-check that the quotes are from the text specified above and that the metaphors relate to AI. End your answer after the 3 quotes / metaphors are finished.", + "context": "https://www.gutenberg.org/cache/epub/84/pg84.txt", + "max_length": 130000 + } +] diff --git a/models/demos/llama3/demo/input_data_long_64k.json b/models/demos/llama3/demo/input_data_long_64k.json new file mode 100644 index 00000000000..c6c841582cc --- /dev/null +++ b/models/demos/llama3/demo/input_data_long_64k.json @@ -0,0 +1,7 @@ +[ + { + "prompt": "Explicitly state the quotes directly taken from the book inside double quotes like this: \n A. < add quote> \n Metaphor: \n B. < add quote> \n Metaphor: \n C. < add quote> \n Metaphor: \n with the metaphors after each quote. Double-check that the quotes are from the text specified above and that the metaphors relate to AI. End your answer after the 3 quotes / metaphors are finished.", + "context": "https://www.gutenberg.org/cache/epub/84/pg84.txt", + "max_length": 260000 + } +] diff --git a/models/demos/llama3/demo/simple_vision_demo.py b/models/demos/llama3/demo/simple_vision_demo.py index 673c3bc5a73..b4946c3eecf 100644 --- a/models/demos/llama3/demo/simple_vision_demo.py +++ b/models/demos/llama3/demo/simple_vision_demo.py @@ -49,8 +49,6 @@ def create_multimodal_model(mesh_device, max_batch_size, max_seq_len, dtype=ttnn tt_model_args = TtModelArgs(mesh_device, max_batch_size=max_batch_size) # limit length or we'll run out of space tt_model_args.max_seq_len = max_seq_len - tt_model_args.kv_seq_len = max_seq_len - tt_model_args.sliding_window = max_seq_len checkpoint = torch.load(tt_model_args.consolidated_weights_path, map_location="cpu", weights_only=True) model = CrossAttentionTransformer( mesh_device, diff --git a/models/demos/llama3/lt b/models/demos/llama3/lt index 23a8fb33f15..594568609ba 100755 --- a/models/demos/llama3/lt +++ b/models/demos/llama3/lt @@ -389,12 +389,20 @@ def main(stdscr): commands = parse_list(command_input, allow_space=False) # Generate combinations (reordered) + # Ignore invalid combinations: + # - 11b and 11b-b models on n150 device + # - 70b model on n150 and n300 devices + # - Vision commands on non-vision (11b) models combinations = [ (c, m, d) for c in commands for m in models for d in devices - if not ((m in ["11b", "11b-b"] and d == "n150") or (m == "70b" and d in ["n150", "n300"])) + if not ( + (m in ["11b", "11b-b"] and d == "n150") + or (m == "70b" and d in ["n150", "n300"]) + or ("vision" in c and m not in ["11b", "11b-b"]) + ) ] # Create output entries @@ -723,8 +731,9 @@ def run_entry_command(entry, screen_lock, output_entries, screen_needs_update): # Define command shortcuts command_shortcuts = { - "demo": "pytest models/demos/llama3/demo/demo.py -k instruct_weights-1", - "demo-1layer": "pytest models/demos/llama3/demo/demo.py -k single_layer", + "demo": "pytest models/demos/llama3/demo/demo.py -k performance-batch-1", + "demo-32": "pytest models/demos/llama3/demo/demo.py -k performance-batch-32", + "demo-long": "pytest models/demos/llama3/demo/demo.py -k long", "attention": "pytest models/demos/llama3/tests/test_llama_attention.py", "attention-prefill": "pytest models/demos/llama3/tests/test_llama_attention_prefill.py", "mlp": "pytest models/demos/llama3/tests/test_llama_mlp.py", @@ -733,9 +742,10 @@ def run_entry_command(entry, screen_lock, output_entries, screen_needs_update): "decoder": "pytest models/demos/llama3/tests/test_llama_decoder.py", "decoder-prefill": "pytest models/demos/llama3/tests/test_llama_decoder_prefill.py", "lm-head": "pytest models/demos/llama3/tests/test_lm_head.py", - "model": "pytest models/demos/llama3/tests/test_llama_model.py -k performance-full", - "model-quick": "pytest models/demos/llama3/tests/test_llama_model.py -k performance-quick", - "model-prefill": "pytest models/demos/llama3/tests/test_llama_model_prefill.py -k performance", + "model": "pytest models/demos/llama3/tests/test_llama_model.py -k 'performance-128 and full'", + "model-quick": "pytest models/demos/llama3/tests/test_llama_model.py -k 'performance-128 and quick'", + "model-prefill": "pytest models/demos/llama3/tests/test_llama_model_prefill.py -k performance-4096", + # Vision tests (require 11B weights) "vision-mlp": "pytest models/demos/llama3/tests/multimodal/test_llama_image_mlp.py", "vision-attn": "pytest models/demos/llama3/tests/multimodal/test_llama_image_attention.py", "vision-block": "pytest models/demos/llama3/tests/multimodal/test_llama_image_block.py", @@ -749,7 +759,6 @@ def run_entry_command(entry, screen_lock, output_entries, screen_needs_update): "vision-encoder": "pytest models/demos/llama3/tests/multimodal/test_llama_vision_encoder.py", "vision-text-xfmr": "pytest models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py", "vision-vision-xfmr": "pytest models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_vision.py", - "perf": "pytest models/demos/llama3/tests/test_llama_perf.py -k 1024", "accuracy": "pytest models/demos/llama3/tests/test_llama_accuracy.py -k performance", } diff --git a/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py b/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py index 172531645c9..7448601b8ce 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py +++ b/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py @@ -63,8 +63,6 @@ def test_llama_cross_attention_transformer_text_inference( model_args = TtModelArgs(mesh_device, max_batch_size=batch) # Limit the max seqlen to 4k to avoid OOM on host model_args.max_seq_len = 4096 - model_args.kv_seq_len = model_args.max_seq_len - model_args.sliding_window = model_args.max_seq_len state_dict = torch.load(model_args.consolidated_weights_path, map_location=torch.device("cpu")) diff --git a/models/demos/llama3/tests/multimodal/test_llama_cross_block.py b/models/demos/llama3/tests/multimodal/test_llama_cross_block.py index 1b0013c78ee..96637e5090c 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_cross_block.py +++ b/models/demos/llama3/tests/multimodal/test_llama_cross_block.py @@ -46,8 +46,6 @@ def test_llama_cross_attention_transformer_block_inference( model_args = TtModelArgs(mesh_device, max_batch_size=batch) # Limit the max seqlen to 4k to avoid OOM on host model_args.max_seq_len = 4096 - model_args.kv_seq_len = model_args.max_seq_len - model_args.sliding_window = model_args.max_seq_len state_dict = torch.load(model_args.consolidated_weights_path, map_location=torch.device("cpu")) # Ref model needs partial state dict, but our models use full state dict keys as cached weight names diff --git a/models/demos/llama3/tests/test_interleaved_to_sharded.py b/models/demos/llama3/tests/test_interleaved_to_sharded.py index 6a0a7bb9fc2..b69d7d2459b 100644 --- a/models/demos/llama3/tests/test_interleaved_to_sharded.py +++ b/models/demos/llama3/tests/test_interleaved_to_sharded.py @@ -8,7 +8,6 @@ import ttnn from models.demos.llama3.tt.llama_common import ( precompute_freqs, - get_single_rot_mat, ) from models.demos.llama3.tt.llama_decoder import TtTransformerBlock from models.demos.llama3.tt.model_config import TtModelArgs @@ -51,14 +50,6 @@ def test_llama_decoder_inference(mesh_device, use_program_cache, reset_seeds): generation_length = 10 all_tests_pass = True - # pre-compute the rotational embedding matrix and send to device - current_rot_mat, rot_matrix = get_single_rot_mat( - model_args.head_dim, - mesh_device, - model_args.num_devices, - start_pos=0, - ) - # Initialize TT model tt_model = TtTransformerBlock( args=model_args, diff --git a/models/demos/llama3/tests/test_llama_accuracy.py b/models/demos/llama3/tests/test_llama_accuracy.py index 8ce5648a8a8..2ae973a907d 100644 --- a/models/demos/llama3/tests/test_llama_accuracy.py +++ b/models/demos/llama3/tests/test_llama_accuracy.py @@ -8,13 +8,14 @@ import os import ttnn from models.demos.llama3.tt.llama_common import ( - get_single_rot_mat, get_prefill_rot_mat, get_rot_transformation_mat, HostEmbedding, + PagedAttentionConfig, ) from models.demos.llama3.tt.llama_model import TtTransformer from models.demos.llama3.tt.model_config import TtModelArgs, LlamaOptimizations +from models.demos.llama3.tt.llama_rope import TtLlamaRotarySetup from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.tokenizer import Tokenizer from models.demos.llama3.demo.demo import preprocess_inputs_prefill from pathlib import Path @@ -58,8 +59,10 @@ def get_accuracy_thresholds(model_name: str, device_name: str, optimizations: Ll @torch.no_grad() @pytest.mark.timeout(900) -@pytest.mark.parametrize("prefill_len", [512]) -@pytest.mark.parametrize("decode_len", [128]) +@pytest.mark.parametrize( + "prefill_len, decode_len, max_seq_len", # Max seqlen should be at least prefill_len + decode_len + ((512, 128, 1024),), +) @pytest.mark.parametrize( "mesh_device", [ @@ -76,13 +79,47 @@ def get_accuracy_thresholds(model_name: str, device_name: str, optimizations: Ll pytest.param(LlamaOptimizations.performance, id="performance"), ], ) -def test_tt_model_accuracy(mesh_device, prefill_len, decode_len, use_program_cache, reset_seeds, optimizations): +@pytest.mark.parametrize( + "paged_attention", + ( + True, + # False + ), + ids=( + "paged_attention", + # "default_attention" + ), +) +@pytest.mark.parametrize( + "page_params", + [{"page_block_size": 32, "page_max_num_blocks": 1024}], +) +@pytest.mark.parametrize( + "batch_size", + (1,), +) +def test_tt_model_accuracy( + prefill_len, + decode_len, + max_seq_len, + batch_size, + paged_attention, + page_params, + optimizations, + mesh_device, + use_program_cache, + reset_seeds, + ensure_gc, +): dtype = ttnn.bfloat8_b mesh_device.enable_async(True) # Load model args and tokenizer - model_args = TtModelArgs(mesh_device, optimizations=optimizations) + model_args = TtModelArgs( + mesh_device, optimizations=optimizations, max_batch_size=batch_size, max_seq_len=max_seq_len + ) + tokenizer = Tokenizer(model_args.tokenizer_path) # Load state_dict for TT model @@ -104,6 +141,51 @@ def test_tt_model_accuracy(mesh_device, prefill_len, decode_len, use_program_cac N = prefill_len + decode_len input_ids = reference_tokens[:, : N + 1] # Shape [1, N+1] + # Setup RoPE transformation matrices + rope_setup = TtLlamaRotarySetup( + mesh_device, + model_args.max_batch_size, + model_args.head_dim, + model_args.max_seq_len, + model_args.rope_theta, + model_args.use_scaled_rope, + ) + transformation_mats_decode = rope_setup.get_trans_mats() + + transformation_mats_prefill_torch = get_rot_transformation_mat(model_args.head_dim) + transformation_mats_prefill = ttnn.from_torch( + transformation_mats_prefill_torch, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + device=mesh_device, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) + transformation_mats = {"decode": transformation_mats_decode, "prefill": transformation_mats_prefill} + + page_table_tt = None + paged_attention_config = None + + if paged_attention: + paged_attention_config = PagedAttentionConfig( + block_size=page_params["page_block_size"], + max_num_blocks=page_params["page_max_num_blocks"], + ) + # Implied shuffling of blocks + permutation = torch.randperm(paged_attention_config.max_num_blocks) + # Page table which maps virtual blocks to physical + reverse_permutation = torch.argsort(permutation) + page_table = reverse_permutation.reshape( + model_args.max_batch_size, paged_attention_config.max_num_blocks // model_args.max_batch_size + ) + page_table_tt = ttnn.from_torch( + page_table, + device=mesh_device, + dtype=ttnn.int32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) + # Initialize TT model tt_model = TtTransformer( args=model_args, @@ -111,6 +193,8 @@ def test_tt_model_accuracy(mesh_device, prefill_len, decode_len, use_program_cac dtype=dtype, state_dict=state_dict, weight_cache_path=model_args.weight_cache_path(dtype), + transformation_mats=transformation_mats, + paged_attention_config=paged_attention_config, ) # Initialize embedding embd = HostEmbedding(model_args) @@ -138,18 +222,9 @@ def test_tt_model_accuracy(mesh_device, prefill_len, decode_len, use_program_cac pt_prefill_input = [embd(input_tokens_prefill_pt[b]).view(1, prefill_lens[b], -1) for b in range(1)] # Pre-compute the rotational embedding matrix and send to device - rot_mats = get_prefill_rot_mat( + rot_mats_prefill = get_prefill_rot_mat( model_args.head_dim, model_args.max_seq_len, mesh_device, seq_len=prefill_lens[0] ) - transformation_mat_torch = get_rot_transformation_mat(model_args.head_dim) - transformation_mats = ttnn.from_torch( - transformation_mat_torch, - dtype=ttnn.bfloat16, - layout=ttnn.TILE_LAYOUT, - device=mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), - memory_config=ttnn.DRAM_MEMORY_CONFIG, - ) prefill_input = model_args.prepare_inputs_ttnn_prefill( pt_prefill_input[batch_id], @@ -157,11 +232,11 @@ def test_tt_model_accuracy(mesh_device, prefill_len, decode_len, use_program_cac tt_out = tt_model( prefill_input, - None, # Current position - rot_mats, - transformation_mats, + current_pos=None, + rot_mats=rot_mats_prefill, user_id=batch_id, mode="prefill", + page_table=page_table_tt, get_last_token=((decoding_pos[batch_id] - 1) // 32) * 32, ) @@ -169,19 +244,19 @@ def test_tt_model_accuracy(mesh_device, prefill_len, decode_len, use_program_cac logger.info(f"Starting decode...") generation_start_pos = prefill_len generation_length = decode_len - current_pos = ttnn.from_torch( - torch.tensor([generation_start_pos]), + + # Initial positions + decoding_pos = [generation_start_pos] * model_args.max_batch_size + current_pos = torch.tensor([decoding_pos[b] for b in range(model_args.max_batch_size)]) + current_pos_tensor = ttnn.from_torch( + current_pos, device=mesh_device, dtype=ttnn.int32, mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), ) - current_rot_mat, rot_matrix = get_single_rot_mat( - model_args.head_dim, - mesh_device, - model_args.num_devices, - start_pos=max(0, generation_start_pos - 1), - ) + # Get cos/sin matrices for the current position of each user + rot_mats = rope_setup.get_rot_mats(current_pos) # Print table header logger.info(f"{'Progress':<15}{'Correct':<8}{'True':<15}{'Actual':<15}{'Top 5 Predictions':<75}") @@ -206,7 +281,13 @@ def test_tt_model_accuracy(mesh_device, prefill_len, decode_len, use_program_cac model_args.model_config["DECODE_RESIDUAL_MEMCFG"], ) # Run TT model - tt_out = tt_model(decode_input, current_pos, rot_mat=current_rot_mat) + tt_out = tt_model( + decode_input, + current_pos_tensor, + rot_mats=rot_mats, + mode="decode", + page_table=page_table_tt, + ) if tt_model.args.num_devices > 1: tt_out_gathered = ttnn.all_gather(tt_out, dim=3, num_links=1, topology=ttnn.Topology.Linear) @@ -215,23 +296,20 @@ def test_tt_model_accuracy(mesh_device, prefill_len, decode_len, use_program_cac tt_out_gathered = tt_out tt_out_rm = ttnn.untilize(tt_out_gathered, use_multicore=True) ttnn.deallocate(tt_out_gathered) - tt_out_tok = ttnn.argmax(tt_out_rm, dim=3, use_multicore=True) + tt_out_tok = ttnn.argmax( + tt_out_rm, + dim=3, + use_multicore=True if model_args.max_batch_size == 1 else False, + ) tt_argmax_token = ttnn.to_torch(tt_out_tok, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1))[ 0, 0, 0, 0 ] ttnn.deallocate(tt_out_rm) - current_rot_mat = ttnn.linear(rot_matrix, current_rot_mat) - ttnn.plus_one(current_pos) - - # Reset rotation matrix every 100 iterations - if i % 100 == 0: # Doing this every 100 iterations as in demo takes top5 from 99% -> - current_rot_mat, rot_matrix_reset = get_single_rot_mat( - model_args.head_dim, - mesh_device, - model_args.num_devices, - start_pos=generation_start_pos + i, - on_host=False, - ) + ttnn.plus_one(current_pos_tensor) + + # Update rot_mats for next iteration + current_pos += 1 + rot_mats = rope_setup.get_rot_mats(current_pos) # Get reference top5 tokens and probabilities for this position ref_top5_tokens = top5_tokens[prefill_len + i] diff --git a/models/demos/llama3/tests/test_llama_attention.py b/models/demos/llama3/tests/test_llama_attention.py index c41ac5644ca..8690b91d3b9 100644 --- a/models/demos/llama3/tests/test_llama_attention.py +++ b/models/demos/llama3/tests/test_llama_attention.py @@ -7,10 +7,11 @@ import os import ttnn from models.demos.llama3.tt.llama_attention import TtLlamaAttention +from models.demos.llama3.tt.llama_rope import TtLlamaRotarySetup from models.demos.llama3.tt.model_config import TtModelArgs from models.demos.llama3.tt.llama_common import ( precompute_freqs, - get_single_rot_mat, + PagedAttentionConfig, ) from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.model import Attention from models.utility_functions import ( @@ -31,14 +32,47 @@ ], indirect=True, ) -def test_llama_attention_inference(mesh_device, use_program_cache, reset_seeds, ensure_gc): +@pytest.mark.parametrize( + "paged_attention", + ( + True, + False, + ), + ids=( + "paged_attention", + "default_attention", + ), +) +@pytest.mark.parametrize( + "page_params", + [{"page_block_size": 32, "page_max_num_blocks": 1024}], +) +@pytest.mark.parametrize( + "batch_size", + (1,), +) +@pytest.mark.parametrize( + "max_seq_len", + (128,), # For decode-only unit test, there's no need to run with large sequence lengths +) +def test_llama_attention_inference( + max_seq_len, + batch_size, + paged_attention, + page_params, + mesh_device, + use_program_cache, + reset_seeds, + ensure_gc, +): dtype = ttnn.bfloat8_b pcc = 0.99 mesh_device.enable_async(True) - model_args = TtModelArgs(mesh_device) - model_args.n_layers = 1 + model_args = TtModelArgs(mesh_device, max_batch_size=batch_size, max_seq_len=max_seq_len) + model_args.n_layers = 1 # For the unit test, just run a sigle layer + state_dict = model_args.load_state_dict() first_layer_prefix = model_args.get_state_dict_prefix("TtLlamaAttention", 0) + "." @@ -50,44 +84,79 @@ def test_llama_attention_inference(mesh_device, use_program_cache, reset_seeds, reference_model = Attention(args=model_args) reference_model.load_state_dict(partial_state_dict) - batch = model_args.max_batch_size seq_len = 1 generation_start_pos = 0 generation_length = 10 all_tests_pass = True - # pre-compute the rotational embedding matrix and send to device - current_rot_mat, rot_matrix = get_single_rot_mat( - model_args.head_dim, + # Setup RoPE transformation matrices + rope_setup = TtLlamaRotarySetup( mesh_device, - model_args.num_devices, - start_pos=0, + batch_size, + model_args.head_dim, + model_args.max_seq_len, + model_args.rope_theta, + model_args.use_scaled_rope, ) + transformation_mats = rope_setup.get_trans_mats() + transformation_mats = {"decode": transformation_mats} + + page_table_tt = None + paged_attention_config = None + + if paged_attention: + paged_attention_config = PagedAttentionConfig( + block_size=page_params["page_block_size"], + max_num_blocks=page_params["page_max_num_blocks"], + ) + + # Implied shuffling of blocks + permutation = torch.randperm(paged_attention_config.max_num_blocks) + # Page table which maps virtual blocks to physical + reverse_permutation = torch.argsort(permutation) + page_table = reverse_permutation.reshape( + model_args.max_batch_size, paged_attention_config.max_num_blocks // model_args.max_batch_size + ) + page_table_tt = ttnn.from_torch( + page_table, + device=mesh_device, + dtype=ttnn.int32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) + tt_model = TtLlamaAttention( mesh_device, state_dict, weight_cache_path=model_args.weight_cache_path(dtype), layer_num=0, dtype=dtype, + transformation_mats=transformation_mats, configuration=model_args, + paged_attention_config=paged_attention_config, ) - cos, sin = precompute_freqs(model_args.head_dim, model_args.max_seq_len * 2) + cos, sin = precompute_freqs( + model_args.head_dim, model_args.max_seq_len * 2, model_args.rope_theta, model_args.use_scaled_rope + ) freqs_cis = torch.complex(cos, sin) + + # Initial positions + current_pos = torch.tensor([generation_start_pos for _ in range(batch_size)]) + current_pos_tensor = ttnn.from_torch( + current_pos, + device=mesh_device, + dtype=ttnn.int32, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) + for i in range(generation_length): # 70B attention block typically sees tensors with mean 0 and std 0.03 - 0.05 in layer 1 - pt_attention_input = torch.randn(batch, seq_len, model_args.dim) * 0.05 + pt_attention_input = torch.randn(batch_size, seq_len, model_args.dim) * 0.05 tt_attention_input = pt_attention_input.clone() - current_pos = generation_start_pos + i - current_pos_tensor = ttnn.from_torch( - torch.tensor([current_pos] * batch), - device=mesh_device, - dtype=ttnn.int32, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), - ) attention_input = model_args.prepare_inputs_ttnn_decode( tt_attention_input, @@ -95,48 +164,84 @@ def test_llama_attention_inference(mesh_device, use_program_cache, reset_seeds, force_replicated=True, ) - tt_out = tt_model(attention_input, current_pos_tensor, rot_mats=current_rot_mat, mode="decode") + # Get cos/sin matrices for the current position of each user + rot_mats = rope_setup.get_rot_mats(current_pos) + + tt_out = tt_model( + attention_input, + current_pos_tensor, + rot_mats=rot_mats, + mode="decode", + page_table=page_table_tt, + ) # multi-device attention module returns replicated output tt_output_torch = ( ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=-1))[0, :, :, : model_args.dim] .view(1, -1, model_args.dim) .permute(1, 0, 2)[: model_args.max_batch_size, :, :] - ) # [ batch, seq, hidden_dim] + ) # [ batch_size, seq, hidden_dim] - freqs_cis_i = freqs_cis[current_pos, :].unsqueeze(0) - # positions = torch.tensor([current_pos]) + # In this test all users have the same position + freqs_cis_i = freqs_cis[current_pos[0], :].unsqueeze(0) - reference_output = reference_model(pt_attention_input, current_pos, freqs_cis_i, mask=None) + reference_output = reference_model(pt_attention_input, current_pos[0], freqs_cis_i, mask=None) passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc) logger.info(comp_allclose(reference_output, tt_output_torch)) logger.info(f"PCC: {pcc_message}") if passing: - logger.info(f"[pos={current_pos}] Llama_Attention Passed!") + logger.info(f"[pos={current_pos[0]}] Llama_Attention Passed!") else: - logger.warning(f"[pos={current_pos}] Llama_Attention Failed!") + logger.warning(f"[pos={current_pos[0]}] Llama_Attention Failed!") all_tests_pass = False - # Update rotation matrix for next iteration - current_rot_mat = ttnn.linear(rot_matrix, current_rot_mat) + # Increment position + current_pos = torch.tensor([generation_start_pos + i for _ in range(batch_size)]) + current_pos_tensor = ttnn.from_torch( + current_pos, + device=mesh_device, + dtype=ttnn.int32, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) check_kv_cache = True if check_kv_cache: # PyTorch output -------------------------------------------------------------------- pytorch_layer_present = [ - reference_model.cache_k.clone().permute(0, 2, 1, 3), # [batch, n_kv_heads, seq, head_dim] - reference_model.cache_v.clone().permute(0, 2, 1, 3), # [batch, n_kv_heads, seq, head_dim] + reference_model.cache_k.clone().permute(0, 2, 1, 3), # [batch_size, n_kv_heads, seq, head_dim] + reference_model.cache_v.clone().permute(0, 2, 1, 3), # [batch_size, n_kv_heads, seq, head_dim] ] # TT hardware execution ------------------------------------------------------------- - tt_layer_present = [ - ttnn.to_torch(cache, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1)) - for cache in tt_model.layer_past - ] + if paged_attention: + tt_layer_present = [ + ( + ttnn.to_torch(cache, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1))[ + reverse_permutation + ] + .reshape( + model_args.max_batch_size, + paged_attention_config.max_num_blocks // model_args.max_batch_size, + model_args.n_kv_heads, + paged_attention_config.block_size, + model_args.head_dim, + ) + .transpose(1, 2) + .reshape(model_args.max_batch_size, model_args.n_kv_heads, -1, model_args.head_dim)[ + :batch_size, ... + ] + ) + for cache in tt_model.layer_past + ] + else: + tt_layer_present = [ + ttnn.to_torch(cache, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1)) + for cache in tt_model.layer_past + ] for i, (cache_pt, cache_tt) in enumerate(zip(pytorch_layer_present, tt_layer_present)): - cache_length_to_check = min(model_args.sliding_window, generation_start_pos + generation_length + 1) + cache_length_to_check = min(model_args.max_seq_len, generation_start_pos + generation_length + 1) cache_pt = cache_pt[:, :, generation_start_pos:cache_length_to_check, :] cache_tt = cache_tt[:, :, generation_start_pos:cache_length_to_check, :] does_pass, output_pcc = comp_pcc(cache_pt, cache_tt, pcc) diff --git a/models/demos/llama3/tests/test_llama_attention_prefill.py b/models/demos/llama3/tests/test_llama_attention_prefill.py index fe3f1834eae..ef33adc4481 100644 --- a/models/demos/llama3/tests/test_llama_attention_prefill.py +++ b/models/demos/llama3/tests/test_llama_attention_prefill.py @@ -11,6 +11,7 @@ from models.demos.llama3.tt.llama_common import ( get_prefill_rot_mat, get_rot_transformation_mat, + PagedAttentionConfig, ) from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.model import Attention, precompute_freqs_cis from models.utility_functions import ( @@ -22,10 +23,6 @@ @torch.no_grad() @skip_for_grayskull("Requires wormhole_b0 to run") -@pytest.mark.parametrize( - "seq_len", - (2048,), -) @pytest.mark.parametrize( "mesh_device", [ @@ -35,13 +32,50 @@ ], indirect=True, ) -def test_llama_attention_inference(seq_len, mesh_device, use_program_cache, reset_seeds, ensure_gc): +# Model and attention prefill tests should run both with and without paged attention to debug any issues that may occur with default attention +@pytest.mark.parametrize( + "paged_attention", + ( + True, + False, + ), + ids=( + "paged_attention", + "default_attention", + ), +) +@pytest.mark.parametrize( + "page_params", + [{"page_block_size": 32, "page_max_num_blocks": 1024}], +) +@pytest.mark.parametrize( + "batch_size", + (1,), +) +@pytest.mark.parametrize( + "max_seq_len", + ( + 2048, + # 1024 * 32, + # 1024 * 64, + ), +) +def test_llama_attention_inference( + max_seq_len, + batch_size, + paged_attention, + page_params, + mesh_device, + use_program_cache, + reset_seeds, + ensure_gc, +): dtype = ttnn.bfloat8_b pcc = 0.99 mesh_device.enable_async(True) - model_args = TtModelArgs(mesh_device) + model_args = TtModelArgs(mesh_device, max_batch_size=batch_size, max_seq_len=max_seq_len) model_args.n_layers = 1 state_dict = model_args.load_state_dict() @@ -53,51 +87,84 @@ def test_llama_attention_inference(seq_len, mesh_device, use_program_cache, rese reference_model = Attention(args=model_args) reference_model.load_state_dict(partial_state_dict) - batch = 1 - # pre-compute the rotational embedding matrix and send to device - rot_mats = get_prefill_rot_mat(model_args.head_dim, model_args.max_seq_len, mesh_device, seq_len=seq_len) + rot_mats = get_prefill_rot_mat(model_args.head_dim, model_args.max_seq_len, mesh_device, seq_len=max_seq_len) transformation_mat_torch = get_rot_transformation_mat(model_args.head_dim) - transformation_mats = ttnn.as_tensor( + transformation_mats_prefill = ttnn.as_tensor( transformation_mat_torch, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), ) + transformation_mats = {"prefill": transformation_mats_prefill} + generation_start_pos = 0 generation_length = 3 all_tests_pass = True + # Setup page table + page_table_tt = None + paged_attention_config = None + + if paged_attention: + paged_attention_config = PagedAttentionConfig( + block_size=page_params["page_block_size"], + max_num_blocks=page_params["page_max_num_blocks"], + ) + # Implied shuffling of blocks + permutation = torch.randperm(paged_attention_config.max_num_blocks) + # Page table which maps virtual blocks to physical + reverse_permutation = torch.argsort(permutation) + page_table = reverse_permutation.reshape( + model_args.max_batch_size, paged_attention_config.max_num_blocks // model_args.max_batch_size + ) + page_table_tt = ttnn.from_torch( + page_table, + device=mesh_device, + dtype=ttnn.int32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) + tt_model = TtLlamaAttention( mesh_device, state_dict, weight_cache_path=model_args.weight_cache_path(dtype), layer_num=0, dtype=dtype, + transformation_mats=transformation_mats, configuration=model_args, + paged_attention_config=paged_attention_config, ) - pt_attention_input = (torch.rand(batch, seq_len, model_args.dim) * 2) - 1 + pt_attention_input = (torch.rand(batch_size, max_seq_len, model_args.dim) * 2) - 1 tt_attention_input = pt_attention_input.clone() attention_input = model_args.prepare_inputs_ttnn_prefill( tt_attention_input, force_replicated=True, ) - tt_out = tt_model(attention_input, 0, rot_mats, transformation_mats, user_id=0, mode="prefill") + tt_out = tt_model( + attention_input, + current_pos=None, + rot_mats=rot_mats, + user_id=0, + mode="prefill", + page_table=page_table_tt, + ) tt_output_torch = ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=-1))[ 0, :, :, : model_args.dim ].view( - batch, seq_len, -1 - ) # [ batch, seq, dim] + batch_size, max_seq_len, -1 + ) # [ batch_size, seq, dim] - positions = torch.LongTensor(range(seq_len)) + positions = torch.LongTensor(range(max_seq_len)) freqs_cis_i = precompute_freqs_cis( model_args.head_dim, model_args.max_seq_len * 2, model_args.rope_theta, model_args.use_scaled_rope )[positions] - attn_mask = torch.full((seq_len, seq_len), torch.finfo(torch.float32).min) + attn_mask = torch.full((max_seq_len, max_seq_len), torch.finfo(torch.float32).min) attn_mask_torch = torch.triu(attn_mask, diagonal=1) reference_output = reference_model(pt_attention_input, positions[0], freqs_cis_i, mask=attn_mask_torch) @@ -115,17 +182,36 @@ def test_llama_attention_inference(seq_len, mesh_device, use_program_cache, rese if check_kv_cache: # PyTorch output -------------------------------------------------------------------- pytorch_layer_present = [ - reference_model.cache_k.clone().permute(0, 2, 1, 3), # [batch, n_kv_heads, seq, head_dim] - reference_model.cache_v.clone().permute(0, 2, 1, 3), # [batch, n_kv_heads, seq, head_dim] + reference_model.cache_k.clone().permute(0, 2, 1, 3), # [batch_size, n_kv_heads, seq, head_dim] + reference_model.cache_v.clone().permute(0, 2, 1, 3), # [batch_size, n_kv_heads, seq, head_dim] ] # TT hardware execution ------------------------------------------------------------- - tt_layer_present = [ - ttnn.to_torch(cache, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1)) - for cache in tt_model.layer_past - ] + if paged_attention: + tt_layer_present = [ + ( + ttnn.to_torch(cache, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1))[reverse_permutation] + .reshape( + model_args.max_batch_size, + paged_attention_config.max_num_blocks // model_args.max_batch_size, + model_args.n_kv_heads, + paged_attention_config.block_size, + model_args.head_dim, + ) + .transpose(1, 2) + .reshape(model_args.max_batch_size, model_args.n_kv_heads, -1, model_args.head_dim)[ + :batch_size, ... + ] + ) + for cache in tt_model.layer_past + ] + else: + tt_layer_present = [ + ttnn.to_torch(cache, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1)) + for cache in tt_model.layer_past + ] for i, (cache_pt, cache_tt) in enumerate(zip(pytorch_layer_present, tt_layer_present)): - cache_length_to_check = min(model_args.sliding_window, generation_start_pos + generation_length + 1) + cache_length_to_check = min(model_args.max_seq_len, generation_start_pos + generation_length + 1) cache_pt = cache_pt[:, :, generation_start_pos:cache_length_to_check, :] cache_tt = cache_tt[:, :, generation_start_pos:cache_length_to_check, :] does_pass, output_pcc = comp_pcc(cache_pt, cache_tt, pcc) diff --git a/models/demos/llama3/tests/test_llama_decoder.py b/models/demos/llama3/tests/test_llama_decoder.py index 1fad070640b..5d24d3b4298 100644 --- a/models/demos/llama3/tests/test_llama_decoder.py +++ b/models/demos/llama3/tests/test_llama_decoder.py @@ -8,10 +8,11 @@ import ttnn from models.demos.llama3.tt.llama_common import ( precompute_freqs, - get_single_rot_mat, + PagedAttentionConfig, ) -from models.demos.llama3.tt.llama_decoder import TtTransformerBlock from models.demos.llama3.tt.model_config import TtModelArgs +from models.demos.llama3.tt.llama_decoder import TtTransformerBlock +from models.demos.llama3.tt.llama_rope import TtLlamaRotarySetup from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.model import TransformerBlock from models.utility_functions import ( comp_pcc, @@ -31,13 +32,45 @@ ], indirect=True, ) -def test_llama_decoder_inference(mesh_device, use_program_cache, reset_seeds, ensure_gc): +@pytest.mark.parametrize( + "paged_attention", + ( + True, + # False + ), + ids=( + "paged_attention", + # "default_attention" + ), +) +@pytest.mark.parametrize( + "page_params", + [{"page_block_size": 32, "page_max_num_blocks": 1024}], +) +@pytest.mark.parametrize( + "batch_size", + (1,), +) +@pytest.mark.parametrize( + "max_seq_len", + (128,), # For decode-only unit test, there's no need to run with large sequence lengths +) +def test_llama_decoder_inference( + max_seq_len, + batch_size, + paged_attention, + page_params, + mesh_device, + use_program_cache, + reset_seeds, + ensure_gc, +): dtype = ttnn.bfloat8_b - mesh_device.enable_async(True) - model_args = TtModelArgs(mesh_device) + model_args = TtModelArgs(mesh_device, max_batch_size=batch_size, max_seq_len=max_seq_len) model_args.n_layers = 1 + state_dict = model_args.load_state_dict() # Ref model needs partial state dict, but our models use full state dict keys as cached weight names @@ -52,13 +85,41 @@ def test_llama_decoder_inference(mesh_device, use_program_cache, reset_seeds, en generation_length = 10 all_tests_pass = True - # pre-compute the rotational embedding matrix and send to device - current_rot_mat, rot_matrix = get_single_rot_mat( - model_args.head_dim, + # Setup RoPE transformation matrices + rope_setup = TtLlamaRotarySetup( mesh_device, - model_args.num_devices, - start_pos=0, + model_args.max_batch_size, + model_args.head_dim, + model_args.max_seq_len, + model_args.rope_theta, + model_args.use_scaled_rope, ) + transformation_mats = rope_setup.get_trans_mats() + transformation_mats = {"decode": transformation_mats} + + # Prepare page table for paged attention + page_table_tt = None + paged_attention_config = None + + if paged_attention: + paged_attention_config = PagedAttentionConfig( + block_size=page_params["page_block_size"], + max_num_blocks=page_params["page_max_num_blocks"], + ) + # Implied shuffling of blocks + permutation = torch.randperm(paged_attention_config.max_num_blocks) + # Page table which maps virtual blocks to physical + reverse_permutation = torch.argsort(permutation) + page_table = reverse_permutation.reshape( + model_args.max_batch_size, paged_attention_config.max_num_blocks // model_args.max_batch_size + ) + page_table_tt = ttnn.from_torch( + page_table, + device=mesh_device, + dtype=ttnn.int32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) # Initialize TT model tt_model = TtTransformerBlock( @@ -68,27 +129,31 @@ def test_llama_decoder_inference(mesh_device, use_program_cache, reset_seeds, en state_dict=state_dict, layer_num=0, weight_cache_path=model_args.weight_cache_path(dtype), + transformation_mats=transformation_mats, + paged_attention_config=paged_attention_config, ) seqlen = 1 - batch = model_args.max_batch_size - cos, sin = precompute_freqs(model_args.head_dim, model_args.max_seq_len * 2) + cos, sin = precompute_freqs( + model_args.head_dim, model_args.max_seq_len * 2, model_args.rope_theta, model_args.use_scaled_rope + ) freqs_cis = torch.complex(cos, sin) + # Initial positions + current_pos = torch.tensor([generation_start_pos for _ in range(batch_size)]) + current_pos_tensor = ttnn.from_torch( + current_pos, + device=mesh_device, + dtype=ttnn.int32, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) for i in range(generation_length): logger.info(f"[Decoder] Generating token {i}") # input = torch.randn(1, 32, 4096) - pt_decode_input = (torch.rand(batch, seqlen, model_args.dim) * 2) - 1 + pt_decode_input = (torch.rand(batch_size, seqlen, model_args.dim) * 2) - 1 tt_decode_input = pt_decode_input.clone() - current_pos = generation_start_pos + i - current_pos_tensor = ttnn.from_torch( - torch.tensor([current_pos] * batch), - device=mesh_device, - dtype=ttnn.int32, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), - ) decode_input = model_args.prepare_inputs_ttnn_decode( tt_decode_input, @@ -96,20 +161,31 @@ def test_llama_decoder_inference(mesh_device, use_program_cache, reset_seeds, en model_args.model_config["DECODE_RESIDUAL_MEMCFG"], ) + # Get cos/sin matrices for the current position of each user + rot_mats = rope_setup.get_rot_mats(current_pos) + # Run TT model - tt_out = tt_model(decode_input, current_pos_tensor, rot_mat=current_rot_mat) + tt_out = tt_model( + decode_input, + current_pos_tensor, + rot_mats=rot_mats, + mode="decode", + page_table=page_table_tt, + ) + tt_output_torch = ( ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=-1))[ :1, :, :, : model_args.dim ] .permute(2, 1, 0, 3) .squeeze(1)[: model_args.max_batch_size, :, :] - ) # [seq, batch, dim] + ) # [seq, batch_size, dim] - freqs_cis_i = freqs_cis[current_pos, :].unsqueeze(0) + # In this test all users have the same position + freqs_cis_i = freqs_cis[current_pos[0], :].unsqueeze(0) # Reference model - ref_output = reference_model(pt_decode_input, current_pos, freqs_cis_i, mask=None) + ref_output = reference_model(pt_decode_input, current_pos[0], freqs_cis_i, mask=None) passing, pcc_message = comp_pcc(ref_output, tt_output_torch) @@ -122,8 +198,14 @@ def test_llama_decoder_inference(mesh_device, use_program_cache, reset_seeds, en logger.warning("Llama Decoder Block Failed!") all_tests_pass = False - # Update rotation matrix for next iteration - current_rot_mat = ttnn.linear(rot_matrix, current_rot_mat) + # Increment position + current_pos = torch.tensor([generation_start_pos + i for _ in range(batch_size)]) + current_pos_tensor = ttnn.from_torch( + current_pos, + device=mesh_device, + dtype=ttnn.int32, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) if all_tests_pass: logger.info(f"All {generation_length} Llama decode iterations Passed!") diff --git a/models/demos/llama3/tests/test_llama_decoder_prefill.py b/models/demos/llama3/tests/test_llama_decoder_prefill.py index 998a4ab2f39..0c40e21b773 100644 --- a/models/demos/llama3/tests/test_llama_decoder_prefill.py +++ b/models/demos/llama3/tests/test_llama_decoder_prefill.py @@ -9,6 +9,7 @@ from models.demos.llama3.tt.llama_common import ( get_prefill_rot_mat, get_rot_transformation_mat, + PagedAttentionConfig, ) from models.demos.llama3.tt.llama_decoder import TtTransformerBlock from models.demos.llama3.tt.model_config import TtModelArgs @@ -22,13 +23,6 @@ @torch.no_grad() @skip_for_grayskull("Requires wormhole_b0 to run") -@pytest.mark.parametrize( - "seq_len", - ( - 4096, - 128, - ), -) @pytest.mark.parametrize( "mesh_device", [ @@ -38,13 +32,48 @@ ], indirect=True, ) -def test_llama_decoder_inference(mesh_device, seq_len, use_program_cache, reset_seeds, ensure_gc): +@pytest.mark.parametrize( + "paged_attention", + ( + True, + # False + ), + ids=( + "paged_attention", + # "default_attention" + ), +) +@pytest.mark.parametrize( + "page_params", + [{"page_block_size": 32, "page_max_num_blocks": 1024}], +) +@pytest.mark.parametrize( + "batch_size", + (1,), +) +@pytest.mark.parametrize( + "max_seq_len", + ( + 4096, + 128, + ), +) +def test_llama_decoder_inference( + max_seq_len, + batch_size, + paged_attention, + page_params, + mesh_device, + use_program_cache, + reset_seeds, + ensure_gc, +): dtype = ttnn.bfloat8_b - mesh_device.enable_async(True) - model_args = TtModelArgs(mesh_device) + model_args = TtModelArgs(mesh_device, max_batch_size=batch_size, max_seq_len=max_seq_len) model_args.n_layers = 1 + state_dict = model_args.load_state_dict() # Ref model needs partial state dict, but our models use full state dict keys as cached weight names @@ -52,7 +81,7 @@ def test_llama_decoder_inference(mesh_device, seq_len, use_program_cache, reset_ partial_state_dict = { k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) } - batch = 1 + reference_model = TransformerBlock(layer_id=0, args=model_args) reference_model.load_state_dict(partial_state_dict) @@ -61,51 +90,84 @@ def test_llama_decoder_inference(mesh_device, seq_len, use_program_cache, reset_ all_tests_pass = True # pre-compute the rotational embedding matrix and send to device - rot_mats = get_prefill_rot_mat(model_args.head_dim, model_args.max_seq_len, mesh_device, seq_len=seq_len) + rot_mats = get_prefill_rot_mat(model_args.head_dim, model_args.max_seq_len, mesh_device, seq_len=max_seq_len) transformation_mat_torch = get_rot_transformation_mat(model_args.head_dim) - transformation_mats = ttnn.as_tensor( + transformation_mats_prefill = ttnn.as_tensor( transformation_mat_torch, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), ) + transformation_mats = {"prefill": transformation_mats_prefill} + + # Setup page table + page_table_tt = None + paged_attention_config = None + + if paged_attention: + paged_attention_config = PagedAttentionConfig( + block_size=page_params["page_block_size"], + max_num_blocks=page_params["page_max_num_blocks"], + ) + # Implied shuffling of blocks + permutation = torch.randperm(paged_attention_config.max_num_blocks) + # Page table which maps virtual blocks to physical + reverse_permutation = torch.argsort(permutation) + page_table = reverse_permutation.reshape( + model_args.max_batch_size, paged_attention_config.max_num_blocks // model_args.max_batch_size + ) + page_table_tt = ttnn.from_torch( + page_table, + device=mesh_device, + dtype=ttnn.int32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) # Initialize TT model tt_model = TtTransformerBlock( - args=model_args, mesh_device=mesh_device, - dtype=dtype, state_dict=state_dict, - layer_num=0, weight_cache_path=model_args.weight_cache_path(dtype), + layer_num=0, + dtype=dtype, + transformation_mats=transformation_mats, + args=model_args, + paged_attention_config=paged_attention_config, ) - # TODO Update start_pos (check llama test for reference) for i in range(generation_length): logger.info(f"[Decoder] Generating token {i}") - pt_decode_input = (torch.rand(batch, seq_len, model_args.dim) * 2) - 1 + pt_decode_input = (torch.rand(batch_size, max_seq_len, model_args.dim) * 2) - 1 tt_decode_input = pt_decode_input.clone() decode_input = model_args.prepare_inputs_ttnn_prefill( tt_decode_input, ) - positions = torch.LongTensor(range(seq_len)) + positions = torch.LongTensor(range(max_seq_len)) freqs_cis_i = precompute_freqs_cis( model_args.head_dim, model_args.max_seq_len * 2, model_args.rope_theta, model_args.use_scaled_rope )[positions] # Reference model - attn_mask = torch.full((seq_len, seq_len), torch.finfo(torch.float32).min) + attn_mask = torch.full((max_seq_len, max_seq_len), torch.finfo(torch.float32).min) attn_mask_torch = torch.triu(attn_mask, diagonal=1) ref_output = reference_model(pt_decode_input, positions[0], freqs_cis_i, mask=attn_mask_torch) # Run TT model - tt_out = tt_model(decode_input, None, rot_mats, transformation_mats, user_id=0, mode="prefill") + tt_out = tt_model( + decode_input, + current_pos=None, + rot_mats=rot_mats, + user_id=0, + mode="prefill", + page_table=page_table_tt, + ) tt_output_torch = ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=-1))[ 0, :, :, : model_args.dim ].view( - batch, seq_len, -1 - ) # [ batch, seq, hidden_dim] + batch_size, max_seq_len, -1 + ) # [ batch_size, seq, hidden_dim] passing, pcc_message = comp_pcc(ref_output, tt_output_torch) logger.info(comp_allclose(ref_output, tt_output_torch)) diff --git a/models/demos/llama3/tests/test_llama_embedding.py b/models/demos/llama3/tests/test_llama_embedding.py index e8178f7e2e1..d5223b64254 100644 --- a/models/demos/llama3/tests/test_llama_embedding.py +++ b/models/demos/llama3/tests/test_llama_embedding.py @@ -28,15 +28,22 @@ ], indirect=True, ) -def test_llama_embedding(mesh_device, use_program_cache, reset_seeds, ensure_gc): +@pytest.mark.parametrize( + "batch_size", + (1,), +) +@pytest.mark.parametrize( + "max_seq_len", + (128,), # For decode-only unit test, there's no need to run with large sequence lengths +) +def test_llama_embedding(max_seq_len, batch_size, mesh_device, use_program_cache, reset_seeds, ensure_gc): dtype = ttnn.bfloat16 - mesh_device.enable_async(True) - model_args = TtModelArgs(mesh_device) + model_args = TtModelArgs(mesh_device, max_batch_size=batch_size, max_seq_len=max_seq_len) model_args.n_layers = 1 - state_dict = model_args.load_state_dict() + state_dict = model_args.load_state_dict() tokenizer = Tokenizer(model_args.tokenizer_path) reference_emb = HostEmbedding(model_args) diff --git a/models/demos/llama3/tests/test_llama_mlp.py b/models/demos/llama3/tests/test_llama_mlp.py index fa7655dd6ff..b810cb357bd 100644 --- a/models/demos/llama3/tests/test_llama_mlp.py +++ b/models/demos/llama3/tests/test_llama_mlp.py @@ -19,15 +19,6 @@ @torch.no_grad() @skip_for_grayskull("Requires wormhole_b0 to run") -@pytest.mark.parametrize( - "seq_len", - ( - 64 * 1024, - 32 * 1024, - # 1024, - 32, - ), -) @pytest.mark.parametrize( "mesh_device", [ @@ -37,13 +28,21 @@ ], indirect=True, ) -def test_llama_mlp_inference(mesh_device, seq_len, use_program_cache, reset_seeds, ensure_gc): +@pytest.mark.parametrize( + "seq_len", + ( + 64 * 1024, + 32 * 1024, + 32, + ), +) +def test_llama_mlp_inference(seq_len, mesh_device, use_program_cache, reset_seeds, ensure_gc): dtype = ttnn.bfloat8_b mode = "decode" if seq_len <= 32 else "prefill" mesh_device.enable_async(True) - model_args = TtModelArgs(mesh_device) + model_args = TtModelArgs(mesh_device, max_batch_size=1, max_seq_len=128) model_args.n_layers = 1 state_dict = model_args.load_state_dict() diff --git a/models/demos/llama3/tests/test_llama_model.py b/models/demos/llama3/tests/test_llama_model.py index efe43cf1a91..cd425579a23 100644 --- a/models/demos/llama3/tests/test_llama_model.py +++ b/models/demos/llama3/tests/test_llama_model.py @@ -8,13 +8,14 @@ import ttnn from models.demos.llama3.tt.llama_common import ( precompute_freqs, - get_single_rot_mat, - sample, + sample_host, encode_prompt_llama_instruct, HostEmbedding, + PagedAttentionConfig, ) -from models.demos.llama3.tt.llama_model import TtTransformer from models.demos.llama3.tt.model_config import TtModelArgs, LlamaOptimizations +from models.demos.llama3.tt.llama_model import TtTransformer +from models.demos.llama3.tt.llama_rope import TtLlamaRotarySetup from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.model import Transformer from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.tokenizer import Tokenizer from models.utility_functions import ( @@ -36,6 +37,29 @@ ], ids=["quick", "full"], ) +@pytest.mark.parametrize( + "paged_attention", + ( + True, + # False, + ), + ids=( + "paged_attention", + # "default_attention", + ), +) +@pytest.mark.parametrize( + "page_params", + [{"page_block_size": 32, "page_max_num_blocks": 1024}], +) +@pytest.mark.parametrize( + "batch_size", + (1,), +) +@pytest.mark.parametrize( + "max_seq_len", + (128,), # For decode-only unit test, there's no need to run with large sequence lengths +) @pytest.mark.parametrize( "optimizations", [ @@ -52,16 +76,34 @@ ], indirect=True, ) -def test_llama_model_inference(mesh_device, weights, layers, optimizations, use_program_cache, reset_seeds, ensure_gc): - mesh_device.enable_async(True) - +def test_llama_model_inference( + weights, + layers, + max_seq_len, + batch_size, + paged_attention, + page_params, + optimizations, + mesh_device, + use_program_cache, + reset_seeds, + ensure_gc, +): run_ref_pt = True # Flag to run reference PyTorch model and compare PCC cache_pcc = layers == 1 # Flag to measure KV cache PCC. Avoid running for all layers to speed up test time. dtype = ttnn.bfloat8_b + mesh_device.enable_async(True) mode_accuracy = optimizations == LlamaOptimizations.accuracy instruct = True if weights == "instruct" else False dummy_weights = True if weights == "random" else False - model_args = TtModelArgs(mesh_device, instruct=instruct, dummy_weights=dummy_weights, optimizations=optimizations) + model_args = TtModelArgs( + mesh_device, + instruct=instruct, + dummy_weights=dummy_weights, + optimizations=optimizations, + max_seq_len=max_seq_len, + max_batch_size=batch_size, + ) model_name = { (16, False): "llama32_1b", @@ -127,7 +169,9 @@ def test_llama_model_inference(mesh_device, weights, layers, optimizations, use_ prompts = ["This is a test"] * model_args.max_batch_size if dummy_weights: - encoded_prompts = [[128000, 2028, 374, 264, 1296]] # "This is a test" encoded prompt + encoded_prompts = [ + [128000, 2028, 374, 264, 1296] + ] * model_args.max_batch_size # "This is a test" encoded prompt assert not instruct, "Instruct prompt not implemented with dummy weights" else: tokenizer = Tokenizer(model_args.tokenizer_path) @@ -147,13 +191,41 @@ def test_llama_model_inference(mesh_device, weights, layers, optimizations, use_ generation_start_pos = 0 generation_length = iterations - # pre-compute the rotational embedding matrix and send to device - current_rot_mat, rot_matrix = get_single_rot_mat( - model_args.head_dim, + # Setup RoPE transformation matrices + rope_setup = TtLlamaRotarySetup( mesh_device, - model_args.num_devices, - start_pos=0, + model_args.max_batch_size, + model_args.head_dim, + model_args.max_seq_len, + model_args.rope_theta, + model_args.use_scaled_rope, ) + transformation_mats = rope_setup.get_trans_mats() + transformation_mats = {"decode": transformation_mats} + + page_table_tt = None + paged_attention_config = None + + # Prepare page table for paged attention + if paged_attention: + paged_attention_config = PagedAttentionConfig( + block_size=page_params["page_block_size"], + max_num_blocks=page_params["page_max_num_blocks"], + ) + # Implied shuffling of blocks + permutation = torch.randperm(paged_attention_config.max_num_blocks) + # Page table which maps virtual blocks to physical + reverse_permutation = torch.argsort(permutation) + page_table = reverse_permutation.reshape( + model_args.max_batch_size, paged_attention_config.max_num_blocks // model_args.max_batch_size + ) + page_table_tt = ttnn.from_torch( + page_table, + device=mesh_device, + dtype=ttnn.int32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) # Load TTNN model tt_model = TtTransformer( @@ -162,6 +234,8 @@ def test_llama_model_inference(mesh_device, weights, layers, optimizations, use_ dtype=dtype, state_dict=state_dict, weight_cache_path=model_args.weight_cache_path(dtype), + transformation_mats=transformation_mats, + paged_attention_config=paged_attention_config, ) logger.info("Model and caches loaded.") @@ -176,7 +250,6 @@ def test_llama_model_inference(mesh_device, weights, layers, optimizations, use_ # Select the first token from the prompts for initial decoding encoded_prompts_tensor = torch.tensor(encoded_prompts) # [:,0] pt_decode_input = embd(encoded_prompts_tensor[:, 0]).view(batch, seqlen, -1) - tt_decode_input = pt_decode_input # Keep track of generated outputs to print out later @@ -184,42 +257,59 @@ def test_llama_model_inference(mesh_device, weights, layers, optimizations, use_ if run_ref_pt: all_outputs_ref = [] + # Initial positions + current_pos = torch.tensor([generation_start_pos for _ in range(batch)]) + current_pos_tensor = ttnn.from_torch( + current_pos, + device=mesh_device, + dtype=ttnn.int32, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) + for i in range(generation_length): - current_pos = generation_start_pos + i + logger.info(f"[Llama3 Model] Generating token {i}") decode_input = model_args.prepare_inputs_ttnn_decode( tt_decode_input, model_args.model_config["DECODE_RESIDUAL_MEMCFG"], ) - current_pos_tensor = ttnn.from_torch( - torch.tensor([current_pos] * batch), - device=mesh_device, - dtype=ttnn.int32, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), - ) + + # Get cos/sin matrices for the current position of each user + rot_mats = rope_setup.get_rot_mats(current_pos) # Run TT model - tt_out = tt_model(decode_input, current_pos_tensor, rot_mat=current_rot_mat) + tt_out = tt_model( + decode_input, + current_pos_tensor, + rot_mats=rot_mats, + mode="decode", + page_table=page_table_tt, + ) + # Convert ttnn tensor to torch tensor tt_output_torch = ( ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=-1)) .permute(2, 1, 0, 3) .squeeze(1)[: model_args.max_batch_size, :, :] ) # [seq, batch, hidden_dim] - ttnn.deallocate(tt_out) - # Update rotation matrix for next iteration - current_rot_mat = ttnn.linear(rot_matrix, current_rot_mat) - if run_ref_pt: # Run reference model - # freqs_cis_i = freqs_cis[current_pos, :].unsqueeze(0) - # positions = torch.tensor([current_pos]) - # mask = ttnn.to_torch(attn_mask[0]) - ref_output = reference_model(pt_decode_input, current_pos) + # In this test all users have the same position + ref_output = reference_model(pt_decode_input, current_pos[0]) - # While in "prefill" mode, use the prompt tokens as the output + # Increment position + current_pos = torch.tensor([generation_start_pos + i for _ in range(batch)]) + current_pos_tensor = ttnn.from_torch( + current_pos, + device=mesh_device, + dtype=ttnn.int32, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) + + # Append the generated token to the list of outputs if i in range(len(encoded_prompts[0])): + # While in "prefill" mode, use the prompt tokens as the output all_outputs.append(encoded_prompts[0][i]) # Update list of TT outputs if run_ref_pt: all_outputs_ref.append(encoded_prompts[0][i]) # Update list of ref outputs @@ -229,16 +319,15 @@ def test_llama_model_inference(mesh_device, weights, layers, optimizations, use_ pt_decode_input = embd(encoded_prompts_tensor[:, i]).view(batch, seqlen, -1) else: # Greedy decode (temperature = 0) the generated token and save it to print out later - tt_out_tok = sample(tt_output_torch, temperature=0, top_p=0.8) + tt_out_tok = sample_host(tt_output_torch, None, temperature=0, top_p=0.8) tt_decode_input = embd(tt_out_tok) all_outputs.append(tt_out_tok.squeeze(1).tolist()[0]) # Update generated token to list of TT outputs if run_ref_pt: - pt_out_tok = sample(ref_output, temperature=0, top_p=0.8) + pt_out_tok = sample_host(ref_output, None, temperature=0, top_p=0.8) pt_decode_input = embd(pt_out_tok) all_outputs_ref.append( pt_out_tok.squeeze(1).tolist()[0] ) # Update generated token to list of ref outputs - # Measure PCC if also running reference model if run_ref_pt: if layers == 1 and i == iterations - 1: # On last iteration in the quick test, set a tighter PCC @@ -271,14 +360,52 @@ def test_llama_model_inference(mesh_device, weights, layers, optimizations, use_ ] tt_layer_present = [] - for layer_past in tt_model.layers[l].attention.layer_past: - tt_layer_present.append( - ttnn.to_torch(layer_past, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1)) - ) + if paged_attention: + for layer_past in tt_model.layers[l].attention.layer_past: + tt_layer_present.append( + ttnn.to_torch(layer_past, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1))[ + reverse_permutation + ] + .reshape( + model_args.max_batch_size, + paged_attention_config.max_num_blocks // model_args.max_batch_size, + model_args.n_kv_heads, + paged_attention_config.block_size, + model_args.head_dim, + ) + .transpose(1, 2) + .reshape(model_args.max_batch_size, model_args.n_kv_heads, -1, model_args.head_dim)[ + :batch, ... + ] + ) + tt_layer_present = [ + ( + ttnn.to_torch(cache, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1))[ + reverse_permutation + ] + .reshape( + model_args.max_batch_size, + paged_attention_config.max_num_blocks // model_args.max_batch_size, + model_args.n_kv_heads, + paged_attention_config.block_size, + model_args.head_dim, + ) + .transpose(1, 2) + .reshape(model_args.max_batch_size, model_args.n_kv_heads, -1, model_args.head_dim)[ + :batch, ... + ] + ) + for cache in tt_model.layers[l].attention.layer_past + ] + else: + for layer_past in tt_model.layers[l].attention.layer_past: + tt_layer_present.append( + ttnn.to_torch(layer_past, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1)) + ) for kv_cache, (cache_pt, cache_tt) in enumerate(zip(pytorch_layer_present, tt_layer_present)): cache_length_to_check = min( - model_args.sliding_window, generation_start_pos + generation_length + 1 + model_args.max_seq_len, generation_start_pos + generation_length + 1 ) cache_pt = cache_pt[:, :, generation_start_pos:cache_length_to_check, :] cache_tt = cache_tt[:, :, generation_start_pos:cache_length_to_check, :] diff --git a/models/demos/llama3/tests/test_llama_model_prefill.py b/models/demos/llama3/tests/test_llama_model_prefill.py index bb1859eddd9..934c91d5746 100644 --- a/models/demos/llama3/tests/test_llama_model_prefill.py +++ b/models/demos/llama3/tests/test_llama_model_prefill.py @@ -10,9 +10,9 @@ from models.demos.llama3.tt.llama_common import ( get_prefill_rot_mat, get_rot_transformation_mat, - sample, HostEmbedding, encode_prompt_llama_instruct, + PagedAttentionConfig, ) from models.demos.llama3.tt.llama_model import TtTransformer from models.demos.llama3.tt.model_config import TtModelArgs, LlamaOptimizations @@ -29,14 +29,6 @@ @skip_for_grayskull("Requires wormhole_b0 to run") @pytest.mark.timeout(900) @pytest.mark.models_performance_bare_metal -@pytest.mark.parametrize( - "seq_len", - ( - # 128, - # 1024, - 4096, - ), -) @pytest.mark.parametrize( "mesh_device", [ @@ -46,6 +38,30 @@ ], indirect=True, ) +# Model and attention prefill tests should run both with and without paged attention to debug any issues that may occur with default attention +@pytest.mark.parametrize( + "paged_attention", + ( + True, + # False, + ), + ids=( + "paged_attention", + # "default_attention", + ), +) +@pytest.mark.parametrize( + "page_params", + [{"page_block_size": 32, "page_max_num_blocks": 1024}], +) +@pytest.mark.parametrize( + "batch_size", + (1,), +) +@pytest.mark.parametrize( + "seq_len", + (4096,), +) @pytest.mark.parametrize( "optimizations", [ @@ -54,7 +70,16 @@ ], ) def test_llama_model_inference( - mesh_device, seq_len, optimizations, use_program_cache, reset_seeds, ensure_gc, is_ci_env + seq_len, + batch_size, + paged_attention, + page_params, + optimizations, + mesh_device, + use_program_cache, + reset_seeds, + ensure_gc, + is_ci_env, ): if is_ci_env and optimizations == LlamaOptimizations.accuracy: pytest.skip("CI test only runs performance mode to reduce CI pipeline load") @@ -68,14 +93,15 @@ def test_llama_model_inference( pcc = 0.91 # TODO Look on improving PCC else: # performance mode assert optimizations == LlamaOptimizations.performance - pcc = 0.91 + pcc = 0.87 # TODO Look on improving PCC mesh_device.enable_async(True) # Use instruct weights instead of general weights instruct = True - model_args = TtModelArgs(mesh_device, instruct=instruct, max_batch_size=1, optimizations=optimizations) + model_args = TtModelArgs(mesh_device, max_batch_size=batch_size, optimizations=optimizations, max_seq_len=seq_len) + tokenizer = Tokenizer(model_args.tokenizer_path) logger.info("Loading weights...") @@ -118,14 +144,39 @@ def test_llama_model_inference( # pre-compute the rotational embedding matrix and send to device rot_mats = get_prefill_rot_mat(model_args.head_dim, model_args.max_seq_len, mesh_device, seq_len=seq_len) transformation_mat_torch = get_rot_transformation_mat(model_args.head_dim) - transformation_mats = ttnn.as_tensor( + transformation_mats_prefill = ttnn.as_tensor( transformation_mat_torch, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), ) + transformation_mats = {"prefill": transformation_mats_prefill} + + # Setup page table + page_table_tt = None + paged_attention_config = None + + if paged_attention: + paged_attention_config = PagedAttentionConfig( + block_size=page_params["page_block_size"], + max_num_blocks=page_params["page_max_num_blocks"], + ) + # Implied shuffling of blocks + permutation = torch.randperm(paged_attention_config.max_num_blocks) + # Page table which maps virtual blocks to physical + reverse_permutation = torch.argsort(permutation) + page_table = reverse_permutation.reshape( + model_args.max_batch_size, paged_attention_config.max_num_blocks // model_args.max_batch_size + ) + page_table_tt = ttnn.from_torch( + page_table, + device=mesh_device, + dtype=ttnn.int32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) # Load TTNN model tt_model = TtTransformer( @@ -134,6 +185,8 @@ def test_llama_model_inference( dtype=dtype, state_dict=state_dict, weight_cache_path=model_args.weight_cache_path(dtype), + transformation_mats=transformation_mats, + paged_attention_config=paged_attention_config, ) logger.info("Model and caches loaded.") @@ -141,30 +194,35 @@ def test_llama_model_inference( if run_ref_pt: all_tests_pass = True - batch = 1 - # Select the first token from the prompt for initial decoding encoded_prompt_tensor = torch.tensor(encoded_prompt) # [:,0] - pt_decode_input = embd(encoded_prompt_tensor).view(batch, seq_len, -1) + pt_prefill_input = embd(encoded_prompt_tensor).view(batch_size, seq_len, -1) - tt_decode_input = pt_decode_input + tt_prefill_input = pt_prefill_input - decode_input = model_args.prepare_inputs_ttnn_prefill( - tt_decode_input, + tt_prefill_input = model_args.prepare_inputs_ttnn_prefill( + pt_prefill_input, ) for i in range(1): start_pos = 0 # Run TT model - tt_out = tt_model(decode_input, None, rot_mats, transformation_mats, user_id=i, mode="prefill") + tt_out = tt_model( + tt_prefill_input, + current_pos=None, + rot_mats=rot_mats, + user_id=i, + mode="prefill", + page_table=page_table_tt, + ) # Convert ttnn tensor to torch tensor tt_output_torch = ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=-1))[ :, 0, :, : ].view( - batch, seq_len, -1 - ) # [ batch, seq, hidden_dim] + batch_size, seq_len, -1 + ) # [ batch_size, seq, hidden_dim] if run_ref_pt: # Run reference model - ref_output = reference_model(pt_decode_input, start_pos, mode="prefill") + ref_output = reference_model(pt_prefill_input, start_pos, mode="prefill") # Measure PCC if also running reference model if run_ref_pt: @@ -186,20 +244,58 @@ def test_llama_model_inference( pytorch_layer_present = [ reference_model.layers[i] .attention.cache_k.clone() - .permute(0, 2, 1, 3), # [batch, n_kv_heads, seq, head_dim] + .permute(0, 2, 1, 3), # [batch_size, n_kv_heads, seq, head_dim] reference_model.layers[i] .attention.cache_v.clone() - .permute(0, 2, 1, 3), # [batch, n_kv_heads, seq, head_dim] + .permute(0, 2, 1, 3), # [batch_size, n_kv_heads, seq, head_dim] ] tt_layer_present = [] - for layer_past in tt_model.layers[i].attention.layer_past_list[0]: - tt_layer_present.append( - ttnn.to_torch(layer_past, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1)) - ) + if paged_attention: + for layer_past in tt_model.layers[l].attention.layer_past: + tt_layer_present.append( + ttnn.to_torch(layer_past, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1))[ + reverse_permutation + ] + .reshape( + model_args.max_batch_size, + paged_attention_config.max_num_blocks // model_args.max_batch_size, + model_args.n_kv_heads, + paged_attention_config.block_size, + model_args.head_dim, + ) + .transpose(1, 2) + .reshape(model_args.max_batch_size, model_args.n_kv_heads, -1, model_args.head_dim)[ + :batch_size, ... + ] + ) + tt_layer_present = [ + ( + ttnn.to_torch(cache, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1))[ + reverse_permutation + ] + .reshape( + model_args.max_batch_size, + paged_attention_config.max_num_blocks // model_args.max_batch_size, + model_args.n_kv_heads, + paged_attention_config.block_size, + model_args.head_dim, + ) + .transpose(1, 2) + .reshape(model_args.max_batch_size, model_args.n_kv_heads, -1, model_args.head_dim)[ + :batch_size, ... + ] + ) + for cache in tt_model.layers[l].attention.layer_past + ] + else: + for layer_past in tt_model.layers[i].attention.layer_past_list[0]: + tt_layer_present.append( + ttnn.to_torch(layer_past, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1)) + ) for i, (cache_pt, cache_tt) in enumerate(zip(pytorch_layer_present, tt_layer_present)): - cache_length_to_check = model_args.sliding_window + cache_length_to_check = model_args.max_seq_len cache_pt = cache_pt[:, :, 0:cache_length_to_check, :] cache_tt = cache_tt[:, :, 0:cache_length_to_check, :] does_pass, output_pcc = comp_pcc(cache_pt, cache_tt) @@ -217,7 +313,7 @@ def test_llama_model_inference( if run_ref_pt: if all_tests_pass: - logger.info(f"All Llama decode iterations Passed!") + logger.info(f"All Llama prefill iterations Passed!") else: - logger.warning("One or more iterations of Llama decode had bad PCC") + logger.warning("One or more iterations of Llama prefill had bad PCC") assert all_tests_pass, f"PCC value is lower than {pcc} for some of the outputs. Check Warnings!" diff --git a/models/demos/llama3/tests/test_llama_perf.py b/models/demos/llama3/tests/test_llama_perf.py deleted file mode 100644 index a873449fe2c..00000000000 --- a/models/demos/llama3/tests/test_llama_perf.py +++ /dev/null @@ -1,203 +0,0 @@ -# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. - -# SPDX-License-Identifier: Apache-2.0 -import os -import torch -import pytest -import re -from loguru import logger -import os -import ttnn -from models.demos.llama3.tt.llama_common import ( - sample, - HostEmbedding, - get_single_rot_mat, -) -from models.demos.llama3.tt.llama_model import TtTransformer -from models.demos.llama3.tt.llama_embedding import TtLlamaEmbedding -from models.demos.llama3.tt.model_config import TtModelArgs, LlamaOptimizations -from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.tokenizer import Tokenizer - -from models.perf.perf_utils import prep_perf_report -from models.perf.device_perf_utils import run_device_perf, check_device_perf, prep_device_perf_report -from models.utility_functions import profiler, skip_for_grayskull - -if not os.getenv("CI") == "true": # Enable tracy signpost support in local runs only - from tracy import signpost - - -@skip_for_grayskull("Requires eth connected devices to run") -@pytest.mark.models_performance_bare_metal -@pytest.mark.parametrize( - "kv_cache_len, expected_compile_time", - ( - (32, 30), - (128, 30), - (1024, 30), - ), -) -@pytest.mark.parametrize( - "mesh_device", - [ - {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( - os.environ.get("FAKE_DEVICE"), len(ttnn.get_device_ids()) - ) - ], - indirect=True, -) -def test_llama_model_perf(mesh_device, kv_cache_len, expected_compile_time, use_program_cache, reset_seeds, ensure_gc): - dtype = ttnn.bfloat8_b - - mesh_device.enable_async(True) - - model_args = TtModelArgs(mesh_device, optimizations=LlamaOptimizations.performance) - tokenizer = Tokenizer(model_args.tokenizer_path) - - if "3.2-1B" in model_args.DEFAULT_CACHE_PATH: - expected_inference_time = 0.045 - elif "3.2-3B" in model_args.DEFAULT_CACHE_PATH: - expected_inference_time = 0.065 - elif "3.1-8B" in model_args.DEFAULT_CACHE_PATH: - expected_inference_time = 0.08 - elif "3.2-11B" in model_args.DEFAULT_CACHE_PATH: - expected_inference_time = 0.085 - elif "3.1-70B" in model_args.DEFAULT_CACHE_PATH: - expected_inference_time = 0.15 - else: - assert False, f"Llama model not found. Supported Llama models: [3.2-1B, 3.2-3B, 3.1-8B, 3.2-11B, 3.1-70B]" - - # model_args.n_layers = 1 - # Clear global profiler state before starting measurements - profiler.clear() - - profiler.start("weight_loading") - state_dict = model_args.load_state_dict() - - profiler.end("weight_loading") - - prompts = ["This is a test"] * model_args.max_batch_size - encoded_prompts = [tokenizer.encode(prompt, bos=True, eos=False) for prompt in prompts] - - # Embedding on host - embd = HostEmbedding(model_args) - state_dict_prefix = model_args.get_state_dict_prefix("", None) - embd.load_state_dict({"emb.weight": state_dict[f"{state_dict_prefix}tok_embeddings.weight"]}) - - generation_start_pos = kv_cache_len - generation_length = 1 - - profiler.start("TtLlama_model_setup") - - # Load TTNN model - tt_model = TtTransformer( - args=model_args, - mesh_device=mesh_device, - dtype=dtype, - state_dict=state_dict, - weight_cache_path=model_args.weight_cache_path(dtype), - ) - # Load TTNN embedding module - tt_embd = TtLlamaEmbedding( - mesh_device=mesh_device, - args=model_args, - weight_cache_path=model_args.weight_cache_path(dtype), - state_dict=state_dict, - dtype=ttnn.bfloat16, # Row major layout requires bfloat16 - ) - profiler.end("TtLlama_model_setup") - - # Call the function - profiler.start(f"end_to_end_inference_with_compile") - run_inference(tt_model, tt_embd, embd, encoded_prompts, generation_start_pos, generation_length) - profiler.end(f"end_to_end_inference_with_compile") - profiler.print() - compile_and_iter_time = profiler.get("model_run_for_inference_0") - - ttnn.DumpDeviceProfiler(mesh_device.get_devices()[0]) - - if not os.getenv("CI") == "true": # Enable tracy signpost support in local runs only - signpost("Model perf run") - - profiler.start(f"end_to_end_inference") - run_inference(tt_model, tt_embd, embd, encoded_prompts, generation_start_pos, generation_length) - profiler.end(f"end_to_end_inference") - profiler.print() - iter_time = profiler.get("end_to_end_inference") - - comment = f"kv_cache_len={kv_cache_len}_num_layers={model_args.n_layers}" - - # Extract the version, number of weights and device name from the cache folder - if "3.1" in model_args.DEFAULT_CACHE_PATH: - llama_version = "3.1" - else: - llama_version = "3.2" - llama_weight = re.search(r"(\d+)B", model_args.DEFAULT_CACHE_PATH).group(1) - llama_device = model_args.device_name - - prep_perf_report( - model_name=f"Llama{llama_version}_{llama_weight}B_{llama_device}_{comment}", - batch_size=model_args.max_batch_size, - inference_and_compile_time=compile_and_iter_time, - inference_time=iter_time, - expected_compile_time=expected_compile_time, - expected_inference_time=expected_inference_time, - comments=comment, - ) - - -def run_inference(tt_model, tt_embd, embd, encoded_prompts, generation_start_pos, generation_length): - seqlen = 1 # Generating one token per user at a time - batch = tt_model.args.max_batch_size - mesh_device = tt_model.mesh_device - - # pre-compute the rotational embedding matrix and send to device - current_rot_mat, rot_matrix = get_single_rot_mat( - tt_model.args.head_dim, - tt_model.mesh_device, - tt_model.args.num_devices, - start_pos=0, - ) - - # Select the first token from the prompts for initial decoding - encoded_prompts_tensor = torch.tensor(encoded_prompts) # [:,0] - - # Initialize tt_out_tok with the first token - tt_out_tok = ttnn.from_torch( - torch.nn.functional.pad( - encoded_prompts_tensor[:, 0].unsqueeze(0).unsqueeze(0).unsqueeze(0), (0, 31), "constant", 0 - ), - device=mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), - dtype=ttnn.uint32, - ) - - # Send first input to device - current_pos = ttnn.from_torch( - torch.tensor([generation_start_pos] * batch), - device=mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), - dtype=ttnn.int32, - ) - - for i in range(generation_length): - # Run TT model - profiler.start(f"model_run_for_inference_{i}") - - decode_input = ttnn.unsqueeze_to_4D(tt_embd(tt_out_tok)) - decode_input = ttnn.to_memory_config(decode_input, tt_model.args.model_config["DECODE_RESIDUAL_MEMCFG"]) - tt_out = tt_model(decode_input, current_pos, rot_mat=current_rot_mat) - tt_out_rm = ttnn.untilize(tt_out, use_multicore=True) - ttnn.deallocate(tt_out) - tt_out_tok = ttnn.argmax(tt_out_rm, dim=3, use_multicore=True, output_tensor=tt_out_tok) - ttnn.deallocate(tt_out_rm) - - # Update the rotation matrix for the next iteration - new_rot_mat = ttnn.linear(rot_matrix, current_rot_mat) - current_rot_mat = ttnn.copy(new_rot_mat, current_rot_mat) - ttnn.plus_one(current_pos) - - profiler.end(f"model_run_for_inference_{i}") - - # Synchronize devices to ensure all profiling data is captured accurately - for i in range(tt_model.args.num_devices): - ttnn.synchronize_device(mesh_device.get_devices()[i]) diff --git a/models/demos/llama3/tests/test_llama_rms_norm.py b/models/demos/llama3/tests/test_llama_rms_norm.py index bf0ce828900..cca0f113b55 100644 --- a/models/demos/llama3/tests/test_llama_rms_norm.py +++ b/models/demos/llama3/tests/test_llama_rms_norm.py @@ -28,13 +28,30 @@ ], indirect=True, ) +@pytest.mark.parametrize( + "batch_size", + (1,), +) +@pytest.mark.parametrize( + "max_seq_len", + (128,), # For decode-only unit test, there's no need to run with large sequence lengths +) @pytest.mark.parametrize("mode", ["prefill", "decode"]) -def test_llama_rms_norm_inference(mesh_device, use_program_cache, reset_seeds, ensure_gc, mode): +def test_llama_rms_norm_inference( + max_seq_len, + batch_size, + mode, + mesh_device, + use_program_cache, + reset_seeds, + ensure_gc, +): dtype = ttnn.bfloat16 mesh_device.enable_async(True) - model_args = TtModelArgs(mesh_device) + model_args = TtModelArgs(mesh_device, max_batch_size=batch_size, max_seq_len=max_seq_len) + model_args.n_layers = 1 state_dict = model_args.load_state_dict() state_dict_prefix = model_args.get_state_dict_prefix("", 0) diff --git a/models/demos/llama3/tests/test_lm_head.py b/models/demos/llama3/tests/test_lm_head.py index a626910c729..4a5570f5cc0 100644 --- a/models/demos/llama3/tests/test_lm_head.py +++ b/models/demos/llama3/tests/test_lm_head.py @@ -23,6 +23,10 @@ "seq_len", (32,), ) +@pytest.mark.parametrize( + "batch_size", + (1,), +) @pytest.mark.parametrize( "mesh_device", [ @@ -32,12 +36,12 @@ ], indirect=True, ) -def test_llama_lm_head_inference(mesh_device, seq_len, use_program_cache, reset_seeds): +def test_llama_lm_head_inference(seq_len, batch_size, mesh_device, use_program_cache, reset_seeds): dtype = ttnn.bfloat8_b mesh_device.enable_async(True) - model_args = TtModelArgs(mesh_device) + model_args = TtModelArgs(mesh_device, max_batch_size=batch_size, max_seq_len=seq_len) model_args.n_layers = 1 state_dict = model_args.load_state_dict() diff --git a/models/demos/llama3/tt/llama_attention.py b/models/demos/llama3/tt/llama_attention.py index d630e91a3bd..a925044554a 100644 --- a/models/demos/llama3/tt/llama_attention.py +++ b/models/demos/llama3/tt/llama_attention.py @@ -20,7 +20,9 @@ def __init__( weight_cache_path, layer_num, dtype, + transformation_mats, configuration, + paged_attention_config=None, ): super().__init__() @@ -34,7 +36,7 @@ def __init__( self.max_seq_len = configuration.max_seq_len self.max_batch_size = configuration.max_batch_size self.n_kv_heads = configuration.n_kv_heads - self.paged_attention_config = configuration.paged_attention_config + self.paged_attention_config = paged_attention_config self.min_kv_prefill_shard_seqlen = configuration.min_kv_prefill_shard_seqlen self.n_local_heads = self.n_heads // configuration.num_devices @@ -42,13 +44,14 @@ def __init__( self.dtype = dtype - self.kv_seq_len = configuration.kv_seq_len - self.sliding_window = configuration.sliding_window + self.max_seq_len = configuration.max_seq_len self.grid_size = configuration.max_grid_size self.compute_kernel_config_hifi2 = configuration.compute_kernel_config_hifi2 self.compute_kernel_config_hifi4 = configuration.compute_kernel_config_hifi4 + self.transformation_mats = transformation_mats + self.model_config = configuration.get_model_config() self.ccl_topology = configuration.ccl_topology() self.is_multichip = configuration.is_multichip @@ -113,7 +116,7 @@ def __init__( self.use_fused_all_gather_matmul = self.model_config["USE_FUSED_ALL_GATHER_MATMUL"] if self.is_multichip and self.use_fused_all_gather_matmul: pt_wo = self.state_dict[wo_str].transpose(-1, -2).unsqueeze(0).unsqueeze(0) - wo_ttnn = ttnn.as_tensor( + self.wo = ttnn.as_tensor( pt_wo, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, @@ -122,7 +125,6 @@ def __init__( mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-1), cache_file_name=cache_name("wo_width_sharded"), ) - self.wo = ttnn.to_device(wo_ttnn, self.mesh_device) else: # For line topology we can't do all gather matmul for now, but we can height shard and reduce scatter # wo: 2048 (2devices) x 4096: width-sharded on 12 banks, 4224 over 12 banks. wo_mem_config = configuration.create_dram_sharded_mem_config( @@ -163,16 +165,16 @@ def __init__( cache_k = torch.zeros( ( self.max_batch_size, - self.n_kv_heads, - self.sliding_window, + self.n_kv_heads // configuration.num_devices, + self.max_seq_len, self.head_dim, ) ) cache_v = torch.zeros( ( self.max_batch_size, - self.n_kv_heads, - self.sliding_window, + self.n_kv_heads // configuration.num_devices, + self.max_seq_len, self.head_dim, ) ) @@ -180,14 +182,14 @@ def __init__( self.layer_past = [ ttnn.as_tensor( k_or_v, - device=self.mesh_device, - mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=1), - layout=self.model_config["ATTN_W_LAYOUT_TILE"], dtype=self.dtype, + layout=self.model_config["ATTN_W_LAYOUT_TILE"], + device=self.mesh_device, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), cache_file_name=f"{weight_cache_path}/kvcache_{k_or_v.shape}" if weight_cache_path and not configuration.dummy_weights else None, - memory_config=ttnn.DRAM_MEMORY_CONFIG, ) for k_or_v in [cache_k, cache_v] ] @@ -198,14 +200,13 @@ def forward_decode( self, x: ttnn.Tensor, current_pos, - rot_mat=None, + rot_mats=None, page_table=None, ) -> ttnn.Tensor: """ x: (seq_len, 1, batch, dim) current_pos: (batch_size), current token position in the sequence for each user """ - assert self.max_batch_size * self.n_kv_heads < 64 ### # QKV matmuls # Use HiFi2 for DRAM-sharded matmuls as they are otherwise flop-bound. Loses 1 bit of activation precision. @@ -245,26 +246,14 @@ def forward_decode( ttnn.deallocate(xqkv_fused) - q_heads_1BQD = ttnn.linear( - q_heads_pre_rot_1BQD, - rot_mat, - program_config=self.model_config["ROT_MAT_BMM_PROGCFG"]( - q_heads_pre_rot_1BQD.shape[-2], q_heads_pre_rot_1BQD.shape[-1], rot_mat.shape[-1] - ), - memory_config=ttnn.DRAM_MEMORY_CONFIG, - compute_kernel_config=self.compute_kernel_config_hifi2, - dtype=ttnn.bfloat16, + # Q Rotary Embeddings + q_heads_1BQD = ttnn.experimental.rotary_embedding_llama( + q_heads_pre_rot_1BQD, rot_mats[0], rot_mats[1], self.transformation_mats["decode"], is_decode_mode=True ) - k_heads_1BKD = ttnn.linear( - k_heads_pre_rot_1BKD, - rot_mat, - program_config=self.model_config["ROT_MAT_BMM_PROGCFG"]( - k_heads_pre_rot_1BKD.shape[-2], k_heads_pre_rot_1BKD.shape[-1], rot_mat.shape[-1] - ), - memory_config=k_heads_pre_rot_1BKD.memory_config(), - compute_kernel_config=self.compute_kernel_config_hifi2, - dtype=ttnn.bfloat16, + # K Rotary Embeddings + k_heads_1BKD = ttnn.experimental.rotary_embedding_llama( + k_heads_pre_rot_1BKD, rot_mats[0], rot_mats[1], self.transformation_mats["decode"], is_decode_mode=True ) ttnn.deallocate(q_heads_pre_rot_1BQD) @@ -275,20 +264,24 @@ def forward_decode( ### keys = self.layer_past[0] values = self.layer_past[1] - # k_heads, [seqlen, n_kv_heads, bsz, head_dim] # v_heads [seqlen, n_kv_heads, bsz, head_dim] - # keys, [max_batch_size, n_kv_heads // configuration.num_devices, sliding_window, head_dim] + # keys, [max_batch_size, n_kv_heads // configuration.num_devices, max_seq_len, head_dim] ttnn.experimental.paged_update_cache(keys, k_heads_1BKD, update_idxs_tensor=current_pos, page_table=page_table) ttnn.experimental.paged_update_cache( values, v_heads_1BKD, update_idxs_tensor=current_pos, page_table=page_table ) + self.layer_past[0] = keys self.layer_past[1] = values ttnn.deallocate(k_heads_1BKD) ttnn.deallocate(v_heads_1BKD) + # NOTE: Varying the batch size will result in slightly different outputs. + # For example, a prompt w/ 1 user vs, the same prompt repeated N times for N users, will produce different outputs + # This is because the SDPA op in decode mode has different number of reductions depending on batch size + # Which leads to slightly different outputs from attention (due to accumulated errors) if page_table: attn_output_1G4D = ttnn.transformer.paged_scaled_dot_product_attention_decode( q_heads_1BQD, @@ -369,7 +362,7 @@ def forward_decode( dense_out_sharded = ttnn.to_memory_config(dense_out_sharded, self.model_config["DECODE_RESIDUAL_MEMCFG"]) return dense_out_sharded - def forward_prefill(self, x_11SH, rot_mats, transformation_mats, user_id: int = 0, page_table=None): + def forward_prefill(self, x_11SH, rot_mats, user_id: int = 0, page_table=None): seq_len = x_11SH.shape[-2] assert seq_len % 128 == 0 and seq_len > 0, "Seqlen must be divisible by 128" ### @@ -414,12 +407,20 @@ def forward_prefill(self, x_11SH, rot_mats, transformation_mats, user_id: int = ### q_heads_1QSD = ttnn.experimental.rotary_embedding_llama( - q_heads_1QSD_pre_rot, rot_mats[0], rot_mats[1], transformation_mats + q_heads_1QSD_pre_rot, + rot_mats[0], + rot_mats[1], + self.transformation_mats["prefill"], + is_decode_mode=False, ) ttnn.deallocate(q_heads_1QSD_pre_rot) k_heads_1KSD = ttnn.experimental.rotary_embedding_llama( - k_heads_1KSD_pre_rot, rot_mats[0], rot_mats[1], transformation_mats + k_heads_1KSD_pre_rot, + rot_mats[0], + rot_mats[1], + self.transformation_mats["prefill"], + is_decode_mode=False, ) ttnn.deallocate(k_heads_1KSD_pre_rot) @@ -427,22 +428,28 @@ def forward_prefill(self, x_11SH, rot_mats, transformation_mats, user_id: int = keys_BKSD, values_BKSD = self.layer_past[0], self.layer_past[1] k_heads_1KSD_8b = ttnn.typecast(k_heads_1KSD, dtype=ttnn.bfloat8_b) + v_heads_1VSD_8b = ttnn.typecast(v_heads_1VSD, dtype=ttnn.bfloat8_b) + ttnn.deallocate(k_heads_1KSD) + # sharding k_fill to deal with update_cache memory limitation - if seq_len >= self.min_kv_prefill_shard_seqlen: + if ( + seq_len >= self.min_kv_prefill_shard_seqlen and not page_table + ): # ttnn.experimental.paged_fill_cache only supports interleaved inputs k_fill = ttnn.interleaved_to_sharded(k_heads_1KSD_8b, self.model_config["KV_PREFILL_MEM_CFG"](seq_len)) else: k_fill = k_heads_1KSD_8b - v_heads_1VSD_8b = ttnn.typecast(v_heads_1VSD, dtype=ttnn.bfloat8_b) - - ttnn.deallocate(v_heads_1VSD) # sharding v_fill to deal with update_cache memory limitation - if seq_len >= self.min_kv_prefill_shard_seqlen: + if ( + seq_len >= self.min_kv_prefill_shard_seqlen and not page_table + ): # ttnn.experimental.paged_fill_cache only supports interleaved inputs v_fill = ttnn.interleaved_to_sharded(v_heads_1VSD_8b, self.model_config["KV_PREFILL_MEM_CFG"](seq_len)) else: v_fill = v_heads_1VSD_8b + ttnn.deallocate(v_heads_1VSD) + if page_table: ttnn.experimental.paged_fill_cache(keys_BKSD, k_fill, page_table, batch_idx=user_id) ttnn.experimental.paged_fill_cache(values_BKSD, v_fill, page_table, batch_idx=user_id) @@ -458,7 +465,7 @@ def forward_prefill(self, x_11SH, rot_mats, transformation_mats, user_id: int = user_id, ) - if seq_len >= self.min_kv_prefill_shard_seqlen: + if seq_len >= self.min_kv_prefill_shard_seqlen and not page_table: ttnn.deallocate(k_fill) ttnn.deallocate(v_fill) @@ -483,6 +490,7 @@ def forward_prefill(self, x_11SH, rot_mats, transformation_mats, user_id: int = v_heads_V1SD_8b, is_causal=True, scale=self.scale, + compute_kernel_config=self.compute_kernel_config_hifi4, program_config=self.model_config["SDPA_PROGCFG"](seq_len), ) @@ -542,10 +550,8 @@ def forward_prefill(self, x_11SH, rot_mats, transformation_mats, user_id: int = else: return output_11SH - def forward( - self, x, current_pos, rot_mats=None, transformation_mats=None, user_id=0, mode="decode", page_table=None - ): + def forward(self, x, current_pos, rot_mats=None, user_id=0, mode="decode", page_table=None): if mode == "prefill": - return self.forward_prefill(x, rot_mats, transformation_mats, user_id, page_table) + return self.forward_prefill(x, rot_mats, user_id, page_table) else: return self.forward_decode(x, current_pos, rot_mats, page_table) diff --git a/models/demos/llama3/tt/llama_common.py b/models/demos/llama3/tt/llama_common.py index 6368443df4f..b9b5484cb89 100644 --- a/models/demos/llama3/tt/llama_common.py +++ b/models/demos/llama3/tt/llama_common.py @@ -16,6 +16,13 @@ def forward(self, x): return self.emb(x) +# Default configuration for Paged Attention +class PagedAttentionConfig: + def __init__(self, block_size=32, max_num_blocks=1024): + self.block_size = block_size + self.max_num_blocks = max_num_blocks + + def encode_prompt_llama_instruct(tokenizer, prompt_text, system_prompt_text=None): """<|begin_of_text|><|start_header_id|>system<|end_header_id|> {{ system_prompt }}<|eot_id|><|start_header_id|>user<|end_header_id|> @@ -97,49 +104,6 @@ def freqs_to_rotation_matrix(cos_freqs, sin_freqs): return rot_emb_matrix -def gather_rotary_emb(rot_emb_matrix, position_ids): - """ - Gather the rotary embeddings for a given position_ids - """ - batch_size, seqlen = position_ids.shape - emb_size, _, dhead = rot_emb_matrix.shape - position_ids = position_ids.view(-1).unsqueeze(-1).unsqueeze(-1).expand(-1, dhead, dhead) - rot_emb = rot_emb_matrix.gather(0, position_ids).view(batch_size, seqlen, dhead, dhead) - return rot_emb - - -def get_rotation_mat_batched(rot_mat, start_pos, seqlen, batch): - if isinstance(start_pos, int): - start_pos = torch.ones(seqlen, batch, dtype=torch.long) * start_pos - position_ids = start_pos.view(seqlen, batch) - rot_emb = gather_rotary_emb(rot_mat, position_ids) - return rot_emb - - -# Sample logits from a distribution -def sample_top_p(probs: torch.Tensor, p: float): - assert 0 <= p <= 1 - - probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) - probs_sum = torch.cumsum(probs_sort, dim=-1) - mask = probs_sum - probs_sort > p - probs_sort[mask] = 0.0 - probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) - - next_token = torch.multinomial(probs_sort, num_samples=1) - return torch.gather(probs_idx, -1, next_token) - - -def sample(logits: torch.Tensor, temperature: float, top_p: float): - if temperature > 0: - probs = torch.softmax(logits / temperature, dim=-1) - next_token = sample_top_p(probs.squeeze(), top_p) - else: - next_token = torch.argmax(logits, dim=-1) - - return next_token - - def gather_cos_sin(position_ids, cos, sin): position_id_expanded = position_ids.unsqueeze(1).expand(-1, cos.shape[-1]) cos = cos.gather(0, position_id_expanded) @@ -273,3 +237,61 @@ def first_five(tensor, mesh_device): Helper function to return the first 5 elements of a tensor via torch """ return torch.Tensor(ttnn.to_torch(tensor, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=-1)))[0, 0, 0, :5] + + +# Sample logits from a distribution +def sample_top_p(probs: torch.Tensor, p: float): + assert 0 <= p <= 1 + + probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) + probs_sum = torch.cumsum(probs_sort, dim=-1) + mask = probs_sum - probs_sort > p + probs_sort[mask] = 0.0 + probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) + + next_token = torch.multinomial(probs_sort, num_samples=1) + return torch.gather(probs_idx, -1, next_token) + + +def sample_host(tt_input, mesh_device, temperature=0.6, top_p=0.08, on_host=True): + vocab_size = tt_input.shape[-1] + if mesh_device: + pt_input = ttnn.to_torch(tt_input, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=-1))[..., :vocab_size] + else: # input already on host + pt_input = tt_input[..., :vocab_size] + + if temperature > 0: + probs = torch.softmax(pt_input / temperature, dim=-1) + pt_out = sample_top_p(probs.squeeze(), top_p) + if mesh_device: + pt_out = pt_out.view(1, 1, 1, -1) + else: + if mesh_device: + pt_out = torch.argmax(pt_input, dim=-1, keepdim=True).transpose(-1, -2) + else: + pt_out = torch.argmax(pt_input, dim=-1) + + if mesh_device is None: + return pt_out + if on_host: + return ( + ttnn.as_tensor( + pt_out, + layout=ttnn.ROW_MAJOR_LAYOUT, + dtype=ttnn.uint32, + device=None, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device) if mesh_device.get_num_devices() > 1 else None, + ), + pt_out, + ) + else: + return ( + ttnn.from_torch( + pt_out, + layout=ttnn.ROW_MAJOR_LAYOUT, + dtype=ttnn.uint32, + device=mesh_device, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ), + pt_out, + ) diff --git a/models/demos/llama3/tt/llama_decoder.py b/models/demos/llama3/tt/llama_decoder.py index 578e0bf81a6..e5edfce889a 100644 --- a/models/demos/llama3/tt/llama_decoder.py +++ b/models/demos/llama3/tt/llama_decoder.py @@ -10,7 +10,17 @@ class TtTransformerBlock(LightweightModule): - def __init__(self, args, mesh_device, dtype, state_dict, layer_num, weight_cache_path): + def __init__( + self, + args, + mesh_device, + dtype, + state_dict, + layer_num, + weight_cache_path, + transformation_mats, + paged_attention_config=None, + ): super().__init__() self.state_dict = state_dict @@ -25,7 +35,6 @@ def __init__(self, args, mesh_device, dtype, state_dict, layer_num, weight_cache self.max_batch_size = args.max_batch_size self.n_kv_heads = args.n_kv_heads self.current = 0 - self.sliding_window = args.sliding_window self.model_config = args.get_model_config() self.layer_num = layer_num @@ -36,7 +45,9 @@ def __init__(self, args, mesh_device, dtype, state_dict, layer_num, weight_cache weight_cache_path=weight_cache_path, layer_num=layer_num, dtype=dtype, + transformation_mats=transformation_mats, configuration=args, + paged_attention_config=paged_attention_config, ) self.feed_forward = TtLlamaMLP( mesh_device=mesh_device, @@ -82,8 +93,7 @@ def forward( self, x: ttnn.Tensor, current_pos, - rot_mat=None, - transformation_mats=None, + rot_mats=None, user_id=0, mode="decode", page_table=None, @@ -99,8 +109,7 @@ def forward( attn_out = self.attention.forward( attn_in, current_pos, - rot_mat, - transformation_mats, + rot_mats, user_id, mode, page_table, diff --git a/models/demos/llama3/tt/llama_model.py b/models/demos/llama3/tt/llama_model.py index 04cf2c8d77b..e04ed2c4cf8 100644 --- a/models/demos/llama3/tt/llama_model.py +++ b/models/demos/llama3/tt/llama_model.py @@ -24,6 +24,8 @@ def __init__( mesh_device, state_dict, weight_cache_path, + transformation_mats, + paged_attention_config=None, ): super().__init__() self.args = args @@ -44,6 +46,8 @@ def __init__( state_dict=state_dict, weight_cache_path=weight_cache_path, layer_num=i, + transformation_mats=transformation_mats, + paged_attention_config=paged_attention_config, ) for i in range(self.n_layers) ] @@ -76,8 +80,7 @@ def forward( self, x: ttnn.Tensor, current_pos, - rot_mat=None, - transformation_mats=None, + rot_mats=None, user_id=0, mode="decode", page_table=None, @@ -88,7 +91,7 @@ def forward( x = ttnn.to_memory_config(x, self.model_config["DECODE_RESIDUAL_MEMCFG"]) for layer in self.layers: - x = layer(x, current_pos, rot_mat, transformation_mats, user_id, mode, page_table) + x = layer(x, current_pos, rot_mats, user_id, mode, page_table) if mode == "prefill" and get_last_token == -1: return x diff --git a/models/demos/llama3/tt/llama_rope.py b/models/demos/llama3/tt/llama_rope.py new file mode 100644 index 00000000000..576ce982e8c --- /dev/null +++ b/models/demos/llama3/tt/llama_rope.py @@ -0,0 +1,168 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import ttnn +from ttnn import ReplicateTensorToMesh, ConcatMeshToTensor +from models.common.lightweightmodule import LightweightModule +from models.demos.llama3.tt.llama_common import precompute_freqs, get_rot_transformation_mat, gather_cos_sin +from models.utility_functions import nearest_32 +from loguru import logger + + +def compute_gather_cos_sin(dhead, end, theta, position_ids, use_scaled_rope): + cos, sin = precompute_freqs(dhead, end, theta, use_scaled_rope) + return gather_cos_sin(position_ids, cos, sin) + + +class TtLlamaRotarySetup(LightweightModule): + def __init__( + self, + device, + batch_size: int, + head_dim: int, + max_seq_len: int, + rope_theta: float = 10000, + use_scaled_rope: bool = False, + datatype=ttnn.bfloat16, + ): + super().__init__() + + self.batch_size = batch_size + self.head_dim = head_dim + self.device = device + self.is_mesh_device = isinstance(device, ttnn._ttnn.multi_device.MeshDevice) + self.num_devices = device.get_num_devices() if self.is_mesh_device else 1 + + self.core_grid = device.compute_with_storage_grid_size() + num_cores = self.core_grid.x * self.core_grid.y + + # Generate the cos/sin matrices needed for ttnn.embedding op + cos_matrix, sin_matrix = compute_gather_cos_sin( + dhead=head_dim, + end=max_seq_len * 2, + theta=rope_theta, + position_ids=torch.arange(max_seq_len), + use_scaled_rope=use_scaled_rope, + ) + + self.cos_matrix = ttnn.from_torch( + cos_matrix, + device=device, + layout=ttnn.ROW_MAJOR_LAYOUT, + dtype=datatype, + mesh_mapper=ReplicateTensorToMesh(device) if self.is_mesh_device else None, + ) + self.sin_matrix = ttnn.from_torch( + sin_matrix, + device=device, + layout=ttnn.ROW_MAJOR_LAYOUT, + dtype=datatype, + mesh_mapper=ReplicateTensorToMesh(device) if self.is_mesh_device else None, + ) + + batch_grid = ttnn.num_cores_to_corerangeset(batch_size, self.core_grid, row_wise=True) + # Generate the transformation matrix + trans_mat = get_rot_transformation_mat(dhead=ttnn.TILE_SIZE).repeat( + 1, + 1, + batch_size, + 1 + # 1, 1, num_cores, 1 + ) # Repeat across all cores on device + trans_mat_mem_config = ttnn.create_sharded_memory_config( + shape=(ttnn.TILE_SIZE, ttnn.TILE_SIZE), + core_grid=batch_grid, + strategy=ttnn.ShardStrategy.HEIGHT, + orientation=ttnn.ShardOrientation.ROW_MAJOR, + use_height_and_width_as_shard_shape=True, + ) + self.transformation_mat = ttnn.from_torch( + trans_mat, + device=device, + layout=ttnn.TILE_LAYOUT, + dtype=datatype, + memory_config=trans_mat_mem_config, + mesh_mapper=ReplicateTensorToMesh(device) if self.is_mesh_device else None, + ) + + def get_trans_mats(self): + assert self.transformation_mat is not None, "Transformation matrix not initialized" + return self.transformation_mat + + def get_rot_idxs(self, position_idxs, on_host=False): + assert isinstance(position_idxs, torch.Tensor), "Position ids must be a torch tensor" + assert len(position_idxs.shape) == 1, "position idxs must be a [batch] tensor" + + batch = position_idxs.shape[0] + position_idxs = position_idxs.reshape(1, batch) # [1, 1, 1, batch] + assert position_idxs.shape == (1, batch), "position idxs must be a [1, batch] tensor" + assert torch.min(position_idxs) >= 0, "position idxs must be non-negative" + + # Add padding if needed + pad_size = nearest_32(batch) - batch + position_idxs = torch.nn.functional.pad(position_idxs, (0, pad_size), "constant", 0) + + if on_host: # If tensor is on host, don't pass a mesh mapper if single-device + rot_idxs = ttnn.as_tensor( + position_idxs, + dtype=ttnn.uint32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ReplicateTensorToMesh(self.device) if self.num_devices > 1 else None, + ) + else: # On device + rot_idxs = ttnn.as_tensor( + position_idxs, + dtype=ttnn.uint32, + layout=ttnn.ROW_MAJOR_LAYOUT, + device=self.device, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ReplicateTensorToMesh(self.device) if self.is_mesh_device else None, + ) + + return rot_idxs + + def get_rot_mats(self, position_idxs, return_rot_idxs=False): + device = self.device + + # If position_idxs is a torch tensor, get the TTNN version of it + if isinstance(position_idxs, torch.Tensor): + rot_idxs = self.get_rot_idxs(position_idxs) + else: + rot_idxs = position_idxs + assert len(rot_idxs.shape) == 2 and rot_idxs.shape[0] == 1, "rot_idxs must be a [1, batch] tensor" + + # Send the idxs to device + if rot_idxs.device != device: + rot_idxs = ttnn.to_device(rot_idxs, device, memory_config=ttnn.DRAM_MEMORY_CONFIG) + + embedding_layout = ttnn.TILE_LAYOUT + cos = ttnn.embedding(rot_idxs, self.cos_matrix, layout=embedding_layout) # [1, batch, head_dim] + sin = ttnn.embedding(rot_idxs, self.sin_matrix, layout=embedding_layout) # [1, batch, head_dim] + + cos = ttnn.unsqueeze_to_4D(cos) # [1, 1, batch, head_dim] + sin = ttnn.unsqueeze_to_4D(sin) # [1, 1, batch, head_dim] + + cos = ttnn.transpose(cos, 1, 2) # [1, batch, 1[32], head_dim] + sin = ttnn.transpose(sin, 1, 2) # [1, batch, 1[32], head_dim] + + if self.batch_size % ttnn.TILE_SIZE != 0: + cos = cos[:, : self.batch_size, :, :] + sin = sin[:, : self.batch_size, :, :] + + grid = ttnn.num_cores_to_corerangeset(self.batch_size, self.core_grid, row_wise=True) + mem_config = ttnn.create_sharded_memory_config( + shape=(ttnn.TILE_SIZE, self.head_dim), + core_grid=grid, + strategy=ttnn.ShardStrategy.HEIGHT, + orientation=ttnn.ShardOrientation.ROW_MAJOR, + use_height_and_width_as_shard_shape=True, + ) + + cos = ttnn.interleaved_to_sharded(cos, mem_config) # [1, 1 (= batch / shard_num_cores), 1[32], self.head_dim] + sin = ttnn.interleaved_to_sharded(sin, mem_config) # [1, 1 (= batch / shard_num_cores), 1[32], self.head_dim] + + if return_rot_idxs: + return [cos, sin], rot_idxs + return [cos, sin] diff --git a/models/demos/llama3/tt/model_config.py b/models/demos/llama3/tt/model_config.py index 407eda7668a..c3f9f385e7a 100644 --- a/models/demos/llama3/tt/model_config.py +++ b/models/demos/llama3/tt/model_config.py @@ -48,17 +48,6 @@ def performance(cls, model_name): class TtModelArgs: - paged_attention_config = None - - # TODO Update these params. In init we update the max_seq_len to 32k if it's a single device - max_batch_size = 1 - # Context length for Llama models (if single device, reduce to 32k in init) - max_seq_len = 8192 * 16 # 128k - kv_seq_len = 8192 * 16 # 128k - sliding_window = 8192 * 16 # 128k - - tile_size = 32 - OP_KEYS = ( # Embedding "EMB_WEIGHTS", @@ -98,12 +87,16 @@ def __init__( instruct=False, dummy_weights=False, max_batch_size=1, + max_seq_len=1024 * 128, optimizations=LlamaOptimizations.accuracy, ): self.num_devices = mesh_device.get_num_devices() if mesh_device else 0 self.mesh_device = mesh_device self.device_name = {0: "CPU", 1: "N150", 2: "N300", 8: "T3K", 32: "TG"}[self.num_devices] self.model_name = "Unknown" # Llama model name will be dependent on the checkpoint directory + self.max_seq_len = max_seq_len + self.max_batch_size = max_batch_size + self.tile_size = 32 LLAMA_DIR = os.getenv("LLAMA_DIR") if LLAMA_DIR: @@ -169,24 +162,6 @@ def __init__( else: # With Dummy weights, set the params from the local copy inside the model folder. This is required for CI pipeline that doesn't mount the external folders. self._set_llama_params(self.LOCAL_LLAMA_PARAMS[local_params]) - # Reduce full 128k context length for combinations with memory constraints - # Currently: n150 8b and t3k 70b with 8b/8b/8b MLPs - # Default folder location for weights and cached files - # FIXME: Setup the max cache size accordingly depending on the target model, architecture and test type. - if ( - self.num_devices <= 2 - ): # for 1-chip or 2-chip devices limit the seqlen to 4K (to avoid OoO on N150/N300 CI tests) - self.max_seq_len = 1024 * 4 - self.kv_seq_len = 1024 * 4 - self.sliding_window = 1024 * 4 - - if ( - self.n_layers == 1 - ): # When running a single layer just reduce the seq len to 128, since we won't be decoding that many iterations - self.max_seq_len = 128 - self.kv_seq_len = 128 - self.sliding_window = 128 - # Some consumers like SentencePiece only accept str not Path for files self.model_base_path = Path(self.DEFAULT_CKPT_DIR) self.model_cache_path = Path(self.DEFAULT_CACHE_PATH) @@ -200,7 +175,6 @@ def __init__( if "instruct" in self.DEFAULT_CACHE_PATH.lower(): self.instruct = True self.dummy_weights = dummy_weights - self.max_batch_size = max_batch_size self.tile_padded_batch_rows = self.tile_size * int(math.ceil(self.max_batch_size / self.tile_size)) # Enable workarounds by default until di/dt issues are fixed @@ -283,6 +257,7 @@ def __init__( # Chunk values based on what works best empirically self.model_config["SDPA_PROGCFG"] = lambda seqlen: ttnn.SDPAProgramConfig( compute_with_storage_grid_size=(8, 8), + exp_approx_mode=False, q_chunk_size=256 if seqlen >= 2048 else 64, k_chunk_size=256 if seqlen >= 2048 else 64, ) @@ -406,6 +381,7 @@ def find_largest_divisor(n, max_divisor=8): self.model_config["SDPA_DECODE_PROGCFG"] = ttnn.SDPAProgramConfig( compute_with_storage_grid_size=(8, 8), + exp_approx_mode=False, q_chunk_size=32, k_chunk_size=32, ) @@ -449,14 +425,6 @@ def find_largest_divisor(n, max_divisor=8): orientation=ttnn.ShardOrientation.ROW_MAJOR, use_height_and_width_as_shard_shape=True, ) - self.model_config["ROT_MAT_BMM_PROGCFG"] = lambda m, k, n: ttnn.MatmulMultiCoreReuseProgramConfig( - compute_with_storage_grid_size=grid_by_batch, - in0_block_w=math.ceil(k / 32), - out_subblock_h=1, - out_subblock_w=1, # TODO How to choose this subblock size? - per_core_M=math.ceil(m / 32), - per_core_N=math.ceil(n / 32), - ) self.model_config["ROT_MAT_MEMCONFIG"] = ttnn.MemoryConfig( ttnn.TensorMemoryLayout.HEIGHT_SHARDED, ttnn.BufferType.L1, diff --git a/tests/scripts/run_performance.sh b/tests/scripts/run_performance.sh index cfeb37aabc5..7956d1c7b03 100755 --- a/tests/scripts/run_performance.sh +++ b/tests/scripts/run_performance.sh @@ -58,21 +58,6 @@ run_perf_models_llm_javelin() { env QWEN_DIR=/mnt/MLPerf/tt_dnn-models/qwen/Qwen2-7B-Instruct FAKE_DEVICE=N150 pytest -n auto models/demos/qwen/tests -m $test_marker - # Llama3.1-8B - llama8b=/mnt/MLPerf/tt_dnn-models/llama/Meta-Llama-3.1-8B-Instruct/ - # Llama3.2-1B - llama1b=/mnt/MLPerf/tt_dnn-models/llama/Llama3.2-1B-Instruct/ - # Llama3.2-3B - llama3b=/mnt/MLPerf/tt_dnn-models/llama/Llama3.2-3B-Instruct/ - # Llama3.2-11B (#Skip: Weights too big for single-chip ci VM) - llama11b=/mnt/MLPerf/tt_dnn-models/llama/Llama3.2-11B-Vision-Instruct/ - - # Run all Llama3 tests for 8B, 1B, and 3B weights - for llama_dir in "$llama8b" "$llama1b" "$llama3b"; do - LLAMA_DIR=$llama_dir pytest -n auto models/demos/llama3/tests/test_llama_perf.py -m $test_marker - echo "LOG_METAL: Llama3 tests for $llama_dir completed" - done - if [ "$tt_arch" == "wormhole_b0" ]; then env pytest -n auto models/demos/wormhole/mamba/tests -m $test_marker fi diff --git a/tests/scripts/t3000/run_t3000_model_perf_tests.sh b/tests/scripts/t3000/run_t3000_model_perf_tests.sh index 02ec0d8c541..eff50354e04 100755 --- a/tests/scripts/t3000/run_t3000_model_perf_tests.sh +++ b/tests/scripts/t3000/run_t3000_model_perf_tests.sh @@ -56,56 +56,6 @@ run_t3000_llama2_70b_tests() { fi } -run_t3000_llama3_70b_tests() { - # Record the start time - fail=0 - start_time=$(date +%s) - - echo "LOG_METAL: Running run_t3000_llama3_70b_tests" - - LLAMA_DIR=/mnt/MLPerf/tt_dnn-models/llama/Llama3.1-70B-Instruct/ WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest -n auto models/demos/llama3/tests/test_llama_perf.py ; fail+=$? - - # Record the end time - end_time=$(date +%s) - duration=$((end_time - start_time)) - echo "LOG_METAL: run_t3000_llama3_70b_tests $duration seconds to complete" - if [[ $fail -ne 0 ]]; then - exit 1 - fi -} - -run_t3000_llama3_tests() { - # Record the start time - fail=0 - start_time=$(date +%s) - - echo "LOG_METAL: Running run_t3000_llama3_tests" - - wh_arch_yaml=wormhole_b0_80_arch_eth_dispatch.yaml - # Llama3.1-8B - llama8b=/mnt/MLPerf/tt_dnn-models/llama/Meta-Llama-3.1-8B-Instruct/ - # Llama3.2-1B - llama1b=/mnt/MLPerf/tt_dnn-models/llama/Llama3.2-1B-Instruct/ - # Llama3.2-3B - llama3b=/mnt/MLPerf/tt_dnn-models/llama/Llama3.2-3B-Instruct/ - # Llama3.2-11B - llama11b=/mnt/MLPerf/tt_dnn-models/llama/Llama3.2-11B-Vision-Instruct/ - - # Run all Llama3 tests for 8B, 1B, and 3B weights - for llama_dir in "$llama1b" "$llama3b" "$llama8b" "$llama11b"; do - LLAMA_DIR=$llama_dir WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/test_llama_perf.py ; fail+=$? - echo "LOG_METAL: Llama3 tests for $llama_dir completed" - done - - # Record the end time - end_time=$(date +%s) - duration=$((end_time - start_time)) - echo "LOG_METAL: run_t3000_llama3_tests $duration seconds to complete" - if [[ $fail -ne 0 ]]; then - exit 1 - fi -} - run_t3000_falcon40b_tests() { # Record the start time fail=0 @@ -187,15 +137,9 @@ run_t3000_llm_tests() { # Run mixtral tests run_t3000_mixtral_tests - # Run llama3-small (1B, 3B, 8B, 11B) tests - run_t3000_llama3_tests - # Run llama2-70b tests run_t3000_llama2_70b_tests - # Run llama3-70b tests - run_t3000_llama3_70b_tests - # Run falcon40b tests run_t3000_falcon40b_tests diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_rotary_embedding_llama.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_rotary_embedding_llama.py index c9958604dad..e3c172ebb8c 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_rotary_embedding_llama.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_rotary_embedding_llama.py @@ -11,21 +11,16 @@ from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import ( comp_pcc, ) -from models.utility_functions import skip_for_grayskull, skip_for_blackhole, nearest_32 -from models.demos.t3000.llama2_70b.tt.llama_common import precompute_freqs, freqs_to_rotation_matrix, gather_rotary_emb -from models.demos.t3000.llama2_70b.tt.llama_rope import TtLlamaRotarySetup +from models.utility_functions import skip_for_grayskull, skip_for_blackhole, nearest_32, skip_for_wormhole_b0 +from models.demos.llama3.tt.llama_common import ( + precompute_freqs, + get_rot_transformation_mat, +) +from models.demos.llama3.tt.llama_rope import TtLlamaRotarySetup MAX_SEQ_LEN = 128 * 1024 -def get_rotation_mat(dhead, end, start_pos, seqlen, batch): - cos, sin = precompute_freqs(dhead, end) - rot_mat = freqs_to_rotation_matrix(cos, sin) - position_ids = torch.ones(seqlen, batch, dtype=torch.long) * start_pos - rot_emb = gather_rotary_emb(rot_mat, position_ids) - return rot_emb - - class TtLlamaRotary(torch.nn.Module): def __init__( self, @@ -110,15 +105,8 @@ def forward(self, xq, xk, freqs_cis): return xq, xk -def get_rot_transformation_mat(dhead): - rot_emb_matrix = torch.zeros(1, 1, dhead, dhead) - rot_emb_matrix[..., torch.arange(0, dhead, 2), torch.arange(1, dhead, 2)] = 1 - rot_emb_matrix[..., torch.arange(1, dhead, 2), torch.arange(0, dhead, 2)] = -1 - return rot_emb_matrix - - def compute_gather_cos_sin(dhead, end, position_ids): - cos, sin = precompute_freqs(dhead, end) + cos, sin = precompute_freqs(dhead, end, theta=10000.0, use_scaled=False) # Using reference defaults position_id_expanded = position_ids.unsqueeze(1).expand(-1, cos.shape[-1]) cos = cos.gather(0, position_id_expanded) sin = sin.gather(0, position_id_expanded) @@ -185,17 +173,16 @@ def run_test_rotary_embedding_llama( tt_model = TtLlamaRotary(device, head_dim, mode, datatype, fuse_qk) if mode == "decode": - rope_setup_decode = TtLlamaRotarySetup(device, head_dim, max_seq_len) - tt_model.transformation_mat = rope_setup_decode.transformation_mat - # For decode, TTNN expects inputs to be [1, batch, nh, dhead] inp = [x.transpose(1, 2) for x in inp] # inp: [seq_len, batch, n_heads, head_dim] if fuse_qk: - # For fused_qk, repeat the position_ids for q and k - position_ids = torch.concat([position_ids, position_ids]) - cos, sin = rope_setup_decode.get_rot_mats(position_ids) + # Set up rope with 2 * batch size (for fused qk) + rope_setup_decode = TtLlamaRotarySetup(device, batch * 2, head_dim, max_seq_len) + tt_model.transformation_mat = rope_setup_decode.transformation_mat + cos, sin = rope_setup_decode.get_rot_mats(position_ids.repeat(2)) + assert ( batch % 8 == 0 or batch == 1 ), "Batch size must be a multiple of 8 or less than 8 for fused_qk rotary embedding" @@ -230,18 +217,19 @@ def run_test_rotary_embedding_llama( input_mem_configs = [q_input_mem_config, k_input_mem_config] else: + # Set up rope with batch size + rope_setup_decode = TtLlamaRotarySetup(device, batch, head_dim, max_seq_len) + tt_model.transformation_mat = rope_setup_decode.transformation_mat cos, sin = rope_setup_decode.get_rot_mats(position_ids) - grid = ( - ttnn.num_cores_to_corerangeset(batch, rope_setup_decode.core_grid, row_wise=True) - .bounding_box() - .grid_size() - ) + + grid = ttnn.num_cores_to_corerangeset(batch, rope_setup_decode.core_grid, row_wise=True) input_mem_configs = [ ttnn.create_sharded_memory_config( - shape=(1, batch, ttnn.TILE_SIZE, head_dim), - core_grid=ttnn.CoreGrid(y=grid.y, x=grid.x), + shape=(ttnn.TILE_SIZE, head_dim), + core_grid=grid, strategy=ttnn.ShardStrategy.HEIGHT, orientation=ttnn.ShardOrientation.ROW_MAJOR, + use_height_and_width_as_shard_shape=True, ) for _ in range(len(inp)) ] @@ -313,7 +301,7 @@ def run_test_rotary_embedding_llama( (1, 128 * 1024), (64, 1), (32, 1), - (16, 1), + (15, 1), (8, 1), (1, 1), ), @@ -330,7 +318,7 @@ def run_test_rotary_embedding_llama( "prefill_128k", "decode_64", "decode_32", - "decode_16", + "decode_15", "decode_8", "decode_1", ), @@ -459,12 +447,9 @@ def test_rotary_embedding_llama_with_program_cache( num_ops = 2 # 2 * rope if mode == "decode": - num_ops += 4 # embedding + transpose + pad + interleaved_to_sharded + num_ops += 3 # embedding + transpose + interleaved_to_sharded - # When batch size is 1, transpose is a no-op - if batch == 1: - num_ops -= 1 - elif batch % 32 == 0: - num_ops -= 1 # When batch size is a multiple of 32, no padding + if batch % ttnn.TILE_SIZE != 0: + num_ops += 1 # slice assert device.num_program_cache_entries() == num_ops diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_rotary_embedding_llama_fused_qk.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_rotary_embedding_llama_fused_qk.py index 579791f0eab..893fe74baa5 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_rotary_embedding_llama_fused_qk.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_rotary_embedding_llama_fused_qk.py @@ -132,9 +132,9 @@ def test_rotary_embedding_llama_fused_qk_with_program_cache( cache_tensors.append(test_tensor) - if batch == 32 or batch == 16: - num_ops = 4 - else: - num_ops = 5 # embedding + fused_qk_rope + transpose + pad + interleaved_to_sharded + num_ops = 4 # embedding + fused_qk_rope + transpose + interleaved_to_sharded + + if (batch * 2) % ttnn.TILE_SIZE != 0: + num_ops += 1 # slice assert device.num_program_cache_entries() == num_ops