Skip to content

Commit

Permalink
#5337: ttnn fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
sraizada-tt committed Jul 19, 2024
1 parent 47dde14 commit 6c95c41
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
os.environ["MISTRAL_CKPT_DIR"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/"
os.environ["MISTRAL_TOKENIZER_PATH"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/"
os.environ["MISTRAL_CACHE_PATH"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/"
# Prefill prompt files too large to keep in repo
os.environ["MIXTRAL_REF_OUTPUT_PATH"] = "/mnt/MLPerf/tt_dnn-models/Mistral/Mixtral-8x7B-v0.1/prefill/"

import ttnn
from models.demos.wormhole.mistral7b.tt.mistral_common import (
Expand Down Expand Up @@ -45,8 +47,10 @@ def forward(self, x):
"seq_len",
(
128,
# 1024//2,
# 1024 * 2,
512,
1024,
2048,
4096,
),
)
@pytest.mark.parametrize(
Expand All @@ -68,6 +72,7 @@ def test_mistral_model_inference(device, version, seq_len, use_program_cache, re
cache_pcc = False # Flag to measure KV cache PCC for all layers

dtype = ttnn.bfloat8_b
pcc = 0.94

model_args = TtModelArgs(device)
model_args.n_layers = 32 # Full model
Expand All @@ -90,7 +95,11 @@ def test_mistral_model_inference(device, version, seq_len, use_program_cache, re
# The instruct prompts follow the format: <bos> [INST] prompt [/INST]. [INST] are strings. <bos> is the correspoding bos_id token
prompts = ["[INST] what is the capital of Canada? [/INST]"] * 32
else:
prompt_file = "models/demos/wormhole/mistral7b/tests/tale-of-two-cities.txt"
prompt_file = os.environ["MIXTRAL_REF_OUTPUT_PATH"] + "/tale-of-two-cities.txt"
assert os.path.exists(
prompt_file
), f"Expected prompt file not found: {prompt_file}. Please set the flag 'MIXTRAL_REF_OUTPUT_PATH' correctly."

with open(prompt_file, "r") as f:
prompts = f.read()
encoded_prompts = tokenizer.encode(prompts)[:seq_len]
Expand Down Expand Up @@ -156,7 +165,7 @@ def test_mistral_model_inference(device, version, seq_len, use_program_cache, re
# TODO Measure only PCC at the end, instead of at every iteration
# Measure PCC if also running reference model
if run_ref_pt:
passing, pcc_message = comp_pcc(ref_output, tt_output_torch, 0.94)
passing, pcc_message = comp_pcc(ref_output, tt_output_torch, pcc)

logger.info(comp_allclose(ref_output, tt_output_torch))
logger.info(f"Model output: {pcc_message}")
Expand Down Expand Up @@ -206,4 +215,4 @@ def test_mistral_model_inference(device, version, seq_len, use_program_cache, re
logger.info(f"All Mistral decode iterations Passed!")
else:
logger.warning("One or more iterations of Mistral decode had bad PCC")
assert all_tests_pass, f"PCC value is lower than {0.99} for some of the outputs. Check Warnings!"
assert all_tests_pass, f"PCC value is lower than {pcc} for some of the outputs. Check Warnings!"
4 changes: 2 additions & 2 deletions models/demos/wormhole/mistral7b/tt/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def __init__(self, device, instruct=False):
per_core_M=4, # 32, #16, # M / TILE_HEIGHT / Grid_Size (dynamic based on seqlen)
per_core_N=56, # N / TILE_WIDTH / Grid_Size
transpose_mcast=False,
fused_activation=ttnn.experimental.tensor.FusibleActivation.SILU,
fused_activation=ttnn.UnaryOpType.SILU,
fuse_batch=False,
)

Expand Down Expand Up @@ -176,7 +176,7 @@ def __init__(self, device, instruct=False):
per_core_M=1, # 32, #16, # M / TILE_HEIGHT / Grid_Size (dynamic based on seqlen)
per_core_N=56, # N / TILE_WIDTH / Grid_Size
transpose_mcast=False,
fused_activation=ttnn.experimental.tensor.FusibleActivation.SILU,
fused_activation=ttnn.UnaryOpType.SILU,
fuse_batch=False,
)

Expand Down

0 comments on commit 6c95c41

Please sign in to comment.