forked from lfz/DSB2017
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
107 lines (82 loc) · 3.36 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from preprocessing import full_prep
from config_submit import config as config_submit
import torch
from torch.nn import DataParallel
from torch.backends import cudnn
from torch.utils.data import DataLoader
from torch import optim
from torch.autograd import Variable
from layers import acc
from data_detector import DataBowl3Detector,collate
from data_classifier import DataBowl3Classifier
from utils import *
from split_combine import SplitComb
from test_detect import test_detect
from importlib import import_module
import pandas
datapath = config_submit['datapath']
prep_result_path = config_submit['preprocess_result_path']
skip_prep = config_submit['skip_preprocessing']
skip_detect = config_submit['skip_detect']
if not skip_prep:
testsplit = full_prep(datapath,prep_result_path,
n_worker = config_submit['n_worker_preprocessing'],
use_existing=config_submit['use_exsiting_preprocessing'])
else:
testsplit = os.listdir(datapath)
nodmodel = import_module(config_submit['detector_model'].split('.py')[0])
config1, nod_net, loss, get_pbb = nodmodel.get_model()
checkpoint = torch.load(config_submit['detector_param'])
nod_net.load_state_dict(checkpoint['state_dict'])
torch.cuda.set_device(0)
nod_net = nod_net.cuda()
cudnn.benchmark = True
nod_net = DataParallel(nod_net)
bbox_result_path = './bbox_result'
if not os.path.exists(bbox_result_path):
os.mkdir(bbox_result_path)
#testsplit = [f.split('_clean')[0] for f in os.listdir(prep_result_path) if '_clean' in f]
if not skip_detect:
margin = 32
sidelen = 144
config1['datadir'] = prep_result_path
split_comber = SplitComb(sidelen,config1['max_stride'],config1['stride'],margin,pad_value= config1['pad_value'])
dataset = DataBowl3Detector(testsplit,config1,phase='test',split_comber=split_comber)
test_loader = DataLoader(dataset,batch_size = 1,
shuffle = False,num_workers = 32,pin_memory=False,collate_fn =collate)
test_detect(test_loader, nod_net, get_pbb, bbox_result_path,config1,n_gpu=config_submit['n_gpu'])
casemodel = import_module(config_submit['classifier_model'].split('.py')[0])
casenet = casemodel.CaseNet(topk=5)
config2 = casemodel.config
checkpoint = torch.load(config_submit['classifier_param'])
casenet.load_state_dict(checkpoint['state_dict'])
torch.cuda.set_device(0)
casenet = casenet.cuda()
cudnn.benchmark = True
casenet = DataParallel(casenet)
filename = config_submit['outputfile']
def test_casenet(model,testset):
data_loader = DataLoader(
testset,
batch_size = 1,
shuffle = False,
num_workers = 32,
pin_memory=True)
#model = model.cuda()
model.eval()
predlist = []
# weight = torch.from_numpy(np.ones_like(y).float().cuda()
for i,(x,coord) in enumerate(data_loader):
coord = Variable(coord).cuda()
x = Variable(x).cuda()
nodulePred,casePred,_ = model(x,coord)
predlist.append(casePred.data.cpu().numpy())
#print([i,data_loader.dataset.split[i,1],casePred.data.cpu().numpy()])
predlist = np.concatenate(predlist)
return predlist
config2['bboxpath'] = bbox_result_path
config2['datadir'] = prep_result_path
dataset = DataBowl3Classifier(testsplit, config2, phase = 'test')
predlist = test_casenet(casenet,dataset).T
df = pandas.DataFrame({'id':testsplit, 'cancer':predlist})
df.to_csv(filename,index=False)