diff --git a/long_net/attention.py b/long_net/attention.py index 2b3057a..18efc24 100644 --- a/long_net/attention.py +++ b/long_net/attention.py @@ -77,8 +77,31 @@ def __init__( self.proj_k = nn.Linear(dim, dim) self.proj_v = nn.Linear(dim, dim) - def get_mask(self, i, j): - return torch.ones((i, j), device=self.device, dtype=torch.bool).triu(j - i + 2) + def get_mask(self, n, device): + if self.mask is not None and self.mask.shape[-1] >= n: + return self.mask[:n, :n] + + if self.mask is None: + print('computing mask..') + + mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1) + k = 0 + segment_lengths = [4, 8, 16] + dilation_rates = [1, 2, 4] + # segment_lengths = [2048, 4096, 8192, 16384, 32768] + # dilation_rates = [1, 2, 4, 6, 12] + for i in range(len(mask)): + for j in range(len(mask[0])): + will_mask = True + for segment_length, dilation_rate in zip(segment_lengths, dilation_rates): + if np.floor(i/segment_length) == np.floor(j/segment_length) and i % dilation_rate == 0 and j % dilation_rate == 0: + will_mask = False + if will_mask: + mask[i][j] = True + k += 1 + self.register_buffer("mask", mask, persistent=False) + self.mask = mask + return mask def forward(self, x): batch_size, seq_len, _ = x.shape @@ -121,7 +144,7 @@ def forward(self, x): # if causal create a mask and apply to the output if self.causal: - mask = self.get_mask(attn_output.size(1), attn_output.size(1)) + mask = self.get_mask(n=attn_output.size(1), device='cuda:0') attn_output = attn_output.masked_fill(mask, float("-inf"))