From 519788f935e4f3937229a9ac3439b44f24ea46ba Mon Sep 17 00:00:00 2001 From: kkeerthana0573 Date: Mon, 27 May 2024 18:58:38 +0000 Subject: [PATCH] #6344: Update RoBERTa QA demo --- models/demos/bert/tt/ttnn_bert.py | 2 +- models/demos/bert/tt/ttnn_optimized_bert.py | 2 +- .../functional_roberta/demo/demo.py | 26 ++++++++++++++----- 3 files changed, 21 insertions(+), 9 deletions(-) diff --git a/models/demos/bert/tt/ttnn_bert.py b/models/demos/bert/tt/ttnn_bert.py index 23c480cd3d43..10b0c7958f4c 100644 --- a/models/demos/bert/tt/ttnn_bert.py +++ b/models/demos/bert/tt/ttnn_bert.py @@ -232,7 +232,7 @@ def preprocess_inputs( position_ids = ttnn.from_torch(position_ids, dtype=ttnn.uint32, device=device, memory_config=ttnn.L1_MEMORY_CONFIG) if attention_mask is not None: - attention_mask = get_extended_attention_mask(attention_mask, input_ids.shape) + attention_mask = get_extended_attention_mask(attention_mask, input_ids.shape, torch.float32) attention_mask = attention_mask.expand((batch_size, -1, -1, -1)) attention_mask = torch.clamp(attention_mask, min=-100000) attention_mask = ttnn.from_torch( diff --git a/models/demos/bert/tt/ttnn_optimized_bert.py b/models/demos/bert/tt/ttnn_optimized_bert.py index af21f25e7e27..089e755538da 100644 --- a/models/demos/bert/tt/ttnn_optimized_bert.py +++ b/models/demos/bert/tt/ttnn_optimized_bert.py @@ -286,7 +286,7 @@ def preprocess_inputs( position_ids = ttnn.from_torch(position_ids, dtype=ttnn.uint32, device=device, memory_config=ttnn.L1_MEMORY_CONFIG) if attention_mask is not None: - attention_mask = get_extended_attention_mask(attention_mask, input_ids.shape) + attention_mask = get_extended_attention_mask(attention_mask, input_ids.shape, torch.float32) attention_mask = attention_mask.expand((batch_size, -1, -1, -1)) attention_mask = torch.clamp(attention_mask, min=-100000) attention_mask = ttnn.from_torch( diff --git a/models/experimental/functional_roberta/demo/demo.py b/models/experimental/functional_roberta/demo/demo.py index 7e4dc9a18908..d56337c7513b 100644 --- a/models/experimental/functional_roberta/demo/demo.py +++ b/models/experimental/functional_roberta/demo/demo.py @@ -6,7 +6,6 @@ import pytest import torch from loguru import logger -import tt_lib import transformers import ttnn import evaluate @@ -15,8 +14,7 @@ disable_persistent_kernel_cache, profiler, ) -from models.demos.bert.tt import ttnn_bert -from models.demos.bert.tt import ttnn_optimized_bert +from models.demos.bert.tt import ttnn_bert, ttnn_optimized_bert from models.datasets.dataset_squadv2 import squadv2_1K_samples_input, squadv2_answer_decode_batch from ttnn.model_preprocessing import ( @@ -43,6 +41,12 @@ def load_inputs(input_path, batch): return context, question +def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0): + mask = input_ids.ne(padding_idx).int() + incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask + return incremental_indices.long() + padding_idx + + def run_roberta_question_and_answering_inference( device, use_program_cache, @@ -106,10 +110,14 @@ def run_roberta_question_and_answering_inference( profiler.start(f"preprocessing_input") + position_ids = create_position_ids_from_input_ids( + input_ids=roberta_input.input_ids, padding_idx=config.pad_token_id + ) ttnn_roberta_inputs = bert.preprocess_inputs( roberta_input["input_ids"], roberta_input["token_type_ids"], - torch.zeros(1, sequence_size) if bert == ttnn_optimized_bert else None, + position_ids, + roberta_input["attention_mask"], device=device, ) profiler.end(f"preprocessing_input") @@ -209,10 +217,14 @@ def run_roberta_question_and_answering_inference_squad_v2( if i < n_iterations: batch_data = batch[0] curr_batch_size = batch_data["input_ids"].shape[0] + position_ids = create_position_ids_from_input_ids( + input_ids=batch_data.input_ids, padding_idx=config.pad_token_id + ) ttnn_roberta_inputs = bert.preprocess_inputs( batch_data["input_ids"], batch_data["token_type_ids"], - torch.zeros(1, sequence_size) if bert == ttnn_optimized_bert else None, + position_ids, + batch_data["attention_mask"], device=device, ) @@ -255,7 +267,7 @@ def run_roberta_question_and_answering_inference_squad_v2( @pytest.mark.parametrize("model_name", ["deepset/roberta-large-squad2"]) -@pytest.mark.parametrize("bert", [ttnn_bert, ttnn_optimized_bert]) +@pytest.mark.parametrize("bert", [ttnn_optimized_bert, ttnn_bert]) def test_demo( input_path, model_name, @@ -278,7 +290,7 @@ def test_demo( @pytest.mark.parametrize("model_name", ["deepset/roberta-large-squad2"]) -@pytest.mark.parametrize("bert", [ttnn_bert, ttnn_optimized_bert]) +@pytest.mark.parametrize("bert", [ttnn_optimized_bert, ttnn_bert]) @pytest.mark.parametrize( "n_iterations", ((3),),