diff --git a/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py b/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py index 172531645c92..d72990b45b8e 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py +++ b/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py @@ -58,7 +58,7 @@ def test_llama_cross_attention_transformer_text_inference( prefill_pcc_required = 0.98 decode_pcc_required = 0.73 - mesh_device.enable_async(True) + mesh_device.enable_async(False) model_args = TtModelArgs(mesh_device, max_batch_size=batch) # Limit the max seqlen to 4k to avoid OOM on host diff --git a/models/demos/llama3/tt/model_config.py b/models/demos/llama3/tt/model_config.py index 45fe46561c2e..a79c2649e5df 100644 --- a/models/demos/llama3/tt/model_config.py +++ b/models/demos/llama3/tt/model_config.py @@ -92,15 +92,15 @@ def __init__(self, mesh_device, instruct=False, dummy_weights=False, max_batch_s # Assert if all folders and files exist assert os.path.exists( self.DEFAULT_CKPT_DIR - ), f"Checkpoint directory {self.DEFAULT_CKPT_DIR} does not exist, please use export LLAMA_CKPT_DIR=..." + ), f"Checkpoint directory {self.DEFAULT_CKPT_DIR} does not exist, please set LLAMA_DIR=... or LLAMA_CKPT_DIR=..." assert os.path.isfile( self.DEFAULT_TOKENIZER_PATH + "/tokenizer.model" - ), f"Tokenizer file {self.DEFAULT_TOKENIZER_PATH + '/tokenizer.model'} does not exist, please use export LLAMA_TOKENIZER_PATH=..." + ), f"Tokenizer file {self.DEFAULT_TOKENIZER_PATH + '/tokenizer.model'} does not exist, please set LLAMA_TOKENIZER_PATH=..." if not os.path.exists(self.DEFAULT_CACHE_PATH): os.makedirs(self.DEFAULT_CACHE_PATH) assert os.path.exists( self.DEFAULT_CACHE_PATH - ), f"Cache directory {self.DEFAULT_CACHE_PATH} does not exist, please use export LLAMA_CACHE_PATH=..." + ), f"Cache directory {self.DEFAULT_CACHE_PATH} does not exist, please set LLAMA_CACHE_PATH=..." # Check if weights exist in the specified folder. If not warn the user to run the download and untar script. # assert os.path.isfile( # self.DEFAULT_CKPT_DIR + "/consolidated.00.pth" diff --git a/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_text.py b/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_text.py index 0c50b2128b0a..8abf0588cee4 100644 --- a/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_text.py +++ b/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_text.py @@ -2,20 +2,15 @@ # SPDX-License-Identifier: Apache-2.0 -import os import math import ttnn import torch -import torch.nn as nn from models.demos.llama3.tt.llama_decoder import TtTransformerBlock from models.demos.llama3.tt.multimodal.llama_cross_block import TtLlamaCrossAttentionTransformerBlock -from models.demos.llama3.tt.llama_model import LMHead from models.demos.llama3.tt.distributed_norm import DistributedNorm from models.common.rmsnorm import RMSNorm import ttnn -from typing import Optional from models.common.lightweightmodule import LightweightModule -from models.demos.llama3.tt.llama_embedding import TtLlamaEmbedding from models.utility_functions import ( nearest_32, @@ -288,8 +283,9 @@ def forward( h = self.norm(h, mode=mode) - if mode == "decode": # h is expected to be interleaved for the lm head - h = ttnn.sharded_to_interleaved(h) + # TODO: Switch to using dram-sharded LM haed and remove this + # Note: workaround for sharded_to_interleaved memory corruption (#15113) + h = ttnn.to_memory_config(h, ttnn.DRAM_MEMORY_CONFIG) seq_len = h.shape[2] MAX_MM_SEQ_LEN = 1024