Skip to content

Commit

Permalink
save version
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 25, 2023
1 parent cce7d8f commit b3d4644
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 3 deletions.
26 changes: 24 additions & 2 deletions magvit2_pytorch/magvit2_pytorch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
from pathlib import Path
from math import log2, ceil, sqrt
from functools import wraps, partial

Expand All @@ -22,6 +23,7 @@
from beartype.typing import Union, Tuple, Optional

from magvit2_pytorch.attend import Attend
from magvit2_pytorch.version import __version__

from kornia.filters import filter3d

Expand Down Expand Up @@ -1096,11 +1098,31 @@ def state_dict(self, *args, **kwargs):
def load_state_dict(self, *args, **kwargs):
return super().load_state_dict(*args, **kwargs)

def save(self, path, overwrite = True):
path = Path(path)
assert overwrite or not path.exists(), f'{str(path)} already exists'

pkg = dict(
model_state_dict = self.state_dict(),
version = __version__
)

torch.save(pkg, str(path))

def load(self, path):
path = Path(path)
assert path.exists()
pt = torch.load(str(path))
self.load_state_dict(pt)

pkg = torch.load(str(path))
state_dict = pkg.get('model_state_dict')
version = pkg.get('version')

assert exists(state_dict)

if exists(version):
print(f'loading checkpointed tokenizer from version {version}')

self.load_state_dict(state_dict)

@beartype
def encode(
Expand Down
1 change: 1 addition & 0 deletions magvit2_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__version__ = '0.0.31'
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from setuptools import setup, find_packages

exec(open('magvit2_pytorch/version.py').read())

setup(
name = 'magvit2-pytorch',
packages = find_packages(),
version = '0.0.30',
version = __version__,
license='MIT',
description = 'MagViT2 - Pytorch',
long_description_content_type = 'text/markdown',
Expand Down

0 comments on commit b3d4644

Please sign in to comment.