Skip to content

Commit

Permalink
bring in dataset and dataloader
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 26, 2023
1 parent 6dfc98c commit edbddc2
Show file tree
Hide file tree
Showing 3 changed files with 284 additions and 1 deletion.
280 changes: 280 additions & 0 deletions magvit2_pytorch/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,280 @@
from pathlib import Path
from functools import partial

import torch
from torch import Tensor
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader as PytorchDataLoader

import cv2
from PIL import Image
from torchvision import transforms as T, utils

from beartype.typing import Tuple, List
from beartype.door import is_bearable

import numpy as np

from einops import rearrange

# helper functions

def exists(val):
return val is not None

def identity(t, *args, **kwargs):
return t

def pair(val):
return val if isinstance(val, tuple) else (val, val)

def pad_at_dim(t, pad, dim = -1, value = 0.):
dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
zeros = ((0, 0) * dims_from_right)
return F.pad(t, (*zeros, *pad), value = value)

def cast_num_frames(t, *, frames):
f = t.shape[-3]

if f == frames:
return t

if f > frames:
return t[..., :frames, :, :]

return pad_at_dim(t, (0, frames - f), dim = -3)

def convert_image_to_fn(img_type, image):
if not exists(img_type) or image.mode == img_type:
return image

return image.convert(img_type)

# image related helpers fnuctions and dataset

class ImageDataset(Dataset):
def __init__(
self,
folder,
image_size,
convert_image_to = None,
exts = ['jpg', 'jpeg', 'png']
):
super().__init__()
folder = Path(folder)
assert folder.is_dir(), f'{str(folder)} must be a folder containing videos'
self.folder = folder

self.image_size = image_size
self.paths = [p for ext in exts for p in folder.glob(f'**/*.{ext}')]

print(f'{len(self.paths)} training samples found at {folder}')

self.transform = T.Compose([
T.Lambda(partial(convert_image_to_fn, convert_image_to)),
T.Resize(image_size),
T.RandomHorizontalFlip(),
T.CenterCrop(image_size),
T.ToTensor()
])

def __len__(self):
return len(self.paths)

def __getitem__(self, index):
path = self.paths[index]
img = Image.open(path)
return self.transform(img)

# tensor of shape (channels, frames, height, width) -> gif

# handle reading and writing gif

def seek_all_images(img: Tensor, channels = 3):
mode = {
1 : 'L',
3 : 'RGB',
4 : 'RGBA'
}.get(channels)

assert exists(mode), f'channels {channels} invalid'

i = 0
while True:
try:
img.seek(i)
yield img.convert(mode)
except EOFError:
break
i += 1

# tensor of shape (channels, frames, height, width) -> gif

def video_tensor_to_gif(
tensor: Tensor,
path: str,
duration = 120,
loop = 0,
optimize = True
):
images = map(T.ToPILImage(), tensor.unbind(dim = 1))
first_img, *rest_imgs = images
first_img.save(path, save_all = True, append_images = rest_imgs, duration = duration, loop = loop, optimize = optimize)
return images

# gif -> (channels, frame, height, width) tensor

def gif_to_tensor(
path,
channels = 3,
transform = T.ToTensor()
):
img = Image.open(path)
tensors = tuple(map(transform, seek_all_images(img, channels = channels)))
return torch.stack(tensors, dim = 1)

# handle reading and writing mp4

def video_to_tensor(
path: str, # Path of the video to be imported
num_frames = -1, # Number of frames to be stored in the output tensor
crop_size = None
) -> Tensor: # shape (1, channels, frames, height, width)

video = cv2.VideoCapture(path)

frames = []
check = True

while check:
check, frame = video.read()

if not check:
continue

if exists(crop_size):
frame = crop_center(frame, *pair(crop_size))

frames.append(rearrange(frame, '... -> 1 ...'))

frames = np.array(np.concatenate(frames[:-1], axis = 0)) # convert list of frames to numpy array
frames = rearrange(frames, 'f h w c -> c f h w')

frames_torch = torch.tensor(frames).float()

return frames_torch[:, :num_frames, :, :]

def tensor_to_video(
tensor: Tensor, # Pytorch video tensor
path: str, # Path of the video to be saved
fps = 25, # Frames per second for the saved video
video_format = 'MP4V'
):
# Import the video and cut it into frames.
tensor = tensor.cpu()

num_frames, height, width = tensor.shape[-3:]

fourcc = cv2.VideoWriter_fourcc(*video_format) # Changes in this line can allow for different video formats.
video = cv2.VideoWriter(path, fourcc, fps, (width, height))

frames = []

for idx in range(num_frames):
numpy_frame = tensor[:, idx, :, :].numpy()
numpy_frame = np.uint8(rearrange(numpy_frame, 'c h w -> h w c'))
video.write(numpy_frame)

video.release()

cv2.destroyAllWindows()

return video

def crop_center(
img: Tensor,
cropx: int, # Length of the final image in the x direction.
cropy: int # Length of the final image in the y direction.
) -> Tensor:
y, x, c = img.shape
startx = x // 2 - cropx // 2
starty = y // 2 - cropy // 2
return img[starty:(starty + cropy), startx:(startx + cropx), :]

# video dataset

class VideoDataset(Dataset):
def __init__(
self,
folder,
image_size,
channels = 3,
num_frames = 17,
horizontal_flip = False,
force_num_frames = True,
exts = ['gif', 'mp4']
):
super().__init__()
folder = Path(folder)
assert folder.is_dir(), f'{str(folder)} must be a folder containing videos'
self.folder = folder

self.image_size = image_size
self.channels = channels
self.paths = [p for ext in exts for p in folder.glob(f'**/*.{ext}')]

print(f'{len(self.paths)} training samples found at {folder}')

self.transform = T.Compose([
T.Resize(image_size),
T.RandomHorizontalFlip() if horizontal_flip else T.Lambda(identity),
T.CenterCrop(image_size),
T.ToTensor()
])

# functions to transform video path to tensor

self.gif_to_tensor = partial(gif_to_tensor, channels = self.channels, transform = self.transform)
self.mp4_to_tensor = partial(video_to_tensor, crop_size = self.image_size)

self.cast_num_frames_fn = partial(cast_num_frames, frames = num_frames) if force_num_frames else identity

def __len__(self):
return len(self.paths)

def __getitem__(self, index):
path = self.paths[index]
ext = path.suffix

if ext == '.gif':
tensor = self.gif_to_tensor(path)
elif ext == '.mp4':
tensor = self.mp4_to_tensor(str(path))
else:
raise ValueError(f'unknown extension {ext}')

return self.cast_num_frames_fn(tensor)

# override dataloader to be able to collate strings

def collate_tensors_and_strings(data):
if is_bearable(data, List[Tensor]):
return (torch.stack(data),)

data = zip(*data)
output = []

for datum in data:
if is_bearable(datum, Tuple[Tensor, ...]):
datum = torch.stack(datum)
elif is_bearable(datum, Tuple[str, ...]):
datum = list(datum)
else:
raise ValueError('detected invalid type being passed from dataset')

output.append(datum)

return tuple(output)

def DataLoader(*args, **kwargs):
return PytorchDataLoader(*args, collate_fn = collate_tensors_and_strings, **kwargs)
2 changes: 1 addition & 1 deletion magvit2_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.0.33'
__version__ = '0.0.34'
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
'einops>=0.7.0',
'ema-pytorch',
'kornia',
'opencv-python',
'pillow',
'numpy',
'vector-quantize-pytorch>=1.9.18',
'torch',
'torchvision'
Expand Down

0 comments on commit edbddc2

Please sign in to comment.