-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmknet.py
84 lines (65 loc) · 2.85 KB
/
mknet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from __future__ import print_function
import sys, os, math
import h5py
import numpy as np
from numpy import float32, int32, uint8, dtype
from config import fmap_start, net_input_shape, net_output_shape, use_deconvolution_uppath
# Load PyGreentea
# Relative path to where PyGreentea resides
pygt_path = '../../PyGreentea'
sys.path.append(pygt_path)
import PyGreentea as pygt
# Create the network we want
netconf = pygt.netgen.NetConf()
netconf.malis_split_component_phases = True
netconf.ignore_conv_buffer = True
netconf.use_batchnorm = False
netconf.dropout = 0.0
netconf.fmap_start = fmap_start
netconf.u_netconfs[0].use_deconvolution_uppath = use_deconvolution_uppath
netconf.u_netconfs[0].unet_fmap_inc_rule = lambda fmaps: int(math.ceil(fmaps * 3))
netconf.u_netconfs[0].unet_fmap_dec_rule = lambda fmaps: int(math.ceil(fmaps / 3))
netconf.input_shape = net_input_shape
netconf.output_shape = net_output_shape
print ('Input shape: %s' % netconf.input_shape)
print ('Output shape: %s' % netconf.output_shape)
print ('Feature maps: %s' % netconf.fmap_start)
netconf.loss_function = "euclid"
train_net_conf_euclid, test_net_conf = pygt.netgen.create_nets(netconf)
netconf.loss_function = "malis"
train_net_conf_malis, test_net_conf = pygt.netgen.create_nets(netconf)
with open('net_train_euclid.prototxt', 'w') as f:
print(train_net_conf_euclid, file=f)
with open('net_train_malis.prototxt', 'w') as f:
print(train_net_conf_malis, file=f)
# make test protos without zero-padding / deconvolution
netconf.ignore_conv_buffer = True
netconf.u_netconfs[0].use_deconvolution_uppath = False
netconf.input_shape = [132,]*3
netconf.output_shape = [44,]*3
train_net_conf_malis, test_net_conf = pygt.netgen.create_nets(netconf)
with open('net_test.prototxt', 'w') as f:
print(test_net_conf, file=f)
#### Make a big test proto
# Biggest possible network for testing on 12 GB
netconf.mem_global_limit = 8 * 1024 * 1024 * 1024
mode = pygt.netgen.caffe_pb2.TEST
shape_min = [100,100,100]
shape_max = [300,300,300]
constraints = [None, lambda x: x[0], lambda x: x[1]]
inshape,outshape,fmaps = pygt.netgen.compute_valid_io_shapes(netconf,mode,shape_min,shape_max,constraints=constraints)
# We choose the maximum that still gives us 20 fmaps:
index = [n for n, i in enumerate(fmaps) if i>=netconf.fmap_start][-1]
print("Index to use: %s" % index)
# Some patching to allow our new parameters
netconf.input_shape = inshape[index]
netconf.output_shape = outshape[index]
# Workaround to allow any size (train net unusably big)
netconf.mem_global_limit = 200 * 1024 * 1024 * 1024
netconf.mem_buf_limit = 200 * 1024 * 1024 * 1024
# Define some loss function (irrelevant for testing though)
netconf.loss_function = "euclid"
# Generate the nework, store it
train_net_big_conf, test_net_big_conf = pygt.netgen.create_nets(netconf)
with open('net_test_big.prototxt', 'w') as f:
print(test_net_big_conf, file=f)