-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsearch_passage.py
111 lines (95 loc) · 3.89 KB
/
search_passage.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import torch
from transformers import AutoTokenizer, AutoModel
import datasets
import logging
import os
from arguments import get_index_parser
from tqdm import tqdm
import faiss # type: ignore
from models.dpr import mDPRBase
from util.dataset import read_queries, QueryDataset
from util.util import query_tokenizer, set_seed, test_trec_eval
from collections import defaultdict
import json
import time
from datasets import disable_caching
disable_caching()
logging.basicConfig(level = logging.INFO)
logger = logging.getLogger()
def batch_search(model, tokenizer, args, query_loader, ds):
model.eval()
runs = defaultdict(dict)
with torch.no_grad():
for item in tqdm(query_loader, desc=f"batch search ..."):
qids, query = item
q_ids, q_mask = query_tokenizer(query, args, tokenizer)
q_reps = model.query(q_ids, q_mask).detach().cpu().numpy()
scores, ranklists = ds.get_nearest_examples_batch(args.index_name, q_reps, k=args.topK)
for qid, plist, slist in zip(qids, ranklists, scores):
for pid, sc in zip(plist[args.pid_name], slist):
runs[qid][pid] = sc
return runs
def main(args):
set_seed(args.seed)
args.rank = 0 # single gpu, set rank to 0
args.device = torch.cuda.current_device()
os.makedirs(args.output_dir, exist_ok=True)
args.num_langs = len(args.langs)
try:
tokenizer = AutoTokenizer.from_pretrained(args.base_model_name)
except:
tokenizer = AutoTokenizer.from_pretrained(args.base_model_name, from_slow=True)
assert tokenizer.is_fast
if args.use_pooler:
base_encoder = AutoModel.from_pretrained(args.base_model_name, add_pooling_layer=True)
else:
base_encoder = AutoModel.from_pretrained(args.base_model_name, add_pooling_layer=False)
model = mDPRBase(base_encoder, args)
model.to(args.device)
# load checkpoint
if args.checkpoint is not None:
model.load(args.checkpoint)
logger.info("model loaded")
# read collection
ds = datasets.load_from_disk(args.collection)
args.pid_name = ds.column_names[0]
logger.info("dataset loaded")
# load faiss index
logger.info("loading faiss index ...")
ds.load_faiss_index(args.index_name, args.faiss_index, device=args.device)
# make sure the metric is correct.
assert ds.get_index(args.index_name).faiss_index.metric_type == faiss.METRIC_INNER_PRODUCT
# read query
queries = read_queries(args.test_queries)
query_list = [[qid, qtxt] for qid, qtxt in queries.items()]
dataset_query = QueryDataset(query_list)
query_loader = torch.utils.data.DataLoader(
dataset_query,
batch_size=args.batch_size,
drop_last=False, shuffle=False)
# encode query
start_time = time.time()
runs = batch_search(model, tokenizer, args, query_loader, ds)
logger.info(f"batch search finished, {round(time.time() - start_time, 3) / len(query_list)} sec/query.")
# write to file
runf = os.path.join(args.output_dir, f"test.run")
with open(runf, "wt") as runfile:
for qid in runs:
scores = list(sorted(runs[qid].items(), key=lambda x: (x[1], x[0]), reverse=True))[:args.topK]
for i, (did, score) in enumerate(scores):
runfile.write(f"{qid} 0 {did} {i+1} {score} run\n")
if args.test_qrel:
# evaludation
trec_out = test_trec_eval(args.test_qrel, runf, args.metrics, args.trec_eval)
# write trec_eval output into a file
trec_eval_outfile = os.path.join(args.output_dir, f"test.trec_eval")
trec_file = open(trec_eval_outfile, "w")
for line in trec_out:
trec_file.write(line + "\n")
json.dump(vars(args), trec_file)
trec_file.close()
logger.info("done!")
if __name__ == "__main__":
parser = get_index_parser()
args = parser.parse_args()
main(args)