From 86c1104a4f8cd717c99b20652efb81e9812f14ec Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sun, 29 Dec 2024 06:50:03 -0800 Subject: [PATCH] Incorrect indexing in GroupLinear --- timm/layers/ml_decoder.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/timm/layers/ml_decoder.py b/timm/layers/ml_decoder.py index 71460889f..2c57a86d5 100644 --- a/timm/layers/ml_decoder.py +++ b/timm/layers/ml_decoder.py @@ -321,12 +321,15 @@ def __init__( dim, num_classes, num_groups, + shared: bool = False, ): super().__init__() + # 1 group for all queries (shared_fc, use with class_embed) or 1 group for each query (default, used in paper) self.num_classes = num_classes duplicate_factor = int(num_classes / num_groups + 0.999) + num_groups = 1 if shared else num_groups self.weight = nn.Parameter(torch.Tensor(num_groups, dim, duplicate_factor)) - self.bias = nn.Parameter(torch.Tensor(num_classes)) + self.bias = nn.Parameter(torch.Tensor(self.num_classes)) nn.init.xavier_normal_(self.weight) nn.init.constant_(self.bias, 0) @@ -424,11 +427,7 @@ def __init__( drop=proj_drop, ) - # 1 group for all queries (shared_fc, use with class_embed) or 1 group for each query (default, used in paper) - duplicate_factor = int(num_classes / num_groups + 0.999) - num_fc_classes = duplicate_factor if self.shared_fc else num_classes - num_fc_groups = 1 if self.shared_fc else num_groups - self.fc = GroupLinear(dim, num_fc_classes, num_fc_groups) + self.fc = GroupLinear(dim, num_classes, num_groups, shared=self.shared_fc) def _resolve_query(self, q): if q is not None: