forked from DengPingFan/Inf-Net
-
Notifications
You must be signed in to change notification settings - Fork 0
/
MyTrain_MulClsLungInf_UNet.py
84 lines (64 loc) · 3.26 KB
/
MyTrain_MulClsLungInf_UNet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# -*- coding: utf-8 -*-
"""Preview
Code for 'Inf-Net: Automatic COVID-19 Lung Infection Segmentation from CT Scans'
submit to Transactions on Medical Imaging, 2020.
First Version: Created on 2020-05-13 (@author: Ge-Peng Ji)
"""
import os
import numpy as np
import torch.optim as optim
from Code.utils.dataloader_MulClsLungInf_UNet import LungDataset
from torchvision import transforms
# from LungData import test_dataloader, train_dataloader # pls change batch_size
from torch.utils.data import DataLoader
from Code.model_lung_infection.InfNet_UNet import *
def train(epo_num, num_classes, input_channels, batch_size, lr, save_path):
train_dataset = LungDataset(
imgs_path='./Dataset/TrainingSet/MultiClassInfection-Train/Imgs/',
# NOTES: prior is borrowed from the object-level label of train split
pseudo_path='./Dataset/TrainingSet/MultiClassInfection-Train/Prior/',
label_path='./Dataset/TrainingSet/MultiClassInfection-Train/GT/',
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]))
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
lung_model = Inf_Net_UNet(input_channels, num_classes) # input_channels=3, n_class=3
print(lung_model)
lung_model = lung_model.to(device)
criterion = nn.BCELoss().to(device)
optimizer = optim.SGD(lung_model.parameters(), lr=lr, momentum=0.7)
print("#" * 20, "\nStart Training (Inf-Net)\nThis code is written for 'Inf-Net: Automatic COVID-19 Lung "
"Infection Segmentation from CT Scans', 2020, TMI.\n"
"----\nPlease cite the paper if you use this code and dataset. "
"And any questions feel free to contact me "
"via E-mail ([email protected])\n----\n", "#" * 20)
for epo in range(epo_num):
train_loss = 0
lung_model.train()
for index, (img, pseudo, img_mask, _) in enumerate(train_dataloader):
img = img.to(device)
pseudo = pseudo.to(device)
img_mask = img_mask.to(device)
optimizer.zero_grad()
output = lung_model(torch.cat((img, pseudo), dim=1))
output = torch.sigmoid(output) # output.shape is torch.Size([4, 2, 160, 160])
loss = criterion(output, img_mask)
loss.backward()
iter_loss = loss.item()
train_loss += iter_loss
optimizer.step()
if np.mod(index, 20) == 0:
print('Epoch: {}/{}, Step: {}/{}, Train loss is {}'.format(epo, epo_num, index, len(train_dataloader), iter_loss))
os.makedirs('./checkpoints//UNet_Multi-Class-Semi', exist_ok=True)
if np.mod(epo+1, 10) == 0:
torch.save(lung_model.state_dict(),
'./Snapshots/save_weights/{}/unet_model_{}.pkl'.format(save_path, epo+1))
print('Saving checkpoints: unet_model_{}.pkl'.format(epo+1))
if __name__ == "__main__":
train(epo_num=200,
num_classes=3,
input_channels=6,
batch_size=16,
lr=1e-2,
save_path='Semi-Inf-Net_UNet')