diff --git a/timm/layers/ml_decoder.py b/timm/layers/ml_decoder.py index 71460889f..2c57a86d5 100644 --- a/timm/layers/ml_decoder.py +++ b/timm/layers/ml_decoder.py @@ -321,12 +321,15 @@ def __init__( dim, num_classes, num_groups, + shared: bool = False, ): super().__init__() + # 1 group for all queries (shared_fc, use with class_embed) or 1 group for each query (default, used in paper) self.num_classes = num_classes duplicate_factor = int(num_classes / num_groups + 0.999) + num_groups = 1 if shared else num_groups self.weight = nn.Parameter(torch.Tensor(num_groups, dim, duplicate_factor)) - self.bias = nn.Parameter(torch.Tensor(num_classes)) + self.bias = nn.Parameter(torch.Tensor(self.num_classes)) nn.init.xavier_normal_(self.weight) nn.init.constant_(self.bias, 0) @@ -424,11 +427,7 @@ def __init__( drop=proj_drop, ) - # 1 group for all queries (shared_fc, use with class_embed) or 1 group for each query (default, used in paper) - duplicate_factor = int(num_classes / num_groups + 0.999) - num_fc_classes = duplicate_factor if self.shared_fc else num_classes - num_fc_groups = 1 if self.shared_fc else num_groups - self.fc = GroupLinear(dim, num_fc_classes, num_fc_groups) + self.fc = GroupLinear(dim, num_classes, num_groups, shared=self.shared_fc) def _resolve_query(self, q): if q is not None: