Skip to content

Commit

Permalink
Feature: allow custom query in forward
Browse files Browse the repository at this point in the history
  • Loading branch information
fffffgggg54 committed Dec 27, 2024
1 parent b927237 commit 2c75ebd
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions timm/layers/ml_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,13 +407,13 @@ def __init__(
)
self.fc = GroupLinear(dim, num_classes, num_groups)

def forward(self, x):
def forward(self, x, q=None):
# BCHW to BNC
if(len(x.shape) == 4):
x = x.flatten(2).transpose(1, 2)

x = self.act(self.proj(x))
q = torch.cat([x.weight for x in [self.query_embed, self.class_embed] if x is not None], dim=1)
q = q if q is not None else torch.cat([x.weight for x in [self.query_embed, self.class_embed] if x is not None], dim=1)
q = self.embed_norm(self.embed_drop(q))
x = self.attn(q, self.norm1(x))# + q.unsqueeze(1)
x = x + self.mlp(self.norm2(x))
Expand Down

0 comments on commit 2c75ebd

Please sign in to comment.