diff --git a/configs/co_detr/README.md b/configs/co_detr/README.md new file mode 100644 index 0000000000..2dc370f3ae --- /dev/null +++ b/configs/co_detr/README.md @@ -0,0 +1,39 @@ +# DETR + +## Introduction + + +DETR is an object detection model based on transformer. We reproduced the model of the paper. + + +## Model Zoo + +| Backbone | Model | Images/GPU | Inf time (fps) | Box AP | Config | Download | +|:------:|:--------:|:--------:|:--------------:|:------:|:------:|:--------:| +| R-50 | DETR | 4 | --- | 42.3 | [config](https://github.com/PaddlePaddle/PaddleDetection/blob/develop/configs/detr/detr_r50_1x_coco.yml) | [model](https://paddledet.bj.bcebos.com/models/detr_r50_1x_coco.pdparams) | + +**Notes:** + +- DETR is trained on COCO train2017 dataset and evaluated on val2017 results of `mAP(IoU=0.5:0.95)`. +- DETR uses 8GPU to train 500 epochs. + +GPU multi-card training +```bash +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +python -m paddle.distributed.launch --gpus 0,1,2,3,4,5,6,7 tools/train.py -c configs/detr/detr_r50_1x_coco.yml --fleet +``` + +## Citations +``` +@inproceedings{detr, + author = {Nicolas Carion and + Francisco Massa and + Gabriel Synnaeve and + Nicolas Usunier and + Alexander Kirillov and + Sergey Zagoruyko}, + title = {End-to-End Object Detection with Transformers}, + booktitle = {ECCV}, + year = {2020} +} +``` diff --git a/configs/co_detr/_base_/co_detr_r50.yml b/configs/co_detr/_base_/co_detr_r50.yml new file mode 100644 index 0000000000..d4a3721e40 --- /dev/null +++ b/configs/co_detr/_base_/co_detr_r50.yml @@ -0,0 +1,187 @@ +architecture: CO_DETR +# pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet50_vb_normal_pretrained.pdparams +pretrain_weights: /home/aistudio/co_deformable_detr_r50_1x_coco.pdparams + +# model settings +num_dec_layer: &num_dec_layer 6 +lambda_2: &lambda_2 2.0 + +CO_DETR: + backbone: ResNet + neck: ChannelMapper + query_head: CoDeformDETRHead + rpn_head: RPNHead + roi_head: Co_RoiHead + bbox_head: + name: CoATSSHead + num_classes: 80 + in_channels: 256 + stacked_convs: 1 + feat_channels: 256 + anchor_generator: + name: CoAnchorGenerator + octave_base_scale: 8 + scales_per_octave: 1 + aspect_ratios: [1.0] + strides: [8, 16, 32, 64, 128] + assigner: + name: ATSSAssigner + topk: 9 + loss_cls: + name: Weighted_FocalLoss + use_sigmoid: true + gamma: 2.0 + alpha: 0.25 + loss_bbox: + name: GIoULoss + + +ResNet: + # index 0 stands for res2 + depth: 50 + norm_type: bn + freeze_at: 0 + return_idx: [1,2,3] + lr_mult_list: [0.0, 0.1, 0.1, 0.1] + num_stages: 4 + +ChannelMapper: + in_channels: [512, 1024, 2048] + kernel_size: 1 + out_channels: 256 + norm_type: "gn" + norm_groups: 32 + act: None + num_outs: 4 + + +CoDeformDETRHead: + num_query: 300 + num_classes: 80 + in_channels: 2048 + sync_cls_avg_factor: True + with_box_refine: True + as_two_stage: True + mixed_selection: True + transformer: + name: CoDeformableDetrTransformer + num_co_heads: 2 + as_two_stage: True + mixed_selection: True + encoder: + name: CoTransformerEncoder + num_layers: *num_dec_layer + out_channel: 256 + encoder_layer: + name: TransformerEncoderLayer + d_model: 256 + attn: + name: MSDeformableAttention + embed_dim: 256 + num_heads: 8 + num_levels: 4 + num_points: 4 + dim_feedforward: 2048 + dropout: 0.0 + decoder: + name: CoDeformableDetrTransformerDecoder + num_layers: *num_dec_layer + return_intermediate: True + look_forward_twice: True + decoder_layer: + name: PETR_TransformerDecoderLayer + d_model: 256 + dim_feedforward: 2048 + dropout: 0.0 + self_attn: + name: MultiHeadAttention + embed_dim: 256 + num_heads: 8 + dropout: 0.0 + cross_attn: + name: MSDeformableAttention + embed_dim: 256 + positional_encoding: + name: PositionEmbedding + num_pos_feats: 128 + normalize: true + offset: -0.5 + loss_cls: + name: Weighted_FocalLoss + use_sigmoid: true + gamma: 2.0 + alpha: 0.25 + loss_weight: 2.0 + loss_bbox: + name: L1Loss + loss_weight: 5.0 + loss_iou: + name: GIoULoss + loss_weight: 2.0 + assigner: + name: HungarianAssigner + cls_cost: + name: FocalLossCost + weight: 2.0 + reg_cost: + name: BBoxL1Cost + weight: 5.0 + box_format: xywh + iou_cost: + name: IoUCost + iou_mode: giou + weight: 2.0 + test_cfg: + max_per_img: 100 + score_thr: 0.0 + nms: false + nms: + name: MultiClassNMS + keep_top_k: 100 + score_threshold: 0.05 + nms_threshold: 0.6 + +RPNHead: + loss_rpn_bbox: L1Loss + in_channel: 256 + anchor_generator: + name: RetinaAnchorGenerator + octave_base_scale: 4 + scales_per_octave: 3 + aspect_ratios: [0.5, 1.0, 2.0] + strides: [8.0, 16.0, 32.0, 64.0, 128.0] + rpn_target_assign: + batch_size_per_im: 256 + fg_fraction: 0.5 + negative_overlap: 0.3 + positive_overlap: 0.7 + use_random: True + train_proposal: + min_size: 0.0 + nms_thresh: 0.7 + pre_nms_top_n: 4000 + post_nms_top_n: 1000 + topk_after_collect: True + test_proposal: + min_size: 0.0 + nms_thresh: 0.7 + pre_nms_top_n: 1000 + post_nms_top_n: 1000 + +Co_RoiHead: + in_channel: 256 + num_classes: 80 + head: TwoFCHead + roi_extractor: + resolution: 7 + sampling_ratio: 0 + aligned: True + bbox_assigner: + name: BBoxAssigner + batch_size_per_im: 512 + bg_thresh: 0.5 + fg_thresh: 0.5 + fg_fraction: 0.25 + use_random: True + bbox_loss: + name: GIoULoss diff --git a/configs/co_detr/_base_/co_detr_reader.yml b/configs/co_detr/_base_/co_detr_reader.yml new file mode 100644 index 0000000000..6f10ab454d --- /dev/null +++ b/configs/co_detr/_base_/co_detr_reader.yml @@ -0,0 +1,47 @@ +worker_num: 0 +TrainReader: + sample_transforms: + - Decode: {} + - RandomFlip: {prob: 0.5} + - RandomSelect: { transforms1: [ RandomShortSideResize: { short_side_sizes: [ 480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800 ], max_size: 1333 } ], + transforms2: [ + RandomShortSideResize: { short_side_sizes: [ 400, 500, 600 ] }, + RandomSizeCrop: { min_size: 384, max_size: 600 }, + RandomShortSideResize: { short_side_sizes: [ 480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800 ], max_size: 1333 } ] } + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - NormalizeBox: {} + - BboxXYXY2XYWH: {} + - Permute: {} + batch_transforms: + - PadMaskBatch: {pad_to_stride: -1, return_pad_mask: true} + batch_size: 2 + shuffle: false + drop_last: true + collate_batch: false + use_shared_memory: false + + +EvalReader: + sample_transforms: + - Decode: {} + # - PETR_Resize: {img_scale: [[800, 1333]], keep_ratio: True} + - Resize: {target_size: [800, 1333], keep_ratio: True, interp: 1} + - NormalizeImage: + mean: [0.485,0.456,0.406] + std: [0.229, 0.224,0.225] + is_scale: true + - Permute: {} + batch_size: 1 + shuffle: false + drop_last: false + + +TestReader: + sample_transforms: + - Decode: {} + - Resize: {target_size: [800, 1333], keep_ratio: True} + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - Permute: {} + batch_size: 1 + shuffle: false + drop_last: false diff --git a/configs/co_detr/_base_/optimizer_1x.yml b/configs/co_detr/_base_/optimizer_1x.yml new file mode 100644 index 0000000000..13528c5eba --- /dev/null +++ b/configs/co_detr/_base_/optimizer_1x.yml @@ -0,0 +1,16 @@ +epoch: 500 + +LearningRate: + base_lr: 0.0001 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: [400] + use_warmup: false + +OptimizerBuilder: + clip_grad_by_norm: 0.1 + regularizer: false + optimizer: + type: AdamW + weight_decay: 0.0001 diff --git a/configs/co_detr/co_detr_r50_1x_coco.yml b/configs/co_detr/co_detr_r50_1x_coco.yml new file mode 100644 index 0000000000..be03708020 --- /dev/null +++ b/configs/co_detr/co_detr_r50_1x_coco.yml @@ -0,0 +1,9 @@ +_BASE_: [ + '../datasets/coco_detection.yml', + '../runtime.yml', + '_base_/optimizer_1x.yml', + '_base_/co_detr_r50.yml', + '_base_/co_detr_reader.yml', +] +weights: /home/aistudio/co_deformable_detr_r50_1x_coco.pdparams +find_unused_parameters: True diff --git a/ppdet/modeling/architectures/__init__.py b/ppdet/modeling/architectures/__init__.py index d22df32d85..4803b97e42 100644 --- a/ppdet/modeling/architectures/__init__.py +++ b/ppdet/modeling/architectures/__init__.py @@ -45,6 +45,7 @@ from . import detr_ssod from . import multi_stream_detector from . import clrnet +from . import co_detr from .meta_arch import * from .faster_rcnn import * @@ -68,6 +69,7 @@ from .gfl import * from .picodet import * from .detr import * +from .co_detr import * from .sparse_rcnn import * from .tood import * from .retinanet import * diff --git a/ppdet/modeling/architectures/co_detr.py b/ppdet/modeling/architectures/co_detr.py new file mode 100644 index 0000000000..c007b73b4d --- /dev/null +++ b/ppdet/modeling/architectures/co_detr.py @@ -0,0 +1,218 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +import numpy as np +from .meta_arch import BaseArch +from ppdet.core.workspace import register, create + +__all__ = ['CO_DETR'] +# Collaborative DETR, DINO use the same architecture as DETR + +def bbox2result(bboxes, labels, num_classes): + """Convert detection results to a list of numpy arrays. + + Args: + bboxes (paddle.Tensor | np.ndarray): shape (n, 5) + labels (paddle.Tensor | np.ndarray): shape (n, ) + num_classes (int): class number, including background class + + Returns: + list(ndarray): bbox results of each class + """ + if bboxes.shape[0] == 0: + return [np.zeros((0, 5), dtype=np.float32) for i in range(num_classes)] + else: + if isinstance(bboxes, paddle.Tensor): + bboxes = bboxes.detach().cpu().numpy() + labels = labels.detach().cpu().numpy() + return [bboxes[labels == i, :] for i in range(num_classes)] + + +@register +class CO_DETR(BaseArch): + __category__ = 'architecture' + __inject__ = ['bbox_head'] + + def __init__(self, + backbone, + neck=None, + query_head=None, + rpn_head=None, + roi_head=None, + bbox_head=None, + with_pos_coord=True, + with_attn_mask=True, + ): + super(CO_DETR, self).__init__() + self.backbone = backbone + if neck is not None: + self.with_neck = True + self.neck = neck + self.query_head = query_head + self.rpn_head = rpn_head + self.roi_head = roi_head + self.bbox_head = bbox_head + self.deploy = False + self.with_pos_coord = with_pos_coord + self.with_attn_mask = with_attn_mask + + @classmethod + def from_config(cls, cfg, *args, **kwargs): + backbone = create(cfg['backbone']) + kwargs = {'input_shape': backbone.out_shape} + neck = cfg['neck'] and create(cfg['neck'], **kwargs) + # out_shape = neck and neck.out_shape or backbone.out_shape + query_head = create(cfg['query_head']) + out_shape = query_head.transformer.encoder.out_shape + kwargs = {'input_shape': out_shape} + rpn_head = create(cfg['rpn_head'], **kwargs) + roi_head = create(cfg['roi_head'], **kwargs) + return { + 'backbone': backbone, + 'neck': neck, + 'query_head': query_head, + 'rpn_head': rpn_head, + 'roi_head':roi_head, + } + + def extract_feat(self, img): + """Directly extract features from the backbone+neck.""" + x = self.backbone(img) + if self.with_neck: + x = self.neck(x) + return x + + def get_inputs(self): + img_metas = [] + gt_bboxes = [] + gt_labels = [] + + for idx, im_shape in enumerate(self.inputs['im_shape']): + img_meta = { + 'img_shape': im_shape.astype("int32").tolist() + [1, ], + 'batch_input_shape': self.inputs['image'].shape[-2:], + 'pad_mask': self.inputs['pad_mask'][idx], + } + img_metas.append(img_meta) + + gt_labels.append(self.inputs['gt_class'][idx]) + gt_bboxes.append(self.inputs['gt_bbox'][idx]) + + return img_metas, gt_bboxes, gt_labels + + + def get_pred(self): + img = self.inputs['image'] + batch_size, _, height, width = img.shape + img_metas = [ + dict( + batch_input_shape=(height, width), + img_shape=(height, width, 3), + scale_factor=self.inputs['scale_factor'][i]) + for i in range(batch_size) + ] + + x = self.extract_feat(self.inputs) + # from reprod_log import ReprodLogger + # reprod_log_1 = ReprodLogger() + # reprod_log_1.add("demo_test_1", x[0].cpu().detach().numpy()) + # reprod_log_1.save("result_1_paddle.npy") + # breakpoint() + bbox = self.query_head.simple_test( + x, img_metas, rescale=True) + bbox_num=[] + for i in range(len(bbox)): + bbox_num.append(bbox[i].shape[0]) + bbox_num = paddle.to_tensor(bbox_num) + bbox = paddle.concat(bbox, axis=0) + + return {'bbox': bbox, 'bbox_num': bbox_num} + + def get_loss(self): + """ + Args: + img (Tensor): Input images of shape (N, C, H, W). + Typically these should be mean centered and std scaled. + img_metas (list[dict]): A List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + :class:`mmdet.datasets.pipelines.Collect`. + gt_bboxes (list[Tensor]): Each item are the truth boxes for each + image in [tl_x, tl_y, br_x, br_y] format. + gt_labels (list[Tensor]): Class indices corresponding to each box. + gt_areas (list[Tensor]): mask areas corresponding to each box. + gt_bboxes_ignore (None | list[Tensor]): Specify which bounding + boxes can be ignored when computing the loss. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + img = self.inputs['image'] + batch_size, _, height, width = img.shape + img_metas, gt_bboxes, gt_labels = self.get_inputs() + gt_bboxes_ignore = getattr(self.inputs, 'gt_bboxes_ignore', None) + x = self.extract_feat(self.inputs) + losses = dict() + # DETR encoder and decoder forward + if self.query_head is not None: + bbox_losses, x = self.query_head.forward_train(x, img_metas, gt_bboxes, + gt_labels, gt_bboxes_ignore) + losses.update(bbox_losses) + + if self.rpn_head is not None: + rois, rois_num, rpn_loss = self.rpn_head(x, self.inputs) + losses.update(rpn_loss) + + positive_coords = [] + if self.roi_head is not None: + roi_losses, _ = self.roi_head(x, rois, rois_num, + self.inputs) + if self.with_pos_coord: + positive_coords.append(roi_losses.pop('pos_coords')) + else: + if 'pos_coords' in roi_losses.keys(): + tmp = roi_losses.pop('pos_coords') + losses.update(roi_losses) + + # if self.bbox_head is not None: + # bbox_losses = self.bbox_head.forward_train(x,img_metas,gt_bboxes,gt_labels,) + # if self.with_pos_coord: + # positive_coords.append(bbox_losses.pop('pos_coords')) + # else: + # if 'pos_coords' in bbox_losses.keys(): + # tmp = bbox_losses.pop('pos_coords') + # losses.update(bbox_losses) + + if self.with_pos_coord and len(positive_coords)>0: + for i in range(len(positive_coords)): + bbox_losses = self.query_head.forward_train_aux(x, img_metas, gt_bboxes, + gt_labels, gt_bboxes_ignore, positive_coords[i], i) + if bbox_losses is not None: + losses.update(bbox_losses) + loss = 0 + for k, v in losses.items(): + if isinstance(v, list): + loss += sum(v) + else: + loss += v + losses={} + losses['loss'] = loss + return losses + \ No newline at end of file diff --git a/ppdet/modeling/assigners/hungarian_assigner.py b/ppdet/modeling/assigners/hungarian_assigner.py index 154c27ce97..e08286f953 100644 --- a/ppdet/modeling/assigners/hungarian_assigner.py +++ b/ppdet/modeling/assigners/hungarian_assigner.py @@ -24,8 +24,9 @@ import paddle from ppdet.core.workspace import register +from ppdet.modeling.assigners.pose_utils import bbox_cxcywh_to_xyxy -__all__ = ['PoseHungarianAssigner', 'PseudoSampler'] +__all__ = ["PoseHungarianAssigner", "PseudoSampler", "HungarianAssigner"] class AssignResult: @@ -72,11 +73,11 @@ def get_extra_property(self, key): def info(self): """dict: a dictionary of info about the object""" basic_info = { - 'num_gts': self.num_gts, - 'num_preds': self.num_preds, - 'gt_inds': self.gt_inds, - 'max_overlaps': self.max_overlaps, - 'labels': self.labels, + "num_gts": self.num_gts, + "num_preds": self.num_preds, + "gt_inds": self.gt_inds, + "max_overlaps": self.max_overlaps, + "labels": self.labels, } basic_info.update(self._extra_properties) return basic_info @@ -105,24 +106,19 @@ class PoseHungarianAssigner: oks_weight (int | float, optional): The scale factor for regression oks cost. Default 1.0. """ - __inject__ = ['cls_cost', 'kpt_cost', 'oks_cost'] - def __init__(self, - cls_cost='ClassificationCost', - kpt_cost='KptL1Cost', - oks_cost='OksCost'): + __inject__ = ["cls_cost", "kpt_cost", "oks_cost"] + + def __init__( + self, cls_cost="ClassificationCost", kpt_cost="KptL1Cost", oks_cost="OksCost" + ): self.cls_cost = cls_cost self.kpt_cost = kpt_cost self.oks_cost = oks_cost - def assign(self, - cls_pred, - kpt_pred, - gt_labels, - gt_keypoints, - gt_areas, - img_meta, - eps=1e-7): + def assign( + self, cls_pred, kpt_pred, gt_labels, gt_keypoints, gt_areas, img_meta, eps=1e-7 + ): """Computes one-to-one matching based on the weighted costs. This method assign each query prediction to a ground truth or @@ -157,52 +153,50 @@ def assign(self, :obj:`AssignResult`: The assigned result. """ num_gts, num_kpts = gt_keypoints.shape[0], kpt_pred.shape[0] - if not gt_keypoints.astype('bool').any(): + if not gt_keypoints.astype("bool").any(): num_gts = 0 # 1. assign -1 by default - assigned_gt_inds = paddle.full((num_kpts, ), -1, dtype="int64") - assigned_labels = paddle.full((num_kpts, ), -1, dtype="int64") + assigned_gt_inds = paddle.full((num_kpts,), -1, dtype="int64") + assigned_labels = paddle.full((num_kpts,), -1, dtype="int64") if num_gts == 0 or num_kpts == 0: # No ground truth or keypoints, return empty assignment if num_gts == 0: # No ground truth, assign all to background assigned_gt_inds[:] = 0 - return AssignResult( - num_gts, assigned_gt_inds, None, labels=assigned_labels) - img_h, img_w, _ = img_meta['img_shape'] + return AssignResult(num_gts, assigned_gt_inds, None, labels=assigned_labels) + img_h, img_w, _ = img_meta["img_shape"] factor = paddle.to_tensor( - [img_w, img_h, img_w, img_h], dtype=gt_keypoints.dtype).reshape( - (1, -1)) + [img_w, img_h, img_w, img_h], dtype=gt_keypoints.dtype + ).reshape((1, -1)) # 2. compute the weighted costs # classification cost cls_cost = self.cls_cost(cls_pred, gt_labels) # keypoint regression L1 cost - gt_keypoints_reshape = gt_keypoints.reshape((gt_keypoints.shape[0], -1, - 3)) + gt_keypoints_reshape = gt_keypoints.reshape((gt_keypoints.shape[0], -1, 3)) valid_kpt_flag = gt_keypoints_reshape[..., -1] - kpt_pred_tmp = kpt_pred.clone().detach().reshape((kpt_pred.shape[0], -1, - 2)) - normalize_gt_keypoints = gt_keypoints_reshape[ - ..., :2] / factor[:, :2].unsqueeze(0) - kpt_cost = self.kpt_cost(kpt_pred_tmp, normalize_gt_keypoints, - valid_kpt_flag) + kpt_pred_tmp = kpt_pred.clone().detach().reshape((kpt_pred.shape[0], -1, 2)) + normalize_gt_keypoints = gt_keypoints_reshape[..., :2] / factor[ + :, :2 + ].unsqueeze(0) + kpt_cost = self.kpt_cost(kpt_pred_tmp, normalize_gt_keypoints, valid_kpt_flag) # keypoint OKS cost - kpt_pred_tmp = kpt_pred.clone().detach().reshape((kpt_pred.shape[0], -1, - 2)) + kpt_pred_tmp = kpt_pred.clone().detach().reshape((kpt_pred.shape[0], -1, 2)) kpt_pred_tmp = kpt_pred_tmp * factor[:, :2].unsqueeze(0) - oks_cost = self.oks_cost(kpt_pred_tmp, gt_keypoints_reshape[..., :2], - valid_kpt_flag, gt_areas) + oks_cost = self.oks_cost( + kpt_pred_tmp, gt_keypoints_reshape[..., :2], valid_kpt_flag, gt_areas + ) # weighted sum of above three costs cost = cls_cost + kpt_cost + oks_cost # 3. do Hungarian matching on CPU using linear_sum_assignment cost = cost.detach().cpu() if linear_sum_assignment is None: - raise ImportError('Please run "pip install scipy" ' - 'to install scipy first.') + raise ImportError( + 'Please run "pip install scipy" ' "to install scipy first." + ) matched_row_inds, matched_col_inds = linear_sum_assignment(cost) matched_row_inds = paddle.to_tensor(matched_row_inds) matched_col_inds = paddle.to_tensor(matched_col_inds) @@ -212,20 +206,19 @@ def assign(self, assigned_gt_inds[:] = 0 # assign foregrounds based on matching results assigned_gt_inds[matched_row_inds] = matched_col_inds + 1 - assigned_labels[matched_row_inds] = gt_labels[matched_col_inds][ - ..., 0].astype("int64") - return AssignResult( - num_gts, assigned_gt_inds, None, labels=assigned_labels) + assigned_labels[matched_row_inds] = gt_labels[matched_col_inds][..., 0].astype( + "int64" + ) + return AssignResult(num_gts, assigned_gt_inds, None, labels=assigned_labels) class SamplingResult: - """Bbox sampling result. - """ + """Bbox sampling result.""" - def __init__(self, pos_inds, neg_inds, bboxes, gt_bboxes, assign_result, - gt_flags): + def __init__(self, pos_inds, neg_inds, bboxes, gt_bboxes, assign_result, gt_flags): self.pos_inds = pos_inds self.neg_inds = neg_inds + if pos_inds.size > 0: self.pos_bboxes = bboxes[pos_inds] self.neg_bboxes = bboxes[neg_inds] @@ -238,15 +231,15 @@ def __init__(self, pos_inds, neg_inds, bboxes, gt_bboxes, assign_result, # hack for index error case assert self.pos_assigned_gt_inds.numel() == 0 self.pos_gt_bboxes = paddle.zeros( - gt_bboxes.shape, dtype=gt_bboxes.dtype).reshape((-1, 4)) + gt_bboxes.shape, dtype=gt_bboxes.dtype + ).reshape((-1, 4)) else: if len(gt_bboxes.shape) < 2: gt_bboxes = gt_bboxes.reshape((-1, 4)) self.pos_gt_bboxes = paddle.index_select( - gt_bboxes, - self.pos_assigned_gt_inds.astype('int64'), - axis=0) + gt_bboxes, self.pos_assigned_gt_inds.astype("int64"), axis=0 + ) if assign_result.labels is not None: self.pos_gt_labels = assign_result.labels[pos_inds] @@ -260,23 +253,23 @@ def bboxes(self): def __nice__(self): data = self.info.copy() - data['pos_bboxes'] = data.pop('pos_bboxes').shape - data['neg_bboxes'] = data.pop('neg_bboxes').shape + data["pos_bboxes"] = data.pop("pos_bboxes").shape + data["neg_bboxes"] = data.pop("neg_bboxes").shape parts = [f"'{k}': {v!r}" for k, v in sorted(data.items())] - body = ' ' + ',\n '.join(parts) - return '{\n' + body + '\n}' + body = " " + ",\n ".join(parts) + return "{\n" + body + "\n}" @property def info(self): """Returns a dictionary of info about the object.""" return { - 'pos_inds': self.pos_inds, - 'neg_inds': self.neg_inds, - 'pos_bboxes': self.pos_bboxes, - 'neg_bboxes': self.neg_bboxes, - 'pos_is_gt': self.pos_is_gt, - 'num_gts': self.num_gts, - 'pos_assigned_gt_inds': self.pos_assigned_gt_inds, + "pos_inds": self.pos_inds, + "neg_inds": self.neg_inds, + "pos_bboxes": self.pos_bboxes, + "neg_bboxes": self.neg_bboxes, + "pos_is_gt": self.pos_is_gt, + "num_gts": self.num_gts, + "pos_assigned_gt_inds": self.pos_assigned_gt_inds, } @@ -306,11 +299,146 @@ def sample(self, assign_result, bboxes, gt_bboxes, *args, **kwargs): Returns: :obj:`SamplingResult`: sampler results """ - pos_inds = paddle.nonzero( - assign_result.gt_inds > 0, as_tuple=False).squeeze(-1) - neg_inds = paddle.nonzero( - assign_result.gt_inds == 0, as_tuple=False).squeeze(-1) - gt_flags = paddle.zeros([bboxes.shape[0]], dtype='int32') - sampling_result = SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes, - assign_result, gt_flags) + + pos_inds = paddle.nonzero(assign_result.gt_inds > 0, as_tuple=False).squeeze(-1) + neg_inds = paddle.nonzero(assign_result.gt_inds == 0, as_tuple=False).squeeze( + -1 + ) + gt_flags = paddle.zeros([bboxes.shape[0]], dtype="int32") + sampling_result = SamplingResult( + pos_inds, neg_inds, bboxes, gt_bboxes, assign_result, gt_flags + ) return sampling_result + + +@register +class HungarianAssigner: + """Computes one-to-one matching between predictions and ground truth. + + This class computes an assignment between the targets and the predictions + based on the costs. The costs are weighted sum of three components: + classification cost, regression L1 cost and regression iou cost. The + targets don't include the no_object, so generally there are more + predictions than targets. After the one-to-one matching, the un-matched + are treated as backgrounds. Thus each query prediction will be assigned + with `0` or a positive integer indicating the ground truth index: + + - 0: negative sample, no assigned gt + - positive integer: positive sample, index (1-based) of assigned gt + + Args: + cls_weight (int | float, optional): The scale factor for classification + cost. Default 1.0. + bbox_weight (int | float, optional): The scale factor for regression + L1 cost. Default 1.0. + iou_weight (int | float, optional): The scale factor for regression + iou cost. Default 1.0. + iou_calculator (dict | optional): The config for the iou calculation. + Default type `BboxOverlaps2D`. + iou_mode (str | optional): "iou" (intersection over union), "iof" + (intersection over foreground), or "giou" (generalized + intersection over union). Default "giou". + """ + + __inject__ = ["cls_cost", "reg_cost", "iou_cost"] + + def __init__( + self, cls_cost="ClassificationCost", reg_cost="BBoxL1Cost", iou_cost="IoUCost" + ): + self.cls_cost = cls_cost + self.reg_cost = reg_cost + self.iou_cost = iou_cost + + def assign( + self, + bbox_pred, + cls_pred, + gt_bboxes, + gt_labels, + img_meta, + gt_bboxes_ignore=None, + eps=1e-7, + ): + """Computes one-to-one matching based on the weighted costs. + + This method assign each query prediction to a ground truth or + background. The `assigned_gt_inds` with -1 means don't care, + 0 means negative sample, and positive number is the index (1-based) + of assigned gt. + The assignment is done in the following steps, the order matters. + + 1. assign every prediction to -1 + 2. compute the weighted costs + 3. do Hungarian matching on CPU based on the costs + 4. assign all to 0 (background) first, then for each matched pair + between predictions and gts, treat this prediction as foreground + and assign the corresponding gt index (plus 1) to it. + + Args: + bbox_pred (Tensor): Predicted boxes with normalized coordinates + (cx, cy, w, h), which are all in range [0, 1]. Shape + [num_query, 4]. + cls_pred (Tensor): Predicted classification logits, shape + [num_query, num_class]. + gt_bboxes (Tensor): Ground truth boxes with unnormalized + coordinates (x1, y1, x2, y2). Shape [num_gt, 4]. + gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,). + img_meta (dict): Meta information for current image. + gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are + labelled as `ignored`. Default None. + eps (int | float, optional): A value added to the denominator for + numerical stability. Default 1e-7. + + Returns: + :obj:`AssignResult`: The assigned result. + """ + assert ( + gt_bboxes_ignore is None + ), "Only case when gt_bboxes_ignore is None is supported." + num_gts, num_bboxes = gt_bboxes.shape[0], bbox_pred.shape[0] + + # 1. assign -1 by default + assigned_gt_inds = paddle.full((num_bboxes,), -1, dtype="int64") + assigned_labels = paddle.full((num_bboxes,), -1, dtype="int64") + if num_gts == 0 or num_bboxes == 0: + # No ground truth or boxes, return empty assignment + if num_gts == 0: + # No ground truth, assign all to background + assigned_gt_inds[:] = 0 + return AssignResult(num_gts, assigned_gt_inds, None, labels=assigned_labels) + img_h, img_w, _ = img_meta["img_shape"] + factor = paddle.to_tensor( + [img_w, img_h, img_w, img_h], dtype=gt_bboxes.dtype + ).unsqueeze(0) + + # 2. compute the weighted costs + # classification and bboxcost. + cls_cost = self.cls_cost(cls_pred, gt_labels) + # regression L1 cost + normalize_gt_bboxes = gt_bboxes / factor + reg_cost = self.reg_cost(bbox_pred, normalize_gt_bboxes) + # regression iou cost, defaultly giou is used in official DETR. + bboxes = bbox_cxcywh_to_xyxy(bbox_pred) * factor + iou_cost = self.iou_cost(bboxes, gt_bboxes) + # weighted sum of above three costs + cost = cls_cost + reg_cost + iou_cost + + # 3. do Hungarian matching on CPU using linear_sum_assignment + cost = cost.detach().cpu() + if linear_sum_assignment is None: + raise ImportError('Please run "pip install scipy" ' + 'to install scipy first.') + matched_row_inds, matched_col_inds = linear_sum_assignment(cost) + matched_row_inds = paddle.to_tensor(matched_row_inds) + matched_col_inds = paddle.to_tensor(matched_col_inds) + + # 4. assign backgrounds and foregrounds + # assign all indices to backgrounds first + assigned_gt_inds[:] = 0 + # assign foregrounds based on matching results + assigned_gt_inds[matched_row_inds] = matched_col_inds + 1 + assigned_labels[matched_row_inds] = gt_labels[matched_col_inds][ + ..., 0].astype("int64") + + return AssignResult( + num_gts, assigned_gt_inds, None, labels=assigned_labels) diff --git a/ppdet/modeling/assigners/pose_utils.py b/ppdet/modeling/assigners/pose_utils.py index 313215a4dd..422b92eb3a 100644 --- a/ppdet/modeling/assigners/pose_utils.py +++ b/ppdet/modeling/assigners/pose_utils.py @@ -21,8 +21,10 @@ import paddle.nn.functional as F from ppdet.core.workspace import register +from ppdet.data.transform.atss_assigner import bbox_overlaps +from ppdet.modeling.transformers.utils import bbox_xyxy_to_cxcywh -__all__ = ['KptL1Cost', 'OksCost', 'ClassificationCost'] +__all__ = ["KptL1Cost", "OksCost", "ClassificationCost", "BBoxL1Cost", "IoUCost"] def masked_fill(x, mask, value): @@ -63,17 +65,18 @@ def __call__(self, kpt_pred, gt_keypoints, valid_kpt_flag): kpt_cost.append(kpt_pred.sum() * 0) kpt_pred_tmp = kpt_pred.clone() valid_flag = valid_kpt_flag[i] > 0 - valid_flag_expand = valid_flag.unsqueeze(0).unsqueeze(-1).expand_as( - kpt_pred_tmp) + valid_flag_expand = ( + valid_flag.unsqueeze(0).unsqueeze(-1).expand_as(kpt_pred_tmp) + ) if not valid_flag_expand.all(): kpt_pred_tmp = masked_fill(kpt_pred_tmp, ~valid_flag_expand, 0) cost = F.pairwise_distance( kpt_pred_tmp.reshape((kpt_pred_tmp.shape[0], -1)), - gt_keypoints[i].reshape((-1, )).unsqueeze(0), + gt_keypoints[i].reshape((-1,)).unsqueeze(0), p=1, - keepdim=True) - avg_factor = paddle.clip( - valid_flag.astype('float32').sum() * 2, 1.0) + keepdim=True, + ) + avg_factor = paddle.clip(valid_flag.astype("float32").sum() * 2, 1.0) cost = cost / avg_factor kpt_cost.append(cost) kpt_cost = paddle.concat(kpt_cost, axis=1) @@ -94,21 +97,56 @@ class OksCost(object): def __init__(self, num_keypoints=17, weight=1.0): self.weight = weight if num_keypoints == 17: - self.sigmas = np.array( - [ - .26, .25, .25, .35, .35, .79, .79, .72, .72, .62, .62, 1.07, - 1.07, .87, .87, .89, .89 - ], - dtype=np.float32) / 10.0 + self.sigmas = ( + np.array( + [ + 0.26, + 0.25, + 0.25, + 0.35, + 0.35, + 0.79, + 0.79, + 0.72, + 0.72, + 0.62, + 0.62, + 1.07, + 1.07, + 0.87, + 0.87, + 0.89, + 0.89, + ], + dtype=np.float32, + ) + / 10.0 + ) elif num_keypoints == 14: - self.sigmas = np.array( - [ - .79, .79, .72, .72, .62, .62, 1.07, 1.07, .87, .87, .89, - .89, .79, .79 - ], - dtype=np.float32) / 10.0 + self.sigmas = ( + np.array( + [ + 0.79, + 0.79, + 0.72, + 0.72, + 0.62, + 0.62, + 1.07, + 1.07, + 0.87, + 0.87, + 0.89, + 0.89, + 0.79, + 0.79, + ], + dtype=np.float32, + ) + / 10.0 + ) else: - raise ValueError(f'Unsupported keypoints number {num_keypoints}') + raise ValueError(f"Unsupported keypoints number {num_keypoints}") def __call__(self, kpt_pred, gt_keypoints, valid_kpt_flag, gt_areas): """ @@ -125,17 +163,17 @@ def __call__(self, kpt_pred, gt_keypoints, valid_kpt_flag, gt_areas): paddle.Tensor: oks_cost value with weight. """ sigmas = paddle.to_tensor(self.sigmas) - variances = (sigmas * 2)**2 + variances = (sigmas * 2) ** 2 oks_cost = [] assert len(gt_keypoints) == len(gt_areas) for i in range(len(gt_keypoints)): if gt_keypoints[i].size == 0: oks_cost.append(kpt_pred.sum() * 0) - squared_distance = \ - (kpt_pred[:, :, 0] - gt_keypoints[i, :, 0].unsqueeze(0)) ** 2 + \ - (kpt_pred[:, :, 1] - gt_keypoints[i, :, 1].unsqueeze(0)) ** 2 - vis_flag = (valid_kpt_flag[i] > 0).astype('int') + squared_distance = ( + kpt_pred[:, :, 0] - gt_keypoints[i, :, 0].unsqueeze(0) + ) ** 2 + (kpt_pred[:, :, 1] - gt_keypoints[i, :, 1].unsqueeze(0)) ** 2 + vis_flag = (valid_kpt_flag[i] > 0).astype("int") vis_ind = vis_flag.nonzero(as_tuple=False)[:, 0] num_vis_kpt = vis_ind.shape[0] # assert num_vis_kpt > 0 @@ -145,10 +183,8 @@ def __call__(self, kpt_pred, gt_keypoints, valid_kpt_flag, gt_areas): area = gt_areas[i] squared_distance0 = squared_distance / (area * variances * 2) - squared_distance0 = paddle.index_select( - squared_distance0, vis_ind, axis=1) - squared_distance1 = paddle.exp(-squared_distance0).sum(axis=1, - keepdim=True) + squared_distance0 = paddle.index_select(squared_distance0, vis_ind, axis=1) + squared_distance1 = paddle.exp(-squared_distance0).sum(axis=1, keepdim=True) oks = squared_distance1 / num_vis_kpt # The 1 is a constant that doesn't change the matching, so omitted. oks_cost.append(-oks) @@ -160,11 +196,11 @@ def __call__(self, kpt_pred, gt_keypoints, valid_kpt_flag, gt_areas): class ClassificationCost: """ClsSoftmaxCost. - Args: - weight (int | float, optional): loss_weight + Args: + weight (int | float, optional): loss_weight """ - def __init__(self, weight=1.): + def __init__(self, weight=1.0): self.weight = weight def __call__(self, cls_pred, gt_labels): @@ -190,21 +226,16 @@ def __call__(self, cls_pred, gt_labels): class FocalLossCost: """FocalLossCost. - Args: - weight (int | float, optional): loss_weight - alpha (int | float, optional): focal_loss alpha - gamma (int | float, optional): focal_loss gamma - eps (float, optional): default 1e-12 - binary_input (bool, optional): Whether the input is binary, - default False. + Args: + weight (int | float, optional): loss_weight + alpha (int | float, optional): focal_loss alpha + gamma (int | float, optional): focal_loss gamma + eps (float, optional): default 1e-12 + binary_input (bool, optional): Whether the input is binary, + default False. """ - def __init__(self, - weight=1., - alpha=0.25, - gamma=2, - eps=1e-12, - binary_input=False): + def __init__(self, weight=1.0, alpha=0.25, gamma=2, eps=1e-12, binary_input=False): self.weight = weight self.alpha = alpha self.gamma = gamma @@ -224,14 +255,18 @@ def _focal_loss_cost(self, cls_pred, gt_labels): if gt_labels.size == 0: return cls_pred.sum() * 0 cls_pred = F.sigmoid(cls_pred) - neg_cost = -(1 - cls_pred + self.eps).log() * ( - 1 - self.alpha) * cls_pred.pow(self.gamma) - pos_cost = -(cls_pred + self.eps).log() * self.alpha * ( - 1 - cls_pred).pow(self.gamma) + neg_cost = ( + -(1 - cls_pred + self.eps).log() + * (1 - self.alpha) + * cls_pred.pow(self.gamma) + ) + pos_cost = ( + -(cls_pred + self.eps).log() * self.alpha * (1 - cls_pred).pow(self.gamma) + ) cls_cost = paddle.index_select( - pos_cost, gt_labels, axis=1) - paddle.index_select( - neg_cost, gt_labels, axis=1) + pos_cost, gt_labels, axis=1 + ) - paddle.index_select(neg_cost, gt_labels, axis=1) return cls_cost * self.weight def _mask_focal_loss_cost(self, cls_pred, gt_labels): @@ -250,13 +285,18 @@ def _mask_focal_loss_cost(self, cls_pred, gt_labels): gt_labels = gt_labels.flatten(1).float() n = cls_pred.shape[1] cls_pred = F.sigmoid(cls_pred) - neg_cost = -(1 - cls_pred + self.eps).log() * ( - 1 - self.alpha) * cls_pred.pow(self.gamma) - pos_cost = -(cls_pred + self.eps).log() * self.alpha * ( - 1 - cls_pred).pow(self.gamma) - - cls_cost = paddle.einsum('nc,mc->nm', pos_cost, gt_labels) + \ - paddle.einsum('nc,mc->nm', neg_cost, (1 - gt_labels)) + neg_cost = ( + -(1 - cls_pred + self.eps).log() + * (1 - self.alpha) + * cls_pred.pow(self.gamma) + ) + pos_cost = ( + -(cls_pred + self.eps).log() * self.alpha * (1 - cls_pred).pow(self.gamma) + ) + + cls_cost = paddle.einsum("nc,mc->nm", pos_cost, gt_labels) + paddle.einsum( + "nc,mc->nm", neg_cost, (1 - gt_labels) + ) return cls_cost / n * self.weight def __call__(self, cls_pred, gt_labels): @@ -273,3 +313,86 @@ def __call__(self, cls_pred, gt_labels): return self._mask_focal_loss_cost(cls_pred, gt_labels) else: return self._focal_loss_cost(cls_pred, gt_labels) + + +def bbox_cxcywh_to_xyxy(bbox): + """Convert bbox coordinates from (cx, cy, w, h) to (x1, y1, x2, y2). + + Args: + bbox (Tensor): Shape (n, 4) for bboxes. + + Returns: + Tensor: Converted bboxes. + """ + cx, cy, w, h = paddle.split(bbox, (1, 1, 1, 1), axis=-1) + bbox_new = [(cx - 0.5 * w), (cy - 0.5 * h), (cx + 0.5 * w), (cy + 0.5 * h)] + return paddle.concat(bbox_new, axis=-1) + + +@register +class BBoxL1Cost: + """BBoxL1Cost. + + Args: + weight (int | float, optional): loss_weight + box_format (str, optional): 'xyxy' for DETR, 'xywh' for Sparse_RCNN + + """ + + def __init__(self, weight=1.0, box_format="xyxy"): + self.weight = weight + assert box_format in ["xyxy", "xywh"] + self.box_format = box_format + + def __call__(self, bbox_pred, gt_bboxes): + """ + Args: + bbox_pred (Tensor): Predicted boxes with normalized coordinates + (cx, cy, w, h), which are all in range [0, 1]. Shape + (num_query, 4). + gt_bboxes (Tensor): Ground truth boxes with normalized + coordinates (x1, y1, x2, y2). Shape (num_gt, 4). + + Returns: + Tensor: bbox_cost value with weight + """ + if self.box_format == "xywh": + gt_bboxes = bbox_xyxy_to_cxcywh(gt_bboxes) + elif self.box_format == "xyxy": + bbox_pred = bbox_cxcywh_to_xyxy(bbox_pred) + bbox_cost = paddle.cdist(bbox_pred, gt_bboxes, p=1) + return bbox_cost * self.weight + + +@register +class IoUCost: + """IoUCost. + + Args: + iou_mode (str, optional): iou mode such as 'iou' | 'giou' + weight (int | float, optional): loss weight + + """ + + def __init__(self, iou_mode="giou", weight=1.0): + self.weight = weight + self.iou_mode = iou_mode + + def __call__(self, bboxes, gt_bboxes): + """ + Args: + bboxes (Tensor): Predicted boxes with unnormalized coordinates + (x1, y1, x2, y2). Shape (num_query, 4). + gt_bboxes (Tensor): Ground truth boxes with unnormalized + coordinates (x1, y1, x2, y2). Shape (num_gt, 4). + + Returns: + Tensor: iou_cost value with weight + """ + # overlaps: [num_bboxes, num_gt] + overlaps = bbox_overlaps( + bboxes.detach().numpy(), gt_bboxes.detach().numpy(), mode=self.iou_mode, is_aligned=False + ) + # The 1 is a constant that doesn't change the matching, so omitted. + iou_cost = -overlaps + return iou_cost * self.weight diff --git a/ppdet/modeling/heads/__init__.py b/ppdet/modeling/heads/__init__.py index e2b2dc2da0..23c0136ca2 100644 --- a/ppdet/modeling/heads/__init__.py +++ b/ppdet/modeling/heads/__init__.py @@ -41,6 +41,9 @@ from . import sparse_roi_head from . import vitpose_head from . import clrnet_head +from . import co_deformable_detr_head +from . import co_roi_head +from . import co_atss_head from .bbox_head import * from .mask_head import * @@ -72,3 +75,6 @@ from .petr_head import * from .vitpose_head import * from .clrnet_head import * +from .co_deformable_detr_head import * +from .co_roi_head import * +from .co_atss_head import * diff --git a/ppdet/modeling/heads/co_atss_head.py b/ppdet/modeling/heads/co_atss_head.py new file mode 100644 index 0000000000..afe35af034 --- /dev/null +++ b/ppdet/modeling/heads/co_atss_head.py @@ -0,0 +1,665 @@ +from functools import partial +import paddle +import paddle.nn as nn +from ppdet.modeling import bbox_utils +from ppdet.core.workspace import register +from ppdet.modeling.assigners import hungarian_assigner +from ppdet.data.transform.atss_assigner import ATSSAssigner + +__all__ = ['CoATSSHead'] + +class Scale(nn.Layer): + """A learnable scale parameter. + + This layer scales the input by a learnable factor. It multiplies a + learnable scale parameter of shape (1,) with input of any shape. + + Args: + scale (float): Initial value of scale factor. Default: 1.0 + """ + + def __init__(self, scale: float = 1.0): + super().__init__() + self.scale = paddle.create_parameter(paddle.to_tensor(scale, dtype='float32').shape,dtype='float32') + + def forward(self, x: paddle.Tensor) -> paddle.Tensor: + return x * self.scale + +def reduce_mean(tensor): + world_size = paddle.distributed.get_world_size() + if world_size == 1: + return tensor + paddle.distributed.all_reduce(tensor) + return tensor / world_size + +def multi_apply(func, *args, **kwargs): + """Apply function to a list of arguments. + + Note: + This function applies the ``func`` to multiple inputs and + map the multiple outputs of the ``func`` into different + list. Each list contains the same type of outputs corresponding + to different inputs. + + Args: + func (Function): A function that will be applied to a list of + arguments + + Returns: + tuple(list): A tuple containing multiple list, each list contains \ + a kind of returned results by the function + """ + pfunc = partial(func, **kwargs) if kwargs else func + map_results = map(pfunc, *args) + res = tuple(map(list, zip(*map_results))) + return res + +@register +class CoATSSHead(nn.Layer): + """Bridging the Gap Between Anchor-based and Anchor-free Detection via + Adaptive Training Sample Selection. + + ATSS head structure is similar with FCOS, however ATSS use anchor boxes + and assign label by Adaptive Training Sample Selection instead max-iou. + + https://arxiv.org/abs/1912.02424 + """ + __inject__ = ['anchor_generator','loss_cls', 'loss_bbox','sampler'] + def __init__(self, + num_classes, + in_channels, + stacked_convs=4, + feat_channels=256, + anchor_generator=None, + assigner='ATSSAssigner', + sampler='PseudoSampler', + loss_cls=None, + loss_bbox=None, + reg_decoded_bbox=True, + pos_weight=-1 + ): + super().__init__() + self.num_classes=num_classes + self.in_channels=in_channels + self.stacked_convs=stacked_convs + self.feat_channels=feat_channels + self.anchor_generator=anchor_generator + self.num_levels=len(self.anchor_generator.strides) + self.num_anchors = self.anchor_generator.num_anchors + self.use_sigmoid_cls = True + self.loss_cls=loss_cls + self.loss_bbox=loss_bbox + self.loss_centerness=nn.CrossEntropyLoss() + self.assigner=ATSSAssigner() + self.sampler = sampler + self.reg_decoded_bbox=reg_decoded_bbox + self.pos_weight=pos_weight + if self.use_sigmoid_cls: + self.cls_out_channels = num_classes + else: + self.cls_out_channels = num_classes + 1 + + if self.cls_out_channels <= 0: + raise ValueError(f'num_classes={num_classes} is too small') + self._init_layers() + + def _init_layers(self): + """Initialize layers of the head.""" + self.cls_convs = nn.LayerList() + self.reg_convs = nn.LayerList() + for i in range(self.stacked_convs): + chn = self.in_channels if i == 0 else self.feat_channels + self.cls_convs.append( + nn.Sequential( + nn.Conv2D(chn, self.feat_channels, 3, padding=1), + nn.GroupNorm(32,self.feat_channels), + nn.ReLU())) + + self.reg_convs.append( + nn.Sequential( + nn.Conv2D(chn, self.feat_channels, 3, padding=1), + nn.GroupNorm(32,self.feat_channels), + nn.ReLU())) + self.atss_cls = nn.Conv2D( + self.feat_channels, + self.num_anchors * self.cls_out_channels, + 3, + padding=1) + self.atss_reg = nn.Conv2D( + self.feat_channels, self.num_anchors * 4, 3, padding=1) + self.atss_centerness = nn.Conv2D( + self.feat_channels, self.num_anchors * 1, 3, padding=1) + self.scales = nn.LayerList( + [Scale(1.0) for _ in self.anchor_generator.strides]) + + def forward(self, feats): + """Forward features from the upstream network. + + Args: + feats (tuple[Tensor]): Features from the upstream network, each is + a 4D-tensor. + + Returns: + tuple: Usually a tuple of classification scores and bbox prediction + cls_scores (list[Tensor]): Classification scores for all scale + levels, each is a 4D-tensor, the channels number is + num_anchors * num_classes. + bbox_preds (list[Tensor]): Box energies / deltas for all scale + levels, each is a 4D-tensor, the channels number is + num_anchors * 4. + """ + return multi_apply(self.forward_single, feats, self.scales) + + def forward_single(self, x, scale): + """Forward feature of a single scale level. + + Args: + x (Tensor): Features of a single scale level. + scale : Learnable scale module to resize the bbox prediction. + + Returns: + tuple: + cls_score (Tensor): Cls scores for a single scale level + the channels number is num_anchors * num_classes. + bbox_pred (Tensor): Box energies / deltas for a single scale + level, the channels number is num_anchors * 4. + centerness (Tensor): Centerness for a single scale level, the + channel number is (N, num_anchors * 1, H, W). + """ + cls_feat = x + reg_feat = x + + for cls_conv in self.cls_convs: + cls_feat = cls_conv(cls_feat) + for reg_conv in self.reg_convs: + reg_feat = reg_conv(reg_feat) + cls_score = self.atss_cls(cls_feat) + # we just follow atss, not apply exp in bbox_pred + bbox_pred = scale(self.atss_reg(reg_feat)).astype('float32') + centerness = self.atss_centerness(reg_feat) + return cls_score, bbox_pred, centerness + + def loss_single(self, anchors, cls_score, bbox_pred, centerness, labels, + label_weights, bbox_targets, img_metas, num_total_samples): + """Compute loss of a single scale level. + + Args: + cls_score (Tensor): Box scores for each scale level + Has shape (N, num_anchors * num_classes, H, W). + bbox_pred (Tensor): Box energies / deltas for each scale + level with shape (N, num_anchors * 4, H, W). + anchors (Tensor): Box reference for each scale level with shape + (N, num_total_anchors, 4). + labels (Tensor): Labels of each anchors with shape + (N, num_total_anchors). + label_weights (Tensor): Label weights of each anchor with shape + (N, num_total_anchors) + bbox_targets (Tensor): BBox regression targets of each anchor + weight shape (N, num_total_anchors, 4). + num_total_samples (int): Number os positive samples that is + reduced over all GPUs. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + anchors = anchors.reshape((-1, 4)) + cls_score = cls_score.transpose((0, 2, 3, 1)).reshape( + (-1, self.cls_out_channels)) + bbox_pred = bbox_pred.transpose((0, 2, 3, 1)).reshape((-1, 4)) + centerness = centerness.transpose((0, 2, 3, 1)).reshape([-1]) + bbox_targets = bbox_targets.reshape((-1, 4)) + labels = labels.reshape([-1]).astype(paddle.int32) + label_weights = label_weights.reshape([-1]) + # classification loss + loss_cls = self.loss_cls( + cls_score, labels, label_weights, avg_factor=num_total_samples) + # FG cat_id: [0, num_classes -1], BG cat_id: num_classes + bg_class_ind = self.num_classes + pos_inds = paddle.nonzero( + paddle.logical_and((labels >= 0), (labels < bg_class_ind)), + as_tuple=False).squeeze(1) + + if len(pos_inds) > 0: + pos_bbox_targets = bbox_targets[pos_inds] + pos_bbox_pred = bbox_pred[pos_inds] + pos_anchors = anchors[pos_inds] + pos_centerness = centerness[pos_inds] + + centerness_targets = self.centerness_target( + pos_anchors, pos_bbox_targets) + pos_decode_bbox_pred = bbox_utils.delta2bbox( + pos_anchors, pos_bbox_pred) + + # regression loss + loss_bbox = self.loss_bbox( + pos_decode_bbox_pred, + pos_bbox_targets, + weight=centerness_targets, + avg_factor=1.0) + + # centerness loss + loss_centerness = self.loss_centerness( + pos_centerness, + centerness_targets, + avg_factor=num_total_samples) + else: + loss_bbox = bbox_pred.sum() * 0 + loss_centerness = centerness.sum() * 0 + centerness_targets = paddle.to_tensor(0., dtype=bbox_targets.dtype) + + return loss_cls, loss_bbox, loss_centerness, centerness_targets.sum() + + def centerness_target(self, anchors, gts): + # only calculate pos centerness targets, otherwise there may be nan + anchors_cx = (anchors[:, 2] + anchors[:, 0]) / 2 + anchors_cy = (anchors[:, 3] + anchors[:, 1]) / 2 + l_ = anchors_cx - gts[:, 0] + t_ = anchors_cy - gts[:, 1] + r_ = gts[:, 2] - anchors_cx + b_ = gts[:, 3] - anchors_cy + + left_right = paddle.stack([l_, r_], axis=1) + top_bottom = paddle.stack([t_, b_], axis=1) + centerness = paddle.sqrt( + (left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) * + (top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0])) + assert not paddle.isnan(centerness).any() + return centerness + + def get_anchors(self, featmap_sizes, img_metas, device='cuda'): + """Get anchors according to feature map sizes. + + Args: + featmap_sizes (list[tuple]): Multi-level feature map sizes. + img_metas (list[dict]): Image meta info. + device (torch.device | str): Device for returned tensors + + Returns: + tuple: + anchor_list (list[Tensor]): Anchors of each image. + valid_flag_list (list[Tensor]): Valid flags of each image. + """ + num_imgs = len(img_metas) + + # since feature map sizes of all images are the same, we only compute + # anchors for one time + multi_level_anchors = self.anchor_generator( + featmap_sizes) + anchor_list = [multi_level_anchors for _ in range(num_imgs)] + + # for each image, we compute valid flags of multi level anchors + valid_flag_list = [] + for img_id, img_meta in enumerate(img_metas): + multi_level_flags = self.anchor_generator.valid_flags( + featmap_sizes, img_meta['img_shape']) + valid_flag_list.append(multi_level_flags) + + return anchor_list, valid_flag_list + + def loss(self, + cls_scores, + bbox_preds, + centernesses, + gt_bboxes, + gt_labels, + img_metas, + gt_bboxes_ignore=None): + """Compute losses of the head. + + Args: + cls_scores (list[Tensor]): Box scores for each scale level + Has shape (N, num_anchors * num_classes, H, W) + bbox_preds (list[Tensor]): Box energies / deltas for each scale + level with shape (N, num_anchors * 4, H, W) + centernesses (list[Tensor]): Centerness for each scale + level with shape (N, num_anchors * 1, H, W) + gt_bboxes (list[Tensor]): Ground truth bboxes for each image with + shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. + gt_labels (list[Tensor]): class indices corresponding to each box + img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + gt_bboxes_ignore (list[Tensor] | None): specify which bounding + boxes can be ignored when computing the loss. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + featmap_sizes = [featmap.shape[-2:] for featmap in cls_scores] + assert len(featmap_sizes) == self.num_levels + anchor_list, valid_flag_list = self.get_anchors( + cls_scores, img_metas) + + label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1 + cls_reg_targets = self.get_targets( + anchor_list, + valid_flag_list, + gt_bboxes, + img_metas, + gt_bboxes_ignore_list=gt_bboxes_ignore, + gt_labels_list=gt_labels, + label_channels=label_channels) + if cls_reg_targets is None: + return None + + (anchor_list, labels_list, label_weights_list, bbox_targets_list, + bbox_weights_list, num_total_pos, num_total_neg, + ori_anchors, ori_labels, ori_bbox_targets) = cls_reg_targets + num_total_samples = reduce_mean( + paddle.to_tensor(num_total_pos, dtype=paddle.float32)).item() + num_total_samples = max(num_total_samples, 1.0) + new_img_metas = [img_metas for _ in range(len(anchor_list))] + losses_cls, losses_bbox, loss_centerness,\ + bbox_avg_factor = multi_apply( + self.loss_single, + anchor_list, + cls_scores, + bbox_preds, + centernesses, + labels_list, + label_weights_list, + bbox_targets_list, + new_img_metas, + num_total_samples=num_total_samples) + + bbox_avg_factor = sum(bbox_avg_factor) + bbox_avg_factor = reduce_mean(bbox_avg_factor).clip_(min=1).item() + losses_bbox = list(map(lambda x: x / bbox_avg_factor, losses_bbox)) + + pos_coords = (ori_anchors, ori_labels, ori_bbox_targets, 'atss') + return dict( + loss_cls=losses_cls, + loss_bbox=losses_bbox, + loss_centerness=loss_centerness, + pos_coords=pos_coords) + + + def images_to_levels(self, target, num_level_anchors): + """ + Convert targets by image to targets by feature level. + """ + target = paddle.stack(target, 0) + level_targets = [] + start = 0 + for n in num_level_anchors: + end = start + n + level_targets.append(target[:, start:end].squeeze(0)) + start = end + return level_targets + + + def get_targets(self, + anchor_list, + valid_flag_list, + gt_bboxes_list, + img_metas, + gt_bboxes_ignore_list=None, + gt_labels_list=None, + label_channels=1, + unmap_outputs=True): + """Get targets for ATSS head. + + This method is almost the same as `AnchorHead.get_targets()`. Besides + returning the targets as the parent method does, it also returns the + anchors as the first element of the returned tuple. + """ + num_imgs = len(img_metas) + assert len(anchor_list) == len(valid_flag_list) == num_imgs + + # anchor number of multi levels + num_level_anchors = [anchors.shape[0] for anchors in anchor_list[0]] + num_level_anchors_list = [num_level_anchors] * num_imgs + + # concat all level anchors and flags to a single tensor + for i in range(num_imgs): + assert len(anchor_list[i]) == len(valid_flag_list[i]) + anchor_list[i] = paddle.concat(anchor_list[i]) + valid_flag_list[i] = paddle.concat(valid_flag_list[i]) + + # compute targets for each image + if gt_bboxes_ignore_list is None: + gt_bboxes_ignore_list = [None for _ in range(num_imgs)] + if gt_labels_list is None: + gt_labels_list = [None for _ in range(num_imgs)] + (all_anchors, all_labels, all_label_weights, all_bbox_targets, + all_bbox_weights, pos_inds_list, neg_inds_list) = multi_apply( + self._get_target_single, + anchor_list, + valid_flag_list, + num_level_anchors_list, + gt_bboxes_list, + gt_bboxes_ignore_list, + gt_labels_list, + img_metas, + label_channels=label_channels, + unmap_outputs=unmap_outputs) + # no valid anchors + if any([labels is None for labels in all_labels]): + return None + # sampled anchors of all images + num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list]) + num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list]) + # split targets to a list w.r.t. multiple levels + ori_anchors = all_anchors + ori_labels = all_labels + ori_bbox_targets = all_bbox_targets + anchors_list = self.images_to_levels(all_anchors, num_level_anchors) + labels_list = self.images_to_levels(all_labels, num_level_anchors) + label_weights_list = self.images_to_levels(all_label_weights, + num_level_anchors) + bbox_targets_list = self.images_to_levels(all_bbox_targets, + num_level_anchors) + bbox_weights_list = self.images_to_levels(all_bbox_weights, + num_level_anchors) + return (anchors_list, labels_list, label_weights_list, + bbox_targets_list, bbox_weights_list, num_total_pos, + num_total_neg, ori_anchors, ori_labels, ori_bbox_targets) + + def anchor_inside_flags(self,flat_anchors, + valid_flags, + img_shape, + allowed_border=0): + """Check whether the anchors are inside the border. + + Args: + flat_anchors (torch.Tensor): Flatten anchors, shape (n, 4). + valid_flags (torch.Tensor): An existing valid flags of anchors. + img_shape (tuple(int)): Shape of current image. + allowed_border (int, optional): The border to allow the valid anchor. + Defaults to 0. + + Returns: + torch.Tensor: Flags indicating whether the anchors are inside a \ + valid range. + """ + img_h, img_w = img_shape[:2] + if allowed_border >= 0: + inside_flags = valid_flags & \ + (flat_anchors[:, 0] >= -allowed_border) & \ + (flat_anchors[:, 1] >= -allowed_border) & \ + (flat_anchors[:, 2] < img_w + allowed_border) & \ + (flat_anchors[:, 3] < img_h + allowed_border) + else: + inside_flags = valid_flags + return inside_flags + + + def unmap(self,data, count, inds, fill=0): + """Unmap a subset of item (data) back to the original set of items (of size + count)""" + if data.dim() == 1: + ret = paddle.full((count,1), fill) + data=data.unsqueeze(0).transpose((1,0)) + ret[inds,:] = data + ret=ret.transpose((1,0)).squeeze() + else: + new_size = (count, ) + tuple(data.shape[1:]) + ret = paddle.full(new_size, fill) + ret[inds.astype(paddle.bool), :] = data + return ret + + + def get_num_level_anchors_inside(self, num_level_anchors, inside_flags): + split_inside_flags = paddle.split(inside_flags, num_level_anchors) + num_level_anchors_inside = [ + int(flags.sum()) for flags in split_inside_flags + ] + return num_level_anchors_inside + + def _get_target_single(self, + flat_anchors, + valid_flags, + num_level_anchors, + gt_bboxes, + gt_bboxes_ignore, + gt_labels, + img_meta, + label_channels=1, + unmap_outputs=True): + """Compute regression, classification targets for anchors in a single + image. + + Args: + flat_anchors (Tensor): Multi-level anchors of the image, which are + concatenated into a single tensor of shape (num_anchors ,4) + valid_flags (Tensor): Multi level valid flags of the image, + which are concatenated into a single tensor of + shape (num_anchors,). + num_level_anchors Tensor): Number of anchors of each scale level. + gt_bboxes (Tensor): Ground truth bboxes of the image, + shape (num_gts, 4). + gt_bboxes_ignore (Tensor): Ground truth bboxes to be + ignored, shape (num_ignored_gts, 4). + gt_labels (Tensor): Ground truth labels of each box, + shape (num_gts,). + img_meta (dict): Meta info of the image. + label_channels (int): Channel of label. + unmap_outputs (bool): Whether to map outputs back to the original + set of anchors. + + Returns: + tuple: N is the number of total anchors in the image. + labels (Tensor): Labels of all anchors in the image with shape + (N,). + label_weights (Tensor): Label weights of all anchor in the + image with shape (N,). + bbox_targets (Tensor): BBox targets of all anchors in the + image with shape (N, 4). + bbox_weights (Tensor): BBox weights of all anchors in the + image with shape (N, 4) + pos_inds (Tensor): Indices of positive anchor with shape + (num_pos,). + neg_inds (Tensor): Indices of negative anchor with shape + (num_neg,). + """ + inside_flags = self.anchor_inside_flags(flat_anchors, valid_flags, + img_meta['img_shape'][:2], + -1).astype(paddle.bool) + if not inside_flags.any(): + return (None, ) * 7 + # assign gt and sample anchors + anchors = flat_anchors[inside_flags, :] + + num_level_anchors_inside = self.get_num_level_anchors_inside( + num_level_anchors, inside_flags) + + # pad_gt_mask = ( + # gt_bboxes.sum(axis=-1, keepdim=True) > 0).astype(gt_bboxes.dtype) + + assigned_gt_inds, max_overlaps = self.assigner(anchors.cpu().detach().numpy(), num_level_anchors_inside, + gt_labels=gt_labels, gt_bboxes=gt_bboxes.cpu().detach().numpy(), + ) + assigned_gt_inds = paddle.to_tensor(assigned_gt_inds) + max_overlaps = paddle.to_tensor(max_overlaps) + if gt_labels is not None: + assigned_labels = paddle.full((anchors.shape[0], ),-1, dtype=assigned_gt_inds.dtype) + pos_inds = paddle.nonzero( + assigned_gt_inds > 0, as_tuple=False).squeeze() + if pos_inds.numel() > 0: + assigned_labels[pos_inds] = gt_labels[ + assigned_gt_inds[pos_inds] - 1] + else: + assigned_labels = None + + assign_result=hungarian_assigner.AssignResult( + gt_bboxes.shape[0], assigned_gt_inds, max_overlaps, labels=assigned_labels) + + sampling_result = self.sampler.sample(assign_result, anchors, + gt_bboxes) + + num_valid_anchors = anchors.shape[0] + bbox_targets = paddle.zeros_like(anchors) + bbox_weights = paddle.zeros_like(anchors) + labels = paddle.full((num_valid_anchors, ),self.num_classes,dtype=paddle.int64) + + label_weights = paddle.zeros((num_valid_anchors, ), dtype=paddle.float32) + pos_inds = sampling_result.pos_inds + neg_inds = sampling_result.neg_inds + if len(pos_inds) > 0: + if self.reg_decoded_bbox: + pos_bbox_targets = sampling_result.pos_gt_bboxes + else: + pos_bbox_targets = bbox_utils.bbox2delta( + sampling_result.pos_bboxes, sampling_result.pos_gt_bboxes) + + bbox_targets[pos_inds, :] = pos_bbox_targets + bbox_weights[pos_inds, :] = 1.0 + if gt_labels is None: + # Only rpn gives gt_labels as None + # Foreground is the first class since v2.5.0 + labels[pos_inds] = 0 + else: + labels[pos_inds] = gt_labels[ + sampling_result.pos_assigned_gt_inds] + if self.pos_weight <= 0: + label_weights[pos_inds] = 1.0 + else: + label_weights[pos_inds] = self.pos_weight + if len(neg_inds) > 0: + label_weights[neg_inds] = 1.0 + + # map up to original set of anchors + if unmap_outputs: + num_total_anchors = flat_anchors.shape[0] + anchors = self.unmap(anchors, num_total_anchors, inside_flags) + labels = self.unmap( + labels, num_total_anchors, inside_flags, fill=self.num_classes) + + label_weights = self.unmap(label_weights, num_total_anchors, + inside_flags) + bbox_targets = self.unmap(bbox_targets, num_total_anchors, inside_flags) + bbox_weights = self.unmap(bbox_weights, num_total_anchors, inside_flags) + + return (anchors, labels, label_weights, bbox_targets, bbox_weights, + pos_inds, neg_inds) + + def forward_train(self, + x, + img_metas, + gt_bboxes, + gt_labels=None, + gt_bboxes_ignore=None, + **kwargs): + """ + Args: + x (list[Tensor]): Features from FPN. + img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + gt_bboxes (Tensor): Ground truth bboxes of the image, + shape (num_gts, 4). + gt_labels (Tensor): Ground truth labels of each box, + shape (num_gts,). + gt_bboxes_ignore (Tensor): Ground truth bboxes to be + ignored, shape (num_ignored_gts, 4). + + Returns: + tuple: + losses: (dict[str, Tensor]): A dictionary of loss components. + """ + outs = self(x) + if gt_labels is None: + loss_inputs = outs + (gt_bboxes, img_metas) + else: + loss_inputs = outs + (gt_bboxes, gt_labels, img_metas) + losses = self.loss(*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore) + + return losses diff --git a/ppdet/modeling/heads/co_deformable_detr_head.py b/ppdet/modeling/heads/co_deformable_detr_head.py new file mode 100644 index 0000000000..8f6918743b --- /dev/null +++ b/ppdet/modeling/heads/co_deformable_detr_head.py @@ -0,0 +1,1300 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +this code is base on https://github.com/Sense-X/Co-DETR/blob/main/projects/models/co_deformable_detr_head.py +""" +import copy +import numpy as np + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from ppdet.core.workspace import register +import paddle.distributed as dist + +from ..transformers.petr_transformer import inverse_sigmoid, masked_fill +from ..initializer import constant_, normal_ +from ppdet.modeling.transformers.utils import bbox_cxcywh_to_xyxy + +__all__ = ["CoDeformDETRHead"] + +from functools import partial + + +def bias_init_with_prob(prior_prob: float) -> float: + """initialize conv/fc bias value according to a given probability value.""" + bias_init = float(-np.log((1 - prior_prob) / prior_prob)) + return bias_init + + +def constant_init(module, val, bias=0): + if hasattr(module, "weight") and module.weight is not None: + constant_(module.weight, val) + if hasattr(module, "bias") and module.bias is not None: + constant_(module.bias, bias) + + +def reduce_mean(tensor): + """ "Obtain the mean of tensor on different GPUs.""" + if not (dist.get_world_size() and dist.is_initialized()): + return tensor + tensor = tensor.clone() + dist.all_reduce( + tensor.divide(paddle.to_tensor(dist.get_world_size(), dtype="float32")), + op=dist.ReduceOp.SUM, + ) + return tensor + + +def bbox_xyxy_to_cxcywh(bbox): + """Convert bbox coordinates from (x1, y1, x2, y2) to (cx, cy, w, h). + + Args: + bbox (Tensor): Shape (n, 4) for bboxes. + + Returns: + Tensor: Converted bboxes. + """ + x1, y1, x2, y2 = paddle.split(bbox, (1, 1, 1, 1), axis=-1) + bbox_new = [(x1 + x2) / 2, (y1 + y2) / 2, (x2 - x1), (y2 - y1)] + return paddle.concat(bbox_new, axis=-1) + + +def multi_apply(func, *args, **kwargs): + """Apply function to a list of arguments. + + Note: + This function applies the ``func`` to multiple inputs and + map the multiple outputs of the ``func`` into different + list. Each list contains the same type of outputs corresponding + to different inputs. + + Args: + func (Function): A function that will be applied to a list of + arguments + + Returns: + tuple(list): A tuple containing multiple list, each list contains \ + a kind of returned results by the function + """ + pfunc = partial(func, **kwargs) if kwargs else func + map_results = map(pfunc, *args) + res = tuple(map(list, zip(*map_results))) + return res + + +@register +class CoDeformDETRHead(nn.Layer): + __inject__ = [ + "transformer", + "positional_encoding", + "loss_cls", + "loss_bbox", + "loss_iou", + "nms", + "assigner", + "sampler" + ] + + def __init__( + self, + num_classes, + in_channels, + num_query=300, + sync_cls_avg_factor=True, + with_box_refine=False, + as_two_stage=False, + mixed_selection=False, + max_pos_coords=300, + lambda_1=1, + num_reg_fcs=2, + transformer=None, + positional_encoding="SinePositionalEncoding", + loss_cls="FocalLoss", + loss_bbox="L1Loss", + loss_iou="GIoULoss", + assigner="HungarianAssigner", + sampler="PseudoSampler", + test_cfg=dict(max_per_img=100), + nms=None, + use_zero_padding=False, + ): + super().__init__() + self.num_classes = num_classes + self.in_channels = in_channels + self.assigner = assigner + self.sampler = sampler + self.bg_cls_weight = 0 + self.num_query = num_query + self.sync_cls_avg_factor = sync_cls_avg_factor + self.with_box_refine = with_box_refine + self.as_two_stage = as_two_stage + self.mixed_selection = mixed_selection + self.max_pos_coords = max_pos_coords + self.lambda_1 = lambda_1 + self.use_zero_padding = use_zero_padding + self.test_cfg = test_cfg + self.nms = nms + self.transformer = transformer + self.num_reg_fcs = num_reg_fcs + + self.positional_encoding = positional_encoding + self.loss_cls = loss_cls + self.loss_bbox = loss_bbox + self.loss_iou = loss_iou + if self.loss_cls.use_sigmoid: + self.cls_out_channels = num_classes + else: + self.cls_out_channels = num_classes + 1 + self.embed_dims = self.transformer.embed_dims + + num_feats = positional_encoding.num_pos_feats + assert num_feats * 2 == self.embed_dims, ( + "embed_dims should" + f" be exactly 2 times of num_feats. Found {self.embed_dims}" + f" and {num_feats}." + ) + self._init_layers() + self.init_weights() + + def _init_layers(self): + """Initialize classification branch and regression branch of head.""" + self.downsample = nn.Sequential( + nn.Conv2D( + self.embed_dims, self.embed_dims, kernel_size=3, stride=2, padding=1 + ), + nn.GroupNorm(32, self.embed_dims), + ) + + fc_cls = nn.Linear(self.embed_dims, self.cls_out_channels) + + reg_branch = [] + for _ in range(self.num_reg_fcs): + reg_branch.append(nn.Linear(self.embed_dims, self.embed_dims)) + reg_branch.append(nn.ReLU()) + reg_branch.append(nn.Linear(self.embed_dims, 4)) + reg_branch = nn.Sequential(*reg_branch) + + def _get_clones(module, N): + return nn.LayerList([copy.deepcopy(module) for i in range(N)]) + + # last reg_branch is used to generate proposal from + # encode feature map when as_two_stage is True. + num_pred = ( + (self.transformer.decoder.num_layers + 1) + if self.as_two_stage + else self.transformer.decoder.num_layers + ) + + if self.with_box_refine: + self.cls_branches = _get_clones(fc_cls, num_pred) + self.reg_branches = _get_clones(reg_branch, num_pred) + else: + self.cls_branches = nn.LayerList([fc_cls for _ in range(num_pred)]) + self.reg_branches = nn.LayerList([reg_branch for _ in range(num_pred)]) + + if not self.as_two_stage: + self.query_embedding = nn.Embedding(self.num_query, self.embed_dims * 2) + elif self.mixed_selection: + self.query_embedding = nn.Embedding(self.num_query, self.embed_dims) + + def init_weights(self): + """Initialize weights of the DeformDETR head.""" + self.transformer.init_weights() + if self.loss_cls.use_sigmoid: + bias_init = bias_init_with_prob(0.01) + for m in self.cls_branches: + constant_(m.bias, bias_init) + for m in self.reg_branches: + constant_init(m[-1], 0, bias=0) + constant_(self.reg_branches[0][-1].bias.data[2:], -2.0) + if self.as_two_stage: + for m in self.reg_branches: + constant_(m[-1].bias.data[2:], 0.0) + + def forward(self, mlvl_feats, img_metas): + """Forward function. + + Args: + mlvl_feats (tuple[Tensor]): Features from the upstream + network, each is a 4D-tensor with shape + (N, C, H, W). + img_metas (list[dict]): List of image information. + + Returns: + all_cls_scores (Tensor): Outputs from the classification head, \ + shape [nb_dec, bs, num_query, cls_out_channels]. Note \ + cls_out_channels should includes background. + all_bbox_preds (Tensor): Sigmoid outputs from the regression \ + head with normalized coordinate format (cx, cy, w, h). \ + Shape [nb_dec, bs, num_query, 4]. + enc_outputs_class (Tensor): The score of each point on encode \ + feature map, has shape (N, h*w, num_class). Only when \ + as_two_stage is True it would be returned, otherwise \ + `None` would be returned. + enc_outputs_coord (Tensor): The proposal generate from the \ + encode feature map, has shape (N, h*w, 4). Only when \ + as_two_stage is True it would be returned, otherwise \ + `None` would be returned. + """ + batch_size = mlvl_feats[0].shape[0] + input_img_h, input_img_w = img_metas[0]["batch_input_shape"] + img_masks = paddle.zeros((batch_size, input_img_h, input_img_w),mlvl_feats[0].dtype) + for img_id in range(batch_size): + img_h, img_w, _ = img_metas[img_id]["img_shape"] + img_masks[img_id, :img_h, :img_w] = 0 + + mlvl_masks = [] + mlvl_positional_encodings = [] + for feat in mlvl_feats: + mlvl_masks.append( + F.interpolate( + img_masks[None], size=feat.shape[-2:]).squeeze(0)) + mlvl_positional_encodings.append(self.positional_encoding(paddle.logical_not(mlvl_masks[-1]).astype('float32')).transpose((0,3,1,2))) + + query_embeds = None + if not self.as_two_stage or self.mixed_selection: + query_embeds = self.query_embedding.weight + + ( + hs, + init_reference, + inter_references, + enc_outputs_class, + enc_outputs_coord, + enc_outputs, + ) = self.transformer( + mlvl_feats, + mlvl_masks, + query_embeds, + mlvl_positional_encodings, + reg_branches=( + self.reg_branches if self.with_box_refine else None + ), # noqa:E501 + cls_branches=self.cls_branches if self.as_two_stage else None, # noqa:E501 + return_encoder_output=True, + ) + + outs = [] + num_level = len(mlvl_feats) + start = 0 + enc_outputs = enc_outputs.transpose((1,0,2)) + for lvl in range(num_level): + bs, c, h, w = mlvl_feats[lvl].shape + end = start + h * w + feat = enc_outputs[start:end].transpose((1, 2, 0)) + start = end + outs.append(feat.reshape((bs, c, h, w))) + outs.append(self.downsample(outs[-1])) + + outputs_classes = [] + outputs_coords = [] + + for lvl in range(hs.shape[0]): + if lvl == 0: + reference = init_reference + else: + reference = inter_references[lvl - 1] + reference = inverse_sigmoid(reference) + outputs_class = self.cls_branches[lvl](hs[lvl]) + tmp = self.reg_branches[lvl](hs[lvl]) + if reference.shape[-1] == 4: + tmp += reference + else: + assert reference.shape[-1] == 2 + tmp[..., :2] += reference + outputs_coord = F.sigmoid(tmp) + outputs_classes.append(outputs_class) + outputs_coords.append(outputs_coord) + + outputs_classes = paddle.stack(outputs_classes) + outputs_coords = paddle.stack(outputs_coords) + if self.as_two_stage: + return ( + outputs_classes, + outputs_coords, + enc_outputs_class, + F.sigmoid(enc_outputs_coord), + outs, + ) + else: + return outputs_classes, outputs_coords, None, None, outs + + def forward_aux(self, mlvl_feats, img_metas, aux_targets, head_idx): + """Forward function. + + Args: + mlvl_feats (tuple[Tensor]): Features from the upstream + network, each is a 4D-tensor with shape + (N, C, H, W). + img_metas (list[dict]): List of image information. + + Returns: + all_cls_scores (Tensor): Outputs from the classification head, \ + shape [nb_dec, bs, num_query, cls_out_channels]. Note \ + cls_out_channels should includes background. + all_bbox_preds (Tensor): Sigmoid outputs from the regression \ + head with normalized coordinate format (cx, cy, w, h). \ + Shape [nb_dec, bs, num_query, 4]. + enc_outputs_class (Tensor): The score of each point on encode \ + feature map, has shape (N, h*w, num_class). Only when \ + as_two_stage is True it would be returned, otherwise \ + `None` would be returned. + enc_outputs_coord (Tensor): The proposal generate from the \ + encode feature map, has shape (N, h*w, 4). Only when \ + as_two_stage is True it would be returned, otherwise \ + `None` would be returned. + """ + ( + aux_coords, + aux_labels, + aux_targets, + aux_label_weights, + aux_bbox_weights, + aux_feats, + attn_masks, + ) = aux_targets + batch_size = mlvl_feats[0].shape[0] + input_img_h, input_img_w = img_metas[0]["batch_input_shape"] + img_masks = paddle.zeros((batch_size, input_img_h, input_img_w),mlvl_feats[0].dtype) + for img_id in range(batch_size): + img_h, img_w, _ = img_metas[img_id]["img_shape"] + img_masks[img_id, :img_h, :img_w] = 0 + + mlvl_masks = [] + mlvl_positional_encodings = [] + for feat in mlvl_feats: + mlvl_masks.append( + F.interpolate(img_masks[None], size=feat.shape[-2:]) + .astype(paddle.bool) + .squeeze(0) + ) + mlvl_positional_encodings.append(self.positional_encoding(paddle.logical_not(mlvl_masks[-1]).astype('float32')).transpose((0,3,1,2))) + + query_embeds = None + hs, init_reference, inter_references = self.transformer.forward_aux( + mlvl_feats, + mlvl_masks, + query_embeds, + mlvl_positional_encodings, + aux_coords, + pos_feats=aux_feats, + reg_branches=( + self.reg_branches if self.with_box_refine else None + ), # noqa:E501 + cls_branches=self.cls_branches if self.as_two_stage else None, # noqa:E501 + return_encoder_output=True, + attn_masks=attn_masks, + head_idx=head_idx, + ) + if hs is None: + return None, None, None, None + outputs_classes = [] + outputs_coords = [] + + for lvl in range(hs.shape[0]): + if lvl == 0: + reference = init_reference + else: + reference = inter_references[lvl - 1] + reference = inverse_sigmoid(reference) + outputs_class = self.cls_branches[lvl](hs[lvl]) + tmp = self.reg_branches[lvl](hs[lvl]) + if reference.shape[-1] == 4: + tmp += reference + else: + assert reference.shape[-1] == 2 + tmp[..., :2] += reference + outputs_coord = F.sigmoid(tmp) + outputs_classes.append(outputs_class) + outputs_coords.append(outputs_coord) + + outputs_classes = paddle.stack(outputs_classes) + outputs_coords = paddle.stack(outputs_coords) + + return outputs_classes, outputs_coords, None, None + + def loss_single_aux( + self, + cls_scores, + bbox_preds, + labels, + label_weights, + bbox_targets, + bbox_weights, + img_metas, + gt_bboxes_ignore_list=None, + ): + """ "Loss function for outputs from a single decoder layer of a single + feature level. + + Args: + cls_scores (Tensor): Box score logits from a single decoder layer + for all images. Shape [bs, num_query, cls_out_channels]. + bbox_preds (Tensor): Sigmoid outputs from a single decoder layer + for all images, with normalized coordinate (cx, cy, w, h) and + shape [bs, num_query, 4]. + gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image + with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. + gt_labels_list (list[Tensor]): Ground truth class indices for each + image with shape (num_gts, ). + img_metas (list[dict]): List of image meta information. + gt_bboxes_ignore_list (list[Tensor], optional): Bounding + boxes which can be ignored for each image. Default None. + + Returns: + dict[str, Tensor]: A dictionary of loss components for outputs from + a single decoder layer. + """ + num_imgs = cls_scores.shape[0] + num_q = cls_scores.shape[1] + try: + labels = labels.reshape((num_imgs * num_q)) + label_weights = label_weights.reshape((num_imgs * num_q)) + bbox_targets = bbox_targets.reshape((num_imgs * num_q, 4)) + bbox_weights = bbox_weights.reshape((num_imgs * num_q, 4)) + except: + return cls_scores.mean() * 0, cls_scores.mean() * 0, cls_scores.mean() * 0 + + bg_class_ind = self.num_classes + num_total_pos = len( + ((labels >= 0) & (labels < bg_class_ind)).nonzero().squeeze(1) + ) + num_total_neg = num_imgs * num_q - num_total_pos + + # classification loss + cls_scores = cls_scores.reshape((-1, self.cls_out_channels)) + # construct weighted avg_factor to match with the official DETR repo + cls_avg_factor = num_total_pos * 1.0 + num_total_neg * self.bg_cls_weight + if self.sync_cls_avg_factor: + cls_avg_factor = reduce_mean( + paddle.to_tensor([cls_avg_factor], dtype=cls_scores.dtype) + ) + cls_avg_factor = max(cls_avg_factor, 1) + loss_cls = self.loss_cls( + cls_scores, labels, label_weights, avg_factor=cls_avg_factor + ) + + # Compute the average number of gt boxes across all gpus, for + # normalization purposes + num_total_pos = loss_cls.new_tensor([num_total_pos]) + num_total_pos = paddle.clip(reduce_mean(num_total_pos), min=1).item() + + # construct factors used for rescale bboxes + factors = [] + for img_meta, bbox_pred in zip(img_metas, bbox_preds): + img_h, img_w, _ = img_meta["img_shape"] + factor = ( + paddle.to_tensor([img_w, img_h, img_w, img_h], dtype=bbox_pred.dtype) + .unsqueeze(0) + .tile((bbox_pred.shape[0], 1)) + ) + factors.append(factor) + factors = paddle.concat(factors, 0) + + # DETR regress the relative position of boxes (cxcywh) in the image, + # thus the learning target is normalized by the image size. So here + # we need to re-scale them for calculating IoU loss + bbox_preds = bbox_preds.reshape(-1, 4) + bboxes = bbox_cxcywh_to_xyxy(bbox_preds) * factors + bboxes_gt = bbox_cxcywh_to_xyxy(bbox_targets) * factors + + # regression IoU loss, defaultly GIoU loss + loss_iou = self.loss_iou( + bboxes, bboxes_gt, bbox_weights, avg_factor=num_total_pos + ) + + # regression L1 loss + loss_bbox = self.loss_bbox( + bbox_preds, bbox_targets, bbox_weights, avg_factor=num_total_pos + ) + return ( + loss_cls * self.lambda_1, + loss_bbox * self.lambda_1, + loss_iou * self.lambda_1, + ) + + def get_aux_targets(self, pos_coords, img_metas, mlvl_feats, head_idx): + coords, labels, targets = pos_coords[:3] + head_name = pos_coords[-1] + bs, c = len(coords), mlvl_feats[0].shape[1] + max_num_coords = 0 + all_feats = [] + for i in range(bs): + label = labels[i] + feats = [feat[i].reshape((c, -1)).transpose((1, 0)) for feat in mlvl_feats] + feats = paddle.concat(feats, axis=0) + bg_class_ind = self.num_classes + pos_inds = paddle.logical_and((label >= 0), + (label < bg_class_ind)).nonzero().squeeze(1) + max_num_coords = max(max_num_coords, len(pos_inds)) + all_feats.append(feats) + max_num_coords = min(self.max_pos_coords, max_num_coords) + max_num_coords = max(9, max_num_coords) + + if self.use_zero_padding: + attn_masks = [] + label_weights = paddle.zeros([bs, max_num_coords], coords[0].dtype) + else: + attn_masks = None + label_weights = paddle.zeros([bs, max_num_coords], coords[0].dtype) + bbox_weights = paddle.zeros([bs, max_num_coords, 4], coords[0].dtype) + + aux_coords, aux_labels, aux_targets, aux_feats = [], [], [], [] + for i in range(bs): + coord, label, target = coords[i], labels[i], targets[i] + feats = all_feats[i] + if "rcnn" in head_name: + feats = pos_coords[-2][i] + num_coords_per_point = 1 + else: + num_coords_per_point = coord.shape[0] // feats.shape[0] + feats = feats.unsqueeze(1).tile((1, num_coords_per_point, 1)) + feats = feats.reshape( + (feats.shape[0] * num_coords_per_point, feats.shape[-1]) + ) + img_meta = img_metas[i] + img_h, img_w, _ = img_meta["img_shape"] + factor = ( + paddle.to_tensor([img_w, img_h, img_w, img_h], dtype="float32") + .unsqueeze(0) + # .tile((self.num_query, 1)) + ) + bg_class_ind = self.num_classes + pos_inds = paddle.logical_and((label >= 0), + (label < bg_class_ind)).nonzero().squeeze(1) + neg_inds = ((label == bg_class_ind)).nonzero().squeeze(1) + if pos_inds.shape[0] > max_num_coords: + indices = paddle.randperm(pos_inds.shape[0])[:max_num_coords] + pos_inds = pos_inds[indices] + + if pos_inds.shape[0] == 0: + return None, None,None,None,None, None,None + + coord = bbox_xyxy_to_cxcywh(coord[pos_inds] / factor) + label = label[pos_inds] + target = bbox_xyxy_to_cxcywh(target[pos_inds] / factor) + feat = feats[pos_inds] + + if self.use_zero_padding: + label_weights[i][: len(label)] = 1 + bbox_weights[i][: len(label)] = 1 + attn_mask = paddle.zeros( + [ + max_num_coords, + max_num_coords, + ] + ).astype(paddle.bool()) + else: + bbox_weights[i][: len(label)] = 1 + + if coord.shape[0] < max_num_coords: + padding_shape = max_num_coords - coord.shape[0] + if self.use_zero_padding: + padding_coord = paddle.zeros([padding_shape, 4]) + padding_label = paddle.zeros([padding_shape]) * self.num_classes + padding_target = paddle.zeros([padding_shape, 4]) + padding_feat = paddle.zeros([padding_shape, c]) + attn_mask[ + coord.shape[0] :, + 0 : coord.shape[0], + ] = True + attn_mask[ + :, + coord.shape[0] :, + ] = True + else: + indices = paddle.randperm(neg_inds.shape[0])[:padding_shape] + neg_inds = neg_inds[indices] + padding_coord = bbox_xyxy_to_cxcywh(coords[i][neg_inds] / factor) + padding_label = labels[i][neg_inds] + padding_target = bbox_xyxy_to_cxcywh(targets[i][neg_inds] / factor) + padding_feat = feats[neg_inds] + coord = paddle.concat((coord, padding_coord), axis=0) + label = paddle.concat((label, padding_label), axis=0) + target = paddle.concat((target, padding_target), axis=0) + feat = paddle.concat((feat, padding_feat), axis=0) + if self.use_zero_padding: + attn_masks.append(attn_mask.unsqueeze(0)) + aux_coords.append(coord.unsqueeze(0)) + aux_labels.append(label.unsqueeze(0)) + aux_targets.append(target.unsqueeze(0)) + aux_feats.append(feat.unsqueeze(0)) + + if self.use_zero_padding: + attn_masks = ( + paddle.concat(attn_masks, axis=0).unsqueeze(1).tile((1, 8, 1, 1)) + ) + attn_masks = attn_masks.reshape((bs * 8, max_num_coords, max_num_coords)) + else: + attn_mask = None + + aux_coords = paddle.concat(aux_coords, axis=0) + aux_labels = paddle.concat(aux_labels, axis=0) + aux_targets = paddle.concat(aux_targets, axis=0) + aux_feats = paddle.concat(aux_feats, axis=0) + aux_label_weights = label_weights + aux_bbox_weights = bbox_weights + return ( + aux_coords, + aux_labels, + aux_targets, + aux_label_weights, + aux_bbox_weights, + aux_feats, + attn_masks, + ) + + # over-write because img_metas are needed as inputs for bbox_head. + def forward_train_aux( + self, + x, + img_metas, + gt_bboxes, + gt_labels=None, + gt_bboxes_ignore=None, + pos_coords=None, + head_idx=0, + **kwargs, + ): + """Forward function for training mode. + + Args: + x (list[Tensor]): Features from backbone. + img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + gt_bboxes (Tensor): Ground truth bboxes of the image, + shape (num_gts, 4). + gt_labels (Tensor): Ground truth labels of each box, + shape (num_gts,). + gt_bboxes_ignore (Tensor): Ground truth bboxes to be + ignored, shape (num_ignored_gts, 4). + proposal_cfg (mmcv.Config): Test / postprocessing configuration, + if None, test_cfg would be used. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + aux_targets = self.get_aux_targets(pos_coords, img_metas, x, head_idx) + if aux_targets[0] is None: + return None + + outs = self.forward_aux(x[:-1], img_metas, aux_targets, head_idx) + outs = outs + aux_targets + if gt_labels is None: + loss_inputs = outs + (gt_bboxes, img_metas) + else: + loss_inputs = outs + (gt_bboxes, gt_labels, img_metas) + losses = self.loss_aux(*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore) + return losses + + def loss_aux( + self, + all_cls_scores, + all_bbox_preds, + enc_cls_scores, + enc_bbox_preds, + aux_coords, + aux_labels, + aux_targets, + aux_label_weights, + aux_bbox_weights, + aux_feats, + attn_masks, + gt_bboxes_list, + gt_labels_list, + img_metas, + gt_bboxes_ignore=None, + ): + """ "Loss function. + + Args: + all_cls_scores (Tensor): Classification score of all + decoder layers, has shape + [nb_dec, bs, num_query, cls_out_channels]. + all_bbox_preds (Tensor): Sigmoid regression + outputs of all decode layers. Each is a 4D-tensor with + normalized coordinate format (cx, cy, w, h) and shape + [nb_dec, bs, num_query, 4]. + enc_cls_scores (Tensor): Classification scores of + points on encode feature map , has shape + (N, h*w, num_classes). Only be passed when as_two_stage is + True, otherwise is None. + enc_bbox_preds (Tensor): Regression results of each points + on the encode feature map, has shape (N, h*w, 4). Only be + passed when as_two_stage is True, otherwise is None. + gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image + with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. + gt_labels_list (list[Tensor]): Ground truth class indices for each + image with shape (num_gts, ). + img_metas (list[dict]): List of image meta information. + gt_bboxes_ignore (list[Tensor], optional): Bounding boxes + which can be ignored for each image. Default None. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + num_dec_layers = len(all_cls_scores) + all_labels = [aux_labels for _ in range(num_dec_layers)] + all_label_weights = [aux_label_weights for _ in range(num_dec_layers)] + all_bbox_targets = [aux_targets for _ in range(num_dec_layers)] + all_bbox_weights = [aux_bbox_weights for _ in range(num_dec_layers)] + img_metas_list = [img_metas for _ in range(num_dec_layers)] + all_gt_bboxes_ignore_list = [gt_bboxes_ignore for _ in range(num_dec_layers)] + + losses_cls, losses_bbox, losses_iou = multi_apply( + self.loss_single_aux, + all_cls_scores, + all_bbox_preds, + all_labels, + all_label_weights, + all_bbox_targets, + all_bbox_weights, + img_metas_list, + all_gt_bboxes_ignore_list, + ) + + loss_dict = dict() + # loss of proposal generated from encode feature map. + # loss from the last decoder layer + loss_dict["loss_cls_aux"] = losses_cls[-1] + loss_dict["loss_bbox_aux"] = losses_bbox[-1] + loss_dict["loss_iou_aux"] = losses_iou[-1] + # loss from other decoder layers + num_dec_layer = 0 + for loss_cls_i, loss_bbox_i, loss_iou_i in zip( + losses_cls[:-1], losses_bbox[:-1], losses_iou[:-1] + ): + loss_dict[f"d{num_dec_layer}.loss_cls_aux"] = loss_cls_i + loss_dict[f"d{num_dec_layer}.loss_bbox_aux"] = loss_bbox_i + loss_dict[f"d{num_dec_layer}.loss_iou_aux"] = loss_iou_i + num_dec_layer += 1 + return loss_dict + + # over-write because img_metas are needed as inputs for bbox_head. + def forward_train( + self, + x, + img_metas, + gt_bboxes, + gt_labels=None, + gt_bboxes_ignore=None, + proposal_cfg=None, + **kwargs, + ): + """Forward function for training mode. + + Args: + x (list[Tensor]): Features from backbone. + img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + gt_bboxes (Tensor): Ground truth bboxes of the image, + shape (num_gts, 4). + gt_labels (Tensor): Ground truth labels of each box, + shape (num_gts,). + gt_bboxes_ignore (Tensor): Ground truth bboxes to be + ignored, shape (num_ignored_gts, 4). + proposal_cfg (mmcv.Config): Test / postprocessing configuration, + if None, test_cfg would be used. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + assert proposal_cfg is None, '"proposal_cfg" must be None' + outs = self(x, img_metas) + if gt_labels is None: + loss_inputs = outs + (gt_bboxes, img_metas) + else: + loss_inputs = outs + (gt_bboxes, gt_labels, img_metas) + losses = self.loss(*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore) + enc_outputs = outs[-1] + return losses, enc_outputs + + def loss( + self, + all_cls_scores, + all_bbox_preds, + enc_cls_scores, + enc_bbox_preds, + enc_outputs, + gt_bboxes_list, + gt_labels_list, + img_metas, + gt_bboxes_ignore=None, + ): + """ "Loss function. + + Args: + all_cls_scores (Tensor): Classification score of all + decoder layers, has shape + [nb_dec, bs, num_query, cls_out_channels]. + all_bbox_preds (Tensor): Sigmoid regression + outputs of all decode layers. Each is a 4D-tensor with + normalized coordinate format (cx, cy, w, h) and shape + [nb_dec, bs, num_query, 4]. + enc_cls_scores (Tensor): Classification scores of + points on encode feature map , has shape + (N, h*w, num_classes). Only be passed when as_two_stage is + True, otherwise is None. + enc_bbox_preds (Tensor): Regression results of each points + on the encode feature map, has shape (N, h*w, 4). Only be + passed when as_two_stage is True, otherwise is None. + gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image + with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. + gt_labels_list (list[Tensor]): Ground truth class indices for each + image with shape (num_gts, ). + img_metas (list[dict]): List of image meta information. + gt_bboxes_ignore (list[Tensor], optional): Bounding boxes + which can be ignored for each image. Default None. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + + num_dec_layers = len(all_cls_scores) + all_gt_bboxes_list = [gt_bboxes_list for _ in range(num_dec_layers)] + all_gt_labels_list = [gt_labels_list for _ in range(num_dec_layers)] + all_gt_bboxes_ignore_list = [gt_bboxes_ignore for _ in range(num_dec_layers)] + img_metas_list = [img_metas for _ in range(num_dec_layers)] + + losses_cls, losses_bbox, losses_iou = multi_apply( + self.loss_single, + all_cls_scores, + all_bbox_preds, + all_gt_bboxes_list, + all_gt_labels_list, + img_metas_list, + all_gt_bboxes_ignore_list, + ) + + loss_dict = dict() + # loss of proposal generated from encode feature map. + if enc_cls_scores is not None: + binary_labels_list = [ + paddle.zeros_like(gt_labels_list[i]) for i in range(len(img_metas)) + ] + enc_loss_cls, enc_losses_bbox, enc_losses_iou = self.loss_single( + enc_cls_scores, + enc_bbox_preds, + gt_bboxes_list, + binary_labels_list, + img_metas, + gt_bboxes_ignore, + ) + loss_dict["enc_loss_cls"] = enc_loss_cls + loss_dict["enc_loss_bbox"] = enc_losses_bbox + loss_dict["enc_loss_iou"] = enc_losses_iou + + # loss from the last decoder layer + loss_dict["loss_cls"] = losses_cls[-1] + loss_dict["loss_bbox"] = losses_bbox[-1] + loss_dict["loss_iou"] = losses_iou[-1] + # loss from other decoder layers + num_dec_layer = 0 + for loss_cls_i, loss_bbox_i, loss_iou_i in zip( + losses_cls[:-1], losses_bbox[:-1], losses_iou[:-1] + ): + loss_dict[f"d{num_dec_layer}.loss_cls"] = loss_cls_i + loss_dict[f"d{num_dec_layer}.loss_bbox"] = loss_bbox_i + loss_dict[f"d{num_dec_layer}.loss_iou"] = loss_iou_i + num_dec_layer += 1 + return loss_dict + + def get_bboxes( + self, + all_cls_scores, + all_bbox_preds, + enc_cls_scores, + enc_bbox_preds, + enc_outputs, + img_metas, + rescale=False, + ): + """Transform network outputs for a batch into bbox predictions. + + Args: + all_cls_scores (Tensor): Classification score of all + decoder layers, has shape + [nb_dec, bs, num_query, cls_out_channels]. + all_bbox_preds (Tensor): Sigmoid regression + outputs of all decode layers. Each is a 4D-tensor with + normalized coordinate format (cx, cy, w, h) and shape + [nb_dec, bs, num_query, 4]. + enc_cls_scores (Tensor): Classification scores of + points on encode feature map , has shape + (N, h*w, num_classes). Only be passed when as_two_stage is + True, otherwise is None. + enc_bbox_preds (Tensor): Regression results of each points + on the encode feature map, has shape (N, h*w, 4). Only be + passed when as_two_stage is True, otherwise is None. + img_metas (list[dict]): Meta information of each image. + rescale (bool, optional): If True, return boxes in original + image space. Default False. + + Returns: + list[list[Tensor, Tensor]]: Each item in result_list is 2-tuple. \ + The first item is an (n, 5) tensor, where the first 4 columns \ + are bounding box positions (tl_x, tl_y, br_x, br_y) and the \ + 5-th column is a score between 0 and 1. The second item is a \ + (n,) tensor where each item is the predicted class label of \ + the corresponding box. + """ + cls_scores = all_cls_scores[-1] + bbox_preds = all_bbox_preds[-1] + result_list = [] + for img_id in range(len(img_metas)): + cls_score = cls_scores[img_id] + bbox_pred = bbox_preds[img_id] + img_shape = img_metas[img_id]["img_shape"] + scale_factor = img_metas[img_id]["scale_factor"] + proposals = self._get_bboxes_single( + cls_score, bbox_pred, img_shape, scale_factor, rescale + ) + result_list.append(proposals) + + return result_list + + def _get_bboxes_single(self, + cls_score, + bbox_pred, + img_shape, + scale_factor, + rescale=False, + ): + """Transform outputs from the last decoder layer into bbox predictions + for each image. + + Args: + cls_score (Tensor): Box score logits from the last decoder layer + for each image. Shape [num_query, cls_out_channels]. + bbox_pred (Tensor): Sigmoid outputs from the last decoder layer + for each image, with coordinate format (cx, cy, w, h) and + shape [num_query, 4]. + img_shape (tuple[int]): Shape of input image, (height, width, 3). + scale_factor (ndarray, optional): Scale factor of the image arange + as (w_scale, h_scale, w_scale, h_scale). + rescale (bool, optional): If True, return boxes in original image + space. Default False. + + Returns: + tuple[Tensor]: Results of detected bboxes and labels. + + - det_bboxes: Predicted bboxes with shape [num_query, 5], \ + where the first 4 columns are bounding box positions \ + (tl_x, tl_y, br_x, br_y) and the 5-th column are scores \ + between 0 and 1. + - det_labels: Predicted labels of the corresponding box with \ + shape [num_query]. + """ + assert len(cls_score) == len(bbox_pred) + max_per_img = self.test_cfg.get('max_per_img', self.num_query) + score_thr = self.test_cfg.get('score_thr', 0) + + # exclude background + if self.loss_cls.use_sigmoid: + cls_score = F.sigmoid(cls_score) + scores, indexes = cls_score.reshape([-1]).topk(max_per_img) + det_labels = indexes % self.num_classes + bbox_index = indexes // self.num_classes + bbox_pred = bbox_pred[bbox_index] + else: + scores, det_labels = F.softmax(cls_score, axis=-1)[..., :-1].max(-1) + scores, bbox_index = scores.topk(max_per_img) + bbox_pred = bbox_pred[bbox_index] + det_labels = det_labels[bbox_index] + + valid_mask = scores > score_thr + scores = scores[valid_mask] + bbox_pred = bbox_pred[valid_mask] + det_labels = det_labels[valid_mask] + + det_bboxes = bbox_cxcywh_to_xyxy(bbox_pred) + det_bboxes[:, 0::2] = det_bboxes[:, 0::2] * img_shape[1] + det_bboxes[:, 1::2] = det_bboxes[:, 1::2] * img_shape[0] + det_bboxes[:, 0::2].clip(min=0, max=img_shape[1]) + det_bboxes[:, 1::2].clip(min=0, max=img_shape[0]) + + if rescale: + det_bboxes /=paddle.concat([scale_factor[::-1], scale_factor[::-1]]) + det_bboxes = paddle.concat((scores.unsqueeze(1),det_bboxes.astype('float32')), -1) + proposals = paddle.concat((det_labels.unsqueeze(1).astype('float32'),det_bboxes),-1) + return proposals + + def loss_single( + self, + cls_scores, + bbox_preds, + gt_bboxes_list, + gt_labels_list, + img_metas, + gt_bboxes_ignore_list=None, + ): + """ "Loss function for outputs from a single decoder layer of a single + feature level. + + Args: + cls_scores (Tensor): Box score logits from a single decoder layer + for all images. Shape [bs, num_query, cls_out_channels]. + bbox_preds (Tensor): Sigmoid outputs from a single decoder layer + for all images, with normalized coordinate (cx, cy, w, h) and + shape [bs, num_query, 4]. + gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image + with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. + gt_labels_list (list[Tensor]): Ground truth class indices for each + image with shape (num_gts, ). + img_metas (list[dict]): List of image meta information. + gt_bboxes_ignore_list (list[Tensor], optional): Bounding + boxes which can be ignored for each image. Default None. + + Returns: + dict[str, Tensor]: A dictionary of loss components for outputs from + a single decoder layer. + """ + num_imgs = cls_scores.shape[0] + cls_scores_list = [cls_scores[i] for i in range(num_imgs)] + bbox_preds_list = [bbox_preds[i] for i in range(num_imgs)] + cls_reg_targets = self.get_targets( + cls_scores_list, + bbox_preds_list, + gt_bboxes_list, + gt_labels_list, + img_metas, + gt_bboxes_ignore_list, + ) + ( + labels_list, + label_weights_list, + bbox_targets_list, + bbox_weights_list, + num_total_pos, + num_total_neg, + ) = cls_reg_targets + labels = paddle.concat(labels_list, 0) + label_weights = paddle.concat(label_weights_list, 0) + bbox_targets = paddle.concat(bbox_targets_list, 0) + bbox_weights = paddle.concat(bbox_weights_list, 0) + + # classification loss + cls_scores = cls_scores.reshape((-1, self.cls_out_channels)) + # construct weighted avg_factor to match with the official DETR repo + cls_avg_factor = num_total_pos * 1.0 + num_total_neg * self.bg_cls_weight + if self.sync_cls_avg_factor: + cls_avg_factor = reduce_mean( + paddle.to_tensor([cls_avg_factor], dtype=cls_scores.dtype) + ) + cls_avg_factor = max(cls_avg_factor, 1) + loss_cls = self.loss_cls( + cls_scores, labels, label_weights, avg_factor=cls_avg_factor + ) + + # Compute the average number of gt boxes across all gpus, for + # normalization purposes + num_total_pos = paddle.to_tensor([num_total_pos], dtype=loss_cls.dtype) + num_total_pos = paddle.clip(reduce_mean(num_total_pos), min=1).item() + + # construct factors used for rescale bboxes + factors = [] + for img_meta, bbox_pred in zip(img_metas, bbox_preds): + img_h, img_w, _ = img_meta["img_shape"] + factor = ( + paddle.to_tensor([img_w, img_h, img_w, img_h], dtype=bbox_pred.dtype) + .unsqueeze(0) + .tile((bbox_pred.shape[0], 1)) + ) + factors.append(factor) + factors = paddle.concat(factors, 0) + + # DETR regress the relative position of boxes (cxcywh) in the image, + # thus the learning target is normalized by the image size. So here + # we need to re-scale them for calculating IoU loss + bbox_preds = bbox_preds.reshape((-1, 4)) + bboxes = bbox_cxcywh_to_xyxy(bbox_preds) * factors + bboxes_gt = bbox_cxcywh_to_xyxy(bbox_targets) * factors + + # regression IoU loss, defaultly GIoU loss + loss_iou = self.loss_iou( + bboxes, bboxes_gt, bbox_weights + ).mean() + + # regression L1 loss + loss_bbox = self.loss_bbox( + bbox_preds, bbox_targets, bbox_weights, avg_factor=num_total_pos + ) + + return loss_cls, loss_bbox, loss_iou + + def get_targets( + self, + cls_scores_list, + bbox_preds_list, + gt_bboxes_list, + gt_labels_list, + img_metas, + gt_bboxes_ignore_list=None, + ): + """"Compute regression and classification targets for a batch image. + + Outputs from a single decoder layer of a single feature level are used. + + Args: + cls_scores_list (list[Tensor]): Box score logits from a single + decoder layer for each image with shape [num_query, + cls_out_channels]. + bbox_preds_list (list[Tensor]): Sigmoid outputs from a single + decoder layer for each image, with normalized coordinate + (cx, cy, w, h) and shape [num_query, 4]. + gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image + with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. + gt_labels_list (list[Tensor]): Ground truth class indices for each + image with shape (num_gts, ). + img_metas (list[dict]): List of image meta information. + gt_bboxes_ignore_list (list[Tensor], optional): Bounding + boxes which can be ignored for each image. Default None. + + Returns: + tuple: a tuple containing the following targets. + + - labels_list (list[Tensor]): Labels for all images. + - label_weights_list (list[Tensor]): Label weights for all \ + images. + - bbox_targets_list (list[Tensor]): BBox targets for all \ + images. + - bbox_weights_list (list[Tensor]): BBox weights for all \ + images. + - num_total_pos (int): Number of positive samples in all \ + images. + - num_total_neg (int): Number of negative samples in all \ + images. + """ + num_imgs = len(cls_scores_list) + if gt_bboxes_ignore_list is None: + gt_bboxes_ignore_list = [gt_bboxes_ignore_list for _ in range(num_imgs)] + + ( + labels_list, + label_weights_list, + bbox_targets_list, + bbox_weights_list, + pos_inds_list, + neg_inds_list, + ) = multi_apply( + self._get_target_single, + cls_scores_list, + bbox_preds_list, + gt_bboxes_list, + gt_labels_list, + img_metas, + gt_bboxes_ignore_list, + ) + num_total_pos = sum((inds.numel() for inds in pos_inds_list)) + num_total_neg = sum((inds.numel() for inds in neg_inds_list)) + return ( + labels_list, + label_weights_list, + bbox_targets_list, + bbox_weights_list, + num_total_pos, + num_total_neg, + ) + + def _get_target_single( + self, + cls_score, + bbox_pred, + gt_bboxes, + gt_labels, + img_meta, + gt_bboxes_ignore=None, + ): + """ "Compute regression and classification targets for one image. + + Outputs from a single decoder layer of a single feature level are used. + + Args: + cls_score (Tensor): Box score logits from a single decoder layer + for one image. Shape [num_query, cls_out_channels]. + bbox_pred (Tensor): Sigmoid outputs from a single decoder layer + for one image, with normalized coordinate (cx, cy, w, h) and + shape [num_query, 4]. + gt_bboxes (Tensor): Ground truth bboxes for one image with + shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. + gt_labels (Tensor): Ground truth class indices for one image + with shape (num_gts, ). + img_meta (dict): Meta information for one image. + gt_bboxes_ignore (Tensor, optional): Bounding boxes + which can be ignored. Default None. + + Returns: + tuple[Tensor]: a tuple containing the following for one image. + + - labels (Tensor): Labels of each image. + - label_weights (Tensor]): Label weights of each image. + - bbox_targets (Tensor): BBox targets of each image. + - bbox_weights (Tensor): BBox weights of each image. + - pos_inds (Tensor): Sampled positive indices for each image. + - neg_inds (Tensor): Sampled negative indices for each image. + """ + num_bboxes = bbox_pred.shape[0] + # assigner and sampler + assign_result = self.assigner.assign( + bbox_pred, cls_score, gt_bboxes, gt_labels, img_meta, gt_bboxes_ignore + ) + sampling_result = self.sampler.sample(assign_result, bbox_pred, gt_bboxes) + pos_inds = sampling_result.pos_inds + neg_inds = sampling_result.neg_inds + # label targets + labels = paddle.full((num_bboxes,), self.num_classes, dtype="int64") + labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds][..., 0].astype("int64") + label_weights = paddle.ones((num_bboxes,), dtype=gt_bboxes.dtype) + # bbox targets + bbox_targets = paddle.zeros_like(bbox_pred) + bbox_weights = paddle.zeros_like(bbox_pred) + bbox_weights[pos_inds] = 1.0 + img_h, img_w, _ = img_meta["img_shape"] + + + # DETR regress the relative position of boxes (cxcywh) in the image. + # Thus the learning target should be normalized by the image size, also + # the box format should be converted from defaultly x1y1x2y2 to cxcywh. + factor = paddle.to_tensor( + [img_w, img_h, img_w, img_h], dtype=bbox_pred.dtype + ).unsqueeze(0) + pos_gt_bboxes_normalized = sampling_result.pos_gt_bboxes / factor + pos_gt_bboxes_targets = bbox_xyxy_to_cxcywh(pos_gt_bboxes_normalized) + bbox_targets[pos_inds] = pos_gt_bboxes_targets + + return (labels, label_weights, bbox_targets, bbox_weights, pos_inds, neg_inds) + + def simple_test(self, feats, img_metas, rescale=False): + """Test det bboxes without test-time augmentation. + + Args: + feats (tuple[Tensor]): Multi-level features from the + upstream network, each is a 4D-tensor. + img_metas (list[dict]): List of image information. + rescale (bool, optional): Whether to rescale the results. + Defaults to False. + + Returns: + list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple. + The first item is ``bboxes`` with shape (n, 5), + where 5 represent (tl_x, tl_y, br_x, br_y, score). + The shape of the second tensor in the tuple is ``labels`` + with shape (n,) + """ + # forward of this head requires img_metas + outs = self.forward(feats, img_metas) + result_list = self.get_bboxes(*outs, img_metas, rescale=rescale) + return result_list + \ No newline at end of file diff --git a/ppdet/modeling/heads/co_roi_head.py b/ppdet/modeling/heads/co_roi_head.py new file mode 100644 index 0000000000..73c6059f64 --- /dev/null +++ b/ppdet/modeling/heads/co_roi_head.py @@ -0,0 +1,129 @@ + +import paddle +import paddle.nn.functional as F + +from ppdet.core.workspace import register +from ppdet.modeling.heads.bbox_head import BBoxHead +from .roi_extractor import RoIAlign +from ..cls_utils import _get_class_default_kwargs + +__all__ = ['Co_RoiHead'] + +@register +class Co_RoiHead(BBoxHead): + __shared__ = ['num_classes', 'use_cot'] + __inject__ = ['bbox_assigner', 'bbox_loss', 'loss_cot'] + """ + RCNN bbox head + + Args: + head (nn.Layer): Extract feature in bbox head + in_channel (int): Input channel after RoI extractor + roi_extractor (object): The module of RoI Extractor + bbox_assigner (object): The module of Box Assigner, label and sample the + box. + with_pool (bool): Whether to use pooling for the RoI feature. + num_classes (int): The number of classes + bbox_weight (List[float]): The weight to get the decode box + cot_classes (int): The number of base classes + loss_cot (object): The module of Label-cotuning + use_cot(bool): whether to use Label-cotuning + """ + + def __init__(self, + head, + in_channel, + roi_extractor=_get_class_default_kwargs(RoIAlign), + bbox_assigner='BboxAssigner', + with_pool=False, + num_classes=80, + bbox_weight=[10., 10., 5., 5.], + bbox_loss=None, + loss_normalize_pos=False, + cot_classes=None, + loss_cot='COTLoss', + use_cot=False): + super(Co_RoiHead, self).__init__( + head=head, + in_channel=in_channel, + roi_extractor=roi_extractor, + bbox_assigner=bbox_assigner, + with_pool=with_pool, + num_classes=num_classes, + bbox_weight=bbox_weight, + bbox_loss =bbox_loss, + loss_normalize_pos=loss_normalize_pos, + cot_classes=cot_classes, + loss_cot=loss_cot, + use_cot=use_cot + ) + self.head=head + + def forward(self, body_feats=None, rois=None, rois_num=None, inputs=None, cot=False): + """ + body_feats (list[Tensor]): Feature maps from backbone + rois (list[Tensor]): RoIs generated from RPN module + rois_num (Tensor): The number of RoIs in each image + inputs (dict{Tensor}): The ground-truth of image + """ + if self.training: + rois, rois_num, targets = self.bbox_assigner(rois, rois_num, inputs) + self.assigned_rois = (rois, rois_num) + self.assigned_targets = targets + + rois_feat = self.roi_extractor(body_feats, rois, rois_num) + bbox_feat = self.head(rois_feat) + if self.with_pool: + feat = F.adaptive_avg_pool2d(bbox_feat, output_size=1) + feat = paddle.squeeze(feat, axis=[2, 3]) + else: + feat = bbox_feat + if self.use_cot: + scores = self.cot_bbox_score(feat) + cot_scores = self.bbox_score(feat) + else: + scores = self.bbox_score(feat) + deltas = self.bbox_delta(feat) + + if self.training: + loss = self.get_loss( + scores, + deltas, + targets, + rois, + self.bbox_weight, + loss_normalize_pos=self.loss_normalize_pos) + + if self.cot_relation is not None: + loss_cot = self.loss_cot(cot_scores, targets, self.cot_relation) + loss.update(loss_cot) + + target_labels,target_bboxs,_ = targets + max_proposal = target_labels[0].shape[0] + # get pos_coords + ori_proposals, ori_labels, ori_bbox_targets, ori_bbox_feats = [], [], [], [] + for i in range(len(rois)): + ori_proposal = rois[i].unsqueeze(0) + ori_label = target_labels[i].unsqueeze(0) + ori_bbox_target = target_bboxs[i].unsqueeze(0) + + ori_bbox_feat = rois_feat[i*max_proposal:(i+1)*max_proposal].mean(-1).mean(-1) + ori_bbox_feat = ori_bbox_feat.unsqueeze(0) + ori_proposals.append(ori_proposal) + ori_labels.append(ori_label) + ori_bbox_targets.append(ori_bbox_target) + ori_bbox_feats.append(ori_bbox_feat) + + ori_coords = paddle.concat(ori_proposals, axis=0) + ori_labels = paddle.concat(ori_labels, axis=0) + ori_bbox_targets = paddle.concat(ori_bbox_targets, axis=0) + ori_bbox_feats = paddle.concat(ori_bbox_feats, axis=0) + pos_coords = (ori_coords, ori_labels, ori_bbox_targets, ori_bbox_feats, 'rcnn') + loss.update(pos_coords=pos_coords) + return loss, bbox_feat + else: + if cot: + pred = self.get_prediction(cot_scores, deltas) + else: + pred = self.get_prediction(scores, deltas) + return pred, self.head diff --git a/ppdet/modeling/heads/petr_head.py b/ppdet/modeling/heads/petr_head.py index 90760c6651..5888d77a5e 100644 --- a/ppdet/modeling/heads/petr_head.py +++ b/ppdet/modeling/heads/petr_head.py @@ -186,7 +186,7 @@ def __init__(self, loss_oks='OKSLoss', loss_hm='CenterFocalLoss', with_kpt_refine=True, - assigner='PoseHungarianAssigner', + assigner='HungarianAssigner', sampler='PseudoSampler', loss_kpt_rpn='L1Loss', loss_kpt_refine='L1Loss', diff --git a/ppdet/modeling/proposal_generator/anchor_generator.py b/ppdet/modeling/proposal_generator/anchor_generator.py index d189f784a2..f1ae6ff9e7 100644 --- a/ppdet/modeling/proposal_generator/anchor_generator.py +++ b/ppdet/modeling/proposal_generator/anchor_generator.py @@ -23,7 +23,7 @@ from ppdet.core.workspace import register -__all__ = ['AnchorGenerator', 'RetinaAnchorGenerator', 'S2ANetAnchorGenerator'] +__all__ = ['AnchorGenerator', 'RetinaAnchorGenerator', 'S2ANetAnchorGenerator','CoAnchorGenerator'] @register @@ -51,15 +51,17 @@ def __init__(self, aspect_ratios=[0.5, 1.0, 2.0], strides=[16.0], variance=[1.0, 1.0, 1.0, 1.0], - offset=0.): + offset=0., + ): super(AnchorGenerator, self).__init__() self.anchor_sizes = anchor_sizes self.aspect_ratios = aspect_ratios self.strides = strides + self.num_levels = len(self.strides) self.variance = variance self.cell_anchors = self._calculate_anchors(len(strides)) self.offset = offset - + def _broadcast_params(self, params, num_features): if not isinstance(params[0], (list, tuple)): # list[float] return [params] * num_features @@ -121,6 +123,7 @@ def forward(self, input): anchors_over_all_feature_maps = self._grid_anchors(grid_sizes) return anchors_over_all_feature_maps + @property def num_anchors(self): """ @@ -155,7 +158,100 @@ def __init__(self, variance=variance, offset=offset) + +@register +class CoAnchorGenerator(AnchorGenerator): + def __init__(self, + octave_base_scale=4, + scales_per_octave=3, + aspect_ratios=[0.5, 1.0, 2.0], + strides=[8.0, 16.0, 32.0, 64.0, 128.0], + variance=[1.0, 1.0, 1.0, 1.0], + offset=0.0): + anchor_sizes = [] + for s in strides: + anchor_sizes.append([ + s * octave_base_scale * 2**(i/scales_per_octave) \ + for i in range(scales_per_octave)]) + super(CoAnchorGenerator, self).__init__( + anchor_sizes=anchor_sizes, + aspect_ratios=aspect_ratios, + strides=strides, + variance=variance, + offset=offset) + + def _meshgrid(self, x, y, row_major=True): + yy, xx = paddle.meshgrid(y, x) + yy = yy.reshape([-1]) + xx = xx.reshape([-1]) + if row_major: + return xx, yy + else: + return yy, xx + + def valid_flags(self, featmap_sizes, pad_shape): + """Generate valid flags of anchors in multiple feature levels. + + Args: + featmap_sizes (list(tuple)): List of feature map sizes in + multiple feature levels. + pad_shape (tuple): The padded shape of the image. + device (str): Device where the anchors will be put on. + Return: + list(torch.Tensor): Valid flags of anchors in multiple levels. + """ + featmap_sizes = [feature_map.shape[-2:] for feature_map in featmap_sizes] + num_base_anchors = [base_anchors.shape[0] for base_anchors in self.cell_anchors] + assert self.num_levels == len(featmap_sizes) + multi_level_flags = [] + for i in range(self.num_levels): + anchor_stride = self.strides[i] + feat_h, feat_w = featmap_sizes[i] + h, w = pad_shape[:2] + valid_feat_h = min(int(np.ceil(h / anchor_stride)), feat_h) + valid_feat_w = min(int(np.ceil(w / anchor_stride)), feat_w) + flags = self.single_level_valid_flags((feat_h, feat_w), + (valid_feat_h, valid_feat_w), + num_base_anchors[i], + ) + multi_level_flags.append(flags) + return multi_level_flags + + def single_level_valid_flags(self, + featmap_size, + valid_size, + num_base_anchors, + ): + """Generate the valid flags of anchor in a single feature map. + + Args: + featmap_size (tuple[int]): The size of feature maps, arrange + as (h, w). + valid_size (tuple[int]): The valid size of the feature maps. + num_base_anchors (int): The number of base anchors. + device (str, optional): Device where the flags will be put on. + Defaults to 'cuda'. + + Returns: + torch.Tensor: The valid flags of each anchor in a single level \ + feature map. + """ + feat_h, feat_w = featmap_size + valid_h, valid_w = valid_size + assert valid_h <= feat_h and valid_w <= feat_w + valid_x = paddle.zeros(feat_w, dtype='int32') + valid_y = paddle.zeros(feat_h, dtype='int32') + valid_x[:valid_w] = 1 + valid_y[:valid_h] = 1 + valid_xx, valid_yy = self._meshgrid(valid_x, valid_y) + valid = valid_xx & valid_yy + valid = paddle.reshape(valid, [-1, 1]) + valid = paddle.expand(valid, [-1, num_base_anchors]).reshape([-1]) + + return valid + + @register class S2ANetAnchorGenerator(nn.Layer): """ @@ -202,15 +298,6 @@ def gen_base_anchors(self): base_anchors = paddle.round(base_anchors) return base_anchors - def _meshgrid(self, x, y, row_major=True): - yy, xx = paddle.meshgrid(y, x) - yy = yy.reshape([-1]) - xx = xx.reshape([-1]) - if row_major: - return xx, yy - else: - return yy, xx - def forward(self, featmap_size, stride=16): # featmap_size*stride project it to original area @@ -227,6 +314,14 @@ def forward(self, featmap_size, stride=16): all_anchors = self.rect2rbox(all_anchors) return all_anchors + def _meshgrid(self, x, y, row_major=True): + yy, xx = paddle.meshgrid(y, x) + yy = yy.reshape([-1]) + xx = xx.reshape([-1]) + if row_major: + return xx, yy + else: + return yy, xx def valid_flags(self, featmap_size, valid_size): feat_h, feat_w = featmap_size valid_h, valid_w = valid_size diff --git a/ppdet/modeling/transformers/__init__.py b/ppdet/modeling/transformers/__init__.py index 5eac4f110d..27148e09ab 100644 --- a/ppdet/modeling/transformers/__init__.py +++ b/ppdet/modeling/transformers/__init__.py @@ -23,6 +23,7 @@ from . import rtdetr_transformer from . import hybrid_encoder from . import mask_rtdetr_transformer +from . import co_deformable_detr_transformer from .detr_transformer import * from .utils import * @@ -36,3 +37,4 @@ from .rtdetr_transformer import * from .hybrid_encoder import * from .mask_rtdetr_transformer import * +from .co_deformable_detr_transformer import * diff --git a/ppdet/modeling/transformers/co_deformable_detr_transformer.py b/ppdet/modeling/transformers/co_deformable_detr_transformer.py new file mode 100644 index 0000000000..b9a546b0a6 --- /dev/null +++ b/ppdet/modeling/transformers/co_deformable_detr_transformer.py @@ -0,0 +1,639 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +this code is base on https://github.com/Sense-X/Co-DETR/blob/main/projects/models/transformer.py +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import numpy as np +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddle import ParamAttr + +from ppdet.core.workspace import register +from ..layers import MultiHeadAttention, _convert_attention_mask +from .utils import _get_clones +from ..initializer import linear_init_, normal_, constant_, xavier_uniform_ +from ..shape_spec import ShapeSpec + +from .petr_transformer import ( + PETR_TransformerDecoder, + MSDeformableAttention, + TransformerEncoder, + inverse_sigmoid, +) + +__all__ = [ + "CoDeformableDetrTransformerDecoder", + "CoDeformableDetrTransformer", + "CoTransformerEncoder", + +] + +@register +class CoTransformerEncoder(TransformerEncoder): + def __init__(self, encoder_layer, num_layers, norm=None,out_channel=256,spatial_scales=[1/8,1/16,1/32,1/64,1/128]): + super().__init__(encoder_layer, num_layers, norm) + self.out_channel=out_channel + self.spatial_scales=spatial_scales + + @property + def out_shape(self): + return [ + ShapeSpec( + channels=self.out_channel, stride=1. / s) + for s in self.spatial_scales + ] + +@register +class CoDeformableDetrTransformerDecoder(PETR_TransformerDecoder): + __inject__ = ["decoder_layer"] + + def __init__( + self, + decoder_layer, + num_layers, + norm=None, + return_intermediate=False, + look_forward_twice=False, + **kwargs + ): + super().__init__(decoder_layer, num_layers, norm, return_intermediate, **kwargs) + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + self.return_intermediate = return_intermediate + self.look_forward_twice = look_forward_twice + + def forward( + self, + query, + *args, + reference_points=None, + valid_ratios=None, + reg_branches=None, + **kwargs + ): + """Forward function for `TransformerDecoder`. + + Args: + query (Tensor): Input query with shape (num_query, bs, embed_dims). + reference_points (Tensor): The reference points of offset, + has shape (bs, num_query, K*2). + valid_ratios (Tensor): The radios of valid points on the feature + map, has shape (bs, num_levels, 2). + reg_branch: (obj:`nn.ModuleList`): Used for refining the regression results. + Only would be passed when with_box_refine is True,otherwise would be + passed a `None`. + + Returns: + Tensor: Results with shape [1, num_query, bs, embed_dims] when + return_intermediate is `False`, otherwise it has shape + [num_layers, num_query, bs, embed_dims]. + """ + + output = query + intermediate = [] + intermediate_reference_points = [] + for lid, layer in enumerate(self.layers): + if reference_points.shape[-1] == 4: + reference_points_input = ( + reference_points[:, :, None] + * paddle.concat([valid_ratios, valid_ratios], -1)[:, None] + ) + else: + assert reference_points.shape[-1] == 2 + reference_points_input = ( + reference_points[:, :, None] * valid_ratios[:, None] + ) + output = layer( + output, *args, reference_points=reference_points_input, **kwargs + ) + + if reg_branches is not None: + tmp = reg_branches[lid](output) + if reference_points.shape[-1] == 4: + new_reference_points = tmp + inverse_sigmoid(reference_points) + new_reference_points = F.sigmoid(new_reference_points) + else: + assert reference_points.shape[-1] == 2 + new_reference_points = tmp + new_reference_points[..., :2] = tmp[..., :2] + inverse_sigmoid( + reference_points + ) + new_reference_points = F.sigmoid(new_reference_points) + reference_points = new_reference_points.detach() + + if self.return_intermediate: + intermediate.append(output) + intermediate_reference_points.append( + new_reference_points + if self.look_forward_twice + else reference_points + ) + + if self.return_intermediate: + return paddle.stack(intermediate), paddle.stack( + intermediate_reference_points + ) + + return output, reference_points + + +@register +class CoDeformableDetrTransformer(nn.Layer): + __inject__ = ["encoder", "decoder"] + + def __init__( + self, + encoder="", + decoder="", + mixed_selection=True, + with_pos_coord=True, + with_coord_feat=True, + num_co_heads=1, + as_two_stage=False, + two_stage_num_proposals=300, + num_feature_levels=4, + **kwargs + ): + super(CoDeformableDetrTransformer, self).__init__(**kwargs) + + self.as_two_stage = as_two_stage + self.two_stage_num_proposals = two_stage_num_proposals + self.encoder = encoder + self.decoder = decoder + self.embed_dims = self.encoder.embed_dims + self.mixed_selection = mixed_selection + self.with_pos_coord = with_pos_coord + self.with_coord_feat = with_coord_feat + self.num_co_heads = num_co_heads + self.num_feature_levels = num_feature_levels + self.init_layers() + + def init_layers(self): + """Initialize layers of the DeformableDetrTransformer.""" + if self.with_pos_coord: + if self.num_co_heads > 0: + # bug: this code should be 'self.head_pos_embed = nn.Embedding(self.num_co_heads, self.embed_dims)', we keep this bug for reproducing our results with ResNet-50. + # You can fix this bug when reproducing results with swin transformer. + self.head_pos_embed = nn.Embedding( + self.num_co_heads, 1, 1, self.embed_dims + ) + self.aux_pos_trans = nn.LayerList() + self.aux_pos_trans_norm = nn.LayerList() + self.pos_feats_trans = nn.LayerList() + self.pos_feats_norm = nn.LayerList() + for i in range(self.num_co_heads): + self.aux_pos_trans.append( + nn.Linear(self.embed_dims * 2, self.embed_dims * 2) + ) + self.aux_pos_trans_norm.append(nn.LayerNorm(self.embed_dims * 2)) + if self.with_coord_feat: + self.pos_feats_trans.append( + nn.Linear(self.embed_dims, self.embed_dims) + ) + self.pos_feats_norm.append(nn.LayerNorm(self.embed_dims)) + + self.level_embeds = paddle.create_parameter( + (self.num_feature_levels, self.embed_dims), dtype="float32" + ) + + if self.as_two_stage: + self.enc_output = nn.Linear(self.embed_dims, self.embed_dims) + self.enc_output_norm = nn.LayerNorm(self.embed_dims) + self.pos_trans = nn.Linear(self.embed_dims * 2, self.embed_dims * 2) + self.pos_trans_norm = nn.LayerNorm(self.embed_dims * 2) + else: + self.reference_points = nn.Linear(self.embed_dims, 2) + + def init_weights(self): + """Initialize the transformer weights.""" + for p in self.parameters(): + if p.rank() > 1: + xavier_uniform_(p) + if hasattr(p, "bias") and p.bias is not None: + constant_(p.bais) + for m in self.sublayers(): + if isinstance(m, MSDeformableAttention): + m._reset_parameters() + if not self.as_two_stage: + xavier_uniform_(self.reference_points.weight) + constant_(self.reference_points.bias) + normal_(self.level_embeds) + + def get_proposal_pos_embed(self, proposals, num_pos_feats=128, temperature=10000): + """Get the position embedding of proposal.""" + num_pos_feats = self.embed_dims // 2 + scale = 2 * math.pi + dim_t = paddle.arange(num_pos_feats, dtype=paddle.float32) + dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats) + # N, L, 4 + proposals = proposals.sigmoid() * scale + # N, L, 4, 128 + pos = proposals[:, :, :, None] / dim_t + # N, L, 4, 64, 2 + pos = paddle.stack( + (pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), axis=4 + ).flatten(2) + + return pos + + def get_valid_ratio(self, mask): + """Get the valid radios of feature maps of all level.""" + _, H, W = mask.shape + valid_H = paddle.sum(paddle.logical_not(mask[:, :, 0]).astype("float"), 1) + valid_W = paddle.sum(paddle.logical_not(mask[:, 0, :]).astype("float"), 1) + valid_ratio_h = valid_H.astype("float") / H + valid_ratio_w = valid_W.astype("float") / W + valid_ratio = paddle.stack([valid_ratio_w, valid_ratio_h], -1) + return valid_ratio + + def get_proposal_pos_embed(self, proposals, num_pos_feats=128, temperature=10000): + """Get the position embedding of proposal.""" + scale = 2 * math.pi + dim_t = paddle.arange(num_pos_feats, dtype="float32") + dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats) + # N, L, 4 + proposals = F.sigmoid(proposals) * scale + # N, L, 4, 128 + pos = proposals[:, :, :, None] / dim_t + # N, L, 4, 64, 2 + pos = paddle.stack( + (pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), axis=4 + ).flatten(2) + return pos + + @staticmethod + def get_reference_points(spatial_shapes, valid_ratios): + """Get the reference points used in decoder. + + Args: + spatial_shapes (Tensor): The shape of all feature maps, + has shape (num_level, 2). + valid_ratios (Tensor): The radios of valid points on the + feature map, has shape (bs, num_levels, 2). + + Returns: + Tensor: reference points used in decoder, has \ + shape (bs, num_keys, num_levels, 2). + """ + reference_points_list = [] + for lvl, (H, W) in enumerate(spatial_shapes): + ref_y, ref_x = paddle.meshgrid( + paddle.linspace(0.5, H - 0.5, H, dtype="float32"), + paddle.linspace(0.5, W - 0.5, W, dtype="float32"), + ) + ref_y = ref_y.reshape((-1,))[None] / (valid_ratios[:, None, lvl, 1] * H) + ref_x = ref_x.reshape((-1,))[None] / (valid_ratios[:, None, lvl, 0] * W) + ref = paddle.stack((ref_x, ref_y), -1) + reference_points_list.append(ref) + reference_points = paddle.concat(reference_points_list, 1) + reference_points = reference_points[:, :, None] * valid_ratios[:, None] + return reference_points + + def gen_encoder_output_proposals(self, memory, memory_padding_mask, spatial_shapes): + """Generate proposals from encoded memory. + + Args: + memory (Tensor): The output of encoder, has shape + (bs, num_key, embed_dim). num_key is equal the number of points + on feature map from all level. + memory_padding_mask (Tensor): Padding mask for memory. + has shape (bs, num_key). + spatial_shapes (Tensor): The shape of all feature maps. + has shape (num_level, 2). + + Returns: + tuple: A tuple of feature map and bbox prediction. + + - output_memory (Tensor): The input of decoder, has shape + (bs, num_key, embed_dim). num_key is equal the number of + points on feature map from all levels. + - output_proposals (Tensor): The normalized proposal + after a inverse sigmoid, has shape (bs, num_keys, 4). + """ + + N, S, C = memory.shape + proposals = [] + _cur = 0 + + for lvl, (H, W) in enumerate(spatial_shapes): + mask_flatten_ = memory_padding_mask[:, _cur : (_cur + H * W)].reshape( + [N, H, W, 1] + ) + + valid_H = paddle.sum(paddle.logical_not(mask_flatten_[:, :, 0, 0]).astype("float"), 1) + valid_W = paddle.sum(paddle.logical_not(mask_flatten_[:, 0, :, 0]).astype("float"), 1) + + grid_y, grid_x = paddle.meshgrid( + paddle.linspace(0, H - 1, H, dtype="float32"), + paddle.linspace(0, W - 1, W, dtype="float32"), + ) + grid = paddle.concat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) + + scale = paddle.concat( + [valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1 + ).reshape([N, 1, 1, 2]) + grid = (grid.unsqueeze(0).expand((N, -1, -1, -1)) + 0.5) / scale + wh = paddle.ones_like(grid) * 0.05 * (2.0**lvl) + proposal = paddle.concat((grid, wh), -1).reshape([N, -1, 4]) + proposals.append(proposal) + _cur += H * W + output_proposals = paddle.concat(proposals, 1) + output_proposals_valid = ( + ((output_proposals > 0.01) & (output_proposals < 0.99)) + .all(-1, keepdim=True) + ) + output_proposals = paddle.log(output_proposals / (1 - output_proposals)) + output_proposals = output_proposals.masked_fill( + memory_padding_mask.unsqueeze(-1), + float("inf"), + ) + output_proposals = output_proposals.masked_fill( + ~output_proposals_valid, float("inf") + ) + + output_memory = memory + output_memory = output_memory.masked_fill( + memory_padding_mask.unsqueeze(-1), float(0) + ) + output_memory = output_memory.masked_fill(~output_proposals_valid, float(0)) + output_memory = self.enc_output_norm(self.enc_output(output_memory)) + return output_memory, output_proposals + + def forward( + self, + mlvl_feats, + mlvl_masks, + query_embed, + mlvl_pos_embeds, + reg_branches=None, + cls_branches=None, + return_encoder_output=False, + attn_masks=None, + **kwargs + ): + """Forward function for `Transformer`. + + Args: + mlvl_feats (list(Tensor)): Input queries from + different level. Each element has shape + [bs, embed_dims, h, w]. + mlvl_masks (list(Tensor)): The key_padding_mask from + different level used for encoder and decoder, + each element has shape [bs, h, w]. + query_embed (Tensor): The query embedding for decoder, + with shape [num_query, c]. + mlvl_pos_embeds (list(Tensor)): The positional encoding + of feats from different level, has the shape + [bs, embed_dims, h, w]. + reg_branches (obj:`nn.ModuleList`): Regression heads for + feature maps from each decoder layer. Only would + be passed when + `with_box_refine` is True. Default to None. + cls_branches (obj:`nn.ModuleList`): Classification heads + for feature maps from each decoder layer. Only would + be passed when `as_two_stage` + is True. Default to None. + + + Returns: + tuple[Tensor]: results of decoder containing the following tensor. + + - inter_states: Outputs from decoder. If + return_intermediate_dec is True output has shape \ + (num_dec_layers, bs, num_query, embed_dims), else has \ + shape (1, bs, num_query, embed_dims). + - init_reference_out: The initial value of reference \ + points, has shape (bs, num_queries, 4). + - inter_references_out: The internal value of reference \ + points in decoder, has shape \ + (num_dec_layers, bs,num_query, embed_dims) + - enc_outputs_class: The classification score of \ + proposals generated from \ + encoder's feature maps, has shape \ + (batch, h*w, num_classes). \ + Only would be returned when `as_two_stage` is True, \ + otherwise None. + - enc_outputs_coord_unact: The regression results \ + generated from encoder's feature maps., has shape \ + (batch, h*w, 4). Only would \ + be returned when `as_two_stage` is True, \ + otherwise None. + """ + assert self.as_two_stage or query_embed is not None + + feat_flatten = [] + mask_flatten = [] + lvl_pos_embed_flatten = [] + spatial_shapes = [] + for lvl, (feat, mask, pos_embed) in enumerate( + zip(mlvl_feats, mlvl_masks, mlvl_pos_embeds) + ): + bs, c, h, w = feat.shape + spatial_shape = (h, w) + spatial_shapes.append(spatial_shape) + feat = feat.flatten(2).transpose((0, 2, 1)) + mask = mask.flatten(1) + pos_embed = pos_embed.flatten(2).transpose((0, 2, 1)) + lvl_pos_embed = pos_embed + self.level_embeds[lvl].reshape((1, 1, -1)) + lvl_pos_embed_flatten.append(lvl_pos_embed) + feat_flatten.append(feat) + mask_flatten.append(mask) + + feat_flatten = paddle.concat(feat_flatten, 1) + mask_flatten = paddle.concat(mask_flatten, 1) + lvl_pos_embed_flatten = paddle.concat(lvl_pos_embed_flatten, 1) + + spatial_shapes = paddle.to_tensor(spatial_shapes, dtype="int64") + # [l], 每一个level的起始index + level_start_index = paddle.concat( + [paddle.zeros([1], dtype="int64"), spatial_shapes.prod(1).cumsum(0)[:-1]] + ) + + valid_ratios = paddle.stack([self.get_valid_ratio(m) for m in mlvl_masks], 1) + reference_points = self.get_reference_points(spatial_shapes, valid_ratios) + + memory = self.encoder( + src=feat_flatten, + pos_embed=lvl_pos_embed_flatten, + src_mask=mask_flatten, + value_spatial_shapes=spatial_shapes, + reference_points=reference_points, + value_level_start_index=level_start_index, + valid_ratios=valid_ratios, + ) + + bs, _, c = memory.shape + if self.as_two_stage: + output_memory, output_proposals = self.gen_encoder_output_proposals( + memory, mask_flatten, spatial_shapes + ) + enc_outputs_class = cls_branches[self.decoder.num_layers](output_memory) + enc_outputs_coord_unact = ( + reg_branches[self.decoder.num_layers](output_memory) + output_proposals + ) + topk = self.two_stage_num_proposals + # We only use the first channel in enc_outputs_class as foreground, + # the other (num_classes - 1) channels are actually not used. + # Its targets are set to be 0s, which indicates the first + # class (foreground) because we use [0, num_classes - 1] to + # indicate class labels, background class is indicated by + # num_classes (similar convention in RPN). + topk_proposals = paddle.topk(enc_outputs_class[..., 0], topk, axis=1)[1] + # paddle.take_along_axis 对应torch.gather + topk_coords_unact = paddle.take_along_axis( + enc_outputs_coord_unact, topk_proposals.unsqueeze(-1).tile([1, 1, 4]),axis=1 + ) + topk_coords_unact = topk_coords_unact.detach() + reference_points = F.sigmoid(topk_coords_unact) + init_reference_out = reference_points + pos_trans_out = self.pos_trans_norm( + self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact.astype('float32'))) + ) + if not self.mixed_selection: + query_pos, query = paddle.split(pos_trans_out, pos_trans_out.shape[2]//c, axis=2) + else: + # query_embed here is the content embed for deformable DETR + query = query_embed.unsqueeze(0).expand([bs, -1, -1]) + query_pos, _ = paddle.split(pos_trans_out, pos_trans_out.shape[2]//c, axis=2) + else: + query_pos, query = paddle.split(query_embed, query_embed.shape[1]//c, axis=1) + query_pos = query_pos.unsqueeze(0).expand([bs, -1, -1]) + query = query.unsqueeze(0).expand([bs, -1, -1]) + reference_points = F.sigmoid(self.reference_points(query_pos)) + init_reference_out = reference_points + + # decoder + inter_states, inter_references = self.decoder( + query=query, + memory=memory, + query_pos_embed=query_pos, # error + memory_mask=mask_flatten, + reference_points=reference_points, # error + value_spatial_shapes=spatial_shapes, + value_level_start_index=level_start_index, + valid_ratios=valid_ratios, + reg_branches=reg_branches, + attn_masks=attn_masks, + **kwargs + ) + inter_references_out = inter_references + if self.as_two_stage: + if return_encoder_output: + return ( + inter_states, + init_reference_out, + inter_references_out, + enc_outputs_class, + enc_outputs_coord_unact, + memory, + ) + return ( + inter_states, + init_reference_out, + inter_references_out, + enc_outputs_class, + enc_outputs_coord_unact, + ) + if return_encoder_output: + return ( + inter_states, + init_reference_out, + inter_references_out, + None, + None, + memory, + ) + return inter_states, init_reference_out, inter_references_out, None, None + + def forward_aux( + self, + mlvl_feats, + mlvl_masks, + query_embed, + mlvl_pos_embeds, + pos_anchors, + pos_feats=None, + reg_branches=None, + cls_branches=None, + return_encoder_output=False, + attn_masks=None, + head_idx=0, + **kwargs + ): + feat_flatten = [] + mask_flatten = [] + spatial_shapes = [] + for lvl, (feat, mask, pos_embed) in enumerate( + zip(mlvl_feats, mlvl_masks, mlvl_pos_embeds) + ): + bs, c, h, w = feat.shape + spatial_shape = (h, w) + spatial_shapes.append(spatial_shape) + feat = feat.flatten(2).transpose((0,2,1)) + mask = mask.flatten(1) + feat_flatten.append(feat) + mask_flatten.append(mask) + + feat_flatten = paddle.concat(feat_flatten, 1) + mask_flatten = paddle.concat(mask_flatten, 1) + spatial_shapes = paddle.to_tensor(spatial_shapes,dtype=paddle.int64) + # [l], 每一个level的起始index + level_start_index = paddle.concat( + [paddle.zeros([1], dtype="int64"), spatial_shapes.prod(1).cumsum(0)[:-1]] + ) + valid_ratios = paddle.stack([self.get_valid_ratio(m) for m in mlvl_masks], 1) + + memory = feat_flatten + bs, _, c = memory.shape + topk = pos_anchors.shape[1] + topk_coords_unact = inverse_sigmoid((pos_anchors)) + reference_points = pos_anchors + init_reference_out = reference_points + if self.num_co_heads > 0: + pos_trans_out = self.aux_pos_trans_norm[head_idx]( + self.aux_pos_trans[head_idx]( + self.get_proposal_pos_embed(topk_coords_unact) + ) + ) + query_pos, query = paddle.split(pos_trans_out, pos_trans_out.shape[2]//c, axis=2) + if self.with_coord_feat: + query = query + self.pos_feats_norm[head_idx]( + self.pos_feats_trans[head_idx](pos_feats) + ) + query_pos = query_pos + self.head_pos_embed.weight[head_idx] + + # decoder + inter_states, inter_references = self.decoder( + query=query, + memory=memory, + query_pos_embed=query_pos, # error + memory_mask=mask_flatten, + reference_points=reference_points, # error + value_spatial_shapes=spatial_shapes, + value_level_start_index=level_start_index, + valid_ratios=valid_ratios, + reg_branches=reg_branches, + attn_masks=attn_masks, + **kwargs + ) + + inter_references_out = inter_references + return inter_states, init_reference_out, inter_references_out