Skip to content

Commit

Permalink
setup multiscale adversarial losses
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 31, 2023
1 parent e12829a commit 3ac9ff1
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 15 deletions.
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,11 @@ assert torch.allclose(
- [x] add conditioning for encoder decoder with residual modulatable conv 3d
- [x] `decode_from_codebook_indices` should be able to accept flattened ids and reshape to correct feature map dimensions and decode back to video
- [x] add trainer and manage discriminator training
- [x] completely generalize to multiple discriminators at different time scales (taking inspiration of multi-resolution discriminators from soundstream)
- [x] complete multiscale discriminator losses
- [ ] auto-manage multiscale discriminator optimizers
- [ ] helper functions for temporal discrimination (picking random consecutive frames)
- [ ] add adaptive rmsnorm
- [ ] completely generalize to multiple discriminators at different time scales (taking inspiration of multi-resolution discriminators from soundstream)
- [ ] add attention
- [ ] use axial rotary embeddings for spatial
- [ ] add an optional autoregressive loss at some penultimate layer of the decoder - check literature to see if anyone else has done this unification of transformer decoder + tokenizer in one architecture
Expand Down
115 changes: 102 additions & 13 deletions magvit2_pytorch/magvit2_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ def unpack_one(t, ps, pattern):
def is_odd(n):
return not divisible_by(n, 2)

def maybe_del_attr_(o, attr):
if hasattr(o, attr):
delattr(o, attr)

def cast_tuple(t, length = 1):
return t if isinstance(t, tuple) else ((t,) * length)

Expand Down Expand Up @@ -842,7 +846,9 @@ def forward(self, x):
'lfq_commitment_loss',
'perceptual_loss',
'gen_loss',
'adaptive_adversarial_weight'
'adaptive_adversarial_weight',
'multiscale_gen_losses',
'multiscale_gen_adaptive_weights'
])

DiscrLossBreakdown = namedtuple('DiscrLossBreakdown', [
Expand Down Expand Up @@ -883,9 +889,11 @@ def __init__(
perceptual_loss_weight = 1.,
antialiased_downsample = True,
discr_kwargs: Optional[dict] = None,
multiscale_discrs: Optional[Tuple[Module, ...]] = None,
use_gan = True,
adversarial_loss_weight = 1.,
grad_penalty_loss_weight = 10.,
multiscale_adversarial_loss_weight = 1.,
flash_attn = True
):
super().__init__()
Expand Down Expand Up @@ -1095,6 +1103,13 @@ def __init__(

self.has_gan = use_gan and adversarial_loss_weight > 0.

# multi-scale discriminators

self.multiscale_discrs = ModuleList([*multiscale_discrs])

self.multiscale_adversarial_loss_weight = multiscale_adversarial_loss_weight
self.has_multiscale_discrs = use_gan and multiscale_adversarial_loss_weight > 0.

@property
def device(self):
return self.zero.device
Expand Down Expand Up @@ -1129,11 +1144,9 @@ def copy_for_eval(self):
device = self.device
vae_copy = copy.deepcopy(self.cpu())

if hasattr(vae_copy, 'discr'):
del vae_copy.discr

if hasattr(vae_copy, 'vgg'):
del vae_copy.vgg
maybe_del_attr_(vae_copy, 'discr')
maybe_del_attr_(vae_copy, 'vgg')
maybe_del_attr_(vae_copy, 'multiscale_discrs')

vae_copy.eval()
return vae_copy.to(device)
Expand Down Expand Up @@ -1335,6 +1348,8 @@ def forward(
assert self.has_gan
assert exists(self.discr)

# pick a random frame for image discriminator

frame_indices = torch.randn((batch, frames)).topk(1, dim = -1).indices

real = pick_video_frame(video, frame_indices)
Expand All @@ -1349,14 +1364,40 @@ def forward(

discr_loss = hinge_discr_loss(fake_logits, real_logits)

# multiscale discriminators

multiscale_discr_losses = []

if self.has_multiscale_discrs:
for discr in self.multiscale_discrs:
multiscale_real_logits = discr(video)
multiscale_fake_logits = discr(recon_video)

multiscale_discr_loss = hinge_discr_loss(fake_logits, real_logits)

multiscale_discr_losses.append(multiscale_discr_loss)
else:
multiscale_discr_losses.append(self.zero)

# gradient penalty

if apply_gradient_penalty:
gradient_penalty_loss = gradient_penalty(real, real_logits)
else:
gradient_penalty_loss = self.zero

total_loss = discr_loss + gradient_penalty_loss * self.grad_penalty_loss_weight
# total loss

return total_loss, DiscrLossBreakdown(discr_loss, gradient_penalty_loss)
total_loss = discr_loss + \
gradient_penalty_loss * self.grad_penalty_loss_weight + \
sum(multiscale_discr_losses)

discr_loss_breakdown = DiscrLossBreakdown(
discr_loss,
gradient_penalty_loss
)

return total_loss, discr_loss_breakdown

# perceptual loss

Expand All @@ -1373,33 +1414,81 @@ def forward(
else:
perceptual_loss = self.zero

# get gradient with respect to perceptual loss for last decoder layer
# needed for adaptive weighting

last_dec_layer = self.conv_out.conv.weight

if self.training and (self.has_gan or self.has_multiscale_discrs):
norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_dec_layer).norm(p = 2)

# per-frame image discriminator

if self.has_gan:
frame_indices = torch.randn((batch, frames)).topk(1, dim = -1).indices
recon_video_frames = pick_video_frame(recon_video, frame_indices)

fake_logits = self.discr(recon_video_frames)
gen_loss = hinge_gen_loss(fake_logits)

last_dec_layer = self.conv_out.conv.weight

adaptive_weight = 1.

if not self.training:
norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_dec_layer).norm(p = 2)
norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_dec_layer).norm(p = 2)

adaptive_weight = norm_grad_wrt_perceptual_loss / norm_grad_wrt_gen_loss.clamp(min = 1e-5)
adaptive_weight.clamp_(max = 1e4)
else:
gen_loss = self.zero
adaptive_weight = 0.

# multiscale discriminator losses

multiscale_gen_losses = []
multiscale_gen_adaptive_weights = []

if self.has_multiscale_discrs:
for discr in self.multiscale_discrs:
fake_logits = recon_video_frames
multiscale_gen_loss = hinge_gen_loss(fake_logits)

multiscale_gen_losses.append(multiscale_gen_loss)

adaptive_weight = 1.

if not self.training:
norm_grad_wrt_gen_loss = grad_layer_wrt_loss(multiscale_gen_loss, last_dec_layer).norm(p = 2)
adaptive_weight = norm_grad_wrt_perceptual_loss / norm_grad_wrt_gen_loss.clamp(min = 1e-5)
adaptive_weight.clamp_(max = 1e4)

multiscale_gen_adaptive_weights.append(adaptive_weight)

# calculate total loss

total_loss = recon_loss \
+ aux_losses * self.lfq_aux_loss_weight \
+ perceptual_loss * self.perceptual_loss_weight \
+ gen_loss * adaptive_weight * self.adversarial_loss_weight

return total_loss, LossBreakdown(recon_loss, aux_losses, *lfq_loss_breakdown, perceptual_loss, gen_loss, adaptive_weight)
if self.has_multiscale_discrs:

weighted_multiscale_gen_losses = sum(loss * weight for loss, weight in zip(multiscale_gen_losses, multiscale_gen_adaptive_weights))

total_loss = total_loss + weighted_multiscale_gen_losses * self.multiscale_adversarial_loss_weight

# loss breakdown

loss_breakdown = LossBreakdown(
recon_loss,
aux_losses,
*lfq_loss_breakdown,
perceptual_loss,
gen_loss,
adaptive_weight,
multiscale_gen_losses,
multiscale_gen_adaptive_weights
)

return total_loss, loss_breakdown

# main class

Expand Down
2 changes: 1 addition & 1 deletion magvit2_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.0.52'
__version__ = '0.0.53'

0 comments on commit 3ac9ff1

Please sign in to comment.