forked from kingoflolz/mesh-transformer-jax
-
Notifications
You must be signed in to change notification settings - Fork 1
/
prepare_dataset.py
182 lines (152 loc) · 6.33 KB
/
prepare_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
#!/usr/bin/env python
import argparse
import os
from pathlib import Path
from typing import List
# import ftfy
import tensorflow as tf
from transformers import GPT2TokenizerFast
from tqdm import tqdm
import datasets
from itertools import islice
def iter_tokens(input_ids, eos_token_id):
for token_ids in input_ids:
for token_id in token_ids:
yield (token_id)
yield (eos_token_id)
def split_every_with_padding(n, iterable, pad_token_type_id=None):
"""Splits iterable every n and fills the last chunk with pad_token_type_id
if neccessary"""
i = iter(iterable)
piece = list(islice(i, n))
while piece:
if len(piece) < n:
piece += [pad_token_type_id] * (n - len(piece))
yield piece
piece = list(islice(i, n))
def split_every(n, iterable):
"""Splits iterable in chunks of n ignoring the last chunk if not long enough"""
i = iter(iterable)
piece = list(islice(i, n))
while piece and len(piece) == n:
yield piece
piece = list(islice(i, n))
def _int64_feature(value):
"""
Returns an int64_list from a bool / enum / int / uint.
"""
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
def write_to_file(writer, data):
"""
Writes data to tfrecord file
"""
feature = {"text": _int64_feature(data)}
tf_example = tf.train.Example(features=tf.train.Features(feature=feature))
writer.write(tf_example.SerializeToString())
def write_tfrecord(sequences, fp):
with tf.io.TFRecordWriter(fp) as writer:
for idx, seq in enumerate(sequences):
write_to_file(writer, seq)
return idx
def generate_sample(dataset, epochs, key, preserve_data_order=False):
for epoch in range(epochs):
if not preserve_data_order:
dataset.set_epoch(epoch)
for sample in dataset:
yield sample[key]
def main(args):
GPT2TokenizerFast.max_model_input_sizes[
"gpt2"
] = 1e20 # disables a misleading warning
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
epochs = args.n_repack_epochs
seq_length = args.sequence_length
ncc = datasets.load_dataset(
args.dataset,
name=args.dataset_config or None,
split=args.dataset_split,
streaming=True,
use_auth_token=True,
)
if not args.preserve_data_order:
ncc = ncc.shuffle(args.dataset_buffer_size, seed=args.seed)
ncc = ncc.map(lambda x: tokenizer(x[args.dataset_text_column]), batched=True)
seqs = tqdm(
split_every(
seq_length,
iter_tokens(
generate_sample(ncc, epochs, "input_ids", args.preserve_data_order), tokenizer.eos_token_id
),
),
desc="Writing token ids as TF records",
)
filepath = args.output_dir / f"{args.name}.tfrecords"
seq_count = write_tfrecord(seqs, filepath.as_posix())
filepath_seq = args.output_dir / f"{args.name}_{seq_count}.tfrecords"
os.rename(filepath.as_posix(), filepath_seq.as_posix())
def parse_args():
parser = argparse.ArgumentParser(description="""
Converts a text dataset from Huggingface into the training data format expected by the model.
This script creates a single .tfrecords file as output
- Why: the model's data loader ignores "trailing" data (< 1 batch) at the end of a .tfrecords file
- this causes data loss if you have many .tfrecords files
- This is probably not appropriate for very large datasets
""", formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument("name", type=str,
help="Name of output file will be {name}_{seqnum}.tfrecords, where seqnum is total sequence count")
parser.add_argument(
"dataset",
type=str,
help="Dataset path or hub name.",
)
parser.add_argument(
"--dataset_config",
type=str, default="",
help="Dataset config.",
)
parser.add_argument(
"--dataset_split",
type=str, default="train",
help="Dataset split. It accepts any Huggingface datasets expression for splits.",
)
parser.add_argument(
"--dataset_text_column",
type=str, default="text",
help="Dataset text field name.",
)
parser.add_argument(
"--dataset_buffer_size",
type=int, default=10_000,
help="Dataset buffer size for shuffling.",
)
parser.add_argument(
"--sequence_length",
type=int, default=2048,
help="Sequence length of each TF record.",
)
parser.add_argument("--output-dir", type=str, default="", help="Output directory (default: current directory)")
# cleaning_args = parser.add_argument_group('data cleaning arguments')
# cleaning_args.add_argument("--normalize-with-ftfy", action="store_true", help="Normalize text with ftfy")
# cleaning_args.add_argument("--normalize-with-wikitext-detokenize",
# action="store_true", help="Use wikitext detokenizer")
# minu_help = "Exclude repetitive documents made up of < MIN_UNIQUE_TOKENS unique tokens. These can produce large gradients."
# minu_help += " Set <= 0 to disable. If enabled, 200 is a good default value. (Default: 0)"
# cleaning_args.add_argument("--min-unique-tokens", type=int, default=0,
# help=minu_help)
shuffle_pack_args = parser.add_argument_group('data shuffling/packing arguments')
repack_ep_help = "Repeat the data N_REPACK_EPOCHS times, shuffled differently in each repetition. Recommended for multi-epoch training (set this to your intended number of epochs)."
shuffle_pack_args.add_argument("--n-repack-epochs",
type=int, default=1,
help=repack_ep_help
)
shuffle_pack_args.add_argument("--seed", type=int, default=10,
help="random seed for shuffling data (default: 10)")
shuffle_pack_args.add_argument("--preserve-data-order",
default=False, action="store_true",
help="Disables shuffling, so the input and output data have the same order.")
args = parser.parse_args()
# convert output_dir to pathy
args.output_dir = Path(args.output_dir)
return args
if __name__ == "__main__":
main(parse_args())