-
Notifications
You must be signed in to change notification settings - Fork 2
/
train.py
254 lines (217 loc) · 12.7 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
import os
import torch
import argparse
## HF imports
import diffusers
from diffusers.optimization import get_cosine_schedule_with_warmup
from diffusers import UNet2DModel
from diffusers import DDPMScheduler, DDIMScheduler
from diffusers.utils.torch_utils import randn_tensor
import pandas as pd
import numpy as np
from torchvision import transforms
import torch.nn.functional as F
from torch import nn
from tqdm.autonotebook import tqdm
## Custom imports
from config import TrainingConfig
from transform import *
from dataset import ContourDiffDataset
from ContourDiffPipeline import ContourDiffDDPMPipeline, ContourDiffDDIMPipeline
from utils import evaluate, add_contours_to_noise
def main(args):
### Load the training config
config = TrainingConfig(
model_type=args.model_type,
dataset=args.dataset,
img_size=args.img_size,
input_domain=args.input_domain,
output_domain=args.output_domain,
in_channels=args.in_channels,
train_batch_size=args.train_batch_size,
eval_batch_size=args.eval_batch_size,
num_epochs=args.num_epochs,
noise_step=args.noise_step,
learning_rate=args.learning_rate,
lr_warmup_steps=args.lr_warmup_steps,
save_image_epochs=args.save_image_epochs,
save_model_epochs=args.save_model_epochs,
seed=args.seed,
workers=args.workers,
# device=args.device,
generator_seed=args.generator_seed,
contour_guided=args.contour_guided,
contour_channel_mode=args.contour_channel_mode,
conditional=args.conditional,
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('running on {}'.format(device))
if args.output_dir is not None:
config.output_dir = args.output_dir
else:
config.output_dir = f'ContourDiff-{config.input_domain}-{config.output_domain}-{config.model_type}-{config.dataset}'
### Load transform for images and contours
train_transform_img = load_train_transform_img(config)
train_transform_contour = load_train_transform_contour(config)
val_transform_img = load_val_transform_img(config)
val_transform_contour = load_val_transform_contour(config)
### Load the meta csv for training and validation
### training: output_domain (e.g. MRI)
### validation: input_domain (e.g. CT)
### The meta files can be generated by running preprocess.py
df_train_meta = pd.read_csv(os.path.join(args.data_directory, args.output_domain_meta_path), index_col=0)
df_val_meta = pd.read_csv(os.path.join(args.data_directory, args.input_domain_meta_path), index_col=0)
### Load the Dataset and DataLoader
### image_directory: folder includes all the images
### contour_directory: folder includes all the contours
train_dataset = ContourDiffDataset(
df_train_meta,
image_directory=os.path.join(args.data_directory, args.output_domain_img_folder),
contour_directory=os.path.join(args.data_directory, args.output_domain_contour_folder),
transform_img=train_transform_img,
transform_contour=train_transform_contour,
generator_seed=config.generator_seed,
config=config
)
val_dataset = ContourDiffDataset(
df_val_meta,
image_directory=os.path.join(args.data_directory, args.input_domain_img_folder),
contour_directory=os.path.join(args.data_directory, args.input_domain_contour_folder),
transform_img=val_transform_img,
transform_contour=val_transform_contour,
generator_seed=config.generator_seed,
config=config
)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=config.train_batch_size, shuffle=True, num_workers=config.workers)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=config.eval_batch_size, shuffle=True, num_workers=config.workers)
### Load the scheduler
if config.model_type == "ddpm":
noise_scheduler = DDPMScheduler(num_train_timesteps=config.noise_step)
elif config.model_type == "ddim":
noise_scheduler = DDIMScheduler(num_train_timesteps=config.noise_step)
### Load the model
if config.contour_guided:
if config.contour_channel_mode == "single":
model_in_channels = config.in_channels + 1
else:
raise NotImplementedError("Multi-channel map is not implemented")
model = UNet2DModel(
sample_size=config.img_size, # the target image resolution
in_channels=model_in_channels, # the number of input channels, 3 for RGB images
out_channels=config.in_channels, # the number of output channels
layers_per_block=2, # how many ResNet layers to use per UNet block
block_out_channels=(128, 128, 256, 256, 512, 512), # the number of output channes for each UNet block
down_block_types=(
"DownBlock2D", # a regular ResNet downsampling block
"DownBlock2D",
"DownBlock2D",
"DownBlock2D",
"AttnDownBlock2D",
"DownBlock2D", # a ResNet downsampling block with spatial self-attention
),
up_block_types=(
"UpBlock2D", # a regular ResNet upsampling block
"AttnUpBlock2D", # a ResNet upsampling block with spatial self-attention
"UpBlock2D",
"UpBlock2D",
"UpBlock2D",
"UpBlock2D"
),
)
# model.to(config.device)
model = nn.DataParallel(model)
model.to(device)
### Load optimizer and scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
lr_scheduler = get_cosine_schedule_with_warmup(
optimizer=optimizer,
num_warmup_steps=config.lr_warmup_steps,
num_training_steps=(len(train_dataloader) * config.num_epochs),
)
for epoch in range(config.num_epochs):
## Training
progress_bar = tqdm(total=len(train_dataloader))
progress_bar.set_description(f"Epoch {epoch}")
model.train()
global_step = 0
for step, batch in enumerate(train_dataloader):
clean_images = batch["images"]
clean_images = clean_images.to(device)
## Sample noise to add to the images
noise = torch.randn(clean_images.shape).to(clean_images.device)
bs = clean_images.shape[0]
## Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bs,), device=clean_images.device).long()
# Add noise to the clean images according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)
### Add contours to guide the reverse process
if config.contour_guided:
noisy_images = add_contours_to_noise(noisy_images, batch, config, device)
noise_pred = model(noisy_images, timesteps, return_dict=False)[0]
loss = F.mse_loss(noise_pred, noise)
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
progress_bar.update(1)
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
progress_bar.set_postfix(**logs)
global_step += 1
### Initialize the ContourDiff Pipeline
### Modified from diffusers.DDPMPipeline and diffusers.DDIMPipeline
if config.model_type == "ddpm":
pipeline = ContourDiffDDPMPipeline(unet=model.module, scheduler=noise_scheduler, data_loader=val_dataloader, external_config=config)
elif config.model_type == "ddim":
pipeline = ContourDiffDDIMPipeline(unet=model.module, scheduler=noise_scheduler, data_loader=val_dataloader, external_config=config)
### Evaluate using contours generated from input_domain images
model.eval()
### Evaluate using contours generated from input_domain images
if (epoch == 0) or (epoch + 1) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1:
if config.contour_guided:
data_batch = next(iter(val_dataloader))
evaluate(config, epoch + 1, pipeline, noise_step=config.noise_step, contour=True, data_batch=data_batch)
else:
raise NotImplementedError("Multi-channel map is not implemented")
### Save checkpoints
if (epoch == 0) or (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1:
if args.overwrite:
pipeline.save_pretrained(os.path.join(config.output_dir, f"model"))
else:
pipeline.save_pretrained(os.path.join(config.output_dir, f"model_epoch_{epoch + 1}"))
if __name__ == "__main__":
# Parse args:
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default=None, help="name of the dataset")
parser.add_argument('--model_type', type=str, default="ddpm", choices=["ddpm", "ddim"], help="type of diffusion models (ddpm or ddim)")
parser.add_argument('--img_size', type=int, default=256, help="size of the input images")
parser.add_argument('--train_batch_size', type=int, default=4, help="training batch size")
parser.add_argument('--eval_batch_size', type=int, default=16, help="validation batch size")
parser.add_argument('--num_epochs', type=int, default=400, help="number of epochs for training")
parser.add_argument('--noise_step', type=int, default=1000, help="maximum number of steps to add noise during training")
parser.add_argument('--learning_rate', type=float, default=1e-4, help="initial learning rate used for training")
parser.add_argument('--lr_warmup_steps', type=int, default=500, help="number of steps to warmup the training")
parser.add_argument('--save_image_epochs', type=int, default=20, help="frequency to save the translated samples")
parser.add_argument('--save_model_epochs', type=int, default=20, help="frequency to save the model checkpoints")
parser.add_argument('--overwrite', action='store_true', help="overwrite previous checkpoints if specified, otherwise, save checkpoints separately")
parser.add_argument('--output_dir', type=str, default=None, help="directory to save the output samples and checkpoints. If not specified, it will use the default name as ContourDiff-{input_domain}-{output_domain}-{model_type}-{dataset}")
parser.add_argument('--seed', type=int, default=0, help="seeds for random noise generator")
parser.add_argument('--workers', type=int, default=0, help="number of workers")
# parser.add_argument('--device', type=str, default="cuda:0", help="gpu to use")
parser.add_argument('--generator_seed', type=int, default=42, help="seed to ensure identical transformation applying to images and contours")
parser.add_argument('--contour_guided', action='store_true', help="enable contour guided diffusion if specified")
parser.add_argument('--contour_channel_mode', type=str, default="single", help="number of channels for the contour")
parser.add_argument('--conditional', action='store_true', help="if adding other conditions (except from contours) to generate images")
parser.add_argument('--data_directory', type=str, required=True, help="directory of the dataset")
parser.add_argument('--input_domain', type=str, required=True, help="name of the input domain (e.g. CT)")
parser.add_argument('--output_domain', type=str, required=True, help="name of the output domain (e.g. MRI)")
parser.add_argument('--input_domain_img_folder', type=str, required=True, help="name of the folder containing images from input domain (e.g. CT)")
parser.add_argument('--input_domain_contour_folder', type=str, required=True, help="name of the folder containing contours from input domain (e.g. CT)")
parser.add_argument('--output_domain_img_folder', type=str, required=True, help="name of the folder containing images from output domain (e.g. MRI)")
parser.add_argument('--output_domain_contour_folder', type=str, required=True, help="name of the folder containing contours from output domain (e.g. MRI)")
parser.add_argument('--input_domain_meta_path', type=str, required=True, help="path of input domain meta under data_directory")
parser.add_argument('--output_domain_meta_path', type=str, required=True, help="path of output domain meta under data_directory")
parser.add_argument('--in_channels', type=int, default=1, help="name of the channels for input image")
args = parser.parse_args()
main(args)