Skip to content

Commit

Permalink
#12328: Fix demo_trace and add on-device argmax to test_llama_perf
Browse files Browse the repository at this point in the history
  • Loading branch information
yieldthought committed Sep 27, 2024
1 parent a9e9b51 commit 6ca9a02
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 30 deletions.
8 changes: 4 additions & 4 deletions models/demos/wormhole/llama31_8b/demo/demo_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,10 @@ def run_llama_demo(user_input, batch_size, device, instruct_mode, is_ci_env, num
output_filename = f"{output_directory}/demo_user_output_{timestamp}.txt"

# Set Llama flags for CI
# if is_ci_env and instruct_mode: # Update paths for instruct mode, otherwise use default paths for general weights
os.environ["LLAMA_CKPT_DIR"] = "/proj_sw/user_dev/hf_data/llama/Meta-Llama-3.1-8B-Instruct/"
os.environ["LLAMA_TOKENIZER_PATH"] = "/proj_sw/user_dev/hf_data/llama/Meta-Llama-3.1-8B-Instruct/"
os.environ["LLAMA_CACHE_PATH"] = "/proj_sw/user_dev/hf_data/llama/Meta-Llama-3.1-8B-Instruct/"
if is_ci_env and instruct_mode: # Update paths for instruct mode, otherwise use default paths for general weights
os.environ["LLAMA_CKPT_DIR"] = "/proj_sw/user_dev/hf_data/llama/Meta-Llama-3.1-8B-Instruct/"
os.environ["LLAMA_TOKENIZER_PATH"] = "/proj_sw/user_dev/hf_data/llama/Meta-Llama-3.1-8B-Instruct/"
os.environ["LLAMA_CACHE_PATH"] = "/proj_sw/user_dev/hf_data/llama/Meta-Llama-3.1-8B-Instruct/"
# This module requires the env paths above for CI runs
from models.demos.wormhole.llama31_8b.tt.model_config import TtModelArgs

Expand Down
54 changes: 28 additions & 26 deletions models/demos/wormhole/llama31_8b/tests/test_llama_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@
import ttnn
from models.demos.wormhole.llama31_8b.tt.llama_common import (
prepare_inputs_ttnn,
sample,
HostEmbedding,
get_single_rot_mat,
HostEmbedding,
)
from models.demos.wormhole.llama31_8b.tt.llama_model import TtTransformer
from models.demos.wormhole.llama31_8b.tt.llama_embedding import TtLlamaEmbedding
Expand Down Expand Up @@ -137,35 +136,38 @@ def run_inference(device, tt_model, tt_embd, embd, encoded_prompts, generation_s
# Select the first token from the prompts for initial decoding
encoded_prompts_tensor = torch.tensor(encoded_prompts) # [:,0]

for i in range(generation_length):
current_pos = generation_start_pos + i
pt_decode_input = embd(encoded_prompts_tensor[:, 0]).view(batch, seqlen, -1)
tt_decode_input = pt_decode_input
decode_input = prepare_inputs_ttnn(
tt_decode_input,
tt_model.args.dim,
tt_model.device,
)
# Initialize tt_out_tok with the first token
tt_out_tok = ttnn.from_torch(
torch.nn.functional.pad(
encoded_prompts_tensor[:, 0].unsqueeze(0).unsqueeze(0).unsqueeze(0), (0, 31), "constant", 0
),
device=device,
dtype=ttnn.uint32,
)

current_pos_tensor = ttnn.from_torch(torch.tensor([current_pos] * batch), device=device, dtype=ttnn.int32)
current_pos_attn_tensor = ttnn.from_torch(
torch.tensor([current_pos] * batch * 8), device=device, dtype=ttnn.int32
)
current_pos = ttnn.from_torch(torch.tensor([generation_start_pos] * batch), device=device, dtype=ttnn.int32)
current_pos_attn = ttnn.from_torch(
torch.tensor([generation_start_pos] * batch * 8), device=device, dtype=ttnn.int32
)

for i in range(generation_length):
# Run TT model
profiler.start(f"model_run_for_inference_{i}")
tt_out = tt_model(decode_input, current_pos_tensor, current_pos_attn_tensor, rot_mat=current_rot_mat)

# Convert ttnn tensor to torch tensor
profiler.start(f"result_wait_for_inference_{i}")
tt_out = ttnn.untilize(tt_out, use_multicore=True)
tt_output_torch = ttnn.to_torch(tt_out).permute(2, 1, 0, 3).squeeze(1) # [seq, batch, hidden_dim]
decode_input = ttnn.unsqueeze_to_4D(tt_embd(tt_out_tok))
tt_out = tt_model(decode_input, current_pos, current_pos_attn, rot_mat=current_rot_mat)
tt_out_rm = ttnn.untilize(tt_out, use_multicore=True)
ttnn.deallocate(tt_out)
tt_out_tok = ttnn.argmax(tt_out_rm, dim=3, use_multicore=True, output_tensor=tt_out_tok)
ttnn.deallocate(tt_out_rm)

profiler.end(f"model_run_for_inference_{i}")
profiler.end(f"result_wait_for_inference_{i}")
# Update the rotation matrix for the next iteration
new_rot_mat = ttnn.linear(rot_matrix, current_rot_mat)
current_rot_mat = ttnn.copy(new_rot_mat, current_rot_mat)
ttnn.plus_one(current_pos)
ttnn.plus_one(current_pos_attn)

# Greedy decode the generated token and pass it back in, this is just a perf test
tt_out_tok = sample(tt_output_torch, temperature=0, top_p=1)
profiler.end(f"model_run_for_inference_{i}")

# Update the rotation matrix for the next iteration
current_rot_mat = ttnn.linear(rot_matrix, current_rot_mat)
# Synchronize device to ensure all operations are complete
ttnn.synchronize_device(device)

0 comments on commit 6ca9a02

Please sign in to comment.