forked from patil-suraj/question_generation
-
Notifications
You must be signed in to change notification settings - Fork 0
/
data_collator.py
64 lines (52 loc) · 2.41 KB
/
data_collator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from typing import Dict, List, Optional
import torch
def trim_batch(
input_ids, pad_token_id, attention_mask=None,
):
"""Remove columns that are populated exclusively by pad_token_id"""
keep_column_mask = input_ids.ne(pad_token_id).any(dim=0)
if attention_mask is None:
return input_ids[:, keep_column_mask]
else:
return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])
# prepares lm_labels from target_ids, returns examples with keys as expected by the forward method
# this is necessacry because the trainer directly passes this dict as arguments to the model
# so make sure the keys match the parameter names of the forward method
class T2TDataCollator():
def __init__(self, tokenizer, model_type="t5", mode='training', using_tpu=False):
self.tokenizer = tokenizer
self.model_type = model_type
self.mode = mode
self.using_tpu = using_tpu
def __call__(self, batch: List) -> Dict[str, torch.Tensor]:
"""
Take a list of samples from a Dataset and collate them into a batch.
Returns:
A dictionary of tensors
"""
input_ids = torch.stack([example['source_ids'] for example in batch])
target_ids = torch.stack([example['target_ids'] for example in batch])
attention_mask = torch.stack([example['attention_mask'] for example in batch])
pad_token_id = self.tokenizer.pad_token_id
# don't trim on tpu, for some reason trimming leads to slower training on TPU
if not self.using_tpu:
input_ids, attention_mask = trim_batch(input_ids, pad_token_id, attention_mask=attention_mask)
target_ids = trim_batch(target_ids, pad_token_id)
if self.model_type == "t5":
lm_labels = target_ids.clone()
decoder_input_ids = None
if self.mode == 'training':
lm_labels[lm_labels[:, :] == pad_token_id] = -100
else:
decoder_input_ids = target_ids[:, :-1].contiguous()
lm_labels = target_ids[:, 1:].clone()
if self.mode == 'training':
lm_labels[target_ids[:, 1:] == pad_token_id] = -100
params = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": lm_labels,
}
if decoder_input_ids is not None:
params["decoder_input_ids"] = decoder_input_ids
return params