diff --git a/models/demos/bert/tt/ttnn_optimized_bert.py b/models/demos/bert/tt/ttnn_optimized_bert.py index 9ddfe1a59fd..cb9484d7f24 100644 --- a/models/demos/bert/tt/ttnn_optimized_bert.py +++ b/models/demos/bert/tt/ttnn_optimized_bert.py @@ -42,7 +42,7 @@ def bert_attention( attention_scores = ttnn.matmul( query, key, - memory_config=ttnn.L1_MEMORY_CONFIG, + # memory_config=ttnn.L1_MEMORY_CONFIG, dtype=ttnn.bfloat16, core_grid=ttnn.CoreGrid(y=batch_size, x=num_cores_x), ) diff --git a/models/demos/roberta/README.md b/models/demos/roberta/README.md new file mode 100644 index 00000000000..f8fb1a9c29c --- /dev/null +++ b/models/demos/roberta/README.md @@ -0,0 +1,21 @@ +## functional_roberta Demo +## How to Run + +Use `pytest --disable-warnings models/demos/roberta/demo/demo.py::test_demo[models.demos.bert.tt.ttnn_optimized_bert-deepset/roberta-large-squad2-models/demos/roberta/demo/input_data.json]` to run the demo. + +If you wish to run the demo for ttnn_optimized_functional_roberta, use `pytest --disable-warnings --input-path="models/demos/roberta/demo/input_data.json" models/demos/roberta/demo/demo.py::test_demo[models.demos.bert.tt.ttnn_optimized_bert-deepset/roberta-large-squad2-models/demos/roberta/demo/input_data.json]` to run the demo. + +If you wish to run the demo with a different input use `pytest --disable-warnings --input-path="" models/demos/roberta/demo/demo.py::test_demo[models.demos.bert.tt.ttnn_bert-deepset/roberta-large-squad2]`. This file is expected to have exactly 8 inputs. + +Our second demo is designed to run SQuADV2 dataset, run this with `pytest --disable-warnings models/demos/roberta/demo/demo.py::test_demo_squadv2[3-models.demos.bert.tt.ttnn_bert-deepset/roberta-large-squad2]`. + +If you wish to run for `n_iterations` samples, use `pytest --disable-warnings models/demos/roberta/demo/demo.py::test_demo_squadv2[-models.demos.bert.tt.ttnn_bert-deepset/roberta-large-squad2]` + + +# Inputs +Inputs by default are provided from `input_data.json`. If you wish you to change the inputs, provide a different path to test_demo. + +We do not recommend modifying `input_data.json` file. + +# Details +The entry point to functional_roberta model is bert_for_question_answering in `models/demos/bert/tt/ttnn_bert.py` (`models/demos/bert/tt/ttnn_optimized_bert.py` for optimized version). The model picks up certain configs and weights from huggingface pretrained model. We have used `deepset/roberta-large-squad2` version from huggingface as our reference. diff --git a/models/experimental/functional_roberta/demo/demo.py b/models/demos/roberta/demo/demo.py similarity index 89% rename from models/experimental/functional_roberta/demo/demo.py rename to models/demos/roberta/demo/demo.py index 0e83e11064b..d2af25f2cc8 100644 --- a/models/experimental/functional_roberta/demo/demo.py +++ b/models/demos/roberta/demo/demo.py @@ -14,8 +14,7 @@ disable_persistent_kernel_cache, profiler, ) -from models.demos.bert.tt import ttnn_bert -from models.demos.bert.tt import ttnn_optimized_bert +from models.demos.bert.tt import ttnn_bert, ttnn_optimized_bert from models.datasets.dataset_squadv2 import squadv2_1K_samples_input, squadv2_answer_decode_batch from ttnn.model_preprocessing import ( @@ -42,6 +41,12 @@ def load_inputs(input_path, batch): return context, question +def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0): + mask = input_ids.ne(padding_idx).int() + incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask + return incremental_indices.long() + padding_idx + + def run_roberta_question_and_answering_inference( device, use_program_cache, @@ -105,10 +110,14 @@ def run_roberta_question_and_answering_inference( profiler.start(f"preprocessing_input") + position_ids = create_position_ids_from_input_ids( + input_ids=roberta_input.input_ids, padding_idx=config.pad_token_id + ) ttnn_roberta_inputs = bert.preprocess_inputs( roberta_input["input_ids"], roberta_input["token_type_ids"], - torch.zeros(1, sequence_size) if bert == ttnn_optimized_bert else None, + position_ids, + roberta_input["attention_mask"], device=device, ) profiler.end(f"preprocessing_input") @@ -208,10 +217,14 @@ def run_roberta_question_and_answering_inference_squad_v2( if i < n_iterations: batch_data = batch[0] curr_batch_size = batch_data["input_ids"].shape[0] + position_ids = create_position_ids_from_input_ids( + input_ids=batch_data.input_ids, padding_idx=config.pad_token_id + ) ttnn_roberta_inputs = bert.preprocess_inputs( batch_data["input_ids"], batch_data["token_type_ids"], - torch.zeros(1, sequence_size) if bert == ttnn_optimized_bert else None, + position_ids, + batch_data["attention_mask"], device=device, ) @@ -253,10 +266,13 @@ def run_roberta_question_and_answering_inference_squad_v2( logger.info(f"\tTT_Eval: exact: {eval_score['exact']} -- F1: {eval_score['f1']}") -@pytest.mark.parametrize("model_name", ["deepset/roberta-large-squad2"]) -@pytest.mark.parametrize("bert", [ttnn_bert, ttnn_optimized_bert]) +@pytest.mark.parametrize( + "model_name, input_loc", + ((["deepset/roberta-large-squad2", "models/demos/roberta/demo/input_data.json"]),), +) +@pytest.mark.parametrize("bert", [ttnn_optimized_bert, ttnn_bert]) def test_demo( - input_path, + input_loc, model_name, bert, device, @@ -272,12 +288,12 @@ def test_demo( batch_size=8, sequence_size=384, bert=bert, - input_path=input_path, + input_path=input_loc, ) @pytest.mark.parametrize("model_name", ["deepset/roberta-large-squad2"]) -@pytest.mark.parametrize("bert", [ttnn_bert, ttnn_optimized_bert]) +@pytest.mark.parametrize("bert", [ttnn_optimized_bert, ttnn_bert]) @pytest.mark.parametrize( "n_iterations", ((3),), diff --git a/models/experimental/functional_roberta/demo/input_data.json b/models/demos/roberta/demo/input_data.json similarity index 100% rename from models/experimental/functional_roberta/demo/input_data.json rename to models/demos/roberta/demo/input_data.json diff --git a/models/demos/roberta/tests/test_perf_device_roberta.py b/models/demos/roberta/tests/test_perf_device_roberta.py new file mode 100644 index 00000000000..e4dbceae48e --- /dev/null +++ b/models/demos/roberta/tests/test_perf_device_roberta.py @@ -0,0 +1,39 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import pytest +from models.utility_functions import skip_for_wormhole_b0, is_grayskull +from models.perf.device_perf_utils import run_device_perf, check_device_perf, prep_device_perf_report + + +@pytest.mark.models_device_performance_bare_metal +@pytest.mark.parametrize( + "batch_size, test, expected_perf", + [ + [ + 8, + "sequence_size=384-batch_size=8-model_name=deepset/roberta-large-squad2", + 66.7 if is_grayskull() else 129.8, + ], + ], +) +def test_perf_device_bare_metal(batch_size, test, expected_perf): + subdir = "ttnn_roberta_" + num_iterations = 3 + margin = 0.03 + command = f"pytest tests/ttnn/integration_tests/roberta/test_ttnn_opitmized_roberta.py::test_roberta_for_question_answering[{test}]" + cols = ["DEVICE FW", "DEVICE KERNEL", "DEVICE BRISC KERNEL"] + + inference_time_key = "AVG DEVICE KERNEL SAMPLES/S" + expected_perf_cols = {inference_time_key: expected_perf} + + post_processed_results = run_device_perf(command, subdir, num_iterations, cols, batch_size) + expected_results = check_device_perf(post_processed_results, margin, expected_perf_cols) + prep_device_perf_report( + model_name=f"ttnn_roberta_{batch_size}", + batch_size=batch_size, + post_processed_results=post_processed_results, + expected_results=expected_results, + comments=test.replace("/", "_"), + ) diff --git a/tests/ttnn/integration_tests/roberta/test_performance.py b/models/demos/roberta/tests/test_performance.py similarity index 91% rename from tests/ttnn/integration_tests/roberta/test_performance.py rename to models/demos/roberta/tests/test_performance.py index 2579da7f246..fdc823869c0 100644 --- a/tests/ttnn/integration_tests/roberta/test_performance.py +++ b/models/demos/roberta/tests/test_performance.py @@ -19,8 +19,7 @@ from ttnn.model_preprocessing import preprocess_model_parameters from models.utility_functions import ( - is_wormhole_b0, - is_blackhole, + skip_for_wormhole_b0, enable_persistent_kernel_cache, disable_persistent_kernel_cache, ) @@ -29,12 +28,11 @@ def get_expected_times(bert): return { - ttnn_bert: (13, 32), - ttnn_optimized_bert: (12, 0.092), + ttnn_bert: (13, 5.5), + ttnn_optimized_bert: (12, 0.12), }[bert] -@pytest.mark.skipif(is_wormhole_b0() or is_blackhole(), reason="Unsupported on WH and BH") @pytest.mark.models_performance_bare_metal @pytest.mark.models_performance_virtual_machine @pytest.mark.parametrize("model_name", ["deepset/roberta-large-squad2"]) @@ -52,9 +50,9 @@ def test_performance(device, use_program_cache, model_name, batch_size, sequence torch_attention_mask = torch.zeros(1, sequence_size) if bert == ttnn_optimized_bert else None if bert == ttnn_bert: - tt_model_name = f"ttnn_{model_name}" + tt_model_name = f"ttnn_roberta_{model_name}" elif bert == ttnn_optimized_bert: - tt_model_name = f"ttnn_{model_name}_optimized" + tt_model_name = f"ttnn_roberta_{model_name}_optimized" else: raise ValueError(f"Unknown functional_roberta: {bert}") diff --git a/models/experimental/functional_roberta/README.md b/models/experimental/functional_roberta/README.md deleted file mode 100644 index cef8feb818c..00000000000 --- a/models/experimental/functional_roberta/README.md +++ /dev/null @@ -1,21 +0,0 @@ -## functional_roberta Demo -## How to Run - -Use `pytest --disable-warnings --input-path="models/experimental/functional_roberta/demo/input_data.json" models/experimental/functional_roberta/demo/demo.py::test_demo[models.demos.bert.tt.ttnn_bert-deepset/roberta-large-squad2]` to run the demo. - -If you wish to run the demo for ttnn_optimized_functional_roberta, use `pytest --disable-warnings --input-path="models/experimental/functional_roberta/demo/input_data.json" models/experimental/functional_roberta/demo/demo.py::test_demo[models.demos.bert.tt.ttnn_optimized_bert-deepset/roberta-large-squad2]` to run the demo. - -If you wish to run the demo with a different input use `pytest --disable-warnings --input-path="" models/experimental/functional_roberta/demo/demo.py::test_demo[models.demos.bert.tt.ttnn_bert-deepset/roberta-large-squad2]`. This file is expected to have exactly 8 inputs. - -Our second demo is designed to run SQuADV2 dataset, run this with `pytest --disable-warnings models/experimental/functional_roberta/demo/demo.py::test_demo_squadv2[3-models.demos.bert.tt.ttnn_bert-deepset/roberta-large-squad2]`. - -If you wish to run for `n_iterations` samples, use `pytest --disable-warnings models/experimental/functional_roberta/demo/demo.py::test_demo_squadv2[-models.demos.bert.tt.ttnn_bert-deepset/roberta-large-squad2]` - - -# Inputs -Inputs by default are provided from `input_data.json`. If you wish you to change the inputs, provide a different path to test_demo. - -We do not recommend modifying `input_data.json` file. - -# Details -The entry point to functional_roberta model is bert_for_question_answering in `models/experimental/bert/tt/ttnn_bert.py` (`models/experimental/bert/tt/ttnn_optimized_bert.py` for optimized version). The model picks up certain configs and weights from huggingface pretrained model. We have used `deepset/roberta-large-squad2` version from huggingface as our reference. diff --git a/tests/scripts/run_performance.sh b/tests/scripts/run_performance.sh index 0527bc69e6b..7997e2f509f 100755 --- a/tests/scripts/run_performance.sh +++ b/tests/scripts/run_performance.sh @@ -27,6 +27,8 @@ run_perf_models_other() { env pytest -n auto models/demos/metal_BERT_large_11/tests -m $test_marker + env pytest -n auto models/demos/roberta/tests/test_performance.py -m $test_marker + ## Merge all the generated reports env python models/perf/merge_perf_results.py } @@ -82,6 +84,9 @@ run_device_perf_models() { env pytest models/demos/ttnn_falcon7b/tests -m $test_marker --timeout=360 env pytest models/demos/bert/tests -m $test_marker + + env pytest models/demos/roberta/tests/test_perf_device_roberta.py -m $test_marker + fi if [ "$tt_arch" == "wormhole_b0" ]; then @@ -94,6 +99,9 @@ run_device_perf_models() { env WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest models/demos/metal_BERT_large_11/tests -m $test_marker env WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest models/demos/falcon7b_common/tests -m $test_marker + + env WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest models/demos/roberta/tests/test_perf_device_roberta.py -m $test_marker + fi ## Merge all the generated reports diff --git a/tests/scripts/single_card/run_single_card_demo_tests.sh b/tests/scripts/single_card/run_single_card_demo_tests.sh index 96a9bd2b6f5..3f7b03fd214 100755 --- a/tests/scripts/single_card/run_single_card_demo_tests.sh +++ b/tests/scripts/single_card/run_single_card_demo_tests.sh @@ -19,6 +19,9 @@ run_common_func_tests() { # Resnet WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest -n auto --disable-warnings models/demos/wormhole/resnet50/demo/demo.py; fail+=$? + # RoBERTa + WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest -n auto --disable-warnings -q -s --input-method=json --input-path='models/demos/roberta/demo/input_data.json' models/demos/roberta/demo/demo.py --timeout 420; fail+=$? + return $fail } diff --git a/tests/ttnn/integration_tests/roberta/test_ttnn_opitmized_roberta.py b/tests/ttnn/integration_tests/roberta/test_ttnn_opitmized_roberta.py new file mode 100644 index 00000000000..6eadd489102 --- /dev/null +++ b/tests/ttnn/integration_tests/roberta/test_ttnn_opitmized_roberta.py @@ -0,0 +1,116 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import ttnn +import torch +import pytest +import tt_lib +import transformers + +from models.demos.bert.tt import ttnn_optimized_bert +from ttnn.model_preprocessing import preprocess_model_parameters +from tests.ttnn.utils_for_testing import assert_with_pcc + + +@pytest.mark.parametrize("model_name", ["deepset/roberta-large-squad2"]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("sequence_size", [384]) +def test_roberta(device, use_program_cache, reset_seeds, model_name, batch_size, sequence_size): + config = transformers.RobertaConfig.from_pretrained(model_name) + model = transformers.RobertaModel.from_pretrained(model_name) + + input_ids = torch.randint(0, config.vocab_size, (batch_size, sequence_size)).to(torch.int32) + torch_token_type_ids = torch.zeros((batch_size, sequence_size), dtype=torch.int32) + torch_position_ids = torch.zeros((batch_size, sequence_size), dtype=torch.int32) + torch_attention_mask = torch.ones(batch_size, sequence_size) + + torch_output = model( + input_ids=input_ids, + attention_mask=torch_attention_mask, + token_type_ids=torch_token_type_ids, + position_ids=torch_position_ids, + ) + torch_output = torch_output.last_hidden_state + + tt_model_name = f"ttnn_{model_name}_optimized" + + parameters = preprocess_model_parameters( + model_name=tt_model_name, + initialize_model=lambda: transformers.RobertaModel.from_pretrained(model_name, torchscript=False).eval(), + custom_preprocessor=ttnn_optimized_bert.custom_preprocessor, + device=device, + ) + + ttnn_roberta_inputs = ttnn_optimized_bert.preprocess_inputs( + input_ids, + torch_token_type_ids, + torch_position_ids, + torch_attention_mask, + device=device, + ) + + tt_output = ttnn_optimized_bert.bert( + config, + *ttnn_roberta_inputs, + parameters=parameters, + ) + + tt_output = ttnn.to_torch(tt_output) + + assert_with_pcc(torch_output, tt_output, 0.94) + + +@pytest.mark.parametrize("model_name", ["deepset/roberta-large-squad2"]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("sequence_size", [384]) +def test_roberta_for_question_answering(device, use_program_cache, reset_seeds, model_name, batch_size, sequence_size): + config = transformers.RobertaConfig.from_pretrained(model_name) + model = transformers.RobertaForQuestionAnswering.from_pretrained(model_name) + + input_ids = torch.randint(0, config.vocab_size, (batch_size, sequence_size)).to(torch.int32) + torch_token_type_ids = torch.zeros((batch_size, sequence_size), dtype=torch.int32) + torch_position_ids = torch.zeros((batch_size, sequence_size), dtype=torch.int32) + torch_attention_mask = torch.ones(batch_size, sequence_size) + + torch_output = model( + input_ids=input_ids, + attention_mask=torch_attention_mask, + token_type_ids=torch_token_type_ids, + position_ids=torch_position_ids, + ) + torch_output_start_logits = torch_output.start_logits + torch_output_end_logits = torch_output.end_logits + + tt_model_name = f"ttnn_{model_name}_optimized" + + parameters = preprocess_model_parameters( + model_name=tt_model_name, + initialize_model=lambda: transformers.RobertaForQuestionAnswering.from_pretrained( + model_name, torchscript=False + ).eval(), + custom_preprocessor=ttnn_optimized_bert.custom_preprocessor, + device=device, + ) + + ttnn_roberta_inputs = ttnn_optimized_bert.preprocess_inputs( + input_ids, + torch_token_type_ids, + torch_position_ids, + torch_attention_mask, + device=device, + ) + + tt_output = ttnn_optimized_bert.bert_for_question_answering( + config, + *ttnn_roberta_inputs, + parameters=parameters, + name="roberta", + ) + tt_output = ttnn.to_torch(tt_output) + + tt_output_start_logits = tt_output[..., :, 0] + tt_output_end_logits = tt_output[..., :, 1] + + assert_with_pcc(torch_output_start_logits, tt_output_start_logits, 0.88) + assert_with_pcc(torch_output_end_logits, tt_output_end_logits, 0.89)