Skip to content

Commit

Permalink
add linear space attention layer, using a specific one from meta ai
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 24, 2023
1 parent 3cc44f7 commit 72a86e8
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 1 deletion.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,13 @@ codes = tokenizer(videos, return_codes = True)
url = {https://api.semanticscholar.org/CorpusID:239016890}
}
```

```bibtex
@inproceedings{ElNouby2021XCiTCI,
title = {XCiT: Cross-Covariance Image Transformers},
author = {Alaaeldin El-Nouby and Hugo Touvron and Mathilde Caron and Piotr Bojanowski and Matthijs Douze and Armand Joulin and Ivan Laptev and Natalia Neverova and Gabriel Synnaeve and Jakob Verbeek and Herv{\'e} J{\'e}gou},
booktitle = {Neural Information Processing Systems},
year = {2021},
url = {https://api.semanticscholar.org/CorpusID:235458262}
}
```
5 changes: 5 additions & 0 deletions magvit2_pytorch/attend.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,11 @@ def flash_attn(
):
batch, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device

# manage scale, since scale is not customizable in sdp, hack around it

if exists(self.scale):
q = q * self.scale / (q.shape[-1] ** -0.5)

# Check if mask exists and expand to compatible shape
# The mask is B L, so it would have to be expanded to B H N L

Expand Down
80 changes: 80 additions & 0 deletions magvit2_pytorch/magvit2_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ def cast_tuple(t, length = 1):

# tensor helpers

def l2norm(t):
return F.normalize(t, dim = -1, p = 2)

def pad_at_dim(t, pad, dim = -1, value = 0.):
dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
zeros = ((0, 0) * dims_from_right)
Expand Down Expand Up @@ -222,6 +225,64 @@ def forward(self, x, mask = None):
out = self.attend(q, k, v, mask = mask)
return self.to_out(out)

class LinearAttention(Module):
"""
using the specific linear attention proposed in https://arxiv.org/abs/2106.09681
"""

def __init__(
self,
*,
dim,
dim_head = 32,
heads = 8,
flash = False,
dropout = 0.
):
super().__init__()
dim_inner = dim_head * heads
self.to_qkv = Sequential(
RMSNorm(dim),
nn.Linear(dim, dim_inner * 3, bias = False),
Rearrange('b n (qkv h d) -> qkv b h d n', qkv = 3, h = heads)
)

self.temperature = nn.Parameter(torch.ones(heads, 1, 1))

self.attend = Attend(
scale = 1.,
causal = False,
dropout = dropout,
flash = flash
)

self.to_out = Sequential(
Rearrange('b h d n -> b n (h d)'),
nn.Linear(dim_inner, dim)
)

def forward(self, x):
q, k, v = self.to_qkv(x)

q, k = map(l2norm, (q, k))
q = q * self.temperature.exp()

out = self.attend(q, k, v)

return self.to_out(out)

class LinearSpaceAttention(LinearAttention):
def forward(self, x, *args, **kwargs):
x = rearrange(x, 'b c t h w -> b t h w c')
x, batch_ps = pack_one(x, '* h w c')
x, seq_ps = pack_one(x, 'b * c')

x = super().forward(x, *args, **kwargs)

x = unpack_one(x, seq_ps, 'b * c')
x = unpack_one(x, batch_ps, '* h w c')
return rearrange(x, 'b t h w c -> b c t h w')

class SpaceAttention(Attention):
def forward(self, x, *args, **kwargs):
x = rearrange(x, 'b c t h w -> b t h w c')
Expand Down Expand Up @@ -856,6 +917,25 @@ def __init__(
Residual(FeedForward(dim))
)

elif layer_type == 'linear_attend_space':
attn_kwargs = dict(
dim = dim,
dim_head = attn_dim_head,
heads = attn_heads,
dropout = attn_dropout,
flash = flash_attn
)

encoder_layer = Sequential(
Residual(LinearSpaceAttention(**attn_kwargs)),
Residual(FeedForward(dim))
)

decoder_layer = Sequential(
Residual(LinearSpaceAttention(**attn_kwargs)),
Residual(FeedForward(dim))
)

elif layer_type == 'attend_time':
attn_kwargs = dict(
dim = dim,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'magvit2-pytorch',
packages = find_packages(),
version = '0.0.26',
version = '0.0.27',
license='MIT',
description = 'MagViT2 - Pytorch',
long_description_content_type = 'text/markdown',
Expand Down

0 comments on commit 72a86e8

Please sign in to comment.