-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
278 lines (237 loc) · 10.8 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
import sys
import os
# current_dir = os.path.dirname(os.path.abspath(__file__))
# top_level_dir = os.path.abspath(os.path.join(current_dir, "../../../")) # Adjust as needed
# if top_level_dir not in sys.path:
# sys.path.insert(0, top_level_dir)
import warnings
warnings.filterwarnings("ignore", message="torch.utils._pytree._register_pytree_node is deprecated")
import hydra
import torch
import wandb
from omegaconf import DictConfig, OmegaConf
from accelerate import Accelerator, DistributedDataParallelKwargs
from utils.train_helper import sample, update_ema, requires_grad
from copy import deepcopy
from time import time
import torch.distributed as dist
from fid import calc
from PIL import Image
import torchvision.transforms as transforms
import torchvision.utils as vutils
import glob
import numpy as np
from itertools import islice
import webdataset as wds
import pickle
import os
# WebDataset Helper Function
def nodesplitter(src, group=None):
rank, world_size, worker, num_workers = wds.utils.pytorch_worker_info()
if world_size > 1:
for s in islice(src, rank, None, world_size):
yield s
else:
for s in src:
yield s
def get_file_paths(dir):
return [os.path.join(dir, file) for file in os.listdir(dir)]
def split_by_proc(data_list, global_rank, total_size):
'''
Evenly split the data_list into total_size parts and return the part indexed by global_rank.
'''
assert len(data_list) >= total_size
assert global_rank < total_size
return data_list[global_rank::total_size]
def decode_data(item):
output = {}
img = pickle.loads(item['latent'])
output['latent'] = img
label = int(item['cls'].decode('utf-8'))
output['label'] = label
return output
def make_loader(root, mode='train', batch_size=32,
num_workers=4, cache_dir=None,
resampled=False, world_size=1, total_num=1281167,
bufsize=1000, initial=100):
data_list = get_file_paths(root)
num_batches_in_total = total_num // (batch_size * world_size)
if resampled:
repeat = True
splitter = False
else:
repeat = False
splitter = nodesplitter
dataset = (
wds.WebDataset(
data_list,
cache_dir=cache_dir,
repeat=repeat,
resampled=resampled,
handler=wds.handlers.warn_and_stop,
nodesplitter=splitter,
shardshuffle=True
)
.shuffle(bufsize, initial=initial)
.map(decode_data, handler=wds.handlers.warn_and_stop)
.to_tuple('latent label')
.batched(batch_size, partial=False)
)
loader = wds.WebLoader(dataset, batch_size=None, num_workers=num_workers, shuffle=False, persistent_workers=True)
if resampled:
loader = loader.with_epoch(num_batches_in_total)
return loader
def rzprint(*args, **kwargs):
if not dist.is_available() or not dist.is_initialized() or dist.get_rank() == 0:
print(*args, **kwargs)
@hydra.main(config_path="configs", config_name="config")
def train(cfg: DictConfig):
print(OmegaConf.to_yaml(cfg))
data_config = cfg.dataset
model_config = cfg.model
experiment_dir = os.path.join(cfg.results_dir, cfg.run_name)
##############################################################
# INIT
##############################################################
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=False)
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
device = accelerator.device
size = accelerator.num_processes
rank = accelerator.process_index
rzprint("Init Accelerator.")
model = hydra.utils.instantiate(model_config.model).to(device)
optimizer = hydra.utils.instantiate(cfg.train.optimizer, model.parameters())
rzprint("Init model and optimizer.")
diffuser = hydra.utils.instantiate(cfg.train.diffuser)
model = diffuser.wrap_model_with_precond(model)
ema = deepcopy(model)
requires_grad(ema, False)
if cfg.load_ckpt:
load_path = os.path.join(experiment_dir, 'latest.pt')
if os.path.exists(load_path):
checkpoint = torch.load(load_path, map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'])
ema.load_state_dict(checkpoint['ema_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
step = checkpoint['step']
rzprint(f"Loaded checkpoint from {load_path} at step {step}")
else:
rzprint(f"No checkpoint found at {load_path}")
step = 0
else:
step = 0
model, ema, optimizer = accelerator.prepare(model, ema, optimizer)
model.train()
ema.eval()
rzprint("Init diffuser.")
##############################################################
# DATA
##############################################################
total_batch_size = cfg.train.general.batch_size
batch_size_per_device = total_batch_size // size
rzprint(f"Batch size per device: {batch_size_per_device}")
rzprint(f"Total batch size: {total_batch_size}")
loader = make_loader(
cfg.dataset.train_path,
mode='train',
batch_size=batch_size_per_device,
num_workers=data_config.num_workers,
resampled=False,
total_num=data_config.total_num
)
rzprint("Init data loader.")
if cfg.log_wandb and rank == 0:
wandb.init(project=cfg.wandb.project, config=OmegaConf.to_container(cfg))
running_loss = 0
start_time = time()
rzprint("Starting training loop...")
for epoch in range(cfg.train.general.max_epochs):
rzprint(f"Epoch {epoch + 1}/{cfg.train.general.max_epochs}")
for x, cond in loader:
##############################################################
# TRAIN STEP
##############################################################
x = x.to(accelerator.device)
cond = cond.to(accelerator.device)
with accelerator.autocast():
x = sample(x)
loss = diffuser.get_training_loss(
model,
x,
cond.to(torch.long),
mask_ratio=cfg.train.general.mask_ratio,
class_drop_prob=cfg.train.general.class_drop_prob,
)
loss = loss.mean()
optimizer.zero_grad()
accelerator.backward(loss, retain_graph=True)
optimizer.step()
update_ema(ema, model.module)
running_loss += loss.item()
##############################################################
# LOGGING
##############################################################
if step % cfg.train.logging.log_interval == 0 and step > 0:
elapsed = time() - start_time
steps_per_sec = cfg.train.logging.log_interval / elapsed
avg_loss = running_loss / cfg.train.logging.log_interval
rzprint(f"Step {step}: Loss: {avg_loss:.4f}, Steps/sec: {steps_per_sec:.2f} \n")
if cfg.log_wandb and rank == 0:
wandb.log({"loss": avg_loss, "steps_per_sec": steps_per_sec}, step=step)
running_loss = 0
start_time = time()
##############################################################
# EVAL
##############################################################
if step % cfg.train.eval.eval_interval == 0 and cfg.enable_eval and step > 0:
for cfg_scale in cfg.train.eval.cfg_scales:
cfg.train.eval.cfg_scale = cfg_scale
outdir = os.path.join(experiment_dir, 'fid')
os.makedirs(outdir, exist_ok=True)
rzprint(f"FID Folder: {outdir}")
rzprint(f"EMA device: {next(ema.parameters()).device}")
start_time = time()
diffuser.generate(cfg.train.eval, ema, device, rank, size, outdir=outdir)
accelerator.wait_for_everyone()
elapsed = time() - start_time
rzprint(f"Time taken to generate samples: {elapsed:.2f}s")
fid = calc(outdir, data_config.ref_path, cfg.train.eval.fid_num_samples, cfg.global_seed, cfg.train.eval.fid_batch_size, cfg.train.eval.inception_path)
accelerator.wait_for_everyone()
cfg.train.eval.cfg_scale = None
if rank == 0:
rzprint(f"FID (CFG:{cfg_scale}): {fid}")
if cfg.log_wandb:
wandb.log({f"FID (CFG:{cfg_scale})": fid}, step=step)
num_samples = 16
image_files = sorted(glob.glob(os.path.join(outdir, '*.png')))
image_list = []
for img_file in image_files[:num_samples]:
img = Image.open(img_file).convert('RGB')
transform = transforms.ToTensor()
img_tensor = transform(img)
image_list.append(img_tensor)
if len(image_list) > 0:
grid = vutils.make_grid(image_list, nrow=int(np.sqrt(num_samples)), normalize=True)
wandb.log({f"FID (CFG:{cfg_scale}) Samples": [wandb.Image(grid, caption="Generated Samples")]}, step=step)
##############################################################
# CHECKPOINT
##############################################################
if cfg.save_ckpt and step % cfg.train.eval.eval_interval == 0 and step > 0:
save_path = os.path.join(experiment_dir, f'step_{step:06d}.pt')
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model.module)
unwrapped_ema = accelerator.unwrap_model(ema)
if accelerator.is_main_process:
checkpoint = {
'model_state_dict': unwrapped_model.state_dict(),
'ema_state_dict': unwrapped_ema.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'step': step,
}
torch.save(checkpoint, save_path)
# Save latest checkpoint
latest_path = os.path.join(experiment_dir, 'latest.pt')
torch.save(checkpoint, latest_path)
step += 1
if __name__ == '__main__':
train()