-
Notifications
You must be signed in to change notification settings - Fork 7
/
main.py
54 lines (46 loc) · 2.28 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import sys
import os
import numpy as np
import HoughCNN as HCNN
basePath=os.getcwd()
params = dict()
params['DataManagerParams'] = dict()
params['ModelParams'] = dict()
#params of the algorithm
params['ModelParams']['numcontrolpoints'] = 2
params['ModelParams']['sigma'] = 15
params['ModelParams']['device'] = 0
params['ModelParams']['prototxtTrain'] = os.path.join(basePath,'Prototxt/train_HCNN_promise2012.prototxt')
params['ModelParams']['prototxtTest'] = os.path.join(basePath,'Prototxt/test_HCNN_promise2012.prototxt')
params['ModelParams']['snapshot'] = 5000
params['ModelParams']['dirTrain'] = os.path.join(basePath,'PromiseNormalised/Train')
params['ModelParams']['dirTest'] = os.path.join(basePath,'PromiseNormalised/Test')
params['ModelParams']['dirResult'] = os.path.join(basePath,'Results') #where we need to save the results (relative to the base path)
params['ModelParams']['dirSnapshots'] = os.path.join(basePath,'Models/HCNN/') #where to save the models while training
params['ModelParams']['batchsize'] = 400 #the batchsize
params['ModelParams']['numIterations'] = 100000 #the number of iterations
params['ModelParams']['baseLR'] = 0.0001 #the learning rate, initial one
params['ModelParams']['nProc'] = 8 #the number of threads to do data augmentation
params['ModelParams']['solver'] = None
params['ModelParams']['patchSize'] = 33
params['ModelParams']['SamplingStep'] = 4
params['ModelParams']['featLength'] = 128
params['ModelParams']['numNeighs'] = 10
params['ModelParams']['maxDist'] = 2.0
params['ModelParams']['centrtol'] = 4
params['ModelParams']['SegPatchRadius'] = [21, 21, 21]
#params of the DataManager
params['DataManagerParams']['dstRes'] = np.asarray([1, 1, 1.5], dtype=float)
params['DataManagerParams']['VolSize'] = np.asarray([128, 128, 64], dtype=int)
params['DataManagerParams']['VolSize'] = np.asarray([128, 128, 64], dtype=int)
params['DataManagerParams']['normDir'] = False
params['DataManagerParams']['rebuildDbase'] = True
params['DataManagerParams']['databasePklLoadPath'] = './database.pkl'
params['DataManagerParams']['databasePklSavePath'] = './database.pkl'
model=HCNN.HoughCNN(params)
train = [i for i, j in enumerate(sys.argv) if j == '-train']
if len(train) > 0:
model.train()
test = [i for i, j in enumerate(sys.argv) if j == '-test']
if len(test) > 0:
model.test()