-
Notifications
You must be signed in to change notification settings - Fork 92
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
#14423: Add data parallel support for RoBERTa model
- Loading branch information
1 parent
bd026b1
commit df03748
Showing
8 changed files
with
1,060 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# RoBERTa demo | ||
|
||
Demo showcasing Data Parallel implementation of RoBERTa running on Wormhole - n150, n300 using ttnn. | ||
|
||
## Introduction | ||
RoBERTa builds on BERT and modifies key hyperparameters, removing the next-sentence pretraining objective and training with much larger mini-batches and learning rates. | ||
RoBERTa is similar to BERT but with better pretraining techniques like Dynamic Masking, Sentence Packing, Larger Batches, Byte-level BPE vocabulary. The RoBERTa model was proposed in [RoBERTa: A Robustly Optimized BERT](https://arxiv.org/abs/1907.11692) Pretraining Approach based on Google’s BERT model released in 2018. | ||
|
||
## Details | ||
The entry point to ttnn_optimized_roberta model is bert_for_question_answering in `models/demos/wormhole/roberta/tt/ttnn_optimized_roberta.py`. The model picks up certain configs and weights from huggingface pretrained model. We have used `deepset/roberta-large-squad2` version from huggingface as our reference. | ||
|
||
### Sequence Size: 384 | ||
Sequence size determines the maximum length of input sequences processed by the model, optimizing performance and compatibility. It's recommended to set the sequence_size to 384 | ||
|
||
### Batch size: 16 | ||
Batch Size determines the number of input sequences processed simultaneously during training or inference, impacting computational efficiency and memory usage. On each device, the batch size will be 8, as the operations run in parallel. It's recommended to set the batch_size to 16 | ||
|
||
## How to Run | ||
|
||
Use `pytest --disable-warnings models/demos/wormhole/roberta/demo/demo.py::test_demo[wormhole_b0-True-models.demos.wormhole.roberta.tt.ttnn_optimized_roberta-8-384-deepset/roberta-large-squad2-models/demos/wormhole/roberta/demo/input_data.json]` to run the demo. | ||
|
||
If you wish to run the demo with a different input use `pytest --disable-warnings models/demos/wormhole/roberta/demo/demo.py::test_demo[wormhole_b0-True-models.demos.wormhole.roberta.tt.ttnn_optimized_roberta-8-384-deepset/roberta-large-squad2-<address_to_your_customized_inputs_file.json>]`. This file is expected to have exactly 16 inputs. | ||
|
||
Our second demo is designed to run SQuADV2 dataset, run this with `pytest --disable-warnings models/demos/wormhole/roberta/demo/demo.py::test_demo_squadv2[wormhole_b0-True-models.demos.wormhole.roberta.tt.ttnn_optimized_roberta-8-384-3-deepset/roberta-large-squad2]`. | ||
|
||
If you wish to run for `n_iterations` samples, use `pytest --disable-warnings models/demos/wormhole/roberta/demo/demo.py::test_demo_squadv2[wormhole_b0-True-models.demos.wormhole.roberta.tt.ttnn_optimized_roberta-8-384-<n_iterations>-deepset/roberta-large-squad2]` | ||
|
||
## Inputs | ||
The demo receives inputs from respective `input_data.json` by default. To modify the inputs or specify a different path, adjust the input_path parameter in the command accordingly. It's recommended to avoid direct modifications to the input_data.json file. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,328 @@ | ||
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import json | ||
import pytest | ||
import torch | ||
from loguru import logger | ||
import transformers | ||
import ttnn | ||
import evaluate | ||
from models.utility_functions import ( | ||
disable_compilation_reports, | ||
disable_persistent_kernel_cache, | ||
profiler, | ||
is_wormhole_b0, | ||
run_for_wormhole_b0, | ||
) | ||
from models.demos.wormhole.roberta.tt import ttnn_optimized_roberta | ||
|
||
from models.datasets.dataset_squadv2 import squadv2_1K_samples_input, squadv2_answer_decode_batch | ||
from ttnn.model_preprocessing import ( | ||
preprocess_model_parameters, | ||
) | ||
|
||
from ttnn.model_preprocessing import * | ||
from transformers import RobertaForQuestionAnswering, pipeline, RobertaTokenizer | ||
|
||
import evaluate | ||
|
||
|
||
def load_inputs(input_path, batch): | ||
with open(input_path) as f: | ||
input_data = json.load(f) | ||
assert len(input_data) >= batch, f"Input data needs to have at least {batch} (batch size) entries." | ||
|
||
context = [] | ||
question = [] | ||
for i in range(batch): | ||
context.append(input_data[i]["context"]) | ||
question.append(input_data[i]["question"]) | ||
|
||
return context, question | ||
|
||
|
||
def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0): | ||
mask = input_ids.ne(padding_idx).int() | ||
incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask | ||
return incremental_indices.long() + padding_idx | ||
|
||
|
||
def run_roberta_question_and_answering_inference( | ||
device, | ||
use_program_cache, | ||
model_name, | ||
batch_size, | ||
sequence_size, | ||
bert, | ||
input_path, | ||
): | ||
disable_persistent_kernel_cache() | ||
|
||
hugging_face_reference_model = RobertaForQuestionAnswering.from_pretrained(model_name) | ||
hugging_face_reference_model.eval() | ||
|
||
# set up tokenizer | ||
tokenizer = RobertaTokenizer.from_pretrained(model_name) | ||
config = hugging_face_reference_model.config | ||
nlp = pipeline("question-answering", model=hugging_face_reference_model, tokenizer=tokenizer) | ||
config.use_dram = True | ||
|
||
tt_model_name = f"ttnn_{model_name}_optimized" | ||
mesh_device_flag = is_wormhole_b0() and ttnn.GetNumAvailableDevices() == 2 | ||
batch_size = 16 if mesh_device_flag else 8 | ||
inputs_mesh_mapper = ttnn.ShardTensorToMesh(device, dim=0) | ||
weights_mesh_mapper = ttnn.ReplicateTensorToMesh(device) | ||
output_mesh_composer = ttnn.ConcatMeshToTensor(device, dim=0) | ||
profiler.start(f"preprocessing_parameter") | ||
with ttnn.distribute(ttnn.ReplicateTensorToMesh(device)): | ||
parameters = preprocess_model_parameters( | ||
model_name=tt_model_name, | ||
initialize_model=lambda: transformers.RobertaForQuestionAnswering.from_pretrained( | ||
model_name, torchscript=False | ||
), | ||
custom_preprocessor=ttnn_optimized_roberta.custom_preprocessor, | ||
device=device, | ||
) | ||
profiler.end(f"preprocessing_parameter") | ||
|
||
context, question = load_inputs(input_path, batch_size) | ||
|
||
preprocess_params, _, postprocess_params = nlp._sanitize_parameters() | ||
preprocess_params["max_seq_len"] = sequence_size | ||
inputs = nlp._args_parser({"context": context, "question": question}) | ||
preprocessed_inputs = [] | ||
for i in range(batch_size): | ||
model_input = next(nlp.preprocess(inputs[0][i], **preprocess_params)) | ||
single_input = { | ||
"example": model_input["example"], | ||
"inputs": model_input, | ||
} | ||
preprocessed_inputs.append(single_input) | ||
|
||
roberta_input = tokenizer.batch_encode_plus( | ||
zip(question, context), | ||
max_length=sequence_size, | ||
padding="max_length", | ||
truncation=True, | ||
return_attention_mask=True, | ||
return_token_type_ids=True, | ||
return_tensors="pt", | ||
) | ||
|
||
profiler.start(f"preprocessing_input") | ||
|
||
position_ids = create_position_ids_from_input_ids( | ||
input_ids=roberta_input.input_ids, padding_idx=config.pad_token_id | ||
) | ||
ttnn_roberta_inputs = bert.preprocess_inputs( | ||
roberta_input["input_ids"], | ||
roberta_input["token_type_ids"], | ||
position_ids, | ||
roberta_input["attention_mask"], | ||
device=device, | ||
inputs_mesh_mapper=inputs_mesh_mapper, | ||
) | ||
profiler.end(f"preprocessing_input") | ||
|
||
profiler.start(f"inference_time") | ||
tt_output = bert.bert_for_question_answering( | ||
config, | ||
*ttnn_roberta_inputs, | ||
parameters=parameters, | ||
name="roberta", | ||
) | ||
profiler.end(f"inference_time") | ||
|
||
tt_output = ( | ||
ttnn.to_torch(ttnn.from_device(tt_output), mesh_composer=output_mesh_composer) | ||
.reshape(batch_size, 1, sequence_size, -1) | ||
.to(torch.float32) | ||
) | ||
|
||
tt_start_logits = tt_output[..., :, 0].squeeze(1) | ||
tt_end_logits = tt_output[..., :, 1].squeeze(1) | ||
|
||
model_answers = {} | ||
profiler.start("post_processing_output_to_string") | ||
for i in range(batch_size): | ||
tt_res = { | ||
"start": tt_start_logits[i], | ||
"end": tt_end_logits[i], | ||
"example": preprocessed_inputs[i]["example"], | ||
**preprocessed_inputs[i]["inputs"], | ||
} | ||
|
||
tt_answer = nlp.postprocess([tt_res], **postprocess_params) | ||
|
||
logger.info(f"answer: {tt_answer['answer']}\n") | ||
model_answers[i] = tt_answer["answer"] | ||
|
||
profiler.end("post_processing_output_to_string") | ||
|
||
measurements = { | ||
"preprocessing_parameter": profiler.get("preprocessing_parameter"), | ||
"preprocessing_input": profiler.get("preprocessing_input"), | ||
"inference_time": profiler.get("inference_time"), | ||
"post_processing": profiler.get("post_processing_output_to_string"), | ||
} | ||
logger.info(f"preprocessing_parameter: {measurements['preprocessing_parameter']} s") | ||
logger.info(f"preprocessing_input: {measurements['preprocessing_input']} s") | ||
logger.info(f"inference_time: {measurements['inference_time']} s") | ||
logger.info(f"post_processing : {measurements['post_processing']} s") | ||
|
||
return measurements | ||
|
||
|
||
def run_roberta_question_and_answering_inference_squad_v2( | ||
device, | ||
use_program_cache, | ||
model_name, | ||
batch_size, | ||
sequence_size, | ||
bert, | ||
n_iterations, | ||
): | ||
disable_persistent_kernel_cache() | ||
|
||
hugging_face_reference_model = RobertaForQuestionAnswering.from_pretrained(model_name) | ||
hugging_face_reference_model.eval() | ||
|
||
# set up tokenizer | ||
tokenizer = RobertaTokenizer.from_pretrained(model_name) | ||
config = hugging_face_reference_model.config | ||
config.use_dram = True | ||
|
||
tt_model_name = f"ttnn_{model_name}_optimized" | ||
mesh_device_flag = is_wormhole_b0() and ttnn.GetNumAvailableDevices() == 2 | ||
batch_size = 16 if mesh_device_flag else 8 | ||
inputs_mesh_mapper = ttnn.ShardTensorToMesh(device, dim=0) | ||
weights_mesh_mapper = ttnn.ReplicateTensorToMesh(device) | ||
output_mesh_composer = ttnn.ConcatMeshToTensor(device, dim=0) | ||
with ttnn.distribute(ttnn.ReplicateTensorToMesh(device)): | ||
parameters = preprocess_model_parameters( | ||
model_name=tt_model_name, | ||
initialize_model=lambda: transformers.RobertaForQuestionAnswering.from_pretrained( | ||
model_name, torchscript=False | ||
), | ||
custom_preprocessor=ttnn_optimized_roberta.custom_preprocessor, | ||
device=device, | ||
) | ||
|
||
nlp = pipeline("question-answering", model=hugging_face_reference_model, tokenizer=tokenizer) | ||
|
||
attention_mask = True | ||
token_type_ids = True | ||
inputs_squadv2 = squadv2_1K_samples_input(tokenizer, sequence_size, attention_mask, token_type_ids, batch_size) | ||
squad_metric = evaluate.load("squad_v2") | ||
|
||
with torch.no_grad(): | ||
pred_labels = [] | ||
cpu_pred_labels = [] | ||
true_labels = [] | ||
i = 0 | ||
for batch in inputs_squadv2: | ||
if i < n_iterations: | ||
batch_data = batch[0] | ||
curr_batch_size = batch_data["input_ids"].shape[0] | ||
position_ids = create_position_ids_from_input_ids( | ||
input_ids=batch_data.input_ids, padding_idx=config.pad_token_id | ||
) | ||
ttnn_roberta_inputs = bert.preprocess_inputs( | ||
batch_data["input_ids"], | ||
batch_data["token_type_ids"], | ||
position_ids, | ||
batch_data["attention_mask"], | ||
device=device, | ||
inputs_mesh_mapper=inputs_mesh_mapper, | ||
) | ||
|
||
tt_output = bert.bert_for_question_answering( | ||
config, | ||
*ttnn_roberta_inputs, | ||
parameters=parameters, | ||
name="roberta", | ||
) | ||
tt_output = ( | ||
ttnn.to_torch(ttnn.from_device(tt_output), mesh_composer=output_mesh_composer) | ||
.reshape(batch_size, 1, sequence_size, -1) | ||
.to(torch.float32) | ||
) | ||
cpu_output = hugging_face_reference_model(**batch_data) | ||
references = batch[1] | ||
question = batch[2] | ||
context = batch[3] | ||
|
||
cpu_predictions, tt_predictions = squadv2_answer_decode_batch( | ||
hugging_face_reference_model, | ||
tokenizer, | ||
nlp, | ||
references, | ||
cpu_output, | ||
tt_output, | ||
curr_batch_size, | ||
question, | ||
context, | ||
) | ||
pred_labels.extend(tt_predictions) | ||
cpu_pred_labels.extend(cpu_predictions) | ||
true_labels.extend(references) | ||
|
||
del tt_output | ||
i += 1 | ||
eval_score = squad_metric.compute(predictions=pred_labels, references=true_labels) | ||
cpu_eval_score = squad_metric.compute(predictions=cpu_pred_labels, references=true_labels) | ||
logger.info(f"TT_Eval: exact: {eval_score['exact']} -- F1: {eval_score['f1']}") | ||
|
||
|
||
@run_for_wormhole_b0() | ||
@pytest.mark.parametrize( | ||
"model_name, input_loc", | ||
((["deepset/roberta-large-squad2", "models/demos/wormhole/roberta/demo/input_data.json"]),), | ||
) | ||
@pytest.mark.parametrize( | ||
("bert", "batch_size", "sequence_size"), | ||
((ttnn_optimized_roberta, 8, 384),), | ||
) | ||
def test_demo(mesh_device, use_program_cache, model_name, input_loc, bert, batch_size, sequence_size): | ||
disable_persistent_kernel_cache() | ||
disable_compilation_reports() | ||
|
||
return run_roberta_question_and_answering_inference( | ||
device=mesh_device, | ||
use_program_cache=use_program_cache, | ||
model_name=model_name, | ||
batch_size=batch_size, | ||
sequence_size=sequence_size, | ||
bert=bert, | ||
input_path=input_loc, | ||
) | ||
|
||
|
||
@pytest.mark.parametrize("model_name", ["deepset/roberta-large-squad2"]) | ||
@pytest.mark.parametrize( | ||
("bert", "batch_size", "sequence_size", "n_iterations"), | ||
((ttnn_optimized_roberta, 8, 384, 3),), | ||
) | ||
def test_demo_squadv2( | ||
mesh_device, | ||
use_program_cache, | ||
model_name, | ||
bert, | ||
batch_size, | ||
sequence_size, | ||
n_iterations, | ||
): | ||
disable_persistent_kernel_cache() | ||
disable_compilation_reports() | ||
|
||
return run_roberta_question_and_answering_inference_squad_v2( | ||
device=mesh_device, | ||
use_program_cache=use_program_cache, | ||
model_name=model_name, | ||
batch_size=batch_size, | ||
sequence_size=sequence_size, | ||
bert=bert, | ||
n_iterations=n_iterations, | ||
) |
Oops, something went wrong.