Skip to content

Commit

Permalink
assert that decoding from code indices works correctly and is the sam…
Browse files Browse the repository at this point in the history
…e as reconstructed video. ready for the language modeling part
  • Loading branch information
lucidrains committed Oct 24, 2023
1 parent da129e0 commit 14047c9
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 8 deletions.
21 changes: 14 additions & 7 deletions magvit2_pytorch/magvit2_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1112,31 +1112,34 @@ def encode(
@beartype
def decode_from_code_indices(
self,
indices: Tensor,
codes: Tensor,
cond: Optional[Tensor] = None
):
codes = self.quantizers.get_output_from_indices(indices)
return self.decode(codes = codes, cond = cond)
quantized = self.quantizers.indices_to_codes(codes)
out = self.decode(quantized, cond = cond)
return out[:, :, self.time_padding:]

@beartype
def decode(
self,
codes: Tensor,
quantized: Tensor,
cond: Optional[Tensor] = None
):
batch = quantized.shape[0]

# conditioning, if needed

assert (not self.has_cond) or exists(cond), '`cond` must be passed into tokenizer forward method since conditionable layers were specified'

if exists(cond):
assert cond.shape == (codes.shape[0], self.dim_cond)
assert cond.shape == (batch, self.dim_cond)

cond = self.decoder_cond_in(cond)
cond_kwargs = dict(cond = cond)

# decoder layers

x = codes
x = quantized

for fn in self.decoder_layers:

Expand All @@ -1157,6 +1160,7 @@ def forward(
cond: Optional[Tensor] = None,
return_loss = False,
return_codes = False,
return_recon = False,
return_discr_loss = False,
apply_gradient_penalty = True
):
Expand Down Expand Up @@ -1188,7 +1192,7 @@ def forward(

(quantized, codes, aux_losses), lfq_loss_breakdown = self.quantizers(x, return_loss_breakdown = True)

if return_codes:
if return_codes and not return_recon:
return codes

# decoder
Expand All @@ -1197,6 +1201,9 @@ def forward(

recon_video = padded_recon_video[:, :, self.time_padding:]

if return_codes:
return recon_video, codes

# reconstruction loss

if not (return_loss or return_discr_loss):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'magvit2-pytorch',
packages = find_packages(),
version = '0.0.27',
version = '0.0.28',
license='MIT',
description = 'MagViT2 - Pytorch',
long_description_content_type = 'text/markdown',
Expand Down

0 comments on commit 14047c9

Please sign in to comment.