diff --git a/README.md b/README.md index c29bb4f..934b385 100644 --- a/README.md +++ b/README.md @@ -103,12 +103,12 @@ codes = tokenizer(videos, return_codes = True) ```bibtex @article{Shleifer2021NormFormerIT, - title = {NormFormer: Improved Transformer Pretraining with Extra Normalization}, - author = {Sam Shleifer and Jason Weston and Myle Ott}, - journal = {ArXiv}, - year = {2021}, - volume = {abs/2110.09456}, - url = {https://api.semanticscholar.org/CorpusID:239016890} + title = {NormFormer: Improved Transformer Pretraining with Extra Normalization}, + author = {Sam Shleifer and Jason Weston and Myle Ott}, + journal = {ArXiv}, + year = {2021}, + volume = {abs/2110.09456}, + url = {https://api.semanticscholar.org/CorpusID:239016890} } ``` diff --git a/magvit2_pytorch/magvit2_pytorch.py b/magvit2_pytorch/magvit2_pytorch.py index 7784718..c498bef 100644 --- a/magvit2_pytorch/magvit2_pytorch.py +++ b/magvit2_pytorch/magvit2_pytorch.py @@ -27,6 +27,8 @@ from kornia.filters import filter3d +import pickle + # helper def exists(v): @@ -850,6 +852,7 @@ class VideoTokenizer(Module): @beartype def __init__( self, + *, image_size, layers: Tuple[Union[str, Tuple[str, int]], ...] = ( 'residual', @@ -885,6 +888,15 @@ def __init__( ): super().__init__() + # for autosaving the config + + _locals = locals() + _locals.pop('self', None) + _locals.pop('__class__', None) + self._configs = pickle.dumps(_locals) + + # image size + self.image_size = image_size # encoder @@ -1066,6 +1078,19 @@ def __init__( self.has_gan = use_gan and adversarial_loss_weight > 0. + @classmethod + def init_and_load_from(cls, path, strict = True): + path = Path(path) + assert path.exists() + pkg = torch.load(str(path), map_location = 'cpu') + + assert 'config' in pkg, 'model configs were not found in this saved checkpoint' + + config = pickle.loads(pkg['config']) + tokenizer = cls(**config) + tokenizer.load(path, strict = strict) + return tokenizer + def parameters(self): return [ *self.conv_in.parameters(), @@ -1104,12 +1129,13 @@ def save(self, path, overwrite = True): pkg = dict( model_state_dict = self.state_dict(), - version = __version__ + version = __version__, + config = self._configs ) torch.save(pkg, str(path)) - def load(self, path): + def load(self, path, strict = True): path = Path(path) assert path.exists() @@ -1122,7 +1148,7 @@ def load(self, path): if exists(version): print(f'loading checkpointed tokenizer from version {version}') - self.load_state_dict(state_dict) + self.load_state_dict(state_dict, strict = strict) @beartype def encode( diff --git a/magvit2_pytorch/version.py b/magvit2_pytorch/version.py index 138343c..d9edf17 100644 --- a/magvit2_pytorch/version.py +++ b/magvit2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.0.31' +__version__ = '0.0.33'