Skip to content

Commit

Permalink
#0: Port functional_roberta to n300
Browse files Browse the repository at this point in the history
  • Loading branch information
kkeerthana0573 committed Sep 19, 2024
1 parent bd0b7c6 commit 1209486
Show file tree
Hide file tree
Showing 10 changed files with 218 additions and 38 deletions.
2 changes: 1 addition & 1 deletion models/demos/bert/tt/ttnn_optimized_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def bert_attention(
attention_scores = ttnn.matmul(
query,
key,
memory_config=ttnn.L1_MEMORY_CONFIG,
# memory_config=ttnn.L1_MEMORY_CONFIG,
dtype=ttnn.bfloat16,
core_grid=ttnn.CoreGrid(y=batch_size, x=num_cores_x),
)
Expand Down
21 changes: 21 additions & 0 deletions models/demos/roberta/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
## functional_roberta Demo
## How to Run

Use `pytest --disable-warnings models/demos/roberta/demo/demo.py::test_demo[models.demos.bert.tt.ttnn_optimized_bert-deepset/roberta-large-squad2-models/demos/roberta/demo/input_data.json]` to run the demo.

If you wish to run the demo for ttnn_optimized_functional_roberta, use `pytest --disable-warnings --input-path="models/demos/roberta/demo/input_data.json" models/demos/roberta/demo/demo.py::test_demo[models.demos.bert.tt.ttnn_optimized_bert-deepset/roberta-large-squad2-models/demos/roberta/demo/input_data.json]` to run the demo.

If you wish to run the demo with a different input use `pytest --disable-warnings --input-path="<address_to_your_json_file.json>" models/demos/roberta/demo/demo.py::test_demo[models.demos.bert.tt.ttnn_bert-deepset/roberta-large-squad2]`. This file is expected to have exactly 8 inputs.

Our second demo is designed to run SQuADV2 dataset, run this with `pytest --disable-warnings models/demos/roberta/demo/demo.py::test_demo_squadv2[3-models.demos.bert.tt.ttnn_bert-deepset/roberta-large-squad2]`.

If you wish to run for `n_iterations` samples, use `pytest --disable-warnings models/demos/roberta/demo/demo.py::test_demo_squadv2[<n_iterations>-models.demos.bert.tt.ttnn_bert-deepset/roberta-large-squad2]`


# Inputs
Inputs by default are provided from `input_data.json`. If you wish you to change the inputs, provide a different path to test_demo.

We do not recommend modifying `input_data.json` file.

# Details
The entry point to functional_roberta model is bert_for_question_answering in `models/demos/bert/tt/ttnn_bert.py` (`models/demos/bert/tt/ttnn_optimized_bert.py` for optimized version). The model picks up certain configs and weights from huggingface pretrained model. We have used `deepset/roberta-large-squad2` version from huggingface as our reference.
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
disable_persistent_kernel_cache,
profiler,
)
from models.demos.bert.tt import ttnn_bert
from models.demos.bert.tt import ttnn_optimized_bert
from models.demos.bert.tt import ttnn_bert, ttnn_optimized_bert

from models.datasets.dataset_squadv2 import squadv2_1K_samples_input, squadv2_answer_decode_batch
from ttnn.model_preprocessing import (
Expand All @@ -42,6 +41,12 @@ def load_inputs(input_path, batch):
return context, question


def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
mask = input_ids.ne(padding_idx).int()
incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
return incremental_indices.long() + padding_idx


def run_roberta_question_and_answering_inference(
device,
use_program_cache,
Expand Down Expand Up @@ -105,10 +110,14 @@ def run_roberta_question_and_answering_inference(

profiler.start(f"preprocessing_input")

position_ids = create_position_ids_from_input_ids(
input_ids=roberta_input.input_ids, padding_idx=config.pad_token_id
)
ttnn_roberta_inputs = bert.preprocess_inputs(
roberta_input["input_ids"],
roberta_input["token_type_ids"],
torch.zeros(1, sequence_size) if bert == ttnn_optimized_bert else None,
position_ids,
roberta_input["attention_mask"],
device=device,
)
profiler.end(f"preprocessing_input")
Expand Down Expand Up @@ -208,10 +217,14 @@ def run_roberta_question_and_answering_inference_squad_v2(
if i < n_iterations:
batch_data = batch[0]
curr_batch_size = batch_data["input_ids"].shape[0]
position_ids = create_position_ids_from_input_ids(
input_ids=batch_data.input_ids, padding_idx=config.pad_token_id
)
ttnn_roberta_inputs = bert.preprocess_inputs(
batch_data["input_ids"],
batch_data["token_type_ids"],
torch.zeros(1, sequence_size) if bert == ttnn_optimized_bert else None,
position_ids,
batch_data["attention_mask"],
device=device,
)

Expand Down Expand Up @@ -253,10 +266,13 @@ def run_roberta_question_and_answering_inference_squad_v2(
logger.info(f"\tTT_Eval: exact: {eval_score['exact']} -- F1: {eval_score['f1']}")


@pytest.mark.parametrize("model_name", ["deepset/roberta-large-squad2"])
@pytest.mark.parametrize("bert", [ttnn_bert, ttnn_optimized_bert])
@pytest.mark.parametrize(
"model_name, input_loc",
((["deepset/roberta-large-squad2", "models/demos/roberta/demo/input_data.json"]),),
)
@pytest.mark.parametrize("bert", [ttnn_optimized_bert, ttnn_bert])
def test_demo(
input_path,
input_loc,
model_name,
bert,
device,
Expand All @@ -272,12 +288,12 @@ def test_demo(
batch_size=8,
sequence_size=384,
bert=bert,
input_path=input_path,
input_path=input_loc,
)


@pytest.mark.parametrize("model_name", ["deepset/roberta-large-squad2"])
@pytest.mark.parametrize("bert", [ttnn_bert, ttnn_optimized_bert])
@pytest.mark.parametrize("bert", [ttnn_optimized_bert, ttnn_bert])
@pytest.mark.parametrize(
"n_iterations",
((3),),
Expand Down
File renamed without changes.
39 changes: 39 additions & 0 deletions models/demos/roberta/tests/test_perf_device_roberta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import pytest
from models.utility_functions import skip_for_wormhole_b0, is_grayskull
from models.perf.device_perf_utils import run_device_perf, check_device_perf, prep_device_perf_report


@pytest.mark.models_device_performance_bare_metal
@pytest.mark.parametrize(
"batch_size, test, expected_perf",
[
[
8,
"sequence_size=384-batch_size=8-model_name=deepset/roberta-large-squad2",
66.7 if is_grayskull() else 129.8,
],
],
)
def test_perf_device_bare_metal(batch_size, test, expected_perf):
subdir = "ttnn_roberta_"
num_iterations = 3
margin = 0.03
command = f"pytest tests/ttnn/integration_tests/roberta/test_ttnn_opitmized_roberta.py::test_roberta_for_question_answering[{test}]"
cols = ["DEVICE FW", "DEVICE KERNEL", "DEVICE BRISC KERNEL"]

inference_time_key = "AVG DEVICE KERNEL SAMPLES/S"
expected_perf_cols = {inference_time_key: expected_perf}

post_processed_results = run_device_perf(command, subdir, num_iterations, cols, batch_size)
expected_results = check_device_perf(post_processed_results, margin, expected_perf_cols)
prep_device_perf_report(
model_name=f"ttnn_roberta_{batch_size}",
batch_size=batch_size,
post_processed_results=post_processed_results,
expected_results=expected_results,
comments=test.replace("/", "_"),
)
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@
from ttnn.model_preprocessing import preprocess_model_parameters

from models.utility_functions import (
is_wormhole_b0,
is_blackhole,
skip_for_wormhole_b0,
enable_persistent_kernel_cache,
disable_persistent_kernel_cache,
)
Expand All @@ -29,12 +28,11 @@

def get_expected_times(bert):
return {
ttnn_bert: (13, 32),
ttnn_optimized_bert: (12, 0.092),
ttnn_bert: (13, 5.5),
ttnn_optimized_bert: (12, 0.12),
}[bert]


@pytest.mark.skipif(is_wormhole_b0() or is_blackhole(), reason="Unsupported on WH and BH")
@pytest.mark.models_performance_bare_metal
@pytest.mark.models_performance_virtual_machine
@pytest.mark.parametrize("model_name", ["deepset/roberta-large-squad2"])
Expand All @@ -52,9 +50,9 @@ def test_performance(device, use_program_cache, model_name, batch_size, sequence
torch_attention_mask = torch.zeros(1, sequence_size) if bert == ttnn_optimized_bert else None

if bert == ttnn_bert:
tt_model_name = f"ttnn_{model_name}"
tt_model_name = f"ttnn_roberta_{model_name}"
elif bert == ttnn_optimized_bert:
tt_model_name = f"ttnn_{model_name}_optimized"
tt_model_name = f"ttnn_roberta_{model_name}_optimized"
else:
raise ValueError(f"Unknown functional_roberta: {bert}")

Expand Down
21 changes: 0 additions & 21 deletions models/experimental/functional_roberta/README.md

This file was deleted.

8 changes: 8 additions & 0 deletions tests/scripts/run_performance.sh
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ run_perf_models_other() {

env pytest -n auto models/demos/metal_BERT_large_11/tests -m $test_marker

env pytest -n auto models/demos/roberta/tests/test_performance.py -m $test_marker

## Merge all the generated reports
env python models/perf/merge_perf_results.py
}
Expand Down Expand Up @@ -82,6 +84,9 @@ run_device_perf_models() {
env pytest models/demos/ttnn_falcon7b/tests -m $test_marker --timeout=360

env pytest models/demos/bert/tests -m $test_marker

env pytest models/demos/roberta/tests/test_perf_device_roberta.py -m $test_marker

fi

if [ "$tt_arch" == "wormhole_b0" ]; then
Expand All @@ -94,6 +99,9 @@ run_device_perf_models() {
env WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest models/demos/metal_BERT_large_11/tests -m $test_marker

env WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest models/demos/falcon7b_common/tests -m $test_marker

env WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest models/demos/roberta/tests/test_perf_device_roberta.py -m $test_marker

fi

## Merge all the generated reports
Expand Down
3 changes: 3 additions & 0 deletions tests/scripts/single_card/run_single_card_demo_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ run_common_func_tests() {
# Resnet
WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest -n auto --disable-warnings models/demos/wormhole/resnet50/demo/demo.py; fail+=$?

# RoBERTa
WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest -n auto --disable-warnings -q -s --input-method=json --input-path='models/demos/roberta/demo/input_data.json' models/demos/roberta/demo/demo.py --timeout 420; fail+=$?

return $fail
}

Expand Down
116 changes: 116 additions & 0 deletions tests/ttnn/integration_tests/roberta/test_ttnn_opitmized_roberta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import ttnn
import torch
import pytest
import tt_lib
import transformers

from models.demos.bert.tt import ttnn_optimized_bert
from ttnn.model_preprocessing import preprocess_model_parameters
from tests.ttnn.utils_for_testing import assert_with_pcc


@pytest.mark.parametrize("model_name", ["deepset/roberta-large-squad2"])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("sequence_size", [384])
def test_roberta(device, use_program_cache, reset_seeds, model_name, batch_size, sequence_size):
config = transformers.RobertaConfig.from_pretrained(model_name)
model = transformers.RobertaModel.from_pretrained(model_name)

input_ids = torch.randint(0, config.vocab_size, (batch_size, sequence_size)).to(torch.int32)
torch_token_type_ids = torch.zeros((batch_size, sequence_size), dtype=torch.int32)
torch_position_ids = torch.zeros((batch_size, sequence_size), dtype=torch.int32)
torch_attention_mask = torch.ones(batch_size, sequence_size)

torch_output = model(
input_ids=input_ids,
attention_mask=torch_attention_mask,
token_type_ids=torch_token_type_ids,
position_ids=torch_position_ids,
)
torch_output = torch_output.last_hidden_state

tt_model_name = f"ttnn_{model_name}_optimized"

parameters = preprocess_model_parameters(
model_name=tt_model_name,
initialize_model=lambda: transformers.RobertaModel.from_pretrained(model_name, torchscript=False).eval(),
custom_preprocessor=ttnn_optimized_bert.custom_preprocessor,
device=device,
)

ttnn_roberta_inputs = ttnn_optimized_bert.preprocess_inputs(
input_ids,
torch_token_type_ids,
torch_position_ids,
torch_attention_mask,
device=device,
)

tt_output = ttnn_optimized_bert.bert(
config,
*ttnn_roberta_inputs,
parameters=parameters,
)

tt_output = ttnn.to_torch(tt_output)

assert_with_pcc(torch_output, tt_output, 0.94)


@pytest.mark.parametrize("model_name", ["deepset/roberta-large-squad2"])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("sequence_size", [384])
def test_roberta_for_question_answering(device, use_program_cache, reset_seeds, model_name, batch_size, sequence_size):
config = transformers.RobertaConfig.from_pretrained(model_name)
model = transformers.RobertaForQuestionAnswering.from_pretrained(model_name)

input_ids = torch.randint(0, config.vocab_size, (batch_size, sequence_size)).to(torch.int32)
torch_token_type_ids = torch.zeros((batch_size, sequence_size), dtype=torch.int32)
torch_position_ids = torch.zeros((batch_size, sequence_size), dtype=torch.int32)
torch_attention_mask = torch.ones(batch_size, sequence_size)

torch_output = model(
input_ids=input_ids,
attention_mask=torch_attention_mask,
token_type_ids=torch_token_type_ids,
position_ids=torch_position_ids,
)
torch_output_start_logits = torch_output.start_logits
torch_output_end_logits = torch_output.end_logits

tt_model_name = f"ttnn_{model_name}_optimized"

parameters = preprocess_model_parameters(
model_name=tt_model_name,
initialize_model=lambda: transformers.RobertaForQuestionAnswering.from_pretrained(
model_name, torchscript=False
).eval(),
custom_preprocessor=ttnn_optimized_bert.custom_preprocessor,
device=device,
)

ttnn_roberta_inputs = ttnn_optimized_bert.preprocess_inputs(
input_ids,
torch_token_type_ids,
torch_position_ids,
torch_attention_mask,
device=device,
)

tt_output = ttnn_optimized_bert.bert_for_question_answering(
config,
*ttnn_roberta_inputs,
parameters=parameters,
name="roberta",
)
tt_output = ttnn.to_torch(tt_output)

tt_output_start_logits = tt_output[..., :, 0]
tt_output_end_logits = tt_output[..., :, 1]

assert_with_pcc(torch_output_start_logits, tt_output_start_logits, 0.88)
assert_with_pcc(torch_output_end_logits, tt_output_end_logits, 0.89)

0 comments on commit 1209486

Please sign in to comment.