-
Notifications
You must be signed in to change notification settings - Fork 2.3k
/
Copy pathocrvqa_dataset.py
139 lines (122 loc) · 5.38 KB
/
ocrvqa_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 3 Community License Agreement.
import copy
import itertools
import torch
from datasets import load_dataset
# check system prompt token seq or user prompt token seq is in the current token list
def check_header(targets, seq):
for i in range(len(seq) - 3):
if seq[i : i + 3] in targets:
return True
return False
def replace_target(target, seq):
for i in range(len(seq) - 3):
if seq[i : i + 3] == target:
seq[i], seq[i + 1], seq[i + 2] = -100, -100, -100
return seq
def tokenize_dialogs(dialogs, images, processor):
text_prompt = processor.apply_chat_template(dialogs)
text_prompt = [prompt.replace('<|begin_of_text|>','') for prompt in text_prompt]
batch = processor(
images=images,
text=text_prompt,
padding=True,
return_tensors="pt",
)
label_list = []
for i in range(len(batch["input_ids"])):
dialog_tokens = batch["input_ids"][i].tolist()
labels = copy.copy(dialog_tokens)
eot_indices = [i for i, n in enumerate(labels) if n == 128009]
last_idx = 0
# system prompt header "<|start_header_id|>system<|end_header_id|>" has been tokenized to [128006, 9125, 128007]
# user prompt header "<|start_header_id|>user<|end_header_id|>" has been tokenized to [128006, 882, 128007]
prompt_header_seqs = [[128006, 9125, 128007], [128006, 882, 128007]]
for n, idx in enumerate(eot_indices):
current_seq = labels[last_idx : idx + 1]
if check_header(prompt_header_seqs, current_seq):
# found prompt header, indicating that this seq should be masked
labels[last_idx : idx + 1] = [-100] * (idx - last_idx + 1)
else:
last_idx = idx + 1
# Mask all the assistant header prompt <|start_header_id|>assistant<|end_header_id|>, which has been tokenized to [128006, 78191, 128007]
assistant_header_seq = [128006, 78191, 128007]
labels = replace_target(assistant_header_seq, labels)
# Mask the padding token and image token 128256
for i in range(len(labels)):
if (
labels[i] == processor.tokenizer.pad_token_id or labels[i] == 128256
): # 128256 is image token index
labels[i] = -100
label_list.append(labels)
batch["labels"] = torch.tensor(label_list)
return batch
def get_custom_dataset(dataset_config, processor, split, split_ratio=0.9):
# load_dataset will return DatasetDict that contains all the data in the train set
dataset_dict = load_dataset("HuggingFaceM4/the_cauldron", name="ocrvqa")
dataset = dataset_dict["train"]
# Comment out the following line to use the full dataset, for quick testing only use 2000 samples
dataset = dataset.select(range(2000))
dataset = dataset.train_test_split(
test_size=1 - split_ratio, shuffle=True, seed=42
)[split]
return dataset
class OCRVQADataCollator:
def __init__(self, processor):
self.processor = processor
self.processor.tokenizer.padding_side = (
"right" # during training, one always uses padding on the right
)
def __call__(self, samples):
dialogs, images = [], []
for sample in samples:
image_list, sample_list = sample["images"], sample["texts"]
if len(image_list) > 1:
raise ValueError("Only support one image per sample")
image = image_list[0].convert("RGB") # only use the first image
dialog = []
for sample_dict in sample_list:
if not dialog:
# only append image to the first sentence
dialog += [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": sample_dict["user"].strip()},
],
},
{
"role": "assistant",
"content": [
{
"type": "text",
"text": sample_dict["assistant"].strip(),
}
],
},
]
else:
dialog += [
{
"role": "user",
"content": [
{"type": "text", "text": sample_dict["user"].strip()}
],
},
{
"role": "assistant",
"content": [
{
"type": "text",
"text": sample_dict["assistant"].strip(),
}
],
},
]
dialogs.append(dialog)
images.append([image])
return tokenize_dialogs(dialogs, images, self.processor)
def get_data_collator(processor):
return OCRVQADataCollator(processor)