-
Notifications
You must be signed in to change notification settings - Fork 2
/
single_stage.py
201 lines (164 loc) · 7.39 KB
/
single_stage.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import torch
import torch.nn as nn
from mmdet.core import bbox2result
from ..builder import DETECTORS, build_backbone, build_head, build_neck
from .base import BaseDetector
@DETECTORS.register_module()
class SingleStageDetector(BaseDetector):
"""Base class for single-stage detectors.
Single-stage detectors directly and densely predict bounding boxes on the
output features of the backbone+neck.
"""
def __init__(self,
backbone,
neck=None,
bbox_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None):
super(SingleStageDetector, self).__init__()
self.backbone = build_backbone(backbone)
if neck is not None:
self.neck = build_neck(neck)
bbox_head.update(train_cfg=train_cfg)
bbox_head.update(test_cfg=test_cfg)
self.bbox_head = build_head(bbox_head)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.init_weights(pretrained=pretrained)
def init_weights(self, pretrained=None):
"""Initialize the weights in detector.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
super(SingleStageDetector, self).init_weights(pretrained)
self.backbone.init_weights(pretrained=pretrained)
if self.with_neck:
if isinstance(self.neck, nn.Sequential):
for m in self.neck:
m.init_weights()
else:
self.neck.init_weights()
self.bbox_head.init_weights()
def extract_feat(self, img):
"""Directly extract features from the backbone+neck."""
x = self.backbone(img)
if self.with_neck:
x = self.neck(x)
return x
def forward_dummy(self, img):
"""Used for computing network flops.
See `mmdetection/tools/get_flops.py`
"""
x = self.extract_feat(img)
outs = self.bbox_head(x)
return outs
def forward_train(self,
img,
img_metas,
gt_bboxes,
gt_labels,
gt_bboxes_ignore=None):
"""
Args:
img (Tensor): Input images of shape (N, C, H, W).
Typically these should be mean centered and std scaled.
img_metas (list[dict]): A List of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
:class:`mmdet.datasets.pipelines.Collect`.
gt_bboxes (list[Tensor]): Each item are the truth boxes for each
image in [tl_x, tl_y, br_x, br_y] format.
gt_labels (list[Tensor]): Class indices corresponding to each box
gt_bboxes_ignore (None | list[Tensor]): Specify which bounding
boxes can be ignored when computing the loss.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
x = self.extract_feat(img)
losses = self.bbox_head.forward_train(x, img_metas, gt_bboxes,
gt_labels, gt_bboxes_ignore)
return losses
def forward_trace_fcos(self, img):
x = self.extract_feat(img)
outs = self.bbox_head(x)
cls_score_all = list()
bbox_pred_all = list()
centerness_all = list()
# 这里必须做维度的交换
for (cls_score, bbox_pred, centerness) in zip(outs[0], outs[1], outs[2]):
cls_score_all.append(cls_score.permute(0, 2, 3, 1).contiguous())
bbox_pred_all.append(bbox_pred.permute(0, 2, 3, 1).contiguous())
centerness_all.append(centerness.permute(0, 2, 3, 1).contiguous())
cls_score = torch.cat([o.view(o.size(0), -1) for o in cls_score_all], 1)
bbox_pred = torch.cat([o.view(o.size(0), -1) for o in bbox_pred_all], 1)
centerness = torch.cat([o.view(o.size(0), -1) for o in centerness_all], 1)
bbox_pred = bbox_pred.view(-1, 4)
cls_score = cls_score.view(-1, self.bbox_head.cls_out_channels)
centerness = centerness.view(-1, 1)
return cls_score, bbox_pred, centerness
def forward_trace_ssd_retinanet(self, img):
x = self.extract_feat(img)
outs = self.bbox_head(x)
cls_score_all = list()
bbox_pred_all = list()
# 这里必须做维度的交换
for (cls_score, bbox_pred) in zip(outs[0], outs[1]):
cls_score_all.append(cls_score.permute(0, 2, 3, 1).contiguous())
bbox_pred_all.append(bbox_pred.permute(0, 2, 3, 1).contiguous())
cls_score = torch.cat([o.view(o.size(0), -1) for o in cls_score_all], 1)
bbox_pred = torch.cat([o.view(o.size(0), -1) for o in bbox_pred_all], 1)
bbox_pred = bbox_pred.view(-1, 4)
cls_score = cls_score.view(-1, self.bbox_head.cls_out_channels)
return cls_score, bbox_pred
def simple_test(self, img, img_metas, rescale=False):
"""Test function without test time augmentation.
Args:
imgs (list[torch.Tensor]): List of multiple images
img_metas (list[dict]): List of image information.
rescale (bool, optional): Whether to rescale the results.
Defaults to False.
Returns:
list[list[np.ndarray]]: BBox results of each image and classes.
The outer list corresponds to each image. The inner list
corresponds to each class.
"""
# print(img.shape)
# print (img[0][0][100][100])
# print(img[0][0][200][200])
# print(img[0][0][300][300])
x = self.extract_feat(img)
outs = self.bbox_head(x)
bbox_list = self.bbox_head.get_bboxes(
*outs, img_metas, rescale=rescale)
# skip post-processing when exporting to ONNX
if torch.onnx.is_in_onnx_export():
return bbox_list
bbox_results = [
bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes)
for det_bboxes, det_labels in bbox_list
]
return bbox_results
def aug_test(self, imgs, img_metas, rescale=False):
"""Test function with test time augmentation.
Args:
imgs (list[Tensor]): the outer list indicates test-time
augmentations and inner Tensor should have a shape NxCxHxW,
which contains all images in the batch.
img_metas (list[list[dict]]): the outer list indicates test-time
augs (multiscale, flip, etc.) and the inner list indicates
images in a batch. each dict has image information.
rescale (bool, optional): Whether to rescale the results.
Defaults to False.
Returns:
list[list[np.ndarray]]: BBox results of each image and classes.
The outer list corresponds to each image. The inner list
corresponds to each class.
"""
assert hasattr(self.bbox_head, 'aug_test'), \
f'{self.bbox_head.__class__.__name__}' \
' does not support test-time augmentation'
feats = self.extract_feats(imgs)
return [self.bbox_head.aug_test(feats, img_metas, rescale=rescale)]