Skip to content

Commit

Permalink
add image spatial squeeze excites to all residual units
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 24, 2023
1 parent 041ccb6 commit b2f105b
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 5 deletions.
55 changes: 51 additions & 4 deletions magvit2_pytorch/magvit2_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,55 @@ def forward(self, x, **kwargs):
o = unpack_one(o, ps, '* n c')
return rearrange(o, 'b ... f c -> b c f ...')


class SqueezeExcite(Module):
# global context network - attention-esque squeeze-excite variant (https://arxiv.org/abs/2012.13375)

def __init__(
self,
dim,
*,
dim_out = None,
dim_hidden_min = 16,
init_bias = -10
):
super().__init__()
dim_out = default(dim_out, dim)

self.to_k = nn.Conv2d(dim, 1, 1)
dim_hidden = max(dim_hidden_min, dim_out // 2)

self.net = nn.Sequential(
nn.Conv2d(dim, dim_hidden, 1),
nn.LeakyReLU(0.1),
nn.Conv2d(dim_hidden, dim_out, 1),
nn.Sigmoid()
)

nn.init.zeros_(self.net[-2].weight)
nn.init.constant_(self.net[-2].bias, init_bias)

def forward(self, x):
orig_input, batch = x, x.shape[0]
is_video = x.ndim == 5

if is_video:
x = rearrange(x, 'b c f h w -> (b f) c h w')

context = self.to_k(x)

context = rearrange(context, 'b c h w -> b c (h w)').softmax(dim = -1)
spatial_flattened_input = rearrange(x, 'b c h w -> b c (h w)')

out = einsum('b i n, b c n -> b c i', context, spatial_flattened_input)
out = rearrange(out, '... -> ... 1')
gates = self.net(out)

if is_video:
gates = rearrange(gates, '(b f) c h w -> b c f h w', b = batch)

return gates * orig_input

# token shifting

class TokenShift(Module):
Expand Down Expand Up @@ -900,12 +949,10 @@ def ResidualUnit(
CausalConv3d(dim, dim, kernel_size, pad_mode = pad_mode),
nn.ELU(),
nn.Conv3d(dim, dim, 1),
nn.ELU()
nn.ELU(),
SqueezeExcite(dim)
)

nn.init.zeros_(net[-2].weight)
nn.init.zeros_(net[-2].bias)

return Residual(net)

@beartype
Expand Down
2 changes: 1 addition & 1 deletion magvit2_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.1.39'
__version__ = '0.1.40'

0 comments on commit b2f105b

Please sign in to comment.