Skip to content

Commit

Permalink
Incorrect indexing in GroupLinear
Browse files Browse the repository at this point in the history
  • Loading branch information
fffffgggg54 committed Dec 29, 2024
1 parent 566a843 commit 86c1104
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions timm/layers/ml_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,12 +321,15 @@ def __init__(
dim,
num_classes,
num_groups,
shared: bool = False,
):
super().__init__()
# 1 group for all queries (shared_fc, use with class_embed) or 1 group for each query (default, used in paper)
self.num_classes = num_classes
duplicate_factor = int(num_classes / num_groups + 0.999)
num_groups = 1 if shared else num_groups
self.weight = nn.Parameter(torch.Tensor(num_groups, dim, duplicate_factor))
self.bias = nn.Parameter(torch.Tensor(num_classes))
self.bias = nn.Parameter(torch.Tensor(self.num_classes))
nn.init.xavier_normal_(self.weight)
nn.init.constant_(self.bias, 0)

Expand Down Expand Up @@ -424,11 +427,7 @@ def __init__(
drop=proj_drop,
)

# 1 group for all queries (shared_fc, use with class_embed) or 1 group for each query (default, used in paper)
duplicate_factor = int(num_classes / num_groups + 0.999)
num_fc_classes = duplicate_factor if self.shared_fc else num_classes
num_fc_groups = 1 if self.shared_fc else num_groups
self.fc = GroupLinear(dim, num_fc_classes, num_fc_groups)
self.fc = GroupLinear(dim, num_classes, num_groups, shared=self.shared_fc)

def _resolve_query(self, q):
if q is not None:
Expand Down

0 comments on commit 86c1104

Please sign in to comment.