forked from tenstorrent/tt-buda-demos
-
Notifications
You must be signed in to change notification settings - Fork 0
/
pytorch_dpr_reader.py
51 lines (39 loc) · 1.54 KB
/
pytorch_dpr_reader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC
# SPDX-License-Identifier: Apache-2.0
# DPR Demo Script - Reader
import pybuda
from transformers import DPRReader, DPRReaderTokenizer
def run_dpr_reader_pytorch(variant="facebook/dpr-reader-multiset-base"):
# Load Bert tokenizer and model from HuggingFace
# Variants: facebook/dpr-reader-single-nq-base, facebook/dpr-reader-multiset-base
model_ckpt = variant
tokenizer = DPRReaderTokenizer.from_pretrained(model_ckpt)
model = DPRReader.from_pretrained(model_ckpt)
compiler_cfg = pybuda.config._get_global_compiler_config() # load global compiler config object
compiler_cfg.default_df_override = pybuda._C.DataFormat.Float16_b
# Data preprocessing
input_tokens = tokenizer(
questions=["What is love?"],
titles=["Haddaway"],
texts=["'What Is Love' is a song recorded by the artist Haddaway"],
max_length=128,
padding="max_length",
truncation=True,
return_tensors="pt",
)
# Run inference on Tenstorrent device
output_q = pybuda.run_inference(
pybuda.PyTorchModule("pt_dpr_reader", model),
inputs=[input_tokens],
)
output = output_q.get()
# Postprocessing
start_logits = output[0].value()
end_logits = output[1].value()
relevance_logits = output[2].value()
# Print outputs
print(f"Start Logits: {start_logits}")
print(f"End Logits: {end_logits}")
print(f"Relevance Logits: {relevance_logits}")
if __name__ == "__main__":
run_dpr_reader_pytorch()