Skip to content

Commit

Permalink
apply transform to each frame of video
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 26, 2023
1 parent edbddc2 commit 0fd2038
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
11 changes: 6 additions & 5 deletions magvit2_pytorch/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def video_tensor_to_gif(
# gif -> (channels, frame, height, width) tensor

def gif_to_tensor(
path,
path: str,
channels = 3,
transform = T.ToTensor()
):
Expand Down Expand Up @@ -228,8 +228,7 @@ def __init__(
self.transform = T.Compose([
T.Resize(image_size),
T.RandomHorizontalFlip() if horizontal_flip else T.Lambda(identity),
T.CenterCrop(image_size),
T.ToTensor()
T.CenterCrop(image_size)
])

# functions to transform video path to tensor
Expand All @@ -245,14 +244,16 @@ def __len__(self):
def __getitem__(self, index):
path = self.paths[index]
ext = path.suffix
path_str = str(path)

if ext == '.gif':
tensor = self.gif_to_tensor(path)
tensor = self.gif_to_tensor(path_str)
elif ext == '.mp4':
tensor = self.mp4_to_tensor(str(path))
tensor = self.mp4_to_tensor(path_str)
else:
raise ValueError(f'unknown extension {ext}')

tensor = self.transform(tensor)
return self.cast_num_frames_fn(tensor)

# override dataloader to be able to collate strings
Expand Down
2 changes: 1 addition & 1 deletion magvit2_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.0.34'
__version__ = '0.0.35'

0 comments on commit 0fd2038

Please sign in to comment.