You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In the paper, you mentioned, "We find that adopting the context-aware classifier to calculate the cosine similarities yields better results than the commonly used dot product...". However, did you only use the dot product in the code as the get_pred function below? I don't know where is the cosine similarity.
def get_pred(self, x, proto):
b, c, h, w = x.size()[:]
if len(proto.shape[:]) == 3:
# x: [b, c, h, w]
# proto: [b, cls, c]
cls_num = proto.size(1)
x = x / (torch.norm(x, 2, 1, True) + 1e-12)
proto = proto / (torch.norm(proto, 2, -1, True) + 1e-12) # b, n, c
x = x.contiguous().view(b, c, hw) # b, c, hw
pred = proto @ x # b, cls, hw
elif len(proto.shape[:]) == 2:
# x: [b, c, h, w]
# proto: [cls, c]
cls_num = proto.size(0)
x = x / (torch.norm(x, 2, 1, True)+ 1e-12)
proto = proto / (torch.norm(proto, 2, 1, True)+ 1e-12)
x = x.contiguous().view(b, c, hw) # b, c, hw
proto = proto.unsqueeze(0) # 1, cls, c
pred = proto @ x # b, cls, hw
pred = pred.contiguous().view(b, cls_num, h, w)
return pred * 15
The text was updated successfully, but these errors were encountered:
In the paper, you mentioned, "We find that adopting the context-aware classifier to calculate the cosine similarities yields better results than the commonly used dot product...". However, did you only use the dot product in the code as the get_pred function below? I don't know where is the cosine similarity.
def get_pred(self, x, proto):
b, c, h, w = x.size()[:]
if len(proto.shape[:]) == 3:
# x: [b, c, h, w]
# proto: [b, cls, c]
cls_num = proto.size(1)
x = x / (torch.norm(x, 2, 1, True) + 1e-12)
proto = proto / (torch.norm(proto, 2, -1, True) + 1e-12) # b, n, c
x = x.contiguous().view(b, c, hw) # b, c, hw
pred = proto @ x # b, cls, hw
elif len(proto.shape[:]) == 2:
# x: [b, c, h, w]
# proto: [cls, c]
cls_num = proto.size(0)
x = x / (torch.norm(x, 2, 1, True)+ 1e-12)
proto = proto / (torch.norm(proto, 2, 1, True)+ 1e-12)
x = x.contiguous().view(b, c, hw) # b, c, hw
proto = proto.unsqueeze(0) # 1, cls, c
pred = proto @ x # b, cls, hw
pred = pred.contiguous().view(b, cls_num, h, w)
return pred * 15
The text was updated successfully, but these errors were encountered: