Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
Isi-dev authored Aug 1, 2024
1 parent 34e2d36 commit 00a5b59
Show file tree
Hide file tree
Showing 59 changed files with 6,915 additions and 0 deletions.
3 changes: 3 additions & 0 deletions tools/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .datasets import *
from .modules import *
from .inferences import *
Binary file added tools/__pycache__/__init__.cpython-310.pyc
Binary file not shown.
Binary file added tools/__pycache__/__init__.cpython-39.pyc
Binary file not shown.
2 changes: 2 additions & 0 deletions tools/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .image_dataset import *
from .video_dataset import *
Binary file added tools/datasets/__pycache__/__init__.cpython-310.pyc
Binary file not shown.
Binary file added tools/datasets/__pycache__/__init__.cpython-39.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
86 changes: 86 additions & 0 deletions tools/datasets/image_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import os
import cv2
import torch
import random
import logging
import tempfile
import numpy as np
from copy import copy
from PIL import Image
from io import BytesIO
from torch.utils.data import Dataset
from ...utils.registry_class import DATASETS

@DATASETS.register_class()
class ImageDataset(Dataset):
def __init__(self,
data_list,
data_dir_list,
max_words=1000,
vit_resolution=[224, 224],
resolution=(384, 256),
max_frames=1,
transforms=None,
vit_transforms=None,
**kwargs):

self.max_frames = max_frames
self.resolution = resolution
self.transforms = transforms
self.vit_resolution = vit_resolution
self.vit_transforms = vit_transforms

image_list = []
for item_path, data_dir in zip(data_list, data_dir_list):
lines = open(item_path, 'r').readlines()
lines = [[data_dir, item.strip()] for item in lines]
image_list.extend(lines)
self.image_list = image_list

def __len__(self):
return len(self.image_list)

def __getitem__(self, index):
data_dir, file_path = self.image_list[index]
img_key = file_path.split('|||')[0]
try:
ref_frame, vit_frame, video_data, caption = self._get_image_data(data_dir, file_path)
except Exception as e:
logging.info('{} get frames failed... with error: {}'.format(img_key, e))
caption = ''
img_key = ''
ref_frame = torch.zeros(3, self.resolution[1], self.resolution[0])
vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0])
video_data = torch.zeros(self.max_frames, 3, self.resolution[1], self.resolution[0])
return ref_frame, vit_frame, video_data, caption, img_key

def _get_image_data(self, data_dir, file_path):
frame_list = []
img_key, caption = file_path.split('|||')
file_path = os.path.join(data_dir, img_key)
for _ in range(5):
try:
image = Image.open(file_path)
if image.mode != 'RGB':
image = image.convert('RGB')
frame_list.append(image)
break
except Exception as e:
logging.info('{} read video frame failed with error: {}'.format(img_key, e))
continue

video_data = torch.zeros(self.max_frames, 3, self.resolution[1], self.resolution[0])
try:
if len(frame_list) > 0:
mid_frame = frame_list[0]
vit_frame = self.vit_transforms(mid_frame)
frame_tensor = self.transforms(frame_list)
video_data[:len(frame_list), ...] = frame_tensor
else:
vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0])
except:
vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0])
ref_frame = copy(video_data[0])

return ref_frame, vit_frame, video_data, caption

118 changes: 118 additions & 0 deletions tools/datasets/video_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import os
import cv2
import json
import torch
import random
import logging
import tempfile
import numpy as np
from copy import copy
from PIL import Image
from torch.utils.data import Dataset
from ...utils.registry_class import DATASETS


@DATASETS.register_class()
class VideoDataset(Dataset):
def __init__(self,
data_list,
data_dir_list,
max_words=1000,
resolution=(384, 256),
vit_resolution=(224, 224),
max_frames=16,
sample_fps=8,
transforms=None,
vit_transforms=None,
get_first_frame=False,
**kwargs):

self.max_words = max_words
self.max_frames = max_frames
self.resolution = resolution
self.vit_resolution = vit_resolution
self.sample_fps = sample_fps
self.transforms = transforms
self.vit_transforms = vit_transforms
self.get_first_frame = get_first_frame

image_list = []
for item_path, data_dir in zip(data_list, data_dir_list):
lines = open(item_path, 'r').readlines()
lines = [[data_dir, item] for item in lines]
image_list.extend(lines)
self.image_list = image_list


def __getitem__(self, index):
data_dir, file_path = self.image_list[index]
video_key = file_path.split('|||')[0]
try:
ref_frame, vit_frame, video_data, caption = self._get_video_data(data_dir, file_path)
except Exception as e:
logging.info('{} get frames failed... with error: {}'.format(video_key, e))
caption = ''
video_key = ''
ref_frame = torch.zeros(3, self.resolution[1], self.resolution[0])
vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0])
video_data = torch.zeros(self.max_frames, 3, self.resolution[1], self.resolution[0])
return ref_frame, vit_frame, video_data, caption, video_key


def _get_video_data(self, data_dir, file_path):
video_key, caption = file_path.split('|||')
file_path = os.path.join(data_dir, video_key)

for _ in range(5):
try:
capture = cv2.VideoCapture(file_path)
_fps = capture.get(cv2.CAP_PROP_FPS)
_total_frame_num = capture.get(cv2.CAP_PROP_FRAME_COUNT)
stride = round(_fps / self.sample_fps)
cover_frame_num = (stride * self.max_frames)
if _total_frame_num < cover_frame_num + 5:
start_frame = 0
end_frame = _total_frame_num
else:
start_frame = random.randint(0, _total_frame_num-cover_frame_num-5)
end_frame = start_frame + cover_frame_num

pointer, frame_list = 0, []
while(True):
ret, frame = capture.read()
pointer +=1
if (not ret) or (frame is None): break
if pointer < start_frame: continue
if pointer >= end_frame - 1: break
if (pointer - start_frame) % stride == 0:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame = Image.fromarray(frame)
frame_list.append(frame)
break
except Exception as e:
logging.info('{} read video frame failed with error: {}'.format(video_key, e))
continue

video_data = torch.zeros(self.max_frames, 3, self.resolution[1], self.resolution[0])
if self.get_first_frame:
ref_idx = 0
else:
ref_idx = int(len(frame_list)/2)
try:
if len(frame_list)>0:
mid_frame = copy(frame_list[ref_idx])
vit_frame = self.vit_transforms(mid_frame)
frames = self.transforms(frame_list)
video_data[:len(frame_list), ...] = frames
else:
vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0])
except:
vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0])
ref_frame = copy(frames[ref_idx])

return ref_frame, vit_frame, video_data, caption

def __len__(self):
return len(self.image_list)


2 changes: 2 additions & 0 deletions tools/inferences/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .inference_unianimate_entrance import *
from .inference_unianimate_long_entrance import *
Binary file not shown.
Binary file added tools/inferences/__pycache__/__init__.cpython-39.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading

0 comments on commit 00a5b59

Please sign in to comment.