Skip to content

Commit

Permalink
#6344: Update RoBERTa QA demo
Browse files Browse the repository at this point in the history
  • Loading branch information
kkeerthana0573 committed Jun 11, 2024
1 parent 7431130 commit 519788f
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 9 deletions.
2 changes: 1 addition & 1 deletion models/demos/bert/tt/ttnn_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def preprocess_inputs(
position_ids = ttnn.from_torch(position_ids, dtype=ttnn.uint32, device=device, memory_config=ttnn.L1_MEMORY_CONFIG)

if attention_mask is not None:
attention_mask = get_extended_attention_mask(attention_mask, input_ids.shape)
attention_mask = get_extended_attention_mask(attention_mask, input_ids.shape, torch.float32)
attention_mask = attention_mask.expand((batch_size, -1, -1, -1))
attention_mask = torch.clamp(attention_mask, min=-100000)
attention_mask = ttnn.from_torch(
Expand Down
2 changes: 1 addition & 1 deletion models/demos/bert/tt/ttnn_optimized_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def preprocess_inputs(
position_ids = ttnn.from_torch(position_ids, dtype=ttnn.uint32, device=device, memory_config=ttnn.L1_MEMORY_CONFIG)

if attention_mask is not None:
attention_mask = get_extended_attention_mask(attention_mask, input_ids.shape)
attention_mask = get_extended_attention_mask(attention_mask, input_ids.shape, torch.float32)
attention_mask = attention_mask.expand((batch_size, -1, -1, -1))
attention_mask = torch.clamp(attention_mask, min=-100000)
attention_mask = ttnn.from_torch(
Expand Down
26 changes: 19 additions & 7 deletions models/experimental/functional_roberta/demo/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import pytest
import torch
from loguru import logger
import tt_lib
import transformers
import ttnn
import evaluate
Expand All @@ -15,8 +14,7 @@
disable_persistent_kernel_cache,
profiler,
)
from models.demos.bert.tt import ttnn_bert
from models.demos.bert.tt import ttnn_optimized_bert
from models.demos.bert.tt import ttnn_bert, ttnn_optimized_bert

from models.datasets.dataset_squadv2 import squadv2_1K_samples_input, squadv2_answer_decode_batch
from ttnn.model_preprocessing import (
Expand All @@ -43,6 +41,12 @@ def load_inputs(input_path, batch):
return context, question


def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
mask = input_ids.ne(padding_idx).int()
incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
return incremental_indices.long() + padding_idx


def run_roberta_question_and_answering_inference(
device,
use_program_cache,
Expand Down Expand Up @@ -106,10 +110,14 @@ def run_roberta_question_and_answering_inference(

profiler.start(f"preprocessing_input")

position_ids = create_position_ids_from_input_ids(
input_ids=roberta_input.input_ids, padding_idx=config.pad_token_id
)
ttnn_roberta_inputs = bert.preprocess_inputs(
roberta_input["input_ids"],
roberta_input["token_type_ids"],
torch.zeros(1, sequence_size) if bert == ttnn_optimized_bert else None,
position_ids,
roberta_input["attention_mask"],
device=device,
)
profiler.end(f"preprocessing_input")
Expand Down Expand Up @@ -209,10 +217,14 @@ def run_roberta_question_and_answering_inference_squad_v2(
if i < n_iterations:
batch_data = batch[0]
curr_batch_size = batch_data["input_ids"].shape[0]
position_ids = create_position_ids_from_input_ids(
input_ids=batch_data.input_ids, padding_idx=config.pad_token_id
)
ttnn_roberta_inputs = bert.preprocess_inputs(
batch_data["input_ids"],
batch_data["token_type_ids"],
torch.zeros(1, sequence_size) if bert == ttnn_optimized_bert else None,
position_ids,
batch_data["attention_mask"],
device=device,
)

Expand Down Expand Up @@ -255,7 +267,7 @@ def run_roberta_question_and_answering_inference_squad_v2(


@pytest.mark.parametrize("model_name", ["deepset/roberta-large-squad2"])
@pytest.mark.parametrize("bert", [ttnn_bert, ttnn_optimized_bert])
@pytest.mark.parametrize("bert", [ttnn_optimized_bert, ttnn_bert])
def test_demo(
input_path,
model_name,
Expand All @@ -278,7 +290,7 @@ def test_demo(


@pytest.mark.parametrize("model_name", ["deepset/roberta-large-squad2"])
@pytest.mark.parametrize("bert", [ttnn_bert, ttnn_optimized_bert])
@pytest.mark.parametrize("bert", [ttnn_optimized_bert, ttnn_bert])
@pytest.mark.parametrize(
"n_iterations",
((3),),
Expand Down

0 comments on commit 519788f

Please sign in to comment.