-
Notifications
You must be signed in to change notification settings - Fork 0
/
dense_retriever.py
692 lines (592 loc) · 24.8 KB
/
dense_retriever.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Command line tool to get dense results and validate them
"""
import glob
import json
import logging
import pickle
import time
import zlib
from typing import List, Tuple, Dict, Iterator
import hydra
import numpy as np
import torch
from omegaconf import DictConfig, OmegaConf
from torch import Tensor as T
from torch import nn
from dpr.utils.data_utils import RepTokenSelector
from dpr.data.qa_validation import calculate_matches, calculate_chunked_matches, calculate_matches_from_meta
from dpr.data.retriever_data import KiltCsvCtxSrc, TableChunk
from dpr.indexer.faiss_indexers import (
DenseIndexer,
)
from dpr.models import init_biencoder_components
from dpr.models.biencoder import (
BiEncoder,
_select_span_with_token,
)
from dpr.options import setup_logger, setup_cfg_gpu, set_cfg_params_from_state
from dpr.utils.data_utils import Tensorizer
from dpr.utils.model_utils import setup_for_distributed_mode, get_model_obj, load_states_from_checkpoint
logger = logging.getLogger()
setup_logger(logger)
def generate_question_vectors(
question_encoder: torch.nn.Module,
tensorizer: Tensorizer,
questions: List[str],
entities: List[str],
entity_spans: List[Tuple[int]],
bsz: int,
query_token: str = None,
selector: RepTokenSelector = None,
) -> T:
n = len(questions)
query_vectors = []
with torch.no_grad():
for j, batch_start in enumerate(range(0, n, bsz)):
batch_questions = questions[batch_start : batch_start + bsz]
batch_entities = entities[batch_start : batch_start + bsz]
batch_spans = entity_spans[batch_start : batch_start + bsz]
if query_token:
# TODO: tmp workaround for EL, remove or revise
if query_token == "[START_ENT]":
batch_tensors = [
_select_span_with_token(q, tensorizer, token_str=query_token) for q in batch_questions
]
else:
batch_tensors = [tensorizer.text_to_tensor(" ".join([query_token, q])) for q in batch_questions]
elif isinstance(batch_questions[0], T):
batch_tensors = [q for q in batch_questions]
else:
batch_tensors = []
batch_ent_tensors = []
batch_ent_pos_ids = []
for q, ent, span in zip(batch_questions, batch_entities, batch_spans):
token_ids, entity_ids, ent_pos_ids = tensorizer.text_to_tensor(
q,
entities=ent,
entity_spans=span
)
batch_tensors.append(token_ids)
batch_ent_tensors.append(entity_ids)
batch_ent_pos_ids.append(ent_pos_ids)
# # TODO: this only works for Wav2vec pipeline but will crash the regular text pipeline
# max_vector_len = max(q_t.size(1) for q_t in batch_tensors)
# min_vector_len = min(q_t.size(1) for q_t in batch_tensors)
# if max_vector_len != min_vector_len:
# # TODO: _pad_to_len move to utils
# from dpr.models.reader import _pad_to_len
# batch_tensors = [_pad_to_len(q.squeeze(0), 0, max_vector_len) for q in batch_tensors]
q_ids_batch = torch.stack(batch_tensors, dim=0).cuda()
q_seg_batch = torch.zeros_like(q_ids_batch).cuda()
q_attn_mask = tensorizer.get_attn_mask(q_ids_batch)
q_ent_ids_batch = torch.stack(batch_ent_tensors, dim=0).cuda()
q_ent_seg_batch = torch.zeros_like(q_ent_ids_batch).cuda()
q_ent_attn_mask = tensorizer.get_attn_mask(q_ent_ids_batch)
q_ent_pos_ids = torch.stack(batch_ent_pos_ids, dim=0).cuda()
if selector:
rep_positions = selector.get_positions(q_ids_batch, tensorizer)
_, out, _, _ = BiEncoder.get_representation(
question_encoder,
q_ids_batch,
q_seg_batch,
q_attn_mask,
q_ent_ids_batch,
q_ent_seg_batch,
q_ent_attn_mask,
q_ent_pos_ids,
representation_token_pos=rep_positions,
)
else:
_, out, _, _ = question_encoder(
q_ids_batch,
q_seg_batch,
q_attn_mask,
q_ent_ids_batch,
q_ent_seg_batch,
q_ent_attn_mask,
q_ent_pos_ids
)
query_vectors.extend(out.cpu().split(1, dim=0))
if len(query_vectors) % 100 == 0:
logger.info("Encoded queries %d", len(query_vectors))
query_tensor = torch.cat(query_vectors, dim=0)
logger.info("Total encoded queries tensor %s", query_tensor.size())
assert query_tensor.size(0) == len(questions)
return query_tensor
class DenseRetriever(object):
def __init__(self, question_encoder: nn.Module, batch_size: int, tensorizer: Tensorizer):
self.question_encoder = question_encoder
self.batch_size = batch_size
self.tensorizer = tensorizer
self.selector = None
def generate_question_vectors(self, questions: List[str], entities: List[str], entity_spans: List[Tuple[int]], query_token: str = None) -> T:
bsz = self.batch_size
self.question_encoder.eval()
return generate_question_vectors(
self.question_encoder,
self.tensorizer,
questions,
entities,
entity_spans,
bsz,
query_token=query_token,
selector=self.selector,
)
class LocalFaissRetriever(DenseRetriever):
"""
Does passage retrieving over the provided index and question encoder
"""
def __init__(
self,
question_encoder: nn.Module,
batch_size: int,
tensorizer: Tensorizer,
index: DenseIndexer,
):
super().__init__(question_encoder, batch_size, tensorizer)
self.index = index
def index_encoded_data(
self,
vector_files: List[str],
buffer_size: int,
path_id_prefixes: List = None,
):
"""
Indexes encoded passages takes form a list of files
:param vector_files: file names to get passages vectors from
:param buffer_size: size of a buffer (amount of passages) to send for the indexing at once
:return:
"""
buffer = []
for i, item in enumerate(iterate_encoded_files(vector_files, path_id_prefixes=path_id_prefixes)):
buffer.append(item)
if 0 < buffer_size == len(buffer):
self.index.index_data(buffer)
buffer = []
self.index.index_data(buffer)
logger.info("Data indexing completed.")
def get_top_docs(self, query_vectors: np.array, top_docs: int = 100) -> List[Tuple[List[object], List[float]]]:
"""
Does the retrieval of the best matching passages given the query vectors batch
:param query_vectors:
:param top_docs:
:return:
"""
time0 = time.time()
results = self.index.search_knn(query_vectors, top_docs)
logger.info("index search time: %f sec.", time.time() - time0)
self.index = None
return results
# works only with our distributed_faiss library
class DenseRPCRetriever(DenseRetriever):
def __init__(
self,
question_encoder: nn.Module,
batch_size: int,
tensorizer: Tensorizer,
index_cfg_path: str,
dim: int,
use_l2_conversion: bool = False,
nprobe: int = 256,
):
from distributed_faiss.client import IndexClient
super().__init__(question_encoder, batch_size, tensorizer)
self.dim = dim
self.index_id = "dr"
self.nprobe = nprobe
logger.info("Connecting to index server ...")
self.index_client = IndexClient(index_cfg_path)
self.use_l2_conversion = use_l2_conversion
logger.info("Connected")
def load_index(self, index_id):
from distributed_faiss.index_cfg import IndexCfg
self.index_id = index_id
logger.info("Loading remote index %s", index_id)
idx_cfg = IndexCfg()
idx_cfg.nprobe = self.nprobe
if self.use_l2_conversion:
idx_cfg.metric = "l2"
self.index_client.load_index(self.index_id, cfg=idx_cfg, force_reload=False)
logger.info("Index loaded")
self._wait_index_ready(index_id)
def index_encoded_data(
self,
vector_files: List[str],
buffer_size: int = 1000,
path_id_prefixes: List = None,
):
"""
Indexes encoded passages takes form a list of files
:param vector_files: file names to get passages vectors from
:param buffer_size: size of a buffer (amount of passages) to send for the indexing at once
:return:
"""
from distributed_faiss.index_cfg import IndexCfg
buffer = []
idx_cfg = IndexCfg()
idx_cfg.dim = self.dim
logger.info("Index train num=%d", idx_cfg.train_num)
idx_cfg.faiss_factory = "flat"
index_id = self.index_id
self.index_client.create_index(index_id, idx_cfg)
def send_buf_data(buf, index_client):
buffer_vectors = [np.reshape(encoded_item[1], (1, -1)) for encoded_item in buf]
buffer_vectors = np.concatenate(buffer_vectors, axis=0)
meta = [encoded_item[0] for encoded_item in buf]
index_client.add_index_data(index_id, buffer_vectors, meta)
for i, item in enumerate(iterate_encoded_files(vector_files, path_id_prefixes=path_id_prefixes)):
buffer.append(item)
if 0 < buffer_size == len(buffer):
send_buf_data(buffer, self.index_client)
buffer = []
if buffer:
send_buf_data(buffer, self.index_client)
logger.info("Embeddings sent.")
self._wait_index_ready(index_id)
def get_top_docs(
self, query_vectors: np.array, top_docs: int = 100, search_batch: int = 512
) -> List[Tuple[List[object], List[float]]]:
"""
Does the retrieval of the best matching passages given the query vectors batch
:param query_vectors:
:param top_docs:
:param search_batch:
:return:
"""
if self.use_l2_conversion:
aux_dim = np.zeros(len(query_vectors), dtype="float32")
query_vectors = np.hstack((query_vectors, aux_dim.reshape(-1, 1)))
logger.info("query_hnsw_vectors %s", query_vectors.shape)
self.index_client.cfg.metric = "l2"
results = []
for i in range(0, query_vectors.shape[0], search_batch):
time0 = time.time()
query_batch = query_vectors[i : i + search_batch]
logger.info("query_batch: %s", query_batch.shape)
# scores, meta = self.index_client.search(query_batch, top_docs, self.index_id)
scores, meta = self.index_client.search_with_filter(
query_batch, top_docs, self.index_id, filter_pos=3, filter_value=True
)
logger.info("index search time: %f sec.", time.time() - time0)
results.extend([(meta[q], scores[q]) for q in range(len(scores))])
return results
def _wait_index_ready(self, index_id: str):
from distributed_faiss.index_state import IndexState
# TODO: move this method into IndexClient class
while self.index_client.get_state(index_id) != IndexState.TRAINED:
logger.info("Remote Index is not ready ...")
time.sleep(60)
logger.info(
"Remote Index is ready. Index data size %d",
self.index_client.get_ntotal(index_id),
)
def validate(
passages: Dict[object, Tuple[str, str]],
answers: List[List[str]],
result_ctx_ids: List[Tuple[List[object], List[float]]],
workers_num: int,
match_type: str,
) -> List[List[bool]]:
logger.info("validating passages. size=%d", len(passages))
match_stats = calculate_matches(passages, answers, result_ctx_ids, workers_num, match_type)
top_k_hits = match_stats.top_k_hits
logger.info("Validation results: top k documents hits %s", top_k_hits)
top_k_hits = [v / len(result_ctx_ids) for v in top_k_hits]
logger.info("Validation results: top k documents hits accuracy %s", top_k_hits)
return match_stats.questions_doc_hits
def validate_from_meta(
answers: List[List[str]],
result_ctx_ids: List[Tuple[List[object], List[float]]],
workers_num: int,
match_type: str,
meta_compressed: bool,
) -> List[List[bool]]:
match_stats = calculate_matches_from_meta(
answers, result_ctx_ids, workers_num, match_type, use_title=True, meta_compressed=meta_compressed
)
top_k_hits = match_stats.top_k_hits
logger.info("Validation results: top k documents hits %s", top_k_hits)
top_k_hits = [v / len(result_ctx_ids) for v in top_k_hits]
logger.info("Validation results: top k documents hits accuracy %s", top_k_hits)
return match_stats.questions_doc_hits
def save_results(
passages: Dict[object, Tuple[str, str]],
questions: List[str],
answers: List[List[str]],
top_passages_and_scores: List[Tuple[List[object], List[float]]],
per_question_hits: List[List[bool]],
out_file: str,
):
# join passages text with the result ids, their questions and assigning has|no answer labels
merged_data = []
# assert len(per_question_hits) == len(questions) == len(answers)
for i, q in enumerate(questions):
q_answers = answers[i]
results_and_scores = top_passages_and_scores[i]
hits = per_question_hits[i]
docs = [passages[doc_id] for doc_id in results_and_scores[0]]
scores = [str(score) for score in results_and_scores[1]]
ctxs_num = len(hits)
results_item = {
"question": q,
"answers": q_answers,
"ctxs": [
{
"id": results_and_scores[0][c],
"title": docs[c][1],
"text": docs[c][0],
"score": scores[c],
"has_answer": hits[c],
}
for c in range(ctxs_num)
],
}
# if questions_extra_attr and questions_extra:
# extra = questions_extra[i]
# results_item[questions_extra_attr] = extra
merged_data.append(results_item)
with open(out_file, "w") as writer:
writer.write(json.dumps(merged_data, indent=4) + "\n")
logger.info("Saved results * scores to %s", out_file)
# TODO: unify with save_results
def save_results_from_meta(
questions: List[str],
answers: List[List[str]],
top_passages_and_scores: List[Tuple[List[object], List[float]]],
per_question_hits: List[List[bool]],
out_file: str,
rpc_meta_compressed: bool = False,
):
# join passages text with the result ids, their questions and assigning has|no answer labels
merged_data = []
# assert len(per_question_hits) == len(questions) == len(answers)
for i, q in enumerate(questions):
q_answers = answers[i]
results_and_scores = top_passages_and_scores[i]
hits = per_question_hits[i]
docs = [doc for doc in results_and_scores[0]]
scores = [str(score) for score in results_and_scores[1]]
ctxs_num = len(hits)
results_item = {
"question": q,
"answers": q_answers,
"ctxs": [
{
"id": docs[c][0],
"title": zlib.decompress(docs[c][2]).decode() if rpc_meta_compressed else docs[c][2],
"text": zlib.decompress(docs[c][1]).decode() if rpc_meta_compressed else docs[c][1],
"is_wiki": docs[c][3],
"score": scores[c],
"has_answer": hits[c],
}
for c in range(ctxs_num)
],
}
merged_data.append(results_item)
with open(out_file, "w") as writer:
writer.write(json.dumps(merged_data, indent=4) + "\n")
logger.info("Saved results * scores to %s", out_file)
def iterate_encoded_files(vector_files: list, path_id_prefixes: List = None) -> Iterator[Tuple]:
for i, file in enumerate(vector_files):
logger.info("Reading file %s", file)
id_prefix = None
if path_id_prefixes:
id_prefix = path_id_prefixes[i]
with open(file, "rb") as reader:
doc_vectors = pickle.load(reader)
for doc in doc_vectors:
doc = list(doc)
if id_prefix and not str(doc[0]).startswith(id_prefix):
doc[0] = id_prefix + str(doc[0])
yield doc
def validate_tables(
passages: Dict[object, TableChunk],
answers: List[List[str]],
result_ctx_ids: List[Tuple[List[object], List[float]]],
workers_num: int,
match_type: str,
) -> List[List[bool]]:
match_stats = calculate_chunked_matches(passages, answers, result_ctx_ids, workers_num, match_type)
top_k_chunk_hits = match_stats.top_k_chunk_hits
top_k_table_hits = match_stats.top_k_table_hits
logger.info("Validation results: top k documents hits %s", top_k_chunk_hits)
top_k_hits = [v / len(result_ctx_ids) for v in top_k_chunk_hits]
logger.info("Validation results: top k table chunk hits accuracy %s", top_k_hits)
logger.info("Validation results: top k tables hits %s", top_k_table_hits)
top_k_table_hits = [v / len(result_ctx_ids) for v in top_k_table_hits]
logger.info("Validation results: top k tables accuracy %s", top_k_table_hits)
return match_stats.top_k_chunk_hits
def get_all_passages(ctx_sources):
all_passages = {}
for ctx_src in ctx_sources:
ctx_src.load_data_to(all_passages)
logger.info("Loaded ctx data: %d", len(all_passages))
if len(all_passages) == 0:
raise RuntimeError("No passages data found. Please specify ctx_file param properly.")
return all_passages
@hydra.main(config_path="conf", config_name="dense_retriever")
def main(cfg: DictConfig):
cfg = setup_cfg_gpu(cfg)
saved_state = load_states_from_checkpoint(cfg.model_file)
set_cfg_params_from_state(saved_state.encoder_params, cfg)
logger.info("CFG (after gpu configuration):")
logger.info("%s", OmegaConf.to_yaml(cfg))
tensorizer, encoder, _ = init_biencoder_components(cfg.encoder.encoder_model_type, cfg, inference_only=True)
logger.info("Loading saved model state ...")
encoder.load_state(saved_state, strict=False)
encoder_path = cfg.encoder_path
if encoder_path:
logger.info("Selecting encoder: %s", encoder_path)
encoder = getattr(encoder, encoder_path)
else:
logger.info("Selecting standard question encoder")
encoder = encoder.question_model
encoder, _ = setup_for_distributed_mode(encoder, None, cfg.device, cfg.n_gpu, cfg.local_rank, cfg.fp16)
encoder.eval()
model_to_load = get_model_obj(encoder)
vector_size = model_to_load.get_out_size()
logger.info("Encoder vector_size=%d", vector_size)
# get questions & answers
questions = []
questions_text = []
question_answers = []
question_entities = []
question_spans = []
if not cfg.qa_dataset:
logger.warning("Please specify qa_dataset to use")
return
ds_key = cfg.qa_dataset
logger.info("qa_dataset: %s", ds_key)
qa_src = hydra.utils.instantiate(cfg.datasets[ds_key])
qa_src.load_data()
total_queries = len(qa_src)
for i in range(total_queries):
qa_sample = qa_src[i]
question, answers, entities, spans = qa_sample.query, qa_sample.answers, qa_sample.entities, qa_sample.entity_spans
questions.append(question)
question_answers.append(answers)
question_entities.append(entities)
question_spans.append(spans)
logger.info("questions len %d", len(questions))
logger.info("questions_text len %d", len(questions_text))
if cfg.rpc_retriever_cfg_file:
index_buffer_sz = 1000
retriever = DenseRPCRetriever(
encoder,
cfg.batch_size,
tensorizer,
cfg.rpc_retriever_cfg_file,
vector_size,
use_l2_conversion=cfg.use_l2_conversion,
)
else:
index = hydra.utils.instantiate(cfg.indexers[cfg.indexer])
logger.info("Local Index class %s ", type(index))
index_buffer_sz = index.buffer_size
index.init_index(vector_size)
retriever = LocalFaissRetriever(encoder, cfg.batch_size, tensorizer, index)
logger.info("Using special token %s", qa_src.special_query_token)
questions_tensor = retriever.generate_question_vectors(questions, question_entities, question_spans, query_token=qa_src.special_query_token)
if qa_src.selector:
logger.info("Using custom representation token selector")
retriever.selector = qa_src.selector
index_path = cfg.index_path
if cfg.rpc_retriever_cfg_file and cfg.rpc_index_id:
retriever.load_index(cfg.rpc_index_id)
elif index_path and index.index_exists(index_path):
logger.info("Index path: %s", index_path)
retriever.index.deserialize(index_path)
else:
# send data for indexing
id_prefixes = []
ctx_sources = []
for ctx_src in cfg.ctx_datatsets:
ctx_src = hydra.utils.instantiate(cfg.ctx_sources[ctx_src])
id_prefixes.append(ctx_src.id_prefix)
ctx_sources.append(ctx_src)
logger.info("ctx_sources: %s", type(ctx_src))
logger.info("id_prefixes per dataset: %s", id_prefixes)
# index all passages
ctx_files_patterns = cfg.encoded_ctx_files
logger.info("ctx_files_patterns: %s", ctx_files_patterns)
if ctx_files_patterns:
assert len(ctx_files_patterns) == len(id_prefixes), "ctx len={} pref leb={}".format(
len(ctx_files_patterns), len(id_prefixes)
)
else:
assert (
index_path or cfg.rpc_index_id
), "Either encoded_ctx_files or index_path pr rpc_index_id parameter should be set."
input_paths = []
path_id_prefixes = []
for i, pattern in enumerate(ctx_files_patterns):
pattern_files = glob.glob(pattern)
pattern_id_prefix = id_prefixes[i]
input_paths.extend(pattern_files)
path_id_prefixes.extend([pattern_id_prefix] * len(pattern_files))
logger.info("Embeddings files id prefixes: %s", path_id_prefixes)
logger.info("Reading all passages data from files: %s", input_paths)
retriever.index_encoded_data(input_paths, index_buffer_sz, path_id_prefixes=path_id_prefixes)
if index_path:
retriever.index.serialize(index_path)
# get top k results
top_results_and_scores = retriever.get_top_docs(questions_tensor.numpy(), cfg.n_docs)
if cfg.use_rpc_meta:
questions_doc_hits = validate_from_meta(
question_answers,
top_results_and_scores,
cfg.validation_workers,
cfg.match,
cfg.rpc_meta_compressed,
)
if cfg.out_file:
save_results_from_meta(
questions,
question_answers,
top_results_and_scores,
questions_doc_hits,
cfg.out_file,
cfg.rpc_meta_compressed,
)
else:
all_passages = get_all_passages(ctx_sources)
if cfg.validate_as_tables:
questions_doc_hits = validate_tables(
all_passages,
question_answers,
top_results_and_scores,
cfg.validation_workers,
cfg.match,
)
else:
questions_doc_hits = validate(
all_passages,
question_answers,
top_results_and_scores,
cfg.validation_workers,
cfg.match,
)
if cfg.out_file:
save_results(
all_passages,
questions_text if questions_text else questions,
question_answers,
top_results_and_scores,
questions_doc_hits,
cfg.out_file,
)
if cfg.kilt_out_file:
kilt_ctx = next(iter([ctx for ctx in ctx_sources if isinstance(ctx, KiltCsvCtxSrc)]), None)
if not kilt_ctx:
raise RuntimeError("No Kilt compatible context file provided")
assert hasattr(cfg, "kilt_out_file")
kilt_ctx.convert_to_kilt(qa_src.kilt_gold_file, cfg.out_file, cfg.kilt_out_file)
if __name__ == "__main__":
main()