From edbddc2c25e3501eb9e5396d1c9d9655592eaa03 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Thu, 26 Oct 2023 08:19:30 -0700 Subject: [PATCH] bring in dataset and dataloader --- magvit2_pytorch/data.py | 280 +++++++++++++++++++++++++++++++++++++ magvit2_pytorch/version.py | 2 +- setup.py | 3 + 3 files changed, 284 insertions(+), 1 deletion(-) create mode 100644 magvit2_pytorch/data.py diff --git a/magvit2_pytorch/data.py b/magvit2_pytorch/data.py new file mode 100644 index 0000000..edbd82e --- /dev/null +++ b/magvit2_pytorch/data.py @@ -0,0 +1,280 @@ +from pathlib import Path +from functools import partial + +import torch +from torch import Tensor +import torch.nn.functional as F +from torch.utils.data import Dataset, DataLoader as PytorchDataLoader + +import cv2 +from PIL import Image +from torchvision import transforms as T, utils + +from beartype.typing import Tuple, List +from beartype.door import is_bearable + +import numpy as np + +from einops import rearrange + +# helper functions + +def exists(val): + return val is not None + +def identity(t, *args, **kwargs): + return t + +def pair(val): + return val if isinstance(val, tuple) else (val, val) + +def pad_at_dim(t, pad, dim = -1, value = 0.): + dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1) + zeros = ((0, 0) * dims_from_right) + return F.pad(t, (*zeros, *pad), value = value) + +def cast_num_frames(t, *, frames): + f = t.shape[-3] + + if f == frames: + return t + + if f > frames: + return t[..., :frames, :, :] + + return pad_at_dim(t, (0, frames - f), dim = -3) + +def convert_image_to_fn(img_type, image): + if not exists(img_type) or image.mode == img_type: + return image + + return image.convert(img_type) + +# image related helpers fnuctions and dataset + +class ImageDataset(Dataset): + def __init__( + self, + folder, + image_size, + convert_image_to = None, + exts = ['jpg', 'jpeg', 'png'] + ): + super().__init__() + folder = Path(folder) + assert folder.is_dir(), f'{str(folder)} must be a folder containing videos' + self.folder = folder + + self.image_size = image_size + self.paths = [p for ext in exts for p in folder.glob(f'**/*.{ext}')] + + print(f'{len(self.paths)} training samples found at {folder}') + + self.transform = T.Compose([ + T.Lambda(partial(convert_image_to_fn, convert_image_to)), + T.Resize(image_size), + T.RandomHorizontalFlip(), + T.CenterCrop(image_size), + T.ToTensor() + ]) + + def __len__(self): + return len(self.paths) + + def __getitem__(self, index): + path = self.paths[index] + img = Image.open(path) + return self.transform(img) + +# tensor of shape (channels, frames, height, width) -> gif + +# handle reading and writing gif + +def seek_all_images(img: Tensor, channels = 3): + mode = { + 1 : 'L', + 3 : 'RGB', + 4 : 'RGBA' + }.get(channels) + + assert exists(mode), f'channels {channels} invalid' + + i = 0 + while True: + try: + img.seek(i) + yield img.convert(mode) + except EOFError: + break + i += 1 + +# tensor of shape (channels, frames, height, width) -> gif + +def video_tensor_to_gif( + tensor: Tensor, + path: str, + duration = 120, + loop = 0, + optimize = True +): + images = map(T.ToPILImage(), tensor.unbind(dim = 1)) + first_img, *rest_imgs = images + first_img.save(path, save_all = True, append_images = rest_imgs, duration = duration, loop = loop, optimize = optimize) + return images + +# gif -> (channels, frame, height, width) tensor + +def gif_to_tensor( + path, + channels = 3, + transform = T.ToTensor() +): + img = Image.open(path) + tensors = tuple(map(transform, seek_all_images(img, channels = channels))) + return torch.stack(tensors, dim = 1) + +# handle reading and writing mp4 + +def video_to_tensor( + path: str, # Path of the video to be imported + num_frames = -1, # Number of frames to be stored in the output tensor + crop_size = None +) -> Tensor: # shape (1, channels, frames, height, width) + + video = cv2.VideoCapture(path) + + frames = [] + check = True + + while check: + check, frame = video.read() + + if not check: + continue + + if exists(crop_size): + frame = crop_center(frame, *pair(crop_size)) + + frames.append(rearrange(frame, '... -> 1 ...')) + + frames = np.array(np.concatenate(frames[:-1], axis = 0)) # convert list of frames to numpy array + frames = rearrange(frames, 'f h w c -> c f h w') + + frames_torch = torch.tensor(frames).float() + + return frames_torch[:, :num_frames, :, :] + +def tensor_to_video( + tensor: Tensor, # Pytorch video tensor + path: str, # Path of the video to be saved + fps = 25, # Frames per second for the saved video + video_format = 'MP4V' +): + # Import the video and cut it into frames. + tensor = tensor.cpu() + + num_frames, height, width = tensor.shape[-3:] + + fourcc = cv2.VideoWriter_fourcc(*video_format) # Changes in this line can allow for different video formats. + video = cv2.VideoWriter(path, fourcc, fps, (width, height)) + + frames = [] + + for idx in range(num_frames): + numpy_frame = tensor[:, idx, :, :].numpy() + numpy_frame = np.uint8(rearrange(numpy_frame, 'c h w -> h w c')) + video.write(numpy_frame) + + video.release() + + cv2.destroyAllWindows() + + return video + +def crop_center( + img: Tensor, + cropx: int, # Length of the final image in the x direction. + cropy: int # Length of the final image in the y direction. +) -> Tensor: + y, x, c = img.shape + startx = x // 2 - cropx // 2 + starty = y // 2 - cropy // 2 + return img[starty:(starty + cropy), startx:(startx + cropx), :] + +# video dataset + +class VideoDataset(Dataset): + def __init__( + self, + folder, + image_size, + channels = 3, + num_frames = 17, + horizontal_flip = False, + force_num_frames = True, + exts = ['gif', 'mp4'] + ): + super().__init__() + folder = Path(folder) + assert folder.is_dir(), f'{str(folder)} must be a folder containing videos' + self.folder = folder + + self.image_size = image_size + self.channels = channels + self.paths = [p for ext in exts for p in folder.glob(f'**/*.{ext}')] + + print(f'{len(self.paths)} training samples found at {folder}') + + self.transform = T.Compose([ + T.Resize(image_size), + T.RandomHorizontalFlip() if horizontal_flip else T.Lambda(identity), + T.CenterCrop(image_size), + T.ToTensor() + ]) + + # functions to transform video path to tensor + + self.gif_to_tensor = partial(gif_to_tensor, channels = self.channels, transform = self.transform) + self.mp4_to_tensor = partial(video_to_tensor, crop_size = self.image_size) + + self.cast_num_frames_fn = partial(cast_num_frames, frames = num_frames) if force_num_frames else identity + + def __len__(self): + return len(self.paths) + + def __getitem__(self, index): + path = self.paths[index] + ext = path.suffix + + if ext == '.gif': + tensor = self.gif_to_tensor(path) + elif ext == '.mp4': + tensor = self.mp4_to_tensor(str(path)) + else: + raise ValueError(f'unknown extension {ext}') + + return self.cast_num_frames_fn(tensor) + +# override dataloader to be able to collate strings + +def collate_tensors_and_strings(data): + if is_bearable(data, List[Tensor]): + return (torch.stack(data),) + + data = zip(*data) + output = [] + + for datum in data: + if is_bearable(datum, Tuple[Tensor, ...]): + datum = torch.stack(datum) + elif is_bearable(datum, Tuple[str, ...]): + datum = list(datum) + else: + raise ValueError('detected invalid type being passed from dataset') + + output.append(datum) + + return tuple(output) + +def DataLoader(*args, **kwargs): + return PytorchDataLoader(*args, collate_fn = collate_tensors_and_strings, **kwargs) diff --git a/magvit2_pytorch/version.py b/magvit2_pytorch/version.py index d9edf17..9ae29b1 100644 --- a/magvit2_pytorch/version.py +++ b/magvit2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.0.33' +__version__ = '0.0.34' diff --git a/setup.py b/setup.py index e98b937..424dd69 100644 --- a/setup.py +++ b/setup.py @@ -25,6 +25,9 @@ 'einops>=0.7.0', 'ema-pytorch', 'kornia', + 'opencv-python', + 'pillow', + 'numpy', 'vector-quantize-pytorch>=1.9.18', 'torch', 'torchvision'