-
Notifications
You must be signed in to change notification settings - Fork 6
/
average_checkpoints.py
186 lines (155 loc) · 6.79 KB
/
average_checkpoints.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
#!/usr/bin/env python3
import argparse
import collections
import torch
import os
import re
from fairseq.utils import import_user_module
def default_avg_params(params_dict):
averaged_params = collections.OrderedDict()
# v should be a list of torch Tensor.
for k, v in params_dict.items():
summed_v = None
for x in v:
summed_v = summed_v + x if summed_v is not None else x
averaged_params[k] = summed_v / len(v)
return averaged_params
def ema_avg_params(params_dict, ema_decay):
averaged_params = collections.OrderedDict()
lens = [len(v) for k, v in params_dict.items()]
assert all(x == lens[0] for x in lens), f'lens params: {lens}'
num_checkpoints = lens[0]
# y = x
for k, v in params_dict.items():
# order: newest to oldest
# reverse the order
# y_t = x_t * decay + y_{t-1} * (1 - decay)
total_v = None
for x in reversed(v):
if total_v is None:
total_v = x
else:
total_v = x * ema_decay + total_v * (1.0 - ema_decay)
averaged_params[k] = total_v
return averaged_params
def average_checkpoints(inputs, ema_decay=1.0):
"""Loads checkpoints from inputs and returns a model with averaged weights.
Args:
inputs: An iterable of string paths of checkpoints to load from.
Returns:
A dict of string keys mapping to various values. The 'model' key
from the returned dict should correspond to an OrderedDict mapping
string parameter names to torch Tensors.
"""
params_dict = collections.OrderedDict()
params_keys = None
new_state = None
for i, f in enumerate(inputs):
state = torch.load(
f,
map_location=(
lambda s, _: torch.serialization.default_restore_location(s, 'cpu')
),
)
# Copies over the settings from the first checkpoint
if new_state is None:
new_state = state
model_params = state['model']
model_params_keys = list(model_params.keys())
if params_keys is None:
params_keys = model_params_keys
elif params_keys != model_params_keys:
raise KeyError(
'For checkpoint {}, expected list of params: {}, '
'but found: {}'.format(f, params_keys, model_params_keys)
)
for k in params_keys:
if k not in params_dict:
params_dict[k] = []
p = model_params[k]
if isinstance(p, torch.HalfTensor):
p = p.float()
params_dict[k].append(p)
if ema_decay < 1.0:
print(f'Exponential moving averaging, decay={ema_decay}')
averaged_params = ema_avg_params(params_dict, ema_decay)
else:
print(f'Default averaging')
averaged_params = default_avg_params(params_dict)
new_state['model'] = averaged_params
return new_state
def last_n_checkpoints(paths, n, update_based, upper_bound=None):
assert len(paths) == 1
path = paths[0]
if update_based:
pt_regexp = re.compile(r'checkpoint_\d+_(\d+)\.pt')
else:
pt_regexp = re.compile(r'checkpoint(\d+)\.pt')
files = os.listdir(path)
entries = []
for f in files:
m = pt_regexp.fullmatch(f)
if m is not None:
sort_key = int(m.group(1))
if upper_bound is None or sort_key <= upper_bound:
entries.append((sort_key, m.group(0)))
if len(entries) < n:
raise Exception('Found {} checkpoint files but need at least {}', len(entries), n)
return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)[:n]]
def main():
parser = argparse.ArgumentParser(
description='Tool to average the params of input checkpoints to '
'produce a new checkpoint',
)
# fmt: off
parser.add_argument('--inputs', required=True, nargs='+',
help='Input checkpoint file paths.')
parser.add_argument('--output', required=True, metavar='FILE',
help='Write the new checkpoint containing the averaged weights to this path.')
num_group = parser.add_mutually_exclusive_group()
num_group.add_argument('--num-epoch-checkpoints', type=int,
help='if set, will try to find checkpoints with names checkpoint_xx.pt in the path specified by input, '
'and average last this many of them.')
num_group.add_argument('--num-update-checkpoints', type=int,
help='if set, will try to find checkpoints with names checkpoint_ee_xx.pt in the path specified by input, '
'and average last this many of them.')
parser.add_argument('--checkpoint-upper-bound', type=int,
help='when using --num-epoch-checkpoints, this will set an upper bound on which checkpoint to use, '
'e.g., with --num-epoch-checkpoints=10 --checkpoint-upper-bound=50, checkpoints 41-50 would be averaged.')
# parser.add_argument('--ema', type=float, default=1.0, help='exponential moving average decay')
# parser.add_argument('--no-progress-bar', action='store_true', help='disable progress bar')
parser.add_argument('--ema', default='False', type=str, metavar='BOOL', help='ema')
parser.add_argument('--ema_decay', type=float, default=1.0, help='exponential moving average decay')
parser.add_argument('--user-dir', default=None)
# fmt: on
args = parser.parse_args()
import_user_module(args)
print(args)
num = None
is_update_based = False
if args.num_update_checkpoints is not None:
num = args.num_update_checkpoints
is_update_based = True
elif args.num_epoch_checkpoints is not None:
num = args.num_epoch_checkpoints
assert args.checkpoint_upper_bound is None or args.num_epoch_checkpoints is not None, \
'--checkpoint-upper-bound requires --num-epoch-checkpoints'
assert args.num_epoch_checkpoints is None or args.num_update_checkpoints is None, \
'Cannot combine --num-epoch-checkpoints and --num-update-checkpoints'
if num is not None:
args.inputs = last_n_checkpoints(
args.inputs, num, is_update_based, upper_bound=args.checkpoint_upper_bound,
)
# print('averaging checkpoints: ', args.inputs)
print('averaging checkpoints: ')
for checkpoint in args.inputs:
print(checkpoint)
print('-' * 40)
# ema = args.ema
# assert isinstance(args.ema, bool)
print(f'Start averaing with ema={args.ema}, ema_decay={args.ema_decay}')
new_state = average_checkpoints(args.inputs, args.ema_decay)
torch.save(new_state, args.output)
print('Finished writing averaged checkpoint to {}.'.format(args.output))
if __name__ == '__main__':
main()