From cb3678c90d4c777c30fd55f4e3900c5c0ddc83b8 Mon Sep 17 00:00:00 2001 From: kkeerthana0573 Date: Mon, 27 May 2024 18:58:38 +0000 Subject: [PATCH] #6344: Port RoBERTa model to n300 --- models/demos/bert/tt/ttnn_optimized_bert.py | 15 ++- models/demos/roberta/README.md | 19 +++ .../roberta}/demo/demo.py | 85 +++++++----- .../roberta}/demo/input_data.json | 0 .../roberta/tests/test_perf_device_roberta.py | 37 ++++++ .../demos/roberta/tests}/test_performance.py | 42 +++--- .../experimental/functional_roberta/README.md | 21 --- tests/scripts/run_performance.sh | 4 + .../single_card/run_single_card_demo_tests.sh | 2 + .../roberta/test_ttnn_optimized_roberta.py | 124 ++++++++++++++++++ 10 files changed, 265 insertions(+), 84 deletions(-) create mode 100644 models/demos/roberta/README.md rename models/{experimental/functional_roberta => demos/roberta}/demo/demo.py (79%) rename models/{experimental/functional_roberta => demos/roberta}/demo/input_data.json (100%) create mode 100644 models/demos/roberta/tests/test_perf_device_roberta.py rename {tests/ttnn/integration_tests/roberta => models/demos/roberta/tests}/test_performance.py (78%) delete mode 100644 models/experimental/functional_roberta/README.md create mode 100644 tests/ttnn/integration_tests/roberta/test_ttnn_optimized_roberta.py diff --git a/models/demos/bert/tt/ttnn_optimized_bert.py b/models/demos/bert/tt/ttnn_optimized_bert.py index 9ddfe1a59fd8..97109d873e1a 100644 --- a/models/demos/bert/tt/ttnn_optimized_bert.py +++ b/models/demos/bert/tt/ttnn_optimized_bert.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 import ttnn - +from models.utility_functions import is_grayskull from models.experimental.functional_common.attention_mask_functions import get_extended_attention_mask @@ -13,7 +13,7 @@ def bert_attention( attention_mask, *, parameters, - num_cores_x=12, + num_cores_x=12 if is_grayskull() else 8, ): num_heads = config.num_attention_heads batch_size, _, hidden_size = hidden_states.shape @@ -43,7 +43,7 @@ def bert_attention( query, key, memory_config=ttnn.L1_MEMORY_CONFIG, - dtype=ttnn.bfloat16, + dtype=ttnn.bfloat8_b, core_grid=ttnn.CoreGrid(y=batch_size, x=num_cores_x), ) ttnn.deallocate(query) @@ -95,7 +95,7 @@ def bert_intermediate( hidden_states, *, parameters, - num_cores_x=12, + num_cores_x=12 if is_grayskull() else 8, ): batch_size, *_ = hidden_states.shape @@ -107,6 +107,11 @@ def bert_intermediate( dtype=ttnn.bfloat8_b, core_grid=ttnn.CoreGrid(y=batch_size, x=num_cores_x), activation="gelu", + compute_kernel_config=ttnn.WormholeComputeKernelConfig( + math_fidelity=ttnn.MathFidelity.HiFi2, + math_approx_mode=False, + packer_l1_acc=False, + ), ) return output @@ -117,7 +122,7 @@ def bert_output( residual, *, parameters, - num_cores_x=12, + num_cores_x=12 if is_grayskull() else 8, ): batch_size, *_ = hidden_states.shape diff --git a/models/demos/roberta/README.md b/models/demos/roberta/README.md new file mode 100644 index 000000000000..ae3827a12325 --- /dev/null +++ b/models/demos/roberta/README.md @@ -0,0 +1,19 @@ +## functional_roberta Demo +## How to Run + +If you wish to run the demo for ttnn_optimized_functional_roberta, use `pytest --disable-warnings models/demos/roberta/demo/demo.py::test_demo[models.demos.bert.tt.ttnn_optimized_bert-8-384-deepset/roberta-large-squad2-models/demos/roberta/demo/input_data.json]` to run the demo. + +If you wish to run the demo with a different input use `pytest --disable-warnings models/demos/roberta/demo/demo.py::test_demo[models.demos.bert.tt.ttnn_optimized_bert-8-384-deepset/roberta-large-squad2-]`. This file is expected to have exactly 8 inputs. + +Our second demo is designed to run SQuADV2 dataset, run this with `pytest --disable-warnings models/demos/roberta/demo/demo.py::test_demo_squadv2[models.demos.bert.tt.ttnn_optimized_bert-8-384-3-deepset/roberta-large-squad2]`. + +If you wish to run for `n_iterations` samples, use `pytest --disable-warnings models/demos/roberta/demo/demo.py::test_demo_squadv2[models.demos.bert.tt.ttnn_optimized_bert-8-384--deepset/roberta-large-squad2]` + + +# Inputs +Inputs by default are provided from `input_data.json`. If you wish you to change the inputs, provide a different path to test_demo. + +We do not recommend modifying `input_data.json` file. + +# Details +The entry point to functional_roberta model is bert_for_question_answering in `models/demos/bert/tt/ttnn_bert.py` (`models/demos/bert/tt/ttnn_optimized_bert.py` for optimized version). The model picks up certain configs and weights from huggingface pretrained model. We have used `deepset/roberta-large-squad2` version from huggingface as our reference. diff --git a/models/experimental/functional_roberta/demo/demo.py b/models/demos/roberta/demo/demo.py similarity index 79% rename from models/experimental/functional_roberta/demo/demo.py rename to models/demos/roberta/demo/demo.py index 0e83e11064b4..ad5607c0b36a 100644 --- a/models/experimental/functional_roberta/demo/demo.py +++ b/models/demos/roberta/demo/demo.py @@ -14,7 +14,6 @@ disable_persistent_kernel_cache, profiler, ) -from models.demos.bert.tt import ttnn_bert from models.demos.bert.tt import ttnn_optimized_bert from models.datasets.dataset_squadv2 import squadv2_1K_samples_input, squadv2_answer_decode_batch @@ -42,6 +41,12 @@ def load_inputs(input_path, batch): return context, question +def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0): + mask = input_ids.ne(padding_idx).int() + incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask + return incremental_indices.long() + padding_idx + + def run_roberta_question_and_answering_inference( device, use_program_cache, @@ -60,13 +65,9 @@ def run_roberta_question_and_answering_inference( tokenizer = RobertaTokenizer.from_pretrained(model_name) config = hugging_face_reference_model.config nlp = pipeline("question-answering", model=hugging_face_reference_model, tokenizer=tokenizer) + config.use_dram = True - if bert == ttnn_bert: - tt_model_name = f"ttnn_{model_name}" - elif bert == ttnn_optimized_bert: - tt_model_name = f"ttnn_{model_name}_optimized" - else: - raise ValueError(f"Unknown bert: {bert}") + tt_model_name = f"ttnn_{model_name}_optimized" profiler.start(f"preprocessing_parameter") parameters = preprocess_model_parameters( @@ -105,10 +106,14 @@ def run_roberta_question_and_answering_inference( profiler.start(f"preprocessing_input") + position_ids = create_position_ids_from_input_ids( + input_ids=roberta_input.input_ids, padding_idx=config.pad_token_id + ) ttnn_roberta_inputs = bert.preprocess_inputs( roberta_input["input_ids"], roberta_input["token_type_ids"], - torch.zeros(1, sequence_size) if bert == ttnn_optimized_bert else None, + position_ids, + roberta_input["attention_mask"], device=device, ) profiler.end(f"preprocessing_input") @@ -139,7 +144,8 @@ def run_roberta_question_and_answering_inference( tt_answer = nlp.postprocess([tt_res], **postprocess_params) - logger.info(f"answer: {tt_answer['answer']}\n") + logger.info(f"Question: {question[i]}") + logger.info(f"Answer: {tt_answer['answer']}\n") model_answers[i] = tt_answer["answer"] profiler.end("post_processing_output_to_string") @@ -175,13 +181,9 @@ def run_roberta_question_and_answering_inference_squad_v2( # set up tokenizer tokenizer = RobertaTokenizer.from_pretrained(model_name) config = hugging_face_reference_model.config + config.use_dram = True - if bert == ttnn_bert: - tt_model_name = f"ttnn_{model_name}" - elif bert == ttnn_optimized_bert: - tt_model_name = f"ttnn_{model_name}_optimized" - else: - raise ValueError(f"Unknown bert: {bert}") + tt_model_name = f"ttnn_{model_name}_optimized" parameters = preprocess_model_parameters( model_name=tt_model_name, @@ -208,10 +210,14 @@ def run_roberta_question_and_answering_inference_squad_v2( if i < n_iterations: batch_data = batch[0] curr_batch_size = batch_data["input_ids"].shape[0] + position_ids = create_position_ids_from_input_ids( + input_ids=batch_data.input_ids, padding_idx=config.pad_token_id + ) ttnn_roberta_inputs = bert.preprocess_inputs( batch_data["input_ids"], batch_data["token_type_ids"], - torch.zeros(1, sequence_size) if bert == ttnn_optimized_bert else None, + position_ids, + batch_data["attention_mask"], device=device, ) @@ -250,18 +256,24 @@ def run_roberta_question_and_answering_inference_squad_v2( i += 1 eval_score = squad_metric.compute(predictions=pred_labels, references=true_labels) cpu_eval_score = squad_metric.compute(predictions=cpu_pred_labels, references=true_labels) - logger.info(f"\tTT_Eval: exact: {eval_score['exact']} -- F1: {eval_score['f1']}") + logger.info(f"TT_Eval: exact: {eval_score['exact']} -- F1: {eval_score['f1']}") + logger.info(f"CPU_Eval: exact: {cpu_eval_score['exact']} -- F1: {cpu_eval_score['f1']}") + assert eval_score["exact"] >= cpu_eval_score["exact"] and eval_score["f1"] >= cpu_eval_score["f1"], ( + f"Expected Exact Match: {cpu_eval_score['exact']}, Actual Exact Match: {eval_score['exact']}; " + f"Expected F1 Score: {cpu_eval_score['f1']}, Actual F1 Score: {eval_score['f1']}" + ) -@pytest.mark.parametrize("model_name", ["deepset/roberta-large-squad2"]) -@pytest.mark.parametrize("bert", [ttnn_bert, ttnn_optimized_bert]) -def test_demo( - input_path, - model_name, - bert, - device, - use_program_cache, -): + +@pytest.mark.parametrize( + "model_name, input_loc", + ((["deepset/roberta-large-squad2", "models/demos/roberta/demo/input_data.json"]),), +) +@pytest.mark.parametrize( + ("bert", "batch_size", "sequence_size"), + ((ttnn_optimized_bert, 8, 384),), +) +def test_demo(device, use_program_cache, model_name, input_loc, bert, batch_size, sequence_size): disable_persistent_kernel_cache() disable_compilation_reports() @@ -269,25 +281,26 @@ def test_demo( device=device, use_program_cache=use_program_cache, model_name=model_name, - batch_size=8, - sequence_size=384, + batch_size=batch_size, + sequence_size=sequence_size, bert=bert, - input_path=input_path, + input_path=input_loc, ) @pytest.mark.parametrize("model_name", ["deepset/roberta-large-squad2"]) -@pytest.mark.parametrize("bert", [ttnn_bert, ttnn_optimized_bert]) @pytest.mark.parametrize( - "n_iterations", - ((3),), + ("bert", "batch_size", "sequence_size", "n_iterations"), + ((ttnn_optimized_bert, 8, 384, 3),), ) def test_demo_squadv2( + device, + use_program_cache, model_name, bert, + batch_size, + sequence_size, n_iterations, - device, - use_program_cache, ): disable_persistent_kernel_cache() disable_compilation_reports() @@ -296,8 +309,8 @@ def test_demo_squadv2( device=device, use_program_cache=use_program_cache, model_name=model_name, - batch_size=8, - sequence_size=384, + batch_size=batch_size, + sequence_size=sequence_size, bert=bert, n_iterations=n_iterations, ) diff --git a/models/experimental/functional_roberta/demo/input_data.json b/models/demos/roberta/demo/input_data.json similarity index 100% rename from models/experimental/functional_roberta/demo/input_data.json rename to models/demos/roberta/demo/input_data.json diff --git a/models/demos/roberta/tests/test_perf_device_roberta.py b/models/demos/roberta/tests/test_perf_device_roberta.py new file mode 100644 index 000000000000..ff5eaf3b180e --- /dev/null +++ b/models/demos/roberta/tests/test_perf_device_roberta.py @@ -0,0 +1,37 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import pytest +from models.utility_functions import is_grayskull +from models.perf.device_perf_utils import run_device_perf, check_device_perf, prep_device_perf_report + + +@pytest.mark.models_device_performance_bare_metal +@pytest.mark.parametrize( + "batch_size, test", + [ + [8, "sequence_size=384-batch_size=8-model_name=deepset/roberta-large-squad2"], + ], +) +def test_perf_device_bare_metal(batch_size, test): + subdir = "ttnn_roberta" + num_iterations = 1 + margin = 0.03 + expected_perf = 166.88 if is_grayskull() else 166.66 + + command = f"pytest tests/ttnn/integration_tests/roberta/test_ttnn_optimized_roberta.py::test_roberta_for_question_answering[{test}]" + cols = ["DEVICE FW", "DEVICE KERNEL", "DEVICE BRISC KERNEL"] + + inference_time_key = "AVG DEVICE KERNEL SAMPLES/S" + expected_perf_cols = {inference_time_key: expected_perf} + + post_processed_results = run_device_perf(command, subdir, num_iterations, cols, batch_size) + expected_results = check_device_perf(post_processed_results, margin, expected_perf_cols) + prep_device_perf_report( + model_name=f"ttnn_roberta_{batch_size}", + batch_size=batch_size, + post_processed_results=post_processed_results, + expected_results=expected_results, + comments=test.replace("/", "_"), + ) diff --git a/tests/ttnn/integration_tests/roberta/test_performance.py b/models/demos/roberta/tests/test_performance.py similarity index 78% rename from tests/ttnn/integration_tests/roberta/test_performance.py rename to models/demos/roberta/tests/test_performance.py index 2579da7f2469..55cef8e7996e 100644 --- a/tests/ttnn/integration_tests/roberta/test_performance.py +++ b/models/demos/roberta/tests/test_performance.py @@ -2,25 +2,19 @@ # SPDX-License-Identifier: Apache-2.0 -import time - -import pytest -from loguru import logger +import ttnn +import time import torch +import pytest import transformers - -import ttnn - -from models.demos.bert.tt import ttnn_bert +from loguru import logger from models.demos.bert.tt import ttnn_optimized_bert - from ttnn.model_preprocessing import preprocess_model_parameters from models.utility_functions import ( - is_wormhole_b0, - is_blackhole, + is_grayskull, enable_persistent_kernel_cache, disable_persistent_kernel_cache, ) @@ -29,34 +23,34 @@ def get_expected_times(bert): return { - ttnn_bert: (13, 32), - ttnn_optimized_bert: (12, 0.092), + ttnn_optimized_bert: (8.7, 0.15) if is_grayskull() else (12.5, 0.14), }[bert] -@pytest.mark.skipif(is_wormhole_b0() or is_blackhole(), reason="Unsupported on WH and BH") +def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0): + mask = input_ids.ne(padding_idx).int() + incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask + return incremental_indices.long() + padding_idx + + @pytest.mark.models_performance_bare_metal @pytest.mark.models_performance_virtual_machine @pytest.mark.parametrize("model_name", ["deepset/roberta-large-squad2"]) @pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("sequence_size", [384]) -@pytest.mark.parametrize("bert", [ttnn_bert, ttnn_optimized_bert]) +@pytest.mark.parametrize("bert", [ttnn_optimized_bert]) def test_performance(device, use_program_cache, model_name, batch_size, sequence_size, bert): disable_persistent_kernel_cache() config = transformers.RobertaConfig.from_pretrained(model_name) + config.use_dram = True input_ids = torch.randint(0, config.vocab_size, (batch_size, sequence_size)).to(torch.int32) torch_token_type_ids = torch.zeros((batch_size, sequence_size), dtype=torch.int32) - torch_position_ids = torch.zeros((batch_size, sequence_size), dtype=torch.int32) torch_attention_mask = torch.zeros(1, sequence_size) if bert == ttnn_optimized_bert else None + torch_position_ids = create_position_ids_from_input_ids(input_ids=input_ids, padding_idx=config.pad_token_id) - if bert == ttnn_bert: - tt_model_name = f"ttnn_{model_name}" - elif bert == ttnn_optimized_bert: - tt_model_name = f"ttnn_{model_name}_optimized" - else: - raise ValueError(f"Unknown functional_roberta: {bert}") + tt_model_name = f"ttnn_{model_name}_optimized" parameters = preprocess_model_parameters( model_name=tt_model_name, @@ -106,3 +100,7 @@ def test_performance(device, use_program_cache, model_name, batch_size, sequence logger.info(f"Compile time: {inference_and_compile_time - inference_time}") logger.info(f"Inference time: {inference_time}") logger.info(f"Samples per second: {1 / inference_time * batch_size}") + + assert ( + inference_time < expected_inference_time + ), f"Expected inference time: {expected_inference_time} Actual inference time: {inference_time}" diff --git a/models/experimental/functional_roberta/README.md b/models/experimental/functional_roberta/README.md deleted file mode 100644 index cef8feb818c0..000000000000 --- a/models/experimental/functional_roberta/README.md +++ /dev/null @@ -1,21 +0,0 @@ -## functional_roberta Demo -## How to Run - -Use `pytest --disable-warnings --input-path="models/experimental/functional_roberta/demo/input_data.json" models/experimental/functional_roberta/demo/demo.py::test_demo[models.demos.bert.tt.ttnn_bert-deepset/roberta-large-squad2]` to run the demo. - -If you wish to run the demo for ttnn_optimized_functional_roberta, use `pytest --disable-warnings --input-path="models/experimental/functional_roberta/demo/input_data.json" models/experimental/functional_roberta/demo/demo.py::test_demo[models.demos.bert.tt.ttnn_optimized_bert-deepset/roberta-large-squad2]` to run the demo. - -If you wish to run the demo with a different input use `pytest --disable-warnings --input-path="" models/experimental/functional_roberta/demo/demo.py::test_demo[models.demos.bert.tt.ttnn_bert-deepset/roberta-large-squad2]`. This file is expected to have exactly 8 inputs. - -Our second demo is designed to run SQuADV2 dataset, run this with `pytest --disable-warnings models/experimental/functional_roberta/demo/demo.py::test_demo_squadv2[3-models.demos.bert.tt.ttnn_bert-deepset/roberta-large-squad2]`. - -If you wish to run for `n_iterations` samples, use `pytest --disable-warnings models/experimental/functional_roberta/demo/demo.py::test_demo_squadv2[-models.demos.bert.tt.ttnn_bert-deepset/roberta-large-squad2]` - - -# Inputs -Inputs by default are provided from `input_data.json`. If you wish you to change the inputs, provide a different path to test_demo. - -We do not recommend modifying `input_data.json` file. - -# Details -The entry point to functional_roberta model is bert_for_question_answering in `models/experimental/bert/tt/ttnn_bert.py` (`models/experimental/bert/tt/ttnn_optimized_bert.py` for optimized version). The model picks up certain configs and weights from huggingface pretrained model. We have used `deepset/roberta-large-squad2` version from huggingface as our reference. diff --git a/tests/scripts/run_performance.sh b/tests/scripts/run_performance.sh index 5acab7c6def7..897ed740df21 100755 --- a/tests/scripts/run_performance.sh +++ b/tests/scripts/run_performance.sh @@ -39,6 +39,8 @@ run_perf_models_other() { env pytest -n auto models/demos/mnist/tests -m $test_marker + env pytest -n auto models/demos/roberta/tests/test_performance.py -m $test_marker + ## Merge all the generated reports env python models/perf/merge_perf_results.py } @@ -106,6 +108,8 @@ run_device_perf_models() { env pytest models/demos/mnist/tests -m $test_marker + env pytest models/demos/roberta/tests/ -m $test_marker + if [ "$tt_arch" == "grayskull" ]; then #TODO(MO): Until #6560 is fixed, GS device profiler test are grouped with #Model Device perf regression tests to make sure thy run on no-soft-reset BMs diff --git a/tests/scripts/single_card/run_single_card_demo_tests.sh b/tests/scripts/single_card/run_single_card_demo_tests.sh index 5f5642483f63..0c9d9473437d 100755 --- a/tests/scripts/single_card/run_single_card_demo_tests.sh +++ b/tests/scripts/single_card/run_single_card_demo_tests.sh @@ -39,6 +39,8 @@ run_common_func_tests() { # Mnist pytest --disable-warnings models/demos/mnist/demo/demo.py --timeout 600; fail+=$? + #RoBERTa + pytest --disable-warnings models/demos/roberta/demo/demo.py --timeout 600; fail+=$? return $fail } diff --git a/tests/ttnn/integration_tests/roberta/test_ttnn_optimized_roberta.py b/tests/ttnn/integration_tests/roberta/test_ttnn_optimized_roberta.py new file mode 100644 index 000000000000..bbbd45fd4af7 --- /dev/null +++ b/tests/ttnn/integration_tests/roberta/test_ttnn_optimized_roberta.py @@ -0,0 +1,124 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import ttnn +import torch +import pytest +import tt_lib +import transformers + +from tests.ttnn.utils_for_testing import assert_with_pcc +from models.demos.bert.tt import ttnn_optimized_bert, ttnn_bert +from ttnn.model_preprocessing import preprocess_model_parameters +from transformers import RobertaForQuestionAnswering, RobertaConfig +from models.utility_functions import skip_for_wormhole_b0, is_grayskull + + +def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0): + mask = input_ids.ne(padding_idx).int() + incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask + return incremental_indices.long() + padding_idx + + +@pytest.mark.parametrize("model_name", ["deepset/roberta-large-squad2"]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("sequence_size", [384]) +def test_roberta(device, use_program_cache, reset_seeds, model_name, batch_size, sequence_size): + config = transformers.RobertaConfig.from_pretrained(model_name) + model = transformers.RobertaModel.from_pretrained(model_name) + config.use_dram = True + + input_ids = torch.randint(0, config.vocab_size, (batch_size, sequence_size)).to(torch.int32) + torch_token_type_ids = torch.zeros((batch_size, sequence_size), dtype=torch.int32) + torch_attention_mask = torch.ones(batch_size, sequence_size) + torch_position_ids = create_position_ids_from_input_ids(input_ids=input_ids, padding_idx=config.pad_token_id) + torch_output = model( + input_ids=input_ids, + attention_mask=torch_attention_mask, + token_type_ids=torch_token_type_ids, + position_ids=torch_position_ids, + ) + torch_output = torch_output.last_hidden_state + + tt_model_name = f"ttnn_{model_name}_optimized" + + parameters = preprocess_model_parameters( + model_name=tt_model_name, + initialize_model=lambda: transformers.RobertaModel.from_pretrained(model_name, torchscript=False).eval(), + custom_preprocessor=ttnn_optimized_bert.custom_preprocessor, + device=device, + ) + + ttnn_roberta_inputs = ttnn_optimized_bert.preprocess_inputs( + input_ids, + torch_token_type_ids, + torch_position_ids, + torch_attention_mask, + device=device, + ) + + tt_output = ttnn_optimized_bert.bert( + config, + *ttnn_roberta_inputs, + parameters=parameters, + ) + + tt_output = ttnn.to_torch(tt_output) + + assert_with_pcc(torch_output, tt_output, 0.89 if is_grayskull() else 0.94) + + +@pytest.mark.parametrize("model_name", ["deepset/roberta-large-squad2"]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("sequence_size", [384]) +def test_roberta_for_question_answering(device, use_program_cache, reset_seeds, model_name, batch_size, sequence_size): + config = RobertaConfig.from_pretrained(model_name) + model = RobertaForQuestionAnswering.from_pretrained(model_name) + config.use_dram = True + + input_ids = torch.randint(0, config.vocab_size, (batch_size, sequence_size)).to(torch.int32) + torch_token_type_ids = torch.zeros((batch_size, sequence_size), dtype=torch.int32) + torch_attention_mask = torch.ones(batch_size, sequence_size) + torch_position_ids = create_position_ids_from_input_ids(input_ids=input_ids, padding_idx=config.pad_token_id) + torch_output = model( + input_ids=input_ids, + attention_mask=torch_attention_mask, + token_type_ids=torch_token_type_ids, + position_ids=torch_position_ids, + ) + torch_output_start_logits = torch_output.start_logits + torch_output_end_logits = torch_output.end_logits + + tt_model_name = f"ttnn_{model_name}_optimized" + + parameters = preprocess_model_parameters( + model_name=tt_model_name, + initialize_model=lambda: transformers.RobertaForQuestionAnswering.from_pretrained( + model_name, torchscript=False + ).eval(), + custom_preprocessor=ttnn_optimized_bert.custom_preprocessor, + device=device, + ) + + ttnn_roberta_inputs = ttnn_optimized_bert.preprocess_inputs( + input_ids, + torch_token_type_ids, + torch_position_ids, + torch_attention_mask, + device=device, + ) + + tt_output = ttnn_optimized_bert.bert_for_question_answering( + config, + *ttnn_roberta_inputs, + parameters=parameters, + name="roberta", + ) + tt_output = ttnn.to_torch(tt_output) + + tt_output_start_logits = tt_output[..., :, 0] + tt_output_end_logits = tt_output[..., :, 1] + + assert_with_pcc(torch_output_start_logits, tt_output_start_logits, 0.81 if is_grayskull() else 0.90) + assert_with_pcc(torch_output_end_logits, tt_output_end_logits, 0.79 if is_grayskull() else 0.90)