-
Notifications
You must be signed in to change notification settings - Fork 63
/
engine_mar.py
250 lines (202 loc) · 9.63 KB
/
engine_mar.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
import math
import sys
from typing import Iterable
import torch
import util.misc as misc
import util.lr_sched as lr_sched
from models.vae import DiagonalGaussianDistribution
import torch_fidelity
import shutil
import cv2
import numpy as np
import os
import copy
import time
def update_ema(target_params, source_params, rate=0.99):
"""
Update target parameters to be closer to those of source parameters using
an exponential moving average.
:param target_params: the target parameter sequence.
:param source_params: the source parameter sequence.
:param rate: the EMA rate (closer to 1 means slower).
"""
for targ, src in zip(target_params, source_params):
targ.detach().mul_(rate).add_(src, alpha=1 - rate)
def train_one_epoch(model, vae,
model_params, ema_params,
data_loader: Iterable, optimizer: torch.optim.Optimizer,
device: torch.device, epoch: int, loss_scaler,
log_writer=None,
args=None):
model.train(True)
metric_logger = misc.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
header = 'Epoch: [{}]'.format(epoch)
print_freq = 20
optimizer.zero_grad()
if log_writer is not None:
print('log_dir: {}'.format(log_writer.log_dir))
for data_iter_step, (samples, labels) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
# we use a per iteration (instead of per epoch) lr scheduler
lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
samples = samples.to(device, non_blocking=True)
labels = labels.to(device, non_blocking=True)
with torch.no_grad():
if args.use_cached:
moments = samples
posterior = DiagonalGaussianDistribution(moments)
else:
posterior = vae.encode(samples)
# normalize the std of latent to be 1. Change it if you use a different tokenizer
x = posterior.sample().mul_(0.2325)
# forward
with torch.cuda.amp.autocast():
loss = model(x, labels)
loss_value = loss.item()
if not math.isfinite(loss_value):
print("Loss is {}, stopping training".format(loss_value))
sys.exit(1)
loss_scaler(loss, optimizer, clip_grad=args.grad_clip, parameters=model.parameters(), update_grad=True)
optimizer.zero_grad()
torch.cuda.synchronize()
update_ema(ema_params, model_params, rate=args.ema_rate)
metric_logger.update(loss=loss_value)
lr = optimizer.param_groups[0]["lr"]
metric_logger.update(lr=lr)
loss_value_reduce = misc.all_reduce_mean(loss_value)
if log_writer is not None:
""" We use epoch_1000x as the x-axis in tensorboard.
This calibrates different curves when batch size changes.
"""
epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x)
log_writer.add_scalar('lr', lr, epoch_1000x)
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
def evaluate(model_without_ddp, vae, ema_params, args, epoch, batch_size=16, log_writer=None, cfg=1.0,
use_ema=True):
model_without_ddp.eval()
num_steps = args.num_images // (batch_size * misc.get_world_size()) + 1
save_folder = os.path.join(args.output_dir, "ariter{}-diffsteps{}-temp{}-{}cfg{}-image{}".format(args.num_iter,
args.num_sampling_steps,
args.temperature,
args.cfg_schedule,
cfg,
args.num_images))
if use_ema:
save_folder = save_folder + "_ema"
if args.evaluate:
save_folder = save_folder + "_evaluate"
print("Save to:", save_folder)
if misc.get_rank() == 0:
if not os.path.exists(save_folder):
os.makedirs(save_folder)
# switch to ema params
if use_ema:
model_state_dict = copy.deepcopy(model_without_ddp.state_dict())
ema_state_dict = copy.deepcopy(model_without_ddp.state_dict())
for i, (name, _value) in enumerate(model_without_ddp.named_parameters()):
assert name in ema_state_dict
ema_state_dict[name] = ema_params[i]
print("Switch to ema")
model_without_ddp.load_state_dict(ema_state_dict)
class_num = args.class_num
assert args.num_images % class_num == 0 # number of images per class must be the same
class_label_gen_world = np.arange(0, class_num).repeat(args.num_images // class_num)
class_label_gen_world = np.hstack([class_label_gen_world, np.zeros(50000)])
world_size = misc.get_world_size()
local_rank = misc.get_rank()
used_time = 0
gen_img_cnt = 0
for i in range(num_steps):
print("Generation step {}/{}".format(i, num_steps))
labels_gen = class_label_gen_world[world_size * batch_size * i + local_rank * batch_size:
world_size * batch_size * i + (local_rank + 1) * batch_size]
labels_gen = torch.Tensor(labels_gen).long().cuda()
torch.cuda.synchronize()
start_time = time.time()
# generation
with torch.no_grad():
with torch.cuda.amp.autocast():
sampled_tokens = model_without_ddp.sample_tokens(bsz=batch_size, num_iter=args.num_iter, cfg=cfg,
cfg_schedule=args.cfg_schedule, labels=labels_gen,
temperature=args.temperature)
sampled_images = vae.decode(sampled_tokens / 0.2325)
# measure speed after the first generation batch
if i >= 1:
torch.cuda.synchronize()
used_time += time.time() - start_time
gen_img_cnt += batch_size
print("Generating {} images takes {:.5f} seconds, {:.5f} sec per image".format(gen_img_cnt, used_time, used_time / gen_img_cnt))
torch.distributed.barrier()
sampled_images = sampled_images.detach().cpu()
sampled_images = (sampled_images + 1) / 2
# distributed save
for b_id in range(sampled_images.size(0)):
img_id = i * sampled_images.size(0) * world_size + local_rank * sampled_images.size(0) + b_id
if img_id >= args.num_images:
break
gen_img = np.round(np.clip(sampled_images[b_id].numpy().transpose([1, 2, 0]) * 255, 0, 255))
gen_img = gen_img.astype(np.uint8)[:, :, ::-1]
cv2.imwrite(os.path.join(save_folder, '{}.png'.format(str(img_id).zfill(5))), gen_img)
torch.distributed.barrier()
time.sleep(10)
# back to no ema
if use_ema:
print("Switch back from ema")
model_without_ddp.load_state_dict(model_state_dict)
# compute FID and IS
if log_writer is not None:
if args.img_size == 256:
input2 = None
fid_statistics_file = 'fid_stats/adm_in256_stats.npz'
else:
raise NotImplementedError
metrics_dict = torch_fidelity.calculate_metrics(
input1=save_folder,
input2=input2,
fid_statistics_file=fid_statistics_file,
cuda=True,
isc=True,
fid=True,
kid=False,
prc=False,
verbose=False,
)
fid = metrics_dict['frechet_inception_distance']
inception_score = metrics_dict['inception_score_mean']
postfix = ""
if use_ema:
postfix = postfix + "_ema"
if not cfg == 1.0:
postfix = postfix + "_cfg{}".format(cfg)
log_writer.add_scalar('fid{}'.format(postfix), fid, epoch)
log_writer.add_scalar('is{}'.format(postfix), inception_score, epoch)
print("FID: {:.4f}, Inception Score: {:.4f}".format(fid, inception_score))
# remove temporal saving folder
shutil.rmtree(save_folder)
torch.distributed.barrier()
time.sleep(10)
def cache_latents(vae,
data_loader: Iterable,
device: torch.device,
args=None):
metric_logger = misc.MetricLogger(delimiter=" ")
header = 'Caching: '
print_freq = 20
for data_iter_step, (samples, _, paths) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
samples = samples.to(device, non_blocking=True)
with torch.no_grad():
posterior = vae.encode(samples)
moments = posterior.parameters
posterior_flip = vae.encode(samples.flip(dims=[3]))
moments_flip = posterior_flip.parameters
for i, path in enumerate(paths):
save_path = os.path.join(args.cached_path, path + '.npz')
os.makedirs(os.path.dirname(save_path), exist_ok=True)
np.savez(save_path, moments=moments[i].cpu().numpy(), moments_flip=moments_flip[i].cpu().numpy())
if misc.is_dist_avail_and_initialized():
torch.cuda.synchronize()
return