forked from Isi-dev/ComfyUI-UniAnimate-W
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
59 changed files
with
6,915 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .datasets import * | ||
from .modules import * | ||
from .inferences import * |
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .image_dataset import * | ||
from .video_dataset import * |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
import os | ||
import cv2 | ||
import torch | ||
import random | ||
import logging | ||
import tempfile | ||
import numpy as np | ||
from copy import copy | ||
from PIL import Image | ||
from io import BytesIO | ||
from torch.utils.data import Dataset | ||
from ...utils.registry_class import DATASETS | ||
|
||
@DATASETS.register_class() | ||
class ImageDataset(Dataset): | ||
def __init__(self, | ||
data_list, | ||
data_dir_list, | ||
max_words=1000, | ||
vit_resolution=[224, 224], | ||
resolution=(384, 256), | ||
max_frames=1, | ||
transforms=None, | ||
vit_transforms=None, | ||
**kwargs): | ||
|
||
self.max_frames = max_frames | ||
self.resolution = resolution | ||
self.transforms = transforms | ||
self.vit_resolution = vit_resolution | ||
self.vit_transforms = vit_transforms | ||
|
||
image_list = [] | ||
for item_path, data_dir in zip(data_list, data_dir_list): | ||
lines = open(item_path, 'r').readlines() | ||
lines = [[data_dir, item.strip()] for item in lines] | ||
image_list.extend(lines) | ||
self.image_list = image_list | ||
|
||
def __len__(self): | ||
return len(self.image_list) | ||
|
||
def __getitem__(self, index): | ||
data_dir, file_path = self.image_list[index] | ||
img_key = file_path.split('|||')[0] | ||
try: | ||
ref_frame, vit_frame, video_data, caption = self._get_image_data(data_dir, file_path) | ||
except Exception as e: | ||
logging.info('{} get frames failed... with error: {}'.format(img_key, e)) | ||
caption = '' | ||
img_key = '' | ||
ref_frame = torch.zeros(3, self.resolution[1], self.resolution[0]) | ||
vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0]) | ||
video_data = torch.zeros(self.max_frames, 3, self.resolution[1], self.resolution[0]) | ||
return ref_frame, vit_frame, video_data, caption, img_key | ||
|
||
def _get_image_data(self, data_dir, file_path): | ||
frame_list = [] | ||
img_key, caption = file_path.split('|||') | ||
file_path = os.path.join(data_dir, img_key) | ||
for _ in range(5): | ||
try: | ||
image = Image.open(file_path) | ||
if image.mode != 'RGB': | ||
image = image.convert('RGB') | ||
frame_list.append(image) | ||
break | ||
except Exception as e: | ||
logging.info('{} read video frame failed with error: {}'.format(img_key, e)) | ||
continue | ||
|
||
video_data = torch.zeros(self.max_frames, 3, self.resolution[1], self.resolution[0]) | ||
try: | ||
if len(frame_list) > 0: | ||
mid_frame = frame_list[0] | ||
vit_frame = self.vit_transforms(mid_frame) | ||
frame_tensor = self.transforms(frame_list) | ||
video_data[:len(frame_list), ...] = frame_tensor | ||
else: | ||
vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0]) | ||
except: | ||
vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0]) | ||
ref_frame = copy(video_data[0]) | ||
|
||
return ref_frame, vit_frame, video_data, caption | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
import os | ||
import cv2 | ||
import json | ||
import torch | ||
import random | ||
import logging | ||
import tempfile | ||
import numpy as np | ||
from copy import copy | ||
from PIL import Image | ||
from torch.utils.data import Dataset | ||
from ...utils.registry_class import DATASETS | ||
|
||
|
||
@DATASETS.register_class() | ||
class VideoDataset(Dataset): | ||
def __init__(self, | ||
data_list, | ||
data_dir_list, | ||
max_words=1000, | ||
resolution=(384, 256), | ||
vit_resolution=(224, 224), | ||
max_frames=16, | ||
sample_fps=8, | ||
transforms=None, | ||
vit_transforms=None, | ||
get_first_frame=False, | ||
**kwargs): | ||
|
||
self.max_words = max_words | ||
self.max_frames = max_frames | ||
self.resolution = resolution | ||
self.vit_resolution = vit_resolution | ||
self.sample_fps = sample_fps | ||
self.transforms = transforms | ||
self.vit_transforms = vit_transforms | ||
self.get_first_frame = get_first_frame | ||
|
||
image_list = [] | ||
for item_path, data_dir in zip(data_list, data_dir_list): | ||
lines = open(item_path, 'r').readlines() | ||
lines = [[data_dir, item] for item in lines] | ||
image_list.extend(lines) | ||
self.image_list = image_list | ||
|
||
|
||
def __getitem__(self, index): | ||
data_dir, file_path = self.image_list[index] | ||
video_key = file_path.split('|||')[0] | ||
try: | ||
ref_frame, vit_frame, video_data, caption = self._get_video_data(data_dir, file_path) | ||
except Exception as e: | ||
logging.info('{} get frames failed... with error: {}'.format(video_key, e)) | ||
caption = '' | ||
video_key = '' | ||
ref_frame = torch.zeros(3, self.resolution[1], self.resolution[0]) | ||
vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0]) | ||
video_data = torch.zeros(self.max_frames, 3, self.resolution[1], self.resolution[0]) | ||
return ref_frame, vit_frame, video_data, caption, video_key | ||
|
||
|
||
def _get_video_data(self, data_dir, file_path): | ||
video_key, caption = file_path.split('|||') | ||
file_path = os.path.join(data_dir, video_key) | ||
|
||
for _ in range(5): | ||
try: | ||
capture = cv2.VideoCapture(file_path) | ||
_fps = capture.get(cv2.CAP_PROP_FPS) | ||
_total_frame_num = capture.get(cv2.CAP_PROP_FRAME_COUNT) | ||
stride = round(_fps / self.sample_fps) | ||
cover_frame_num = (stride * self.max_frames) | ||
if _total_frame_num < cover_frame_num + 5: | ||
start_frame = 0 | ||
end_frame = _total_frame_num | ||
else: | ||
start_frame = random.randint(0, _total_frame_num-cover_frame_num-5) | ||
end_frame = start_frame + cover_frame_num | ||
|
||
pointer, frame_list = 0, [] | ||
while(True): | ||
ret, frame = capture.read() | ||
pointer +=1 | ||
if (not ret) or (frame is None): break | ||
if pointer < start_frame: continue | ||
if pointer >= end_frame - 1: break | ||
if (pointer - start_frame) % stride == 0: | ||
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | ||
frame = Image.fromarray(frame) | ||
frame_list.append(frame) | ||
break | ||
except Exception as e: | ||
logging.info('{} read video frame failed with error: {}'.format(video_key, e)) | ||
continue | ||
|
||
video_data = torch.zeros(self.max_frames, 3, self.resolution[1], self.resolution[0]) | ||
if self.get_first_frame: | ||
ref_idx = 0 | ||
else: | ||
ref_idx = int(len(frame_list)/2) | ||
try: | ||
if len(frame_list)>0: | ||
mid_frame = copy(frame_list[ref_idx]) | ||
vit_frame = self.vit_transforms(mid_frame) | ||
frames = self.transforms(frame_list) | ||
video_data[:len(frame_list), ...] = frames | ||
else: | ||
vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0]) | ||
except: | ||
vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0]) | ||
ref_frame = copy(frames[ref_idx]) | ||
|
||
return ref_frame, vit_frame, video_data, caption | ||
|
||
def __len__(self): | ||
return len(self.image_list) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .inference_unianimate_entrance import * | ||
from .inference_unianimate_long_entrance import * |
Binary file not shown.
Binary file not shown.
Binary file added
BIN
+12.6 KB
tools/inferences/__pycache__/inference_unianimate_entrance.cpython-310.pyc
Binary file not shown.
Binary file added
BIN
+12.2 KB
tools/inferences/__pycache__/inference_unianimate_entrance.cpython-39.pyc
Binary file not shown.
Binary file added
BIN
+12.7 KB
tools/inferences/__pycache__/inference_unianimate_long_entrance.cpython-310.pyc
Binary file not shown.
Binary file added
BIN
+12.8 KB
tools/inferences/__pycache__/inference_unianimate_long_entrance.cpython-39.pyc
Binary file not shown.
Oops, something went wrong.