-
Notifications
You must be signed in to change notification settings - Fork 112
/
siamese_model.py
220 lines (177 loc) · 8.7 KB
/
siamese_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
#! /usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright © 2017 bily Huazhong University of Science and Technology
#
# Distributed under terms of the MIT license.
"""Construct the computational graph of siamese model for training. """
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import tensorflow as tf
from datasets.dataloader import DataLoader
from embeddings.convolutional_alexnet import convolutional_alexnet_arg_scope, convolutional_alexnet
from metrics.track_metrics import center_dist_error, center_score_error
from utils.train_utils import construct_gt_score_maps, load_mat_model
slim = tf.contrib.slim
class SiameseModel:
def __init__(self, model_config, train_config, mode='train'):
self.model_config = model_config
self.train_config = train_config
self.mode = mode
assert mode in ['train', 'validation', 'inference']
if self.mode == 'train':
self.data_config = self.train_config['train_data_config']
elif self.mode == 'validation':
self.data_config = self.train_config['validation_data_config']
self.dataloader = None
self.exemplars = None
self.instances = None
self.response = None
self.batch_loss = None
self.total_loss = None
self.init_fn = None
self.global_step = None
def is_training(self):
"""Returns true if the model is built for training mode"""
return self.mode == 'train'
def build_inputs(self):
"""Input fetching and batching
Outputs:
self.exemplars: image batch of shape [batch, hz, wz, 3]
self.instances: image batch of shape [batch, hx, wx, 3]
"""
if self.mode in ['train', 'validation']:
with tf.device("/cpu:0"): # Put data loading and preprocessing in CPU is substantially faster
self.dataloader = DataLoader(self.data_config, self.is_training())
self.dataloader.build()
exemplars, instances = self.dataloader.get_one_batch()
exemplars = tf.to_float(exemplars)
instances = tf.to_float(instances)
else:
self.examplar_feed = tf.placeholder(shape=[None, None, None, 3],
dtype=tf.uint8,
name='examplar_input')
self.instance_feed = tf.placeholder(shape=[None, None, None, 3],
dtype=tf.uint8,
name='instance_input')
exemplars = tf.to_float(self.examplar_feed)
instances = tf.to_float(self.instance_feed)
self.exemplars = exemplars
self.instances = instances
def build_image_embeddings(self, reuse=False):
"""Builds the image model subgraph and generates image embeddings
Inputs:
self.exemplars: A tensor of shape [batch, hz, wz, 3]
self.instances: A tensor of shape [batch, hx, wx, 3]
Outputs:
self.exemplar_embeds: A Tensor of shape [batch, hz_embed, wz_embed, embed_dim]
self.instance_embeds: A Tensor of shape [batch, hx_embed, wx_embed, embed_dim]
"""
config = self.model_config['embed_config']
arg_scope = convolutional_alexnet_arg_scope(config,
trainable=config['train_embedding'],
is_training=self.is_training())
@functools.wraps(convolutional_alexnet)
def embedding_fn(images, reuse=False):
with slim.arg_scope(arg_scope):
return convolutional_alexnet(images, reuse=reuse)
self.exemplar_embeds, _ = embedding_fn(self.exemplars, reuse=reuse)
self.instance_embeds, _ = embedding_fn(self.instances, reuse=True)
def build_template(self):
# The template is simply the feature of the exemplar image in SiamFC.
self.templates = self.exemplar_embeds
def build_detection(self, reuse=False):
with tf.variable_scope('detection', reuse=reuse):
def _translation_match(x, z): # translation match for one example within a batch
x = tf.expand_dims(x, 0) # [1, in_height, in_width, in_channels]
z = tf.expand_dims(z, -1) # [filter_height, filter_width, in_channels, 1]
return tf.nn.conv2d(x, z, strides=[1, 1, 1, 1], padding='VALID', name='translation_match')
output = tf.map_fn(lambda x: _translation_match(x[0], x[1]),
(self.instance_embeds, self.templates),
dtype=self.instance_embeds.dtype)
output = tf.squeeze(output, [1, 4]) # of shape e.g., [8, 15, 15]
# Adjust score, this is required to make training possible.
config = self.model_config['adjust_response_config']
bias = tf.get_variable('biases', [1],
dtype=tf.float32,
initializer=tf.constant_initializer(0.0, dtype=tf.float32),
trainable=config['train_bias'])
response = config['scale'] * output + bias
self.response = response
def build_loss(self):
response = self.response
response_size = response.get_shape().as_list()[1:3] # [height, width]
gt = construct_gt_score_maps(response_size,
self.data_config['batch_size'],
self.model_config['embed_config']['stride'],
self.train_config['gt_config'])
with tf.name_scope('Loss'):
loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=response,
labels=gt)
with tf.name_scope('Balance_weights'):
n_pos = tf.reduce_sum(tf.to_float(tf.equal(gt[0], 1)))
n_neg = tf.reduce_sum(tf.to_float(tf.equal(gt[0], 0)))
w_pos = 0.5 / n_pos
w_neg = 0.5 / n_neg
class_weights = tf.where(tf.equal(gt, 1),
w_pos * tf.ones_like(gt),
tf.ones_like(gt))
class_weights = tf.where(tf.equal(gt, 0),
w_neg * tf.ones_like(gt),
class_weights)
loss = loss * class_weights
# Note that we use reduce_sum instead of reduce_mean since the loss has
# already been normalized by class_weights in spatial dimension.
loss = tf.reduce_sum(loss, [1, 2])
batch_loss = tf.reduce_mean(loss, name='batch_loss')
tf.losses.add_loss(batch_loss)
total_loss = tf.losses.get_total_loss()
self.batch_loss = batch_loss
self.total_loss = total_loss
tf.summary.image('exemplar', self.exemplars, family=self.mode)
tf.summary.image('instance', self.instances, family=self.mode)
mean_batch_loss, update_op1 = tf.metrics.mean(batch_loss)
mean_total_loss, update_op2 = tf.metrics.mean(total_loss)
with tf.control_dependencies([update_op1, update_op2]):
tf.summary.scalar('batch_loss', mean_batch_loss, family=self.mode)
tf.summary.scalar('total_loss', mean_total_loss, family=self.mode)
if self.mode == 'train':
tf.summary.image('GT', tf.reshape(gt[0], [1] + response_size + [1]), family='GT')
tf.summary.image('Response', tf.expand_dims(tf.sigmoid(response), -1), family=self.mode)
tf.summary.histogram('Response', self.response, family=self.mode)
# Two more metrics to monitor the performance of training
tf.summary.scalar('center_score_error', center_score_error(response), family=self.mode)
tf.summary.scalar('center_dist_error', center_dist_error(response), family=self.mode)
def setup_global_step(self):
global_step = tf.Variable(
initial_value=0,
name='global_step',
trainable=False,
collections=[tf.GraphKeys.GLOBAL_STEP, tf.GraphKeys.GLOBAL_VARIABLES])
self.global_step = global_step
def setup_embedding_initializer(self):
"""Sets up the function to restore embedding variables from checkpoint."""
embed_config = self.model_config['embed_config']
if embed_config['embedding_checkpoint_file']:
# Restore Siamese FC models from .mat model files
initialize = load_mat_model(embed_config['embedding_checkpoint_file'],
'convolutional_alexnet/', 'detection/')
def restore_fn(sess):
tf.logging.info("Restoring embedding variables from checkpoint file %s",
embed_config['embedding_checkpoint_file'])
sess.run([initialize])
self.init_fn = restore_fn
def build(self, reuse=False):
"""Creates all ops for training and evaluation"""
with tf.name_scope(self.mode):
self.build_inputs()
self.build_image_embeddings(reuse=reuse)
self.build_template()
self.build_detection(reuse=reuse)
self.setup_embedding_initializer()
if self.mode in ['train', 'validation']:
self.build_loss()
if self.is_training():
self.setup_global_step()