Skip to content

Commit

Permalink
saves the video tokenizer config in the .pt file
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 25, 2023
1 parent b3d4644 commit 6dfc98c
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 10 deletions.
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,12 @@ codes = tokenizer(videos, return_codes = True)

```bibtex
@article{Shleifer2021NormFormerIT,
title = {NormFormer: Improved Transformer Pretraining with Extra Normalization},
author = {Sam Shleifer and Jason Weston and Myle Ott},
journal = {ArXiv},
year = {2021},
volume = {abs/2110.09456},
url = {https://api.semanticscholar.org/CorpusID:239016890}
title = {NormFormer: Improved Transformer Pretraining with Extra Normalization},
author = {Sam Shleifer and Jason Weston and Myle Ott},
journal = {ArXiv},
year = {2021},
volume = {abs/2110.09456},
url = {https://api.semanticscholar.org/CorpusID:239016890}
}
```

Expand Down
32 changes: 29 additions & 3 deletions magvit2_pytorch/magvit2_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@

from kornia.filters import filter3d

import pickle

# helper

def exists(v):
Expand Down Expand Up @@ -850,6 +852,7 @@ class VideoTokenizer(Module):
@beartype
def __init__(
self,
*,
image_size,
layers: Tuple[Union[str, Tuple[str, int]], ...] = (
'residual',
Expand Down Expand Up @@ -885,6 +888,15 @@ def __init__(
):
super().__init__()

# for autosaving the config

_locals = locals()
_locals.pop('self', None)
_locals.pop('__class__', None)
self._configs = pickle.dumps(_locals)

# image size

self.image_size = image_size

# encoder
Expand Down Expand Up @@ -1066,6 +1078,19 @@ def __init__(

self.has_gan = use_gan and adversarial_loss_weight > 0.

@classmethod
def init_and_load_from(cls, path, strict = True):
path = Path(path)
assert path.exists()
pkg = torch.load(str(path), map_location = 'cpu')

assert 'config' in pkg, 'model configs were not found in this saved checkpoint'

config = pickle.loads(pkg['config'])
tokenizer = cls(**config)
tokenizer.load(path, strict = strict)
return tokenizer

def parameters(self):
return [
*self.conv_in.parameters(),
Expand Down Expand Up @@ -1104,12 +1129,13 @@ def save(self, path, overwrite = True):

pkg = dict(
model_state_dict = self.state_dict(),
version = __version__
version = __version__,
config = self._configs
)

torch.save(pkg, str(path))

def load(self, path):
def load(self, path, strict = True):
path = Path(path)
assert path.exists()

Expand All @@ -1122,7 +1148,7 @@ def load(self, path):
if exists(version):
print(f'loading checkpointed tokenizer from version {version}')

self.load_state_dict(state_dict)
self.load_state_dict(state_dict, strict = strict)

@beartype
def encode(
Expand Down
2 changes: 1 addition & 1 deletion magvit2_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.0.31'
__version__ = '0.0.33'

0 comments on commit 6dfc98c

Please sign in to comment.