-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathutils.py
109 lines (86 loc) · 3.81 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os
from typing import Dict, List, Tuple
import random
import numpy as np
import torch
import shutil
import re
import pickle
import logging
from transformers import (
MODEL_WITH_LM_HEAD_MAPPING,
WEIGHTS_NAME,
AdamW,
AutoConfig,
AutoModelWithLMHead,
AutoTokenizer,
PreTrainedModel,
PreTrainedTokenizer,
get_linear_schedule_with_warmup,
)
def construct_conv(row, tokenizer, eos = True):
flatten = lambda l: [item for sublist in l for item in sublist]
conv = list(reversed([tokenizer.encode(x) + [tokenizer.eos_token_id] for x in row]))
conv = flatten(conv)
return conv
class ConversationDataset():
def __init__(self, tokenizer: PreTrainedTokenizer, args, df, block_size=512):
logger = logging.getLogger(__name__)
block_size = block_size - (tokenizer.model_max_length - tokenizer.max_len_single_sentence)
directory = args.cache_dir
cached_features_file = os.path.join(
directory, args.model_type + "_cached_lm_" + str(block_size)
)
if os.path.exists(cached_features_file) and not args.overwrite_cache:
logger.info("Loading features from cached file %s", cached_features_file)
with open(cached_features_file, "rb") as handle:
self.examples = pickle.load(handle)
else:
logger.info("Creating features from dataset file at %s", directory)
self.examples = []
for _, row in df.iterrows():
conv = construct_conv(row, tokenizer)
self.examples.append(conv)
logger.info("Saving features into cached file %s", cached_features_file)
with open(cached_features_file, "wb") as handle:
pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)
def __len__(self):
return len(self.examples)
def __getitem__(self, item):
return torch.tensor(self.examples[item], dtype=torch.long)
# Cacheing and storing of data/checkpoints
def load_and_cache_examples(args, tokenizer, df_trn, df_val, evaluate=False):
return ConversationDataset(tokenizer, args, df_val if evaluate else df_trn)
def set_seed(args):
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.n_gpu > 0:
torch.cuda.manual_seed_all(args.seed)
def _sorted_checkpoints(args, checkpoint_prefix="checkpoint", use_mtime=False) -> List[str]:
ordering_and_checkpoint_path = []
glob_checkpoints = glob.glob(os.path.join(args.output_dir, "{}-*".format(checkpoint_prefix)))
for path in glob_checkpoints:
if use_mtime:
ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
else:
regex_match = re.match(".*{}-([0-9]+)".format(checkpoint_prefix), path)
if regex_match and regex_match.groups():
ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))
checkpoints_sorted = sorted(ordering_and_checkpoint_path)
checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
return checkpoints_sorted
def _rotate_checkpoints(args, checkpoint_prefix="checkpoint", use_mtime=False) -> None:
if not args.save_total_limit:
return
if args.save_total_limit <= 0:
return
# Check if we should delete older checkpoint(s)
checkpoints_sorted = _sorted_checkpoints(args, checkpoint_prefix, use_mtime)
if len(checkpoints_sorted) <= args.save_total_limit:
return
number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - args.save_total_limit)
checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
for checkpoint in checkpoints_to_be_deleted:
logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint))
shutil.rmtree(checkpoint)