-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathnear_repair.py
218 lines (155 loc) · 8.53 KB
/
near_repair.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
import _init_paths_local
import os
import time
import importlib
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tensorboardX import SummaryWriter
from near.datasets.refine_dataset import AlanDataset, AbdomenCTDataset, AbdomenCTDistortedDataset
from near.models.nn3d.grid import GatherGridsFromVolumes
from near.models.nn3d.model import EmbeddingDecoder
from near.models.losses import latent_l2_penalty, dice_score
from near.utils.misc import to_device, to_var, write_json, Metrics, surface_dice, cal_surface_dice
def setup_cfg(cfg):
cfg["run_flag"] += time.strftime("%y%m%d_%H%M%S")
base_path = os.path.join(cfg["base_path"], cfg["run_flag"])
if os.path.exists(base_path):
raise ValueError(
"Existing [base_path]: %s! Use another `run_flag`. " % base_path)
else:
os.makedirs(base_path)
write_json(cfg, os.path.join(base_path, "config.json"), verbose=True)
return cfg, base_path
def train_epoch(model, optimizer, loader,
shape_loss_fn, gather_fn,
metrics_per_batch, metrics_per_epoch,
l2_penalty_weight, writer, epoch):
model.train()
global iteration
tmp_metrics = Metrics(*metrics_per_epoch.keys)
for _, (indices, shape, appearance) in enumerate(loader):
indices = to_var(indices)
_, grids, labels = gather_fn(shape)
# Note there are grid noises during the training.
# Therefore, the `appearance_label` could not align with the `labels` (with sight shift), which could serve as a data augmentation.
# However, if your images are high-res, this shift will be large. If so, you can align them by concatenating then split.
_, _, appearance_label = gather_fn(appearance)
pred_logit_shape, encoded = model(indices, grids, appearance_label)
shape_loss = shape_loss_fn(pred_logit_shape, labels) # shape loss
dice_metric = dice_score(pred_logit_shape.sigmoid() > 0.5, labels)
surface_dice_metric = cal_surface_dice(pred_logit_shape.sigmoid().detach().cpu().numpy() > 0.5, labels.cpu().numpy().astype(bool))
l2_loss = latent_l2_penalty(encoded)
loss = shape_loss + l2_penalty_weight*l2_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
tmp_metrics.ordered_update(loss.item(),
shape_loss.item(),
dice_metric.item(),
surface_dice_metric.item(),
l2_loss.item())
metrics_per_batch.ordered_update(loss.item(),
shape_loss.item(),
dice_metric.item(),
surface_dice_metric.item(),
l2_loss.item())
writer.add_scalar('iter_loss', loss.item(), iteration)
iteration += 1
metrics_per_epoch.ordered_update(*tmp_metrics.ordered_mean())
metrics_per_epoch.log_latest(header="Train Loss: ")
keys = ['train_loss', 'train_shape_loss', 'train_dice', 'train_surface_dice', 'train_l2_loss']
vals = tmp_metrics.ordered_mean()
for key, val in zip(keys, vals):
writer.add_scalar(key, val, epoch)
def eval_epoch(model, loader,
shape_loss_fn,
gather_fn, metrics_per_epoch, base_path,
l2_penalty_weight, writer, epoch):
model.eval()
best_loss, _ = metrics_per_epoch.find_best("shape")
tmp_metrics = Metrics(*metrics_per_epoch.keys)
for _, (indices, shape, appearance) in enumerate(loader):
indices = to_var(indices)
_, grids, labels = gather_fn(shape)
_, _, appearance_label = gather_fn(appearance)
with torch.no_grad():
pred_logit_shape, encoded = model(indices, grids, appearance_label)
shape_loss = shape_loss_fn(pred_logit_shape, labels) # shape loss
dice_metric = dice_score(pred_logit_shape.sigmoid() > 0.5, labels)
surface_dice_metric = cal_surface_dice(pred_logit_shape.sigmoid().cpu().numpy() > 0.5, labels.cpu().numpy().astype(bool))
l2_loss = latent_l2_penalty(encoded)
loss = shape_loss + l2_penalty_weight*l2_loss
tmp_metrics.ordered_update(loss.item(),
shape_loss.item(),
dice_metric.item(),
surface_dice_metric.item(),
l2_loss.item())
metrics_per_epoch.ordered_update(*tmp_metrics.ordered_mean())
metrics_per_epoch.log_latest(header="Eval Loss: ")
keys = ['val_loss', 'val_shape_loss', 'val_dice', 'val_surface_dice', 'val_l2_loss']
vals = tmp_metrics.ordered_mean()
for key, val in zip(keys, vals):
writer.add_scalar(key, val, epoch)
torch.save(model.state_dict(), os.path.join(base_path, "latest.pth"))
tmp_best_loss = tmp_metrics.ordered_mean()[1] # shape loss
if tmp_best_loss < best_loss:
torch.save(model.state_dict(), os.path.join(base_path, "best.pth"))
print("======================================================================")
print("Found a new best model! Loss: ", tmp_best_loss)
print("======================================================================")
if __name__ == "__main__":
cfg, base_path = setup_cfg(importlib.import_module("config_near").cfg)
data_path = cfg['data_path']
train_dataset = AlanDataset(root=data_path, resolution=cfg["target_resolution"], n_samples=cfg["n_training_samples"])
eval_dataset = AlanDataset(root=data_path, resolution=cfg["target_resolution"], n_samples=cfg["n_training_samples"])
train_loader = DataLoader(train_dataset, batch_size=cfg["batch_size"], shuffle=True,
pin_memory=(torch.cuda.is_available()), num_workers=cfg["n_workers"])
train_gather_fn = GatherGridsFromVolumes(
cfg["training_resolution"], grid_noise=cfg["grid_noise"], uniform_grid_noise=cfg["uniform_grid_noise"])
eval_loader = DataLoader(train_dataset, batch_size=cfg["eval_batch_size"], shuffle=False,
pin_memory=(torch.cuda.is_available()), num_workers=cfg["n_workers"])
eval_gather_fn = GatherGridsFromVolumes(
cfg["target_resolution"], grid_noise=None)
print('total data: ', len(train_dataset))
training_metrics = Metrics("total", "shape", "dice", "surface_dice", "l2")
train_metrics = Metrics("total", "shape", "dice", "surface_dice", "l2")
eval_metrics = Metrics("total", "shape", "dice", "surface_dice", "l2")
writer = SummaryWriter(log_dir=os.path.join(base_path, 'Tensorboard_Results'))
# define model and optimization
model = to_device(EmbeddingDecoder(n_samples=len(train_dataset), appearance=cfg['appearance']))
if torch.cuda.is_available() and torch.cuda.device_count() > 1:
model = torch.nn.DataParallel(model)
optimizer = torch.optim.Adam(model.parameters(), lr=cfg["lr"])
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer, cfg["milestones"], gamma=0.1)
shape_loss_fn = nn.BCEWithLogitsLoss()
l2_penalty_weight = cfg["l2_penalty_weight"]
eval_epoch(model=model, loader=eval_loader,
shape_loss_fn=shape_loss_fn,
gather_fn=eval_gather_fn,
metrics_per_epoch=eval_metrics, base_path=base_path,
l2_penalty_weight=l2_penalty_weight,
writer=writer, epoch=0)
global iteration
iteration = 0
for epoch in range(cfg["n_epochs"]):
print("Epoch", epoch+1)
train_epoch(model=model, optimizer=optimizer,
loader=train_loader, shape_loss_fn=shape_loss_fn,
gather_fn=train_gather_fn,
metrics_per_batch=training_metrics,
metrics_per_epoch=train_metrics,
l2_penalty_weight=l2_penalty_weight,
writer=writer, epoch=epoch)
eval_epoch(model=model, loader=eval_loader,
shape_loss_fn=shape_loss_fn,
gather_fn=eval_gather_fn,
metrics_per_epoch=eval_metrics, base_path=base_path,
l2_penalty_weight=l2_penalty_weight,
writer=writer, epoch=epoch)
scheduler.step()
training_metrics.save(os.path.join(base_path, "training_loss.json"))
train_metrics.save(os.path.join(base_path, "train_loss.json"))
eval_metrics.save(os.path.join(base_path, "eval_loss.json"))
writer.close()