Skip to content

Commit

Permalink
ready for training
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 24, 2023
1 parent 72a86e8 commit da129e0
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 0 deletions.
13 changes: 13 additions & 0 deletions magvit2_pytorch/magvit2_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1036,6 +1036,19 @@ def __init__(

self.has_gan = use_gan and adversarial_loss_weight > 0.

def parameters(self):
return [
*self.conv_in.parameters(),
*self.conv_out.parameters(),
*self.encoder_layers.parameters(),
*self.decoder_layers.parameters(),
*self.encoder_cond_in.parameters(),
*self.decoder_cond_in.parameters(),
]

def discr_parameters(self):
return self.discr.parameters()

def copy_for_eval(self):
device = next(self.parameters()).device
vae_copy = copy.deepcopy(self.cpu())
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
'accelerate',
'beartype',
'einops>=0.7.0',
'ema-pytorch',
'kornia',
'vector-quantize-pytorch>=1.9.18',
'torch',
Expand Down

0 comments on commit da129e0

Please sign in to comment.