diff --git a/timm/layers/ml_decoder.py b/timm/layers/ml_decoder.py index 32811cb85..dc9e7f59d 100644 --- a/timm/layers/ml_decoder.py +++ b/timm/layers/ml_decoder.py @@ -407,13 +407,13 @@ def __init__( ) self.fc = GroupLinear(dim, num_classes, num_groups) - def forward(self, x): + def forward(self, x, q=None): # BCHW to BNC if(len(x.shape) == 4): x = x.flatten(2).transpose(1, 2) x = self.act(self.proj(x)) - q = torch.cat([x.weight for x in [self.query_embed, self.class_embed] if x is not None], dim=1) + q = q if q is not None else torch.cat([x.weight for x in [self.query_embed, self.class_embed] if x is not None], dim=1) q = self.embed_norm(self.embed_drop(q)) x = self.attn(q, self.norm1(x))# + q.unsqueeze(1) x = x + self.mlp(self.norm2(x))