Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of Similarity Head #60

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,35 @@ def __call__(self, X, Y, M, clf_logits, lm_logits=None, only_return_losses=False
self.opt.zero_grad()
return train_loss.item()

# TODO Implement a LossCompute class for similiraty tasks.
class SimilarityLossCompute:
"A Loss compute and train function for similarity tasks."

def __init__(self, lm_criterion, clf_criterion, lm_coef, opt=None):
self.lm_criterion = lm_criterion
self.clf_criterion = clf_criterion
self.lm_coef = lm_coef
self.opt = opt

def __call__(self, X, Y, M, clf_logits, lm_logits=None, only_return_losses=False):
# Language modeling loss
if lm_logits is not None:
x_shifted = X[:, :, 1:, 0].contiguous().view(-1)
M = M.view(-1, M.size(2))
lm_losses = self.lm_criterion(lm_logits, x_shifted)
lm_losses = lm_losses.view(X.size(0) * X.size(1), X.size(2) - 1)
lm_losses = lm_losses * M[:, 1:]
lm_losses = lm_losses.sum(1) / torch.sum(M[:, 1:], 1)
# Classification loss
clf_losses = self.clf_criterion(clf_logits, Y)
if only_return_losses:
return (clf_losses, lm_losses) if lm_logits is not None else clf_losses

if self.lm_coef > 0 and lm_logits is not None:
train_loss = clf_losses.sum() + self.lm_coef * lm_losses.sum()
else:
train_loss = clf_losses.sum()
train_loss.backward()
if self.opt is not None:
self.opt.step()
self.opt.zero_grad()
return train_loss.item()
12 changes: 6 additions & 6 deletions model_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,9 +224,8 @@ def forward(self, h, x):


class ClfHead(nn.Module):
"""Classification Head for the transformer
"""Classification Head for the transformer"""

TODO: test this class."""
def __init__(self, clf_token, cfg, n_class):
super(ClfHead, self).__init__()
self.n_embd = cfg.n_embd
Expand All @@ -247,9 +246,8 @@ def forward(self, h, x):
return clf_logits

class SimilarityHead(nn.Module):
""" Similarity Head for the transformer
""" Similarity Head for the transformer"""

TODO: test this class."""
def __init__(self, clf_token, cfg):
super(SimilarityHead, self).__init__()
self.n_embd = cfg.n_embd
Expand All @@ -264,9 +262,11 @@ def forward(self, h, x):
sim_h = h.view(-1, self.n_embd)
flat = x[..., 0].contiguous().view(-1)
sim_h = sim_h[flat == self.clf_token, :]
sim_h = self.dropout(sim_h)
sim_h = sim_h.sum(dim = 1)
#addition of the two different sequence representations
sim_h=sim_h[[range(0,sim_h.shape[0],2)],:]+sim_h[[range(1,sim_h.shape[0],2)],:]
sim_h=torch.squeeze(sim_h,dim=0)
sim_logits = self.linear(sim_h)
sim_logits=torch.squeeze(sim_logits,dim=1)

return sim_logits

Expand Down