-
Notifications
You must be signed in to change notification settings - Fork 1
/
data.py
112 lines (94 loc) · 3.5 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import math
import os.path
import random
from dataclasses import dataclass
from typing import List, Tuple
import datasets
from torch.utils.data import Dataset
from transformers import DataCollatorWithPadding
from transformers import PreTrainedTokenizer, BatchEncoding
from streaming import LocalDataset
from streaming.base.format.mds.encodings import Encoding, _encodings
import json
from arguments import DataArguments
class ListStr(Encoding):
def encode(self, obj):
return json.dumps(obj).encode()
def decode(self, data):
return json.loads(data)
_encodings['liststr'] = ListStr
class TrainDatasetForEmbedding(Dataset):
def __init__(
self,
args: DataArguments,
tokenizer: PreTrainedTokenizer
):
self.dataset = LocalDataset(local = args.train_data)
self.tokenizer = tokenizer
self.args = args
self.total_len = len(self.dataset)
def __len__(self):
return self.total_len
def __getitem__(self, item) -> Tuple[BatchEncoding, List[BatchEncoding]]:
query = self.dataset[item]['query']
if self.args.query_instruction_for_retrieval is not None:
query = self.args.query_instruction_for_retrieval + query
passages = []
pos = random.choice(self.dataset[item]['pos'])
passages.append(pos)
if len(self.dataset[item]['neg']) < self.args.train_group_size - 1:
num = math.ceil((self.args.train_group_size - 1) / len(self.dataset[item]['neg']))
negs = random.sample(self.dataset[item]['neg'] * num, self.args.train_group_size - 1)
else:
negs = random.sample(self.dataset[item]['neg'], self.args.train_group_size - 1)
passages.extend(negs)
if self.args.passage_instruction_for_retrieval is not None:
passages = [self.args.passage_instruction_for_retrieval+p for p in passages]
return query, passages
@dataclass
class EmbedCollator(DataCollatorWithPadding):
"""
Wrapper that does conversion from List[Tuple[encode_qry, encode_psg]] to List[qry], List[psg]
and pass batch separately to the actual collator.
Abstract out data detail for the model.
"""
query_max_len: int = 32
passage_max_len: int = 128
def padding_score(self, teacher_score):
group_size = None
for scores in teacher_score:
if scores is not None:
group_size = len(scores)
break
if group_size is None:
return None
padding_scores = [100.0] + [0.0] * (group_size - 1)
new_teacher_score = []
for scores in teacher_score:
if scores is None:
new_teacher_score.append(padding_scores)
else:
new_teacher_score.append(scores)
return new_teacher_score
def __call__(self, features):
query = [f[0] for f in features]
passage = [f[1] for f in features]
if isinstance(query[0], list):
query = sum(query, [])
if isinstance(passage[0], list):
passage = sum(passage, [])
q_collated = self.tokenizer(
query,
padding=True,
truncation=True,
max_length=self.query_max_len,
return_tensors="pt",
)
d_collated = self.tokenizer(
passage,
padding=True,
truncation=True,
max_length=self.passage_max_len,
return_tensors="pt",
)
return {"query": q_collated, "passage": d_collated}