-
Notifications
You must be signed in to change notification settings - Fork 1
/
keras_to_tensorflow.py
129 lines (82 loc) · 3.59 KB
/
keras_to_tensorflow.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# Taken from https://github.com/amir-abdi/keras_to_tensorflow
# coding: utf-8
# # Set parameters
# In[ ]:
"""
Copyright (c) 2017, by the Authors: Amir H. Abdi
This software is freely available under the MIT Public License.
Please see the License file in the root for details.
The following code snippet will convert the keras model file,
which is saved using model.save('kerasmodel_weight_file'),
to the freezed .pb tensorflow weight file which holds both the
network architecture and its associated weights.
""";
# In[ ]:
# setting input arguments
import argparse
parser = argparse.ArgumentParser(description='set input arguments')
parser.add_argument('-input_fld', action="store",
dest='input_fld', type=str, default='.')
parser.add_argument('-output_fld', action="store",
dest='output_fld', type=str, default='.')
parser.add_argument('-input_model_file', action="store",
dest='input_model_file', type=str, default='model.h5')
parser.add_argument('-output_model_file', action="store",
dest='output_model_file', type=str, default='model.pb')
parser.add_argument('-output_graphdef_file', action="store",
dest='output_graphdef_file', type=str, default='model.ascii')
parser.add_argument('-num_outputs', action="store",
dest='num_outputs', type=int, default=1)
parser.add_argument('-graph_def', action="store",
dest='graph_def', type=bool, default=False)
parser.add_argument('-output_node_prefix', action="store",
dest='output_node_prefix', type=str, default='output_node')
parser.add_argument('-f')
args = parser.parse_args()
print('input args: ', args)
# In[ ]:
# uncomment the following lines to alter the default values set above
# args.input_fld = '.'
# args.output_fld = '.'
# args.input_model_file = 'model.h5'
# args.output_model_file = 'model.pb'
# num_output: this value has nothing to do with the number of classes, batch_size, etc.,
# and it is mostly equal to 1.
# If you have a multi-stream network (forked network with multiple outputs),
# set the value to the number of outputs.
num_output = args.num_outputs
# # initialize
# In[ ]:
from keras.models import load_model
import tensorflow as tf
import os
import os.path as osp
from keras import backend as K
output_fld = args.output_fld
if not os.path.isdir(output_fld):
os.mkdir(output_fld)
weight_file_path = osp.join(args.input_fld, args.input_model_file)
# # Load keras model and rename output
# In[ ]:
K.set_learning_phase(0)
net_model = load_model(weight_file_path)
pred = [None]*num_output
pred_node_names = [None]*num_output
for i in range(num_output):
pred_node_names[i] = args.output_node_prefix+str(i)
pred[i] = tf.identity(net_model.outputs[i], name=pred_node_names[i])
print('output nodes names are: ', pred_node_names)
# #### [optional] write graph definition in ascii
# In[ ]:
sess = K.get_session()
if args.graph_def:
f = args.output_graphdef_file
tf.train.write_graph(sess.graph.as_graph_def(), output_fld, f, as_text=True)
print('saved the graph definition in ascii format at: ', osp.join(output_fld, f))
# #### convert variables to constants and save
# In[ ]:
from tensorflow.python.framework import graph_util
from tensorflow.python.framework import graph_io
constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), pred_node_names)
graph_io.write_graph(constant_graph, output_fld, args.output_model_file, as_text=False)
print('saved the freezed graph (ready for inference) at: ', osp.join(output_fld, args.output_model_file))