diff --git a/loss.py b/loss.py index 61023b8..6747026 100644 --- a/loss.py +++ b/loss.py @@ -66,4 +66,35 @@ def __call__(self, X, Y, M, clf_logits, lm_logits=None, only_return_losses=False self.opt.zero_grad() return train_loss.item() -# TODO Implement a LossCompute class for similiraty tasks. +class SimilarityLossCompute: + "A Loss compute and train function for similarity tasks." + + def __init__(self, lm_criterion, clf_criterion, lm_coef, opt=None): + self.lm_criterion = lm_criterion + self.clf_criterion = clf_criterion + self.lm_coef = lm_coef + self.opt = opt + + def __call__(self, X, Y, M, clf_logits, lm_logits=None, only_return_losses=False): + # Language modeling loss + if lm_logits is not None: + x_shifted = X[:, :, 1:, 0].contiguous().view(-1) + M = M.view(-1, M.size(2)) + lm_losses = self.lm_criterion(lm_logits, x_shifted) + lm_losses = lm_losses.view(X.size(0) * X.size(1), X.size(2) - 1) + lm_losses = lm_losses * M[:, 1:] + lm_losses = lm_losses.sum(1) / torch.sum(M[:, 1:], 1) + # Classification loss + clf_losses = self.clf_criterion(clf_logits, Y) + if only_return_losses: + return (clf_losses, lm_losses) if lm_logits is not None else clf_losses + + if self.lm_coef > 0 and lm_logits is not None: + train_loss = clf_losses.sum() + self.lm_coef * lm_losses.sum() + else: + train_loss = clf_losses.sum() + train_loss.backward() + if self.opt is not None: + self.opt.step() + self.opt.zero_grad() + return train_loss.item() diff --git a/model_pytorch.py b/model_pytorch.py index 37a7b66..bc3fae2 100644 --- a/model_pytorch.py +++ b/model_pytorch.py @@ -224,9 +224,8 @@ def forward(self, h, x): class ClfHead(nn.Module): - """Classification Head for the transformer + """Classification Head for the transformer""" - TODO: test this class.""" def __init__(self, clf_token, cfg, n_class): super(ClfHead, self).__init__() self.n_embd = cfg.n_embd @@ -247,9 +246,8 @@ def forward(self, h, x): return clf_logits class SimilarityHead(nn.Module): - """ Similarity Head for the transformer + """ Similarity Head for the transformer""" - TODO: test this class.""" def __init__(self, clf_token, cfg): super(SimilarityHead, self).__init__() self.n_embd = cfg.n_embd @@ -264,9 +262,11 @@ def forward(self, h, x): sim_h = h.view(-1, self.n_embd) flat = x[..., 0].contiguous().view(-1) sim_h = sim_h[flat == self.clf_token, :] - sim_h = self.dropout(sim_h) - sim_h = sim_h.sum(dim = 1) + #addition of the two different sequence representations + sim_h=sim_h[[range(0,sim_h.shape[0],2)],:]+sim_h[[range(1,sim_h.shape[0],2)],:] + sim_h=torch.squeeze(sim_h,dim=0) sim_logits = self.linear(sim_h) + sim_logits=torch.squeeze(sim_logits,dim=1) return sim_logits