-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
305 lines (256 loc) · 13.5 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
import torch as t
from torch import nn
from resnet import resnet18
class FC(nn.Module):
def __init__(self, in_features, out_features, is_relu, is_bn, num_of_correspondence):
"""
:param in_features: number of features of input data
:param out_features: number of features of output data
:param is_relu: True use ReLU, False not use
:param is_bn: True use BatchNorm1d, False not use
:param num_of_correspondence: number of point correspondence of one point correspondence set
"""
super(FC, self).__init__()
is_bias = not is_bn
self.block = nn.Sequential(
nn.Linear(in_features=in_features, out_features=out_features, bias=is_bias)
)
if is_bn:
self.block.add_module("bn", nn.BatchNorm1d(num_features=num_of_correspondence))
if is_relu:
self.block.add_module("relu", nn.ReLU())
def forward(self, x):
"""
:param x: shape like (N, C, F), N represents the number of point correspondence set or Batch Size,
C represents the number of point correspondence of one point correspondence set, F represents the in_features,
in_features is 6 if x is original point cloud
:return: shape like (N, C, out_features)
"""
return self.block(x)
class Res(nn.Module):
def __init__(self, resblock_count):
"""
:param resblock_count: number of resnet block
"""
super(Res, self).__init__()
self.res_blocks = nn.Sequential()
for i in range(resblock_count):
self.res_blocks.add_module("res_%d" % (i,), nn.Sequential(*list(resnet18().children())[:-2]))
def forward(self, x):
"""
:param x: shape like (N, C, F), N represents the number of point correspondence set or Batch Size,
C represents the number of point correspondence of one point correspondence set, F represents the in_features,
in_features is 6 if x is original point cloud
:return: shape like (N, C, out_features)
"""
res_results = [] # shape of item of the list is [N, C, F]
x = x.unsqueeze(1) # (N, 1, C, F)
for n, m in self.res_blocks._modules.items():
x = m(x)
res_results.append(x.squeeze(1))
return res_results
# class Res(nn.Module):
#
# def __init__(self, resblock_count):
# """
#
# :param resblock_count: number of resnet block
# """
# super(Res, self).__init__()
# self.res_blocks = nn.Sequential()
# for i in range(resblock_count):
# self.res_blocks.add_module("res_block_%d" % (i,), nn.Sequential(
# nn.Linear(in_features=128, out_features=128),
# nn.ReLU(),
# nn.Linear(in_features=128, out_features=128),
# nn.ReLU()
# ))
#
# def forward(self, x):
# """
#
# :param x: shape like (N, C, F), N represents the number of point correspondence set or Batch Size,
# C represents the number of point correspondence of one point correspondence set, F represents the in_features,
# in_features is 6 if x is original point cloud
# :return: shape like (N, C, out_features)
# """
# res_results = [] # shape of item of the list is [N, C, F]
# x = x.unsqueeze(1) # (N, 1, C, F)
# for n, m in self.res_blocks._modules.items():
# x = m(x)
# res_results.append(x.squeeze(1))
# return res_results
class CLSNet(nn.Module):
def __init__(self, res_block_count, num_of_correspondence):
"""
:param res_block_count: number of resnet block
:param num_of_correspondence: number of point correspondence of one point correspondence set
"""
super(CLSNet, self).__init__()
self.fc1 = FC(in_features=6, out_features=128, is_relu=True, is_bn=False, num_of_correspondence=num_of_correspondence)
self.res = Res(res_block_count)
self.fc2 = FC(in_features=128, out_features=1, is_relu=False, is_bn=False, num_of_correspondence=num_of_correspondence)
self.relu = nn.ReLU()
self.tanh = nn.Tanh()
def forward(self, x):
fc1_result = self.fc1(x)
res_results = self.res(fc1_result)
use_for_cls_loss = self.fc2(res_results[-1]).squeeze(2) # shape like (N, C)
relu_result = self.relu(use_for_cls_loss)
out = self.tanh(relu_result) # shape like (N, C)
cls_features = [fc1_result] + res_results
return out, cls_features, use_for_cls_loss
class ContextBN(nn.Module):
def __init__(self):
super(ContextBN, self).__init__()
pass
def forward(self, x):
x = (x - t.mean(x, dim=1, keepdim=True)) / (t.std(x, dim=1, keepdim=True) + 1e-10)
return x
class RegNet(nn.Module):
def __init__(self, M, res_block_count):
super(RegNet, self).__init__()
self.context_bn = ContextBN()
self.conv = nn.Conv2d(in_channels=1, out_channels=8, kernel_size=3, stride=(2, 1), padding=(1, 1))
self.linear1 = nn.Sequential(
nn.Linear(in_features=8 * (res_block_count // 2 + 1) * 128, out_features=256),
nn.ReLU()
)
self.linear2 = nn.Linear(in_features=256, out_features=M + 3)
def forward(self, cls_features):
pool_results = []
for cls_feature in cls_features:
max_pool_result = t.max(cls_feature, dim=1).values # shape like: (N, F)
context_bn_result = self.context_bn(max_pool_result)
pool_results.append(context_bn_result)
concate_results = t.cat(pool_results, dim=1)
concate_results = concate_results.view((concate_results.size()[0], len(pool_results), -1)) # shape like: (N, res_block_count + 1, F)
concate_results = concate_results.unsqueeze(1) # (N, 1, res_block_count + 1, F)
conv_result = self.conv(concate_results).view((concate_results.size()[0], -1)) # (N, 8, res_block_count // 2 + 1, F)
linear1_result = self.linear1(conv_result)
reg_result = self.linear2(linear1_result)
return reg_result
# class RegNet(nn.Module):
#
# def __init__(self, M, res_block_count):
# super(RegNet, self).__init__()
# self.context_bn = ContextBN()
# self.linear1 = nn.Sequential(
# nn.Linear(in_features=(res_block_count + 1) * 128, out_features=1024),
# nn.ReLU(),
# nn.Linear(in_features=1024, out_features=512),
# nn.ReLU(),
# nn.Linear(in_features=512, out_features=256),
# nn.ReLU(),
# nn.Linear(in_features=256, out_features=M + 3)
# )
#
# def forward(self, cls_features):
# pool_results = []
# for cls_feature in cls_features:
# max_pool_result = t.max(cls_feature, dim=1).values # shape like: (N, F)
# context_bn_result = self.context_bn(max_pool_result)
# pool_results.append(context_bn_result)
# concate_results = t.cat(pool_results, dim=1)
# # concate_results = concate_results.view((concate_results.size()[0], len(pool_results), -1)) # shape like: (N, res_block_count + 1, F)
# # concate_results = concate_results.unsqueeze(1) # (N, 1, res_block_count + 1, F)
# # conv_result = self.conv(concate_results).view((concate_results.size()[0], -1)) # (N, 8, res_block_count // 2 + 1, F)
# # linear1_result = self.linear1(conv_result)
# # reg_result = self.linear2(linear1_result)
# reg_result = self.linear1(concate_results)
# return reg_result
class ThreeDRegNet(nn.Module):
def __init__(self, res_block_count, num_of_correspondence, M):
"""
:param res_block_count: resnet block count of Registration Block
:param num_of_correspondence: number of correspondence of one point correspondence set
:param M: number of rotation parameter
"""
super(ThreeDRegNet, self).__init__()
self.cls = CLSNet(res_block_count, num_of_correspondence)
self.reg = RegNet(M, res_block_count)
def forward(self, x):
"""
:param x: shape like (N, num_of_correspondence, 6), note that x[:, :, :3] is the point set which is registrated
:return:
"""
cls_out, cls_features, use_for_cls_loss = self.cls(x) # cls_out: (N, num_of_correspondence)
reg_out = self.reg(cls_features) # reg_out: (N, M + 3)
return cls_out, reg_out, use_for_cls_loss
class RefineNet(nn.Module):
def __init__(self, threeDRegNet_count, res_block_counts, num_of_correspondence, M, use_lie):
"""
:param threeDRegNet_count: count of threeDRegNet
:param res_block_counts: list of int, item represents count of res_block of every threeDRegNet
:param num_of_correspondence: number of correspondence of one point set
:param M: count of parameter of rotation, 9 if predict rotation matrix directly, 3 if predict lie param
:param use_lie: True will predict parameters of lie algebra, False will predict rotation matrix directly
"""
super(RefineNet, self).__init__()
assert len(res_block_counts) == threeDRegNet_count, "number of item of res_block_counts should equal to threeDRegNet_count"
self.M = M
self.use_lie = use_lie
self.block = nn.Sequential()
for i in range(threeDRegNet_count):
self.block.add_module("regnet_%d" % (i,), ThreeDRegNet(res_block_counts[i], num_of_correspondence, M))
def forward(self, x):
"""
:param x: shape like (N, num_of_correspondence, 6), note that x[:, :, :3] is the source point, x[:, :, 3:6] is the template point
:return:
"""
cls_outs = [] # shape of item is (N, num_of_correspondence), value between 0 and 1
reg_outs = [] # shape of item is (N, M + 3)
use_for_cls_losses = [] # shape of item is (N, num_of_correspondence)
points_preds = [] # shape of item is (N, num_of_correspondence, 3)
rotation_mats = [] # shape of item is (N, 3, 3)
trans_mats = [] # shape of item is (N, 3)
points_pred = x[:, :, :x.size()[2] // 2] # point after registrate, shape like (N, num_of_correspondence, 3)
dest = x[:, :, x.size()[2] // 2:] # target point of registration, shape like (N, num_of_correspondence, 3)
for n, m in self.block._modules.items():
cls_out, reg_out, use_for_cls_loss = m(x)
points_pred, rotation_mat, trans_mat = registration(reg_out, points_pred, self.M, self.use_lie)
points_preds.append(points_pred)
x = t.cat([points_pred, dest], dim=2)
cls_outs.append(cls_out)
reg_outs.append(reg_out)
use_for_cls_losses.append(use_for_cls_loss)
rotation_mats.append(rotation_mat)
trans_mats.append(trans_mat)
return rotation_mats, trans_mats, cls_outs, reg_outs, use_for_cls_losses, points_preds
def registration(reg_out, point_set, M, use_lie):
"""
:param use_lie: True will predict parameter of lie algebra, False will predict rotation matrix directly
:param M: number of parameter of rotation, like 9 means 3 * 3 rotation matrix
:param reg_out: reg_out, shape like (N, M + 3), N is the number of point set
:param point_set: point set, shape like (N, num_of_correspondence, 3), N is the number of point set, num_of_correspondence is the number of point correspondence of one point set, 3 is one point
:return:
"""
if use_lie:
assert M == 3, "M should be 3 when use_lie is True"
rotation_mat = lie_to_rot_mat(reg_out, M) # shape like (N, 3, 3)
else:
assert M == 9, "M should be 9"
rotation_mat = reg_out[:, :M].view((reg_out.size()[0], point_set.size()[2], -1)) # shape like (N, 3, 3)
trans_mat = reg_out[:, M:] # shape like (N, 3)
rot_result = t.bmm(rotation_mat, point_set.permute(dims=[0, 2, 1])).permute(dims=[0, 2, 1]) # shape like (N, num_of_correspondence, M // 3)
result = rot_result + trans_mat.unsqueeze(1) # shape like (N, num_of_correspondence, M // 3)
return result, rotation_mat, trans_mat
def lie_to_rot_mat(reg_out, M):
rotation_param = reg_out[:, :M] # shape (N, 3)
norm_tensor = t.norm(rotation_param, dim=1, keepdim=True) # shape (N, 1)
unit_tesnor = (rotation_param / norm_tensor).unsqueeze(-1) # shape (N, 3, 1)
unit_hat = t.zeros(size=(rotation_param.size()[0], 9)).type(rotation_param.dtype).to(rotation_param.device)
unit_hat[:, [5, 2, 1]] = t.cat([-unit_tesnor.view((unit_tesnor.size()[0], -1))[:, 0:1], unit_tesnor.view((unit_tesnor.size()[0], -1))[:, 1:2], -unit_tesnor.view((unit_tesnor.size()[0], -1))[:, 2:]], dim=1)
unit_hat[:, [3, 6, 7]] = t.cat([unit_tesnor.view((unit_tesnor.size()[0], -1))[:, 2:], -unit_tesnor.view((unit_tesnor.size()[0], -1))[:, 1:2], unit_tesnor.view((unit_tesnor.size()[0], -1))[:, 0:1]], dim=1)
unit_hat = unit_hat.view((unit_hat.size()[0], 3, 3)) # shape (N, 3, 3)
R = t.cos(norm_tensor).unsqueeze(-1) * t.cat([t.eye(3).unsqueeze(0).type(rotation_param.dtype).to(rotation_param.device)] * norm_tensor.size()[0], dim=0) + \
(1 - t.cos(norm_tensor)).unsqueeze(-1) * t.bmm(unit_tesnor, t.transpose(unit_tesnor, 1, 2)) + \
t.sin(norm_tensor).unsqueeze(-1) * unit_hat
# print(rotation_param)
# print(R)
# print("====================")
return R
if __name__ == "__main__":
d = t.randn(2, 512, 6)
model = RefineNet(3, [4, 5, 6], 512, 3, True)
rotation_mats, trans_mats, cls_outs, reg_outs, use_for_cls_losses, points_preds = model(d)