diff --git a/models/demos/llama3/tt/generator.py b/models/demos/llama3/tt/generator.py index e95841b00ec..b1d08294ceb 100644 --- a/models/demos/llama3/tt/generator.py +++ b/models/demos/llama3/tt/generator.py @@ -360,6 +360,7 @@ def capture_trace( tt_full_text_mask_expand_1NSH, tt_position_id, tt_rope_id, + _, ) = self.model.prepare_decode_inputs_host( tokens, cross_attention_masks, full_text_row_masked_out_mask, position_id ) @@ -425,6 +426,7 @@ def decode_forward_trace( tt_full_text_mask_expand_1NSH, tt_position_id, tt_rope_id, + _, ) = self.model.prepare_decode_inputs_host( tokens, cross_attention_masks, full_text_row_masked_out_mask, position_id=position_id )