diff --git a/models/demos/wormhole/squeezebert/README.md b/models/demos/wormhole/squeezebert/README.md new file mode 100644 index 000000000000..7c4274956d40 --- /dev/null +++ b/models/demos/wormhole/squeezebert/README.md @@ -0,0 +1,33 @@ +# SqueezeBERT demo + +Demo showcasing SqueezeBERT running on Grayskull - e150 and Wormhole - n150, n300 using ttnn. + +## Introduction +SqueezeBERT is a bidirectional transformer similar to the BERT model. The key difference between the BERT architecture and the SqueezeBERT architecture is that SqueezeBERT uses grouped convolutions instead of fully-connected layers for the Q, K, V and FFN layers. + + +## Details +The entry point to functional_squeezebert model is squeezebert_for_question_answering in `models/demos/squeezebert/tt/ttnn_functional_squeezebert.py`. The model picks up certain configs and weights from huggingface pretrained model. We have used `squeezebert/squeezebert-uncased` version from huggingface as our reference. + +### Sequence Size: 384 +Sequence size determines the maximum length of input sequences processed by the model, optimizing performance and compatibility. It's recommended to set the sequence_size to 384 + +### Batch size: 16 +Batch Size determines the number of input sequences processed simultaneously during training or inference, impacting computational efficiency and memory usage. On each device, the batch size will be 8, as the operations run in parallel. It's recommended to set the batch_size to 16 + +## How to Run + +Use `pytest --disable-warnings models/demos/wormhole/squeezebert/demo/demo.py::test_demo[wormhole_b0-True-models.demos.wormhole.squeezebert.tt.ttnn_functional_squeezebert-squeezebert/squeezebert-uncased-models/demos/wormhole/squeezebert/demo/input_data.json-8-384-device_params0]` to run the demo. + +If you wish to run the demo with a different input use `pytest --disable-warnings models/demos/wormhole/squeezebert/demo/demo.py::test_demo[wormhole_b0-True-models.demos.wormhole.squeezebert.tt.ttnn_functional_squeezebert-squeezebert/squeezebert-uncased--8-384-device_params0] `. This file is expected to have exactly 16 inputs. + +Our second demo is designed to run SQuADV2 dataset, run this with `pytest --disable-warnings models/demos/wormhole/squeezebert/demo/demo.py::test_demo_squadv2[wormhole_b0-True-3-models.demos.wormhole.squeezebert.tt.ttnn_functional_squeezebert-squeezebert/squeezebert-uncased-8-384-device_params0]`. + +If you wish to run for `n_iterations` samples, use `pytest --disable-warnings models/demos/wormhole/squeezebert/demo/demo.py::test_demo_squadv2[wormhole_b0-True--models.demos.wormhole.squeezebert.tt.ttnn_functional_squeezebert-squeezebert/squeezebert-uncased-8-384-device_params0]` + + +## Inputs +The demo receives inputs from respective `input_data.json` by default. To modify the inputs or specify a different path, adjust the input_path parameter in the command accordingly. It's recommended to avoid direct modifications to the input_data.json file. + + +#### Owner: [kkeerthana0573](https://github.com/kkeerthana0573) diff --git a/models/demos/wormhole/squeezebert/demo/demo.py b/models/demos/wormhole/squeezebert/demo/demo.py new file mode 100644 index 000000000000..dae2d219f0f6 --- /dev/null +++ b/models/demos/wormhole/squeezebert/demo/demo.py @@ -0,0 +1,343 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import ttnn +import json +import torch +import pytest +import evaluate +import transformers +from loguru import logger +from models.utility_functions import ( + profiler, + is_wormhole_b0, + run_for_wormhole_b0, + disable_compilation_reports, + disable_persistent_kernel_cache, +) +from ttnn.model_preprocessing import preprocess_model_parameters +from models.demos.wormhole.squeezebert.tt import ttnn_functional_squeezebert +from models.datasets.dataset_squadv2 import squadv2_1K_samples_input, squadv2_answer_decode_batch + + +def load_inputs(input_path, batch): + with open(input_path) as f: + input_data = json.load(f) + assert len(input_data) >= batch, f"Input data needs to have at least {batch} (batch size) entries." + + context = [] + question = [] + for i in range(batch): + context.append(input_data[i]["context"]) + question.append(input_data[i]["question"]) + + return context, question + + +def positional_ids(config, input_ids, past_key_values_length=0): + seq_length = input_ids.size(1) + position_ids = torch.arange(config.max_position_embeddings, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0)[:, past_key_values_length : seq_length + past_key_values_length] + position_ids = position_ids.expand_as(input_ids) + + return position_ids + + +def run_squeezebert_question_and_answering_inference( + mesh_device, + use_program_cache, + model_name, + batch_size, + sequence_size, + squeezebert, + input_path, +): + disable_persistent_kernel_cache() + + hugging_face_reference_model = transformers.SqueezeBertForQuestionAnswering.from_pretrained( + model_name, torchscript=False + ) + + state_dict = hugging_face_reference_model.state_dict() + tt_model_name = f"ttnn_{model_name}_optimized" + + profiler.start(f"preprocessing_parameter") + + inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device) + output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) + with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + parameters = preprocess_model_parameters( + model_name=tt_model_name, + initialize_model=lambda: hugging_face_reference_model, + custom_preprocessor=ttnn_functional_squeezebert.custom_preprocessor, + device=mesh_device, + ) + profiler.end(f"preprocessing_parameter") + + tokenizer = transformers.SqueezeBertTokenizer.from_pretrained(model_name) + config = hugging_face_reference_model.config + nlp = transformers.pipeline("question-answering", model=hugging_face_reference_model, tokenizer=tokenizer) + + context, question = load_inputs(input_path, batch_size) + + preprocess_params, _, postprocess_params = nlp._sanitize_parameters() + preprocess_params["max_seq_len"] = sequence_size + inputs = nlp._args_parser({"context": context, "question": question}) + + preprocessed_inputs = [] + for i in range(batch_size): + model_input = next(nlp.preprocess(inputs[0][i], **preprocess_params)) + + single_input = { + "example": model_input["example"], + "inputs": model_input, + } + preprocessed_inputs.append(single_input) + + squeezebert_input = tokenizer.batch_encode_plus( + zip(question, context), + max_length=sequence_size, + padding="max_length", + truncation=True, + return_attention_mask=True, + return_token_type_ids=True, + return_tensors="pt", + ) + + profiler.start(f"preprocessing_input") + position_ids = positional_ids(config, squeezebert_input.input_ids) + ttnn_squeezebert_inputs = squeezebert.preprocess_inputs( + squeezebert_input["input_ids"], + squeezebert_input["token_type_ids"], + position_ids, + squeezebert_input["attention_mask"], + device=mesh_device, + mesh_mapper=inputs_mesh_mapper, + ) + profiler.end(f"preprocessing_input") + + profiler.start(f"inference_time") + tt_output = squeezebert.squeezebert_for_question_answering( + config, + *ttnn_squeezebert_inputs, + state_dict=state_dict, + base_addr=f"transformer.", + parameters=parameters, + device=mesh_device, + reader_patterns_cache={}, + mesh_mapper=inputs_mesh_mapper, + mesh_composer=output_mesh_composer, + ) + profiler.end(f"inference_time") + + tt_output = ttnn.to_torch(tt_output, mesh_composer=output_mesh_composer).to(torch.float32) + + tt_start_logits = tt_output[..., :, 0].squeeze(1) + tt_end_logits = tt_output[..., :, 1].squeeze(1) + + model_answers = {} + profiler.start("post_processing_output_to_string") + for i in range(batch_size): + tt_res = { + "start": tt_start_logits[i], + "end": tt_end_logits[i], + "example": preprocessed_inputs[i]["example"], + **preprocessed_inputs[i]["inputs"], + } + tt_answer = nlp.postprocess([tt_res], **postprocess_params) + + logger.info(f"answer: {tt_answer['answer']}\n") + model_answers[i] = tt_answer["answer"] + + profiler.end("post_processing_output_to_string") + + measurements = { + "preprocessing_parameter": profiler.get("preprocessing_parameter"), + "preprocessing_input": profiler.get("preprocessing_input"), + "inference_time": profiler.get("inference_time"), + "post_processing": profiler.get("post_processing_output_to_string"), + } + logger.info(f"preprocessing_parameter: {measurements['preprocessing_parameter']} s") + logger.info(f"preprocessing_input: {measurements['preprocessing_input']} s") + logger.info(f"inference_time: {measurements['inference_time']} s") + logger.info(f"post_processing : {measurements['post_processing']} s") + + return measurements + + +def run_squeezebert_question_and_answering_inference_squad_v2( + mesh_device, + use_program_cache, + model_name, + batch_size, + sequence_size, + squeezebert, + n_iterations, +): + disable_persistent_kernel_cache() + hugging_face_reference_model = transformers.SqueezeBertForQuestionAnswering.from_pretrained( + model_name, torchscript=False + ) + + state_dict = hugging_face_reference_model.state_dict() + tt_model_name = f"ttnn_{model_name}_optimized" + + inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device) + output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) + with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + parameters = preprocess_model_parameters( + model_name=tt_model_name, + initialize_model=lambda: hugging_face_reference_model, + custom_preprocessor=ttnn_functional_squeezebert.custom_preprocessor, + device=mesh_device, + ) + + tokenizer = transformers.SqueezeBertTokenizer.from_pretrained(model_name) + config = hugging_face_reference_model.config + + nlp = transformers.pipeline("question-answering", model=hugging_face_reference_model, tokenizer=tokenizer) + + attention_mask = True + token_type_ids = True + inputs_squadv2 = squadv2_1K_samples_input(tokenizer, sequence_size, attention_mask, token_type_ids, batch_size) + squad_metric = evaluate.load("squad_v2") + + with torch.no_grad(): + pred_labels = [] + cpu_pred_labels = [] + true_labels = [] + i = 0 + for batch in inputs_squadv2: + if i < n_iterations: + batch_data = batch[0] + curr_batch_size = batch_data["input_ids"].shape[0] + position_ids = positional_ids(config, batch_data.input_ids) + + ttnn_squeezebert_inputs = squeezebert.preprocess_inputs( + batch_data["input_ids"], + batch_data["token_type_ids"], + position_ids, + batch_data["attention_mask"], + device=mesh_device, + mesh_mapper=inputs_mesh_mapper, + ) + + tt_output = squeezebert.squeezebert_for_question_answering( + config, + *ttnn_squeezebert_inputs, + state_dict=state_dict, + base_addr=f"transformer.", + parameters=parameters, + device=mesh_device, + reader_patterns_cache={}, + mesh_mapper=inputs_mesh_mapper, + mesh_composer=output_mesh_composer, + ) + tt_output = ttnn.to_torch(tt_output, mesh_composer=output_mesh_composer).to(torch.float32) + + cpu_output = hugging_face_reference_model(**batch_data) + references = batch[1] + question = batch[2] + context = batch[3] + + cpu_predictions, tt_predictions = squadv2_answer_decode_batch( + hugging_face_reference_model, + tokenizer, + nlp, + references, + cpu_output, + tt_output, + curr_batch_size, + question, + context, + ) + pred_labels.extend(tt_predictions) + cpu_pred_labels.extend(cpu_predictions) + true_labels.extend(references) + + del tt_output + i += 1 + eval_score = squad_metric.compute(predictions=pred_labels, references=true_labels) + cpu_eval_score = squad_metric.compute(predictions=cpu_pred_labels, references=true_labels) + logger.info(f"\tTT_Eval: exact: {eval_score['exact']} -- F1: {eval_score['f1']}") + logger.info(f"\tCPU_Eval: exact: {cpu_eval_score['exact']} -- F1: {cpu_eval_score['f1']}") + + tolerance = 0.03 + assert ( + abs(eval_score["exact"] - cpu_eval_score["exact"]) <= tolerance + and abs(eval_score["f1"] - cpu_eval_score["f1"]) <= tolerance + ), ( + f"Expected Exact Match : {cpu_eval_score['exact']}, Actual Exact Match: {eval_score['exact']}; " + f"Expected F1 Score : {cpu_eval_score['f1']}, Actual F1 Score: {eval_score['f1']}" + ) + + +@run_for_wormhole_b0() +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) +@pytest.mark.parametrize( + "batch_size, sequence_size", + [ + (8, 384), + ], +) +@pytest.mark.parametrize( + "model_name, input_loc", + ((["squeezebert/squeezebert-uncased", "models/demos/wormhole/squeezebert/demo/input_data.json"]),), +) +@pytest.mark.parametrize("squeezebert", [ttnn_functional_squeezebert]) +def test_demo( + input_loc, batch_size, sequence_size, model_name, squeezebert, mesh_device, use_program_cache, reset_seeds +): + disable_persistent_kernel_cache() + disable_compilation_reports() + + mesh_device_flag = is_wormhole_b0() and ttnn.GetNumAvailableDevices() == 2 + batch_size = batch_size * 2 if mesh_device_flag else batch_size + print(f" batch_size: {batch_size}") + return run_squeezebert_question_and_answering_inference( + mesh_device=mesh_device, + use_program_cache=use_program_cache, + model_name=model_name, + batch_size=batch_size, + sequence_size=sequence_size, + squeezebert=squeezebert, + input_path=input_loc, + ) + + +@run_for_wormhole_b0() +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) +@pytest.mark.parametrize( + "batch_size, sequence_size", + [ + (8, 384), + ], +) +@pytest.mark.parametrize("model_name", ["squeezebert/squeezebert-uncased"]) +@pytest.mark.parametrize("squeezebert", [ttnn_functional_squeezebert]) +@pytest.mark.parametrize( + "n_iterations", + ((3),), +) +def test_demo_squadv2( + batch_size, sequence_size, model_name, squeezebert, n_iterations, mesh_device, use_program_cache, reset_seeds +): + disable_persistent_kernel_cache() + disable_compilation_reports() + + mesh_device_flag = is_wormhole_b0() and ttnn.GetNumAvailableDevices() == 2 + batch_size = batch_size * 2 if mesh_device_flag else batch_size + print(f" batch_size: {batch_size}") + + return run_squeezebert_question_and_answering_inference_squad_v2( + mesh_device=mesh_device, + use_program_cache=use_program_cache, + model_name=model_name, + batch_size=batch_size, + sequence_size=sequence_size, + squeezebert=squeezebert, + n_iterations=n_iterations, + ) diff --git a/models/demos/wormhole/squeezebert/demo/input_data.json b/models/demos/wormhole/squeezebert/demo/input_data.json new file mode 100644 index 000000000000..f182d147a451 --- /dev/null +++ b/models/demos/wormhole/squeezebert/demo/input_data.json @@ -0,0 +1,66 @@ +[ + { + "context" : "Johann Joachim Winckelmann was a German art historian and archaeologist. He was a pioneering Hellenist who first articulated the difference between Greek, Greco-Roman and Roman art. The prophet and founding hero of modern archaeology, Winckelmann was one of the founders of scientific archaeology and first applied the categories of style on a large, systematic basis to the history of art.", + "question" : "What discipline did Winkelmann create?" + }, + { + "context" : "The Norman dynasty had a major political, cultural and military impact on medieval Europe and even the Near East. The Normans were famed for their martial spirit and eventually for their Christian piety, becoming exponents of the Catholic orthodoxy into which they assimilated. They adopted the Gallo-Romance language of the Frankish land they settled, their dialect becoming known as Norman, Normaund or Norman French, an important literary language. The Duchy of Normandy, which they formed by treaty with the French crown, was a great fief of medieval France, and under Richard I of Normandy was forged into a cohesive and formidable principality in feudal tenure. The Normans are noted both for their culture, such as their unique Romanesque architecture and musical traditions, and for their significant military accomplishments and innovations. Norman adventurers founded the Kingdom of Sicily under Roger II after conquering southern Italy on the Saracens and Byzantines, and an expedition on behalf of their duke, William the Conqueror, led to the Norman conquest of England at the Battle of Hastings in 1066. Norman cultural and military influence spread from these new European centres to the Crusader states of the Near East, where their prince Bohemond I founded the Principality of Antioch in the Levant, to Scotland and Wales in Great Britain, to Ireland, and to the coasts of north Africa and the Canary Islands.", + "question" : "Who ruled the duchy of Normandy" + }, + { + "context" : "In many countries, there is a Gender pay gap in favor of males in the labor market. Several factors other than discrimination may contribute to this gap. On average, women are more likely than men to consider factors other than pay when looking for work, and may be less willing to travel or relocate. Thomas Sowell, in his book Knowledge and Decisions, claims that this difference is due to women not taking jobs due to marriage or pregnancy, but income studies show that that does not explain the entire difference. A U.S. Census's report stated that in US once other factors are accounted for there is still a difference in earnings between women and men. The income gap in other countries ranges from 53% in Botswana to -40% in Bahrain.", + "question" : "Who does a gender pay gap tend to favor?" + }, + { + "context" : "Most of the Huguenot congregations (or individuals) in North America eventually affiliated with other Protestant denominations with more numerous members. The Huguenots adapted quickly and often married outside their immediate French communities, which led to their assimilation. Their descendants in many families continued to use French first names and surnames for their children well into the nineteenth century. Assimilated, the French made numerous contributions to United States economic life, especially as merchants and artisans in the late Colonial and early Federal periods. For example, E.I. du Pont, a former student of Lavoisier, established the Eleutherian gunpowder mills.", + "question" : "How were Huguenot settlers assimilated into North American society at large?" + }, + { + "context" : "In the laboratory, biostratigraphers analyze rock samples from outcrop and drill cores for the fossils found in them. These fossils help scientists to date the core and to understand the depositional environment in which the rock units formed. Geochronologists precisely date rocks within the stratigraphic section in order to provide better absolute bounds on the timing and rates of deposition. Magnetic stratigraphers look for signs of magnetic reversals in igneous rock units within the drill cores. Other scientists perform stable isotope studies on the rocks to gain information about past climate.", + "question" : "Who analyzes rock samples from drill cores in the lab?" + }, + { + "context" : "Neutrophils and macrophages are phagocytes that travel throughout the body in pursuit of invading pathogens. Neutrophils are normally found in the bloodstream and are the most abundant type of phagocyte, normally representing 50% to 60% of the total circulating leukocytes. During the acute phase of inflammation, particularly as a result of bacterial infection, neutrophils migrate toward the site of inflammation in a process called chemotaxis, and are usually the first cells to arrive at the scene of infection. Macrophages are versatile cells that reside within tissues and produce a wide array of chemicals including enzymes, complement proteins, and regulatory factors such as interleukin 1. Macrophages also act as scavengers, ridding the body of worn-out cells and other debris, and as antigen-presenting cells that activate the adaptive immune system.", + "question" : "What is the process in which neutrophils move towards the site of inflammation called?" + }, + { + "context" : "In Afghanistan, the mujahideen's victory against the Soviet Union in the 1980s did not lead to justice and prosperity, due to a vicious and destructive civil war between political and tribal warlords, making Afghanistan one of the poorest countries on earth. In 1992, the Democratic Republic of Afghanistan ruled by communist forces collapsed, and democratic Islamist elements of mujahdeen founded the Islamic State of Afghanistan. In 1996, a more conservative and anti-democratic Islamist movement known as the Taliban rose to power, defeated most of the warlords and took over roughly 80% of Afghanistan.", + "question" : "When did the Democratic Republic of Afghanistan collapse?" + }, + { + "context" : "The largest single sensory feature is the aboral organ (at the opposite end from the mouth). Its main component is a statocyst, a balance sensor consisting of a statolith, a solid particle supported on four bundles of cilia, called \"balancers\", that sense its orientation. The statocyst is protected by a transparent dome made of long, immobile cilia. A ctenophore does not automatically try to keep the statolith resting equally on all the balancers. Instead its response is determined by the animal's \"mood\", in other words the overall state of the nervous system. For example, if a ctenophore with trailing tentacles captures prey, it will often put some comb rows into reverse, spinning the mouth towards the prey.", + "question" : "What is the main component of the aboral organ?" + }, + { + "context": "Mark Rothko was a Latvian-born American abstract painter. He is best known for his color field paintings that depicted irregular and painterly rectangular regions of color, which he produced from 1949 to 1970. Although Rothko did not personally subscribe to any one school, he is associated with the American Abstract Expressionist movement of modern art. Originally emigrating to Portland, Oregon, from Russian Empire (Latvia) with his family, Rothko later moved to New York City where his youthful period of artistic production dealt primarily with urban scenery.", + "question": "what is Rothko best known for?" + }, + { + "context": "Malignant narcissism is a psychological syndrome that could include aspects of narcissistic personality disorder (NPD) alongside a mix of antisocial, paranoid and sadistic personality disorder traits. The importance of malignant narcissism and of projection as a defense mechanism has been confirmed in paranoia, as well as the patient's vulnerability to malignant narcissistic regression. A person with malignant narcissism exhibits paranoia in addition to the symptoms of a Narcissistic Personality Disorder. Because a malignant narcissist's personality cannot tolerate any criticism, being mocked typically causes paranoia.", + "question": "What symptoms a malignant narcissist might exhibit in addition to the symptoms of a NPD patient?" + }, + { + "context": "The 14 July Revolution, also known as the 1958 Iraqi military coup, was a coup d'état that took place on 14 July 1958 in Iraq which resulted in the toppling of King Faisal II and the overthrow of the Hashemite-led Kingdom of Iraq. The Iraqi Republic established in its wake ended the Hashemite Arab Federation between Iraq and Jordan that had been established just six months earlier. In July 1958, units of the Royal Iraqi Army were dispatched to Jordan in support of King Hussein. A group of Iraqi Free Officers, led by Brigadier Abd al-Karim Qasim and Colonel Abdul Salam Arif, took advantage of the opportunity and instead marched on Baghdad. On 14 July, revolutionary forces seized control of the capital and proclaimed a new republic, headed by a Revolutionary Council.", + "question": "When was the Hashemite Arab Federation formed?" + }, + { + "context": "The Tasmanian devil is a carnivorous marsupial of the family Dasyuridae. It was formerly present across mainland Australia, but became extinct there around 3,500 years ago. The size of a small dog, the Tasmanian devil became the largest carnivorous marsupial in the world following the extinction of the thylacine in 1936. It is related to quolls, and distantly related to the thylacine. It is characterised by its stocky and muscular build, black fur, pungent odour, extremely loud and disturbing screech, keen sense of smell, and ferocity when feeding. The Tasmanian devil's large head and neck allow it to generate among the strongest bites per unit body mass of any extant predatory land mammal. It hunts prey and scavenges on carrion.", + "question": "What allows Tasmanian devil to generate strong bites?" + }, + { + "context": "Soon after the Normans began to enter Italy, they entered the Byzantine Empire and then Armenia, fighting against the Pechenegs, the Bulgars, and especially the Seljuk Turks. Norman mercenaries were first encouraged to come to the south by the Lombards to act against the Byzantines, but they soon fought in Byzantine service in Sicily. They were prominent alongside Varangian and Lombard contingents in the Sicilian campaign of George Maniaces in 1038-40. There is debate whether the Normans in Greek service actually were from Norman Italy, and it now seems likely only a few came from there. It is also unknown how many of the 'Franks', as the Byzantines called them, were Normans and not other Frenchmen.", + "question": "Who was the Normans' main enemy in Italy, the Byzantine Empire and Armenia?" + }, + { + "context": "Soon after the Normans began to enter Italy, they entered the Byzantine Empire and then Armenia, fighting against the Pechenegs, the Bulgars, and especially the Seljuk Turks. Norman mercenaries were first encouraged to come to the south by the Lombards to act against the Byzantines, but they soon fought in Byzantine service in Sicily. They were prominent alongside Varangian and Lombard contingents in the Sicilian campaign of George Maniaces in 1038-40. There is debate whether the Normans in Greek service actually were from Norman Italy, and it now seems likely only a few came from there. It is also unknown how many of the 'Franks', as the Byzantines called them, were Normans and not other Frenchmen.", + "question": "Who entered Italy soon after the Byzantine Empire?" + }, + { + "context": "Soon after the Normans began to enter Italy, they entered the Byzantine Empire and then Armenia, fighting against the Pechenegs, the Bulgars, and especially the Seljuk Turks. Norman mercenaries were first encouraged to come to the south by the Lombards to act against the Byzantines, but they soon fought in Byzantine service in Sicily. They were prominent alongside Varangian and Lombard contingents in the Sicilian campaign of George Maniaces in 1038-40. There is debate whether the Normans in Greek service actually were from Norman Italy, and it now seems likely only a few came from there. It is also unknown how many of the 'Franks', as the Byzantines called them, were Normans and not other Frenchmen.", + "question": "Who did the Normans fight in Italy?" + }, + { + "context": "Soon after the Normans began to enter Italy, they entered the Byzantine Empire and then Armenia, fighting against the Pechenegs, the Bulgars, and especially the Seljuk Turks. Norman mercenaries were first encouraged to come to the south by the Lombards to act against the Byzantines, but they soon fought in Byzantine service in Sicily. They were prominent alongside Varangian and Lombard contingents in the Sicilian campaign of George Maniaces in 1038-40. There is debate whether the Normans in Greek service actually were from Norman Italy, and it now seems likely only a few came from there. It is also unknown how many of the 'Franks', as the Byzantines called them, were Normans and not other Frenchmen.", + "question": "Who did the Normans encourage to come to the south?" + } +] diff --git a/models/demos/wormhole/squeezebert/tests/test_perf_device_squeezebert.py b/models/demos/wormhole/squeezebert/tests/test_perf_device_squeezebert.py new file mode 100644 index 000000000000..bbd257830f24 --- /dev/null +++ b/models/demos/wormhole/squeezebert/tests/test_perf_device_squeezebert.py @@ -0,0 +1,45 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import ttnn +from models.utility_functions import run_for_wormhole_b0, is_wormhole_b0 +from models.perf.device_perf_utils import run_device_perf, check_device_perf, prep_device_perf_report + + +@run_for_wormhole_b0() +@pytest.mark.models_device_performance_bare_metal +@pytest.mark.parametrize( + "batch_size, test", + [ + [ + 8, + "silicon_arch_name=wormhole_b0-silicon_arch_wormhole_b0=True-sequence_size=384-batch_size=8-model_name=squeezebert/squeezebert-uncased-device_params={'l1_small_size': 16384}", + ], + ], +) +def test_perf_device_bare_metal(batch_size, test): + subdir = "ttnn_squeezebert" + num_iterations = 1 + margin = 0.03 + expected_perf = 290.35 + + command = f"pytest tests/ttnn/integration_tests/squeezebert/test_ttnn_squeezebert_wh.py::test_squeezebert_for_question_answering" + cols = ["DEVICE FW", "DEVICE KERNEL", "DEVICE BRISC KERNEL"] + + inference_time_key = "AVG DEVICE KERNEL SAMPLES/S" + expected_perf_cols = {inference_time_key: expected_perf} + + mesh_device_flag = is_wormhole_b0() and ttnn.GetNumAvailableDevices() == 2 + batch_size = batch_size * 2 if mesh_device_flag else batch_size + + post_processed_results = run_device_perf(command, subdir, num_iterations, cols, batch_size) + expected_results = check_device_perf(post_processed_results, margin, expected_perf_cols, assert_on_fail=False) + prep_device_perf_report( + model_name=f"ttnn_squeezebert_{batch_size}", + batch_size=batch_size, + post_processed_results=post_processed_results, + expected_results=expected_results, + comments=test.replace("/", "_"), + ) diff --git a/models/demos/wormhole/squeezebert/tests/test_performance.py b/models/demos/wormhole/squeezebert/tests/test_performance.py new file mode 100644 index 000000000000..485ddc2a9c7c --- /dev/null +++ b/models/demos/wormhole/squeezebert/tests/test_performance.py @@ -0,0 +1,150 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + + +import time +import ttnn +import torch +import pytest +import transformers +from loguru import logger +from models.perf.perf_utils import prep_perf_report +from ttnn.model_preprocessing import preprocess_model_parameters +from models.demos.wormhole.squeezebert.tt import ttnn_functional_squeezebert +from models.experimental.functional_common.attention_mask_functions import get_extended_attention_mask +from models.utility_functions import ( + enable_persistent_kernel_cache, + disable_persistent_kernel_cache, + is_wormhole_b0, + run_for_wormhole_b0, +) + + +def synchronize_devices(device): + devices = device.get_devices() + for device in devices: + ttnn.synchronize_device(device) + + +def get_expected_times(squeezebert): + return {ttnn_functional_squeezebert: (29.29, 15.5)}[squeezebert] + + +@run_for_wormhole_b0() +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) +@pytest.mark.models_performance_bare_metal +@pytest.mark.models_performance_virtual_machine +@pytest.mark.parametrize("model_name", ["squeezebert/squeezebert-uncased"]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("sequence_size", [384]) +@pytest.mark.parametrize("squeezebert", [ttnn_functional_squeezebert]) +def test_performance(mesh_device, use_program_cache, model_name, batch_size, sequence_size, squeezebert): + disable_persistent_kernel_cache() + num_iterations = 2 + rf_model = transformers.SqueezeBertForQuestionAnswering.from_pretrained(model_name) + config = transformers.SqueezeBertConfig.from_pretrained(model_name) + state_dict = rf_model.state_dict() + tt_model_name = f"ttnn_{model_name}_optimized" + + mesh_device_flag = is_wormhole_b0() and ttnn.GetNumAvailableDevices() == 2 + batch_size = batch_size * 2 if mesh_device_flag else batch_size + + inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device) + output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) + with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + parameters = preprocess_model_parameters( + model_name=tt_model_name, + initialize_model=lambda: rf_model, + custom_preprocessor=ttnn_functional_squeezebert.custom_preprocessor, + device=mesh_device, + ) + + input_ids = torch.randint(0, config.vocab_size, (batch_size, sequence_size)).to(torch.int32) + torch_token_type_ids = torch.zeros((batch_size, sequence_size), dtype=torch.int32) + position_ids = torch.zeros((batch_size, sequence_size), dtype=torch.int32) + torch_attention_mask = torch.ones(batch_size, sequence_size) + + ttnn_squeezebert_inputs_on_cpu = ttnn_functional_squeezebert.preprocess_inputs( + input_ids, + torch_token_type_ids, + position_ids, + torch_attention_mask, + device=mesh_device, + mesh_mapper=inputs_mesh_mapper, + ) + + start = time.time() + ttnn_squeezebert_inputs = [ + ( + ttnn.to_device(tensor, device=mesh_device, memory_config=ttnn.L1_MEMORY_CONFIG) + if tensor is not None + else tensor + ) + for tensor in ttnn_squeezebert_inputs_on_cpu + ] + tt_output = squeezebert.squeezebert_for_question_answering( + config, + *ttnn_squeezebert_inputs, + state_dict=state_dict, + base_addr=f"transformer.", + parameters=parameters, + device=mesh_device, + reader_patterns_cache={}, + mesh_mapper=inputs_mesh_mapper, + mesh_composer=output_mesh_composer, + ) + + tt_output = ttnn.from_device(tt_output, blocking=False) + synchronize_devices(mesh_device) + end = time.time() + inference_and_compile_time = end - start + enable_persistent_kernel_cache() + + start = time.time() + for _ in range(num_iterations): + ttnn_squeezebert_inputs = [ + ( + ttnn.to_device(tensor, device=mesh_device, memory_config=ttnn.L1_MEMORY_CONFIG) + if tensor is not None + else tensor + ) + for tensor in ttnn_squeezebert_inputs_on_cpu + ] + tt_output = squeezebert.squeezebert_for_question_answering( + config, + *ttnn_squeezebert_inputs, + state_dict=state_dict, + base_addr=f"transformer.", + parameters=parameters, + device=mesh_device, + reader_patterns_cache={}, + mesh_mapper=inputs_mesh_mapper, + mesh_composer=output_mesh_composer, + ) + tt_output = ttnn.from_device(tt_output, blocking=False) + + synchronize_devices(mesh_device) + end = time.time() + average_inference_time = (end - start) / num_iterations + + expected_compile_time, expected_inference_time = get_expected_times(squeezebert) + prep_perf_report( + model_name=tt_model_name, + batch_size=batch_size, + inference_and_compile_time=inference_and_compile_time, + inference_time=average_inference_time, + expected_compile_time=expected_compile_time, + expected_inference_time=expected_inference_time, + comments="", + inference_time_cpu=0.0, + ) + + logger.info(f"Compile time: {inference_and_compile_time - average_inference_time}") + logger.info(f"Average Inference time: {average_inference_time}") + logger.info(f"Samples per second: {1 / average_inference_time * batch_size}") + + assert ( + average_inference_time < expected_inference_time + ), f"Expected inference time: {expected_inference_time} Actual inference time: {average_inference_time}" diff --git a/models/demos/wormhole/squeezebert/tt/ttnn_functional_squeezebert.py b/models/demos/wormhole/squeezebert/tt/ttnn_functional_squeezebert.py new file mode 100644 index 000000000000..336161c7bd86 --- /dev/null +++ b/models/demos/wormhole/squeezebert/tt/ttnn_functional_squeezebert.py @@ -0,0 +1,557 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import ttnn +import torch +from torch import nn +from tests.ttnn.ttnn_utility_fuction import get_shard_grid_from_num_cores +from models.experimental.functional_common.attention_mask_functions import get_extended_attention_mask + + +def transpose_for_scores(config, x, device, permute_tensor: bool): + new_x_shape = (x.shape[0], config.num_attention_heads, config.attention_head_size, x.shape[-1]) + x = ttnn.from_device(x) + x = ttnn.reshape(x, new_x_shape) + x = ttnn.to_device(x, device) + + if permute_tensor: + x = ttnn.permute(x, (0, 1, 3, 2)) + + return x + + +def transpose_output(config, x, device): + all_head_size = config.num_attention_heads * config.attention_head_size + if len(x.shape) == 4: + x = ttnn.permute(x, (0, 1, 3, 2)) + + new_x_shape = (x.shape[0], all_head_size, x.shape[3]) + x = ttnn.reshape(x, new_x_shape) + + return x + + +def permute_reshape(hidden_states, shape=(0, 2, 1), reshape=True): + bs, *_ = hidden_states.shape + hidden_states = ttnn.permute(hidden_states, (0, 2, 1)) + if reshape: + hidden_states = ttnn.reshape(hidden_states, (bs, hidden_states.shape[-2], hidden_states.shape[-1])) + + return hidden_states + + +def ttnn_conv1d( + device, + tt_input_tensor, + weights, + conv_params, + bias, + *, + output_dtype=ttnn.bfloat16, + weights_dtype=ttnn.bfloat8_b, + math_fidelity=ttnn.MathFidelity.LoFi, + deallocate_activation=False, + act_block_h=None, + height_sharding=True, + use_shallow_conv_variant=False, + fp32_accum=False, + packer_l1_acc=False, + debug=False, + groups=4, + math_approx=True, + activation="", + reallocate_halo=False, + reshard=False, +): + weights = ttnn.from_torch(weights, dtype=ttnn.float32) + bias = ttnn.from_torch(bias.unsqueeze(0).unsqueeze(0).unsqueeze(0), dtype=ttnn.float32) + + conv_config = ttnn.Conv1dConfig( + dtype=ttnn.bfloat16, + weights_dtype=ttnn.bfloat8_b, + math_approx_mode_enabled=math_approx, + fp32_dest_acc_enabled=fp32_accum, + packer_l1_accum_enabled=packer_l1_acc, + activation=activation, + input_channels_alignment=(16 if use_shallow_conv_variant else 32), + deallocate_activation=deallocate_activation, + reallocate_halo_output=reallocate_halo, + act_block_h_override=32, + reshard_if_not_optimal=reshard, + shard_layout=( + ttnn.TensorMemoryLayout.HEIGHT_SHARDED if height_sharding else ttnn.TensorMemoryLayout.BLOCK_SHARDED + ), + core_grid=get_shard_grid_from_num_cores(56, device), + math_fidelity=math_fidelity, + ) + + [tt_output_tensor_on_device, out_length, weights_device, bias_device] = ttnn.Conv1d( + input_tensor=tt_input_tensor, + weight_tensor=weights, + in_channels=tt_input_tensor.shape[-1], + out_channels=weights.shape[0], + device=device, + bias_tensor=bias, + kernel_size=1, + stride=1, + padding=0, + batch_size=tt_input_tensor.shape[0], + input_length=tt_input_tensor.shape[1], + conv_config=conv_config, + conv_op_cache={}, + debug=debug, + groups=groups, + ) + + tt_output_tensor_on_device = ttnn.squeeze(tt_output_tensor_on_device, 0) + tt_output_tensor_on_device = ttnn.reshape( + tt_output_tensor_on_device, (tt_input_tensor.shape[0], out_length, tt_output_tensor_on_device.shape[-1]) + ) + + tt_output_tensor = ttnn.from_device(tt_output_tensor_on_device) + + return tt_output_tensor + + +def squeezebert_conv_layernorm( + config, + hidden_states, + input_tensor, + *, + state_dict, + base_addr, + parameters, + device, + cin, + cout, + groups, + mesh_mapper=None, + mesh_composer=None, +): + torch_hidden_states = ttnn.to_torch(hidden_states, mesh_composer=mesh_composer).to(torch.float32) + + self_output_conv1d_ = nn.Conv1d(in_channels=cin, out_channels=cout, kernel_size=1, groups=groups) + self_output_conv1d_.weight = nn.Parameter(state_dict[f"{base_addr}conv1d.weight"]) + self_output_conv1d_.bias = nn.Parameter(state_dict[f"{base_addr}conv1d.bias"]) + + torch_self_output = self_output_conv1d_(torch_hidden_states) + self_output = ttnn.from_torch( + torch_self_output, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, mesh_mapper=mesh_mapper, device=device + ) + + self_output_layernorm = ttnn.add(self_output, input_tensor) + self_output_layernorm = permute_reshape(self_output_layernorm) + + attention_output = ttnn.layer_norm( + self_output_layernorm, + weight=parameters.layernorm.weight, + bias=parameters.layernorm.bias, + epsilon=config.layer_norm_eps, + ) + ttnn.deallocate(self_output_layernorm) + attention_output = permute_reshape(attention_output) + + return attention_output + + +def squeezebert_attention( + config, + hidden_states, + attention_mask, + *, + state_dict, + base_addr, + parameters, + device, + reader_patterns_cache, + num_cores_x=12, + mesh_mapper=None, + mesh_composer=None, +): + num_heads = config.num_attention_heads + batch_size, hidden_size, _ = hidden_states.shape + head_size = hidden_size // num_heads + config.attention_head_size = head_size + + hidden_states = ttnn.to_layout(hidden_states, ttnn.ROW_MAJOR_LAYOUT) + hidden_states = permute_reshape(hidden_states) + hidden_states = ttnn.from_device(hidden_states) + mixed_query_layer = ttnn_conv1d( + device, + hidden_states, + nn.Parameter(state_dict[f"{base_addr}query.weight"]), + conv_params=[1, 0], + bias=nn.Parameter(state_dict[f"{base_addr}query.bias"]), + ) + mixed_query_layer = ttnn.to_device(mixed_query_layer, device) + mixed_query_layer = ttnn.permute(mixed_query_layer, (0, 2, 1)) + + mixed_key_layer = ttnn_conv1d( + device, + hidden_states, + nn.Parameter(state_dict[f"{base_addr}key.weight"]), + conv_params=[1, 0], + bias=nn.Parameter(state_dict[f"{base_addr}key.bias"]), + ) + mixed_key_layer = ttnn.to_device(mixed_key_layer, device) + mixed_key_layer = ttnn.permute(mixed_key_layer, (0, 2, 1)) + + mixed_value_layer = ttnn_conv1d( + device, + hidden_states, + nn.Parameter(state_dict[f"{base_addr}value.weight"]), + conv_params=[1, 0], + bias=nn.Parameter(state_dict[f"{base_addr}value.bias"]), + ) + mixed_value_layer = ttnn.to_device(mixed_value_layer, device) + mixed_value_layer = ttnn.permute(mixed_value_layer, (0, 2, 1)) + + query = transpose_for_scores(config, mixed_query_layer, device, True) + key = transpose_for_scores(config, mixed_key_layer, device, False) + value = transpose_for_scores(config, mixed_value_layer, device, True) + + ttnn.deallocate(mixed_query_layer) + ttnn.deallocate(mixed_key_layer) + ttnn.deallocate(mixed_value_layer) + + attention_scores = ttnn.matmul( + query, + key, + memory_config=ttnn.L1_MEMORY_CONFIG, + dtype=ttnn.bfloat16, + core_grid=ttnn.CoreGrid(y=batch_size, x=num_cores_x), + ) + ttnn.deallocate(query) + ttnn.deallocate(key) + + attention_probs = ttnn.transformer.attention_softmax_( + attention_scores, attention_mask=attention_mask, head_size=head_size + ) + + context_layer = ttnn.matmul( + attention_probs, + value, + memory_config=ttnn.L1_MEMORY_CONFIG, + dtype=ttnn.bfloat8_b, + core_grid=ttnn.CoreGrid(y=batch_size, x=num_cores_x), + ) + context_layer = transpose_output(config, context_layer, device) + + return context_layer + + +def squeezebert_intermediate( + config, + hidden_states, + *, + state_dict, + base_addr, + parameters, + device, + num_cores_x=12, + mesh_mapper=None, + mesh_composer=None, +): + torch_hidden_states = ttnn.to_torch(hidden_states, mesh_composer=mesh_composer).to(torch.float32) + + torch_conv_ = nn.Conv1d( + in_channels=config.hidden_size, + out_channels=config.intermediate_size, + kernel_size=1, + groups=config.intermediate_groups, + ) + torch_conv_.weight = nn.Parameter(state_dict[f"{base_addr}conv1d.weight"]) + torch_conv_.bias = nn.Parameter(state_dict[f"{base_addr}conv1d.bias"]) + + torch_conv_output = torch_conv_(torch_hidden_states) + ttnn_conv_output = ttnn.from_torch( + torch_conv_output, device=device, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, mesh_mapper=mesh_mapper + ) + + output = ttnn.gelu(ttnn_conv_output) + return output + + +def squeezebert_layer( + config, + hidden_states, + attention_mask, + state_dict, + base_addr, + parameters, + device, + reader_patterns_cache, + mesh_mapper=None, + mesh_composer=None, +): + multi_head_attention_output = squeezebert_attention( + config, + hidden_states=hidden_states, + attention_mask=attention_mask, + state_dict=state_dict, + base_addr=f"{base_addr}attention.", + parameters=parameters.attention, + device=device, + reader_patterns_cache=reader_patterns_cache, + mesh_mapper=mesh_mapper, + mesh_composer=mesh_composer, + ) + + attention_output = squeezebert_conv_layernorm( + config, + hidden_states=multi_head_attention_output, + input_tensor=hidden_states, + state_dict=state_dict, + base_addr=f"{base_addr}post_attention.", + parameters=parameters.post_attention, + device=device, + cin=config.hidden_size, + cout=config.hidden_size, + groups=config.post_attention_groups, + mesh_mapper=mesh_mapper, + mesh_composer=mesh_composer, + ) + ttnn.deallocate(hidden_states) + ttnn.deallocate(multi_head_attention_output) + + intermediate = squeezebert_intermediate( + config, + attention_output, + state_dict=state_dict, + base_addr=f"{base_addr}intermediate.", + parameters=parameters.intermediate, + device=device, + mesh_mapper=mesh_mapper, + mesh_composer=mesh_composer, + ) + + output = squeezebert_conv_layernorm( + config, + hidden_states=intermediate, + input_tensor=attention_output, + state_dict=state_dict, + base_addr=f"{base_addr}output.", + parameters=parameters.output, + device=device, + cin=config.intermediate_size, + cout=config.hidden_size, + groups=config.output_groups, + mesh_mapper=mesh_mapper, + mesh_composer=mesh_composer, + ) + + return output + + +def squeezebert_encoder( + config, + hidden_states, + attention_mask, + *, + state_dict, + base_addr, + parameters, + device, + reader_patterns_cache, + mesh_mapper=None, + mesh_composer=None, +): + hidden_states = permute_reshape(hidden_states) + # encoder_output = None + + for layer_idx, encoder_parameters in enumerate(parameters.layers): + encoder_output = squeezebert_layer( + config, + hidden_states, + attention_mask, + state_dict, + base_addr=f"{base_addr}layers.{layer_idx}.", + parameters=encoder_parameters, + device=device, + reader_patterns_cache=reader_patterns_cache, + mesh_mapper=mesh_mapper, + mesh_composer=mesh_composer, + ) + # encoder_output = ttnn.reallocate(encoder_output) + hidden_states = encoder_output + + hidden_states = permute_reshape(hidden_states) + + return hidden_states + + +def squeezebert( + config, + input_ids, + token_type_ids, + position_ids, + attention_mask, + state_dict, + base_addr, + parameters, + device, + reader_patterns_cache, + mesh_mapper=None, + mesh_composer=None, +): + word_embeddings = ttnn.embedding( + input_ids, + parameters.embeddings.word_embeddings.weight, + layout=ttnn.TILE_LAYOUT, + padding_idx=config.pad_token_id, + ) + ttnn.deallocate(input_ids) + + token_type_embeddings = ttnn.embedding( + token_type_ids, + parameters.embeddings.token_type_embeddings.weight, + layout=ttnn.TILE_LAYOUT, + ) + ttnn.deallocate(token_type_ids) + + word_plus_token_type_embeddings = word_embeddings + token_type_embeddings + ttnn.deallocate(word_embeddings) + ttnn.deallocate(token_type_embeddings) + + position_embeddings = ttnn.embedding( + position_ids, + parameters.embeddings.position_embeddings.weight, + layout=ttnn.TILE_LAYOUT, + ) + ttnn.deallocate(position_ids) + + embeddings = word_plus_token_type_embeddings + position_embeddings + ttnn.deallocate(word_plus_token_type_embeddings) + ttnn.deallocate(position_embeddings) + + encoder_input = ttnn.layer_norm( + embeddings, + weight=parameters.embeddings.LayerNorm.weight, + bias=parameters.embeddings.LayerNorm.bias, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + ttnn.deallocate(embeddings) + + encoder_output = squeezebert_encoder( + config=config, + hidden_states=encoder_input, + attention_mask=attention_mask, + state_dict=state_dict, + base_addr=f"{base_addr}encoder.", + parameters=parameters.encoder, + device=device, + reader_patterns_cache=reader_patterns_cache, + mesh_mapper=mesh_mapper, + mesh_composer=mesh_composer, + ) + ttnn.deallocate(encoder_input) + + return encoder_output + + +def squeezebert_for_question_answering( + config, + input_ids, + token_type_ids, + position_ids, + attention_mask, + *, + state_dict, + base_addr, + parameters, + device, + reader_patterns_cache, + name="transformer", + mesh_mapper=None, + mesh_composer=None, +): + # print(f"parmes: {parameters.keys()}") + squeezebert_output = squeezebert( + config, + input_ids, + token_type_ids, + position_ids, + attention_mask, + state_dict, + base_addr, + parameters=parameters.transformer, + device=device, + reader_patterns_cache=reader_patterns_cache, + mesh_mapper=mesh_mapper, + mesh_composer=mesh_composer, + ) + + qa_outputs = ttnn.linear( + squeezebert_output, + parameters.qa_outputs.weight, + bias=parameters.qa_outputs.bias, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + + return qa_outputs + + +def preprocess_inputs( + input_ids, + token_type_ids, + position_ids, + attention_mask, + device, + mesh_mapper=None, +): + import torch + + batch_size, _ = input_ids.shape + + input_ids = ttnn.from_torch( + input_ids, + dtype=ttnn.uint32, + mesh_mapper=mesh_mapper, + device=device, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + token_type_ids = ttnn.from_torch( + token_type_ids, + dtype=ttnn.uint32, + mesh_mapper=mesh_mapper, + device=device, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + position_ids = ttnn.from_torch( + position_ids, dtype=ttnn.uint32, mesh_mapper=mesh_mapper, device=device, memory_config=ttnn.L1_MEMORY_CONFIG + ) + + if attention_mask is not None: + attention_mask = get_extended_attention_mask(attention_mask, input_ids.shape, torch.float32) + attention_mask = attention_mask.expand((batch_size, -1, -1, -1)) + attention_mask = torch.clamp(attention_mask, min=-100000) + attention_mask = ttnn.from_torch( + attention_mask, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + mesh_mapper=mesh_mapper, + device=device, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + + return input_ids, token_type_ids, position_ids, attention_mask + + +def preprocess_conv_parameter(parameter, *, dtype): + parameter = ttnn.from_torch(parameter, dtype=dtype) + return parameter + + +def custom_preprocessor(model, name): + parameters = {} + if isinstance(model, nn.Conv1d): + weight = model.weight + bias = model.bias + + while bias.dim() < 4: + bias = bias.unsqueeze(0).unsqueeze(0).unsqueeze(0) + parameters["weight"] = preprocess_conv_parameter(weight, dtype=ttnn.float32) + parameters["bias"] = preprocess_conv_parameter(bias, dtype=ttnn.float32) + + return parameters diff --git a/tests/scripts/run_performance.sh b/tests/scripts/run_performance.sh index 7956d1c7b034..5addc898dd3c 100755 --- a/tests/scripts/run_performance.sh +++ b/tests/scripts/run_performance.sh @@ -21,6 +21,9 @@ run_perf_models_other() { env WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest models/demos/wormhole/bert_tiny/tests/test_performance.py -m $test_marker env WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest models/demos/yolov4/tests/test_perf_yolo.py -m $test_marker + + env WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest models/demos/wormhole/squeezebert/tests -m $test_marker + fi env pytest -n auto tests/ttnn/integration_tests/bert/test_performance.py -m $test_marker @@ -121,6 +124,8 @@ run_device_perf_models() { env WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest models/demos/wormhole/bert_tiny/tests -m $test_marker env WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest models/demos/yolov4/tests/ -m $test_marker + + env WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest models/demos/wormhole/squeezebert/tests/test_perf_device_squeezebert.py -m $test_marker fi ## Merge all the generated reports diff --git a/tests/scripts/single_card/run_single_card_demo_tests.sh b/tests/scripts/single_card/run_single_card_demo_tests.sh index 5f5642483f63..5b5b33c068ff 100755 --- a/tests/scripts/single_card/run_single_card_demo_tests.sh +++ b/tests/scripts/single_card/run_single_card_demo_tests.sh @@ -39,6 +39,9 @@ run_common_func_tests() { # Mnist pytest --disable-warnings models/demos/mnist/demo/demo.py --timeout 600; fail+=$? + # Squeezebert + pytest --disable-warnings models/demos/wormhole/squeezebert/demo/demo.py --timeout 600; fail+=$? + return $fail } diff --git a/tests/ttnn/integration_tests/squeezebert/test_ttnn_squeezebert_wh.py b/tests/ttnn/integration_tests/squeezebert/test_ttnn_squeezebert_wh.py new file mode 100644 index 000000000000..52da8ad0999d --- /dev/null +++ b/tests/ttnn/integration_tests/squeezebert/test_ttnn_squeezebert_wh.py @@ -0,0 +1,463 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import ttnn +import torch +import pytest +import transformers +from models.utility_functions import torch_random, is_wormhole_b0, run_for_wormhole_b0 +from tests.ttnn.utils_for_testing import assert_with_pcc +from ttnn.model_preprocessing import preprocess_model_parameters +from models.demos.wormhole.squeezebert.tt import ttnn_functional_squeezebert + + +@run_for_wormhole_b0() +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) +@pytest.mark.parametrize("model_name", ["squeezebert/squeezebert-uncased"]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("sequence_size", [384]) +@pytest.mark.parametrize("torch_dtype", [torch.bfloat16]) +def test_squeezebert_attention(mesh_device, model_name, batch_size, sequence_size, torch_dtype, reset_seeds): + config = transformers.SqueezeBertConfig.from_pretrained(model_name) + model = transformers.models.squeezebert.modeling_squeezebert.SqueezeBertSelfAttention( + config, cin=config.hidden_size, q_groups=config.q_groups, k_groups=config.k_groups, v_groups=config.v_groups + ) + state_dict = model.state_dict() + tt_model_name = f"ttnn_{model_name}_optimized" + + mesh_device_flag = is_wormhole_b0() and ttnn.GetNumAvailableDevices() == 2 + batch_size = batch_size * 2 if mesh_device_flag else batch_size + + inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device) + output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) + with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + parameters = preprocess_model_parameters( + model_name=tt_model_name, + initialize_model=lambda: model, + custom_preprocessor=ttnn_functional_squeezebert.custom_preprocessor, + device=mesh_device, + ) + + model = model.eval().to(torch_dtype) + + torch_hidden_states = torch_random((batch_size, sequence_size, config.hidden_size), -0.1, 0.1, dtype=torch_dtype) + torch_hidden_states = torch_hidden_states.permute(0, 2, 1) + + torch_attention_mask = torch.ones(batch_size, sequence_size, dtype=torch_dtype) + torch_attention_mask = torch_attention_mask[:, None, None, :] + + torch_output = model(torch_hidden_states, attention_mask=torch_attention_mask, output_attentions=False) + + ttnn_attention_mask = ttnn.from_torch( + torch_attention_mask, layout=ttnn.TILE_LAYOUT, device=mesh_device, mesh_mapper=inputs_mesh_mapper + ) + + tt_model_name = f"ttnn_{model_name}_optimized" + + hidden_states = ttnn.from_torch( + torch_hidden_states, layout=ttnn.TILE_LAYOUT, device=mesh_device, mesh_mapper=inputs_mesh_mapper + ) + + output = ttnn_functional_squeezebert.squeezebert_attention( + config, + hidden_states, + attention_mask=ttnn_attention_mask, + state_dict=state_dict, + base_addr=f"", + parameters=parameters, + device=mesh_device, + reader_patterns_cache={}, + mesh_mapper=inputs_mesh_mapper, + mesh_composer=output_mesh_composer, + ) + + output = ttnn.to_torch(output, mesh_composer=output_mesh_composer) + + assert_with_pcc(torch_output["context_layer"], output, 0.99) + + +@run_for_wormhole_b0() +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) +@pytest.mark.parametrize("model_name", ["squeezebert/squeezebert-uncased"]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("sequence_size", [384]) +@pytest.mark.parametrize("torch_dtype", [torch.bfloat16]) +def test_squeezebert_intermediate(mesh_device, model_name, batch_size, sequence_size, torch_dtype, reset_seeds): + config = transformers.SqueezeBertConfig.from_pretrained(model_name) + model = transformers.models.squeezebert.modeling_squeezebert.ConvActivation( + cin=config.hidden_size, cout=config.intermediate_size, groups=config.intermediate_groups, act=config.hidden_act + ) + state_dict = model.state_dict() + tt_model_name = f"ttnn_{model_name}_optimized" + + mesh_device_flag = is_wormhole_b0() and ttnn.GetNumAvailableDevices() == 2 + batch_size = batch_size * 2 if mesh_device_flag else batch_size + + inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device) + output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) + with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + parameters = preprocess_model_parameters( + model_name=tt_model_name, + initialize_model=lambda: model, + custom_preprocessor=ttnn_functional_squeezebert.custom_preprocessor, + device=mesh_device, + ) + + torch_hidden_states = torch_random((batch_size, sequence_size, config.hidden_size), -0.1, 0.1, dtype=torch_dtype) + torch_hidden_states = torch_hidden_states.permute(0, 2, 1) + model = model.eval().to(torch_dtype) + torch_output = model(torch_hidden_states) + + tt_model_name = f"ttnn_{model_name}_optimized" + + hidden_states = ttnn.from_torch( + torch_hidden_states, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + device=mesh_device, + mesh_mapper=inputs_mesh_mapper, + ) + + output = ttnn_functional_squeezebert.squeezebert_intermediate( + config=config, + hidden_states=hidden_states, + state_dict=state_dict, + base_addr=f"", + parameters=parameters, + device=mesh_device, + mesh_mapper=inputs_mesh_mapper, + mesh_composer=output_mesh_composer, + ) + output = ttnn.to_torch(output, mesh_composer=output_mesh_composer) + + assert_with_pcc(torch_output, output.to(torch_output.dtype), 0.99) + + +@run_for_wormhole_b0() +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) +@pytest.mark.parametrize("model_name", ["squeezebert/squeezebert-uncased"]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("sequence_size", [384]) +@pytest.mark.parametrize("torch_dtype", [torch.bfloat16]) +def test_squeezebert_output(mesh_device, model_name, batch_size, sequence_size, torch_dtype, reset_seeds): + config = transformers.SqueezeBertConfig.from_pretrained(model_name) + + model = transformers.models.squeezebert.modeling_squeezebert.ConvDropoutLayerNorm( + cin=config.intermediate_size, + cout=config.hidden_size, + groups=config.output_groups, + dropout_prob=config.hidden_dropout_prob, + ) + state_dict = model.state_dict() + tt_model_name = f"ttnn_{model_name}_optimized" + + mesh_device_flag = is_wormhole_b0() and ttnn.GetNumAvailableDevices() == 2 + batch_size = batch_size * 2 if mesh_device_flag else batch_size + + inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device) + output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) + with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + parameters = preprocess_model_parameters( + model_name=tt_model_name, + initialize_model=lambda: model, + custom_preprocessor=ttnn_functional_squeezebert.custom_preprocessor, + device=mesh_device, + ) + + torch_hidden_states = torch_random( + (batch_size, sequence_size, config.intermediate_size), -0.1, 0.1, dtype=torch_dtype + ) + torch_hidden_states = torch_hidden_states.permute(0, 2, 1) + + torch_residual = torch_random((batch_size, sequence_size, config.hidden_size), -0.1, 0.1, dtype=torch_dtype) + torch_residual = torch_residual.permute(0, 2, 1) + model = model.eval().to(torch_dtype) + torch_output = model(torch_hidden_states, torch_residual) + + tt_model_name = f"ttnn_{model_name}_optimized" + + hidden_states = ttnn.from_torch( + torch_hidden_states, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + device=mesh_device, + mesh_mapper=inputs_mesh_mapper, + ) + residual = ttnn.from_torch( + torch_residual, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=mesh_device, mesh_mapper=inputs_mesh_mapper + ) + + output = ttnn_functional_squeezebert.squeezebert_conv_layernorm( + config=config, + hidden_states=hidden_states, + input_tensor=residual, + state_dict=state_dict, + base_addr=f"", + parameters=parameters, + device=mesh_device, + cin=config.intermediate_size, + cout=config.hidden_size, + groups=config.output_groups, + mesh_mapper=inputs_mesh_mapper, + mesh_composer=output_mesh_composer, + ) + output = ttnn.to_torch(output, mesh_composer=output_mesh_composer) + + assert_with_pcc(torch_output, output.to(torch_output.dtype), 0.99) + + +@run_for_wormhole_b0() +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) +@pytest.mark.parametrize("model_name", ["squeezebert/squeezebert-uncased"]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("sequence_size", [384]) +@pytest.mark.parametrize("torch_dtype", [torch.bfloat16]) +def test_squeezebert_layer(mesh_device, model_name, batch_size, sequence_size, torch_dtype, reset_seeds): + config = transformers.SqueezeBertConfig.from_pretrained(model_name) + model = transformers.models.squeezebert.modeling_squeezebert.SqueezeBertModule(config) + state_dict = model.state_dict() + tt_model_name = f"ttnn_{model_name}_optimized" + + mesh_device_flag = is_wormhole_b0() and ttnn.GetNumAvailableDevices() == 2 + batch_size = batch_size * 2 if mesh_device_flag else batch_size + + inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device) + output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) + with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + parameters = preprocess_model_parameters( + model_name=tt_model_name, + initialize_model=lambda: model, + custom_preprocessor=ttnn_functional_squeezebert.custom_preprocessor, + device=mesh_device, + ) + + torch_hidden_states = torch_random((batch_size, sequence_size, config.hidden_size), -0.1, 0.1, dtype=torch_dtype) + torch_hidden_states = torch_hidden_states.permute(0, 2, 1) + + torch_attention_mask = torch.ones(batch_size, sequence_size, dtype=torch_dtype) + torch_attention_mask = torch_attention_mask[:, None, None, :] + model = model.eval().to(torch_dtype) + torch_output = model(torch_hidden_states, attention_mask=torch_attention_mask, output_attentions=False) + + tt_model_name = f"ttnn_{model_name}_optimized" + + hidden_states = ttnn.from_torch( + torch_hidden_states, layout=ttnn.TILE_LAYOUT, device=mesh_device, mesh_mapper=inputs_mesh_mapper + ) + ttnn_attention_mask = ttnn.from_torch( + torch_attention_mask, layout=ttnn.TILE_LAYOUT, device=mesh_device, mesh_mapper=inputs_mesh_mapper + ) + + output = ttnn_functional_squeezebert.squeezebert_layer( + config, + hidden_states, + attention_mask=ttnn_attention_mask, + state_dict=state_dict, + base_addr=f"", + parameters=parameters, + device=mesh_device, + reader_patterns_cache={}, + mesh_mapper=inputs_mesh_mapper, + mesh_composer=output_mesh_composer, + ) + + output = ttnn.to_torch(output, mesh_composer=output_mesh_composer) + + assert_with_pcc(torch_output["feature_map"], output, 0.99) + + +@run_for_wormhole_b0() +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) +@pytest.mark.parametrize("model_name", ["squeezebert/squeezebert-uncased"]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("sequence_size", [384]) +@pytest.mark.parametrize("torch_dtype", [torch.bfloat16]) +def test_squeezebert_encoder(mesh_device, model_name, batch_size, sequence_size, torch_dtype, reset_seeds): + config = transformers.SqueezeBertConfig.from_pretrained(model_name) + + model = transformers.models.squeezebert.modeling_squeezebert.SqueezeBertEncoder(config) + state_dict = model.state_dict() + tt_model_name = f"ttnn_{model_name}_optimized" + + mesh_device_flag = is_wormhole_b0() and ttnn.GetNumAvailableDevices() == 2 + batch_size = batch_size * 2 if mesh_device_flag else batch_size + + inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device) + output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) + with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + parameters = preprocess_model_parameters( + model_name=tt_model_name, + initialize_model=lambda: model, + custom_preprocessor=ttnn_functional_squeezebert.custom_preprocessor, + device=mesh_device, + ) + + torch_hidden_states = torch_random((batch_size, sequence_size, config.hidden_size), -0.1, 0.1, dtype=torch_dtype) + torch_attention_mask = torch.ones(batch_size, sequence_size, dtype=torch_dtype) + torch_attention_mask = torch_attention_mask[:, None, None, :] + model = model.eval().to(torch_dtype) + + torch_output = model(torch_hidden_states, attention_mask=torch_attention_mask).last_hidden_state + + tt_model_name = f"ttnn_{model_name}_optimized" + + hidden_states = ttnn.from_torch( + torch_hidden_states, ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=mesh_device, mesh_mapper=inputs_mesh_mapper + ) + ttnn_attention_mask = ttnn.from_torch( + torch_attention_mask, layout=ttnn.TILE_LAYOUT, device=mesh_device, mesh_mapper=inputs_mesh_mapper + ) + + output = ttnn_functional_squeezebert.squeezebert_encoder( + config, + hidden_states, + attention_mask=ttnn_attention_mask, + state_dict=state_dict, + base_addr=f"", + parameters=parameters, + device=mesh_device, + reader_patterns_cache={}, + mesh_mapper=inputs_mesh_mapper, + mesh_composer=output_mesh_composer, + ) + output = ttnn.to_torch(output, mesh_composer=output_mesh_composer) + + assert_with_pcc(torch_output, output, 0.99) + + +@run_for_wormhole_b0() +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) +@pytest.mark.parametrize("model_name", ["squeezebert/squeezebert-uncased"]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("sequence_size", [384]) +def test_squeezebert_model(mesh_device, model_name, batch_size, sequence_size, reset_seeds): + config = transformers.SqueezeBertConfig.from_pretrained(model_name) + model = transformers.SqueezeBertModel.from_pretrained(model_name) + state_dict = model.state_dict() + tt_model_name = f"ttnn_{model_name}_optimized" + + mesh_device_flag = is_wormhole_b0() and ttnn.GetNumAvailableDevices() == 2 + batch_size = batch_size * 2 if mesh_device_flag else batch_size + + inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device) + output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) + with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + parameters = preprocess_model_parameters( + model_name=tt_model_name, + initialize_model=lambda: model, + custom_preprocessor=ttnn_functional_squeezebert.custom_preprocessor, + device=mesh_device, + ) + + torch_input_ids = torch.randint(0, config.vocab_size, (batch_size, sequence_size)).to(torch.int32) + torch_token_type_ids = torch.ones((batch_size, sequence_size), dtype=torch.int32) + torch_position_ids = torch.ones((batch_size, sequence_size), dtype=torch.int32) + torch_attention_mask = torch.ones(1, sequence_size, dtype=torch.bfloat16) + + torch_output = model( + torch_input_ids, + token_type_ids=torch_token_type_ids, + position_ids=torch_position_ids, + attention_mask=torch_attention_mask, + ).last_hidden_state + + tt_model_name = f"ttnn_{model_name}_optimized" + + ttnn_bert_inputs = ttnn_functional_squeezebert.preprocess_inputs( + torch_input_ids, + torch_token_type_ids, + torch_position_ids, + torch_attention_mask, + device=mesh_device, + mesh_mapper=inputs_mesh_mapper, + ) + + output = ttnn_functional_squeezebert.squeezebert( + config, + *ttnn_bert_inputs, + state_dict=state_dict, + base_addr=f"", + parameters=parameters, + device=mesh_device, + reader_patterns_cache={}, + mesh_mapper=inputs_mesh_mapper, + mesh_composer=output_mesh_composer, + ) + output = ttnn.to_torch(output, mesh_composer=output_mesh_composer) + + assert_with_pcc(torch_output, output, 0.99) + + +@run_for_wormhole_b0() +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) +@pytest.mark.parametrize("model_name", ["squeezebert/squeezebert-uncased"]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("sequence_size", [384]) +def test_squeezebert_for_question_answering(mesh_device, model_name, batch_size, sequence_size, reset_seeds): + rf_model = transformers.SqueezeBertForQuestionAnswering.from_pretrained(model_name) + config = transformers.SqueezeBertConfig.from_pretrained(model_name) + state_dict = rf_model.state_dict() + tt_model_name = f"ttnn_{model_name}_optimized" + + mesh_device_flag = is_wormhole_b0() and ttnn.GetNumAvailableDevices() == 2 + batch_size = batch_size * 2 if mesh_device_flag else batch_size + + inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device) + output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) + with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + parameters = preprocess_model_parameters( + model_name=tt_model_name, + initialize_model=lambda: rf_model, + custom_preprocessor=ttnn_functional_squeezebert.custom_preprocessor, + device=mesh_device, + ) + + torch_squeezebert_input = torch.randint(0, config.vocab_size, (batch_size, sequence_size)).to(torch.int32) + torch_token_type_ids = torch.zeros((batch_size, sequence_size), dtype=torch.int32) + torch_position_ids = torch.zeros((batch_size, sequence_size), dtype=torch.int32) + torch_attention_mask = torch.ones(batch_size, sequence_size) + + rf_model = rf_model.eval() + torch_output = rf_model( + input_ids=torch_squeezebert_input, + token_type_ids=torch_token_type_ids, + position_ids=torch_position_ids, + attention_mask=torch_attention_mask, + ) + + ttnn_squeezebert_inputs = ttnn_functional_squeezebert.preprocess_inputs( + torch_squeezebert_input, + torch_token_type_ids, + torch_position_ids, + torch_attention_mask, + device=mesh_device, + mesh_mapper=inputs_mesh_mapper, + ) + + tt_output = ttnn_functional_squeezebert.squeezebert_for_question_answering( + config, + *ttnn_squeezebert_inputs, + state_dict=state_dict, + base_addr=f"transformer.", + parameters=parameters, + device=mesh_device, + reader_patterns_cache={}, + mesh_mapper=inputs_mesh_mapper, + mesh_composer=output_mesh_composer, + ) + + tt_output = ttnn.to_torch(tt_output, mesh_composer=output_mesh_composer) + + tt_start_logits = tt_output[..., :, 0] + tt_end_logits = tt_output[..., :, 1] + + assert_with_pcc(torch_output.start_logits, tt_start_logits, 0.90 if mesh_device_flag else 0.88) + assert_with_pcc(torch_output.end_logits, tt_end_logits, 0.90 if mesh_device_flag else 0.87)