From 6ca9a028eb6a8a94abac02526a937d3cc5cb16f9 Mon Sep 17 00:00:00 2001 From: Mark O'Connor Date: Fri, 27 Sep 2024 13:16:36 +0000 Subject: [PATCH] #12328: Fix demo_trace and add on-device argmax to test_llama_perf --- .../wormhole/llama31_8b/demo/demo_trace.py | 8 +-- .../llama31_8b/tests/test_llama_perf.py | 54 ++++++++++--------- 2 files changed, 32 insertions(+), 30 deletions(-) diff --git a/models/demos/wormhole/llama31_8b/demo/demo_trace.py b/models/demos/wormhole/llama31_8b/demo/demo_trace.py index cc42ebab512e..c789cd354c62 100644 --- a/models/demos/wormhole/llama31_8b/demo/demo_trace.py +++ b/models/demos/wormhole/llama31_8b/demo/demo_trace.py @@ -124,10 +124,10 @@ def run_llama_demo(user_input, batch_size, device, instruct_mode, is_ci_env, num output_filename = f"{output_directory}/demo_user_output_{timestamp}.txt" # Set Llama flags for CI - # if is_ci_env and instruct_mode: # Update paths for instruct mode, otherwise use default paths for general weights - os.environ["LLAMA_CKPT_DIR"] = "/proj_sw/user_dev/hf_data/llama/Meta-Llama-3.1-8B-Instruct/" - os.environ["LLAMA_TOKENIZER_PATH"] = "/proj_sw/user_dev/hf_data/llama/Meta-Llama-3.1-8B-Instruct/" - os.environ["LLAMA_CACHE_PATH"] = "/proj_sw/user_dev/hf_data/llama/Meta-Llama-3.1-8B-Instruct/" + if is_ci_env and instruct_mode: # Update paths for instruct mode, otherwise use default paths for general weights + os.environ["LLAMA_CKPT_DIR"] = "/proj_sw/user_dev/hf_data/llama/Meta-Llama-3.1-8B-Instruct/" + os.environ["LLAMA_TOKENIZER_PATH"] = "/proj_sw/user_dev/hf_data/llama/Meta-Llama-3.1-8B-Instruct/" + os.environ["LLAMA_CACHE_PATH"] = "/proj_sw/user_dev/hf_data/llama/Meta-Llama-3.1-8B-Instruct/" # This module requires the env paths above for CI runs from models.demos.wormhole.llama31_8b.tt.model_config import TtModelArgs diff --git a/models/demos/wormhole/llama31_8b/tests/test_llama_perf.py b/models/demos/wormhole/llama31_8b/tests/test_llama_perf.py index dce87212d110..3afd3b088aed 100644 --- a/models/demos/wormhole/llama31_8b/tests/test_llama_perf.py +++ b/models/demos/wormhole/llama31_8b/tests/test_llama_perf.py @@ -9,9 +9,8 @@ import ttnn from models.demos.wormhole.llama31_8b.tt.llama_common import ( prepare_inputs_ttnn, - sample, - HostEmbedding, get_single_rot_mat, + HostEmbedding, ) from models.demos.wormhole.llama31_8b.tt.llama_model import TtTransformer from models.demos.wormhole.llama31_8b.tt.llama_embedding import TtLlamaEmbedding @@ -137,35 +136,38 @@ def run_inference(device, tt_model, tt_embd, embd, encoded_prompts, generation_s # Select the first token from the prompts for initial decoding encoded_prompts_tensor = torch.tensor(encoded_prompts) # [:,0] - for i in range(generation_length): - current_pos = generation_start_pos + i - pt_decode_input = embd(encoded_prompts_tensor[:, 0]).view(batch, seqlen, -1) - tt_decode_input = pt_decode_input - decode_input = prepare_inputs_ttnn( - tt_decode_input, - tt_model.args.dim, - tt_model.device, - ) + # Initialize tt_out_tok with the first token + tt_out_tok = ttnn.from_torch( + torch.nn.functional.pad( + encoded_prompts_tensor[:, 0].unsqueeze(0).unsqueeze(0).unsqueeze(0), (0, 31), "constant", 0 + ), + device=device, + dtype=ttnn.uint32, + ) - current_pos_tensor = ttnn.from_torch(torch.tensor([current_pos] * batch), device=device, dtype=ttnn.int32) - current_pos_attn_tensor = ttnn.from_torch( - torch.tensor([current_pos] * batch * 8), device=device, dtype=ttnn.int32 - ) + current_pos = ttnn.from_torch(torch.tensor([generation_start_pos] * batch), device=device, dtype=ttnn.int32) + current_pos_attn = ttnn.from_torch( + torch.tensor([generation_start_pos] * batch * 8), device=device, dtype=ttnn.int32 + ) + for i in range(generation_length): # Run TT model profiler.start(f"model_run_for_inference_{i}") - tt_out = tt_model(decode_input, current_pos_tensor, current_pos_attn_tensor, rot_mat=current_rot_mat) - # Convert ttnn tensor to torch tensor - profiler.start(f"result_wait_for_inference_{i}") - tt_out = ttnn.untilize(tt_out, use_multicore=True) - tt_output_torch = ttnn.to_torch(tt_out).permute(2, 1, 0, 3).squeeze(1) # [seq, batch, hidden_dim] + decode_input = ttnn.unsqueeze_to_4D(tt_embd(tt_out_tok)) + tt_out = tt_model(decode_input, current_pos, current_pos_attn, rot_mat=current_rot_mat) + tt_out_rm = ttnn.untilize(tt_out, use_multicore=True) + ttnn.deallocate(tt_out) + tt_out_tok = ttnn.argmax(tt_out_rm, dim=3, use_multicore=True, output_tensor=tt_out_tok) + ttnn.deallocate(tt_out_rm) - profiler.end(f"model_run_for_inference_{i}") - profiler.end(f"result_wait_for_inference_{i}") + # Update the rotation matrix for the next iteration + new_rot_mat = ttnn.linear(rot_matrix, current_rot_mat) + current_rot_mat = ttnn.copy(new_rot_mat, current_rot_mat) + ttnn.plus_one(current_pos) + ttnn.plus_one(current_pos_attn) - # Greedy decode the generated token and pass it back in, this is just a perf test - tt_out_tok = sample(tt_output_torch, temperature=0, top_p=1) + profiler.end(f"model_run_for_inference_{i}") - # Update the rotation matrix for the next iteration - current_rot_mat = ttnn.linear(rot_matrix, current_rot_mat) + # Synchronize device to ensure all operations are complete + ttnn.synchronize_device(device)