From 0fd20380abb8725e6877650847b796b4abcfddeb Mon Sep 17 00:00:00 2001 From: lucidrains Date: Thu, 26 Oct 2023 08:51:28 -0700 Subject: [PATCH] apply transform to each frame of video --- magvit2_pytorch/data.py | 11 ++++++----- magvit2_pytorch/version.py | 2 +- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/magvit2_pytorch/data.py b/magvit2_pytorch/data.py index edbd82e..4a3bdad 100644 --- a/magvit2_pytorch/data.py +++ b/magvit2_pytorch/data.py @@ -125,7 +125,7 @@ def video_tensor_to_gif( # gif -> (channels, frame, height, width) tensor def gif_to_tensor( - path, + path: str, channels = 3, transform = T.ToTensor() ): @@ -228,8 +228,7 @@ def __init__( self.transform = T.Compose([ T.Resize(image_size), T.RandomHorizontalFlip() if horizontal_flip else T.Lambda(identity), - T.CenterCrop(image_size), - T.ToTensor() + T.CenterCrop(image_size) ]) # functions to transform video path to tensor @@ -245,14 +244,16 @@ def __len__(self): def __getitem__(self, index): path = self.paths[index] ext = path.suffix + path_str = str(path) if ext == '.gif': - tensor = self.gif_to_tensor(path) + tensor = self.gif_to_tensor(path_str) elif ext == '.mp4': - tensor = self.mp4_to_tensor(str(path)) + tensor = self.mp4_to_tensor(path_str) else: raise ValueError(f'unknown extension {ext}') + tensor = self.transform(tensor) return self.cast_num_frames_fn(tensor) # override dataloader to be able to collate strings diff --git a/magvit2_pytorch/version.py b/magvit2_pytorch/version.py index 9ae29b1..2cf0f75 100644 --- a/magvit2_pytorch/version.py +++ b/magvit2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.0.34' +__version__ = '0.0.35'