diff --git a/models/demos/llama3/tt/multimodal/llama_vision_model.py b/models/demos/llama3/tt/multimodal/llama_vision_model.py index fc75adad00d2..96149d5a0f9d 100644 --- a/models/demos/llama3/tt/multimodal/llama_vision_model.py +++ b/models/demos/llama3/tt/multimodal/llama_vision_model.py @@ -388,14 +388,13 @@ def prepare_inputs_decode(self, tokens, cross_attention_masks, full_text_row_mas rot_mats, ) = self.copy_host_to_device((tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, tt_position_id, rot_mats)) - tt_xattn_mask, tt_full_text_mask_expand_1NSH = self.transform_decode_inputs_device( + tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH = self.transform_decode_inputs_device( + tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, B=tokens.shape[0], ) - tt_h = ttnn.to_memory_config(tt_h, self.configuration.model_config["DECODE_RESIDUAL_MEMCFG"]) - return ( tt_h, tt_xattn_mask, @@ -415,7 +414,7 @@ def prepare_decode_inputs_host(self, tokens, cross_attention_masks, full_text_ro h = self.prepare_inputs_common(position_ids, tokens) tt_h = self.configuration.prepare_inputs_ttnn_decode( h, - ttnn.DRAM_MEMORY_CONFIG, # L1 memory_configs are not respected for on_host tensors + None, # on_host tensors have no memory_config on_host=True, ) @@ -489,7 +488,7 @@ def copy_host_to_device(self, host_tensors, device_tensors=None): ttnn.copy_host_to_device_tensor(host_tensors[i], device_tensors[i]) return device_tensors - def transform_decode_inputs_device(self, tt_xattn_mask, tt_full_text_mask_expand_1NSH, B): + def transform_decode_inputs_device(self, tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, B): """ Does any transformations on device tensors which are necessary before ttnn_decode_forward """ @@ -498,6 +497,8 @@ def transform_decode_inputs_device(self, tt_xattn_mask, tt_full_text_mask_expand ), f"Batch size must match max batch size. Got {B}, expected {self.configuration.max_batch_size}" S = 1 + tt_h = ttnn.to_memory_config(tt_h, self.configuration.model_config["DECODE_RESIDUAL_MEMCFG"]) + tt_xattn_mask = ttnn.to_layout(tt_xattn_mask, ttnn.TILE_LAYOUT) tt_xattn_mask = ttnn.reshape( tt_xattn_mask, @@ -530,7 +531,7 @@ def transform_decode_inputs_device(self, tt_xattn_mask, tt_full_text_mask_expand ), ) - return (tt_xattn_mask, tt_full_text_mask_expand_1NSH) + return (tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH) def process_output_prefill(self, tt_out, B, S): padded_seq_len = _get_padded_prefill_seqlen(S) diff --git a/models/demos/llama3/tt/multimodal/vision_generator.py b/models/demos/llama3/tt/multimodal/vision_generator.py index f8e3681216fd..b00fbf3ff739 100644 --- a/models/demos/llama3/tt/multimodal/vision_generator.py +++ b/models/demos/llama3/tt/multimodal/vision_generator.py @@ -185,9 +185,15 @@ def capture_trace( ) trace_id = ttnn.begin_trace_capture(self.mesh_device, cq_id=0) + tt_h_trace_input = tt_h B = tokens.shape[0] # Do on-device transformations of inputs before forward - tt_xattn_mask_transform, tt_full_text_mask_expand_1NSH_transform = self.model.transform_decode_inputs_device( + ( + tt_h, + tt_xattn_mask_transform, + tt_full_text_mask_expand_1NSH_transform, + ) = self.model.transform_decode_inputs_device( + tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, B=B, @@ -204,7 +210,15 @@ def capture_trace( ttnn.end_trace_capture(self.mesh_device, trace_id, cq_id=0) - return trace_id, tt_logits_rm, tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, tt_position_id, rot_mats + return ( + trace_id, + tt_logits_rm, + tt_h_trace_input, + tt_xattn_mask, + tt_full_text_mask_expand_1NSH, + tt_position_id, + rot_mats, + ) def decode_forward_trace( self,