diff --git a/models/demos/wormhole/bert_tiny/README.md b/models/demos/wormhole/bert_tiny/README.md new file mode 100644 index 000000000000..a793bcb4dba3 --- /dev/null +++ b/models/demos/wormhole/bert_tiny/README.md @@ -0,0 +1,27 @@ +## Bert-Tiny Demo + +## Introduction +BERT stands for Bidirectional Encoder Representations from Transformers. Unlike recent language representation models, BERT is designed to pre-train deep bidirectional representations from unlabeled text by jointly conditioning on both left and right context in all layers. As a result, the pre-trained BERT model can be fine-tuned with just one additional output layer to create state-of-the-art models for a wide range of tasks, such as question answering and language inference, without substantial task-specific architecture modifications. + +# Platforms: + WH N300, N150 + +## How to Run + +Use `pytest --disable-warnings models/demos/wormhole/bert_tiny/demo/demo.py::test_demo[wormhole_b0-True-models/demos/wormhole/bert_tiny/demo/input_data.json-mrm8488/bert-tiny-finetuned-squadv2-128-device_params0]` to run the demo. + +If you wish to run the demo with a different input use `pytest --disable-warnings models/demos/wormhole/bert_tiny/demo/demo.py::test_demo[wormhole_b0-True--mrm8488/bert-tiny-finetuned-squadv2-128-device_params0]`. This file is expected to have exactly 16 inputs. + + +Our second demo is designed to run SQuADV2 dataset, run this with `pytest --disable-warnings models/demos/wormhole/bert_tiny/demo/demo.py::test_demo_squadv2[wormhole_b0-True-1-mrm8488/bert-tiny-finetuned-squadv2-384-device_params0]`. + +If you wish to run for `n_iterations` samples, use `pytest --disable-warnings models/demos/wormhole/bert_tiny/demo/demo.py::test_demo_squadv2[wormhole_b0-True--mrm8488/bert-tiny-finetuned-squadv2-384-device_params0]` + + +# Inputs +Inputs by default are provided from `input_data.json`. If you wish you to change the inputs, provide a different path to test_demo. + +We do not recommend modifying `input_data.json` file. + +# Details +The entry point to bert model is bert_for_question_answering in `models/demos/wormhole/bert_tiny/tt/bert_tiny.py`. The model picks up certain configs and weights from huggingface pretrained model. We have used `mrm8488/bert-tiny-finetuned-squadv2` version from huggingface as our reference. diff --git a/models/demos/wormhole/bert_tiny/demo/demo.py b/models/demos/wormhole/bert_tiny/demo/demo.py new file mode 100644 index 000000000000..fe403df23384 --- /dev/null +++ b/models/demos/wormhole/bert_tiny/demo/demo.py @@ -0,0 +1,317 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import json +import pytest +import torch +from loguru import logger + +import ttnn +from models.utility_functions import ( + disable_compilation_reports, + disable_persistent_kernel_cache, + profiler, +) + +from models.datasets.dataset_squadv2 import squadv2_1K_samples_input, squadv2_answer_decode_batch +from ttnn.model_preprocessing import ( + preprocess_model_parameters, +) + +from ttnn.model_preprocessing import * +from transformers import BertForQuestionAnswering, BertTokenizer, pipeline +from models.demos.wormhole.bert_tiny.tt.bert_tiny import bert_for_question_answering, preprocess_inputs +import evaluate +from models.utility_functions import skip_for_grayskull, is_wormhole_b0 + + +def load_inputs(input_path, batch): + with open(input_path) as f: + input_data = json.load(f) + assert len(input_data) >= batch, f"Input data needs to have at least {batch} (batch size) entries." + + context = [] + question = [] + for i in range(batch): + context.append(input_data[i]["context"]) + question.append(input_data[i]["question"]) + + return context, question + + +def positional_ids(config, input_ids, past_key_values_length=0): + seq_length = input_ids.size(1) + position_ids = torch.arange(config.max_position_embeddings, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0)[:, past_key_values_length : seq_length + past_key_values_length] + position_ids = position_ids.expand_as(input_ids) + + return position_ids + + +def run_bert_question_and_answering_inference( + mesh_device, + use_program_cache, + model_name, + sequence_size, + model_location_generator, + input_path, +): + disable_persistent_kernel_cache() + model = str(model_location_generator(model_name, model_subdir="Bert")) + hugging_face_reference_model = BertForQuestionAnswering.from_pretrained(model, torchscript=False) + pytorch_model = hugging_face_reference_model.eval() + + tokenizer_name = str(model_location_generator(model_name, model_subdir="Bert")) + tokenizer = BertTokenizer.from_pretrained(tokenizer_name) + config = hugging_face_reference_model.config + nlp = pipeline("question-answering", model=hugging_face_reference_model, tokenizer=tokenizer) + + profiler.start(f"preprocessing_parameter") + mesh_device_flag = is_wormhole_b0() and ttnn.GetNumAvailableDevices() == 2 + batch_size = 16 if mesh_device_flag else 8 + inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) + + with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + parameters = preprocess_model_parameters( + initialize_model=lambda: pytorch_model, + device=mesh_device, + convert_to_ttnn=lambda *_: True, + ) + + profiler.end(f"preprocessing_parameter") + + context, question = load_inputs(input_path, batch_size) + + preprocess_params, _, postprocess_params = nlp._sanitize_parameters() + preprocess_params["max_seq_len"] = sequence_size + inputs = nlp._args_parser({"context": context, "question": question}) + preprocessed_inputs = [] + for i in range(batch_size): + model_input = next(nlp.preprocess(inputs[0][i], **preprocess_params)) + single_input = { + "example": model_input["example"], + "inputs": model_input, + } + preprocessed_inputs.append(single_input) + + bert_input = tokenizer.batch_encode_plus( + zip(question, context), + max_length=sequence_size, + padding="max_length", + truncation=True, + return_attention_mask=True, + return_token_type_ids=True, + return_tensors="pt", + ) + + position_ids = positional_ids(config, bert_input.input_ids) + profiler.start(f"preprocessing_input") + ttnn_bert_inputs = preprocess_inputs( + bert_input["input_ids"], + bert_input["token_type_ids"], + position_ids, + bert_input["attention_mask"], + mesh_device=mesh_device, + inputs_mesh_mapper=inputs_mesh_mapper, + ) + profiler.end(f"preprocessing_input") + + profiler.start(f"inference_time") + ttnn_output = bert_for_question_answering( + config, + *ttnn_bert_inputs, + parameters=parameters, + device=mesh_device, + ) + profiler.end(f"inference_time") + + ttnn_output = ( + ttnn.to_torch(ttnn.from_device(ttnn_output), mesh_composer=output_mesh_composer) + .reshape(batch_size, 1, sequence_size, -1) + .to(torch.float32) + ) + + ttnn_start_logits = ttnn_output[..., :, 0].squeeze(1) + ttnn_end_logits = ttnn_output[..., :, 1].squeeze(1) + + model_answers = {} + profiler.start("post_processing_output_to_string") + for i in range(batch_size): + tt_res = { + "start": ttnn_start_logits[i], + "end": ttnn_end_logits[i], + "example": preprocessed_inputs[i]["example"], + **preprocessed_inputs[i]["inputs"], + } + + tt_answer = nlp.postprocess([tt_res], **postprocess_params) + + logger.info(f"answer: {tt_answer['answer']}\n") + model_answers[i] = tt_answer["answer"] + + profiler.end("post_processing_output_to_string") + + measurements = { + "preprocessing_parameter": profiler.get("preprocessing_parameter"), + "preprocessing_input": profiler.get("preprocessing_input"), + "inference_time": profiler.get("inference_time"), + "post_processing": profiler.get("post_processing_output_to_string"), + } + logger.info(f"preprocessing_parameter: {measurements['preprocessing_parameter']} s") + logger.info(f"preprocessing_input: {measurements['preprocessing_input']} s") + logger.info(f"inference_time: {measurements['inference_time']} s") + logger.info(f"post_processing : {measurements['post_processing']} s") + + return measurements + + +def run_bert_question_and_answering_inference_squad_v2( + mesh_device, + use_program_cache, + model_name, + sequence_size, + model_location_generator, + n_iterations, +): + disable_persistent_kernel_cache() + + model = str(model_location_generator(model_name, model_subdir="Bert")) + hugging_face_reference_model = BertForQuestionAnswering.from_pretrained(model, torchscript=False) + pytorch_model = hugging_face_reference_model.eval() + + # set up tokenizer + tokenizer_name = str(model_location_generator(model_name, model_subdir="Bert")) + tokenizer = BertTokenizer.from_pretrained(tokenizer_name) + config = hugging_face_reference_model.config + + mesh_device_flag = is_wormhole_b0() and ttnn.GetNumAvailableDevices() == 2 + batch_size = 16 if mesh_device_flag else 8 + + inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) + with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + parameters = preprocess_model_parameters( + initialize_model=lambda: pytorch_model, + device=mesh_device, + convert_to_ttnn=lambda *_: True, + ) + + nlp = pipeline("question-answering", model=hugging_face_reference_model, tokenizer=tokenizer) + + attention_mask = True + token_type_ids = True + inputs_squadv2 = squadv2_1K_samples_input(tokenizer, sequence_size, attention_mask, token_type_ids, batch_size) + squad_metric = evaluate.load("squad_v2") + + with torch.no_grad(): + pred_labels = [] + cpu_pred_labels = [] + true_labels = [] + i = 0 + for batch in inputs_squadv2: + if i < n_iterations: + batch_data = batch[0] + curr_batch_size = batch_data["input_ids"].shape[0] + position_ids = positional_ids(config, batch_data.input_ids) + ttnn_bert_inputs = preprocess_inputs( + batch_data["input_ids"], + batch_data["token_type_ids"], + position_ids, + batch_data["attention_mask"], + mesh_device=mesh_device, + inputs_mesh_mapper=inputs_mesh_mapper, + ) + tt_output = bert_for_question_answering( + config, + *ttnn_bert_inputs, + parameters=parameters, + device=mesh_device, + ) + tt_output = ( + ttnn.to_torch(ttnn.from_device(tt_output), mesh_composer=output_mesh_composer) + .reshape(batch_size, 1, sequence_size, -1) + .to(torch.float32) + ) + cpu_output = hugging_face_reference_model(**batch_data) + references = batch[1] + question = batch[2] + context = batch[3] + + cpu_predictions, tt_predictions = squadv2_answer_decode_batch( + hugging_face_reference_model, + tokenizer, + nlp, + references, + cpu_output, + tt_output, + curr_batch_size, + question, + context, + ) + pred_labels.extend(tt_predictions) + cpu_pred_labels.extend(cpu_predictions) + true_labels.extend(references) + + del tt_output + i += 1 + eval_score = squad_metric.compute(predictions=pred_labels, references=true_labels) + cpu_eval_score = squad_metric.compute(predictions=cpu_pred_labels, references=true_labels) + logger.info(f"TT_Eval: exact: {eval_score['exact']} -- F1: {eval_score['f1']}") + logger.info(f"CPU_Eval: exact: {cpu_eval_score['exact']} -- F1: {cpu_eval_score['f1']}") + + +@skip_for_grayskull() +@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576}], indirect=True) +@pytest.mark.parametrize("sequence_size", [128]) +@pytest.mark.parametrize("model_name", ["mrm8488/bert-tiny-finetuned-squadv2"]) +@pytest.mark.parametrize("input_loc", ["models/demos/wormhole/bert_tiny/demo/input_data.json"]) +def test_demo( + input_loc, + sequence_size, + model_name, + model_location_generator, + mesh_device, + use_program_cache, +): + disable_persistent_kernel_cache() + disable_compilation_reports() + + return run_bert_question_and_answering_inference( + mesh_device=mesh_device, + use_program_cache=use_program_cache, + model_name=model_name, + sequence_size=sequence_size, + model_location_generator=model_location_generator, + input_path=input_loc, + ) + + +@skip_for_grayskull() +@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576}], indirect=True) +@pytest.mark.parametrize("sequence_size", [384]) +@pytest.mark.parametrize("model_name", ["mrm8488/bert-tiny-finetuned-squadv2"]) +@pytest.mark.parametrize( + "n_iterations", + ((1),), +) +def test_demo_squadv2( + model_name, + sequence_size, + n_iterations, + model_location_generator, + mesh_device, + use_program_cache, +): + disable_persistent_kernel_cache() + disable_compilation_reports() + + return run_bert_question_and_answering_inference_squad_v2( + mesh_device=mesh_device, + use_program_cache=use_program_cache, + model_name=model_name, + sequence_size=sequence_size, + model_location_generator=model_location_generator, + n_iterations=n_iterations, + ) diff --git a/models/demos/wormhole/bert_tiny/demo/input_data.json b/models/demos/wormhole/bert_tiny/demo/input_data.json new file mode 100644 index 000000000000..f182d147a451 --- /dev/null +++ b/models/demos/wormhole/bert_tiny/demo/input_data.json @@ -0,0 +1,66 @@ +[ + { + "context" : "Johann Joachim Winckelmann was a German art historian and archaeologist. He was a pioneering Hellenist who first articulated the difference between Greek, Greco-Roman and Roman art. The prophet and founding hero of modern archaeology, Winckelmann was one of the founders of scientific archaeology and first applied the categories of style on a large, systematic basis to the history of art.", + "question" : "What discipline did Winkelmann create?" + }, + { + "context" : "The Norman dynasty had a major political, cultural and military impact on medieval Europe and even the Near East. The Normans were famed for their martial spirit and eventually for their Christian piety, becoming exponents of the Catholic orthodoxy into which they assimilated. They adopted the Gallo-Romance language of the Frankish land they settled, their dialect becoming known as Norman, Normaund or Norman French, an important literary language. The Duchy of Normandy, which they formed by treaty with the French crown, was a great fief of medieval France, and under Richard I of Normandy was forged into a cohesive and formidable principality in feudal tenure. The Normans are noted both for their culture, such as their unique Romanesque architecture and musical traditions, and for their significant military accomplishments and innovations. Norman adventurers founded the Kingdom of Sicily under Roger II after conquering southern Italy on the Saracens and Byzantines, and an expedition on behalf of their duke, William the Conqueror, led to the Norman conquest of England at the Battle of Hastings in 1066. Norman cultural and military influence spread from these new European centres to the Crusader states of the Near East, where their prince Bohemond I founded the Principality of Antioch in the Levant, to Scotland and Wales in Great Britain, to Ireland, and to the coasts of north Africa and the Canary Islands.", + "question" : "Who ruled the duchy of Normandy" + }, + { + "context" : "In many countries, there is a Gender pay gap in favor of males in the labor market. Several factors other than discrimination may contribute to this gap. On average, women are more likely than men to consider factors other than pay when looking for work, and may be less willing to travel or relocate. Thomas Sowell, in his book Knowledge and Decisions, claims that this difference is due to women not taking jobs due to marriage or pregnancy, but income studies show that that does not explain the entire difference. A U.S. Census's report stated that in US once other factors are accounted for there is still a difference in earnings between women and men. The income gap in other countries ranges from 53% in Botswana to -40% in Bahrain.", + "question" : "Who does a gender pay gap tend to favor?" + }, + { + "context" : "Most of the Huguenot congregations (or individuals) in North America eventually affiliated with other Protestant denominations with more numerous members. The Huguenots adapted quickly and often married outside their immediate French communities, which led to their assimilation. Their descendants in many families continued to use French first names and surnames for their children well into the nineteenth century. Assimilated, the French made numerous contributions to United States economic life, especially as merchants and artisans in the late Colonial and early Federal periods. For example, E.I. du Pont, a former student of Lavoisier, established the Eleutherian gunpowder mills.", + "question" : "How were Huguenot settlers assimilated into North American society at large?" + }, + { + "context" : "In the laboratory, biostratigraphers analyze rock samples from outcrop and drill cores for the fossils found in them. These fossils help scientists to date the core and to understand the depositional environment in which the rock units formed. Geochronologists precisely date rocks within the stratigraphic section in order to provide better absolute bounds on the timing and rates of deposition. Magnetic stratigraphers look for signs of magnetic reversals in igneous rock units within the drill cores. Other scientists perform stable isotope studies on the rocks to gain information about past climate.", + "question" : "Who analyzes rock samples from drill cores in the lab?" + }, + { + "context" : "Neutrophils and macrophages are phagocytes that travel throughout the body in pursuit of invading pathogens. Neutrophils are normally found in the bloodstream and are the most abundant type of phagocyte, normally representing 50% to 60% of the total circulating leukocytes. During the acute phase of inflammation, particularly as a result of bacterial infection, neutrophils migrate toward the site of inflammation in a process called chemotaxis, and are usually the first cells to arrive at the scene of infection. Macrophages are versatile cells that reside within tissues and produce a wide array of chemicals including enzymes, complement proteins, and regulatory factors such as interleukin 1. Macrophages also act as scavengers, ridding the body of worn-out cells and other debris, and as antigen-presenting cells that activate the adaptive immune system.", + "question" : "What is the process in which neutrophils move towards the site of inflammation called?" + }, + { + "context" : "In Afghanistan, the mujahideen's victory against the Soviet Union in the 1980s did not lead to justice and prosperity, due to a vicious and destructive civil war between political and tribal warlords, making Afghanistan one of the poorest countries on earth. In 1992, the Democratic Republic of Afghanistan ruled by communist forces collapsed, and democratic Islamist elements of mujahdeen founded the Islamic State of Afghanistan. In 1996, a more conservative and anti-democratic Islamist movement known as the Taliban rose to power, defeated most of the warlords and took over roughly 80% of Afghanistan.", + "question" : "When did the Democratic Republic of Afghanistan collapse?" + }, + { + "context" : "The largest single sensory feature is the aboral organ (at the opposite end from the mouth). Its main component is a statocyst, a balance sensor consisting of a statolith, a solid particle supported on four bundles of cilia, called \"balancers\", that sense its orientation. The statocyst is protected by a transparent dome made of long, immobile cilia. A ctenophore does not automatically try to keep the statolith resting equally on all the balancers. Instead its response is determined by the animal's \"mood\", in other words the overall state of the nervous system. For example, if a ctenophore with trailing tentacles captures prey, it will often put some comb rows into reverse, spinning the mouth towards the prey.", + "question" : "What is the main component of the aboral organ?" + }, + { + "context": "Mark Rothko was a Latvian-born American abstract painter. He is best known for his color field paintings that depicted irregular and painterly rectangular regions of color, which he produced from 1949 to 1970. Although Rothko did not personally subscribe to any one school, he is associated with the American Abstract Expressionist movement of modern art. Originally emigrating to Portland, Oregon, from Russian Empire (Latvia) with his family, Rothko later moved to New York City where his youthful period of artistic production dealt primarily with urban scenery.", + "question": "what is Rothko best known for?" + }, + { + "context": "Malignant narcissism is a psychological syndrome that could include aspects of narcissistic personality disorder (NPD) alongside a mix of antisocial, paranoid and sadistic personality disorder traits. The importance of malignant narcissism and of projection as a defense mechanism has been confirmed in paranoia, as well as the patient's vulnerability to malignant narcissistic regression. A person with malignant narcissism exhibits paranoia in addition to the symptoms of a Narcissistic Personality Disorder. Because a malignant narcissist's personality cannot tolerate any criticism, being mocked typically causes paranoia.", + "question": "What symptoms a malignant narcissist might exhibit in addition to the symptoms of a NPD patient?" + }, + { + "context": "The 14 July Revolution, also known as the 1958 Iraqi military coup, was a coup d'état that took place on 14 July 1958 in Iraq which resulted in the toppling of King Faisal II and the overthrow of the Hashemite-led Kingdom of Iraq. The Iraqi Republic established in its wake ended the Hashemite Arab Federation between Iraq and Jordan that had been established just six months earlier. In July 1958, units of the Royal Iraqi Army were dispatched to Jordan in support of King Hussein. A group of Iraqi Free Officers, led by Brigadier Abd al-Karim Qasim and Colonel Abdul Salam Arif, took advantage of the opportunity and instead marched on Baghdad. On 14 July, revolutionary forces seized control of the capital and proclaimed a new republic, headed by a Revolutionary Council.", + "question": "When was the Hashemite Arab Federation formed?" + }, + { + "context": "The Tasmanian devil is a carnivorous marsupial of the family Dasyuridae. It was formerly present across mainland Australia, but became extinct there around 3,500 years ago. The size of a small dog, the Tasmanian devil became the largest carnivorous marsupial in the world following the extinction of the thylacine in 1936. It is related to quolls, and distantly related to the thylacine. It is characterised by its stocky and muscular build, black fur, pungent odour, extremely loud and disturbing screech, keen sense of smell, and ferocity when feeding. The Tasmanian devil's large head and neck allow it to generate among the strongest bites per unit body mass of any extant predatory land mammal. It hunts prey and scavenges on carrion.", + "question": "What allows Tasmanian devil to generate strong bites?" + }, + { + "context": "Soon after the Normans began to enter Italy, they entered the Byzantine Empire and then Armenia, fighting against the Pechenegs, the Bulgars, and especially the Seljuk Turks. Norman mercenaries were first encouraged to come to the south by the Lombards to act against the Byzantines, but they soon fought in Byzantine service in Sicily. They were prominent alongside Varangian and Lombard contingents in the Sicilian campaign of George Maniaces in 1038-40. There is debate whether the Normans in Greek service actually were from Norman Italy, and it now seems likely only a few came from there. It is also unknown how many of the 'Franks', as the Byzantines called them, were Normans and not other Frenchmen.", + "question": "Who was the Normans' main enemy in Italy, the Byzantine Empire and Armenia?" + }, + { + "context": "Soon after the Normans began to enter Italy, they entered the Byzantine Empire and then Armenia, fighting against the Pechenegs, the Bulgars, and especially the Seljuk Turks. Norman mercenaries were first encouraged to come to the south by the Lombards to act against the Byzantines, but they soon fought in Byzantine service in Sicily. They were prominent alongside Varangian and Lombard contingents in the Sicilian campaign of George Maniaces in 1038-40. There is debate whether the Normans in Greek service actually were from Norman Italy, and it now seems likely only a few came from there. It is also unknown how many of the 'Franks', as the Byzantines called them, were Normans and not other Frenchmen.", + "question": "Who entered Italy soon after the Byzantine Empire?" + }, + { + "context": "Soon after the Normans began to enter Italy, they entered the Byzantine Empire and then Armenia, fighting against the Pechenegs, the Bulgars, and especially the Seljuk Turks. Norman mercenaries were first encouraged to come to the south by the Lombards to act against the Byzantines, but they soon fought in Byzantine service in Sicily. They were prominent alongside Varangian and Lombard contingents in the Sicilian campaign of George Maniaces in 1038-40. There is debate whether the Normans in Greek service actually were from Norman Italy, and it now seems likely only a few came from there. It is also unknown how many of the 'Franks', as the Byzantines called them, were Normans and not other Frenchmen.", + "question": "Who did the Normans fight in Italy?" + }, + { + "context": "Soon after the Normans began to enter Italy, they entered the Byzantine Empire and then Armenia, fighting against the Pechenegs, the Bulgars, and especially the Seljuk Turks. Norman mercenaries were first encouraged to come to the south by the Lombards to act against the Byzantines, but they soon fought in Byzantine service in Sicily. They were prominent alongside Varangian and Lombard contingents in the Sicilian campaign of George Maniaces in 1038-40. There is debate whether the Normans in Greek service actually were from Norman Italy, and it now seems likely only a few came from there. It is also unknown how many of the 'Franks', as the Byzantines called them, were Normans and not other Frenchmen.", + "question": "Who did the Normans encourage to come to the south?" + } +] diff --git a/models/demos/wormhole/bert_tiny/tests/test_performance.py b/models/demos/wormhole/bert_tiny/tests/test_performance.py new file mode 100644 index 000000000000..97ecd7105594 --- /dev/null +++ b/models/demos/wormhole/bert_tiny/tests/test_performance.py @@ -0,0 +1,148 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +import pytest +import ttnn +import time + +from loguru import logger +import ttnn +from ttnn.model_preprocessing import preprocess_model_parameters +from models.utility_functions import ( + enable_persistent_kernel_cache, + disable_persistent_kernel_cache, +) +from models.perf.perf_utils import prep_perf_report +from models.perf.device_perf_utils import run_device_perf, check_device_perf, prep_device_perf_report +from transformers import BertForQuestionAnswering +from models.demos.wormhole.bert_tiny.tt.bert_tiny import bert_for_question_answering +from models.utility_functions import is_wormhole_b0, skip_for_grayskull + + +def get_expected_times(bert_tiny): + return (38.5, 1.6) + + +@skip_for_grayskull() +@pytest.mark.models_performance_bare_metal +@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576}], indirect=True) +@pytest.mark.parametrize("sequence_size", [128]) +@pytest.mark.parametrize("model_name", ["mrm8488/bert-tiny-finetuned-squadv2"]) +def test_perf_bert_tiny( + mesh_device, + sequence_size, + model_name, + model_location_generator, + reset_seeds, +): + disable_persistent_kernel_cache() + model_name = str(model_location_generator(model_name, model_subdir="Bert")) + hugging_face_reference_model = BertForQuestionAnswering.from_pretrained(model_name, torchscript=False) + config = hugging_face_reference_model.config + pytorch_model = hugging_face_reference_model + + mesh_device_flag = is_wormhole_b0() and ttnn.GetNumAvailableDevices() == 2 + batch_size = 16 if mesh_device_flag else 8 + + torch_bert_input = torch.randint(0, 100, (batch_size, sequence_size)).to(torch.int32) + torch_token_type_ids = torch.zeros((batch_size, sequence_size), dtype=torch.int32) + torch_position_ids = torch.zeros((batch_size, sequence_size), dtype=torch.int32) + torch_attention_mask = torch.zeros(1, sequence_size) + + inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + + with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + parameters = preprocess_model_parameters( + initialize_model=lambda: pytorch_model, + device=mesh_device, + convert_to_ttnn=lambda *_: True, + ) + + ttnn_bert_inputs = ttnn.from_torch( + torch_bert_input, dtype=ttnn.uint32, mesh_mapper=inputs_mesh_mapper, device=mesh_device + ) + ttnn_token_type_ids = ttnn.from_torch( + torch_token_type_ids, dtype=ttnn.uint32, mesh_mapper=inputs_mesh_mapper, device=mesh_device + ) + ttnn_position_ids = ttnn.from_torch( + torch_position_ids, dtype=ttnn.uint32, mesh_mapper=inputs_mesh_mapper, device=mesh_device + ) + ttnn_attention_mask = ttnn.from_torch( + torch_attention_mask, + dtype=ttnn.bfloat16, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + device=mesh_device, + ) + durations = [] + for i in range(2): + start = time.time() + ttnn_output = bert_for_question_answering( + config, + input_ids=ttnn_bert_inputs, + token_type_ids=ttnn_token_type_ids, + position_ids=ttnn_position_ids, + attention_mask=ttnn_attention_mask, + parameters=parameters, + device=mesh_device, + ) + output = ttnn.from_device(ttnn_output) + + end = time.time() + durations.append(end - start) + enable_persistent_kernel_cache() + + inference_and_compile_time, inference_time, *_ = durations + + expected_compile_time, expected_inference_time = get_expected_times("bert_tiny") + prep_perf_report( + model_name="bert_tiny", + batch_size=batch_size, + inference_and_compile_time=inference_and_compile_time, + inference_time=inference_time, + expected_compile_time=expected_compile_time, + expected_inference_time=expected_inference_time, + comments="", + inference_time_cpu=0.0, + ) + + logger.info(f"Compile time: {inference_and_compile_time - inference_time}") + logger.info(f"Inference time: {inference_time}") + logger.info(f"Samples per second: {1 / inference_time * batch_size}") + assert ( + inference_time < expected_inference_time + ), f"Expected inference time: {expected_inference_time} Actual inference time: {inference_time}" + logger.info("Exit Bert-Tiny perf test") + + +@skip_for_grayskull() +@pytest.mark.models_device_performance_bare_metal +@pytest.mark.parametrize( + "batch_size, expected_perf", + [ + (16, 6537.95), + ], +) +def test_perf_device_bare_metal(batch_size, expected_perf): + subdir = "ttnn_bert_tiny" + num_iterations = 1 + margin = 0.03 + + command = f"pytest tests/ttnn/integration_tests/bert_tiny/test_bert_tiny_wh.py::test_bert_for_question_answering" + cols = ["DEVICE FW", "DEVICE KERNEL", "DEVICE BRISC KERNEL"] + + inference_time_key = "AVG DEVICE KERNEL SAMPLES/S" + expected_perf_cols = {inference_time_key: expected_perf} + + mesh_device_flag = is_wormhole_b0() and ttnn.GetNumAvailableDevices() == 2 + post_processed_results = run_device_perf(command, subdir, num_iterations, cols, batch_size) + expected_results = check_device_perf(post_processed_results, margin, expected_perf_cols, assert_on_fail=True) + prep_device_perf_report( + model_name=f"ttnn_bert_tiny{batch_size}", + batch_size=batch_size if mesh_device_flag else 8, + post_processed_results=post_processed_results, + expected_results=expected_results, + comments="", + ) diff --git a/models/demos/wormhole/bert_tiny/tt/bert_tiny.py b/models/demos/wormhole/bert_tiny/tt/bert_tiny.py new file mode 100644 index 000000000000..98c437c98e8c --- /dev/null +++ b/models/demos/wormhole/bert_tiny/tt/bert_tiny.py @@ -0,0 +1,333 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import ttnn +from models.experimental.functional_common.attention_mask_functions import get_extended_attention_mask + + +def bert_attention( + config, + hidden_states, + attention_mask, + device=None, + *, + parameters, +): + num_heads = config.num_attention_heads + batch_size, sequence_size, hidden_size = hidden_states.shape + head_size = hidden_size // num_heads + + query = ttnn.linear( + hidden_states, + parameters.self.query.weight, + bias=parameters.self.query.bias, + core_grid=device.core_grid, + memory_config=ttnn.L1_MEMORY_CONFIG, + dtype=ttnn.bfloat16, + ) + query = ttnn.to_layout(query, layout=ttnn.ROW_MAJOR_LAYOUT) + query = ttnn.from_device(query) + query = ttnn.reshape(query, (batch_size, sequence_size, num_heads, head_size)) + query = ttnn.to_layout(query, layout=ttnn.TILE_LAYOUT) + query = ttnn.to_device(query, device) + query = ttnn.permute(query, (0, 2, 1, 3)) + + key = ttnn.linear( + hidden_states, + parameters.self.key.weight, + bias=parameters.self.key.bias, + core_grid=device.core_grid, + memory_config=ttnn.L1_MEMORY_CONFIG, + dtype=ttnn.bfloat16, + ) + key = ttnn.to_layout(key, layout=ttnn.ROW_MAJOR_LAYOUT) + key = ttnn.from_device(key) + key = ttnn.reshape(key, (batch_size, sequence_size, num_heads, head_size)) + key = ttnn.to_layout(key, layout=ttnn.TILE_LAYOUT) + key = ttnn.to_device(key, device) + key = ttnn.permute(key, (0, 2, 3, 1)) + + value = ttnn.linear( + hidden_states, + parameters.self.value.weight, + bias=parameters.self.value.bias, + core_grid=device.core_grid, + memory_config=ttnn.L1_MEMORY_CONFIG, + dtype=ttnn.bfloat16, + ) + value = ttnn.to_layout(value, layout=ttnn.ROW_MAJOR_LAYOUT) + value = ttnn.from_device(value) + value = ttnn.reshape(value, (batch_size, sequence_size, num_heads, head_size)) + value = ttnn.to_layout(value, layout=ttnn.TILE_LAYOUT) + value = ttnn.to_device(value, device) + value = ttnn.permute(value, (0, 2, 1, 3)) + + attention_scores = ttnn.matmul(query, key) + attention_scores = attention_scores * (1 / (head_size**0.5)) + if attention_mask is not None: + attention_scores = ttnn.to_layout(attention_scores, ttnn.TILE_LAYOUT) + attention_scores = ttnn.to_device(attention_scores, device) + attention_mask = ttnn.to_layout(attention_mask, ttnn.TILE_LAYOUT) + attention_scores = attention_scores + attention_mask + value = ttnn.to_device(value, device) + + attention_probs = ttnn.softmax(attention_scores, dim=-1) + + context_layer = attention_probs @ value + context_layer = ttnn.permute(context_layer, (0, 2, 1, 3)) + context_layer = ttnn.to_layout(context_layer, ttnn.ROW_MAJOR_LAYOUT) + context_layer = ttnn.from_device(context_layer) + context_layer = ttnn.reshape(context_layer, (batch_size, sequence_size, hidden_size)) + context_layer = ttnn.to_device(context_layer, device) + context_layer = ttnn.to_layout(context_layer, ttnn.TILE_LAYOUT) + + self_output = context_layer + self_output = ttnn.linear( + self_output, + parameters.output.dense.weight, + bias=parameters.output.dense.bias, + core_grid=device.core_grid, + memory_config=ttnn.L1_MEMORY_CONFIG, + dtype=ttnn.bfloat16, + ) + + attention_output = ttnn.layer_norm( + hidden_states + self_output, + weight=parameters.output.LayerNorm.weight, + bias=parameters.output.LayerNorm.bias, + epsilon=config.layer_norm_eps, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + + return attention_output + + +def bert_intermediate( + hidden_states, + device=None, + *, + parameters, +): + output = ttnn.linear( + hidden_states, + parameters.dense.weight, + bias=parameters.dense.bias, + activation="gelu", + core_grid=device.core_grid, + memory_config=ttnn.L1_MEMORY_CONFIG, + dtype=ttnn.bfloat16, + ) + return output + + +def bert_output( + config, + hidden_states, + residual, + device=None, + *, + parameters, +): + output = ttnn.linear( + hidden_states, + parameters.dense.weight, + bias=parameters.dense.bias, + core_grid=device.core_grid, + memory_config=ttnn.L1_MEMORY_CONFIG, + dtype=ttnn.bfloat16, + ) + + output = ttnn.layer_norm( + output + residual, + weight=parameters.LayerNorm.weight, + bias=parameters.LayerNorm.bias, + epsilon=config.layer_norm_eps, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + + return output + + +def bert_feedforward( + config, + hidden_states, + device=None, + *, + parameters, +): + intermediate = bert_intermediate(hidden_states, parameters=parameters.intermediate, device=device) + hidden_states = bert_output(config, intermediate, hidden_states, parameters=parameters.output, device=device) + return hidden_states + + +def bert_layer( + config, + hidden_states, + attention_mask, + device=None, + *, + parameters, +): + attention_output = bert_attention( + config, + hidden_states, + attention_mask, + parameters=parameters.attention, + device=device, + ) + + feedforward_output = bert_feedforward( + config, + attention_output, + parameters=parameters, + device=device, + ) + + return feedforward_output + + +def bert_encoder( + config, + hidden_states, + attention_mask, + device=None, + *, + parameters, +): + encoder_input = hidden_states + encoder_output = None + for encoder_parameters in parameters.layer: + encoder_output = bert_layer( + config, + encoder_input, + attention_mask, + parameters=encoder_parameters, + device=device, + ) + encoder_input = encoder_output + return encoder_output + + +def bert( + config, + input_ids, + token_type_ids, + position_ids, + attention_mask, + device=None, + *, + parameters, +): + word_embeddings = ttnn.embedding(input_ids, parameters.embeddings.word_embeddings.weight) + token_type_embeddings = ttnn.embedding(token_type_ids, parameters.embeddings.token_type_embeddings.weight) + position_embeddings = ttnn.embedding(position_ids, parameters.embeddings.position_embeddings.weight) + word_embeddings = ttnn.to_layout(word_embeddings, ttnn.TILE_LAYOUT) + token_type_embeddings = ttnn.to_layout(token_type_embeddings, ttnn.TILE_LAYOUT) + position_embeddings = ttnn.to_layout(position_embeddings, ttnn.TILE_LAYOUT) + + embeddings = word_embeddings + token_type_embeddings + position_embeddings + + hidden_states = ttnn.layer_norm( + embeddings, + weight=parameters.embeddings.LayerNorm.weight, + bias=parameters.embeddings.LayerNorm.bias, + epsilon=config.layer_norm_eps, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + + hidden_states = bert_encoder( + config, + hidden_states, + attention_mask, + parameters=parameters.encoder, + device=device, + ) + + return hidden_states + + +def bert_for_question_answering( + config, + input_ids, + token_type_ids, + position_ids, + attention_mask, + device=None, + *, + parameters, + name="bert", +): + bert_output = bert( + config, + input_ids, + token_type_ids, + position_ids, + attention_mask, + device=device, + parameters=parameters[name], + ) + + qa_outputs = bert_output + qa_outputs = ttnn.linear( + qa_outputs, + parameters.qa_outputs.weight, + bias=parameters.qa_outputs.bias, + core_grid=device.core_grid, + memory_config=ttnn.L1_MEMORY_CONFIG, + dtype=ttnn.bfloat16, + ) + return qa_outputs + + +def preprocess_inputs( + input_ids, + token_type_ids, + position_ids, + attention_mask, + mesh_device, + inputs_mesh_mapper, +): + batch_size, _ = input_ids.shape + + input_ids = ttnn.from_torch( + input_ids, + dtype=ttnn.uint32, + memory_config=ttnn.L1_MEMORY_CONFIG, + mesh_mapper=inputs_mesh_mapper, + device=mesh_device, + ) + token_type_ids = ttnn.from_torch( + token_type_ids, + dtype=ttnn.uint32, + memory_config=ttnn.L1_MEMORY_CONFIG, + mesh_mapper=inputs_mesh_mapper, + device=mesh_device, + ) + position_ids = ttnn.from_torch( + position_ids, + dtype=ttnn.uint32, + memory_config=ttnn.L1_MEMORY_CONFIG, + mesh_mapper=inputs_mesh_mapper, + device=mesh_device, + ) + + if attention_mask is not None: + attention_mask = get_extended_attention_mask(attention_mask, input_ids.shape, torch.float32) + attention_mask = attention_mask.expand((batch_size, -1, -1, -1)) + attention_mask = torch.clamp(attention_mask, min=-100000) + attention_mask = ttnn.from_torch( + attention_mask, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.L1_MEMORY_CONFIG, + mesh_mapper=inputs_mesh_mapper, + device=mesh_device, + ) + + return input_ids, token_type_ids, position_ids, attention_mask + + +def custom_preprocessor(torch_model, name): + return {} diff --git a/tests/scripts/run_performance.sh b/tests/scripts/run_performance.sh index 35d1fde36724..1fe98b4b0d22 100755 --- a/tests/scripts/run_performance.sh +++ b/tests/scripts/run_performance.sh @@ -17,6 +17,8 @@ run_perf_models_other() { if [ "$tt_arch" == "wormhole_b0" ]; then env WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest models/demos/wormhole/resnet50/tests/test_perf_e2e_resnet50.py -m $test_marker + + env WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest models/demos/wormhole/bert_tiny/tests/test_performance.py -m $test_marker fi env pytest -n auto tests/ttnn/integration_tests/bert/test_performance.py -m $test_marker @@ -124,6 +126,8 @@ run_device_perf_models() { env WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest models/demos/metal_BERT_large_11/tests -m $test_marker env WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest models/demos/falcon7b_common/tests -m $test_marker + + env WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest models/demos/wormhole/bert_tiny/tests -m $test_marker fi ## Merge all the generated reports diff --git a/tests/scripts/single_card/run_single_card_demo_tests.sh b/tests/scripts/single_card/run_single_card_demo_tests.sh index c13299895dd7..ff7bf88321f3 100755 --- a/tests/scripts/single_card/run_single_card_demo_tests.sh +++ b/tests/scripts/single_card/run_single_card_demo_tests.sh @@ -21,6 +21,8 @@ run_common_func_tests() { #Bert-Tiny WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest -n auto models/demos/bert_tiny/demo/demo.py --timeout 600; fail+=$? + WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest -n auto models/demos/wormhole/bert_tiny/demo/demo.py --timeout 600; fail+=$? + # Bert pytest -n auto --disable-warnings models/demos/metal_BERT_large_11/demo/demo.py -k batch_7; fail+=$? WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest -n auto --disable-warnings models/demos/metal_BERT_large_11/demo/demo.py -k batch_8; fail+=$? diff --git a/tests/ttnn/integration_tests/bert_tiny/test_bert_tiny_wh.py b/tests/ttnn/integration_tests/bert_tiny/test_bert_tiny_wh.py new file mode 100644 index 000000000000..d309befa0b28 --- /dev/null +++ b/tests/ttnn/integration_tests/bert_tiny/test_bert_tiny_wh.py @@ -0,0 +1,298 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 +import torch +import pytest +import ttnn + +from transformers import BertForQuestionAnswering, BertConfig +from models.demos.wormhole.bert_tiny.tt.bert_tiny import ( + bert_for_question_answering, + bert_attention, + bert_intermediate, + bert_output, + bert_layer, +) +from ttnn.model_preprocessing import preprocess_model_parameters +from tests.ttnn.utils_for_testing import assert_with_pcc +from models.utility_functions import skip_for_grayskull, is_wormhole_b0 + + +@skip_for_grayskull() +@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576}], indirect=True) +def test_bert_attention_inference( + model_location_generator, + mesh_device, + reset_seeds, +): + model_name = str(model_location_generator("mrm8488/bert-tiny-finetuned-squadv2", model_subdir="Bert")) + hugging_face_reference_model = BertForQuestionAnswering.from_pretrained(model_name, torchscript=False) + + encoder_idx = 0 + pytorch_attention_model = hugging_face_reference_model.bert.encoder.layer[encoder_idx].attention + config = hugging_face_reference_model.config + mesh_device_flag = is_wormhole_b0() and ttnn.GetNumAvailableDevices() == 2 + batch_size = 16 if mesh_device_flag else 8 + inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) + + with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + parameters = preprocess_model_parameters( + initialize_model=lambda: pytorch_attention_model, + device=mesh_device, + convert_to_ttnn=lambda *_: True, + ) + + input = (torch.rand(batch_size, 1, 128, hugging_face_reference_model.config.hidden_size) * 2) - 1 + torch_attention_mask = torch.zeros(1, 128) + pytorch_out = pytorch_attention_model(input.squeeze(1), attention_mask=torch_attention_mask)[0] + + tt_input = ttnn.from_torch( + input.squeeze(1), + device=mesh_device, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + mesh_mapper=inputs_mesh_mapper, + ) + tt_attention_mask = ttnn.from_torch( + torch_attention_mask, + device=mesh_device, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + mesh_mapper=inputs_mesh_mapper, + ) + + tt_output = bert_attention( + config=config, + hidden_states=tt_input, + attention_mask=tt_attention_mask, + parameters=parameters, + device=mesh_device, + ) + tt_output = ttnn.from_device(tt_output) + tt_output = ttnn.to_torch(tt_output, mesh_composer=output_mesh_composer) + + assert_with_pcc(pytorch_out, tt_output, 0.99) + + +@skip_for_grayskull() +@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576}], indirect=True) +def test_bert_intermediate_inference( + model_location_generator, + mesh_device, + reset_seeds, +): + model_name = str(model_location_generator("mrm8488/bert-tiny-finetuned-squadv2", model_subdir="Bert")) + hugging_face_reference_model = BertForQuestionAnswering.from_pretrained(model_name, torchscript=False) + + encoder_idx = 0 + pytorch_intermediate_model = hugging_face_reference_model.bert.encoder.layer[encoder_idx].intermediate + + mesh_device_flag = is_wormhole_b0() and ttnn.GetNumAvailableDevices() == 2 + batch_size = 16 if mesh_device_flag else 8 + inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) + + with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + parameters = preprocess_model_parameters( + initialize_model=lambda: pytorch_intermediate_model, + device=mesh_device, + convert_to_ttnn=lambda *_: True, + ) + + input = (torch.rand(batch_size, 1, 128, hugging_face_reference_model.config.hidden_size) * 2) - 1 + pytorch_out = pytorch_intermediate_model(input).squeeze(0) + + tt_input = ttnn.from_torch( + input.squeeze(1), + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + device=mesh_device, + mesh_mapper=inputs_mesh_mapper, + ) + + tt_output = bert_intermediate( + hidden_states=tt_input, + parameters=parameters, + device=mesh_device, + ) + tt_output = ttnn.from_device(tt_output) + tt_output = ttnn.to_torch(tt_output, mesh_composer=output_mesh_composer) + assert_with_pcc(pytorch_out.squeeze(1), tt_output, 0.99) + + +@skip_for_grayskull() +@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576}], indirect=True) +def test_bert_output_inference( + model_location_generator, + mesh_device, + reset_seeds, +): + model_name = str(model_location_generator("mrm8488/bert-tiny-finetuned-squadv2", model_subdir="Bert")) + hugging_face_reference_model = BertForQuestionAnswering.from_pretrained(model_name, torchscript=False) + + encoder_idx = 0 + config = hugging_face_reference_model.config + pytorch_output_model = hugging_face_reference_model.bert.encoder.layer[encoder_idx].attention.output + + mesh_device_flag = is_wormhole_b0() and ttnn.GetNumAvailableDevices() == 2 + batch_size = 16 if mesh_device_flag else 8 + inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) + + with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + parameters = preprocess_model_parameters( + initialize_model=lambda: pytorch_output_model, + device=mesh_device, + convert_to_ttnn=lambda *_: True, + ) + + hidden_state = (torch.rand(batch_size, 1, 128, hugging_face_reference_model.config.hidden_size) * 2) - 1 + input = (torch.rand(batch_size, 1, 128, hugging_face_reference_model.config.hidden_size) * 2) - 1 + pytorch_out = pytorch_output_model(hidden_state, input).squeeze(0) + + ttnn_input = ttnn.from_torch( + input.squeeze(1), + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + device=mesh_device, + mesh_mapper=inputs_mesh_mapper, + ) + ttnn_hidden_state = ttnn.from_torch( + hidden_state.squeeze(1), + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + device=mesh_device, + mesh_mapper=inputs_mesh_mapper, + ) + + tt_output = bert_output( + config=config, + hidden_states=ttnn_hidden_state, + residual=ttnn_input, + parameters=parameters, + device=mesh_device, + ) + tt_output = ttnn.from_device(tt_output) + tt_output = ttnn.to_torch(tt_output, mesh_composer=output_mesh_composer) + assert_with_pcc(pytorch_out.squeeze(1), tt_output, 0.99) + + +@skip_for_grayskull() +@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576}], indirect=True) +def test_bert_layer_inference( + model_location_generator, + mesh_device, + reset_seeds, +): + model_name = str(model_location_generator("mrm8488/bert-tiny-finetuned-squadv2", model_subdir="Bert")) + hugging_face_reference_model = BertForQuestionAnswering.from_pretrained(model_name, torchscript=False) + + encoder_idx = 0 + config = hugging_face_reference_model.config + pytorch_layer_model = hugging_face_reference_model.bert.encoder.layer[encoder_idx] + + mesh_device_flag = is_wormhole_b0() and ttnn.GetNumAvailableDevices() == 2 + batch_size = 16 if mesh_device_flag else 8 + inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) + + with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + parameters = preprocess_model_parameters( + initialize_model=lambda: pytorch_layer_model, + device=mesh_device, + convert_to_ttnn=lambda *_: True, + ) + + input = (torch.rand(batch_size, 1, 128, hugging_face_reference_model.config.hidden_size) * 2) - 1 + pytorch_out = pytorch_layer_model(input.squeeze(1))[0] + + ttnn_input = ttnn.from_torch( + input.squeeze(1), + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + mesh_mapper=inputs_mesh_mapper, + device=mesh_device, + ) + + tt_output = bert_layer( + config=config, + hidden_states=ttnn_input, + attention_mask=None, + parameters=parameters, + device=mesh_device, + ) + tt_output = ttnn.from_device(tt_output) + tt_output = ttnn.to_torch(tt_output, mesh_composer=output_mesh_composer) + assert_with_pcc(pytorch_out.squeeze(1), tt_output, 0.99) + + +@skip_for_grayskull() +@pytest.mark.parametrize("model_name", ["mrm8488/bert-tiny-finetuned-squadv2"]) +@pytest.mark.parametrize("sequence_size", [128]) +@pytest.mark.parametrize("num_hidden_layers", [1]) +def test_bert_for_question_answering(mesh_device, model_name, sequence_size, num_hidden_layers, reset_seeds): + inputs_mesh_mapper = None + output_mesh_composer = None + parameters = None + + config = BertConfig.from_pretrained(model_name) + model = BertForQuestionAnswering.from_pretrained(model_name, config=config).eval() + + if num_hidden_layers is not None: + config.num_hidden_layers = num_hidden_layers + + mesh_device_flag = is_wormhole_b0() and ttnn.GetNumAvailableDevices() == 2 + batch_size = 16 if mesh_device_flag else 8 + inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) + + with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + device=mesh_device, + convert_to_ttnn=lambda *_: True, + ) + + torch_input_ids = torch.randint(0, config.vocab_size, (batch_size, sequence_size)).to(torch.int32) + torch_token_type_ids = torch.zeros((batch_size, sequence_size), dtype=torch.int32) + torch_position_ids = torch.zeros((batch_size, sequence_size), dtype=torch.int32) + torch_attention_mask = torch.zeros(1, sequence_size) + torch_output = model( + torch_input_ids, + token_type_ids=torch_token_type_ids, + position_ids=torch_position_ids, + attention_mask=torch_attention_mask, + ) + + ttnn_bert_inputs_ids = ttnn.from_torch( + torch_input_ids, dtype=ttnn.uint32, mesh_mapper=inputs_mesh_mapper, device=mesh_device + ) + ttnn_token_type_ids = ttnn.from_torch( + torch_token_type_ids, dtype=ttnn.uint32, mesh_mapper=inputs_mesh_mapper, device=mesh_device + ) + ttnn_position_ids = ttnn.from_torch( + torch_position_ids, dtype=ttnn.uint32, mesh_mapper=inputs_mesh_mapper, device=mesh_device + ) + ttnn_attention_mask = ttnn.from_torch( + torch_attention_mask, + dtype=ttnn.bfloat16, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + device=mesh_device, + ) + + output = bert_for_question_answering( + config, + ttnn_bert_inputs_ids, + ttnn_token_type_ids, + ttnn_position_ids, + ttnn_attention_mask, + parameters=parameters, + device=mesh_device, + ) + output = ttnn.to_torch(output, mesh_composer=output_mesh_composer) + start_logits = output[..., 0] + end_logits = output[..., 1] + + assert_with_pcc(torch_output.start_logits, start_logits, 0.94) + assert_with_pcc(torch_output.end_logits, end_logits, 0.95)