diff --git a/models/demos/llama3/demo/demo.py b/models/demos/llama3/demo/demo.py index a828830d330..09cff31e7c6 100644 --- a/models/demos/llama3/demo/demo.py +++ b/models/demos/llama3/demo/demo.py @@ -26,7 +26,6 @@ ) from models.demos.llama3.tt.llama_model import TtTransformer from models.demos.llama3.tt.llama_embedding import TtLlamaEmbedding -from models.demos.llama3.tt.llama_rope import TtLlamaRotarySetup from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.tokenizer import Tokenizer from models.demos.llama3.tt.model_config import TtModelArgs @@ -265,28 +264,6 @@ def run_llama3_demo( state_dict = model_args.load_state_dict() profiler.end("weight_loading") - # Setup RoPE transformation matrices - rope_setup = TtLlamaRotarySetup( - mesh_device, - batch_size, - model_args.head_dim, - model_args.max_seq_len, - model_args.rope_theta, - model_args.use_scaled_rope, - ) - transformation_mats_decode = rope_setup.get_trans_mats() - - transformation_mats_prefill_torch = get_rot_transformation_mat(model_args.head_dim) - transformation_mats_prefill = ttnn.from_torch( - transformation_mats_prefill_torch, - dtype=ttnn.bfloat16, - layout=ttnn.TILE_LAYOUT, - device=mesh_device, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), - ) - transformation_mats = {"decode": transformation_mats_decode, "prefill": transformation_mats_prefill} - page_table_tt = None if paged_attention: @@ -314,7 +291,6 @@ def run_llama3_demo( dtype=dtype, state_dict=state_dict, weight_cache_path=model_args.weight_cache_path(dtype), - transformation_mats=transformation_mats, paged_attention_config=paged_attention_config, ) tt_embd = TtLlamaEmbedding( @@ -476,7 +452,7 @@ def run_llama3_demo( ) # Get cos/sin matrices for the current position of each user - rot_mats, rot_mat_idxs = rope_setup.get_rot_mats(current_pos, return_rot_idxs=True) + rot_mats, rot_mat_idxs = tt_model.rope_setup.get_rot_mats(current_pos, return_rot_idxs=True) # Compile logger.info(f"Compiling model trace...") decode_input = ttnn.unsqueeze_to_4D(tt_embd(tt_out_tok)) @@ -519,7 +495,7 @@ def run_llama3_demo( decode_input = ttnn.unsqueeze_to_4D(tt_embd(tt_out_tok)) decode_input = ttnn.to_memory_config(decode_input, tt_model.args.model_config["DECODE_RESIDUAL_MEMCFG"]) - rot_mats = rope_setup.get_rot_mats(rot_mat_idxs) + rot_mats = tt_model.rope_setup.get_rot_mats(rot_mat_idxs) tt_out = tt_model( decode_input, current_pos_tensor, @@ -562,7 +538,7 @@ def run_llama3_demo( # Reset the current position and output token tensors for the real decode run ttnn.copy_host_to_device_tensor(current_pos_reset, current_pos_tensor) ttnn.copy_host_to_device_tensor(tt_out_tok_reset, tt_out_tok) - rot_mat_idxs_reset = rope_setup.get_rot_idxs(current_pos, on_host=True) + rot_mat_idxs_reset = tt_model.rope_setup.get_rot_idxs(current_pos, on_host=True) ttnn.copy_host_to_device_tensor(rot_mat_idxs_reset, rot_mat_idxs) profiler.end(f"capture_trace_{batch_idx}") @@ -591,7 +567,7 @@ def run_llama3_demo( # TODO This is required for now since we cannot ttnn.plus_one(rot_mat_idxs) while it being uint32. # If this tensor is int32, it won't be supported by ttnn.embedding current_pos += 1 - rot_mat_idxs_updated = rope_setup.get_rot_idxs(current_pos, on_host=True) + rot_mat_idxs_updated = tt_model.rope_setup.get_rot_idxs(current_pos, on_host=True) ttnn.copy_host_to_device_tensor(rot_mat_idxs_updated, rot_mat_idxs) # Write to host diff --git a/models/demos/llama3/tests/test_llama_accuracy.py b/models/demos/llama3/tests/test_llama_accuracy.py index 2ae973a907d..0afad6f1754 100644 --- a/models/demos/llama3/tests/test_llama_accuracy.py +++ b/models/demos/llama3/tests/test_llama_accuracy.py @@ -9,13 +9,11 @@ import ttnn from models.demos.llama3.tt.llama_common import ( get_prefill_rot_mat, - get_rot_transformation_mat, HostEmbedding, PagedAttentionConfig, ) from models.demos.llama3.tt.llama_model import TtTransformer from models.demos.llama3.tt.model_config import TtModelArgs, LlamaOptimizations -from models.demos.llama3.tt.llama_rope import TtLlamaRotarySetup from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.tokenizer import Tokenizer from models.demos.llama3.demo.demo import preprocess_inputs_prefill from pathlib import Path @@ -141,28 +139,6 @@ def test_tt_model_accuracy( N = prefill_len + decode_len input_ids = reference_tokens[:, : N + 1] # Shape [1, N+1] - # Setup RoPE transformation matrices - rope_setup = TtLlamaRotarySetup( - mesh_device, - model_args.max_batch_size, - model_args.head_dim, - model_args.max_seq_len, - model_args.rope_theta, - model_args.use_scaled_rope, - ) - transformation_mats_decode = rope_setup.get_trans_mats() - - transformation_mats_prefill_torch = get_rot_transformation_mat(model_args.head_dim) - transformation_mats_prefill = ttnn.from_torch( - transformation_mats_prefill_torch, - dtype=ttnn.bfloat16, - layout=ttnn.TILE_LAYOUT, - device=mesh_device, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), - ) - transformation_mats = {"decode": transformation_mats_decode, "prefill": transformation_mats_prefill} - page_table_tt = None paged_attention_config = None @@ -193,7 +169,6 @@ def test_tt_model_accuracy( dtype=dtype, state_dict=state_dict, weight_cache_path=model_args.weight_cache_path(dtype), - transformation_mats=transformation_mats, paged_attention_config=paged_attention_config, ) # Initialize embedding @@ -256,7 +231,7 @@ def test_tt_model_accuracy( ) # Get cos/sin matrices for the current position of each user - rot_mats = rope_setup.get_rot_mats(current_pos) + rot_mats = tt_model.rope_setup.get_rot_mats(current_pos) # Print table header logger.info(f"{'Progress':<15}{'Correct':<8}{'True':<15}{'Actual':<15}{'Top 5 Predictions':<75}") @@ -309,7 +284,7 @@ def test_tt_model_accuracy( # Update rot_mats for next iteration current_pos += 1 - rot_mats = rope_setup.get_rot_mats(current_pos) + rot_mats = tt_model.rope_setup.get_rot_mats(current_pos) # Get reference top5 tokens and probabilities for this position ref_top5_tokens = top5_tokens[prefill_len + i] diff --git a/models/demos/llama3/tests/test_llama_model.py b/models/demos/llama3/tests/test_llama_model.py index cd425579a23..26d71bddd8a 100644 --- a/models/demos/llama3/tests/test_llama_model.py +++ b/models/demos/llama3/tests/test_llama_model.py @@ -15,7 +15,6 @@ ) from models.demos.llama3.tt.model_config import TtModelArgs, LlamaOptimizations from models.demos.llama3.tt.llama_model import TtTransformer -from models.demos.llama3.tt.llama_rope import TtLlamaRotarySetup from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.model import Transformer from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.tokenizer import Tokenizer from models.utility_functions import ( @@ -191,18 +190,6 @@ def test_llama_model_inference( generation_start_pos = 0 generation_length = iterations - # Setup RoPE transformation matrices - rope_setup = TtLlamaRotarySetup( - mesh_device, - model_args.max_batch_size, - model_args.head_dim, - model_args.max_seq_len, - model_args.rope_theta, - model_args.use_scaled_rope, - ) - transformation_mats = rope_setup.get_trans_mats() - transformation_mats = {"decode": transformation_mats} - page_table_tt = None paged_attention_config = None @@ -234,7 +221,6 @@ def test_llama_model_inference( dtype=dtype, state_dict=state_dict, weight_cache_path=model_args.weight_cache_path(dtype), - transformation_mats=transformation_mats, paged_attention_config=paged_attention_config, ) logger.info("Model and caches loaded.") @@ -275,7 +261,7 @@ def test_llama_model_inference( ) # Get cos/sin matrices for the current position of each user - rot_mats = rope_setup.get_rot_mats(current_pos) + rot_mats = tt_model.rope_setup.get_rot_mats(current_pos) # Run TT model tt_out = tt_model( diff --git a/models/demos/llama3/tests/test_llama_model_prefill.py b/models/demos/llama3/tests/test_llama_model_prefill.py index 934c91d5746..d73ef4f6691 100644 --- a/models/demos/llama3/tests/test_llama_model_prefill.py +++ b/models/demos/llama3/tests/test_llama_model_prefill.py @@ -93,7 +93,7 @@ def test_llama_model_inference( pcc = 0.91 # TODO Look on improving PCC else: # performance mode assert optimizations == LlamaOptimizations.performance - pcc = 0.87 # TODO Look on improving PCC + pcc = 0.869 # TODO Look on improving PCC mesh_device.enable_async(True) @@ -143,17 +143,6 @@ def test_llama_model_inference( # pre-compute the rotational embedding matrix and send to device rot_mats = get_prefill_rot_mat(model_args.head_dim, model_args.max_seq_len, mesh_device, seq_len=seq_len) - transformation_mat_torch = get_rot_transformation_mat(model_args.head_dim) - transformation_mats_prefill = ttnn.as_tensor( - transformation_mat_torch, - dtype=ttnn.bfloat16, - layout=ttnn.TILE_LAYOUT, - device=mesh_device, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), - ) - transformation_mats = {"prefill": transformation_mats_prefill} - # Setup page table page_table_tt = None paged_attention_config = None @@ -185,7 +174,6 @@ def test_llama_model_inference( dtype=dtype, state_dict=state_dict, weight_cache_path=model_args.weight_cache_path(dtype), - transformation_mats=transformation_mats, paged_attention_config=paged_attention_config, ) diff --git a/models/demos/llama3/tt/llama_attention.py b/models/demos/llama3/tt/llama_attention.py index 9c9b537d4f0..665a11193e7 100644 --- a/models/demos/llama3/tt/llama_attention.py +++ b/models/demos/llama3/tt/llama_attention.py @@ -471,8 +471,10 @@ def forward_prefill(self, x_11SH, rot_mats, user_id: int = 0, page_table=None, k # Assume that the page table does not have padding, so we can use it to get the unpadded page len. block_size = keys_BKSD.shape[2] page_len = page_table.shape[1] * block_size - ttnn.experimental.paged_fill_cache(keys_BKSD, k_fill[:, :, :page_len, :], page_table, batch_idx=user_id) - ttnn.experimental.paged_fill_cache(values_BKSD, v_fill[:, :, :page_len, :], page_table, batch_idx=user_id) + k_fill_sliced = k_fill[:, :, :page_len, :] if page_len < k_fill.shape[2] else k_fill + v_fill_sliced = v_fill[:, :, :page_len, :] if page_len < v_fill.shape[2] else v_fill + ttnn.experimental.paged_fill_cache(keys_BKSD, k_fill_sliced, page_table, batch_idx=user_id) + ttnn.experimental.paged_fill_cache(values_BKSD, v_fill_sliced, page_table, batch_idx=user_id) else: ttnn.fill_cache( keys_BKSD,