From 61c5189ad689e3b5f06f433f390d2cee720d36de Mon Sep 17 00:00:00 2001 From: conditionedstimulus Date: Wed, 6 Nov 2024 16:18:02 +0100 Subject: [PATCH] moved back to the version when I asked the review --- .../models/dab_detr/modeling_dab_detr.py | 736 ++---------------- 1 file changed, 74 insertions(+), 662 deletions(-) diff --git a/src/transformers/models/dab_detr/modeling_dab_detr.py b/src/transformers/models/dab_detr/modeling_dab_detr.py index 2250ff116be..7819ad38f73 100644 --- a/src/transformers/models/dab_detr/modeling_dab_detr.py +++ b/src/transformers/models/dab_detr/modeling_dab_detr.py @@ -22,37 +22,20 @@ from torch import Tensor, nn from ...activations import ACT2FN +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput from ...modeling_utils import PreTrainedModel from ...utils import ( ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, - is_accelerate_available, - is_scipy_available, - is_vision_available, logging, replace_return_docstrings, - requires_backends, ) from ...utils.backbone_utils import load_backbone from .configuration_dab_detr import DabDetrConfig -if is_accelerate_available(): - from accelerate import PartialState - from accelerate.utils import reduce - -if is_scipy_available(): - from scipy.optimize import linear_sum_assignment - -if is_vision_available(): - from ...image_transforms import center_to_corners_format - - -from ...modeling_attn_mask_utils import _prepare_4d_attention_mask - - logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "DabDetrConfig" @@ -346,11 +329,6 @@ def forward(self, pixel_values, pixel_mask): raise ValueError("No pixel mask provided") y_embed = pixel_mask.cumsum(1, dtype=torch.float32) x_embed = pixel_mask.cumsum(2, dtype=torch.float32) - # In place operations - # y_embed /= (y_embed[:, -1:, :] + 1e-6) - # y_embed *= self.scale - # x_embed /= (x_embed[:, :, -1:] + 1e-6) - # x_embed *= self.scale y_embed = y_embed / (y_embed[:, -1:, :] + 1e-6) * self.scale x_embed = x_embed / (x_embed[:, :, -1:] + 1e-6) * self.scale @@ -432,33 +410,25 @@ class DetrAttention(nn.Module): def __init__( self, - embed_dim: int, - num_heads: int, - dropout: float = 0.0, + config: DabDetrConfig, bias: bool = True, ): super().__init__() - self.embed_dim = embed_dim - self.num_heads = num_heads - self.dropout = dropout - self.head_dim = embed_dim // num_heads - if self.head_dim * num_heads != self.embed_dim: + self.config = config + self.hidden_size = config.hidden_size + self.num_heads = config.encoder_attention_heads + self.attention_dropout = config.attention_dropout + self.head_dim = self.hidden_size // self.num_heads + if self.head_dim * self.num_heads != self.hidden_size: raise ValueError( - f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" - f" {num_heads})." + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:" + f" {self.num_heads})." ) self.scaling = self.head_dim**-0.5 - - self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - - def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int): - return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - - def with_pos_embed(self, tensor: torch.Tensor, object_queries: Optional[Tensor]): - return tensor if object_queries is None else tensor + object_queries + self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=bias) + self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=bias) + self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=bias) + self.out_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=bias) def forward( self, @@ -466,74 +436,46 @@ def forward( attention_mask: Optional[torch.Tensor] = None, object_queries: Optional[torch.Tensor] = None, key_value_states: Optional[torch.Tensor] = None, - spatial_position_embeddings: Optional[torch.Tensor] = None, output_attentions: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - batch_size, target_len, embed_dim = hidden_states.size() + batch_size, q_len, embed_dim = hidden_states.size() # add position embeddings to the hidden states before projecting to queries and keys if object_queries is not None: hidden_states_original = hidden_states - hidden_states = self.with_pos_embed(hidden_states, object_queries) - - # add key-value position embeddings to the key value states - if spatial_position_embeddings is not None: - key_value_states = self.with_pos_embed(key_value_states, spatial_position_embeddings) + hidden_states = hidden_states + object_queries - # get query proj query_states = self.q_proj(hidden_states) * self.scaling key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states_original) - # get key, value proj - query_states = ( - query_states.view(batch_size, target_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - ) - key_states = key_states.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - value_states = value_states.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - - proj_shape = (batch_size * self.num_heads, -1, self.head_dim) - query_states = query_states.view(*proj_shape) - key_states = key_states.view(*proj_shape) - value_states = value_states.view(*proj_shape) - - source_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) if attention_mask is not None: - attn_weights = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attention_mask - attn_weights = attn_weights.view(batch_size * self.num_heads, target_len, source_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(batch_size, self.num_heads, target_len, source_len) - attn_weights = attn_weights_reshaped.view(batch_size * self.num_heads, target_len, source_len) - else: - attn_weights_reshaped = None + attn_weights = attn_weights + attention_mask - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - attn_output = torch.bmm(attn_probs, value_states) + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) - if attn_output.size() != (batch_size * self.num_heads, target_len, self.head_dim): + if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim): raise ValueError( - f"`attn_output` should be of size {(batch_size, self.num_heads, target_len, self.head_dim)}, but is" + f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is" f" {attn_output.size()}" ) - attn_output = attn_output.view(batch_size, self.num_heads, target_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - attn_output = attn_output.reshape(batch_size, target_len, embed_dim) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(batch_size, q_len, embed_dim) attn_output = self.out_proj(attn_output) - return attn_output, attn_weights_reshaped + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights # Modified from transformers.models.conditional_detr.modeling_conditional_detr.ConditionalDetrAttention with ConditionalDetr->DABDETR,Conditional DETR->DabDetr @@ -551,7 +493,7 @@ def __init__(self, config: DabDetrConfig, bias: bool = True, is_cross: bool = Fa self.embed_dim = config.hidden_size * 2 if is_cross else config.hidden_size self.output_dim = config.hidden_size self.attention_heads = config.decoder_attention_heads - self.dropout = config.attention_dropout + self.attention_dropout = config.attention_dropout self.attention_head_dim = self.embed_dim // self.attention_heads if self.attention_head_dim * self.attention_heads != self.embed_dim: raise ValueError( @@ -565,7 +507,6 @@ def __init__(self, config: DabDetrConfig, bias: bool = True, is_cross: bool = Fa f"output_dim must be divisible by attention_heads (got `output_dim`: {self.output_dim} and `attention_heads`: {self.attention_heads})." ) self.scaling = self.attention_head_dim**-0.5 - self.output_proj = nn.Linear(self.output_dim, self.output_dim, bias=bias) def forward( @@ -578,77 +519,38 @@ def forward( ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" - batch_size, target_len, _ = hidden_states.size() + batch_size, q_len, _ = hidden_states.size() # scaling query and refactor key-, value states query_states = hidden_states * self.scaling - key_states = ( - key_states.view(batch_size, -1, self.attention_heads, self.attention_head_dim).transpose(1, 2).contiguous() - ) - value_states = ( - value_states.view(batch_size, -1, self.attention_heads, self.values_head_dim).transpose(1, 2).contiguous() - ) - - # projection of query,key, value states - projected_shape = (batch_size * self.attention_heads, -1, self.attention_head_dim) - values_projected_shape = (batch_size * self.attention_heads, -1, self.values_head_dim) - query_states = ( - query_states.view(batch_size, -1, self.attention_heads, self.attention_head_dim) - .transpose(1, 2) - .contiguous() - ) - query_states = query_states.view(*projected_shape) - key_states = key_states.view(*projected_shape) - value_states = value_states.view(*values_projected_shape) - - source_len = key_states.size(1) - - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + query_states = query_states.view(batch_size, -1, self.attention_heads, self.attention_head_dim).transpose(1, 2) + key_states = key_states.view(batch_size, -1, self.attention_heads, self.attention_head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, -1, self.attention_heads, self.values_head_dim).transpose(1, 2) - if attn_weights.size() != (batch_size * self.attention_heads, target_len, source_len): - raise ValueError( - f"Attention weights should be of size {(batch_size * self.attention_heads, target_len, source_len)}, but is" - f" {attn_weights.size()}" - ) + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) if attention_mask is not None: - if attention_mask.size() != (batch_size, 1, target_len, source_len): - raise ValueError( - f"Attention mask should be of size {(batch_size, 1, target_len, source_len)}, but is" - f" {attention_mask.size()}" - ) - attn_weights = attn_weights.view(batch_size, self.attention_heads, target_len, source_len) + attention_mask - attn_weights = attn_weights.view(batch_size * self.attention_heads, target_len, source_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(batch_size, self.attention_heads, target_len, source_len) - attn_weights = attn_weights_reshaped.view(batch_size * self.attention_heads, target_len, source_len) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attn_weights = attn_weights + attention_mask - attn_output = torch.bmm(attn_probs, value_states) + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_probs = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_probs, value_states) - if attn_output.size() != (batch_size * self.attention_heads, target_len, self.values_head_dim): + if attn_output.size() != (batch_size, self.attention_heads, q_len, self.values_head_dim): raise ValueError( - f"`attn_output` should be of size {(batch_size, self.attention_heads, target_len, self.values_head_dim)}, but is" + f"`attn_output` should be of size {(batch_size, self.attention_heads, q_len, self.values_head_dim)}, but is" f" {attn_output.size()}" ) - attn_output = attn_output.view(batch_size, self.attention_heads, target_len, self.values_head_dim) - attn_output = attn_output.transpose(1, 2) - attn_output = attn_output.reshape(batch_size, target_len, self.output_dim) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(batch_size, q_len, self.output_dim) attn_output = self.output_proj(attn_output) - return attn_output, attn_weights_reshaped + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights class DabDetrDecoderLayerSelfAttention(nn.Module): @@ -808,18 +710,14 @@ def forward(self, hidden_states: torch.Tensor): class DabDetrEncoderLayer(nn.Module): def __init__(self, config: DabDetrConfig): super().__init__() - self.embed_dim = config.hidden_size - self.self_attn = DetrAttention( - embed_dim=self.embed_dim, - num_heads=config.encoder_attention_heads, - dropout=config.attention_dropout, - ) - self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.hidden_size = config.hidden_size + self.self_attn = DetrAttention(config) + self.self_attn_layer_norm = nn.LayerNorm(self.hidden_size) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] - self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) - self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) - self.final_layer_norm = nn.LayerNorm(self.embed_dim) + self.fc1 = nn.Linear(self.hidden_size, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.hidden_size) + self.final_layer_norm = nn.LayerNorm(self.hidden_size) def forward( self, @@ -1018,7 +916,7 @@ def _init_weights(self, module): pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): Pixel values. Padding will be ignored by default should you provide it. - Pixel values can be obtained using [`AutoImageProcessor`]. See [`DabDetrImageProcessor.__call__`] + Pixel values can be obtained using [`AutoImageProcessor`]. See [`DetrImageProcessor.__call__`] for details. pixel_mask (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*): @@ -1188,19 +1086,20 @@ def __init__(self, config: DabDetrConfig): [DabDetrDecoderLayer(config, is_first=(layer_id == 0)) for layer_id in range(config.decoder_layers)] ) # in DAB-DETR, the decoder uses layernorm after the last decoder layer output - self.layernorm = nn.LayerNorm(config.hidden_size) - hidden_size = config.hidden_size + self.hidden_size = config.hidden_size + self.layernorm = nn.LayerNorm(self.hidden_size) # Default cond-elewise - self.query_scale = DabDetrMLP(hidden_size, hidden_size, hidden_size, 2) + self.query_scale = DabDetrMLP(self.hidden_size, self.hidden_size, self.hidden_size, 2) - self.ref_point_head = DabDetrMLP(config.query_dim // 2 * hidden_size, hidden_size, hidden_size, 2) + self.ref_point_head = DabDetrMLP( + config.query_dim // 2 * self.hidden_size, self.hidden_size, self.hidden_size, 2 + ) self.bbox_embed = None - self.hidden_size = hidden_size # Default decoder_modulate_hw_attn is True - self.ref_anchor_head = DabDetrMLP(hidden_size, hidden_size, 2, 2) + self.ref_anchor_head = DabDetrMLP(self.hidden_size, self.hidden_size, 2, 2) # Initialize weights and apply final processing self.post_init() @@ -1277,7 +1176,7 @@ def forward( pos_transformation = 1 if layer_id == 0 else self.query_scale(hidden_states) # apply transformation - query_sine_embed = query_sine_embed[..., : self.config.hidden_size] * pos_transformation + query_sine_embed = query_sine_embed[..., : self.hidden_size] * pos_transformation # modulated HW attentions refHW_cond = self.ref_anchor_head(hidden_states).sigmoid() # nq, bs, 2 @@ -1389,7 +1288,9 @@ def __init__(self, config: DabDetrConfig): self.query_refpoint_embeddings.weight.data[:, :2].requires_grad = False # Create projection layer - self.input_projection = nn.Conv2d(self.backbone.intermediate_channel_sizes[-1], config.hidden_size, kernel_size=1) + self.input_projection = nn.Conv2d( + self.backbone.intermediate_channel_sizes[-1], config.hidden_size, kernel_size=1 + ) self.backbone = DabDetrConvModel(self.backbone, object_queries) self.encoder = DabDetrEncoder(config) @@ -1404,7 +1305,7 @@ def __init__(self, config: DabDetrConfig): Warning("num_patterns should be int but {}".format(type(self.num_patterns))) self.num_patterns = 0 if self.num_patterns > 0: - self.patterns = nn.Embedding(self.num_patterns, config.hidden_size) + self.patterns = nn.Embedding(self.num_patterns, self.hidden_size) self.aux_loss = config.auxiliary_loss @@ -1630,427 +1531,6 @@ def forward(self, q, k, mask: Optional[Tensor] = None): return weights -# Copied from transformers.models.detr.modeling_detr.dice_loss -def dice_loss(inputs, targets, num_boxes): - """ - Compute the DICE loss, similar to generalized IOU for masks - - Args: - inputs: A float tensor of arbitrary shape. - The predictions for each example. - targets: A float tensor with the same shape as inputs. Stores the binary - classification label for each element in inputs (0 for the negative class and 1 for the positive - class). - """ - inputs = inputs.sigmoid() - inputs = inputs.flatten(1) - numerator = 2 * (inputs * targets).sum(1) - denominator = inputs.sum(-1) + targets.sum(-1) - loss = 1 - (numerator + 1) / (denominator + 1) - return loss.sum() / num_boxes - - -# Copied from transformers.models.detr.modeling_detr.sigmoid_focal_loss -def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2): - """ - Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. - - Args: - inputs (`torch.FloatTensor` of arbitrary shape): - The predictions for each example. - targets (`torch.FloatTensor` with the same shape as `inputs`) - A tensor storing the binary classification label for each element in the `inputs` (0 for the negative class - and 1 for the positive class). - alpha (`float`, *optional*, defaults to `0.25`): - Optional weighting factor in the range (0,1) to balance positive vs. negative examples. - gamma (`int`, *optional*, defaults to `2`): - Exponent of the modulating factor (1 - p_t) to balance easy vs hard examples. - - Returns: - Loss tensor - """ - prob = inputs.sigmoid() - ce_loss = nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction="none") - # add modulating factor - p_t = prob * targets + (1 - prob) * (1 - targets) - loss = ce_loss * ((1 - p_t) ** gamma) - - if alpha >= 0: - alpha_t = alpha * targets + (1 - alpha) * (1 - targets) - loss = alpha_t * loss - - return loss.mean(1).sum() / num_boxes - - -# Copied from transformers.models.conditional_detr.modeling_conditional_detr.ConditionalDetrLoss with ConditionalDetr->DabDetr -class DabDetrLoss(nn.Module): - """ - This class computes the losses for DabDetrForObjectDetection/DabDetrForSegmentation. The process - happens in two steps: 1) we compute hungarian assignment between ground truth boxes and the outputs of the model 2) - we supervise each pair of matched ground-truth / prediction (supervise class and box). - - Args: - matcher (`DabDetrHungarianMatcher`): - Module able to compute a matching between targets and proposals. - num_classes (`int`): - Number of object categories, omitting the special no-object category. - focal_alpha (`float`): - Alpha parameter in focal loss. - losses (`List[str]`): - List of all the losses to be applied. See `get_loss` for a list of all available losses. - """ - - # Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrLoss.__init__ - def __init__(self, matcher, num_classes, focal_alpha, losses): - super().__init__() - self.matcher = matcher - self.num_classes = num_classes - self.focal_alpha = focal_alpha - self.losses = losses - - # Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrLoss.loss_labels - def loss_labels(self, outputs, targets, indices, num_boxes): - """ - Classification loss (Binary focal loss) targets dicts must contain the key "class_labels" containing a tensor - of dim [nb_target_boxes] - """ - if "logits" not in outputs: - raise KeyError("No logits were found in the outputs") - source_logits = outputs["logits"] - - idx = self._get_source_permutation_idx(indices) - target_classes_o = torch.cat([t["class_labels"][J] for t, (_, J) in zip(targets, indices)]) - target_classes = torch.full( - source_logits.shape[:2], self.num_classes, dtype=torch.int64, device=source_logits.device - ) - target_classes[idx] = target_classes_o - - target_classes_onehot = torch.zeros( - [source_logits.shape[0], source_logits.shape[1], source_logits.shape[2] + 1], - dtype=source_logits.dtype, - layout=source_logits.layout, - device=source_logits.device, - ) - target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1) - - target_classes_onehot = target_classes_onehot[:, :, :-1] - loss_ce = ( - sigmoid_focal_loss(source_logits, target_classes_onehot, num_boxes, alpha=self.focal_alpha, gamma=2) - * source_logits.shape[1] - ) - losses = {"loss_ce": loss_ce} - - return losses - - @torch.no_grad() - # Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrLoss.loss_cardinality - def loss_cardinality(self, outputs, targets, indices, num_boxes): - """ - Compute the cardinality error, i.e. the absolute error in the number of predicted non-empty boxes. - - This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients. - """ - logits = outputs["logits"] - device = logits.device - target_lengths = torch.as_tensor([len(v["class_labels"]) for v in targets], device=device) - # Count the number of predictions that are NOT "no-object" (which is the last class) - card_pred = (logits.argmax(-1) != logits.shape[-1] - 1).sum(1) - card_err = nn.functional.l1_loss(card_pred.float(), target_lengths.float()) - losses = {"cardinality_error": card_err} - return losses - - # Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrLoss.loss_boxes - def loss_boxes(self, outputs, targets, indices, num_boxes): - """ - Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss. - - Targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]. The target boxes - are expected in format (center_x, center_y, w, h), normalized by the image size. - """ - if "pred_boxes" not in outputs: - raise KeyError("No predicted boxes found in outputs") - idx = self._get_source_permutation_idx(indices) - source_boxes = outputs["pred_boxes"][idx] - target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0) - - loss_bbox = nn.functional.l1_loss(source_boxes, target_boxes, reduction="none") - - losses = {} - losses["loss_bbox"] = loss_bbox.sum() / num_boxes - - loss_giou = 1 - torch.diag( - generalized_box_iou(center_to_corners_format(source_boxes), center_to_corners_format(target_boxes)) - ) - losses["loss_giou"] = loss_giou.sum() / num_boxes - return losses - - # Copied from transformers.models.detr.modeling_detr.DetrLoss.loss_masks - def loss_masks(self, outputs, targets, indices, num_boxes): - """ - Compute the losses related to the masks: the focal loss and the dice loss. - - Targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]. - """ - if "pred_masks" not in outputs: - raise KeyError("No predicted masks found in outputs") - - source_idx = self._get_source_permutation_idx(indices) - target_idx = self._get_target_permutation_idx(indices) - source_masks = outputs["pred_masks"] - source_masks = source_masks[source_idx] - masks = [t["masks"] for t in targets] - # TODO use valid to mask invalid areas due to padding in loss - target_masks, valid = nested_tensor_from_tensor_list(masks).decompose() - target_masks = target_masks.to(source_masks) - target_masks = target_masks[target_idx] - - # upsample predictions to the target size - source_masks = nn.functional.interpolate( - source_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False - ) - source_masks = source_masks[:, 0].flatten(1) - - target_masks = target_masks.flatten(1) - target_masks = target_masks.view(source_masks.shape) - losses = { - "loss_mask": sigmoid_focal_loss(source_masks, target_masks, num_boxes), - "loss_dice": dice_loss(source_masks, target_masks, num_boxes), - } - return losses - - # Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrLoss._get_source_permutation_idx - def _get_source_permutation_idx(self, indices): - # permute predictions following indices - batch_idx = torch.cat([torch.full_like(source, i) for i, (source, _) in enumerate(indices)]) - source_idx = torch.cat([source for (source, _) in indices]) - return batch_idx, source_idx - - # Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrLoss._get_target_permutation_idx - def _get_target_permutation_idx(self, indices): - # permute targets following indices - batch_idx = torch.cat([torch.full_like(target, i) for i, (_, target) in enumerate(indices)]) - target_idx = torch.cat([target for (_, target) in indices]) - return batch_idx, target_idx - - # Copied from transformers.models.detr.modeling_detr.DetrLoss.get_loss - def get_loss(self, loss, outputs, targets, indices, num_boxes): - loss_map = { - "labels": self.loss_labels, - "cardinality": self.loss_cardinality, - "boxes": self.loss_boxes, - "masks": self.loss_masks, - } - if loss not in loss_map: - raise ValueError(f"Loss {loss} not supported") - return loss_map[loss](outputs, targets, indices, num_boxes) - - # Copied from transformers.models.detr.modeling_detr.DetrLoss.forward - def forward(self, outputs, targets): - """ - This performs the loss computation. - - Args: - outputs (`dict`, *optional*): - Dictionary of tensors, see the output specification of the model for the format. - targets (`List[dict]`, *optional*): - List of dicts, such that `len(targets) == batch_size`. The expected keys in each dict depends on the - losses applied, see each loss' doc. - """ - outputs_without_aux = {k: v for k, v in outputs.items() if k != "auxiliary_outputs"} - - # Retrieve the matching between the outputs of the last layer and the targets - indices = self.matcher(outputs_without_aux, targets) - - # Compute the average number of target boxes across all nodes, for normalization purposes - num_boxes = sum(len(t["class_labels"]) for t in targets) - num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) - - world_size = 1 - if is_accelerate_available(): - if PartialState._shared_state != {}: - num_boxes = reduce(num_boxes) - world_size = PartialState().num_processes - num_boxes = torch.clamp(num_boxes / world_size, min=1).item() - - # Compute all the requested losses - losses = {} - for loss in self.losses: - losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes)) - - # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. - if "auxiliary_outputs" in outputs: - for i, auxiliary_outputs in enumerate(outputs["auxiliary_outputs"]): - indices = self.matcher(auxiliary_outputs, targets) - for loss in self.losses: - if loss == "masks": - # Intermediate masks losses are too costly to compute, we ignore them. - continue - l_dict = self.get_loss(loss, auxiliary_outputs, targets, indices, num_boxes) - l_dict = {k + f"_{i}": v for k, v in l_dict.items()} - losses.update(l_dict) - - return losses - - -# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrHungarianMatcher with DeformableDetr->DabDetr -class DabDetrHungarianMatcher(nn.Module): - """ - This class computes an assignment between the targets and the predictions of the network. - - For efficiency reasons, the targets don't include the no_object. Because of this, in general, there are more - predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, while the others are - un-matched (and thus treated as non-objects). - - Args: - class_cost: - The relative weight of the classification error in the matching cost. - bbox_cost: - The relative weight of the L1 error of the bounding box coordinates in the matching cost. - giou_cost: - The relative weight of the giou loss of the bounding box in the matching cost. - """ - - def __init__(self, class_cost: float = 1, bbox_cost: float = 1, giou_cost: float = 1): - super().__init__() - requires_backends(self, ["scipy"]) - - self.class_cost = class_cost - self.bbox_cost = bbox_cost - self.giou_cost = giou_cost - if class_cost == 0 and bbox_cost == 0 and giou_cost == 0: - raise ValueError("All costs of the Matcher can't be 0") - - @torch.no_grad() - def forward(self, outputs, targets): - """ - Args: - outputs (`dict`): - A dictionary that contains at least these entries: - * "logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits - * "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates. - targets (`List[dict]`): - A list of targets (len(targets) = batch_size), where each target is a dict containing: - * "class_labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of - ground-truth - objects in the target) containing the class labels - * "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates. - - Returns: - `List[Tuple]`: A list of size `batch_size`, containing tuples of (index_i, index_j) where: - - index_i is the indices of the selected predictions (in order) - - index_j is the indices of the corresponding selected targets (in order) - For each batch element, it holds: len(index_i) = len(index_j) = min(num_queries, num_target_boxes) - """ - batch_size, num_queries = outputs["logits"].shape[:2] - - # We flatten to compute the cost matrices in a batch - out_prob = outputs["logits"].flatten(0, 1).sigmoid() # [batch_size * num_queries, num_classes] - out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4] - - # Also concat the target labels and boxes - target_ids = torch.cat([v["class_labels"] for v in targets]) - target_bbox = torch.cat([v["boxes"] for v in targets]) - - # Compute the classification cost. - alpha = 0.25 - gamma = 2.0 - neg_cost_class = (1 - alpha) * (out_prob**gamma) * (-(1 - out_prob + 1e-8).log()) - pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log()) - class_cost = pos_cost_class[:, target_ids] - neg_cost_class[:, target_ids] - - # Compute the L1 cost between boxes - bbox_cost = torch.cdist(out_bbox, target_bbox, p=1) - - # Compute the giou cost between boxes - giou_cost = -generalized_box_iou(center_to_corners_format(out_bbox), center_to_corners_format(target_bbox)) - - # Final cost matrix - cost_matrix = self.bbox_cost * bbox_cost + self.class_cost * class_cost + self.giou_cost * giou_cost - cost_matrix = cost_matrix.view(batch_size, num_queries, -1).cpu() - - sizes = [len(v["boxes"]) for v in targets] - indices = [linear_sum_assignment(c[i]) for i, c in enumerate(cost_matrix.split(sizes, -1))] - return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] - - -# Copied from transformers.models.detr.modeling_detr._upcast -def _upcast(t: Tensor) -> Tensor: - # Protects from numerical overflows in multiplications by upcasting to the equivalent higher type - if t.is_floating_point(): - return t if t.dtype in (torch.float32, torch.float64) else t.float() - else: - return t if t.dtype in (torch.int32, torch.int64) else t.int() - - -# Copied from transformers.models.detr.modeling_detr.box_area -def box_area(boxes: Tensor) -> Tensor: - """ - Computes the area of a set of bounding boxes, which are specified by its (x1, y1, x2, y2) coordinates. - - Args: - boxes (`torch.FloatTensor` of shape `(number_of_boxes, 4)`): - Boxes for which the area will be computed. They are expected to be in (x1, y1, x2, y2) format with `0 <= x1 - < x2` and `0 <= y1 < y2`. - - Returns: - `torch.FloatTensor`: a tensor containing the area for each box. - """ - boxes = _upcast(boxes) - return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) - - -# Copied from transformers.models.detr.modeling_detr.box_iou -def box_iou(boxes1, boxes2): - area1 = box_area(boxes1) - area2 = box_area(boxes2) - - left_top = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] - right_bottom = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] - - width_height = (right_bottom - left_top).clamp(min=0) # [N,M,2] - inter = width_height[:, :, 0] * width_height[:, :, 1] # [N,M] - - union = area1[:, None] + area2 - inter - - iou = inter / union - return iou, union - - -# Copied from transformers.models.detr.modeling_detr.generalized_box_iou -def generalized_box_iou(boxes1, boxes2): - """ - Generalized IoU from https://giou.stanford.edu/. The boxes should be in [x0, y0, x1, y1] (corner) format. - - Returns: - `torch.FloatTensor`: a [N, M] pairwise matrix, where N = len(boxes1) and M = len(boxes2) - """ - # degenerate boxes gives inf / nan results - # so do an early check - if not (boxes1[:, 2:] >= boxes1[:, :2]).all(): - raise ValueError(f"boxes1 must be in [x0, y0, x1, y1] (corner) format, but got {boxes1}") - if not (boxes2[:, 2:] >= boxes2[:, :2]).all(): - raise ValueError(f"boxes2 must be in [x0, y0, x1, y1] (corner) format, but got {boxes2}") - iou, union = box_iou(boxes1, boxes2) - - top_left = torch.min(boxes1[:, None, :2], boxes2[:, :2]) - bottom_right = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) - - width_height = (bottom_right - top_left).clamp(min=0) # [N,M,2] - area = width_height[:, :, 0] * width_height[:, :, 1] - - return iou - (area - union) / area - - -# Copied from transformers.models.detr.modeling_detr._max_by_axis -def _max_by_axis(the_list): - # type: (List[List[int]]) -> List[int] - maxes = the_list[0] - for sublist in the_list[1:]: - for index, item in enumerate(sublist): - maxes[index] = max(maxes[index], item) - return maxes - - @add_start_docstrings( """ DAB_DETR Model (consisting of a backbone and encoder-decoder Transformer) with object detection heads on @@ -2182,44 +1662,16 @@ def forward( tmp[..., : self.query_dim] += reference_before_sigmoid outputs_coord = tmp.sigmoid() - loss, loss_dict, auxiliary_outputs = None, None, None pred_boxes = outputs_coord[-1] + loss, loss_dict, auxiliary_outputs = None, None, None if labels is not None: - # First: create the matcher - matcher = DabDetrHungarianMatcher( - class_cost=self.config.class_cost, bbox_cost=self.config.bbox_cost, giou_cost=self.config.giou_cost - ) - # Second: create the criterion - losses = ["labels", "boxes", "cardinality"] - criterion = DabDetrLoss( - matcher=matcher, - num_classes=self.config.num_labels, - focal_alpha=self.config.focal_alpha, - losses=losses, - ) - criterion.to(self.device) - - # Third: compute the losses, based on outputs and labels - outputs_loss = {} - outputs_loss["logits"] = logits - outputs_loss["pred_boxes"] = pred_boxes - + outputs_class = None if self.config.auxiliary_loss: outputs_class = self.class_embed(intermediate_hidden_states) - auxiliary_outputs = self._set_aux_loss(outputs_class, outputs_coord) - outputs_loss["auxiliary_outputs"] = auxiliary_outputs - - loss_dict = criterion(outputs_loss, labels) - # Fourth: compute total loss, as a weighted sum of the various losses - weight_dict = {"loss_ce": self.config.cls_loss_coefficient, "loss_bbox": self.config.bbox_loss_coefficient} - weight_dict["loss_giou"] = self.config.giou_loss_coefficient - if self.config.auxiliary_loss: - aux_weight_dict = {} - for i in range(self.config.decoder_layers - 1): - aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()}) - weight_dict.update(aux_weight_dict) - loss = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) + loss, loss_dict, auxiliary_outputs = self.loss_function( + logits, labels, self.device, pred_boxes, self.config, outputs_class, outputs_coord + ) if not return_dict: if auxiliary_outputs is not None: @@ -2245,46 +1697,6 @@ def forward( ) -# Copied from transformers.models.detr.modeling_detr.NestedTensor -class NestedTensor: - def __init__(self, tensors, mask: Optional[Tensor]): - self.tensors = tensors - self.mask = mask - - def to(self, device): - cast_tensor = self.tensors.to(device) - mask = self.mask - if mask is not None: - cast_mask = mask.to(device) - else: - cast_mask = None - return NestedTensor(cast_tensor, cast_mask) - - def decompose(self): - return self.tensors, self.mask - - def __repr__(self): - return str(self.tensors) - - -# Copied from transformers.models.detr.modeling_detr.nested_tensor_from_tensor_list -def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): - if tensor_list[0].ndim == 3: - max_size = _max_by_axis([list(img.shape) for img in tensor_list]) - batch_shape = [len(tensor_list)] + max_size - batch_size, num_channels, height, width = batch_shape - dtype = tensor_list[0].dtype - device = tensor_list[0].device - tensor = torch.zeros(batch_shape, dtype=dtype, device=device) - mask = torch.ones((batch_size, height, width), dtype=torch.bool, device=device) - for img, pad_img, m in zip(tensor_list, tensor, mask): - pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) - m[: img.shape[1], : img.shape[2]] = False - else: - raise ValueError("Only 3-dimensional tensors are supported") - return NestedTensor(tensor, mask) - - __all__ = [ "DabDetrForObjectDetection", "DabDetrModel",