diff --git a/magvit2_pytorch/magvit2_pytorch.py b/magvit2_pytorch/magvit2_pytorch.py index fafc171..7784718 100644 --- a/magvit2_pytorch/magvit2_pytorch.py +++ b/magvit2_pytorch/magvit2_pytorch.py @@ -1,4 +1,5 @@ import copy +from pathlib import Path from math import log2, ceil, sqrt from functools import wraps, partial @@ -22,6 +23,7 @@ from beartype.typing import Union, Tuple, Optional from magvit2_pytorch.attend import Attend +from magvit2_pytorch.version import __version__ from kornia.filters import filter3d @@ -1096,11 +1098,31 @@ def state_dict(self, *args, **kwargs): def load_state_dict(self, *args, **kwargs): return super().load_state_dict(*args, **kwargs) + def save(self, path, overwrite = True): + path = Path(path) + assert overwrite or not path.exists(), f'{str(path)} already exists' + + pkg = dict( + model_state_dict = self.state_dict(), + version = __version__ + ) + + torch.save(pkg, str(path)) + def load(self, path): path = Path(path) assert path.exists() - pt = torch.load(str(path)) - self.load_state_dict(pt) + + pkg = torch.load(str(path)) + state_dict = pkg.get('model_state_dict') + version = pkg.get('version') + + assert exists(state_dict) + + if exists(version): + print(f'loading checkpointed tokenizer from version {version}') + + self.load_state_dict(state_dict) @beartype def encode( diff --git a/magvit2_pytorch/version.py b/magvit2_pytorch/version.py new file mode 100644 index 0000000..138343c --- /dev/null +++ b/magvit2_pytorch/version.py @@ -0,0 +1 @@ +__version__ = '0.0.31' diff --git a/setup.py b/setup.py index 474e943..e98b937 100644 --- a/setup.py +++ b/setup.py @@ -1,9 +1,11 @@ from setuptools import setup, find_packages +exec(open('magvit2_pytorch/version.py').read()) + setup( name = 'magvit2-pytorch', packages = find_packages(), - version = '0.0.30', + version = __version__, license='MIT', description = 'MagViT2 - Pytorch', long_description_content_type = 'text/markdown',