From 6c95c41415afb378febea83b542e451eb54b5caa Mon Sep 17 00:00:00 2001 From: Stuti Raizada Date: Fri, 19 Jul 2024 05:46:46 -0700 Subject: [PATCH] #5337: ttnn fixes --- .../tests/test_mistral_model_prefill.py | 19 ++++++++++++++----- .../wormhole/mistral7b/tt/model_config.py | 4 ++-- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/models/demos/wormhole/mistral7b/tests/test_mistral_model_prefill.py b/models/demos/wormhole/mistral7b/tests/test_mistral_model_prefill.py index f5fca338828c..bfd855481924 100644 --- a/models/demos/wormhole/mistral7b/tests/test_mistral_model_prefill.py +++ b/models/demos/wormhole/mistral7b/tests/test_mistral_model_prefill.py @@ -11,6 +11,8 @@ os.environ["MISTRAL_CKPT_DIR"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/" os.environ["MISTRAL_TOKENIZER_PATH"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/" os.environ["MISTRAL_CACHE_PATH"] = "/mnt/MLPerf/ttnn/models/demos/mistral7b/" + # Prefill prompt files too large to keep in repo + os.environ["MIXTRAL_REF_OUTPUT_PATH"] = "/mnt/MLPerf/tt_dnn-models/Mistral/Mixtral-8x7B-v0.1/prefill/" import ttnn from models.demos.wormhole.mistral7b.tt.mistral_common import ( @@ -45,8 +47,10 @@ def forward(self, x): "seq_len", ( 128, - # 1024//2, - # 1024 * 2, + 512, + 1024, + 2048, + 4096, ), ) @pytest.mark.parametrize( @@ -68,6 +72,7 @@ def test_mistral_model_inference(device, version, seq_len, use_program_cache, re cache_pcc = False # Flag to measure KV cache PCC for all layers dtype = ttnn.bfloat8_b + pcc = 0.94 model_args = TtModelArgs(device) model_args.n_layers = 32 # Full model @@ -90,7 +95,11 @@ def test_mistral_model_inference(device, version, seq_len, use_program_cache, re # The instruct prompts follow the format: [INST] prompt [/INST]. [INST] are strings. is the correspoding bos_id token prompts = ["[INST] what is the capital of Canada? [/INST]"] * 32 else: - prompt_file = "models/demos/wormhole/mistral7b/tests/tale-of-two-cities.txt" + prompt_file = os.environ["MIXTRAL_REF_OUTPUT_PATH"] + "/tale-of-two-cities.txt" + assert os.path.exists( + prompt_file + ), f"Expected prompt file not found: {prompt_file}. Please set the flag 'MIXTRAL_REF_OUTPUT_PATH' correctly." + with open(prompt_file, "r") as f: prompts = f.read() encoded_prompts = tokenizer.encode(prompts)[:seq_len] @@ -156,7 +165,7 @@ def test_mistral_model_inference(device, version, seq_len, use_program_cache, re # TODO Measure only PCC at the end, instead of at every iteration # Measure PCC if also running reference model if run_ref_pt: - passing, pcc_message = comp_pcc(ref_output, tt_output_torch, 0.94) + passing, pcc_message = comp_pcc(ref_output, tt_output_torch, pcc) logger.info(comp_allclose(ref_output, tt_output_torch)) logger.info(f"Model output: {pcc_message}") @@ -206,4 +215,4 @@ def test_mistral_model_inference(device, version, seq_len, use_program_cache, re logger.info(f"All Mistral decode iterations Passed!") else: logger.warning("One or more iterations of Mistral decode had bad PCC") - assert all_tests_pass, f"PCC value is lower than {0.99} for some of the outputs. Check Warnings!" + assert all_tests_pass, f"PCC value is lower than {pcc} for some of the outputs. Check Warnings!" diff --git a/models/demos/wormhole/mistral7b/tt/model_config.py b/models/demos/wormhole/mistral7b/tt/model_config.py index 84b6d71e637e..4286f0ad633a 100644 --- a/models/demos/wormhole/mistral7b/tt/model_config.py +++ b/models/demos/wormhole/mistral7b/tt/model_config.py @@ -141,7 +141,7 @@ def __init__(self, device, instruct=False): per_core_M=4, # 32, #16, # M / TILE_HEIGHT / Grid_Size (dynamic based on seqlen) per_core_N=56, # N / TILE_WIDTH / Grid_Size transpose_mcast=False, - fused_activation=ttnn.experimental.tensor.FusibleActivation.SILU, + fused_activation=ttnn.UnaryOpType.SILU, fuse_batch=False, ) @@ -176,7 +176,7 @@ def __init__(self, device, instruct=False): per_core_M=1, # 32, #16, # M / TILE_HEIGHT / Grid_Size (dynamic based on seqlen) per_core_N=56, # N / TILE_WIDTH / Grid_Size transpose_mcast=False, - fused_activation=ttnn.experimental.tensor.FusibleActivation.SILU, + fused_activation=ttnn.UnaryOpType.SILU, fuse_batch=False, )