diff --git a/tools/__init__.py b/tools/__init__.py new file mode 100644 index 0000000..33ef13c --- /dev/null +++ b/tools/__init__.py @@ -0,0 +1,3 @@ +from .datasets import * +from .modules import * +from .inferences import * diff --git a/tools/__pycache__/__init__.cpython-310.pyc b/tools/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..74bb827 Binary files /dev/null and b/tools/__pycache__/__init__.cpython-310.pyc differ diff --git a/tools/__pycache__/__init__.cpython-39.pyc b/tools/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000..5f55644 Binary files /dev/null and b/tools/__pycache__/__init__.cpython-39.pyc differ diff --git a/tools/datasets/__init__.py b/tools/datasets/__init__.py new file mode 100644 index 0000000..f1b217f --- /dev/null +++ b/tools/datasets/__init__.py @@ -0,0 +1,2 @@ +from .image_dataset import * +from .video_dataset import * diff --git a/tools/datasets/__pycache__/__init__.cpython-310.pyc b/tools/datasets/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..be3d985 Binary files /dev/null and b/tools/datasets/__pycache__/__init__.cpython-310.pyc differ diff --git a/tools/datasets/__pycache__/__init__.cpython-39.pyc b/tools/datasets/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000..b6726b1 Binary files /dev/null and b/tools/datasets/__pycache__/__init__.cpython-39.pyc differ diff --git a/tools/datasets/__pycache__/image_dataset.cpython-310.pyc b/tools/datasets/__pycache__/image_dataset.cpython-310.pyc new file mode 100644 index 0000000..77f0b69 Binary files /dev/null and b/tools/datasets/__pycache__/image_dataset.cpython-310.pyc differ diff --git a/tools/datasets/__pycache__/image_dataset.cpython-39.pyc b/tools/datasets/__pycache__/image_dataset.cpython-39.pyc new file mode 100644 index 0000000..e7ce1cb Binary files /dev/null and b/tools/datasets/__pycache__/image_dataset.cpython-39.pyc differ diff --git a/tools/datasets/__pycache__/video_dataset.cpython-310.pyc b/tools/datasets/__pycache__/video_dataset.cpython-310.pyc new file mode 100644 index 0000000..f8c0008 Binary files /dev/null and b/tools/datasets/__pycache__/video_dataset.cpython-310.pyc differ diff --git a/tools/datasets/__pycache__/video_dataset.cpython-39.pyc b/tools/datasets/__pycache__/video_dataset.cpython-39.pyc new file mode 100644 index 0000000..ceb6e95 Binary files /dev/null and b/tools/datasets/__pycache__/video_dataset.cpython-39.pyc differ diff --git a/tools/datasets/image_dataset.py b/tools/datasets/image_dataset.py new file mode 100644 index 0000000..d958e48 --- /dev/null +++ b/tools/datasets/image_dataset.py @@ -0,0 +1,86 @@ +import os +import cv2 +import torch +import random +import logging +import tempfile +import numpy as np +from copy import copy +from PIL import Image +from io import BytesIO +from torch.utils.data import Dataset +from ...utils.registry_class import DATASETS + +@DATASETS.register_class() +class ImageDataset(Dataset): + def __init__(self, + data_list, + data_dir_list, + max_words=1000, + vit_resolution=[224, 224], + resolution=(384, 256), + max_frames=1, + transforms=None, + vit_transforms=None, + **kwargs): + + self.max_frames = max_frames + self.resolution = resolution + self.transforms = transforms + self.vit_resolution = vit_resolution + self.vit_transforms = vit_transforms + + image_list = [] + for item_path, data_dir in zip(data_list, data_dir_list): + lines = open(item_path, 'r').readlines() + lines = [[data_dir, item.strip()] for item in lines] + image_list.extend(lines) + self.image_list = image_list + + def __len__(self): + return len(self.image_list) + + def __getitem__(self, index): + data_dir, file_path = self.image_list[index] + img_key = file_path.split('|||')[0] + try: + ref_frame, vit_frame, video_data, caption = self._get_image_data(data_dir, file_path) + except Exception as e: + logging.info('{} get frames failed... with error: {}'.format(img_key, e)) + caption = '' + img_key = '' + ref_frame = torch.zeros(3, self.resolution[1], self.resolution[0]) + vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0]) + video_data = torch.zeros(self.max_frames, 3, self.resolution[1], self.resolution[0]) + return ref_frame, vit_frame, video_data, caption, img_key + + def _get_image_data(self, data_dir, file_path): + frame_list = [] + img_key, caption = file_path.split('|||') + file_path = os.path.join(data_dir, img_key) + for _ in range(5): + try: + image = Image.open(file_path) + if image.mode != 'RGB': + image = image.convert('RGB') + frame_list.append(image) + break + except Exception as e: + logging.info('{} read video frame failed with error: {}'.format(img_key, e)) + continue + + video_data = torch.zeros(self.max_frames, 3, self.resolution[1], self.resolution[0]) + try: + if len(frame_list) > 0: + mid_frame = frame_list[0] + vit_frame = self.vit_transforms(mid_frame) + frame_tensor = self.transforms(frame_list) + video_data[:len(frame_list), ...] = frame_tensor + else: + vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0]) + except: + vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0]) + ref_frame = copy(video_data[0]) + + return ref_frame, vit_frame, video_data, caption + diff --git a/tools/datasets/video_dataset.py b/tools/datasets/video_dataset.py new file mode 100644 index 0000000..cdc45de --- /dev/null +++ b/tools/datasets/video_dataset.py @@ -0,0 +1,118 @@ +import os +import cv2 +import json +import torch +import random +import logging +import tempfile +import numpy as np +from copy import copy +from PIL import Image +from torch.utils.data import Dataset +from ...utils.registry_class import DATASETS + + +@DATASETS.register_class() +class VideoDataset(Dataset): + def __init__(self, + data_list, + data_dir_list, + max_words=1000, + resolution=(384, 256), + vit_resolution=(224, 224), + max_frames=16, + sample_fps=8, + transforms=None, + vit_transforms=None, + get_first_frame=False, + **kwargs): + + self.max_words = max_words + self.max_frames = max_frames + self.resolution = resolution + self.vit_resolution = vit_resolution + self.sample_fps = sample_fps + self.transforms = transforms + self.vit_transforms = vit_transforms + self.get_first_frame = get_first_frame + + image_list = [] + for item_path, data_dir in zip(data_list, data_dir_list): + lines = open(item_path, 'r').readlines() + lines = [[data_dir, item] for item in lines] + image_list.extend(lines) + self.image_list = image_list + + + def __getitem__(self, index): + data_dir, file_path = self.image_list[index] + video_key = file_path.split('|||')[0] + try: + ref_frame, vit_frame, video_data, caption = self._get_video_data(data_dir, file_path) + except Exception as e: + logging.info('{} get frames failed... with error: {}'.format(video_key, e)) + caption = '' + video_key = '' + ref_frame = torch.zeros(3, self.resolution[1], self.resolution[0]) + vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0]) + video_data = torch.zeros(self.max_frames, 3, self.resolution[1], self.resolution[0]) + return ref_frame, vit_frame, video_data, caption, video_key + + + def _get_video_data(self, data_dir, file_path): + video_key, caption = file_path.split('|||') + file_path = os.path.join(data_dir, video_key) + + for _ in range(5): + try: + capture = cv2.VideoCapture(file_path) + _fps = capture.get(cv2.CAP_PROP_FPS) + _total_frame_num = capture.get(cv2.CAP_PROP_FRAME_COUNT) + stride = round(_fps / self.sample_fps) + cover_frame_num = (stride * self.max_frames) + if _total_frame_num < cover_frame_num + 5: + start_frame = 0 + end_frame = _total_frame_num + else: + start_frame = random.randint(0, _total_frame_num-cover_frame_num-5) + end_frame = start_frame + cover_frame_num + + pointer, frame_list = 0, [] + while(True): + ret, frame = capture.read() + pointer +=1 + if (not ret) or (frame is None): break + if pointer < start_frame: continue + if pointer >= end_frame - 1: break + if (pointer - start_frame) % stride == 0: + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frame = Image.fromarray(frame) + frame_list.append(frame) + break + except Exception as e: + logging.info('{} read video frame failed with error: {}'.format(video_key, e)) + continue + + video_data = torch.zeros(self.max_frames, 3, self.resolution[1], self.resolution[0]) + if self.get_first_frame: + ref_idx = 0 + else: + ref_idx = int(len(frame_list)/2) + try: + if len(frame_list)>0: + mid_frame = copy(frame_list[ref_idx]) + vit_frame = self.vit_transforms(mid_frame) + frames = self.transforms(frame_list) + video_data[:len(frame_list), ...] = frames + else: + vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0]) + except: + vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0]) + ref_frame = copy(frames[ref_idx]) + + return ref_frame, vit_frame, video_data, caption + + def __len__(self): + return len(self.image_list) + + diff --git a/tools/inferences/__init__.py b/tools/inferences/__init__.py new file mode 100644 index 0000000..db0383b --- /dev/null +++ b/tools/inferences/__init__.py @@ -0,0 +1,2 @@ +from .inference_unianimate_entrance import * +from .inference_unianimate_long_entrance import * diff --git a/tools/inferences/__pycache__/__init__.cpython-310.pyc b/tools/inferences/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..cac82ef Binary files /dev/null and b/tools/inferences/__pycache__/__init__.cpython-310.pyc differ diff --git a/tools/inferences/__pycache__/__init__.cpython-39.pyc b/tools/inferences/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000..8d0533c Binary files /dev/null and b/tools/inferences/__pycache__/__init__.cpython-39.pyc differ diff --git a/tools/inferences/__pycache__/inference_unianimate_entrance.cpython-310.pyc b/tools/inferences/__pycache__/inference_unianimate_entrance.cpython-310.pyc new file mode 100644 index 0000000..9fc212a Binary files /dev/null and b/tools/inferences/__pycache__/inference_unianimate_entrance.cpython-310.pyc differ diff --git a/tools/inferences/__pycache__/inference_unianimate_entrance.cpython-39.pyc b/tools/inferences/__pycache__/inference_unianimate_entrance.cpython-39.pyc new file mode 100644 index 0000000..dd13e73 Binary files /dev/null and b/tools/inferences/__pycache__/inference_unianimate_entrance.cpython-39.pyc differ diff --git a/tools/inferences/__pycache__/inference_unianimate_long_entrance.cpython-310.pyc b/tools/inferences/__pycache__/inference_unianimate_long_entrance.cpython-310.pyc new file mode 100644 index 0000000..3dba0b9 Binary files /dev/null and b/tools/inferences/__pycache__/inference_unianimate_long_entrance.cpython-310.pyc differ diff --git a/tools/inferences/__pycache__/inference_unianimate_long_entrance.cpython-39.pyc b/tools/inferences/__pycache__/inference_unianimate_long_entrance.cpython-39.pyc new file mode 100644 index 0000000..f1994dd Binary files /dev/null and b/tools/inferences/__pycache__/inference_unianimate_long_entrance.cpython-39.pyc differ diff --git a/tools/inferences/inference_unianimate_entrance.py b/tools/inferences/inference_unianimate_entrance.py new file mode 100644 index 0000000..14ceb8f --- /dev/null +++ b/tools/inferences/inference_unianimate_entrance.py @@ -0,0 +1,546 @@ +''' +/* +*Copyright (c) 2021, Alibaba Group; +*Licensed under the Apache License, Version 2.0 (the "License"); +*you may not use this file except in compliance with the License. +*You may obtain a copy of the License at + +* http://www.apache.org/licenses/LICENSE-2.0 + +*Unless required by applicable law or agreed to in writing, software +*distributed under the License is distributed on an "AS IS" BASIS, +*WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +*See the License for the specific language governing permissions and +*limitations under the License. +*/ +''' + +import os +import re +import os.path as osp +import sys +sys.path.insert(0, '/'.join(osp.realpath(__file__).split('/')[:-4])) +import json +import math +import torch +# import pynvml +import logging +import numpy as np +from PIL import Image +import torch.cuda.amp as amp +from importlib import reload +import torch.distributed as dist +import torch.multiprocessing as mp +import random +from einops import rearrange +import torchvision.transforms as T +from torch.nn.parallel import DistributedDataParallel + +from ...utils import transforms as data +from ..modules.config import cfg +from ...utils.seed import setup_seed +from ...utils.multi_port import find_free_port +from ...utils.assign_cfg import assign_signle_cfg +from ...utils.distributed import generalized_all_gather, all_reduce +from ...utils.video_op import save_i2vgen_video, save_t2vhigen_video_safe, save_video_multiple_conditions_not_gif_horizontal_3col +from ...tools.modules.autoencoder import get_first_stage_encoding +from ...utils.registry_class import INFER_ENGINE, MODEL, EMBEDDER, AUTO_ENCODER, DIFFUSION +from copy import copy +import cv2 + + +# @INFER_ENGINE.register_function() +def inference_unianimate_entrance(steps, useFirstFrame, reference_image, ref_pose, pose_sequence, frame_interval, max_frames, resolution, cfg_update, **kwargs): + for k, v in cfg_update.items(): + if isinstance(v, dict) and k in cfg: + cfg[k].update(v) + else: + cfg[k] = v + if not 'MASTER_ADDR' in os.environ: + os.environ['MASTER_ADDR']='localhost' + os.environ['MASTER_PORT']= find_free_port() + cfg.pmi_rank = int(os.getenv('RANK', 0)) + cfg.pmi_world_size = int(os.getenv('WORLD_SIZE', 1)) + + if cfg.debug: + cfg.gpus_per_machine = 1 + cfg.world_size = 1 + else: + cfg.gpus_per_machine = torch.cuda.device_count() + cfg.world_size = cfg.pmi_world_size * cfg.gpus_per_machine + + if cfg.world_size == 1: + return worker(0, steps, useFirstFrame, reference_image, ref_pose, pose_sequence, frame_interval, max_frames, resolution, cfg, cfg_update) + else: + return mp.spawn(worker, nprocs=cfg.gpus_per_machine, args=(cfg, cfg_update)) + return cfg + + +def make_masked_images(imgs, masks): + masked_imgs = [] + for i, mask in enumerate(masks): + # concatenation + masked_imgs.append(torch.cat([imgs[i] * (1 - mask), (1 - mask)], dim=1)) + return torch.stack(masked_imgs, dim=0) + +def load_video_frames(ref_image_tensor, ref_pose_tensor, pose_tensors, train_trans, vit_transforms, train_trans_pose, max_frames=32, frame_interval=1, resolution=[512, 768], get_first_frame=True, vit_resolution=[224, 224]): + for _ in range(5): + try: + num_poses = len(pose_tensors) + numpyFrames = [] + numpyPoses = [] + + # Convert tensors to numpy arrays and prepare lists + for i in range(num_poses): + frame = ref_image_tensor.squeeze(0).cpu().numpy() # Convert to numpy array + # if i == 0: + # print(f'ref image is ({frame})') + numpyFrames.append(frame) + + pose = pose_tensors[i].squeeze(0).cpu().numpy() # Convert to numpy array + numpyPoses.append(pose) + + # Convert reference pose tensor to numpy array + pose_ref = ref_pose_tensor.squeeze(0).cpu().numpy() # Convert to numpy array + + # Sample max_frames poses for video generation + stride = frame_interval + total_frame_num = len(numpyFrames) + cover_frame_num = (stride * (max_frames - 1) + 1) + + if total_frame_num < cover_frame_num: + print(f'_total_frame_num ({total_frame_num}) is smaller than cover_frame_num ({cover_frame_num}), the sampled frame interval is changed') + start_frame = 0 + end_frame = total_frame_num + stride = max((total_frame_num - 1) // (max_frames - 1), 1) + end_frame = stride * max_frames + else: + start_frame = 0 + end_frame = start_frame + cover_frame_num + + frame_list = [] + dwpose_list = [] + + print(f'end_frame is ({end_frame})') + + for i_index in range(start_frame, end_frame, stride): + if i_index < len(numpyFrames): # Check index within bounds + i_frame = numpyFrames[i_index] + i_dwpose = numpyPoses[i_index] + + # Convert numpy arrays to PIL images + # i_frame = np.clip(i_frame, 0, 1) + i_frame = (i_frame - i_frame.min()) / (i_frame.max() - i_frame.min()) #Trying this in place of clip + i_frame = Image.fromarray((i_frame * 255).astype(np.uint8)) + i_frame = i_frame.convert('RGB') + # i_dwpose = np.clip(i_dwpose, 0, 1) + i_dwpose = (i_dwpose - i_dwpose.min()) / (i_dwpose.max() - i_dwpose.min()) #Trying this in place of clip + i_dwpose = Image.fromarray((i_dwpose * 255).astype(np.uint8)) + i_dwpose = i_dwpose.convert('RGB') + + # if i_index == 0: + # print(f'i_frame is ({np.array(i_frame)})') + + frame_list.append(i_frame) + dwpose_list.append(i_dwpose) + + if frame_list: + # random_ref_frame = np.clip(numpyFrames[0], 0, 1) + random_ref_frame = (numpyFrames[0] - numpyFrames[0].min()) / (numpyFrames[0].max() - numpyFrames[0].min()) #Trying this in place of clip + random_ref_frame = Image.fromarray((random_ref_frame * 255).astype(np.uint8)) + if random_ref_frame.mode != 'RGB': + random_ref_frame = random_ref_frame.convert('RGB') + # random_ref_dwpose = np.clip(pose_ref, 0, 1) + random_ref_dwpose = (pose_ref - pose_ref.min()) / (pose_ref.max() - pose_ref.min()) #Trying this in place of clip + random_ref_dwpose = Image.fromarray((random_ref_dwpose * 255).astype(np.uint8)) + if random_ref_dwpose.mode != 'RGB': + random_ref_dwpose = random_ref_dwpose.convert('RGB') + + # Apply transforms + ref_frame = frame_list[0] + vit_frame = vit_transforms(ref_frame) + random_ref_frame_tmp = train_trans_pose(random_ref_frame) + random_ref_dwpose_tmp = train_trans_pose(random_ref_dwpose) + misc_data_tmp = torch.stack([train_trans_pose(ss) for ss in frame_list], dim=0) + video_data_tmp = torch.stack([train_trans(ss) for ss in frame_list], dim=0) + dwpose_data_tmp = torch.stack([train_trans_pose(ss) for ss in dwpose_list], dim=0) + + # Initialize tensors + video_data = torch.zeros(max_frames, 3, resolution[1], resolution[0]) + dwpose_data = torch.zeros(max_frames, 3, resolution[1], resolution[0]) + misc_data = torch.zeros(max_frames, 3, resolution[1], resolution[0]) + random_ref_frame_data = torch.zeros(max_frames, 3, resolution[1], resolution[0]) + random_ref_dwpose_data = torch.zeros(max_frames, 3, resolution[1], resolution[0]) + + # Copy data to tensors + video_data[:len(frame_list), ...] = video_data_tmp + misc_data[:len(frame_list), ...] = misc_data_tmp + dwpose_data[:len(frame_list), ...] = dwpose_data_tmp + random_ref_frame_data[:, ...] = random_ref_frame_tmp + random_ref_dwpose_data[:, ...] = random_ref_dwpose_tmp + + return vit_frame, video_data, misc_data, dwpose_data, random_ref_frame_data, random_ref_dwpose_data + + except Exception as e: + logging.info(f'Error reading video frame: {e}') + continue + + return None, None, None, None, None, None # Return default values if all attempts fail + + +def worker(gpu, steps, useFirstFrame, reference_image, ref_pose, pose_sequence, frame_interval, max_frames, resolution, cfg, cfg_update): + ''' + Inference worker for each gpu + ''' + for k, v in cfg_update.items(): + if isinstance(v, dict) and k in cfg: + cfg[k].update(v) + else: + cfg[k] = v + + cfg.gpu = gpu + cfg.seed = int(cfg.seed) + cfg.rank = cfg.pmi_rank * cfg.gpus_per_machine + gpu + setup_seed(cfg.seed + cfg.rank) + + if not cfg.debug: + torch.cuda.set_device(gpu) + torch.backends.cudnn.benchmark = True + if hasattr(cfg, "CPU_CLIP_VAE") and cfg.CPU_CLIP_VAE: + torch.backends.cudnn.benchmark = False + if not dist.is_initialized(): + dist.init_process_group(backend='gloo', world_size=cfg.world_size, rank=cfg.rank) + + # [Log] Save logging and make log dir + # log_dir = generalized_all_gather(cfg.log_dir)[0] + inf_name = osp.basename(cfg.cfg_file).split('.')[0] + # test_model = osp.basename(cfg.test_model).split('.')[0].split('_')[-1] + + cfg.log_dir = osp.join(cfg.log_dir, '%s' % (inf_name)) + os.makedirs(cfg.log_dir, exist_ok=True) + log_file = osp.join(cfg.log_dir, 'log_%02d.txt' % (cfg.rank)) + cfg.log_file = log_file + reload(logging) + logging.basicConfig( + level=logging.INFO, + format='[%(asctime)s] %(levelname)s: %(message)s', + handlers=[ + logging.FileHandler(filename=log_file), + logging.StreamHandler(stream=sys.stdout)]) + # logging.info(cfg) + logging.info(f"Running UniAnimate inference on gpu {gpu}") + + # [Diffusion] + diffusion = DIFFUSION.build(cfg.Diffusion) + + # [Data] Data Transform + train_trans = data.Compose([ + data.Resize(cfg.resolution), + data.ToTensor(), + data.Normalize(mean=cfg.mean, std=cfg.std) + ]) + + train_trans_pose = data.Compose([ + data.Resize(cfg.resolution), + data.ToTensor(), + ] + ) + + # Defines transformations for data to be fed into a Vision Transformer (ViT) model. + vit_transforms = T.Compose([ + data.Resize(cfg.vit_resolution), + T.ToTensor(), + T.Normalize(mean=cfg.vit_mean, std=cfg.vit_std)]) + + # [Model] embedder + clip_encoder = EMBEDDER.build(cfg.embedder) + clip_encoder.model.to(gpu) + with torch.no_grad(): + _, _, zero_y = clip_encoder(text="") + + + # [Model] auotoencoder + autoencoder = AUTO_ENCODER.build(cfg.auto_encoder) + autoencoder.eval() # freeze + for param in autoencoder.parameters(): + param.requires_grad = False + autoencoder.cuda() + + # [Model] UNet + if "config" in cfg.UNet: + cfg.UNet["config"] = cfg + cfg.UNet["zero_y"] = zero_y + model = MODEL.build(cfg.UNet) + # Here comes the UniAnimate model + # inferences folder + current_directory = os.path.dirname(os.path.abspath(__file__)) + # tools folder + parent_directory = os.path.dirname(current_directory) + # uniAnimate folder + root_directory = os.path.dirname(parent_directory) + unifiedModel = os.path.join(root_directory, 'checkpoints/unianimate_16f_32f_non_ema_223000.pth ') + state_dict = torch.load(unifiedModel, map_location='cpu') + if 'state_dict' in state_dict: + state_dict = state_dict['state_dict'] + if 'step' in state_dict: + resume_step = state_dict['step'] + else: + resume_step = 0 + status = model.load_state_dict(state_dict, strict=True) + logging.info('Load model from {} with status {}'.format(unifiedModel, status)) + model = model.to(gpu) + model.eval() + if hasattr(cfg, "CPU_CLIP_VAE") and cfg.CPU_CLIP_VAE: + print("Avoiding DistributedDataParallel to reduce memory usage") + model.to(torch.float16) + else: + model = DistributedDataParallel(model, device_ids=[gpu]) if not cfg.debug else model + torch.cuda.empty_cache() + + + # Where the input image and pose images come in + test_list = cfg.test_list_path + num_videos = len(test_list) + logging.info(f'There are {num_videos} videos. with {cfg.round} times') + # test_list = [item for item in test_list for _ in range(cfg.round)] + test_list = [item for _ in range(cfg.round) for item in test_list] + + # for idx, file_path in enumerate(test_list): + + # You can start inputs here for any user interface + # Inputs will be ref_image_key, pose_seq_key, frame_interval, max_frames, resolution + # cfg.frame_interval, ref_image_key, pose_seq_key = file_path[0], file_path[1], file_path[2] + + manual_seed = int(cfg.seed + cfg.rank) + setup_seed(manual_seed) + # logging.info(f"[{idx}]/[{len(test_list)}] Begin to sample {ref_image_key}, pose sequence from {pose_seq_key} init seed {manual_seed} ...") + + # initialize reference_image, pose_sequence, frame_interval, max_frames, resolution_x, + vit_frame, video_data, misc_data, dwpose_data, random_ref_frame_data, random_ref_dwpose_data = load_video_frames(reference_image, ref_pose, pose_sequence, train_trans, vit_transforms, train_trans_pose, max_frames, frame_interval, resolution) + misc_data = misc_data.unsqueeze(0).to(gpu) + vit_frame = vit_frame.unsqueeze(0).to(gpu) + dwpose_data = dwpose_data.unsqueeze(0).to(gpu) + random_ref_frame_data = random_ref_frame_data.unsqueeze(0).to(gpu) + random_ref_dwpose_data = random_ref_dwpose_data.unsqueeze(0).to(gpu) + + + + ### save for visualization + misc_backups = copy(misc_data) + frames_num = misc_data.shape[1] + misc_backups = rearrange(misc_backups, 'b f c h w -> b c f h w') + mv_data_video = [] + + + ### local image (first frame) + image_local = [] + if 'local_image' in cfg.video_compositions: + frames_num = misc_data.shape[1] + bs_vd_local = misc_data.shape[0] + # create a repeated version of the first frame across all frames and assign to image_local + image_local = misc_data[:,:1].clone().repeat(1,frames_num,1,1,1) + image_local_clone = rearrange(image_local, 'b f c h w -> b c f h w', b = bs_vd_local) + image_local = rearrange(image_local, 'b f c h w -> b c f h w', b = bs_vd_local) + if hasattr(cfg, "latent_local_image") and cfg.latent_local_image: + with torch.no_grad(): # Disable gradient calculation + temporal_length = frames_num + # The encoder compresses the input data into a lower-dimensional latent representation, often called a "latent vector" or "encoding." + encoder_posterior = autoencoder.encode(video_data[:,0]) + local_image_data = get_first_stage_encoding(encoder_posterior).detach() #use without affecting the gradients of the original model + image_local = local_image_data.unsqueeze(1).repeat(1,temporal_length,1,1,1) # [10, 16, 4, 64, 40] + + + + ### encode the video_data + # bs_vd = misc_data.shape[0] + misc_data = rearrange(misc_data, 'b f c h w -> (b f) c h w') + # misc_data_list = torch.chunk(misc_data, misc_data.shape[0]//cfg.chunk_size,dim=0) + + + with torch.no_grad(): + + random_ref_frame = [] + if 'randomref' in cfg.video_compositions: + random_ref_frame_clone = rearrange(random_ref_frame_data, 'b f c h w -> b c f h w') + if hasattr(cfg, "latent_random_ref") and cfg.latent_random_ref: + + temporal_length = random_ref_frame_data.shape[1] + encoder_posterior = autoencoder.encode(random_ref_frame_data[:,0].sub(0.5).div_(0.5)) + random_ref_frame_data = get_first_stage_encoding(encoder_posterior).detach() + random_ref_frame_data = random_ref_frame_data.unsqueeze(1).repeat(1,temporal_length,1,1,1) # [10, 16, 4, 64, 40] + + random_ref_frame = rearrange(random_ref_frame_data, 'b f c h w -> b c f h w') + + + if 'dwpose' in cfg.video_compositions: + bs_vd_local = dwpose_data.shape[0] + dwpose_data_clone = rearrange(dwpose_data.clone(), 'b f c h w -> b c f h w', b = bs_vd_local) + if 'randomref_pose' in cfg.video_compositions: + dwpose_data = torch.cat([random_ref_dwpose_data[:,:1], dwpose_data], dim=1) + dwpose_data = rearrange(dwpose_data, 'b f c h w -> b c f h w', b = bs_vd_local) + + + y_visual = [] + if 'image' in cfg.video_compositions: + with torch.no_grad(): + vit_frame = vit_frame.squeeze(1) + y_visual = clip_encoder.encode_image(vit_frame).unsqueeze(1) # [60, 1024] + y_visual0 = y_visual.clone() + + # print(torch.get_default_dtype()) + + with amp.autocast(enabled=True): + # pynvml.nvmlInit() + # handle=pynvml.nvmlDeviceGetHandleByIndex(0) + # meminfo=pynvml.nvmlDeviceGetMemoryInfo(handle) + cur_seed = torch.initial_seed() + # logging.info(f"Current seed {cur_seed} ...") + + print(f"Number of frames to denoise: {frames_num}") + noise = torch.randn([1, 4, frames_num, int(cfg.resolution[1]/cfg.scale), int(cfg.resolution[0]/cfg.scale)]) + noise = noise.to(gpu) + # print(f"noise: {noise.shape}") + + + if hasattr(cfg.Diffusion, "noise_strength"): + b, c, f, _, _= noise.shape + offset_noise = torch.randn(b, c, f, 1, 1, device=noise.device) + noise = noise + cfg.Diffusion.noise_strength * offset_noise + # print(f"offset_noise dtype: {offset_noise.dtype}") + # print(f' offset_noise is ({offset_noise})') + + + + # add a noise prior + noise = diffusion.q_sample(random_ref_frame.clone(), getattr(cfg, "noise_prior_value", 949), noise=noise) + + # construct model inputs (CFG) + full_model_kwargs=[{ + 'y': None, + "local_image": None if len(image_local) == 0 else image_local[:], + 'image': None if len(y_visual) == 0 else y_visual0[:], + 'dwpose': None if len(dwpose_data) == 0 else dwpose_data[:], + 'randomref': None if len(random_ref_frame) == 0 else random_ref_frame[:], + }, + { + 'y': None, + "local_image": None, + 'image': None, + 'randomref': None, + 'dwpose': None, + }] + + # for visualization + full_model_kwargs_vis =[{ + 'y': None, + "local_image": None if len(image_local) == 0 else image_local_clone[:], + 'image': None, + 'dwpose': None if len(dwpose_data_clone) == 0 else dwpose_data_clone[:], + 'randomref': None if len(random_ref_frame) == 0 else random_ref_frame_clone[:, :3], + }, + { + 'y': None, + "local_image": None, + 'image': None, + 'randomref': None, + 'dwpose': None, + }] + + + partial_keys = [ + ['image', 'randomref', "dwpose"], + ] + + if useFirstFrame: + partial_keys = [ + ['image', 'local_image', "dwpose"], + ] + print('Using First Frame Conditioning!') + + + for partial_keys_one in partial_keys: + model_kwargs_one = prepare_model_kwargs(partial_keys = partial_keys_one, + full_model_kwargs = full_model_kwargs, + use_fps_condition = cfg.use_fps_condition) + model_kwargs_one_vis = prepare_model_kwargs(partial_keys = partial_keys_one, + full_model_kwargs = full_model_kwargs_vis, + use_fps_condition = cfg.use_fps_condition) + noise_one = noise + + if hasattr(cfg, "CPU_CLIP_VAE") and cfg.CPU_CLIP_VAE: + clip_encoder.cpu() # add this line + autoencoder.cpu() # add this line + torch.cuda.empty_cache() # add this line + + # print(f' noise_one is ({noise_one})') + + + video_data = diffusion.ddim_sample_loop( + noise=noise_one, + model=model.eval(), + model_kwargs=model_kwargs_one, + guide_scale=cfg.guide_scale, + ddim_timesteps=steps, + eta=0.0) + + if hasattr(cfg, "CPU_CLIP_VAE") and cfg.CPU_CLIP_VAE: + # if run forward of autoencoder or clip_encoder second times, load them again + clip_encoder.cuda() + autoencoder.cuda() + video_data = 1. / cfg.scale_factor * video_data + video_data = rearrange(video_data, 'b c f h w -> (b f) c h w') + chunk_size = min(cfg.decoder_bs, video_data.shape[0]) + video_data_list = torch.chunk(video_data, video_data.shape[0]//chunk_size, dim=0) + decode_data = [] + for vd_data in video_data_list: + gen_frames = autoencoder.decode(vd_data) + decode_data.append(gen_frames) + video_data = torch.cat(decode_data, dim=0) + video_data = rearrange(video_data, '(b f) c h w -> b c f h w', b = cfg.batch_size).float() + + # Check sth + + # print(f' video_data is of shape ({video_data.shape})') + # print(f' video_data is ({video_data})') + + del model_kwargs_one_vis[0][list(model_kwargs_one_vis[0].keys())[0]] + del model_kwargs_one_vis[1][list(model_kwargs_one_vis[1].keys())[0]] + + video_data = extract_image_tensors(video_data.cpu(), cfg.mean, cfg.std) + + # synchronize to finish some processes + if not cfg.debug: + torch.cuda.synchronize() + dist.barrier() + + return video_data + +@torch.no_grad() +def extract_image_tensors(video_tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]): + # Unnormalize the video tensor + mean = torch.tensor(mean, device=video_tensor.device).view(1, -1, 1, 1, 1) # ncfhw + std = torch.tensor(std, device=video_tensor.device).view(1, -1, 1, 1, 1) # ncfhw + video_tensor = video_tensor.mul_(std).add_(mean) # unnormalize back to [0,1] + video_tensor.clamp_(0, 1) + + images = rearrange(video_tensor, 'b c f h w -> b f h w c') + images = images.squeeze(0) + images_t = [] + for img in images: + img_array = np.array(img) # Convert PIL Image to numpy array + img_tensor = torch.from_numpy(img_array).permute(2, 0, 1).unsqueeze(0).float() # Convert to tensor and CHW format + img_tensor = img_tensor.permute(0, 2, 3, 1) + images_t.append(img_tensor) + + logging.info('Images data extracted!') + images_t = torch.cat(images_t, dim=0) + return images_t + +def prepare_model_kwargs(partial_keys, full_model_kwargs, use_fps_condition=False): + if use_fps_condition is True: + partial_keys.append('fps') + partial_model_kwargs = [{}, {}] + for partial_key in partial_keys: + partial_model_kwargs[0][partial_key] = full_model_kwargs[0][partial_key] + partial_model_kwargs[1][partial_key] = full_model_kwargs[1][partial_key] + return partial_model_kwargs \ No newline at end of file diff --git a/tools/inferences/inference_unianimate_long_entrance.py b/tools/inferences/inference_unianimate_long_entrance.py new file mode 100644 index 0000000..5e841f5 --- /dev/null +++ b/tools/inferences/inference_unianimate_long_entrance.py @@ -0,0 +1,508 @@ +''' +/* +*Copyright (c) 2021, Alibaba Group; +*Licensed under the Apache License, Version 2.0 (the "License"); +*you may not use this file except in compliance with the License. +*You may obtain a copy of the License at + +* http://www.apache.org/licenses/LICENSE-2.0 + +*Unless required by applicable law or agreed to in writing, software +*distributed under the License is distributed on an "AS IS" BASIS, +*WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +*See the License for the specific language governing permissions and +*limitations under the License. +*/ +''' + +import os +import re +import os.path as osp +import sys +sys.path.insert(0, '/'.join(osp.realpath(__file__).split('/')[:-4])) +import json +import math +import torch +# import pynvml +import logging +import cv2 +import numpy as np +from PIL import Image +from tqdm import tqdm +import torch.cuda.amp as amp +from importlib import reload +import torch.distributed as dist +import torch.multiprocessing as mp +import random +from einops import rearrange +import torchvision.transforms as T +import torchvision.transforms.functional as TF +from torch.nn.parallel import DistributedDataParallel + +from ...utils import transforms as data +from ..modules.config import cfg +from ...utils.seed import setup_seed +from ...utils.multi_port import find_free_port +from ...utils.assign_cfg import assign_signle_cfg +from ...utils.distributed import generalized_all_gather, all_reduce +from ...utils.video_op import save_i2vgen_video, save_t2vhigen_video_safe, save_video_multiple_conditions_not_gif_horizontal_3col +from ...tools.modules.autoencoder import get_first_stage_encoding +from ...utils.registry_class import INFER_ENGINE, MODEL, EMBEDDER, AUTO_ENCODER, DIFFUSION +from copy import copy +import cv2 + + +@INFER_ENGINE.register_function() +def inference_unianimate_long_entrance(cfg_update, **kwargs): + for k, v in cfg_update.items(): + if isinstance(v, dict) and k in cfg: + cfg[k].update(v) + else: + cfg[k] = v + + if not 'MASTER_ADDR' in os.environ: + os.environ['MASTER_ADDR']='localhost' + os.environ['MASTER_PORT']= find_free_port() + cfg.pmi_rank = int(os.getenv('RANK', 0)) + cfg.pmi_world_size = int(os.getenv('WORLD_SIZE', 1)) + + if cfg.debug: + cfg.gpus_per_machine = 1 + cfg.world_size = 1 + else: + cfg.gpus_per_machine = torch.cuda.device_count() + cfg.world_size = cfg.pmi_world_size * cfg.gpus_per_machine + + if cfg.world_size == 1: + worker(0, cfg, cfg_update) + else: + mp.spawn(worker, nprocs=cfg.gpus_per_machine, args=(cfg, cfg_update)) + return cfg + + +def make_masked_images(imgs, masks): + masked_imgs = [] + for i, mask in enumerate(masks): + # concatenation + masked_imgs.append(torch.cat([imgs[i] * (1 - mask), (1 - mask)], dim=1)) + return torch.stack(masked_imgs, dim=0) + +def load_video_frames(ref_image_path, pose_file_path, train_trans, vit_transforms, train_trans_pose, max_frames=32, frame_interval = 1, resolution=[512, 768], get_first_frame=True, vit_resolution=[224, 224]): + + for _ in range(5): + try: + dwpose_all = {} + frames_all = {} + for ii_index in sorted(os.listdir(pose_file_path)): + if ii_index != "ref_pose.jpg": + dwpose_all[ii_index] = Image.open(pose_file_path+"/"+ii_index) + frames_all[ii_index] = Image.fromarray(cv2.cvtColor(cv2.imread(ref_image_path),cv2.COLOR_BGR2RGB)) + # frames_all[ii_index] = Image.open(ref_image_path) + + pose_ref = Image.open(os.path.join(pose_file_path, "ref_pose.jpg")) + first_eq_ref = False + + # sample max_frames poses for video generation + stride = frame_interval + _total_frame_num = len(frames_all) + if max_frames == "None": + max_frames = (_total_frame_num-1)//frame_interval + 1 + cover_frame_num = (stride * (max_frames-1)+1) + if _total_frame_num < cover_frame_num: + print('_total_frame_num is smaller than cover_frame_num, the sampled frame interval is changed') + start_frame = 0 # we set start_frame = 0 because the pose alignment is performed on the first frame + end_frame = _total_frame_num + stride = max((_total_frame_num-1//(max_frames-1)),1) + end_frame = stride*max_frames + else: + start_frame = 0 # we set start_frame = 0 because the pose alignment is performed on the first frame + end_frame = start_frame + cover_frame_num + + frame_list = [] + dwpose_list = [] + random_ref_frame = frames_all[list(frames_all.keys())[0]] + if random_ref_frame.mode != 'RGB': + random_ref_frame = random_ref_frame.convert('RGB') + random_ref_dwpose = pose_ref + if random_ref_dwpose.mode != 'RGB': + random_ref_dwpose = random_ref_dwpose.convert('RGB') + for i_index in range(start_frame, end_frame, stride): + if i_index == start_frame and first_eq_ref: + i_key = list(frames_all.keys())[i_index] + i_frame = frames_all[i_key] + + if i_frame.mode != 'RGB': + i_frame = i_frame.convert('RGB') + i_dwpose = frames_pose_ref + if i_dwpose.mode != 'RGB': + i_dwpose = i_dwpose.convert('RGB') + frame_list.append(i_frame) + dwpose_list.append(i_dwpose) + else: + # added + if first_eq_ref: + i_index = i_index - stride + + i_key = list(frames_all.keys())[i_index] + i_frame = frames_all[i_key] + if i_frame.mode != 'RGB': + i_frame = i_frame.convert('RGB') + i_dwpose = dwpose_all[i_key] + if i_dwpose.mode != 'RGB': + i_dwpose = i_dwpose.convert('RGB') + frame_list.append(i_frame) + dwpose_list.append(i_dwpose) + have_frames = len(frame_list)>0 + middle_indix = 0 + if have_frames: + ref_frame = frame_list[middle_indix] + vit_frame = vit_transforms(ref_frame) + random_ref_frame_tmp = train_trans_pose(random_ref_frame) + random_ref_dwpose_tmp = train_trans_pose(random_ref_dwpose) + misc_data_tmp = torch.stack([train_trans_pose(ss) for ss in frame_list], dim=0) + video_data_tmp = torch.stack([train_trans(ss) for ss in frame_list], dim=0) + dwpose_data_tmp = torch.stack([train_trans_pose(ss) for ss in dwpose_list], dim=0) + + video_data = torch.zeros(max_frames, 3, resolution[1], resolution[0]) + dwpose_data = torch.zeros(max_frames, 3, resolution[1], resolution[0]) + misc_data = torch.zeros(max_frames, 3, resolution[1], resolution[0]) + random_ref_frame_data = torch.zeros(max_frames, 3, resolution[1], resolution[0]) # [32, 3, 512, 768] + random_ref_dwpose_data = torch.zeros(max_frames, 3, resolution[1], resolution[0]) + if have_frames: + video_data[:len(frame_list), ...] = video_data_tmp + misc_data[:len(frame_list), ...] = misc_data_tmp + dwpose_data[:len(frame_list), ...] = dwpose_data_tmp + random_ref_frame_data[:,...] = random_ref_frame_tmp + random_ref_dwpose_data[:,...] = random_ref_dwpose_tmp + + break + + except Exception as e: + logging.info('{} read video frame failed with error: {}'.format(pose_file_path, e)) + continue + + return vit_frame, video_data, misc_data, dwpose_data, random_ref_frame_data, random_ref_dwpose_data, max_frames + + + +def worker(gpu, cfg, cfg_update): + ''' + Inference worker for each gpu + ''' + for k, v in cfg_update.items(): + if isinstance(v, dict) and k in cfg: + cfg[k].update(v) + else: + cfg[k] = v + + cfg.gpu = gpu + cfg.seed = int(cfg.seed) + cfg.rank = cfg.pmi_rank * cfg.gpus_per_machine + gpu + setup_seed(cfg.seed + cfg.rank) + + if not cfg.debug: + torch.cuda.set_device(gpu) + torch.backends.cudnn.benchmark = True + if hasattr(cfg, "CPU_CLIP_VAE") and cfg.CPU_CLIP_VAE: + torch.backends.cudnn.benchmark = False + dist.init_process_group(backend='nccl', world_size=cfg.world_size, rank=cfg.rank) + + # [Log] Save logging and make log dir + log_dir = generalized_all_gather(cfg.log_dir)[0] + inf_name = osp.basename(cfg.cfg_file).split('.')[0] + test_model = osp.basename(cfg.test_model).split('.')[0].split('_')[-1] + + cfg.log_dir = osp.join(cfg.log_dir, '%s' % (inf_name)) + os.makedirs(cfg.log_dir, exist_ok=True) + log_file = osp.join(cfg.log_dir, 'log_%02d.txt' % (cfg.rank)) + cfg.log_file = log_file + reload(logging) + logging.basicConfig( + level=logging.INFO, + format='[%(asctime)s] %(levelname)s: %(message)s', + handlers=[ + logging.FileHandler(filename=log_file), + logging.StreamHandler(stream=sys.stdout)]) + logging.info(cfg) + logging.info(f"Running UniAnimate inference on gpu {gpu}") + + # [Diffusion] + diffusion = DIFFUSION.build(cfg.Diffusion) + + # [Data] Data Transform + train_trans = data.Compose([ + data.Resize(cfg.resolution), + data.ToTensor(), + data.Normalize(mean=cfg.mean, std=cfg.std) + ]) + + train_trans_pose = data.Compose([ + data.Resize(cfg.resolution), + data.ToTensor(), + ] + ) + + vit_transforms = T.Compose([ + data.Resize(cfg.vit_resolution), + T.ToTensor(), + T.Normalize(mean=cfg.vit_mean, std=cfg.vit_std)]) + + # [Model] embedder + clip_encoder = EMBEDDER.build(cfg.embedder) + clip_encoder.model.to(gpu) + with torch.no_grad(): + _, _, zero_y = clip_encoder(text="") + + + # [Model] auotoencoder + autoencoder = AUTO_ENCODER.build(cfg.auto_encoder) + autoencoder.eval() # freeze + for param in autoencoder.parameters(): + param.requires_grad = False + autoencoder.cuda() + + # [Model] UNet + if "config" in cfg.UNet: + cfg.UNet["config"] = cfg + cfg.UNet["zero_y"] = zero_y + model = MODEL.build(cfg.UNet) + state_dict = torch.load(cfg.test_model, map_location='cpu') + if 'state_dict' in state_dict: + state_dict = state_dict['state_dict'] + if 'step' in state_dict: + resume_step = state_dict['step'] + else: + resume_step = 0 + status = model.load_state_dict(state_dict, strict=True) + logging.info('Load model from {} with status {}'.format(cfg.test_model, status)) + model = model.to(gpu) + model.eval() + if hasattr(cfg, "CPU_CLIP_VAE") and cfg.CPU_CLIP_VAE: + model.to(torch.float16) + else: + model = DistributedDataParallel(model, device_ids=[gpu]) if not cfg.debug else model + torch.cuda.empty_cache() + + + + test_list = cfg.test_list_path + num_videos = len(test_list) + logging.info(f'There are {num_videos} videos. with {cfg.round} times') + test_list = [item for _ in range(cfg.round) for item in test_list] + + for idx, file_path in enumerate(test_list): + cfg.frame_interval, ref_image_key, pose_seq_key = file_path[0], file_path[1], file_path[2] + + manual_seed = int(cfg.seed + cfg.rank + idx//num_videos) + setup_seed(manual_seed) + logging.info(f"[{idx}]/[{len(test_list)}] Begin to sample {ref_image_key}, pose sequence from {pose_seq_key} init seed {manual_seed} ...") + + + vit_frame, video_data, misc_data, dwpose_data, random_ref_frame_data, random_ref_dwpose_data, max_frames = load_video_frames(ref_image_key, pose_seq_key, train_trans, vit_transforms, train_trans_pose, max_frames=cfg.max_frames, frame_interval =cfg.frame_interval, resolution=cfg.resolution) + cfg.max_frames_new = max_frames + misc_data = misc_data.unsqueeze(0).to(gpu) + vit_frame = vit_frame.unsqueeze(0).to(gpu) + dwpose_data = dwpose_data.unsqueeze(0).to(gpu) + random_ref_frame_data = random_ref_frame_data.unsqueeze(0).to(gpu) + random_ref_dwpose_data = random_ref_dwpose_data.unsqueeze(0).to(gpu) + + ### save for visualization + misc_backups = copy(misc_data) + frames_num = misc_data.shape[1] + misc_backups = rearrange(misc_backups, 'b f c h w -> b c f h w') + mv_data_video = [] + + + ### local image (first frame) + image_local = [] + if 'local_image' in cfg.video_compositions: + frames_num = misc_data.shape[1] + bs_vd_local = misc_data.shape[0] + image_local = misc_data[:,:1].clone().repeat(1,frames_num,1,1,1) + image_local_clone = rearrange(image_local, 'b f c h w -> b c f h w', b = bs_vd_local) + image_local = rearrange(image_local, 'b f c h w -> b c f h w', b = bs_vd_local) + if hasattr(cfg, "latent_local_image") and cfg.latent_local_image: + with torch.no_grad(): + temporal_length = frames_num + encoder_posterior = autoencoder.encode(video_data[:,0]) + local_image_data = get_first_stage_encoding(encoder_posterior).detach() + image_local = local_image_data.unsqueeze(1).repeat(1,temporal_length,1,1,1) # [10, 16, 4, 64, 40] + + + + ### encode the video_data + bs_vd = misc_data.shape[0] + misc_data = rearrange(misc_data, 'b f c h w -> (b f) c h w') + misc_data_list = torch.chunk(misc_data, misc_data.shape[0]//cfg.chunk_size,dim=0) + + + with torch.no_grad(): + + random_ref_frame = [] + if 'randomref' in cfg.video_compositions: + random_ref_frame_clone = rearrange(random_ref_frame_data, 'b f c h w -> b c f h w') + if hasattr(cfg, "latent_random_ref") and cfg.latent_random_ref: + + temporal_length = random_ref_frame_data.shape[1] + encoder_posterior = autoencoder.encode(random_ref_frame_data[:,0].sub(0.5).div_(0.5)) + random_ref_frame_data = get_first_stage_encoding(encoder_posterior).detach() + random_ref_frame_data = random_ref_frame_data.unsqueeze(1).repeat(1,temporal_length,1,1,1) # [10, 16, 4, 64, 40] + + random_ref_frame = rearrange(random_ref_frame_data, 'b f c h w -> b c f h w') + + + if 'dwpose' in cfg.video_compositions: + bs_vd_local = dwpose_data.shape[0] + dwpose_data_clone = rearrange(dwpose_data.clone(), 'b f c h w -> b c f h w', b = bs_vd_local) + if 'randomref_pose' in cfg.video_compositions: + dwpose_data = torch.cat([random_ref_dwpose_data[:,:1], dwpose_data], dim=1) + dwpose_data = rearrange(dwpose_data, 'b f c h w -> b c f h w', b = bs_vd_local) + + + y_visual = [] + if 'image' in cfg.video_compositions: + with torch.no_grad(): + vit_frame = vit_frame.squeeze(1) + y_visual = clip_encoder.encode_image(vit_frame).unsqueeze(1) # [60, 1024] + y_visual0 = y_visual.clone() + + + with amp.autocast(enabled=True): + # pynvml.nvmlInit() + # handle=pynvml.nvmlDeviceGetHandleByIndex(0) + # meminfo=pynvml.nvmlDeviceGetMemoryInfo(handle) + cur_seed = torch.initial_seed() + logging.info(f"Current seed {cur_seed} ..., cfg.max_frames_new: {cfg.max_frames_new} ....") + + noise = torch.randn([1, 4, cfg.max_frames_new, int(cfg.resolution[1]/cfg.scale), int(cfg.resolution[0]/cfg.scale)]) + noise = noise.to(gpu) + + # add a noise prior + noise = diffusion.q_sample(random_ref_frame.clone(), getattr(cfg, "noise_prior_value", 939), noise=noise) + + if hasattr(cfg.Diffusion, "noise_strength"): + b, c, f, _, _= noise.shape + offset_noise = torch.randn(b, c, f, 1, 1, device=noise.device) + noise = noise + cfg.Diffusion.noise_strength * offset_noise + + # construct model inputs (CFG) + full_model_kwargs=[{ + 'y': None, + "local_image": None if len(image_local) == 0 else image_local[:], + 'image': None if len(y_visual) == 0 else y_visual0[:], + 'dwpose': None if len(dwpose_data) == 0 else dwpose_data[:], + 'randomref': None if len(random_ref_frame) == 0 else random_ref_frame[:], + }, + { + 'y': None, + "local_image": None, + 'image': None, + 'randomref': None, + 'dwpose': None, + }] + + # for visualization + full_model_kwargs_vis =[{ + 'y': None, + "local_image": None if len(image_local) == 0 else image_local_clone[:], + 'image': None, + 'dwpose': None if len(dwpose_data_clone) == 0 else dwpose_data_clone[:], + 'randomref': None if len(random_ref_frame) == 0 else random_ref_frame_clone[:, :3], + }, + { + 'y': None, + "local_image": None, + 'image': None, + 'randomref': None, + 'dwpose': None, + }] + + + partial_keys = [ + ['image', 'randomref', "dwpose"], + ] + if hasattr(cfg, "partial_keys") and cfg.partial_keys: + partial_keys = cfg.partial_keys + + for partial_keys_one in partial_keys: + model_kwargs_one = prepare_model_kwargs(partial_keys = partial_keys_one, + full_model_kwargs = full_model_kwargs, + use_fps_condition = cfg.use_fps_condition) + model_kwargs_one_vis = prepare_model_kwargs(partial_keys = partial_keys_one, + full_model_kwargs = full_model_kwargs_vis, + use_fps_condition = cfg.use_fps_condition) + noise_one = noise + + if hasattr(cfg, "CPU_CLIP_VAE") and cfg.CPU_CLIP_VAE: + clip_encoder.cpu() # add this line + autoencoder.cpu() # add this line + torch.cuda.empty_cache() # add this line + + video_data = diffusion.ddim_sample_loop( + noise=noise_one, + context_size=cfg.context_size, + context_stride=cfg.context_stride, + context_overlap=cfg.context_overlap, + model=model.eval(), + model_kwargs=model_kwargs_one, + guide_scale=cfg.guide_scale, + ddim_timesteps=cfg.ddim_timesteps, + eta=0.0, + context_batch_size=getattr(cfg, "context_batch_size", 1) + ) + + if hasattr(cfg, "CPU_CLIP_VAE") and cfg.CPU_CLIP_VAE: + # if run forward of autoencoder or clip_encoder second times, load them again + clip_encoder.cuda() + autoencoder.cuda() + + + video_data = 1. / cfg.scale_factor * video_data # [1, 4, h, w] + video_data = rearrange(video_data, 'b c f h w -> (b f) c h w') + chunk_size = min(cfg.decoder_bs, video_data.shape[0]) + video_data_list = torch.chunk(video_data, video_data.shape[0]//chunk_size, dim=0) + decode_data = [] + for vd_data in video_data_list: + gen_frames = autoencoder.decode(vd_data) + decode_data.append(gen_frames) + video_data = torch.cat(decode_data, dim=0) + video_data = rearrange(video_data, '(b f) c h w -> b c f h w', b = cfg.batch_size).float() + + text_size = cfg.resolution[-1] + cap_name = re.sub(r'[^\w\s]', '', ref_image_key.split("/")[-1].split('.')[0]) # .replace(' ', '_') + name = f'seed_{cur_seed}' + for ii in partial_keys_one: + name = name + "_" + ii + file_name = f'rank_{cfg.world_size:02d}_{cfg.rank:02d}_{idx:02d}_{name}_{cap_name}_{cfg.resolution[1]}x{cfg.resolution[0]}.mp4' + local_path = os.path.join(cfg.log_dir, f'{file_name}') + os.makedirs(os.path.dirname(local_path), exist_ok=True) + captions = "human" + del model_kwargs_one_vis[0][list(model_kwargs_one_vis[0].keys())[0]] + del model_kwargs_one_vis[1][list(model_kwargs_one_vis[1].keys())[0]] + + save_video_multiple_conditions_not_gif_horizontal_3col(local_path, video_data.cpu(), model_kwargs_one_vis, misc_backups, + cfg.mean, cfg.std, nrow=1, save_fps=cfg.save_fps) + + # try: + # save_t2vhigen_video_safe(local_path, video_data.cpu(), captions, cfg.mean, cfg.std, text_size) + # logging.info('Save video to dir %s:' % (local_path)) + # except Exception as e: + # logging.info(f'Step: save text or video error with {e}') + + logging.info('Congratulations! The inference is completed!') + # synchronize to finish some processes + if not cfg.debug: + torch.cuda.synchronize() + dist.barrier() + +def prepare_model_kwargs(partial_keys, full_model_kwargs, use_fps_condition=False): + + if use_fps_condition is True: + partial_keys.append('fps') + + partial_model_kwargs = [{}, {}] + for partial_key in partial_keys: + partial_model_kwargs[0][partial_key] = full_model_kwargs[0][partial_key] + partial_model_kwargs[1][partial_key] = full_model_kwargs[1][partial_key] + + return partial_model_kwargs diff --git a/tools/modules/__init__.py b/tools/modules/__init__.py new file mode 100644 index 0000000..db82a43 --- /dev/null +++ b/tools/modules/__init__.py @@ -0,0 +1,7 @@ +from .clip_embedder import FrozenOpenCLIPEmbedder +from .autoencoder import DiagonalGaussianDistribution, AutoencoderKL +from .clip_embedder import * +from .autoencoder import * +from .unet import * +from .diffusions import * +from .embedding_manager import * diff --git a/tools/modules/__pycache__/__init__.cpython-310.pyc b/tools/modules/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..5502869 Binary files /dev/null and b/tools/modules/__pycache__/__init__.cpython-310.pyc differ diff --git a/tools/modules/__pycache__/__init__.cpython-39.pyc b/tools/modules/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000..c4c514b Binary files /dev/null and b/tools/modules/__pycache__/__init__.cpython-39.pyc differ diff --git a/tools/modules/__pycache__/autoencoder.cpython-310.pyc b/tools/modules/__pycache__/autoencoder.cpython-310.pyc new file mode 100644 index 0000000..87ec4ad Binary files /dev/null and b/tools/modules/__pycache__/autoencoder.cpython-310.pyc differ diff --git a/tools/modules/__pycache__/autoencoder.cpython-39.pyc b/tools/modules/__pycache__/autoencoder.cpython-39.pyc new file mode 100644 index 0000000..e4ffeb8 Binary files /dev/null and b/tools/modules/__pycache__/autoencoder.cpython-39.pyc differ diff --git a/tools/modules/__pycache__/clip_embedder.cpython-310.pyc b/tools/modules/__pycache__/clip_embedder.cpython-310.pyc new file mode 100644 index 0000000..b534ca1 Binary files /dev/null and b/tools/modules/__pycache__/clip_embedder.cpython-310.pyc differ diff --git a/tools/modules/__pycache__/clip_embedder.cpython-39.pyc b/tools/modules/__pycache__/clip_embedder.cpython-39.pyc new file mode 100644 index 0000000..73c8c16 Binary files /dev/null and b/tools/modules/__pycache__/clip_embedder.cpython-39.pyc differ diff --git a/tools/modules/__pycache__/config.cpython-310.pyc b/tools/modules/__pycache__/config.cpython-310.pyc new file mode 100644 index 0000000..33c5510 Binary files /dev/null and b/tools/modules/__pycache__/config.cpython-310.pyc differ diff --git a/tools/modules/__pycache__/config.cpython-39.pyc b/tools/modules/__pycache__/config.cpython-39.pyc new file mode 100644 index 0000000..6d86594 Binary files /dev/null and b/tools/modules/__pycache__/config.cpython-39.pyc differ diff --git a/tools/modules/__pycache__/embedding_manager.cpython-310.pyc b/tools/modules/__pycache__/embedding_manager.cpython-310.pyc new file mode 100644 index 0000000..9704a22 Binary files /dev/null and b/tools/modules/__pycache__/embedding_manager.cpython-310.pyc differ diff --git a/tools/modules/__pycache__/embedding_manager.cpython-39.pyc b/tools/modules/__pycache__/embedding_manager.cpython-39.pyc new file mode 100644 index 0000000..5a41dc5 Binary files /dev/null and b/tools/modules/__pycache__/embedding_manager.cpython-39.pyc differ diff --git a/tools/modules/autoencoder.py b/tools/modules/autoencoder.py new file mode 100644 index 0000000..756d188 --- /dev/null +++ b/tools/modules/autoencoder.py @@ -0,0 +1,698 @@ +import os +import torch +import logging +import collections +import numpy as np +import torch.nn as nn +import torch.nn.functional as F + +from ...utils.registry_class import AUTO_ENCODER,DISTRIBUTION + + +def nonlinearity(x): + # swish + return x*torch.sigmoid(x) + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + + +@torch.no_grad() +def get_first_stage_encoding(encoder_posterior, scale_factor=0.18215): + if isinstance(encoder_posterior, DiagonalGaussianDistribution): + z = encoder_posterior.sample() + elif isinstance(encoder_posterior, torch.Tensor): + z = encoder_posterior + else: + raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented") + return scale_factor * z + + +@AUTO_ENCODER.register_class() +class AutoencoderKL(nn.Module): + def __init__(self, + ddconfig, + embed_dim, + pretrained=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + ema_decay=None, + learn_logvar=False, + use_vid_decoder=False, + **kwargs): + super().__init__() + self.learn_logvar = learn_logvar + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + assert ddconfig["double_z"] + self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + self.embed_dim = embed_dim + if colorize_nlabels is not None: + assert type(colorize_nlabels)==int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + + self.use_ema = ema_decay is not None + + if pretrained is not None: + # modules + current_directory = os.path.dirname(os.path.abspath(__file__)) + # tools + parent_directory = os.path.dirname(current_directory) + # uniAnimate + root_directory = os.path.dirname(parent_directory) + pretrained = os.path.join(root_directory, 'checkpoints/v2-1_512-ema-pruned.ckpt') + self.init_from_ckpt(pretrained, ignore_keys=ignore_keys) + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + sd_new = collections.OrderedDict() + for k in keys: + if k.find('first_stage_model') >= 0: + k_new = k.split('first_stage_model.')[-1] + sd_new[k_new] = sd[k] + self.load_state_dict(sd_new, strict=True) + logging.info(f"Restored from {path}") + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self) + + def encode(self, x): + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def encode_firsr_stage(self, x, scale_factor=1.0): + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + z = get_first_stage_encoding(posterior, scale_factor) + return z + + def encode_ms(self, x): + hs = self.encoder(x, True) + h = hs[-1] + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + hs[-1] = h + return hs + + def decode(self, z, **kwargs): + z = self.post_quant_conv(z) + dec = self.decoder(z, **kwargs) + return dec + + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() + return x + + def get_last_layer(self): + return self.decoder.conv_out.weight + + @torch.no_grad() + def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if not only_inputs: + xrec, posterior = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["samples"] = self.decode(torch.randn_like(posterior.sample())) + log["reconstructions"] = xrec + if log_ema or self.use_ema: + with self.ema_scope(): + xrec_ema, posterior_ema = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec_ema.shape[1] > 3 + xrec_ema = self.to_rgb(xrec_ema) + log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample())) + log["reconstructions_ema"] = xrec_ema + log["inputs"] = x + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + return x + + +@AUTO_ENCODER.register_class() +class AutoencoderVideo(AutoencoderKL): + def __init__(self, + ddconfig, + embed_dim, + pretrained=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + ema_decay=None, + use_vid_decoder=True, + learn_logvar=False, + **kwargs): + use_vid_decoder = True + super().__init__(ddconfig, embed_dim, pretrained, ignore_keys, image_key, colorize_nlabels, monitor, ema_decay, learn_logvar, use_vid_decoder, **kwargs) + + def decode(self, z, **kwargs): + # z = self.post_quant_conv(z) + dec = self.decoder(z, **kwargs) + return dec + + def encode(self, x): + h = self.encoder(x) + # moments = self.quant_conv(h) + moments = h + posterior = DiagonalGaussianDistribution(moments) + return posterior + + +class IdentityFirstStage(torch.nn.Module): + def __init__(self, *args, vq_interface=False, **kwargs): + self.vq_interface = vq_interface + super().__init__() + + def encode(self, x, *args, **kwargs): + return x + + def decode(self, x, *args, **kwargs): + return x + + def quantize(self, x, *args, **kwargs): + if self.vq_interface: + return x, None, [None, None, None] + return x + + def forward(self, x, *args, **kwargs): + return x + + + +@DISTRIBUTION.register_class() +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.]) + else: + if other is None: + return 0.5 * torch.sum(torch.pow(self.mean, 2) + + self.var - 1.0 - self.logvar, + dim=[1, 2, 3]) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var - 1.0 - self.logvar + other.logvar, + dim=[1, 2, 3]) + + def nll(self, sample, dims=[1,2,3]): + if self.deterministic: + return torch.Tensor([0.]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims) + + def mode(self): + return self.mean + + +# -------------------------------modules-------------------------------- + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0,1,0,1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + +class ResnetBlock(nn.Module): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, + dropout, temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, + out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x+h + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = q.reshape(b,c,h*w) + q = q.permute(0,2,1) # b,hw,c + k = k.reshape(b,c,h*w) # b,c,hw + w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b,c,h*w) + w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b,c,h,w) + + h_ = self.proj_out(h_) + + return x+h_ + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = q.reshape(b,c,h*w) + q = q.permute(0,2,1) # b,hw,c + k = k.reshape(b,c,h*w) # b,c,hw + w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b,c,h*w) + w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b,c,h,w) + + h_ = self.proj_out(h_) + + return x+h_ + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0,1,0,1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + +class Encoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla", + **ignore_kwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + 2*z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x, return_feat=False): + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + if return_feat: + hs[-1] = h + return hs + else: + return h + + +class Decoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, + attn_type="vanilla", **ignorekwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,)+tuple(ch_mult) + block_in = ch*ch_mult[self.num_resolutions-1] + curr_res = resolution // 2**(self.num_resolutions-1) + self.z_shape = (1,z_channels, curr_res, curr_res) + # logging.info("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, + block_in, + kernel_size=3, + stride=1, + padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, z, **kwargs): + #assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + if self.tanh_out: + h = torch.tanh(h) + return h + + + + diff --git a/tools/modules/clip_embedder.py b/tools/modules/clip_embedder.py new file mode 100644 index 0000000..cc711ce --- /dev/null +++ b/tools/modules/clip_embedder.py @@ -0,0 +1,241 @@ +import os +import torch +import logging +import open_clip +import numpy as np +import torch.nn as nn +import torchvision.transforms as T + +from ...utils.registry_class import EMBEDDER + + +@EMBEDDER.register_class() +class FrozenOpenCLIPEmbedder(nn.Module): + """ + Uses the OpenCLIP transformer encoder for text + """ + LAYERS = [ + #"pooled", + "last", + "penultimate" + ] + def __init__(self, pretrained, arch="ViT-H-14", device="cuda", max_length=77, + freeze=True, layer="last"): + super().__init__() + assert layer in self.LAYERS + + # modules + current_directory = os.path.dirname(os.path.abspath(__file__)) + # tools + parent_directory = os.path.dirname(current_directory) + # uniAnimate + root_directory = os.path.dirname(parent_directory) + pretrained = os.path.join(root_directory, 'checkpoints/open_clip_pytorch_model.bin') + + + model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=pretrained) + del model.visual + self.model = model + + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + self.layer = layer + if self.layer == "last": + self.layer_idx = 0 + elif self.layer == "penultimate": + self.layer_idx = 1 + else: + raise NotImplementedError() + + def freeze(self): + self.model = self.model.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + tokens = open_clip.tokenize(text) + z = self.encode_with_transformer(tokens.to(self.device)) + return z + + def encode_with_transformer(self, text): + x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] + x = x + self.model.positional_embedding + x = x.permute(1, 0, 2) # NLD -> LND + x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.model.ln_final(x) + return x + + def text_transformer_forward(self, x: torch.Tensor, attn_mask = None): + for i, r in enumerate(self.model.transformer.resblocks): + if i == len(self.model.transformer.resblocks) - self.layer_idx: + break + if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(r, x, attn_mask) + else: + x = r(x, attn_mask=attn_mask) + return x + + def encode(self, text): + return self(text) + + +@EMBEDDER.register_class() +class FrozenOpenCLIPVisualEmbedder(nn.Module): + """ + Uses the OpenCLIP transformer encoder for text + """ + LAYERS = [ + #"pooled", + "last", + "penultimate" + ] + def __init__(self, pretrained, vit_resolution=(224, 224), arch="ViT-H-14", device="cuda", max_length=77, + freeze=True, layer="last"): + super().__init__() + assert layer in self.LAYERS + + # modules + current_directory = os.path.dirname(os.path.abspath(__file__)) + # tools + parent_directory = os.path.dirname(current_directory) + # uniAnimate + root_directory = os.path.dirname(parent_directory) + pretrained = os.path.join(root_directory, 'checkpoints/open_clip_pytorch_model.bin') + + + model, _, preprocess = open_clip.create_model_and_transforms( + arch, device=torch.device('cpu'), pretrained=pretrained) + + del model.transformer + self.model = model + data_white = np.ones((vit_resolution[0], vit_resolution[1], 3), dtype=np.uint8)*255 + self.white_image = preprocess(T.ToPILImage()(data_white)).unsqueeze(0) + + self.device = device + self.max_length = max_length # 77 + if freeze: + self.freeze() + self.layer = layer # 'penultimate' + if self.layer == "last": + self.layer_idx = 0 + elif self.layer == "penultimate": + self.layer_idx = 1 + else: + raise NotImplementedError() + + def freeze(self): + self.model = self.model.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, image): + # tokens = open_clip.tokenize(text) + z = self.model.encode_image(image.to(self.device)) + return z + + def encode_with_transformer(self, text): + x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] + x = x + self.model.positional_embedding + x = x.permute(1, 0, 2) # NLD -> LND + x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.model.ln_final(x) + + return x + + def text_transformer_forward(self, x: torch.Tensor, attn_mask = None): + for i, r in enumerate(self.model.transformer.resblocks): + if i == len(self.model.transformer.resblocks) - self.layer_idx: + break + if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(r, x, attn_mask) + else: + x = r(x, attn_mask=attn_mask) + return x + + def encode(self, text): + return self(text) + + + +@EMBEDDER.register_class() +class FrozenOpenCLIPTextVisualEmbedder(nn.Module): + """ + Uses the OpenCLIP transformer encoder for text + """ + LAYERS = [ + #"pooled", + "last", + "penultimate" + ] + def __init__(self, pretrained, arch="ViT-H-14", device="cuda", max_length=77, + freeze=True, layer="last", **kwargs): + super().__init__() + assert layer in self.LAYERS + + # modules + current_directory = os.path.dirname(os.path.abspath(__file__)) + # tools + parent_directory = os.path.dirname(current_directory) + # uniAnimate + root_directory = os.path.dirname(parent_directory) + pretrained = os.path.join(root_directory, 'checkpoints/open_clip_pytorch_model.bin') + + model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=pretrained) + self.model = model + + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + self.layer = layer + if self.layer == "last": + self.layer_idx = 0 + elif self.layer == "penultimate": + self.layer_idx = 1 + else: + raise NotImplementedError() + + def freeze(self): + self.model = self.model.eval() + for param in self.parameters(): + param.requires_grad = False + + + def forward(self, image=None, text=None): + + xi = self.model.encode_image(image.to(self.device)) if image is not None else None + tokens = open_clip.tokenize(text) + xt, x = self.encode_with_transformer(tokens.to(self.device)) + return xi, xt, x + + def encode_with_transformer(self, text): + x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] + x = x + self.model.positional_embedding + x = x.permute(1, 0, 2) # NLD -> LND + x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.model.ln_final(x) + xt = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.model.text_projection + return xt, x + + + def encode_image(self, image): + return self.model.visual(image) + + def text_transformer_forward(self, x: torch.Tensor, attn_mask = None): + for i, r in enumerate(self.model.transformer.resblocks): + if i == len(self.model.transformer.resblocks) - self.layer_idx: + break + if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(r, x, attn_mask) + else: + x = r(x, attn_mask=attn_mask) + return x + + def encode(self, text): + + return self(text) \ No newline at end of file diff --git a/tools/modules/config.py b/tools/modules/config.py new file mode 100644 index 0000000..9a8cc40 --- /dev/null +++ b/tools/modules/config.py @@ -0,0 +1,206 @@ +import torch +import logging +import os.path as osp +from datetime import datetime +from easydict import EasyDict +import os + +cfg = EasyDict(__name__='Config: VideoLDM Decoder') + +# -------------------------------distributed training-------------------------- +pmi_world_size = int(os.getenv('WORLD_SIZE', 1)) +gpus_per_machine = torch.cuda.device_count() +world_size = pmi_world_size * gpus_per_machine +# ----------------------------------------------------------------------------- + + +# ---------------------------Dataset Parameter--------------------------------- +cfg.mean = [0.5, 0.5, 0.5] +cfg.std = [0.5, 0.5, 0.5] +cfg.max_words = 1000 +cfg.num_workers = 8 +cfg.prefetch_factor = 2 + +# PlaceHolder +cfg.resolution = [448, 256] +cfg.vit_out_dim = 1024 +cfg.vit_resolution = 336 +cfg.depth_clamp = 10.0 +cfg.misc_size = 384 +cfg.depth_std = 20.0 + +cfg.save_fps = 8 + +cfg.frame_lens = [32, 32, 32, 1] +cfg.sample_fps = [4, ] +cfg.vid_dataset = { + 'type': 'VideoBaseDataset', + 'data_list': [], + 'max_words': cfg.max_words, + 'resolution': cfg.resolution} +cfg.img_dataset = { + 'type': 'ImageBaseDataset', + 'data_list': ['laion_400m',], + 'max_words': cfg.max_words, + 'resolution': cfg.resolution} + +cfg.batch_sizes = { + str(1):256, + str(4):4, + str(8):4, + str(16):4} +# ----------------------------------------------------------------------------- + + +# ---------------------------Mode Parameters----------------------------------- +# Diffusion +cfg.Diffusion = { + 'type': 'DiffusionDDIM', + 'schedule': 'cosine', # cosine + 'schedule_param': { + 'num_timesteps': 1000, + 'cosine_s': 0.008, + 'zero_terminal_snr': True, + }, + 'mean_type': 'v', # [v, eps] + 'loss_type': 'mse', + 'var_type': 'fixed_small', + 'rescale_timesteps': False, + 'noise_strength': 0.1, + 'ddim_timesteps': 50 +} +cfg.ddim_timesteps = 50 # official: 250 +cfg.use_div_loss = False +# classifier-free guidance +cfg.p_zero = 0.9 +cfg.guide_scale = 3.0 + +# clip vision encoder +cfg.vit_mean = [0.48145466, 0.4578275, 0.40821073] +cfg.vit_std = [0.26862954, 0.26130258, 0.27577711] + +# sketch +cfg.sketch_mean = [0.485, 0.456, 0.406] +cfg.sketch_std = [0.229, 0.224, 0.225] +# cfg.misc_size = 256 +cfg.depth_std = 20.0 +cfg.depth_clamp = 10.0 +cfg.hist_sigma = 10.0 + +# Model +cfg.scale_factor = 0.18215 +cfg.use_checkpoint = True +cfg.use_sharded_ddp = False +cfg.use_fsdp = False +cfg.use_fp16 = True +cfg.temporal_attention = True + +cfg.UNet = { + 'type': 'UNetSD', + 'in_dim': 4, + 'dim': 320, + 'y_dim': cfg.vit_out_dim, + 'context_dim': 1024, + 'out_dim': 8, + 'dim_mult': [1, 2, 4, 4], + 'num_heads': 8, + 'head_dim': 64, + 'num_res_blocks': 2, + 'attn_scales': [1 / 1, 1 / 2, 1 / 4], + 'dropout': 0.1, + 'temporal_attention': cfg.temporal_attention, + 'temporal_attn_times': 1, + 'use_checkpoint': cfg.use_checkpoint, + 'use_fps_condition': False, + 'use_sim_mask': False +} + +# auotoencoder from stabel diffusion +cfg.guidances = [] +cfg.auto_encoder = { + 'type': 'AutoencoderKL', + 'ddconfig': { + 'double_z': True, + 'z_channels': 4, + 'resolution': 256, + 'in_channels': 3, + 'out_ch': 3, + 'ch': 128, + 'ch_mult': [1, 2, 4, 4], + 'num_res_blocks': 2, + 'attn_resolutions': [], + 'dropout': 0.0, + 'video_kernel_size': [3, 1, 1] + }, + 'embed_dim': 4, + 'pretrained': 'models/v2-1_512-ema-pruned.ckpt' +} +# clip embedder +cfg.embedder = { + 'type': 'FrozenOpenCLIPEmbedder', + 'layer': 'penultimate', + 'pretrained': 'models/open_clip_pytorch_model.bin' +} +# ----------------------------------------------------------------------------- + +# ---------------------------Training Settings--------------------------------- +# training and optimizer +cfg.ema_decay = 0.9999 +cfg.num_steps = 600000 +cfg.lr = 5e-5 +cfg.weight_decay = 0.0 +cfg.betas = (0.9, 0.999) +cfg.eps = 1.0e-8 +cfg.chunk_size = 16 +cfg.decoder_bs = 8 +cfg.alpha = 0.7 +cfg.save_ckp_interval = 1000 + +# scheduler +cfg.warmup_steps = 10 +cfg.decay_mode = 'cosine' + +# acceleration +cfg.use_ema = True +if world_size<2: + cfg.use_ema = False +cfg.load_from = None +# ----------------------------------------------------------------------------- + + +# ----------------------------Pretrain Settings--------------------------------- +cfg.Pretrain = { + 'type': 'pretrain_specific_strategies', + 'fix_weight': False, + 'grad_scale': 0.2, + 'resume_checkpoint': 'models/jiuniu_0267000.pth', + 'sd_keys_path': 'models/stable_diffusion_image_key_temporal_attention_x1.json', +} +# ----------------------------------------------------------------------------- + + +# -----------------------------Visual------------------------------------------- +# Visual videos +cfg.viz_interval = 1000 +cfg.visual_train = { + 'type': 'VisualTrainTextImageToVideo', +} +cfg.visual_inference = { + 'type': 'VisualGeneratedVideos', +} +cfg.inference_list_path = '' + +# logging +cfg.log_interval = 100 + +### Default log_dir +cfg.log_dir = 'outputs/' +# ----------------------------------------------------------------------------- + + +# ---------------------------Others-------------------------------------------- +# seed +cfg.seed = 8888 +cfg.negative_prompt = 'Distorted, discontinuous, Ugly, blurry, low resolution, motionless, static, disfigured, disconnected limbs, Ugly faces, incomplete arms' +# ----------------------------------------------------------------------------- + diff --git a/tools/modules/diffusions/__init__.py b/tools/modules/diffusions/__init__.py new file mode 100644 index 0000000..c025248 --- /dev/null +++ b/tools/modules/diffusions/__init__.py @@ -0,0 +1 @@ +from .diffusion_ddim import * diff --git a/tools/modules/diffusions/__pycache__/__init__.cpython-310.pyc b/tools/modules/diffusions/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..49d69df Binary files /dev/null and b/tools/modules/diffusions/__pycache__/__init__.cpython-310.pyc differ diff --git a/tools/modules/diffusions/__pycache__/__init__.cpython-39.pyc b/tools/modules/diffusions/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000..1daaa1f Binary files /dev/null and b/tools/modules/diffusions/__pycache__/__init__.cpython-39.pyc differ diff --git a/tools/modules/diffusions/__pycache__/diffusion_ddim.cpython-310.pyc b/tools/modules/diffusions/__pycache__/diffusion_ddim.cpython-310.pyc new file mode 100644 index 0000000..1604c91 Binary files /dev/null and b/tools/modules/diffusions/__pycache__/diffusion_ddim.cpython-310.pyc differ diff --git a/tools/modules/diffusions/__pycache__/diffusion_ddim.cpython-39.pyc b/tools/modules/diffusions/__pycache__/diffusion_ddim.cpython-39.pyc new file mode 100644 index 0000000..c86b6eb Binary files /dev/null and b/tools/modules/diffusions/__pycache__/diffusion_ddim.cpython-39.pyc differ diff --git a/tools/modules/diffusions/__pycache__/losses.cpython-310.pyc b/tools/modules/diffusions/__pycache__/losses.cpython-310.pyc new file mode 100644 index 0000000..6267e74 Binary files /dev/null and b/tools/modules/diffusions/__pycache__/losses.cpython-310.pyc differ diff --git a/tools/modules/diffusions/__pycache__/losses.cpython-39.pyc b/tools/modules/diffusions/__pycache__/losses.cpython-39.pyc new file mode 100644 index 0000000..465f747 Binary files /dev/null and b/tools/modules/diffusions/__pycache__/losses.cpython-39.pyc differ diff --git a/tools/modules/diffusions/__pycache__/schedules.cpython-310.pyc b/tools/modules/diffusions/__pycache__/schedules.cpython-310.pyc new file mode 100644 index 0000000..c052b4c Binary files /dev/null and b/tools/modules/diffusions/__pycache__/schedules.cpython-310.pyc differ diff --git a/tools/modules/diffusions/__pycache__/schedules.cpython-39.pyc b/tools/modules/diffusions/__pycache__/schedules.cpython-39.pyc new file mode 100644 index 0000000..0e87214 Binary files /dev/null and b/tools/modules/diffusions/__pycache__/schedules.cpython-39.pyc differ diff --git a/tools/modules/diffusions/diffusion_ddim.py b/tools/modules/diffusions/diffusion_ddim.py new file mode 100644 index 0000000..43a17d3 --- /dev/null +++ b/tools/modules/diffusions/diffusion_ddim.py @@ -0,0 +1,1121 @@ +import torch +import math + +from ....utils.registry_class import DIFFUSION +from .schedules import beta_schedule, sigma_schedule +from .losses import kl_divergence, discretized_gaussian_log_likelihood +# from .dpm_solver import NoiseScheduleVP, model_wrapper_guided_diffusion, model_wrapper, DPM_Solver +from typing import Callable, List, Optional +import numpy as np + +def _i(tensor, t, x): + r"""Index tensor using t and format the output according to x. + """ + if tensor.device != x.device: + tensor = tensor.to(x.device) + shape = (x.size(0), ) + (1, ) * (x.ndim - 1) + return tensor[t].view(shape).to(x) + +@DIFFUSION.register_class() +class DiffusionDDIMSR(object): + def __init__(self, reverse_diffusion, forward_diffusion, **kwargs): + from .diffusion_gauss import GaussianDiffusion + self.reverse_diffusion = GaussianDiffusion(sigmas=sigma_schedule(reverse_diffusion.schedule, **reverse_diffusion.schedule_param), + prediction_type=reverse_diffusion.mean_type) + self.forward_diffusion = GaussianDiffusion(sigmas=sigma_schedule(forward_diffusion.schedule, **forward_diffusion.schedule_param), + prediction_type=forward_diffusion.mean_type) + + +@DIFFUSION.register_class() +class DiffusionDPM(object): + def __init__(self, forward_diffusion, **kwargs): + from .diffusion_gauss import GaussianDiffusion + self.forward_diffusion = GaussianDiffusion(sigmas=sigma_schedule(forward_diffusion.schedule, **forward_diffusion.schedule_param), + prediction_type=forward_diffusion.mean_type) + + +@DIFFUSION.register_class() +class DiffusionDDIM(object): + def __init__(self, + schedule='linear_sd', + schedule_param={}, + mean_type='eps', + var_type='learned_range', + loss_type='mse', + epsilon = 1e-12, + rescale_timesteps=False, + noise_strength=0.0, + **kwargs): + + assert mean_type in ['x0', 'x_{t-1}', 'eps', 'v'] + assert var_type in ['learned', 'learned_range', 'fixed_large', 'fixed_small'] + assert loss_type in ['mse', 'rescaled_mse', 'kl', 'rescaled_kl', 'l1', 'rescaled_l1','charbonnier'] + + betas = beta_schedule(schedule, **schedule_param) + assert min(betas) > 0 and max(betas) <= 1 + + if not isinstance(betas, torch.DoubleTensor): + betas = torch.tensor(betas, dtype=torch.float64) + + self.betas = betas + self.num_timesteps = len(betas) + self.mean_type = mean_type # eps + self.var_type = var_type # 'fixed_small' + self.loss_type = loss_type # mse + self.epsilon = epsilon # 1e-12 + self.rescale_timesteps = rescale_timesteps # False + self.noise_strength = noise_strength # 0.0 + + # alphas + alphas = 1 - self.betas + self.alphas_cumprod = torch.cumprod(alphas, dim=0) + self.alphas_cumprod_prev = torch.cat([alphas.new_ones([1]), self.alphas_cumprod[:-1]]) + self.alphas_cumprod_next = torch.cat([self.alphas_cumprod[1:], alphas.new_zeros([1])]) + + # q(x_t | x_{t-1}) + self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod) + self.log_one_minus_alphas_cumprod = torch.log(1.0 - self.alphas_cumprod) + self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod - 1) + + # q(x_{t-1} | x_t, x_0) + self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + self.posterior_log_variance_clipped = torch.log(self.posterior_variance.clamp(1e-20)) + self.posterior_mean_coef1 = betas * torch.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + self.posterior_mean_coef2 = (1.0 - self.alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - self.alphas_cumprod) + + + def sample_loss(self, x0, noise=None): + if noise is None: + noise = torch.randn_like(x0) + if self.noise_strength > 0: + b, c, f, _, _= x0.shape + offset_noise = torch.randn(b, c, f, 1, 1, device=x0.device) + noise = noise + self.noise_strength * offset_noise + return noise + + + def q_sample(self, x0, t, noise=None): + r"""Sample from q(x_t | x_0). + """ + # noise = torch.randn_like(x0) if noise is None else noise + noise = self.sample_loss(x0, noise) + return _i(self.sqrt_alphas_cumprod, t, x0) * x0 + \ + _i(self.sqrt_one_minus_alphas_cumprod, t, x0) * noise + + def q_mean_variance(self, x0, t): + r"""Distribution of q(x_t | x_0). + """ + mu = _i(self.sqrt_alphas_cumprod, t, x0) * x0 + var = _i(1.0 - self.alphas_cumprod, t, x0) + log_var = _i(self.log_one_minus_alphas_cumprod, t, x0) + return mu, var, log_var + + def q_posterior_mean_variance(self, x0, xt, t): + r"""Distribution of q(x_{t-1} | x_t, x_0). + """ + mu = _i(self.posterior_mean_coef1, t, xt) * x0 + _i(self.posterior_mean_coef2, t, xt) * xt + var = _i(self.posterior_variance, t, xt) + log_var = _i(self.posterior_log_variance_clipped, t, xt) + return mu, var, log_var + + @torch.no_grad() + def p_sample(self, xt, t, model, model_kwargs={}, clamp=None, percentile=None, condition_fn=None, guide_scale=None): + r"""Sample from p(x_{t-1} | x_t). + - condition_fn: for classifier-based guidance (guided-diffusion). + - guide_scale: for classifier-free guidance (glide/dalle-2). + """ + # predict distribution of p(x_{t-1} | x_t) + mu, var, log_var, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, percentile, guide_scale) + + # random sample (with optional conditional function) + noise = torch.randn_like(xt) + mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1))) # no noise when t == 0 + if condition_fn is not None: + grad = condition_fn(xt, self._scale_timesteps(t), **model_kwargs) + mu = mu.float() + var * grad.float() + xt_1 = mu + mask * torch.exp(0.5 * log_var) * noise + return xt_1, x0 + + @torch.no_grad() + def p_sample_loop(self, noise, model, model_kwargs={}, clamp=None, percentile=None, condition_fn=None, guide_scale=None): + r"""Sample from p(x_{t-1} | x_t) p(x_{t-2} | x_{t-1}) ... p(x_0 | x_1). + """ + # prepare input + b = noise.size(0) + xt = noise + + # diffusion process + for step in torch.arange(self.num_timesteps).flip(0): + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _ = self.p_sample(xt, t, model, model_kwargs, clamp, percentile, condition_fn, guide_scale) + return xt + + def p_mean_variance(self, xt, t, model, model_kwargs={}, clamp=None, percentile=None, guide_scale=None): + r"""Distribution of p(x_{t-1} | x_t). + """ + # predict distribution + if guide_scale is None: + out = model(xt, self._scale_timesteps(t), **model_kwargs) + else: + # classifier-free guidance + # (model_kwargs[0]: conditional kwargs; model_kwargs[1]: non-conditional kwargs) + assert isinstance(model_kwargs, list) and len(model_kwargs) == 2 + y_out = model(xt, self._scale_timesteps(t), **model_kwargs[0]) + u_out = model(xt, self._scale_timesteps(t), **model_kwargs[1]) + dim = y_out.size(1) if self.var_type.startswith('fixed') else y_out.size(1) // 2 + out = torch.cat([ + u_out[:, :dim] + guide_scale * (y_out[:, :dim] - u_out[:, :dim]), + y_out[:, dim:]], dim=1) # guide_scale=9.0 + + # compute variance + if self.var_type == 'learned': + out, log_var = out.chunk(2, dim=1) + var = torch.exp(log_var) + elif self.var_type == 'learned_range': + out, fraction = out.chunk(2, dim=1) + min_log_var = _i(self.posterior_log_variance_clipped, t, xt) + max_log_var = _i(torch.log(self.betas), t, xt) + fraction = (fraction + 1) / 2.0 + log_var = fraction * max_log_var + (1 - fraction) * min_log_var + var = torch.exp(log_var) + elif self.var_type == 'fixed_large': + var = _i(torch.cat([self.posterior_variance[1:2], self.betas[1:]]), t, xt) + log_var = torch.log(var) + elif self.var_type == 'fixed_small': + var = _i(self.posterior_variance, t, xt) + log_var = _i(self.posterior_log_variance_clipped, t, xt) + + # compute mean and x0 + if self.mean_type == 'x_{t-1}': + mu = out # x_{t-1} + x0 = _i(1.0 / self.posterior_mean_coef1, t, xt) * mu - \ + _i(self.posterior_mean_coef2 / self.posterior_mean_coef1, t, xt) * xt + elif self.mean_type == 'x0': + x0 = out + mu, _, _ = self.q_posterior_mean_variance(x0, xt, t) + elif self.mean_type == 'eps': + x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * out + mu, _, _ = self.q_posterior_mean_variance(x0, xt, t) + elif self.mean_type == 'v': + x0 = _i(self.sqrt_alphas_cumprod, t, xt) * xt - \ + _i(self.sqrt_one_minus_alphas_cumprod, t, xt) * out + mu, _, _ = self.q_posterior_mean_variance(x0, xt, t) + + # restrict the range of x0 + if percentile is not None: + assert percentile > 0 and percentile <= 1 # e.g., 0.995 + s = torch.quantile(x0.flatten(1).abs(), percentile, dim=1).clamp_(1.0).view(-1, 1, 1, 1) + x0 = torch.min(s, torch.max(-s, x0)) / s + elif clamp is not None: + x0 = x0.clamp(-clamp, clamp) + return mu, var, log_var, x0 + + @torch.no_grad() + def ddim_sample(self, xt, t, model, model_kwargs={}, clamp=None, percentile=None, condition_fn=None, guide_scale=None, ddim_timesteps=20, eta=0.0): + r"""Sample from p(x_{t-1} | x_t) using DDIM. + - condition_fn: for classifier-based guidance (guided-diffusion). + - guide_scale: for classifier-free guidance (glide/dalle-2). + """ + stride = self.num_timesteps // ddim_timesteps + + # predict distribution of p(x_{t-1} | x_t) + _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, percentile, guide_scale) + if condition_fn is not None: + # x0 -> eps + alpha = _i(self.alphas_cumprod, t, xt) + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) + eps = eps - (1 - alpha).sqrt() * condition_fn(xt, self._scale_timesteps(t), **model_kwargs) + + # eps -> x0 + x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps + + # derive variables + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) + alphas = _i(self.alphas_cumprod, t, xt) + alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt) + sigmas = eta * torch.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) + + # random sample + noise = torch.randn_like(xt) + direction = torch.sqrt(1 - alphas_prev - sigmas ** 2) * eps + mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1))) + xt_1 = torch.sqrt(alphas_prev) * x0 + direction + mask * sigmas * noise + return xt_1, x0 + + @torch.no_grad() + def ddim_sample_loop(self, noise, model, model_kwargs={}, clamp=None, percentile=None, condition_fn=None, guide_scale=None, ddim_timesteps=20, eta=0.0): + # prepare input + b = noise.size(0) + xt = noise + + # diffusion process (TODO: clamp is inaccurate! Consider replacing the stride by explicit prev/next steps) + steps = (1 + torch.arange(0, self.num_timesteps, self.num_timesteps // ddim_timesteps)).clamp(0, self.num_timesteps - 1).flip(0) + from tqdm import tqdm + for step in tqdm(steps): + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _ = self.ddim_sample(xt, t, model, model_kwargs, clamp, percentile, condition_fn, guide_scale, ddim_timesteps, eta) + # from ipdb import set_trace; set_trace() + return xt + + @torch.no_grad() + def ddim_reverse_sample(self, xt, t, model, model_kwargs={}, clamp=None, percentile=None, guide_scale=None, ddim_timesteps=20): + r"""Sample from p(x_{t+1} | x_t) using DDIM reverse ODE (deterministic). + """ + stride = self.num_timesteps // ddim_timesteps + + # predict distribution of p(x_{t-1} | x_t) + _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, percentile, guide_scale) + + # derive variables + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) + alphas_next = _i( + torch.cat([self.alphas_cumprod, self.alphas_cumprod.new_zeros([1])]), + (t + stride).clamp(0, self.num_timesteps), xt) + + # reverse sample + mu = torch.sqrt(alphas_next) * x0 + torch.sqrt(1 - alphas_next) * eps + return mu, x0 + + @torch.no_grad() + def ddim_reverse_sample_loop(self, x0, model, model_kwargs={}, clamp=None, percentile=None, guide_scale=None, ddim_timesteps=20): + # prepare input + b = x0.size(0) + xt = x0 + + # reconstruction steps + steps = torch.arange(0, self.num_timesteps, self.num_timesteps // ddim_timesteps) + for step in steps: + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _ = self.ddim_reverse_sample(xt, t, model, model_kwargs, clamp, percentile, guide_scale, ddim_timesteps) + return xt + + @torch.no_grad() + def plms_sample(self, xt, t, model, model_kwargs={}, clamp=None, percentile=None, condition_fn=None, guide_scale=None, plms_timesteps=20): + r"""Sample from p(x_{t-1} | x_t) using PLMS. + - condition_fn: for classifier-based guidance (guided-diffusion). + - guide_scale: for classifier-free guidance (glide/dalle-2). + """ + stride = self.num_timesteps // plms_timesteps + + # function for compute eps + def compute_eps(xt, t): + # predict distribution of p(x_{t-1} | x_t) + _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, percentile, guide_scale) + + # condition + if condition_fn is not None: + # x0 -> eps + alpha = _i(self.alphas_cumprod, t, xt) + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) + eps = eps - (1 - alpha).sqrt() * condition_fn(xt, self._scale_timesteps(t), **model_kwargs) + + # eps -> x0 + x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps + + # derive eps + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) + return eps + + # function for compute x_0 and x_{t-1} + def compute_x0(eps, t): + # eps -> x0 + x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps + + # deterministic sample + alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt) + direction = torch.sqrt(1 - alphas_prev) * eps + mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1))) + xt_1 = torch.sqrt(alphas_prev) * x0 + direction + return xt_1, x0 + + # PLMS sample + eps = compute_eps(xt, t) + if len(eps_cache) == 0: + # 2nd order pseudo improved Euler + xt_1, x0 = compute_x0(eps, t) + eps_next = compute_eps(xt_1, (t - stride).clamp(0)) + eps_prime = (eps + eps_next) / 2.0 + elif len(eps_cache) == 1: + # 2nd order pseudo linear multistep (Adams-Bashforth) + eps_prime = (3 * eps - eps_cache[-1]) / 2.0 + elif len(eps_cache) == 2: + # 3nd order pseudo linear multistep (Adams-Bashforth) + eps_prime = (23 * eps - 16 * eps_cache[-1] + 5 * eps_cache[-2]) / 12.0 + elif len(eps_cache) >= 3: + # 4nd order pseudo linear multistep (Adams-Bashforth) + eps_prime = (55 * eps - 59 * eps_cache[-1] + 37 * eps_cache[-2] - 9 * eps_cache[-3]) / 24.0 + xt_1, x0 = compute_x0(eps_prime, t) + return xt_1, x0, eps + + @torch.no_grad() + def plms_sample_loop(self, noise, model, model_kwargs={}, clamp=None, percentile=None, condition_fn=None, guide_scale=None, plms_timesteps=20): + # prepare input + b = noise.size(0) + xt = noise + + # diffusion process + steps = (1 + torch.arange(0, self.num_timesteps, self.num_timesteps // plms_timesteps)).clamp(0, self.num_timesteps - 1).flip(0) + eps_cache = [] + for step in steps: + # PLMS sampling step + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _, eps = self.plms_sample(xt, t, model, model_kwargs, clamp, percentile, condition_fn, guide_scale, plms_timesteps, eps_cache) + + # update eps cache + eps_cache.append(eps) + if len(eps_cache) >= 4: + eps_cache.pop(0) + return xt + + def loss(self, x0, t, model, model_kwargs={}, noise=None, weight = None, use_div_loss= False, loss_mask=None): + + # noise = torch.randn_like(x0) if noise is None else noise # [80, 4, 8, 32, 32] + noise = self.sample_loss(x0, noise) + + xt = self.q_sample(x0, t, noise=noise) + + # compute loss + if self.loss_type in ['kl', 'rescaled_kl']: + loss, _ = self.variational_lower_bound(x0, xt, t, model, model_kwargs) + if self.loss_type == 'rescaled_kl': + loss = loss * self.num_timesteps + elif self.loss_type in ['mse', 'rescaled_mse', 'l1', 'rescaled_l1']: # self.loss_type: mse + out = model(xt, self._scale_timesteps(t), **model_kwargs) + + # VLB for variation + loss_vlb = 0.0 + if self.var_type in ['learned', 'learned_range']: # self.var_type: 'fixed_small' + out, var = out.chunk(2, dim=1) + frozen = torch.cat([out.detach(), var], dim=1) # learn var without affecting the prediction of mean + loss_vlb, _ = self.variational_lower_bound(x0, xt, t, model=lambda *args, **kwargs: frozen) + if self.loss_type.startswith('rescaled_'): + loss_vlb = loss_vlb * self.num_timesteps / 1000.0 + + # MSE/L1 for x0/eps + # target = {'eps': noise, 'x0': x0, 'x_{t-1}': self.q_posterior_mean_variance(x0, xt, t)[0]}[self.mean_type] + target = { + 'eps': noise, + 'x0': x0, + 'x_{t-1}': self.q_posterior_mean_variance(x0, xt, t)[0], + 'v':_i(self.sqrt_alphas_cumprod, t, xt) * noise - _i(self.sqrt_one_minus_alphas_cumprod, t, xt) * x0}[self.mean_type] + if loss_mask is not None: + loss_mask = loss_mask[:, :, 0, ...].unsqueeze(2) # just use one channel (all channels are same) + loss_mask = loss_mask.permute(0, 2, 1, 3, 4) # b,c,f,h,w + # use masked diffusion + loss = (out * loss_mask - target * loss_mask).pow(1 if self.loss_type.endswith('l1') else 2).abs().flatten(1).mean(dim=1) + else: + loss = (out - target).pow(1 if self.loss_type.endswith('l1') else 2).abs().flatten(1).mean(dim=1) + if weight is not None: + loss = loss*weight + + # div loss + if use_div_loss and self.mean_type == 'eps' and x0.shape[2]>1: + + # derive x0 + x0_ = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * out + + # # derive xt_1, set eta=0 as ddim + # alphas_prev = _i(self.alphas_cumprod, (t - 1).clamp(0), xt) + # direction = torch.sqrt(1 - alphas_prev) * out + # xt_1 = torch.sqrt(alphas_prev) * x0_ + direction + + # ncfhw, std on f + div_loss = 0.001/(x0_.std(dim=2).flatten(1).mean(dim=1)+1e-4) + # print(div_loss,loss) + loss = loss+div_loss + + # total loss + loss = loss + loss_vlb + elif self.loss_type in ['charbonnier']: + out = model(xt, self._scale_timesteps(t), **model_kwargs) + + # VLB for variation + loss_vlb = 0.0 + if self.var_type in ['learned', 'learned_range']: + out, var = out.chunk(2, dim=1) + frozen = torch.cat([out.detach(), var], dim=1) # learn var without affecting the prediction of mean + loss_vlb, _ = self.variational_lower_bound(x0, xt, t, model=lambda *args, **kwargs: frozen) + if self.loss_type.startswith('rescaled_'): + loss_vlb = loss_vlb * self.num_timesteps / 1000.0 + + # MSE/L1 for x0/eps + target = {'eps': noise, 'x0': x0, 'x_{t-1}': self.q_posterior_mean_variance(x0, xt, t)[0]}[self.mean_type] + loss = torch.sqrt((out - target)**2 + self.epsilon) + if weight is not None: + loss = loss*weight + loss = loss.flatten(1).mean(dim=1) + + # total loss + loss = loss + loss_vlb + return loss + + def variational_lower_bound(self, x0, xt, t, model, model_kwargs={}, clamp=None, percentile=None): + # compute groundtruth and predicted distributions + mu1, _, log_var1 = self.q_posterior_mean_variance(x0, xt, t) + mu2, _, log_var2, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, percentile) + + # compute KL loss + kl = kl_divergence(mu1, log_var1, mu2, log_var2) + kl = kl.flatten(1).mean(dim=1) / math.log(2.0) + + # compute discretized NLL loss (for p(x0 | x1) only) + nll = -discretized_gaussian_log_likelihood(x0, mean=mu2, log_scale=0.5 * log_var2) + nll = nll.flatten(1).mean(dim=1) / math.log(2.0) + + # NLL for p(x0 | x1) and KL otherwise + vlb = torch.where(t == 0, nll, kl) + return vlb, x0 + + @torch.no_grad() + def variational_lower_bound_loop(self, x0, model, model_kwargs={}, clamp=None, percentile=None): + r"""Compute the entire variational lower bound, measured in bits-per-dim. + """ + # prepare input and output + b = x0.size(0) + metrics = {'vlb': [], 'mse': [], 'x0_mse': []} + + # loop + for step in torch.arange(self.num_timesteps).flip(0): + # compute VLB + t = torch.full((b, ), step, dtype=torch.long, device=x0.device) + # noise = torch.randn_like(x0) + noise = self.sample_loss(x0) + xt = self.q_sample(x0, t, noise) + vlb, pred_x0 = self.variational_lower_bound(x0, xt, t, model, model_kwargs, clamp, percentile) + + # predict eps from x0 + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) + + # collect metrics + metrics['vlb'].append(vlb) + metrics['x0_mse'].append((pred_x0 - x0).square().flatten(1).mean(dim=1)) + metrics['mse'].append((eps - noise).square().flatten(1).mean(dim=1)) + metrics = {k: torch.stack(v, dim=1) for k, v in metrics.items()} + + # compute the prior KL term for VLB, measured in bits-per-dim + mu, _, log_var = self.q_mean_variance(x0, t) + kl_prior = kl_divergence(mu, log_var, torch.zeros_like(mu), torch.zeros_like(log_var)) + kl_prior = kl_prior.flatten(1).mean(dim=1) / math.log(2.0) + + # update metrics + metrics['prior_bits_per_dim'] = kl_prior + metrics['total_bits_per_dim'] = metrics['vlb'].sum(dim=1) + kl_prior + return metrics + + def _scale_timesteps(self, t): + if self.rescale_timesteps: + return t.float() * 1000.0 / self.num_timesteps + return t + #return t.float() + + + + + + +@DIFFUSION.register_class() +class DiffusionDDIMLong(object): + def __init__(self, + schedule='linear_sd', + schedule_param={}, + mean_type='eps', + var_type='learned_range', + loss_type='mse', + epsilon = 1e-12, + rescale_timesteps=False, + noise_strength=0.0, + **kwargs): + + assert mean_type in ['x0', 'x_{t-1}', 'eps', 'v'] + assert var_type in ['learned', 'learned_range', 'fixed_large', 'fixed_small'] + assert loss_type in ['mse', 'rescaled_mse', 'kl', 'rescaled_kl', 'l1', 'rescaled_l1','charbonnier'] + + betas = beta_schedule(schedule, **schedule_param) + assert min(betas) > 0 and max(betas) <= 1 + + if not isinstance(betas, torch.DoubleTensor): + betas = torch.tensor(betas, dtype=torch.float64) + + self.betas = betas + self.num_timesteps = len(betas) + self.mean_type = mean_type # v + self.var_type = var_type # 'fixed_small' + self.loss_type = loss_type # mse + self.epsilon = epsilon # 1e-12 + self.rescale_timesteps = rescale_timesteps # False + self.noise_strength = noise_strength + + # alphas + alphas = 1 - self.betas + self.alphas_cumprod = torch.cumprod(alphas, dim=0) + self.alphas_cumprod_prev = torch.cat([alphas.new_ones([1]), self.alphas_cumprod[:-1]]) + self.alphas_cumprod_next = torch.cat([self.alphas_cumprod[1:], alphas.new_zeros([1])]) + + # q(x_t | x_{t-1}) + self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod) + self.log_one_minus_alphas_cumprod = torch.log(1.0 - self.alphas_cumprod) + self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod - 1) + + # q(x_{t-1} | x_t, x_0) + self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + self.posterior_log_variance_clipped = torch.log(self.posterior_variance.clamp(1e-20)) + self.posterior_mean_coef1 = betas * torch.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + self.posterior_mean_coef2 = (1.0 - self.alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - self.alphas_cumprod) + + + def sample_loss(self, x0, noise=None): + if noise is None: + noise = torch.randn_like(x0) + if self.noise_strength > 0: + b, c, f, _, _= x0.shape + offset_noise = torch.randn(b, c, f, 1, 1, device=x0.device) + noise = noise + self.noise_strength * offset_noise + return noise + + + def q_sample(self, x0, t, noise=None): + r"""Sample from q(x_t | x_0). + """ + # noise = torch.randn_like(x0) if noise is None else noise + noise = self.sample_loss(x0, noise) + return _i(self.sqrt_alphas_cumprod, t, x0) * x0 + \ + _i(self.sqrt_one_minus_alphas_cumprod, t, x0) * noise + + def q_mean_variance(self, x0, t): + r"""Distribution of q(x_t | x_0). + """ + mu = _i(self.sqrt_alphas_cumprod, t, x0) * x0 + var = _i(1.0 - self.alphas_cumprod, t, x0) + log_var = _i(self.log_one_minus_alphas_cumprod, t, x0) + return mu, var, log_var + + def q_posterior_mean_variance(self, x0, xt, t): + r"""Distribution of q(x_{t-1} | x_t, x_0). + """ + mu = _i(self.posterior_mean_coef1, t, xt) * x0 + _i(self.posterior_mean_coef2, t, xt) * xt + var = _i(self.posterior_variance, t, xt) + log_var = _i(self.posterior_log_variance_clipped, t, xt) + return mu, var, log_var + + @torch.no_grad() + def p_sample(self, xt, t, model, model_kwargs={}, clamp=None, percentile=None, condition_fn=None, guide_scale=None): + r"""Sample from p(x_{t-1} | x_t). + - condition_fn: for classifier-based guidance (guided-diffusion). + - guide_scale: for classifier-free guidance (glide/dalle-2). + """ + # predict distribution of p(x_{t-1} | x_t) + mu, var, log_var, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, percentile, guide_scale) + + # random sample (with optional conditional function) + noise = torch.randn_like(xt) + mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1))) # no noise when t == 0 + if condition_fn is not None: + grad = condition_fn(xt, self._scale_timesteps(t), **model_kwargs) + mu = mu.float() + var * grad.float() + xt_1 = mu + mask * torch.exp(0.5 * log_var) * noise + return xt_1, x0 + + @torch.no_grad() + def p_sample_loop(self, noise, model, model_kwargs={}, clamp=None, percentile=None, condition_fn=None, guide_scale=None): + r"""Sample from p(x_{t-1} | x_t) p(x_{t-2} | x_{t-1}) ... p(x_0 | x_1). + """ + # prepare input + b = noise.size(0) + xt = noise + + # diffusion process + for step in torch.arange(self.num_timesteps).flip(0): + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _ = self.p_sample(xt, t, model, model_kwargs, clamp, percentile, condition_fn, guide_scale) + return xt + + def p_mean_variance(self, xt, t, model, model_kwargs={}, clamp=None, percentile=None, guide_scale=None, context_size=32, context_stride=1, context_overlap=4, context_batch_size=1): + r"""Distribution of p(x_{t-1} | x_t). + """ + noise = xt + context_queue = list( + context_scheduler( + 0, + 31, + noise.shape[2], + context_size=context_size, + context_stride=1, + context_overlap=4, + ) + ) + context_step = min( + context_stride, int(np.ceil(np.log2(noise.shape[2] / context_size))) + 1 + ) + # replace the final segment to improve temporal consistency + num_frames = noise.shape[2] + context_queue[-1] = [ + e % num_frames + for e in range(num_frames - context_size * context_step, num_frames, context_step) + ] + + import math + # context_batch_size = 1 + num_context_batches = math.ceil(len(context_queue) / context_batch_size) + global_context = [] + for i in range(num_context_batches): + global_context.append( + context_queue[ + i * context_batch_size : (i + 1) * context_batch_size + ] + ) + noise_pred = torch.zeros_like(noise) + noise_pred_uncond = torch.zeros_like(noise) + counter = torch.zeros( + (1, 1, xt.shape[2], 1, 1), + device=xt.device, + dtype=xt.dtype, + ) + + for i_index, context in enumerate(global_context): + + + latent_model_input = torch.cat([xt[:, :, c] for c in context]) + bs_context = len(context) + + model_kwargs_new = [{ + 'y': None, + "local_image": None if not model_kwargs[0].__contains__('local_image') else torch.cat([model_kwargs[0]["local_image"][:, :, c] for c in context]), + 'image': None if not model_kwargs[0].__contains__('image') else model_kwargs[0]["image"].repeat(bs_context, 1, 1), + 'dwpose': None if not model_kwargs[0].__contains__('dwpose') else torch.cat([model_kwargs[0]["dwpose"][:, :, [0]+[ii+1 for ii in c]] for c in context]), + 'randomref': None if not model_kwargs[0].__contains__('randomref') else torch.cat([model_kwargs[0]["randomref"][:, :, c] for c in context]), + }, + { + 'y': None, + "local_image": None, + 'image': None, + 'randomref': None, + 'dwpose': None, + }] + + if guide_scale is None: + out = model(latent_model_input, self._scale_timesteps(t), **model_kwargs) + for j, c in enumerate(context): + noise_pred[:, :, c] = noise_pred[:, :, c] + out + counter[:, :, c] = counter[:, :, c] + 1 + else: + # classifier-free guidance + # (model_kwargs[0]: conditional kwargs; model_kwargs[1]: non-conditional kwargs) + # assert isinstance(model_kwargs, list) and len(model_kwargs) == 2 + y_out = model(latent_model_input, self._scale_timesteps(t).repeat(bs_context), **model_kwargs_new[0]) + u_out = model(latent_model_input, self._scale_timesteps(t).repeat(bs_context), **model_kwargs_new[1]) + dim = y_out.size(1) if self.var_type.startswith('fixed') else y_out.size(1) // 2 + for j, c in enumerate(context): + noise_pred[:, :, c] = noise_pred[:, :, c] + y_out[j:j+1] + noise_pred_uncond[:, :, c] = noise_pred_uncond[:, :, c] + u_out[j:j+1] + counter[:, :, c] = counter[:, :, c] + 1 + + noise_pred = noise_pred / counter + noise_pred_uncond = noise_pred_uncond / counter + out = torch.cat([ + noise_pred_uncond[:, :dim] + guide_scale * (noise_pred[:, :dim] - noise_pred_uncond[:, :dim]), + noise_pred[:, dim:]], dim=1) # guide_scale=2.5 + + + # compute variance + if self.var_type == 'learned': + out, log_var = out.chunk(2, dim=1) + var = torch.exp(log_var) + elif self.var_type == 'learned_range': + out, fraction = out.chunk(2, dim=1) + min_log_var = _i(self.posterior_log_variance_clipped, t, xt) + max_log_var = _i(torch.log(self.betas), t, xt) + fraction = (fraction + 1) / 2.0 + log_var = fraction * max_log_var + (1 - fraction) * min_log_var + var = torch.exp(log_var) + elif self.var_type == 'fixed_large': + var = _i(torch.cat([self.posterior_variance[1:2], self.betas[1:]]), t, xt) + log_var = torch.log(var) + elif self.var_type == 'fixed_small': + var = _i(self.posterior_variance, t, xt) + log_var = _i(self.posterior_log_variance_clipped, t, xt) + + # compute mean and x0 + if self.mean_type == 'x_{t-1}': + mu = out # x_{t-1} + x0 = _i(1.0 / self.posterior_mean_coef1, t, xt) * mu - \ + _i(self.posterior_mean_coef2 / self.posterior_mean_coef1, t, xt) * xt + elif self.mean_type == 'x0': + x0 = out + mu, _, _ = self.q_posterior_mean_variance(x0, xt, t) + elif self.mean_type == 'eps': + x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * out + mu, _, _ = self.q_posterior_mean_variance(x0, xt, t) + elif self.mean_type == 'v': + x0 = _i(self.sqrt_alphas_cumprod, t, xt) * xt - \ + _i(self.sqrt_one_minus_alphas_cumprod, t, xt) * out + mu, _, _ = self.q_posterior_mean_variance(x0, xt, t) + + # restrict the range of x0 + if percentile is not None: + assert percentile > 0 and percentile <= 1 # e.g., 0.995 + s = torch.quantile(x0.flatten(1).abs(), percentile, dim=1).clamp_(1.0).view(-1, 1, 1, 1) + x0 = torch.min(s, torch.max(-s, x0)) / s + elif clamp is not None: + x0 = x0.clamp(-clamp, clamp) + return mu, var, log_var, x0 + + @torch.no_grad() + def ddim_sample(self, xt, t, model, model_kwargs={}, clamp=None, percentile=None, condition_fn=None, guide_scale=None, ddim_timesteps=20, eta=0.0, context_size=32, context_stride=1, context_overlap=4, context_batch_size=1): + r"""Sample from p(x_{t-1} | x_t) using DDIM. + - condition_fn: for classifier-based guidance (guided-diffusion). + - guide_scale: for classifier-free guidance (glide/dalle-2). + """ + stride = self.num_timesteps // ddim_timesteps + + # predict distribution of p(x_{t-1} | x_t) + _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, percentile, guide_scale, context_size, context_stride, context_overlap, context_batch_size) + if condition_fn is not None: + # x0 -> eps + alpha = _i(self.alphas_cumprod, t, xt) + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) + eps = eps - (1 - alpha).sqrt() * condition_fn(xt, self._scale_timesteps(t), **model_kwargs) + + # eps -> x0 + x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps + + # derive variables + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) + alphas = _i(self.alphas_cumprod, t, xt) + alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt) + sigmas = eta * torch.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) + + # random sample + noise = torch.randn_like(xt) + direction = torch.sqrt(1 - alphas_prev - sigmas ** 2) * eps + mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1))) + xt_1 = torch.sqrt(alphas_prev) * x0 + direction + mask * sigmas * noise + return xt_1, x0 + + @torch.no_grad() + def ddim_sample_loop(self, noise, context_size, context_stride, context_overlap, model, model_kwargs={}, clamp=None, percentile=None, condition_fn=None, guide_scale=None, ddim_timesteps=20, eta=0.0, context_batch_size=1): + # prepare input + b = noise.size(0) + xt = noise + + # diffusion process (TODO: clamp is inaccurate! Consider replacing the stride by explicit prev/next steps) + steps = (1 + torch.arange(0, self.num_timesteps, self.num_timesteps // ddim_timesteps)).clamp(0, self.num_timesteps - 1).flip(0) + from tqdm import tqdm + + for step in tqdm(steps): + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _ = self.ddim_sample(xt, t, model, model_kwargs, clamp, percentile, condition_fn, guide_scale, ddim_timesteps, eta, context_size=context_size, context_stride=context_stride, context_overlap=context_overlap, context_batch_size=context_batch_size) + return xt + + @torch.no_grad() + def ddim_reverse_sample(self, xt, t, model, model_kwargs={}, clamp=None, percentile=None, guide_scale=None, ddim_timesteps=20): + r"""Sample from p(x_{t+1} | x_t) using DDIM reverse ODE (deterministic). + """ + stride = self.num_timesteps // ddim_timesteps + + # predict distribution of p(x_{t-1} | x_t) + _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, percentile, guide_scale) + + # derive variables + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) + alphas_next = _i( + torch.cat([self.alphas_cumprod, self.alphas_cumprod.new_zeros([1])]), + (t + stride).clamp(0, self.num_timesteps), xt) + + # reverse sample + mu = torch.sqrt(alphas_next) * x0 + torch.sqrt(1 - alphas_next) * eps + return mu, x0 + + @torch.no_grad() + def ddim_reverse_sample_loop(self, x0, model, model_kwargs={}, clamp=None, percentile=None, guide_scale=None, ddim_timesteps=20): + # prepare input + b = x0.size(0) + xt = x0 + + # reconstruction steps + steps = torch.arange(0, self.num_timesteps, self.num_timesteps // ddim_timesteps) + for step in steps: + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _ = self.ddim_reverse_sample(xt, t, model, model_kwargs, clamp, percentile, guide_scale, ddim_timesteps) + return xt + + @torch.no_grad() + def plms_sample(self, xt, t, model, model_kwargs={}, clamp=None, percentile=None, condition_fn=None, guide_scale=None, plms_timesteps=20): + r"""Sample from p(x_{t-1} | x_t) using PLMS. + - condition_fn: for classifier-based guidance (guided-diffusion). + - guide_scale: for classifier-free guidance (glide/dalle-2). + """ + stride = self.num_timesteps // plms_timesteps + + # function for compute eps + def compute_eps(xt, t): + # predict distribution of p(x_{t-1} | x_t) + _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, percentile, guide_scale) + + # condition + if condition_fn is not None: + # x0 -> eps + alpha = _i(self.alphas_cumprod, t, xt) + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) + eps = eps - (1 - alpha).sqrt() * condition_fn(xt, self._scale_timesteps(t), **model_kwargs) + + # eps -> x0 + x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps + + # derive eps + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) + return eps + + # function for compute x_0 and x_{t-1} + def compute_x0(eps, t): + # eps -> x0 + x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps + + # deterministic sample + alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt) + direction = torch.sqrt(1 - alphas_prev) * eps + mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1))) + xt_1 = torch.sqrt(alphas_prev) * x0 + direction + return xt_1, x0 + + # PLMS sample + eps = compute_eps(xt, t) + if len(eps_cache) == 0: + # 2nd order pseudo improved Euler + xt_1, x0 = compute_x0(eps, t) + eps_next = compute_eps(xt_1, (t - stride).clamp(0)) + eps_prime = (eps + eps_next) / 2.0 + elif len(eps_cache) == 1: + # 2nd order pseudo linear multistep (Adams-Bashforth) + eps_prime = (3 * eps - eps_cache[-1]) / 2.0 + elif len(eps_cache) == 2: + # 3nd order pseudo linear multistep (Adams-Bashforth) + eps_prime = (23 * eps - 16 * eps_cache[-1] + 5 * eps_cache[-2]) / 12.0 + elif len(eps_cache) >= 3: + # 4nd order pseudo linear multistep (Adams-Bashforth) + eps_prime = (55 * eps - 59 * eps_cache[-1] + 37 * eps_cache[-2] - 9 * eps_cache[-3]) / 24.0 + xt_1, x0 = compute_x0(eps_prime, t) + return xt_1, x0, eps + + @torch.no_grad() + def plms_sample_loop(self, noise, model, model_kwargs={}, clamp=None, percentile=None, condition_fn=None, guide_scale=None, plms_timesteps=20): + # prepare input + b = noise.size(0) + xt = noise + + # diffusion process + steps = (1 + torch.arange(0, self.num_timesteps, self.num_timesteps // plms_timesteps)).clamp(0, self.num_timesteps - 1).flip(0) + eps_cache = [] + for step in steps: + # PLMS sampling step + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _, eps = self.plms_sample(xt, t, model, model_kwargs, clamp, percentile, condition_fn, guide_scale, plms_timesteps, eps_cache) + + # update eps cache + eps_cache.append(eps) + if len(eps_cache) >= 4: + eps_cache.pop(0) + return xt + + def loss(self, x0, t, model, model_kwargs={}, noise=None, weight = None, use_div_loss= False, loss_mask=None): + + # noise = torch.randn_like(x0) if noise is None else noise # [80, 4, 8, 32, 32] + noise = self.sample_loss(x0, noise) + + xt = self.q_sample(x0, t, noise=noise) + + # compute loss + if self.loss_type in ['kl', 'rescaled_kl']: + loss, _ = self.variational_lower_bound(x0, xt, t, model, model_kwargs) + if self.loss_type == 'rescaled_kl': + loss = loss * self.num_timesteps + elif self.loss_type in ['mse', 'rescaled_mse', 'l1', 'rescaled_l1']: # self.loss_type: mse + out = model(xt, self._scale_timesteps(t), **model_kwargs) + + # VLB for variation + loss_vlb = 0.0 + if self.var_type in ['learned', 'learned_range']: # self.var_type: 'fixed_small' + out, var = out.chunk(2, dim=1) + frozen = torch.cat([out.detach(), var], dim=1) # learn var without affecting the prediction of mean + loss_vlb, _ = self.variational_lower_bound(x0, xt, t, model=lambda *args, **kwargs: frozen) + if self.loss_type.startswith('rescaled_'): + loss_vlb = loss_vlb * self.num_timesteps / 1000.0 + + # MSE/L1 for x0/eps + # target = {'eps': noise, 'x0': x0, 'x_{t-1}': self.q_posterior_mean_variance(x0, xt, t)[0]}[self.mean_type] + target = { + 'eps': noise, + 'x0': x0, + 'x_{t-1}': self.q_posterior_mean_variance(x0, xt, t)[0], + 'v':_i(self.sqrt_alphas_cumprod, t, xt) * noise - _i(self.sqrt_one_minus_alphas_cumprod, t, xt) * x0}[self.mean_type] + if loss_mask is not None: + loss_mask = loss_mask[:, :, 0, ...].unsqueeze(2) # just use one channel (all channels are same) + loss_mask = loss_mask.permute(0, 2, 1, 3, 4) # b,c,f,h,w + # use masked diffusion + loss = (out * loss_mask - target * loss_mask).pow(1 if self.loss_type.endswith('l1') else 2).abs().flatten(1).mean(dim=1) + else: + loss = (out - target).pow(1 if self.loss_type.endswith('l1') else 2).abs().flatten(1).mean(dim=1) + if weight is not None: + loss = loss*weight + + # div loss + if use_div_loss and self.mean_type == 'eps' and x0.shape[2]>1: + + # derive x0 + x0_ = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * out + + + # ncfhw, std on f + div_loss = 0.001/(x0_.std(dim=2).flatten(1).mean(dim=1)+1e-4) + # print(div_loss,loss) + loss = loss+div_loss + + # total loss + loss = loss + loss_vlb + elif self.loss_type in ['charbonnier']: + out = model(xt, self._scale_timesteps(t), **model_kwargs) + + # VLB for variation + loss_vlb = 0.0 + if self.var_type in ['learned', 'learned_range']: + out, var = out.chunk(2, dim=1) + frozen = torch.cat([out.detach(), var], dim=1) # learn var without affecting the prediction of mean + loss_vlb, _ = self.variational_lower_bound(x0, xt, t, model=lambda *args, **kwargs: frozen) + if self.loss_type.startswith('rescaled_'): + loss_vlb = loss_vlb * self.num_timesteps / 1000.0 + + # MSE/L1 for x0/eps + target = {'eps': noise, 'x0': x0, 'x_{t-1}': self.q_posterior_mean_variance(x0, xt, t)[0]}[self.mean_type] + loss = torch.sqrt((out - target)**2 + self.epsilon) + if weight is not None: + loss = loss*weight + loss = loss.flatten(1).mean(dim=1) + + # total loss + loss = loss + loss_vlb + return loss + + def variational_lower_bound(self, x0, xt, t, model, model_kwargs={}, clamp=None, percentile=None): + # compute groundtruth and predicted distributions + mu1, _, log_var1 = self.q_posterior_mean_variance(x0, xt, t) + mu2, _, log_var2, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, percentile) + + # compute KL loss + kl = kl_divergence(mu1, log_var1, mu2, log_var2) + kl = kl.flatten(1).mean(dim=1) / math.log(2.0) + + # compute discretized NLL loss (for p(x0 | x1) only) + nll = -discretized_gaussian_log_likelihood(x0, mean=mu2, log_scale=0.5 * log_var2) + nll = nll.flatten(1).mean(dim=1) / math.log(2.0) + + # NLL for p(x0 | x1) and KL otherwise + vlb = torch.where(t == 0, nll, kl) + return vlb, x0 + + @torch.no_grad() + def variational_lower_bound_loop(self, x0, model, model_kwargs={}, clamp=None, percentile=None): + r"""Compute the entire variational lower bound, measured in bits-per-dim. + """ + # prepare input and output + b = x0.size(0) + metrics = {'vlb': [], 'mse': [], 'x0_mse': []} + + # loop + for step in torch.arange(self.num_timesteps).flip(0): + # compute VLB + t = torch.full((b, ), step, dtype=torch.long, device=x0.device) + # noise = torch.randn_like(x0) + noise = self.sample_loss(x0) + xt = self.q_sample(x0, t, noise) + vlb, pred_x0 = self.variational_lower_bound(x0, xt, t, model, model_kwargs, clamp, percentile) + + # predict eps from x0 + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) + + # collect metrics + metrics['vlb'].append(vlb) + metrics['x0_mse'].append((pred_x0 - x0).square().flatten(1).mean(dim=1)) + metrics['mse'].append((eps - noise).square().flatten(1).mean(dim=1)) + metrics = {k: torch.stack(v, dim=1) for k, v in metrics.items()} + + # compute the prior KL term for VLB, measured in bits-per-dim + mu, _, log_var = self.q_mean_variance(x0, t) + kl_prior = kl_divergence(mu, log_var, torch.zeros_like(mu), torch.zeros_like(log_var)) + kl_prior = kl_prior.flatten(1).mean(dim=1) / math.log(2.0) + + # update metrics + metrics['prior_bits_per_dim'] = kl_prior + metrics['total_bits_per_dim'] = metrics['vlb'].sum(dim=1) + kl_prior + return metrics + + def _scale_timesteps(self, t): + if self.rescale_timesteps: + return t.float() * 1000.0 / self.num_timesteps + return t + #return t.float() + + + +def ordered_halving(val): + bin_str = f"{val:064b}" + bin_flip = bin_str[::-1] + as_int = int(bin_flip, 2) + + return as_int / (1 << 64) + + +def context_scheduler( + step: int = ..., + num_steps: Optional[int] = None, + num_frames: int = ..., + context_size: Optional[int] = None, + context_stride: int = 3, + context_overlap: int = 4, + closed_loop: bool = False, +): + if num_frames <= context_size: + yield list(range(num_frames)) + return + + context_stride = min( + context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1 + ) + + for context_step in 1 << np.arange(context_stride): + pad = int(round(num_frames * ordered_halving(step))) + for j in range( + int(ordered_halving(step) * context_step) + pad, + num_frames + pad + (0 if closed_loop else -context_overlap), + (context_size * context_step - context_overlap), + ): + + yield [ + e % num_frames + for e in range(j, j + context_size * context_step, context_step) + ] + diff --git a/tools/modules/diffusions/diffusion_gauss.py b/tools/modules/diffusions/diffusion_gauss.py new file mode 100644 index 0000000..430ab3d --- /dev/null +++ b/tools/modules/diffusions/diffusion_gauss.py @@ -0,0 +1,498 @@ +""" +GaussianDiffusion wraps operators for denoising diffusion models, including the +diffusion and denoising processes, as well as the loss evaluation. +""" +import torch +import torchsde +import random +from tqdm.auto import trange + + +__all__ = ['GaussianDiffusion'] + + +def _i(tensor, t, x): + """ + Index tensor using t and format the output according to x. + """ + shape = (x.size(0), ) + (1, ) * (x.ndim - 1) + return tensor[t.to(tensor.device)].view(shape).to(x.device) + + +class BatchedBrownianTree: + """ + A wrapper around torchsde.BrownianTree that enables batches of entropy. + """ + def __init__(self, x, t0, t1, seed=None, **kwargs): + t0, t1, self.sign = self.sort(t0, t1) + w0 = kwargs.get('w0', torch.zeros_like(x)) + if seed is None: + seed = torch.randint(0, 2 ** 63 - 1, []).item() + self.batched = True + try: + assert len(seed) == x.shape[0] + w0 = w0[0] + except TypeError: + seed = [seed] + self.batched = False + self.trees = [torchsde.BrownianTree( + t0, w0, t1, entropy=s, **kwargs + ) for s in seed] + + @staticmethod + def sort(a, b): + return (a, b, 1) if a < b else (b, a, -1) + + def __call__(self, t0, t1): + t0, t1, sign = self.sort(t0, t1) + w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign) + return w if self.batched else w[0] + + +class BrownianTreeNoiseSampler: + """ + A noise sampler backed by a torchsde.BrownianTree. + + Args: + x (Tensor): The tensor whose shape, device and dtype to use to generate + random samples. + sigma_min (float): The low end of the valid interval. + sigma_max (float): The high end of the valid interval. + seed (int or List[int]): The random seed. If a list of seeds is + supplied instead of a single integer, then the noise sampler will + use one BrownianTree per batch item, each with its own seed. + transform (callable): A function that maps sigma to the sampler's + internal timestep. + """ + def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x): + self.transform = transform + t0 = self.transform(torch.as_tensor(sigma_min)) + t1 = self.transform(torch.as_tensor(sigma_max)) + self.tree = BatchedBrownianTree(x, t0, t1, seed) + + def __call__(self, sigma, sigma_next): + t0 = self.transform(torch.as_tensor(sigma)) + t1 = self.transform(torch.as_tensor(sigma_next)) + return self.tree(t0, t1) / (t1 - t0).abs().sqrt() + + +def get_scalings(sigma): + c_out = -sigma + c_in = 1 / (sigma ** 2 + 1. ** 2) ** 0.5 + return c_out, c_in + + +@torch.no_grad() +def sample_dpmpp_2m_sde( + noise, + model, + sigmas, + eta=1., + s_noise=1., + solver_type='midpoint', + show_progress=True +): + """ + DPM-Solver++ (2M) SDE. + """ + assert solver_type in {'heun', 'midpoint'} + + x = noise * sigmas[0] + sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas[sigmas < float('inf')].max() + noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) + old_denoised = None + h_last = None + + for i in trange(len(sigmas) - 1, disable=not show_progress): + if sigmas[i] == float('inf'): + # Euler method + denoised = model(noise, sigmas[i]) + x = denoised + sigmas[i + 1] * noise + else: + _, c_in = get_scalings(sigmas[i]) + denoised = model(x * c_in, sigmas[i]) + if sigmas[i + 1] == 0: + # Denoising step + x = denoised + else: + # DPM-Solver++(2M) SDE + t, s = -sigmas[i].log(), -sigmas[i + 1].log() + h = s - t + eta_h = eta * h + + x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + \ + (-h - eta_h).expm1().neg() * denoised + + if old_denoised is not None: + r = h_last / h + if solver_type == 'heun': + x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * \ + (1 / r) * (denoised - old_denoised) + elif solver_type == 'midpoint': + x = x + 0.5 * (-h - eta_h).expm1().neg() * \ + (1 / r) * (denoised - old_denoised) + + x = x + noise_sampler( + sigmas[i], + sigmas[i + 1] + ) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise + + old_denoised = denoised + h_last = h + return x + + +class GaussianDiffusion(object): + + def __init__(self, sigmas, prediction_type='eps'): + assert prediction_type in {'x0', 'eps', 'v'} + self.sigmas = sigmas.float() # noise coefficients + self.alphas = torch.sqrt(1 - sigmas ** 2).float() # signal coefficients + self.num_timesteps = len(sigmas) + self.prediction_type = prediction_type + + def diffuse(self, x0, t, noise=None): + """ + Add Gaussian noise to signal x0 according to: + q(x_t | x_0) = N(x_t | alpha_t x_0, sigma_t^2 I). + """ + noise = torch.randn_like(x0) if noise is None else noise + xt = _i(self.alphas, t, x0) * x0 + _i(self.sigmas, t, x0) * noise + return xt + + def denoise( + self, + xt, + t, + s, + model, + model_kwargs={}, + guide_scale=None, + guide_rescale=None, + clamp=None, + percentile=None + ): + """ + Apply one step of denoising from the posterior distribution q(x_s | x_t, x0). + Since x0 is not available, estimate the denoising results using the learned + distribution p(x_s | x_t, \hat{x}_0 == f(x_t)). + """ + s = t - 1 if s is None else s + + # hyperparams + sigmas = _i(self.sigmas, t, xt) + alphas = _i(self.alphas, t, xt) + alphas_s = _i(self.alphas, s.clamp(0), xt) + alphas_s[s < 0] = 1. + sigmas_s = torch.sqrt(1 - alphas_s ** 2) + + # precompute variables + betas = 1 - (alphas / alphas_s) ** 2 + coef1 = betas * alphas_s / sigmas ** 2 + coef2 = (alphas * sigmas_s ** 2) / (alphas_s * sigmas ** 2) + var = betas * (sigmas_s / sigmas) ** 2 + log_var = torch.log(var).clamp_(-20, 20) + + # prediction + if guide_scale is None: + assert isinstance(model_kwargs, dict) + out = model(xt, t=t, **model_kwargs) + else: + # classifier-free guidance (arXiv:2207.12598) + # model_kwargs[0]: conditional kwargs + # model_kwargs[1]: non-conditional kwargs + assert isinstance(model_kwargs, list) and len(model_kwargs) == 2 + y_out = model(xt, t=t, **model_kwargs[0]) + if guide_scale == 1.: + out = y_out + else: + u_out = model(xt, t=t, **model_kwargs[1]) + out = u_out + guide_scale * (y_out - u_out) + + # rescale the output according to arXiv:2305.08891 + if guide_rescale is not None: + assert guide_rescale >= 0 and guide_rescale <= 1 + ratio = (y_out.flatten(1).std(dim=1) / ( + out.flatten(1).std(dim=1) + 1e-12 + )).view((-1, ) + (1, ) * (y_out.ndim - 1)) + out *= guide_rescale * ratio + (1 - guide_rescale) * 1.0 + + # compute x0 + if self.prediction_type == 'x0': + x0 = out + elif self.prediction_type == 'eps': + x0 = (xt - sigmas * out) / alphas + elif self.prediction_type == 'v': + x0 = alphas * xt - sigmas * out + else: + raise NotImplementedError( + f'prediction_type {self.prediction_type} not implemented' + ) + + # restrict the range of x0 + if percentile is not None: + # NOTE: percentile should only be used when data is within range [-1, 1] + assert percentile > 0 and percentile <= 1 + s = torch.quantile(x0.flatten(1).abs(), percentile, dim=1) + s = s.clamp_(1.0).view((-1, ) + (1, ) * (xt.ndim - 1)) + x0 = torch.min(s, torch.max(-s, x0)) / s + elif clamp is not None: + x0 = x0.clamp(-clamp, clamp) + + # recompute eps using the restricted x0 + eps = (xt - alphas * x0) / sigmas + + # compute mu (mean of posterior distribution) using the restricted x0 + mu = coef1 * x0 + coef2 * xt + return mu, var, log_var, x0, eps + + @torch.no_grad() + def sample( + self, + noise, + model, + model_kwargs={}, + condition_fn=None, + guide_scale=None, + guide_rescale=None, + clamp=None, + percentile=None, + solver='euler_a', + steps=20, + t_max=None, + t_min=None, + discretization=None, + discard_penultimate_step=None, + return_intermediate=None, + show_progress=False, + seed=-1, + **kwargs + ): + # sanity check + assert isinstance(steps, (int, torch.LongTensor)) + assert t_max is None or (t_max > 0 and t_max <= self.num_timesteps - 1) + assert t_min is None or (t_min >= 0 and t_min < self.num_timesteps - 1) + assert discretization in (None, 'leading', 'linspace', 'trailing') + assert discard_penultimate_step in (None, True, False) + assert return_intermediate in (None, 'x0', 'xt') + + # function of diffusion solver + solver_fn = { + # 'heun': sample_heun, + 'dpmpp_2m_sde': sample_dpmpp_2m_sde + }[solver] + + # options + schedule = 'karras' if 'karras' in solver else None + discretization = discretization or 'linspace' + seed = seed if seed >= 0 else random.randint(0, 2 ** 31) + if isinstance(steps, torch.LongTensor): + discard_penultimate_step = False + if discard_penultimate_step is None: + discard_penultimate_step = True if solver in ( + 'dpm2', + 'dpm2_ancestral', + 'dpmpp_2m_sde', + 'dpm2_karras', + 'dpm2_ancestral_karras', + 'dpmpp_2m_sde_karras' + ) else False + + # function for denoising xt to get x0 + intermediates = [] + def model_fn(xt, sigma): + # denoising + t = self._sigma_to_t(sigma).repeat(len(xt)).round().long() + x0 = self.denoise( + xt, t, None, model, model_kwargs, guide_scale, guide_rescale, clamp, + percentile + )[-2] + + # collect intermediate outputs + if return_intermediate == 'xt': + intermediates.append(xt) + elif return_intermediate == 'x0': + intermediates.append(x0) + return x0 + + # get timesteps + if isinstance(steps, int): + steps += 1 if discard_penultimate_step else 0 + t_max = self.num_timesteps - 1 if t_max is None else t_max + t_min = 0 if t_min is None else t_min + + # discretize timesteps + if discretization == 'leading': + steps = torch.arange( + t_min, t_max + 1, (t_max - t_min + 1) / steps + ).flip(0) + elif discretization == 'linspace': + steps = torch.linspace(t_max, t_min, steps) + elif discretization == 'trailing': + steps = torch.arange(t_max, t_min - 1, -((t_max - t_min + 1) / steps)) + else: + raise NotImplementedError( + f'{discretization} discretization not implemented' + ) + steps = steps.clamp_(t_min, t_max) + steps = torch.as_tensor(steps, dtype=torch.float32, device=noise.device) + + # get sigmas + sigmas = self._t_to_sigma(steps) + sigmas = torch.cat([sigmas, sigmas.new_zeros([1])]) + if schedule == 'karras': + if sigmas[0] == float('inf'): + sigmas = karras_schedule( + n=len(steps) - 1, + sigma_min=sigmas[sigmas > 0].min().item(), + sigma_max=sigmas[sigmas < float('inf')].max().item(), + rho=7. + ).to(sigmas) + sigmas = torch.cat([ + sigmas.new_tensor([float('inf')]), sigmas, sigmas.new_zeros([1]) + ]) + else: + sigmas = karras_schedule( + n=len(steps), + sigma_min=sigmas[sigmas > 0].min().item(), + sigma_max=sigmas.max().item(), + rho=7. + ).to(sigmas) + sigmas = torch.cat([sigmas, sigmas.new_zeros([1])]) + if discard_penultimate_step: + sigmas = torch.cat([sigmas[:-2], sigmas[-1:]]) + + # sampling + x0 = solver_fn( + noise, + model_fn, + sigmas, + show_progress=show_progress, + **kwargs + ) + return (x0, intermediates) if return_intermediate is not None else x0 + + @torch.no_grad() + def ddim_reverse_sample( + self, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None, + guide_scale=None, + guide_rescale=None, + ddim_timesteps=20, + reverse_steps=600 + ): + r"""Sample from p(x_{t+1} | x_t) using DDIM reverse ODE (deterministic). + """ + stride = reverse_steps // ddim_timesteps + + # predict distribution of p(x_{t-1} | x_t) + _, _, _, x0, eps = self.denoise( + xt, t, None, model, model_kwargs, guide_scale, guide_rescale, clamp, + percentile + ) + # derive variables + s = (t + stride).clamp(0, reverse_steps-1) + # hyperparams + sigmas = _i(self.sigmas, t, xt) + alphas = _i(self.alphas, t, xt) + alphas_s = _i(self.alphas, s.clamp(0), xt) + alphas_s[s < 0] = 1. + sigmas_s = torch.sqrt(1 - alphas_s ** 2) + + # reverse sample + mu = alphas_s * x0 + sigmas_s * eps + return mu, x0 + + @torch.no_grad() + def ddim_reverse_sample_loop( + self, + x0, + model, + model_kwargs={}, + clamp=None, + percentile=None, + guide_scale=None, + guide_rescale=None, + ddim_timesteps=20, + reverse_steps=600 + ): + # prepare input + b = x0.size(0) + xt = x0 + + # reconstruction steps + steps = torch.arange(0, reverse_steps, reverse_steps // ddim_timesteps) + for step in steps: + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _ = self.ddim_reverse_sample(xt, t, model, model_kwargs, clamp, percentile, guide_scale, guide_rescale, ddim_timesteps, reverse_steps) + return xt + + def _sigma_to_t(self, sigma): + if sigma == float('inf'): + t = torch.full_like(sigma, len(self.sigmas) - 1) + else: + log_sigmas = torch.sqrt( + self.sigmas ** 2 / (1 - self.sigmas ** 2) + ).log().to(sigma) + log_sigma = sigma.log() + dists = log_sigma - log_sigmas[:, None] + low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp( + max=log_sigmas.shape[0] - 2 + ) + high_idx = low_idx + 1 + low, high = log_sigmas[low_idx], log_sigmas[high_idx] + w = (low - log_sigma) / (low - high) + w = w.clamp(0, 1) + t = (1 - w) * low_idx + w * high_idx + t = t.view(sigma.shape) + if t.ndim == 0: + t = t.unsqueeze(0) + return t + + def _t_to_sigma(self, t): + t = t.float() + low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac() + log_sigmas = torch.sqrt(self.sigmas ** 2 / (1 - self.sigmas ** 2)).log().to(t) + log_sigma = (1 - w) * log_sigmas[low_idx] + w * log_sigmas[high_idx] + log_sigma[torch.isnan(log_sigma) | torch.isinf(log_sigma)] = float('inf') + return log_sigma.exp() + + def prev_step(self, model_out, t, xt, inference_steps=50): + prev_t = t - self.num_timesteps // inference_steps + + sigmas = _i(self.sigmas, t, xt) + alphas = _i(self.alphas, t, xt) + alphas_prev = _i(self.alphas, prev_t.clamp(0), xt) + alphas_prev[prev_t < 0] = 1. + sigmas_prev = torch.sqrt(1 - alphas_prev ** 2) + + x0 = alphas * xt - sigmas * model_out + eps = (xt - alphas * x0) / sigmas + prev_sample = alphas_prev * x0 + sigmas_prev * eps + return prev_sample + + def next_step(self, model_out, t, xt, inference_steps=50): + t, next_t = min(t - self.num_timesteps // inference_steps, 999), t + + sigmas = _i(self.sigmas, t, xt) + alphas = _i(self.alphas, t, xt) + alphas_next = _i(self.alphas, next_t.clamp(0), xt) + alphas_next[next_t < 0] = 1. + sigmas_next = torch.sqrt(1 - alphas_next ** 2) + + x0 = alphas * xt - sigmas * model_out + eps = (xt - alphas * x0) / sigmas + next_sample = alphas_next * x0 + sigmas_next * eps + return next_sample + + def get_noise_pred_single(self, xt, t, model, model_kwargs): + assert isinstance(model_kwargs, dict) + out = model(xt, t=t, **model_kwargs) + return out + + diff --git a/tools/modules/diffusions/losses.py b/tools/modules/diffusions/losses.py new file mode 100644 index 0000000..d3188d8 --- /dev/null +++ b/tools/modules/diffusions/losses.py @@ -0,0 +1,28 @@ +import torch +import math + +__all__ = ['kl_divergence', 'discretized_gaussian_log_likelihood'] + +def kl_divergence(mu1, logvar1, mu2, logvar2): + return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mu1 - mu2) ** 2) * torch.exp(-logvar2)) + +def standard_normal_cdf(x): + r"""A fast approximation of the cumulative distribution function of the standard normal. + """ + return 0.5 * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) + +def discretized_gaussian_log_likelihood(x0, mean, log_scale): + assert x0.shape == mean.shape == log_scale.shape + cx = x0 - mean + inv_stdv = torch.exp(-log_scale) + cdf_plus = standard_normal_cdf(inv_stdv * (cx + 1.0 / 255.0)) + cdf_min = standard_normal_cdf(inv_stdv * (cx - 1.0 / 255.0)) + log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12)) + log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12)) + cdf_delta = cdf_plus - cdf_min + log_probs = torch.where( + x0 < -0.999, + log_cdf_plus, + torch.where(x0 > 0.999, log_one_minus_cdf_min, torch.log(cdf_delta.clamp(min=1e-12)))) + assert log_probs.shape == x0.shape + return log_probs diff --git a/tools/modules/diffusions/schedules.py b/tools/modules/diffusions/schedules.py new file mode 100644 index 0000000..4e15870 --- /dev/null +++ b/tools/modules/diffusions/schedules.py @@ -0,0 +1,166 @@ +import math +import torch + + +def beta_schedule(schedule='cosine', + num_timesteps=1000, + zero_terminal_snr=False, + **kwargs): + # compute betas + betas = { + # 'logsnr_cosine_interp': logsnr_cosine_interp_schedule, + 'linear': linear_schedule, + 'linear_sd': linear_sd_schedule, + 'quadratic': quadratic_schedule, + 'cosine': cosine_schedule + }[schedule](num_timesteps, **kwargs) + + if zero_terminal_snr and abs(betas.max() - 1.0) > 0.0001: + betas = rescale_zero_terminal_snr(betas) + + return betas + + +def sigma_schedule(schedule='cosine', + num_timesteps=1000, + zero_terminal_snr=False, + **kwargs): + # compute betas + betas = { + 'logsnr_cosine_interp': logsnr_cosine_interp_schedule, + 'linear': linear_schedule, + 'linear_sd': linear_sd_schedule, + 'quadratic': quadratic_schedule, + 'cosine': cosine_schedule + }[schedule](num_timesteps, **kwargs) + if schedule == 'logsnr_cosine_interp': + sigma = betas + else: + sigma = betas_to_sigmas(betas) + if zero_terminal_snr and abs(sigma.max() - 1.0) > 0.0001: + sigma = rescale_zero_terminal_snr(sigma) + + return sigma + + +def linear_schedule(num_timesteps, init_beta, last_beta, **kwargs): + scale = 1000.0 / num_timesteps + init_beta = init_beta or scale * 0.0001 + ast_beta = last_beta or scale * 0.02 + return torch.linspace(init_beta, last_beta, num_timesteps, dtype=torch.float64) + +def logsnr_cosine_interp_schedule( + num_timesteps, + scale_min=2, + scale_max=4, + logsnr_min=-15, + logsnr_max=15, + **kwargs): + return logsnrs_to_sigmas( + _logsnr_cosine_interp(num_timesteps, logsnr_min, logsnr_max, scale_min, scale_max)) + +def linear_sd_schedule(num_timesteps, init_beta, last_beta, **kwargs): + return torch.linspace(init_beta ** 0.5, last_beta ** 0.5, num_timesteps, dtype=torch.float64) ** 2 + + +def quadratic_schedule(num_timesteps, init_beta, last_beta, **kwargs): + init_beta = init_beta or 0.0015 + last_beta = last_beta or 0.0195 + return torch.linspace(init_beta ** 0.5, last_beta ** 0.5, num_timesteps, dtype=torch.float64) ** 2 + + +def cosine_schedule(num_timesteps, cosine_s=0.008, **kwargs): + betas = [] + for step in range(num_timesteps): + t1 = step / num_timesteps + t2 = (step + 1) / num_timesteps + fn = lambda u: math.cos((u + cosine_s) / (1 + cosine_s) * math.pi / 2) ** 2 + betas.append(min(1.0 - fn(t2) / fn(t1), 0.999)) + return torch.tensor(betas, dtype=torch.float64) + + +# def cosine_schedule(n, cosine_s=0.008, **kwargs): +# ramp = torch.linspace(0, 1, n + 1) +# square_alphas = torch.cos((ramp + cosine_s) / (1 + cosine_s) * torch.pi / 2) ** 2 +# betas = (1 - square_alphas[1:] / square_alphas[:-1]).clamp(max=0.999) +# return betas_to_sigmas(betas) + + +def betas_to_sigmas(betas): + return torch.sqrt(1 - torch.cumprod(1 - betas, dim=0)) + + +def sigmas_to_betas(sigmas): + square_alphas = 1 - sigmas**2 + betas = 1 - torch.cat( + [square_alphas[:1], square_alphas[1:] / square_alphas[:-1]]) + return betas + + + +def sigmas_to_logsnrs(sigmas): + square_sigmas = sigmas**2 + return torch.log(square_sigmas / (1 - square_sigmas)) + + +def _logsnr_cosine(n, logsnr_min=-15, logsnr_max=15): + t_min = math.atan(math.exp(-0.5 * logsnr_min)) + t_max = math.atan(math.exp(-0.5 * logsnr_max)) + t = torch.linspace(1, 0, n) + logsnrs = -2 * torch.log(torch.tan(t_min + t * (t_max - t_min))) + return logsnrs + + +def _logsnr_cosine_shifted(n, logsnr_min=-15, logsnr_max=15, scale=2): + logsnrs = _logsnr_cosine(n, logsnr_min, logsnr_max) + logsnrs += 2 * math.log(1 / scale) + return logsnrs + +def karras_schedule(n, sigma_min=0.002, sigma_max=80.0, rho=7.0): + ramp = torch.linspace(1, 0, n) + min_inv_rho = sigma_min**(1 / rho) + max_inv_rho = sigma_max**(1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho))**rho + sigmas = torch.sqrt(sigmas**2 / (1 + sigmas**2)) + return sigmas + +def _logsnr_cosine_interp(n, + logsnr_min=-15, + logsnr_max=15, + scale_min=2, + scale_max=4): + t = torch.linspace(1, 0, n) + logsnrs_min = _logsnr_cosine_shifted(n, logsnr_min, logsnr_max, scale_min) + logsnrs_max = _logsnr_cosine_shifted(n, logsnr_min, logsnr_max, scale_max) + logsnrs = t * logsnrs_min + (1 - t) * logsnrs_max + return logsnrs + + +def logsnrs_to_sigmas(logsnrs): + return torch.sqrt(torch.sigmoid(-logsnrs)) + + +def rescale_zero_terminal_snr(betas): + """ + Rescale Schedule to Zero Terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas = 1 - betas + alphas_bar = alphas.cumprod(0) + alphas_bar_sqrt = alphas_bar.sqrt() + + # Store old values. 8 alphas_bar_sqrt_0 = a + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + # Shift so last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + # Scale so first timestep is back to old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt ** 2 + alphas = alphas_bar[1:] / alphas_bar[:-1] + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + return betas + diff --git a/tools/modules/embedding_manager.py b/tools/modules/embedding_manager.py new file mode 100644 index 0000000..763f3dd --- /dev/null +++ b/tools/modules/embedding_manager.py @@ -0,0 +1,179 @@ +import torch +from torch import nn +import torch.nn.functional as F +import open_clip + +from functools import partial +from ...utils.registry_class import EMBEDMANAGER + +DEFAULT_PLACEHOLDER_TOKEN = ["*"] + +PROGRESSIVE_SCALE = 2000 + +per_img_token_list = [ + 'א', 'ב', 'ג', 'ד', 'ה', 'ו', 'ז', 'ח', 'ט', 'י', 'כ', 'ל', 'מ', 'נ', 'ס', 'ע', 'פ', 'צ', 'ק', 'ר', 'ש', 'ת', +] + +def get_clip_token_for_string(string): + tokens = open_clip.tokenize(string) + + return tokens[0, 1] + +def get_embedding_for_clip_token(embedder, token): + return embedder(token.unsqueeze(0))[0] + + +@EMBEDMANAGER.register_class() +class EmbeddingManager(nn.Module): + def __init__( + self, + embedder, + placeholder_strings=None, + initializer_words=None, + per_image_tokens=False, + num_vectors_per_token=1, + progressive_words=False, + temporal_prompt_length=1, + token_dim=1024, + **kwargs + ): + super().__init__() + + self.string_to_token_dict = {} + + self.string_to_param_dict = nn.ParameterDict() + + self.initial_embeddings = nn.ParameterDict() # These should not be optimized + + self.progressive_words = progressive_words + self.progressive_counter = 0 + + self.max_vectors_per_token = num_vectors_per_token + + get_embedding_for_tkn = partial(get_embedding_for_clip_token, embedder.model.token_embedding.cpu()) + + if per_image_tokens: + placeholder_strings.extend(per_img_token_list) + + for idx, placeholder_string in enumerate(placeholder_strings): + + token = get_clip_token_for_string(placeholder_string) + + if initializer_words and idx < len(initializer_words): + init_word_token = get_clip_token_for_string(initializer_words[idx]) + + with torch.no_grad(): + init_word_embedding = get_embedding_for_tkn(init_word_token) + + token_params = torch.nn.Parameter(init_word_embedding.unsqueeze(0).repeat(num_vectors_per_token, 1), requires_grad=True) + self.initial_embeddings[placeholder_string] = torch.nn.Parameter(init_word_embedding.unsqueeze(0).repeat(num_vectors_per_token, 1), requires_grad=False) + else: + token_params = torch.nn.Parameter(torch.rand(size=(num_vectors_per_token, token_dim), requires_grad=True)) + + self.string_to_token_dict[placeholder_string] = token + self.string_to_param_dict[placeholder_string] = token_params + + + def forward( + self, + tokenized_text, + embedded_text, + ): + b, n, device = *tokenized_text.shape, tokenized_text.device + + for placeholder_string, placeholder_token in self.string_to_token_dict.items(): + + placeholder_embedding = self.string_to_param_dict[placeholder_string].to(device) + + if self.max_vectors_per_token == 1: # If there's only one vector per token, we can do a simple replacement + placeholder_idx = torch.where(tokenized_text == placeholder_token.to(device)) + embedded_text[placeholder_idx] = placeholder_embedding + else: # otherwise, need to insert and keep track of changing indices + if self.progressive_words: + self.progressive_counter += 1 + max_step_tokens = 1 + self.progressive_counter // PROGRESSIVE_SCALE + else: + max_step_tokens = self.max_vectors_per_token + + num_vectors_for_token = min(placeholder_embedding.shape[0], max_step_tokens) + + placeholder_rows, placeholder_cols = torch.where(tokenized_text == placeholder_token.to(device)) + + if placeholder_rows.nelement() == 0: + continue + + sorted_cols, sort_idx = torch.sort(placeholder_cols, descending=True) + sorted_rows = placeholder_rows[sort_idx] + + for idx in range(len(sorted_rows)): + row = sorted_rows[idx] + col = sorted_cols[idx] + + new_token_row = torch.cat([tokenized_text[row][:col], placeholder_token.repeat(num_vectors_for_token).to(device), tokenized_text[row][col + 1:]], axis=0)[:n] + new_embed_row = torch.cat([embedded_text[row][:col], placeholder_embedding[:num_vectors_for_token], embedded_text[row][col + 1:]], axis=0)[:n] + + embedded_text[row] = new_embed_row + tokenized_text[row] = new_token_row + + return embedded_text + + def forward_with_text_img( + self, + tokenized_text, + embedded_text, + embedded_img, + ): + device = tokenized_text.device + for placeholder_string, placeholder_token in self.string_to_token_dict.items(): + placeholder_embedding = self.string_to_param_dict[placeholder_string].to(device) + placeholder_idx = torch.where(tokenized_text == placeholder_token.to(device)) + embedded_text[placeholder_idx] = embedded_text[placeholder_idx] + embedded_img + placeholder_embedding + return embedded_text + + def forward_with_text( + self, + tokenized_text, + embedded_text + ): + device = tokenized_text.device + for placeholder_string, placeholder_token in self.string_to_token_dict.items(): + placeholder_embedding = self.string_to_param_dict[placeholder_string].to(device) + placeholder_idx = torch.where(tokenized_text == placeholder_token.to(device)) + embedded_text[placeholder_idx] = embedded_text[placeholder_idx] + placeholder_embedding + return embedded_text + + def save(self, ckpt_path): + torch.save({"string_to_token": self.string_to_token_dict, + "string_to_param": self.string_to_param_dict}, ckpt_path) + + def load(self, ckpt_path): + ckpt = torch.load(ckpt_path, map_location='cpu') + + string_to_token = ckpt["string_to_token"] + string_to_param = ckpt["string_to_param"] + for string, token in string_to_token.items(): + self.string_to_token_dict[string] = token + for string, param in string_to_param.items(): + self.string_to_param_dict[string] = param + + def get_embedding_norms_squared(self): + all_params = torch.cat(list(self.string_to_param_dict.values()), axis=0) # num_placeholders x embedding_dim + param_norm_squared = (all_params * all_params).sum(axis=-1) # num_placeholders + + return param_norm_squared + + def embedding_parameters(self): + return self.string_to_param_dict.parameters() + + def embedding_to_coarse_loss(self): + + loss = 0. + num_embeddings = len(self.initial_embeddings) + + for key in self.initial_embeddings: + optimized = self.string_to_param_dict[key] + coarse = self.initial_embeddings[key].clone().to(optimized.device) + + loss = loss + (optimized - coarse) @ (optimized - coarse).T / num_embeddings + + return loss \ No newline at end of file diff --git a/tools/modules/unet/__init__.py b/tools/modules/unet/__init__.py new file mode 100644 index 0000000..3d755e9 --- /dev/null +++ b/tools/modules/unet/__init__.py @@ -0,0 +1,2 @@ +from .unet_unianimate import * + diff --git a/tools/modules/unet/__pycache__/__init__.cpython-310.pyc b/tools/modules/unet/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..e9905b3 Binary files /dev/null and b/tools/modules/unet/__pycache__/__init__.cpython-310.pyc differ diff --git a/tools/modules/unet/__pycache__/__init__.cpython-39.pyc b/tools/modules/unet/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000..141624b Binary files /dev/null and b/tools/modules/unet/__pycache__/__init__.cpython-39.pyc differ diff --git a/tools/modules/unet/__pycache__/unet_unianimate.cpython-310.pyc b/tools/modules/unet/__pycache__/unet_unianimate.cpython-310.pyc new file mode 100644 index 0000000..41b338a Binary files /dev/null and b/tools/modules/unet/__pycache__/unet_unianimate.cpython-310.pyc differ diff --git a/tools/modules/unet/__pycache__/unet_unianimate.cpython-39.pyc b/tools/modules/unet/__pycache__/unet_unianimate.cpython-39.pyc new file mode 100644 index 0000000..d71fc81 Binary files /dev/null and b/tools/modules/unet/__pycache__/unet_unianimate.cpython-39.pyc differ diff --git a/tools/modules/unet/__pycache__/util.cpython-310.pyc b/tools/modules/unet/__pycache__/util.cpython-310.pyc new file mode 100644 index 0000000..674f94e Binary files /dev/null and b/tools/modules/unet/__pycache__/util.cpython-310.pyc differ diff --git a/tools/modules/unet/__pycache__/util.cpython-39.pyc b/tools/modules/unet/__pycache__/util.cpython-39.pyc new file mode 100644 index 0000000..d500b1f Binary files /dev/null and b/tools/modules/unet/__pycache__/util.cpython-39.pyc differ diff --git a/tools/modules/unet/mha_flash.py b/tools/modules/unet/mha_flash.py new file mode 100644 index 0000000..5edfe0e --- /dev/null +++ b/tools/modules/unet/mha_flash.py @@ -0,0 +1,103 @@ +import torch +import torch.nn as nn +import torch.cuda.amp as amp +import torch.nn.functional as F +import math +import os +import time +import numpy as np +import random + +# from flash_attn.flash_attention import FlashAttention +class FlashAttentionBlock(nn.Module): + + def __init__(self, dim, context_dim=None, num_heads=None, head_dim=None, batch_size=4): + # consider head_dim first, then num_heads + num_heads = dim // head_dim if head_dim else num_heads + head_dim = dim // num_heads + assert num_heads * head_dim == dim + super(FlashAttentionBlock, self).__init__() + self.dim = dim + self.context_dim = context_dim + self.num_heads = num_heads + self.head_dim = head_dim + self.scale = math.pow(head_dim, -0.25) + + # layers + self.norm = nn.GroupNorm(32, dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + if context_dim is not None: + self.context_kv = nn.Linear(context_dim, dim * 2) + self.proj = nn.Conv2d(dim, dim, 1) + + if self.head_dim <= 128 and (self.head_dim % 8) == 0: + new_scale = math.pow(head_dim, -0.5) + self.flash_attn = FlashAttention(softmax_scale=None, attention_dropout=0.0) + + # zero out the last layer params + nn.init.zeros_(self.proj.weight) + # self.apply(self._init_weight) + + + def _init_weight(self, module): + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=0.15) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Conv2d): + module.weight.data.normal_(mean=0.0, std=0.15) + if module.bias is not None: + module.bias.data.zero_() + + def forward(self, x, context=None): + r"""x: [B, C, H, W]. + context: [B, L, C] or None. + """ + identity = x + b, c, h, w, n, d = *x.size(), self.num_heads, self.head_dim + + # compute query, key, value + x = self.norm(x) + q, k, v = self.to_qkv(x).view(b, n * 3, d, h * w).chunk(3, dim=1) + if context is not None: + ck, cv = self.context_kv(context).reshape(b, -1, n * 2, d).permute(0, 2, 3, 1).chunk(2, dim=1) + k = torch.cat([ck, k], dim=-1) + v = torch.cat([cv, v], dim=-1) + cq = torch.zeros([b, n, d, 4], dtype=q.dtype, device=q.device) + q = torch.cat([q, cq], dim=-1) + + qkv = torch.cat([q,k,v], dim=1) + origin_dtype = qkv.dtype + qkv = qkv.permute(0, 3, 1, 2).reshape(b, -1, 3, n, d).half().contiguous() + out, _ = self.flash_attn(qkv) + out.to(origin_dtype) + + if context is not None: + out = out[:, :-4, :, :] + out = out.permute(0, 2, 3, 1).reshape(b, c, h, w) + + # output + x = self.proj(out) + return x + identity + +if __name__ == '__main__': + batch_size = 8 + flash_net = FlashAttentionBlock(dim=1280, context_dim=512, num_heads=None, head_dim=64, batch_size=batch_size).cuda() + + x = torch.randn([batch_size, 1280, 32, 32], dtype=torch.float32).cuda() + context = torch.randn([batch_size, 4, 512], dtype=torch.float32).cuda() + # context = None + flash_net.eval() + + with amp.autocast(enabled=True): + # warm up + for i in range(5): + y = flash_net(x, context) + torch.cuda.synchronize() + s1 = time.time() + for i in range(10): + y = flash_net(x, context) + torch.cuda.synchronize() + s2 = time.time() + + print(f'Average cost time {(s2-s1)*1000/10} ms') \ No newline at end of file diff --git a/tools/modules/unet/unet_unianimate.py b/tools/modules/unet/unet_unianimate.py new file mode 100644 index 0000000..097b52f --- /dev/null +++ b/tools/modules/unet/unet_unianimate.py @@ -0,0 +1,659 @@ +import math +import torch +# import xformers +# import xformers.ops +import torch.nn as nn +from einops import rearrange +import torch.nn.functional as F +from ....lib.rotary_embedding_torch import RotaryEmbedding +from fairscale.nn.checkpoint import checkpoint_wrapper + +from .util import * +# from .mha_flash import FlashAttentionBlock +from ....utils.registry_class import MODEL + + +USE_TEMPORAL_TRANSFORMER = True + + + +class PreNormattention(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + x + +class PreNormattention_qkv(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, q, k, v, **kwargs): + return self.fn(self.norm(q), self.norm(k), self.norm(v), **kwargs) + q + +class Attention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.attend = nn.Softmax(dim = -1) + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + + def forward(self, x): + b, n, _, h = *x.shape, self.heads + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) + + dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale + + attn = self.attend(dots) + + out = einsum('b h i j, b h j d -> b h i d', attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + + +class Attention_qkv(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.attend = nn.Softmax(dim = -1) + self.to_q = nn.Linear(dim, inner_dim, bias = False) + self.to_k = nn.Linear(dim, inner_dim, bias = False) + self.to_v = nn.Linear(dim, inner_dim, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + + def forward(self, q, k, v): + b, n, _, h = *q.shape, self.heads + bk = k.shape[0] + + q = self.to_q(q) + k = self.to_k(k) + v = self.to_v(v) + q = rearrange(q, 'b n (h d) -> b h n d', h = h) + k = rearrange(k, 'b n (h d) -> b h n d', b=bk, h = h) + v = rearrange(v, 'b n (h d) -> b h n d', b=bk, h = h) + + dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale + + attn = self.attend(dots) + + out = einsum('b h i j, b h j d -> b h i d', attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + +class PostNormattention(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, **kwargs): + return self.norm(self.fn(x, **kwargs) + x) + + + + +class Transformer_v2(nn.Module): + def __init__(self, heads=8, dim=2048, dim_head_k=256, dim_head_v=256, dropout_atte = 0.05, mlp_dim=2048, dropout_ffn = 0.05, depth=1): + super().__init__() + self.layers = nn.ModuleList([]) + self.depth = depth + for _ in range(depth): + self.layers.append(nn.ModuleList([ + PreNormattention(dim, Attention(dim, heads = heads, dim_head = dim_head_k, dropout = dropout_atte)), + FeedForward(dim, mlp_dim, dropout = dropout_ffn), + ])) + def forward(self, x): + for attn, ff in self.layers[:1]: + x = attn(x) + x = ff(x) + x + if self.depth > 1: + for attn, ff in self.layers[1:]: + x = attn(x) + x = ff(x) + x + return x + + +class DropPath(nn.Module): + r"""DropPath but without rescaling and supports optional all-zero and/or all-keep. + """ + def __init__(self, p): + super(DropPath, self).__init__() + self.p = p + + def forward(self, *args, zero=None, keep=None): + if not self.training: + return args[0] if len(args) == 1 else args + + # params + x = args[0] + b = x.size(0) + n = (torch.rand(b) < self.p).sum() + + # non-zero and non-keep mask + mask = x.new_ones(b, dtype=torch.bool) + if keep is not None: + mask[keep] = False + if zero is not None: + mask[zero] = False + + # drop-path index + index = torch.where(mask)[0] + index = index[torch.randperm(len(index))[:n]] + if zero is not None: + index = torch.cat([index, torch.where(zero)[0]], dim=0) + + # drop-path multiplier + multiplier = x.new_ones(b) + multiplier[index] = 0.0 + output = tuple(u * self.broadcast(multiplier, u) for u in args) + return output[0] if len(args) == 1 else output + + def broadcast(self, src, dst): + assert src.size(0) == dst.size(0) + shape = (dst.size(0), ) + (1, ) * (dst.ndim - 1) + return src.view(shape) + + + + +@MODEL.register_class() +class UNetSD_UniAnimate(nn.Module): + + def __init__(self, + config=None, + in_dim=4, + dim=512, + y_dim=512, + context_dim=1024, + hist_dim = 156, + concat_dim = 8, + out_dim=6, + dim_mult=[1, 2, 3, 4], + num_heads=None, + head_dim=64, + num_res_blocks=3, + attn_scales=[1 / 2, 1 / 4, 1 / 8], + use_scale_shift_norm=True, + dropout=0.1, + temporal_attn_times=1, + temporal_attention = True, + use_checkpoint=False, + use_image_dataset=False, + use_fps_condition= False, + use_sim_mask = False, + misc_dropout = 0.5, + training=True, + inpainting=True, + p_all_zero=0.1, + p_all_keep=0.1, + zero_y = None, + black_image_feature = None, + adapter_transformer_layers = 1, + num_tokens=4, + **kwargs + ): + embed_dim = dim * 4 + num_heads=num_heads if num_heads else dim//32 + super(UNetSD_UniAnimate, self).__init__() + self.zero_y = zero_y + self.black_image_feature = black_image_feature + self.cfg = config + self.in_dim = in_dim + self.dim = dim + self.y_dim = y_dim + self.context_dim = context_dim + self.num_tokens = num_tokens + self.hist_dim = hist_dim + self.concat_dim = concat_dim + self.embed_dim = embed_dim + self.out_dim = out_dim + self.dim_mult = dim_mult + + self.num_heads = num_heads + + self.head_dim = head_dim + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.use_scale_shift_norm = use_scale_shift_norm + self.temporal_attn_times = temporal_attn_times + self.temporal_attention = temporal_attention + self.use_checkpoint = use_checkpoint + self.use_image_dataset = use_image_dataset + self.use_fps_condition = use_fps_condition + self.use_sim_mask = use_sim_mask + self.training=training + self.inpainting = inpainting + self.video_compositions = self.cfg.video_compositions + self.misc_dropout = misc_dropout + self.p_all_zero = p_all_zero + self.p_all_keep = p_all_keep + + use_linear_in_temporal = False + transformer_depth = 1 + disabled_sa = False + # params + enc_dims = [dim * u for u in [1] + dim_mult] + dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + shortcut_dims = [] + scale = 1.0 + self.resolution = config.resolution + + + # embeddings + self.time_embed = nn.Sequential( + nn.Linear(dim, embed_dim), + nn.SiLU(), + nn.Linear(embed_dim, embed_dim)) + if 'image' in self.video_compositions: + self.pre_image_condition = nn.Sequential( + nn.Linear(self.context_dim, self.context_dim), + nn.SiLU(), + nn.Linear(self.context_dim, self.context_dim*self.num_tokens)) + + + if 'local_image' in self.video_compositions: + self.local_image_embedding = nn.Sequential( + nn.Conv2d(3, concat_dim * 4, 3, padding=1), + nn.SiLU(), + nn.AdaptiveAvgPool2d((self.resolution[1]//2, self.resolution[0]//2)), + nn.Conv2d(concat_dim * 4, concat_dim * 4, 3, stride=2, padding=1), + nn.SiLU(), + nn.Conv2d(concat_dim * 4, concat_dim, 3, stride=2, padding=1)) + self.local_image_embedding_after = Transformer_v2(heads=2, dim=concat_dim, dim_head_k=concat_dim, dim_head_v=concat_dim, dropout_atte = 0.05, mlp_dim=concat_dim, dropout_ffn = 0.05, depth=adapter_transformer_layers) + + if 'dwpose' in self.video_compositions: + self.dwpose_embedding = nn.Sequential( + nn.Conv2d(3, concat_dim * 4, 3, padding=1), + nn.SiLU(), + nn.AdaptiveAvgPool2d((self.resolution[1]//2, self.resolution[0]//2)), + nn.Conv2d(concat_dim * 4, concat_dim * 4, 3, stride=2, padding=1), + nn.SiLU(), + nn.Conv2d(concat_dim * 4, concat_dim, 3, stride=2, padding=1)) + self.dwpose_embedding_after = Transformer_v2(heads=2, dim=concat_dim, dim_head_k=concat_dim, dim_head_v=concat_dim, dropout_atte = 0.05, mlp_dim=concat_dim, dropout_ffn = 0.05, depth=adapter_transformer_layers) + + if 'randomref_pose' in self.video_compositions: + randomref_dim = 4 + self.randomref_pose2_embedding = nn.Sequential( + nn.Conv2d(3, concat_dim * 4, 3, padding=1), + nn.SiLU(), + nn.AdaptiveAvgPool2d((self.resolution[1]//2, self.resolution[0]//2)), + nn.Conv2d(concat_dim * 4, concat_dim * 4, 3, stride=2, padding=1), + nn.SiLU(), + nn.Conv2d(concat_dim * 4, concat_dim+randomref_dim, 3, stride=2, padding=1)) + self.randomref_pose2_embedding_after = Transformer_v2(heads=2, dim=concat_dim+randomref_dim, dim_head_k=concat_dim+randomref_dim, dim_head_v=concat_dim+randomref_dim, dropout_atte = 0.05, mlp_dim=concat_dim+randomref_dim, dropout_ffn = 0.05, depth=adapter_transformer_layers) + + if 'randomref' in self.video_compositions: + randomref_dim = 4 + self.randomref_embedding2 = nn.Sequential( + nn.Conv2d(randomref_dim, concat_dim * 4, 3, padding=1), + nn.SiLU(), + nn.Conv2d(concat_dim * 4, concat_dim * 4, 3, stride=1, padding=1), + nn.SiLU(), + nn.Conv2d(concat_dim * 4, concat_dim+randomref_dim, 3, stride=1, padding=1)) + self.randomref_embedding_after2 = Transformer_v2(heads=2, dim=concat_dim+randomref_dim, dim_head_k=concat_dim+randomref_dim, dim_head_v=concat_dim+randomref_dim, dropout_atte = 0.05, mlp_dim=concat_dim+randomref_dim, dropout_ffn = 0.05, depth=adapter_transformer_layers) + + ### Condition Dropout + self.misc_dropout = DropPath(misc_dropout) + + + if temporal_attention and not USE_TEMPORAL_TRANSFORMER: + self.rotary_emb = RotaryEmbedding(min(32, head_dim)) + self.time_rel_pos_bias = RelativePositionBias(heads = num_heads, max_distance = 32) # realistically will not be able to generate that many frames of video... yet + + if self.use_fps_condition: + self.fps_embedding = nn.Sequential( + nn.Linear(dim, embed_dim), + nn.SiLU(), + nn.Linear(embed_dim, embed_dim)) + nn.init.zeros_(self.fps_embedding[-1].weight) + nn.init.zeros_(self.fps_embedding[-1].bias) + + # encoder + self.input_blocks = nn.ModuleList() + self.pre_image = nn.Sequential() + init_block = nn.ModuleList([nn.Conv2d(self.in_dim + concat_dim, dim, 3, padding=1)]) + + #### need an initial temporal attention? + if temporal_attention: + if USE_TEMPORAL_TRANSFORMER: + init_block.append(TemporalTransformer(dim, num_heads, head_dim, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disabled_sa, use_linear=use_linear_in_temporal, multiply_zero=use_image_dataset)) + else: + init_block.append(TemporalAttentionMultiBlock(dim, num_heads, head_dim, rotary_emb=self.rotary_emb, temporal_attn_times=temporal_attn_times, use_image_dataset=use_image_dataset)) + + self.input_blocks.append(init_block) + shortcut_dims.append(dim) + for i, (in_dim, out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])): + for j in range(num_res_blocks): + + block = nn.ModuleList([ResBlock(in_dim, embed_dim, dropout, out_channels=out_dim, use_scale_shift_norm=False, use_image_dataset=use_image_dataset,)]) + + if scale in attn_scales: + block.append( + SpatialTransformer( + out_dim, out_dim // head_dim, head_dim, depth=1, context_dim=self.context_dim, + disable_self_attn=False, use_linear=True + ) + ) + if self.temporal_attention: + if USE_TEMPORAL_TRANSFORMER: + block.append(TemporalTransformer(out_dim, out_dim // head_dim, head_dim, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disabled_sa, use_linear=use_linear_in_temporal, multiply_zero=use_image_dataset)) + else: + block.append(TemporalAttentionMultiBlock(out_dim, num_heads, head_dim, rotary_emb = self.rotary_emb, use_image_dataset=use_image_dataset, use_sim_mask=use_sim_mask, temporal_attn_times=temporal_attn_times)) + in_dim = out_dim + self.input_blocks.append(block) + shortcut_dims.append(out_dim) + + # downsample + if i != len(dim_mult) - 1 and j == num_res_blocks - 1: + downsample = Downsample( + out_dim, True, dims=2, out_channels=out_dim + ) + shortcut_dims.append(out_dim) + scale /= 2.0 + self.input_blocks.append(downsample) + + # middle + self.middle_block = nn.ModuleList([ + ResBlock(out_dim, embed_dim, dropout, use_scale_shift_norm=False, use_image_dataset=use_image_dataset,), + SpatialTransformer( + out_dim, out_dim // head_dim, head_dim, depth=1, context_dim=self.context_dim, + disable_self_attn=False, use_linear=True + )]) + + if self.temporal_attention: + if USE_TEMPORAL_TRANSFORMER: + self.middle_block.append( + TemporalTransformer( + out_dim, out_dim // head_dim, head_dim, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disabled_sa, use_linear=use_linear_in_temporal, + multiply_zero=use_image_dataset, + ) + ) + else: + self.middle_block.append(TemporalAttentionMultiBlock(out_dim, num_heads, head_dim, rotary_emb = self.rotary_emb, use_image_dataset=use_image_dataset, use_sim_mask=use_sim_mask, temporal_attn_times=temporal_attn_times)) + + self.middle_block.append(ResBlock(out_dim, embed_dim, dropout, use_scale_shift_norm=False)) + + + # decoder + self.output_blocks = nn.ModuleList() + for i, (in_dim, out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])): + for j in range(num_res_blocks + 1): + + block = nn.ModuleList([ResBlock(in_dim + shortcut_dims.pop(), embed_dim, dropout, out_dim, use_scale_shift_norm=False, use_image_dataset=use_image_dataset, )]) + if scale in attn_scales: + block.append( + SpatialTransformer( + out_dim, out_dim // head_dim, head_dim, depth=1, context_dim=1024, + disable_self_attn=False, use_linear=True + ) + ) + if self.temporal_attention: + if USE_TEMPORAL_TRANSFORMER: + block.append( + TemporalTransformer( + out_dim, out_dim // head_dim, head_dim, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disabled_sa, use_linear=use_linear_in_temporal, multiply_zero=use_image_dataset + ) + ) + else: + block.append(TemporalAttentionMultiBlock(out_dim, num_heads, head_dim, rotary_emb =self.rotary_emb, use_image_dataset=use_image_dataset, use_sim_mask=use_sim_mask, temporal_attn_times=temporal_attn_times)) + in_dim = out_dim + + # upsample + if i != len(dim_mult) - 1 and j == num_res_blocks: + upsample = Upsample(out_dim, True, dims=2.0, out_channels=out_dim) + scale *= 2.0 + block.append(upsample) + self.output_blocks.append(block) + + # head + self.out = nn.Sequential( + nn.GroupNorm(32, out_dim), + nn.SiLU(), + nn.Conv2d(out_dim, self.out_dim, 3, padding=1)) + + # zero out the last layer params + nn.init.zeros_(self.out[-1].weight) + + def forward(self, + x, + t, + y = None, + depth = None, + image = None, + motion = None, + local_image = None, + single_sketch = None, + masked = None, + canny = None, + sketch = None, + dwpose = None, + randomref = None, + histogram = None, + fps = None, + video_mask = None, + focus_present_mask = None, + prob_focus_present = 0., # probability at which a given batch sample will focus on the present (0. is all off, 1. is completely arrested attention across time) + mask_last_frame_num = 0 # mask last frame num + ): + + + assert self.inpainting or masked is None, 'inpainting is not supported' + + batch, c, f, h, w= x.shape + frames = f + device = x.device + self.batch = batch + + #### image and video joint training, if mask_last_frame_num is set, prob_focus_present will be ignored + if mask_last_frame_num > 0: + focus_present_mask = None + video_mask[-mask_last_frame_num:] = False + else: + focus_present_mask = default(focus_present_mask, lambda: prob_mask_like((batch,), prob_focus_present, device = device)) + + if self.temporal_attention and not USE_TEMPORAL_TRANSFORMER: + time_rel_pos_bias = self.time_rel_pos_bias(x.shape[2], device = x.device) + else: + time_rel_pos_bias = None + + + # all-zero and all-keep masks + zero = torch.zeros(batch, dtype=torch.bool).to(x.device) + keep = torch.zeros(batch, dtype=torch.bool).to(x.device) + if self.training: + nzero = (torch.rand(batch) < self.p_all_zero).sum() + nkeep = (torch.rand(batch) < self.p_all_keep).sum() + index = torch.randperm(batch) + zero[index[0:nzero]] = True + keep[index[nzero:nzero + nkeep]] = True + assert not (zero & keep).any() + misc_dropout = partial(self.misc_dropout, zero = zero, keep = keep) + + + concat = x.new_zeros(batch, self.concat_dim, f, h, w) + + + # local_image_embedding (first frame) + if local_image is not None: + local_image = rearrange(local_image, 'b c f h w -> (b f) c h w') + local_image = self.local_image_embedding(local_image) + + h = local_image.shape[2] + local_image = self.local_image_embedding_after(rearrange(local_image, '(b f) c h w -> (b h w) f c', b = batch)) + local_image = rearrange(local_image, '(b h w) f c -> b c f h w', b = batch, h = h) + + concat = concat + misc_dropout(local_image) + + if dwpose is not None: + if 'randomref_pose' in self.video_compositions: + dwpose_random_ref = dwpose[:,:,:1].clone() + dwpose = dwpose[:,:,1:] + dwpose = rearrange(dwpose, 'b c f h w -> (b f) c h w') + dwpose = self.dwpose_embedding(dwpose) + + h = dwpose.shape[2] + dwpose = self.dwpose_embedding_after(rearrange(dwpose, '(b f) c h w -> (b h w) f c', b = batch)) + dwpose = rearrange(dwpose, '(b h w) f c -> b c f h w', b = batch, h = h) + concat = concat + misc_dropout(dwpose) + + randomref_b = x.new_zeros(batch, self.concat_dim+4, 1, h, w) + if randomref is not None: + randomref = rearrange(randomref[:,:,:1,], 'b c f h w -> (b f) c h w') + randomref = self.randomref_embedding2(randomref) + + h = randomref.shape[2] + randomref = self.randomref_embedding_after2(rearrange(randomref, '(b f) c h w -> (b h w) f c', b = batch)) + if 'randomref_pose' in self.video_compositions: + dwpose_random_ref = rearrange(dwpose_random_ref, 'b c f h w -> (b f) c h w') + dwpose_random_ref = self.randomref_pose2_embedding(dwpose_random_ref) + dwpose_random_ref = self.randomref_pose2_embedding_after(rearrange(dwpose_random_ref, '(b f) c h w -> (b h w) f c', b = batch)) + randomref = randomref + dwpose_random_ref + + randomref_a = rearrange(randomref, '(b h w) f c -> b c f h w', b = batch, h = h) + randomref_b = randomref_b + randomref_a + + + x = torch.cat([randomref_b, torch.cat([x, concat], dim=1)], dim=2) + x = rearrange(x, 'b c f h w -> (b f) c h w') + x = self.pre_image(x) + x = rearrange(x, '(b f) c h w -> b c f h w', b = batch) + + # embeddings + if self.use_fps_condition and fps is not None: + e = self.time_embed(sinusoidal_embedding(t, self.dim)) + self.fps_embedding(sinusoidal_embedding(fps, self.dim)) + else: + e = self.time_embed(sinusoidal_embedding(t, self.dim)) + + context = x.new_zeros(batch, 0, self.context_dim) + + + if image is not None: + y_context = self.zero_y.repeat(batch, 1, 1) + context = torch.cat([context, y_context], dim=1) + + image_context = misc_dropout(self.pre_image_condition(image).view(-1, self.num_tokens, self.context_dim)) # torch.cat([y[:,:-1,:], self.pre_image_condition(y[:,-1:,:]) ], dim=1) + context = torch.cat([context, image_context], dim=1) + else: + y_context = self.zero_y.repeat(batch, 1, 1) + context = torch.cat([context, y_context], dim=1) + image_context = torch.zeros_like(self.zero_y.repeat(batch, 1, 1))[:,:self.num_tokens] + context = torch.cat([context, image_context], dim=1) + + # repeat f times for spatial e and context + e = e.repeat_interleave(repeats=f+1, dim=0) + context = context.repeat_interleave(repeats=f+1, dim=0) + + + + ## always in shape (b f) c h w, except for temporal layer + x = rearrange(x, 'b c f h w -> (b f) c h w') + # encoder + xs = [] + for block in self.input_blocks: + x = self._forward_single(block, x, e, context, time_rel_pos_bias, focus_present_mask, video_mask) + xs.append(x) + + # middle + for block in self.middle_block: + x = self._forward_single(block, x, e, context, time_rel_pos_bias,focus_present_mask, video_mask) + + # decoder + for block in self.output_blocks: + x = torch.cat([x, xs.pop()], dim=1) + x = self._forward_single(block, x, e, context, time_rel_pos_bias,focus_present_mask, video_mask, reference=xs[-1] if len(xs) > 0 else None) + + # head + x = self.out(x) + + # reshape back to (b c f h w) + x = rearrange(x, '(b f) c h w -> b c f h w', b = batch) + return x[:,:,1:] + + def _forward_single(self, module, x, e, context, time_rel_pos_bias, focus_present_mask, video_mask, reference=None): + if isinstance(module, ResidualBlock): + module = checkpoint_wrapper(module) if self.use_checkpoint else module + x = x.contiguous() + x = module(x, e, reference) + elif isinstance(module, ResBlock): + module = checkpoint_wrapper(module) if self.use_checkpoint else module + x = x.contiguous() + x = module(x, e, self.batch) + elif isinstance(module, SpatialTransformer): + module = checkpoint_wrapper(module) if self.use_checkpoint else module + x = module(x, context) + elif isinstance(module, TemporalTransformer): + module = checkpoint_wrapper(module) if self.use_checkpoint else module + x = rearrange(x, '(b f) c h w -> b c f h w', b = self.batch) + x = module(x, context) + x = rearrange(x, 'b c f h w -> (b f) c h w') + elif isinstance(module, CrossAttention): + module = checkpoint_wrapper(module) if self.use_checkpoint else module + x = module(x, context) + elif isinstance(module, MemoryEfficientCrossAttention): + module = checkpoint_wrapper(module) if self.use_checkpoint else module + x = module(x, context) + elif isinstance(module, BasicTransformerBlock): + module = checkpoint_wrapper(module) if self.use_checkpoint else module + x = module(x, context) + elif isinstance(module, FeedForward): + x = module(x, context) + elif isinstance(module, Upsample): + x = module(x) + elif isinstance(module, Downsample): + x = module(x) + elif isinstance(module, Resample): + x = module(x, reference) + elif isinstance(module, TemporalAttentionBlock): + module = checkpoint_wrapper(module) if self.use_checkpoint else module + x = rearrange(x, '(b f) c h w -> b c f h w', b = self.batch) + x = module(x, time_rel_pos_bias, focus_present_mask, video_mask) + x = rearrange(x, 'b c f h w -> (b f) c h w') + elif isinstance(module, TemporalAttentionMultiBlock): + module = checkpoint_wrapper(module) if self.use_checkpoint else module + x = rearrange(x, '(b f) c h w -> b c f h w', b = self.batch) + x = module(x, time_rel_pos_bias, focus_present_mask, video_mask) + x = rearrange(x, 'b c f h w -> (b f) c h w') + elif isinstance(module, InitTemporalConvBlock): + module = checkpoint_wrapper(module) if self.use_checkpoint else module + x = rearrange(x, '(b f) c h w -> b c f h w', b = self.batch) + x = module(x) + x = rearrange(x, 'b c f h w -> (b f) c h w') + elif isinstance(module, TemporalConvBlock): + module = checkpoint_wrapper(module) if self.use_checkpoint else module + x = rearrange(x, '(b f) c h w -> b c f h w', b = self.batch) + x = module(x) + x = rearrange(x, 'b c f h w -> (b f) c h w') + elif isinstance(module, nn.ModuleList): + for block in module: + x = self._forward_single(block, x, e, context, time_rel_pos_bias, focus_present_mask, video_mask, reference) + else: + x = module(x) + return x + + + diff --git a/tools/modules/unet/util.py b/tools/modules/unet/util.py new file mode 100644 index 0000000..1d6c6e6 --- /dev/null +++ b/tools/modules/unet/util.py @@ -0,0 +1,1741 @@ +import math +import torch +import xformers +# # import open_clip +# import xformers.ops +import torch.nn as nn +from torch import einsum +from einops import rearrange +from functools import partial +import torch.nn.functional as F +import torch.nn.init as init +from ....lib.rotary_embedding_torch import RotaryEmbedding +from fairscale.nn.checkpoint import checkpoint_wrapper + +# from .mha_flash import FlashAttentionBlock +# from utils.registry_class import MODEL + + +### load all keys started with prefix and replace them with new_prefix +def load_Block(state, prefix, new_prefix=None): + if new_prefix is None: + new_prefix = prefix + + state_dict = {} + state = {key:value for key,value in state.items() if prefix in key} + for key,value in state.items(): + new_key = key.replace(prefix, new_prefix) + state_dict[new_key]=value + return state_dict + + +def load_2d_pretrained_state_dict(state,cfg): + + new_state_dict = {} + + dim = cfg.unet_dim + num_res_blocks = cfg.unet_res_blocks + temporal_attention = cfg.temporal_attention + temporal_conv = cfg.temporal_conv + dim_mult = cfg.unet_dim_mult + attn_scales = cfg.unet_attn_scales + + # params + enc_dims = [dim * u for u in [1] + dim_mult] + dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + shortcut_dims = [] + scale = 1.0 + + #embeddings + state_dict = load_Block(state,prefix=f'time_embedding') + new_state_dict.update(state_dict) + state_dict = load_Block(state,prefix=f'y_embedding') + new_state_dict.update(state_dict) + state_dict = load_Block(state,prefix=f'context_embedding') + new_state_dict.update(state_dict) + + encoder_idx = 0 + ### init block + state_dict = load_Block(state,prefix=f'encoder.{encoder_idx}',new_prefix=f'encoder.{encoder_idx}.0') + new_state_dict.update(state_dict) + encoder_idx += 1 + + shortcut_dims.append(dim) + for i, (in_dim, out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])): + for j in range(num_res_blocks): + # residual (+attention) blocks + idx = 0 + idx_ = 0 + # residual (+attention) blocks + state_dict = load_Block(state,prefix=f'encoder.{encoder_idx}.{idx}',new_prefix=f'encoder.{encoder_idx}.{idx_}') + new_state_dict.update(state_dict) + idx += 1 + idx_ = 2 + + if scale in attn_scales: + # block.append(AttentionBlock(out_dim, context_dim, num_heads, head_dim)) + state_dict = load_Block(state,prefix=f'encoder.{encoder_idx}.{idx}',new_prefix=f'encoder.{encoder_idx}.{idx_}') + new_state_dict.update(state_dict) + # if temporal_attention: + # block.append(TemporalAttentionBlock(out_dim, num_heads, head_dim, rotary_emb = self.rotary_emb)) + in_dim = out_dim + encoder_idx += 1 + shortcut_dims.append(out_dim) + + # downsample + if i != len(dim_mult) - 1 and j == num_res_blocks - 1: + # downsample = ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm, 0.5, dropout) + state_dict = load_Block(state,prefix=f'encoder.{encoder_idx}',new_prefix=f'encoder.{encoder_idx}.0') + new_state_dict.update(state_dict) + + shortcut_dims.append(out_dim) + scale /= 2.0 + encoder_idx += 1 + + # middle + # self.middle = nn.ModuleList([ + # ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm, 'none'), + # TemporalConvBlock(out_dim), + # AttentionBlock(out_dim, context_dim, num_heads, head_dim)]) + # if temporal_attention: + # self.middle.append(TemporalAttentionBlock(out_dim, num_heads, head_dim, rotary_emb = self.rotary_emb)) + # elif temporal_conv: + # self.middle.append(TemporalConvBlock(out_dim,dropout=dropout)) + # self.middle.append(ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm, 'none')) + # self.middle.append(TemporalConvBlock(out_dim)) + + + # middle + middle_idx = 0 + # self.middle = nn.ModuleList([ + # ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm, 1.0, dropout), + # AttentionBlock(out_dim, context_dim, num_heads, head_dim)]) + state_dict = load_Block(state,prefix=f'middle.{middle_idx}') + new_state_dict.update(state_dict) + middle_idx += 2 + + state_dict = load_Block(state,prefix=f'middle.1',new_prefix=f'middle.{middle_idx}') + new_state_dict.update(state_dict) + middle_idx += 1 + + for _ in range(cfg.temporal_attn_times): + # self.middle.append(TemporalAttentionBlock(out_dim, num_heads, head_dim, rotary_emb = self.rotary_emb)) + middle_idx += 1 + + # self.middle.append(ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm, 1.0, dropout)) + state_dict = load_Block(state,prefix=f'middle.2',new_prefix=f'middle.{middle_idx}') + new_state_dict.update(state_dict) + middle_idx += 2 + + + decoder_idx = 0 + for i, (in_dim, out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])): + for j in range(num_res_blocks + 1): + idx = 0 + idx_ = 0 + # residual (+attention) blocks + # block = nn.ModuleList([ResidualBlock(in_dim + shortcut_dims.pop(), embed_dim, out_dim, use_scale_shift_norm, 1.0, dropout)]) + state_dict = load_Block(state,prefix=f'decoder.{decoder_idx}.{idx}',new_prefix=f'decoder.{decoder_idx}.{idx_}') + new_state_dict.update(state_dict) + idx += 1 + idx_ += 2 + if scale in attn_scales: + # block.append(AttentionBlock(out_dim, context_dim, num_heads, head_dim)) + state_dict = load_Block(state,prefix=f'decoder.{decoder_idx}.{idx}',new_prefix=f'decoder.{decoder_idx}.{idx_}') + new_state_dict.update(state_dict) + idx += 1 + idx_ += 1 + for _ in range(cfg.temporal_attn_times): + # block.append(TemporalAttentionBlock(out_dim, num_heads, head_dim, rotary_emb = self.rotary_emb)) + idx_ +=1 + + in_dim = out_dim + + # upsample + if i != len(dim_mult) - 1 and j == num_res_blocks: + + # upsample = ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm, 2.0, dropout) + state_dict = load_Block(state,prefix=f'decoder.{decoder_idx}.{idx}',new_prefix=f'decoder.{decoder_idx}.{idx_}') + new_state_dict.update(state_dict) + idx += 1 + idx_ += 2 + + scale *= 2.0 + # block.append(upsample) + # self.decoder.append(block) + decoder_idx += 1 + + # head + # self.head = nn.Sequential( + # nn.GroupNorm(32, out_dim), + # nn.SiLU(), + # nn.Conv3d(out_dim, self.out_dim, (1,3,3), padding=(0,1,1))) + state_dict = load_Block(state,prefix=f'head') + new_state_dict.update(state_dict) + + return new_state_dict + +def sinusoidal_embedding(timesteps, dim): + # check input + half = dim // 2 + timesteps = timesteps.float() + + # compute sinusoidal embedding + sinusoid = torch.outer( + timesteps, + torch.pow(10000, -torch.arange(half).to(timesteps).div(half))) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + if dim % 2 != 0: + x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1) + return x + +def exists(x): + return x is not None + +def default(val, d): + if exists(val): + return val + return d() if callable(d) else d + +def prob_mask_like(shape, prob, device): + if prob == 1: + return torch.ones(shape, device = device, dtype = torch.bool) + elif prob == 0: + return torch.zeros(shape, device = device, dtype = torch.bool) + else: + mask = torch.zeros(shape, device = device).float().uniform_(0, 1) < prob + ### aviod mask all, which will cause find_unused_parameters error + if mask.all(): + mask[0]=False + return mask + + +class MemoryEfficientCrossAttention(nn.Module): + # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 + def __init__(self, query_dim, max_bs=4096, context_dim=None, heads=8, dim_head=64, dropout=0.0): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.max_bs = max_bs + self.heads = heads + self.dim_head = dim_head + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + self.attention_op: Optional[Any] = None + + def forward(self, x, context=None, mask=None): + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + b, _, _ = q.shape + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(b, t.shape[1], self.heads, self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b * self.heads, t.shape[1], self.dim_head) + .contiguous(), + (q, k, v), + ) + + # actually compute the attention, what we cannot get enough of + if q.shape[0] > self.max_bs: + q_list = torch.chunk(q, q.shape[0] // self.max_bs, dim=0) + k_list = torch.chunk(k, k.shape[0] // self.max_bs, dim=0) + v_list = torch.chunk(v, v.shape[0] // self.max_bs, dim=0) + out_list = [] + for q_1, k_1, v_1 in zip(q_list, k_list, v_list): + out = xformers.ops.memory_efficient_attention( + q_1, k_1, v_1, attn_bias=None, op=self.attention_op) + out_list.append(out) + out = torch.cat(out_list, dim=0) + else: + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) + + if exists(mask): + raise NotImplementedError + out = ( + out.unsqueeze(0) + .reshape(b, self.heads, out.shape[1], self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b, out.shape[1], self.heads * self.dim_head) + ) + return self.to_out(out) + +class RelativePositionBias(nn.Module): + def __init__( + self, + heads = 8, + num_buckets = 32, + max_distance = 128 + ): + super().__init__() + self.num_buckets = num_buckets + self.max_distance = max_distance + self.relative_attention_bias = nn.Embedding(num_buckets, heads) + + @staticmethod + def _relative_position_bucket(relative_position, num_buckets = 32, max_distance = 128): + ret = 0 + n = -relative_position + + num_buckets //= 2 + ret += (n < 0).long() * num_buckets + n = torch.abs(n) + + max_exact = num_buckets // 2 + is_small = n < max_exact + + val_if_large = max_exact + ( + torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) + ).long() + val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) + + ret += torch.where(is_small, n, val_if_large) + return ret + + def forward(self, n, device): + q_pos = torch.arange(n, dtype = torch.long, device = device) + k_pos = torch.arange(n, dtype = torch.long, device = device) + rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1') + rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance) + values = self.relative_attention_bias(rp_bucket) + return rearrange(values, 'i j h -> h i j') + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + NEW: use_linear for more efficiency instead of the 1x1 convs + """ + def __init__(self, in_channels, n_heads, d_head, + depth=1, dropout=0., context_dim=None, + disable_self_attn=False, use_linear=False, + use_checkpoint=True): + super().__init__() + if exists(context_dim) and not isinstance(context_dim, list): + context_dim = [context_dim] + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + if not use_linear: + self.proj_in = nn.Conv2d(in_channels, + inner_dim, + kernel_size=1, + stride=1, + padding=0) + else: + self.proj_in = nn.Linear(in_channels, inner_dim) + + self.transformer_blocks = nn.ModuleList( + [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], + disable_self_attn=disable_self_attn, checkpoint=use_checkpoint) + for d in range(depth)] + ) + if not use_linear: + self.proj_out = zero_module(nn.Conv2d(inner_dim, + in_channels, + kernel_size=1, + stride=1, + padding=0)) + else: + self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) + self.use_linear = use_linear + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + if not isinstance(context, list): + context = [context] + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + if not self.use_linear: + x = self.proj_in(x) + x = rearrange(x, 'b c h w -> b (h w) c').contiguous() + if self.use_linear: + x = self.proj_in(x) + for i, block in enumerate(self.transformer_blocks): + x = block(x, context=context[i]) + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() + if not self.use_linear: + x = self.proj_out(x) + return x + x_in + + +class SpatialTransformerWithAdapter(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + NEW: use_linear for more efficiency instead of the 1x1 convs + """ + def __init__(self, in_channels, n_heads, d_head, + depth=1, dropout=0., context_dim=None, + disable_self_attn=False, use_linear=False, + use_checkpoint=True, + adapter_list=[], adapter_position_list=['', 'parallel', ''], + adapter_hidden_dim=None): + super().__init__() + if exists(context_dim) and not isinstance(context_dim, list): + context_dim = [context_dim] + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + if not use_linear: + self.proj_in = nn.Conv2d(in_channels, + inner_dim, + kernel_size=1, + stride=1, + padding=0) + else: + self.proj_in = nn.Linear(in_channels, inner_dim) + + self.transformer_blocks = nn.ModuleList( + [BasicTransformerBlockWithAdapter(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], + disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, + adapter_list=adapter_list, adapter_position_list=adapter_position_list, + adapter_hidden_dim=adapter_hidden_dim) + for d in range(depth)] + ) + if not use_linear: + self.proj_out = zero_module(nn.Conv2d(inner_dim, + in_channels, + kernel_size=1, + stride=1, + padding=0)) + else: + self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) + self.use_linear = use_linear + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + if not isinstance(context, list): + context = [context] + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + if not self.use_linear: + x = self.proj_in(x) + x = rearrange(x, 'b c h w -> b (h w) c').contiguous() + if self.use_linear: + x = self.proj_in(x) + for i, block in enumerate(self.transformer_blocks): + x = block(x, context=context[i]) + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() + if not self.use_linear: + x = self.proj_out(x) + return x + x_in + +import os +_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32") + +class CrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head ** -0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), + nn.Dropout(dropout) + ) + + def forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + + # force cast to fp32 to avoid overflowing + if _ATTN_PRECISION =="fp32": + with torch.autocast(enabled=False, device_type = 'cuda'): + q, k = q.float(), k.float() + sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale + else: + sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale + + del q, k + + if exists(mask): + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + sim = sim.softmax(dim=-1) + + out = torch.einsum('b i j, b j d -> b i d', sim, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return self.to_out(out) + + +class Adapter(nn.Module): + def __init__(self, in_dim, hidden_dim, condition_dim=None): + super().__init__() + self.down_linear = nn.Linear(in_dim, hidden_dim) + self.up_linear = nn.Linear(hidden_dim, in_dim) + self.condition_dim = condition_dim + if condition_dim is not None: + self.condition_linear = nn.Linear(condition_dim, in_dim) + + init.zeros_(self.up_linear.weight) + init.zeros_(self.up_linear.bias) + + def forward(self, x, condition=None, condition_lam=1): + x_in = x + if self.condition_dim is not None and condition is not None: + x = x + condition_lam * self.condition_linear(condition) + x = self.down_linear(x) + x = F.gelu(x) + x = self.up_linear(x) + x += x_in + return x + + +class MemoryEfficientCrossAttention_attemask(nn.Module): + # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.heads = heads + self.dim_head = dim_head + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + self.attention_op: Optional[Any] = None + + def forward(self, x, context=None, mask=None): + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + b, _, _ = q.shape + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(b, t.shape[1], self.heads, self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b * self.heads, t.shape[1], self.dim_head) + .contiguous(), + (q, k, v), + ) + + # actually compute the attention, what we cannot get enough of + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=xformers.ops.LowerTriangularMask(), op=self.attention_op) + + if exists(mask): + raise NotImplementedError + out = ( + out.unsqueeze(0) + .reshape(b, self.heads, out.shape[1], self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b, out.shape[1], self.heads * self.dim_head) + ) + return self.to_out(out) + + + +class BasicTransformerBlock_attemask(nn.Module): + # ATTENTION_MODES = { + # "softmax": CrossAttention, # vanilla attention + # "softmax-xformers": MemoryEfficientCrossAttention + # } + def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, + disable_self_attn=False): + super().__init__() + # attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax" + # assert attn_mode in self.ATTENTION_MODES + # attn_cls = CrossAttention + attn_cls = MemoryEfficientCrossAttention_attemask + self.disable_self_attn = disable_self_attn + self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, + context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim, + heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward_(self, x, context=None): + return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) + + def forward(self, x, context=None): + x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + + +class BasicTransformerBlockWithAdapter(nn.Module): + # ATTENTION_MODES = { + # "softmax": CrossAttention, # vanilla attention + # "softmax-xformers": MemoryEfficientCrossAttention + # } + def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, disable_self_attn=False, + adapter_list=[], adapter_position_list=['parallel', 'parallel', 'parallel'], adapter_hidden_dim=None, adapter_condition_dim=None + ): + super().__init__() + # attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax" + # assert attn_mode in self.ATTENTION_MODES + # attn_cls = CrossAttention + attn_cls = MemoryEfficientCrossAttention + self.disable_self_attn = disable_self_attn + self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, + context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim, + heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + # adapter + self.adapter_list = adapter_list + self.adapter_position_list = adapter_position_list + hidden_dim = dim//2 if not adapter_hidden_dim else adapter_hidden_dim + if "self_attention" in adapter_list: + self.attn_adapter = Adapter(dim, hidden_dim, adapter_condition_dim) + if "cross_attention" in adapter_list: + self.cross_attn_adapter = Adapter(dim, hidden_dim, adapter_condition_dim) + if "feedforward" in adapter_list: + self.ff_adapter = Adapter(dim, hidden_dim, adapter_condition_dim) + + + def forward_(self, x, context=None, adapter_condition=None, adapter_condition_lam=1): + return checkpoint(self._forward, (x, context, adapter_condition, adapter_condition_lam), self.parameters(), self.checkpoint) + + def forward(self, x, context=None, adapter_condition=None, adapter_condition_lam=1): + if "self_attention" in self.adapter_list: + if self.adapter_position_list[0] == 'parallel': + # parallel + x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + self.attn_adapter(x, adapter_condition, adapter_condition_lam) + elif self.adapter_position_list[0] == 'serial': + # serial + x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x + x = self.attn_adapter(x, adapter_condition, adapter_condition_lam) + else: + x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x + + if "cross_attention" in self.adapter_list: + if self.adapter_position_list[1] == 'parallel': + # parallel + x = self.attn2(self.norm2(x), context=context) + self.cross_attn_adapter(x, adapter_condition, adapter_condition_lam) + elif self.adapter_position_list[1] == 'serial': + x = self.attn2(self.norm2(x), context=context) + x + x = self.cross_attn_adapter(x, adapter_condition, adapter_condition_lam) + else: + x = self.attn2(self.norm2(x), context=context) + x + + if "feedforward" in self.adapter_list: + if self.adapter_position_list[2] == 'parallel': + x = self.ff(self.norm3(x)) + self.ff_adapter(x, adapter_condition, adapter_condition_lam) + elif self.adapter_position_list[2] == 'serial': + x = self.ff(self.norm3(x)) + x + x = self.ff_adapter(x, adapter_condition, adapter_condition_lam) + else: + x = self.ff(self.norm3(x)) + x + + return x + +class BasicTransformerBlock(nn.Module): + # ATTENTION_MODES = { + # "softmax": CrossAttention, # vanilla attention + # "softmax-xformers": MemoryEfficientCrossAttention + # } + def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, + disable_self_attn=False): + super().__init__() + # attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax" + # assert attn_mode in self.ATTENTION_MODES + # attn_cls = CrossAttention + attn_cls = MemoryEfficientCrossAttention + self.disable_self_attn = disable_self_attn + self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, + context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim, + heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward_(self, x, context=None): + return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) + + def forward(self, x, context=None): + x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + +# feedforward +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=padding) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + + +class UpsampleSR600(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=padding) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + # TODO: to match input_blocks, remove elements of two sides + x = x[..., 1:-1, :] + if self.use_conv: + x = self.conv(x) + return x + + +class ResBlock(nn.Module): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + up=False, + down=False, + use_temporal_conv=True, + use_image_dataset=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_scale_shift_norm = use_scale_shift_norm + self.use_temporal_conv = use_temporal_conv + + self.in_layers = nn.Sequential( + nn.GroupNorm(32, channels), + nn.SiLU(), + nn.Conv2d(channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + nn.Linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + nn.GroupNorm(32, self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = nn.Conv2d(channels, self.out_channels, 1) + + if self.use_temporal_conv: + self.temopral_conv = TemporalConvBlock_v2(self.out_channels, self.out_channels, dropout=0.1, use_image_dataset=use_image_dataset) + # self.temopral_conv_2 = TemporalConvBlock(self.out_channels, self.out_channels, dropout=0.1, use_image_dataset=use_image_dataset) + + def forward(self, x, emb, batch_size): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return self._forward(x, emb, batch_size) + + def _forward(self, x, emb, batch_size): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + h = self.skip_connection(x) + h + + if self.use_temporal_conv: + h = rearrange(h, '(b f) c h w -> b c f h w', b=batch_size) + h = self.temopral_conv(h) + # h = self.temopral_conv_2(h) + h = rearrange(h, 'b c f h w -> (b f) c h w') + return h + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + +class Resample(nn.Module): + + def __init__(self, in_dim, out_dim, mode): + assert mode in ['none', 'upsample', 'downsample'] + super(Resample, self).__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.mode = mode + + def forward(self, x, reference=None): + if self.mode == 'upsample': + assert reference is not None + x = F.interpolate(x, size=reference.shape[-2:], mode='nearest') + elif self.mode == 'downsample': + x = F.adaptive_avg_pool2d(x, output_size=tuple(u // 2 for u in x.shape[-2:])) + return x + +class ResidualBlock(nn.Module): + + def __init__(self, in_dim, embed_dim, out_dim, use_scale_shift_norm=True, + mode='none', dropout=0.0): + super(ResidualBlock, self).__init__() + self.in_dim = in_dim + self.embed_dim = embed_dim + self.out_dim = out_dim + self.use_scale_shift_norm = use_scale_shift_norm + self.mode = mode + + # layers + self.layer1 = nn.Sequential( + nn.GroupNorm(32, in_dim), + nn.SiLU(), + nn.Conv2d(in_dim, out_dim, 3, padding=1)) + self.resample = Resample(in_dim, in_dim, mode) + self.embedding = nn.Sequential( + nn.SiLU(), + nn.Linear(embed_dim, out_dim * 2 if use_scale_shift_norm else out_dim)) + self.layer2 = nn.Sequential( + nn.GroupNorm(32, out_dim), + nn.SiLU(), + nn.Dropout(dropout), + nn.Conv2d(out_dim, out_dim, 3, padding=1)) + self.shortcut = nn.Identity() if in_dim == out_dim else nn.Conv2d(in_dim, out_dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.layer2[-1].weight) + + def forward(self, x, e, reference=None): + identity = self.resample(x, reference) + x = self.layer1[-1](self.resample(self.layer1[:-1](x), reference)) + e = self.embedding(e).unsqueeze(-1).unsqueeze(-1).type(x.dtype) + if self.use_scale_shift_norm: + scale, shift = e.chunk(2, dim=1) + x = self.layer2[0](x) * (1 + scale) + shift + x = self.layer2[1:](x) + else: + x = x + e + x = self.layer2(x) + x = x + self.shortcut(identity) + return x + +class AttentionBlock(nn.Module): + + def __init__(self, dim, context_dim=None, num_heads=None, head_dim=None): + # consider head_dim first, then num_heads + num_heads = dim // head_dim if head_dim else num_heads + head_dim = dim // num_heads + assert num_heads * head_dim == dim + super(AttentionBlock, self).__init__() + self.dim = dim + self.context_dim = context_dim + self.num_heads = num_heads + self.head_dim = head_dim + self.scale = math.pow(head_dim, -0.25) + + # layers + self.norm = nn.GroupNorm(32, dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + if context_dim is not None: + self.context_kv = nn.Linear(context_dim, dim * 2) + self.proj = nn.Conv2d(dim, dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.proj.weight) + + def forward(self, x, context=None): + r"""x: [B, C, H, W]. + context: [B, L, C] or None. + """ + identity = x + b, c, h, w, n, d = *x.size(), self.num_heads, self.head_dim + + # compute query, key, value + x = self.norm(x) + q, k, v = self.to_qkv(x).view(b, n * 3, d, h * w).chunk(3, dim=1) + if context is not None: + ck, cv = self.context_kv(context).reshape(b, -1, n * 2, d).permute(0, 2, 3, 1).chunk(2, dim=1) + k = torch.cat([ck, k], dim=-1) + v = torch.cat([cv, v], dim=-1) + + # compute attention + attn = torch.matmul(q.transpose(-1, -2) * self.scale, k * self.scale) + attn = F.softmax(attn, dim=-1) + + # gather context + x = torch.matmul(v, attn.transpose(-1, -2)) + x = x.reshape(b, c, h, w) + + # output + x = self.proj(x) + return x + identity + + +class TemporalAttentionBlock(nn.Module): + def __init__( + self, + dim, + heads = 4, + dim_head = 32, + rotary_emb = None, + use_image_dataset = False, + use_sim_mask = False + ): + super().__init__() + # consider num_heads first, as pos_bias needs fixed num_heads + # heads = dim // dim_head if dim_head else heads + dim_head = dim // heads + assert heads * dim_head == dim + self.use_image_dataset = use_image_dataset + self.use_sim_mask = use_sim_mask + + self.scale = dim_head ** -0.5 + self.heads = heads + hidden_dim = dim_head * heads + + self.norm = nn.GroupNorm(32, dim) + self.rotary_emb = rotary_emb + self.to_qkv = nn.Linear(dim, hidden_dim * 3)#, bias = False) + self.to_out = nn.Linear(hidden_dim, dim)#, bias = False) + + # nn.init.zeros_(self.to_out.weight) + # nn.init.zeros_(self.to_out.bias) + + def forward( + self, + x, + pos_bias = None, + focus_present_mask = None, + video_mask = None + ): + + identity = x + n, height, device = x.shape[2], x.shape[-2], x.device + + x = self.norm(x) + x = rearrange(x, 'b c f h w -> b (h w) f c') + + qkv = self.to_qkv(x).chunk(3, dim = -1) + + if exists(focus_present_mask) and focus_present_mask.all(): + # if all batch samples are focusing on present + # it would be equivalent to passing that token's values (v=qkv[-1]) through to the output + values = qkv[-1] + out = self.to_out(values) + out = rearrange(out, 'b (h w) f c -> b c f h w', h = height) + + return out + identity + + # split out heads + # q, k, v = rearrange_many(qkv, '... n (h d) -> ... h n d', h = self.heads) + # shape [b (hw) h n c/h], n=f + q= rearrange(qkv[0], '... n (h d) -> ... h n d', h = self.heads) + k= rearrange(qkv[1], '... n (h d) -> ... h n d', h = self.heads) + v= rearrange(qkv[2], '... n (h d) -> ... h n d', h = self.heads) + + + # scale + + q = q * self.scale + + # rotate positions into queries and keys for time attention + if exists(self.rotary_emb): + q = self.rotary_emb.rotate_queries_or_keys(q) + k = self.rotary_emb.rotate_queries_or_keys(k) + + # similarity + # shape [b (hw) h n n], n=f + sim = torch.einsum('... h i d, ... h j d -> ... h i j', q, k) + + # relative positional bias + + if exists(pos_bias): + # print(sim.shape,pos_bias.shape) + sim = sim + pos_bias + + if (focus_present_mask is None and video_mask is not None): + #video_mask: [B, n] + mask = video_mask[:, None, :] * video_mask[:, :, None] # [b,n,n] + mask = mask.unsqueeze(1).unsqueeze(1) #[b,1,1,n,n] + sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) + elif exists(focus_present_mask) and not (~focus_present_mask).all(): + attend_all_mask = torch.ones((n, n), device = device, dtype = torch.bool) + attend_self_mask = torch.eye(n, device = device, dtype = torch.bool) + + mask = torch.where( + rearrange(focus_present_mask, 'b -> b 1 1 1 1'), + rearrange(attend_self_mask, 'i j -> 1 1 1 i j'), + rearrange(attend_all_mask, 'i j -> 1 1 1 i j'), + ) + + sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) + + if self.use_sim_mask: + sim_mask = torch.tril(torch.ones((n, n), device = device, dtype = torch.bool), diagonal=0) + sim = sim.masked_fill(~sim_mask, -torch.finfo(sim.dtype).max) + + # numerical stability + sim = sim - sim.amax(dim = -1, keepdim = True).detach() + attn = sim.softmax(dim = -1) + + # aggregate values + + out = torch.einsum('... h i j, ... h j d -> ... h i d', attn, v) + out = rearrange(out, '... h n d -> ... n (h d)') + out = self.to_out(out) + + out = rearrange(out, 'b (h w) f c -> b c f h w', h = height) + + if self.use_image_dataset: + out = identity + 0*out + else: + out = identity + out + return out + +class TemporalTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + """ + def __init__(self, in_channels, n_heads, d_head, + depth=1, dropout=0., context_dim=None, + disable_self_attn=False, use_linear=False, + use_checkpoint=True, only_self_att=True, multiply_zero=False): + super().__init__() + self.multiply_zero = multiply_zero + self.only_self_att = only_self_att + self.use_adaptor = False + if self.only_self_att: + context_dim = None + if not isinstance(context_dim, list): + context_dim = [context_dim] + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + if not use_linear: + self.proj_in = nn.Conv1d(in_channels, + inner_dim, + kernel_size=1, + stride=1, + padding=0) + else: + self.proj_in = nn.Linear(in_channels, inner_dim) + if self.use_adaptor: + self.adaptor_in = nn.Linear(frames, frames) + + self.transformer_blocks = nn.ModuleList( + [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], + checkpoint=use_checkpoint) + for d in range(depth)] + ) + if not use_linear: + self.proj_out = zero_module(nn.Conv1d(inner_dim, + in_channels, + kernel_size=1, + stride=1, + padding=0)) + else: + self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) + if self.use_adaptor: + self.adaptor_out = nn.Linear(frames, frames) + self.use_linear = use_linear + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + if self.only_self_att: + context = None + if not isinstance(context, list): + context = [context] + b, c, f, h, w = x.shape + x_in = x + x = self.norm(x) + + if not self.use_linear: + x = rearrange(x, 'b c f h w -> (b h w) c f').contiguous() + x = self.proj_in(x) + # [16384, 16, 320] + if self.use_linear: + x = rearrange(x, '(b f) c h w -> b (h w) f c', f=self.frames).contiguous() + x = self.proj_in(x) + + if self.only_self_att: + x = rearrange(x, 'bhw c f -> bhw f c').contiguous() + for i, block in enumerate(self.transformer_blocks): + x = block(x) + x = rearrange(x, '(b hw) f c -> b hw f c', b=b).contiguous() + else: + x = rearrange(x, '(b hw) c f -> b hw f c', b=b).contiguous() + for i, block in enumerate(self.transformer_blocks): + # context[i] = repeat(context[i], '(b f) l con -> b (f r) l con', r=(h*w)//self.frames, f=self.frames).contiguous() + context[i] = rearrange(context[i], '(b f) l con -> b f l con', f=self.frames).contiguous() + # calculate each batch one by one (since number in shape could not greater then 65,535 for some package) + for j in range(b): + context_i_j = repeat(context[i][j], 'f l con -> (f r) l con', r=(h*w)//self.frames, f=self.frames).contiguous() + x[j] = block(x[j], context=context_i_j) + + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, 'b (h w) f c -> b f c h w', h=h, w=w).contiguous() + if not self.use_linear: + # x = rearrange(x, 'bhw f c -> bhw c f').contiguous() + x = rearrange(x, 'b hw f c -> (b hw) c f').contiguous() + x = self.proj_out(x) + x = rearrange(x, '(b h w) c f -> b c f h w', b=b, h=h, w=w).contiguous() + + if self.multiply_zero: + x = 0.0 * x + x_in + else: + x = x + x_in + return x + + +class TemporalTransformerWithAdapter(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + """ + def __init__(self, in_channels, n_heads, d_head, + depth=1, dropout=0., context_dim=None, + disable_self_attn=False, use_linear=False, + use_checkpoint=True, only_self_att=True, multiply_zero=False, + adapter_list=[], adapter_position_list=['parallel', 'parallel', 'parallel'], + adapter_hidden_dim=None, adapter_condition_dim=None): + super().__init__() + self.multiply_zero = multiply_zero + self.only_self_att = only_self_att + self.use_adaptor = False + if self.only_self_att: + context_dim = None + if not isinstance(context_dim, list): + context_dim = [context_dim] + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + if not use_linear: + self.proj_in = nn.Conv1d(in_channels, + inner_dim, + kernel_size=1, + stride=1, + padding=0) + else: + self.proj_in = nn.Linear(in_channels, inner_dim) + if self.use_adaptor: + self.adaptor_in = nn.Linear(frames, frames) + + self.transformer_blocks = nn.ModuleList( + [BasicTransformerBlockWithAdapter(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], + checkpoint=use_checkpoint, adapter_list=adapter_list, adapter_position_list=adapter_position_list, + adapter_hidden_dim=adapter_hidden_dim, adapter_condition_dim=adapter_condition_dim) + for d in range(depth)] + ) + if not use_linear: + self.proj_out = zero_module(nn.Conv1d(inner_dim, + in_channels, + kernel_size=1, + stride=1, + padding=0)) + else: + self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) + if self.use_adaptor: + self.adaptor_out = nn.Linear(frames, frames) + self.use_linear = use_linear + + def forward(self, x, context=None, adapter_condition=None, adapter_condition_lam=1): + # note: if no context is given, cross-attention defaults to self-attention + if self.only_self_att: + context = None + if not isinstance(context, list): + context = [context] + b, c, f, h, w = x.shape + x_in = x + x = self.norm(x) + + if not self.use_linear: + x = rearrange(x, 'b c f h w -> (b h w) c f').contiguous() + x = self.proj_in(x) + # [16384, 16, 320] + if self.use_linear: + x = rearrange(x, '(b f) c h w -> b (h w) f c', f=self.frames).contiguous() + x = self.proj_in(x) + + if adapter_condition is not None: + b_cond, f_cond, c_cond = adapter_condition.shape + adapter_condition = adapter_condition.unsqueeze(1).unsqueeze(1).repeat(1, h, w, 1, 1) + adapter_condition = adapter_condition.reshape(b_cond*h*w, f_cond, c_cond) + + if self.only_self_att: + x = rearrange(x, 'bhw c f -> bhw f c').contiguous() + for i, block in enumerate(self.transformer_blocks): + x = block(x, adapter_condition=adapter_condition, adapter_condition_lam=adapter_condition_lam) + x = rearrange(x, '(b hw) f c -> b hw f c', b=b).contiguous() + else: + x = rearrange(x, '(b hw) c f -> b hw f c', b=b).contiguous() + for i, block in enumerate(self.transformer_blocks): + # context[i] = repeat(context[i], '(b f) l con -> b (f r) l con', r=(h*w)//self.frames, f=self.frames).contiguous() + context[i] = rearrange(context[i], '(b f) l con -> b f l con', f=self.frames).contiguous() + # calculate each batch one by one (since number in shape could not greater then 65,535 for some package) + for j in range(b): + context_i_j = repeat(context[i][j], 'f l con -> (f r) l con', r=(h*w)//self.frames, f=self.frames).contiguous() + x[j] = block(x[j], context=context_i_j) + + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, 'b (h w) f c -> b f c h w', h=h, w=w).contiguous() + if not self.use_linear: + # x = rearrange(x, 'bhw f c -> bhw c f').contiguous() + x = rearrange(x, 'b hw f c -> (b hw) c f').contiguous() + x = self.proj_out(x) + x = rearrange(x, '(b h w) c f -> b c f h w', b=b, h=h, w=w).contiguous() + + if self.multiply_zero: + x = 0.0 * x + x_in + else: + x = x + x_in + return x + +class Attention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.attend = nn.Softmax(dim = -1) + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + + def forward(self, x): + b, n, _, h = *x.shape, self.heads + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) + + dots = torch.einsum('b h i d, b h j d -> b h i j', q, k) * self.scale + + attn = self.attend(dots) + + out = torch.einsum('b h i j, b h j d -> b h i d', attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + +class PreNormattention(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + x + +class TransformerV2(nn.Module): + def __init__(self, heads=8, dim=2048, dim_head_k=256, dim_head_v=256, dropout_atte = 0.05, mlp_dim=2048, dropout_ffn = 0.05, depth=1): + super().__init__() + self.layers = nn.ModuleList([]) + self.depth = depth + for _ in range(depth): + self.layers.append(nn.ModuleList([ + PreNormattention(dim, Attention(dim, heads = heads, dim_head = dim_head_k, dropout = dropout_atte)), + FeedForward(dim, mlp_dim, dropout = dropout_ffn), + ])) + def forward(self, x): + # if self.depth + for attn, ff in self.layers[:1]: + x = attn(x) + x = ff(x) + x + if self.depth > 1: + for attn, ff in self.layers[1:]: + x = attn(x) + x = ff(x) + x + return x + +class TemporalTransformer_attemask(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + """ + def __init__(self, in_channels, n_heads, d_head, + depth=1, dropout=0., context_dim=None, + disable_self_attn=False, use_linear=False, + use_checkpoint=True, only_self_att=True, multiply_zero=False): + super().__init__() + self.multiply_zero = multiply_zero + self.only_self_att = only_self_att + self.use_adaptor = False + if self.only_self_att: + context_dim = None + if not isinstance(context_dim, list): + context_dim = [context_dim] + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + if not use_linear: + self.proj_in = nn.Conv1d(in_channels, + inner_dim, + kernel_size=1, + stride=1, + padding=0) + else: + self.proj_in = nn.Linear(in_channels, inner_dim) + if self.use_adaptor: + self.adaptor_in = nn.Linear(frames, frames) + + self.transformer_blocks = nn.ModuleList( + [BasicTransformerBlock_attemask(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], + checkpoint=use_checkpoint) + for d in range(depth)] + ) + if not use_linear: + self.proj_out = zero_module(nn.Conv1d(inner_dim, + in_channels, + kernel_size=1, + stride=1, + padding=0)) + else: + self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) + if self.use_adaptor: + self.adaptor_out = nn.Linear(frames, frames) + self.use_linear = use_linear + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + if self.only_self_att: + context = None + if not isinstance(context, list): + context = [context] + b, c, f, h, w = x.shape + x_in = x + x = self.norm(x) + + if not self.use_linear: + x = rearrange(x, 'b c f h w -> (b h w) c f').contiguous() + x = self.proj_in(x) + # [16384, 16, 320] + if self.use_linear: + x = rearrange(x, '(b f) c h w -> b (h w) f c', f=self.frames).contiguous() + x = self.proj_in(x) + + if self.only_self_att: + x = rearrange(x, 'bhw c f -> bhw f c').contiguous() + for i, block in enumerate(self.transformer_blocks): + x = block(x) + x = rearrange(x, '(b hw) f c -> b hw f c', b=b).contiguous() + else: + x = rearrange(x, '(b hw) c f -> b hw f c', b=b).contiguous() + for i, block in enumerate(self.transformer_blocks): + # context[i] = repeat(context[i], '(b f) l con -> b (f r) l con', r=(h*w)//self.frames, f=self.frames).contiguous() + context[i] = rearrange(context[i], '(b f) l con -> b f l con', f=self.frames).contiguous() + # calculate each batch one by one (since number in shape could not greater then 65,535 for some package) + for j in range(b): + context_i_j = repeat(context[i][j], 'f l con -> (f r) l con', r=(h*w)//self.frames, f=self.frames).contiguous() + x[j] = block(x[j], context=context_i_j) + + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, 'b (h w) f c -> b f c h w', h=h, w=w).contiguous() + if not self.use_linear: + # x = rearrange(x, 'bhw f c -> bhw c f').contiguous() + x = rearrange(x, 'b hw f c -> (b hw) c f').contiguous() + x = self.proj_out(x) + x = rearrange(x, '(b h w) c f -> b c f h w', b=b, h=h, w=w).contiguous() + + if self.multiply_zero: + x = 0.0 * x + x_in + else: + x = x + x_in + return x + +class TemporalAttentionMultiBlock(nn.Module): + def __init__( + self, + dim, + heads=4, + dim_head=32, + rotary_emb=None, + use_image_dataset=False, + use_sim_mask=False, + temporal_attn_times=1, + ): + super().__init__() + self.att_layers = nn.ModuleList( + [TemporalAttentionBlock(dim, heads, dim_head, rotary_emb, use_image_dataset, use_sim_mask) + for _ in range(temporal_attn_times)] + ) + + def forward( + self, + x, + pos_bias = None, + focus_present_mask = None, + video_mask = None + ): + for layer in self.att_layers: + x = layer(x, pos_bias, focus_present_mask, video_mask) + return x + + +class InitTemporalConvBlock(nn.Module): + + def __init__(self, in_dim, out_dim=None, dropout=0.0,use_image_dataset=False): + super(InitTemporalConvBlock, self).__init__() + if out_dim is None: + out_dim = in_dim#int(1.5*in_dim) + self.in_dim = in_dim + self.out_dim = out_dim + self.use_image_dataset = use_image_dataset + + # conv layers + self.conv = nn.Sequential( + nn.GroupNorm(32, out_dim), + nn.SiLU(), + nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding = (1, 0, 0))) + + # zero out the last layer params,so the conv block is identity + # nn.init.zeros_(self.conv1[-1].weight) + # nn.init.zeros_(self.conv1[-1].bias) + nn.init.zeros_(self.conv[-1].weight) + nn.init.zeros_(self.conv[-1].bias) + + def forward(self, x): + identity = x + x = self.conv(x) + if self.use_image_dataset: + x = identity + 0*x + else: + x = identity + x + return x + +class TemporalConvBlock(nn.Module): + + def __init__(self, in_dim, out_dim=None, dropout=0.0, use_image_dataset= False): + super(TemporalConvBlock, self).__init__() + if out_dim is None: + out_dim = in_dim#int(1.5*in_dim) + self.in_dim = in_dim + self.out_dim = out_dim + self.use_image_dataset = use_image_dataset + + # conv layers + self.conv1 = nn.Sequential( + nn.GroupNorm(32, in_dim), + nn.SiLU(), + nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding = (1, 0, 0))) + self.conv2 = nn.Sequential( + nn.GroupNorm(32, out_dim), + nn.SiLU(), + nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding = (1, 0, 0))) + + # zero out the last layer params,so the conv block is identity + # nn.init.zeros_(self.conv1[-1].weight) + # nn.init.zeros_(self.conv1[-1].bias) + nn.init.zeros_(self.conv2[-1].weight) + nn.init.zeros_(self.conv2[-1].bias) + + def forward(self, x): + identity = x + x = self.conv1(x) + x = self.conv2(x) + if self.use_image_dataset: + x = identity + 0*x + else: + x = identity + x + return x + +class TemporalConvBlock_v2(nn.Module): + def __init__(self, in_dim, out_dim=None, dropout=0.0, use_image_dataset=False): + super(TemporalConvBlock_v2, self).__init__() + if out_dim is None: + out_dim = in_dim # int(1.5*in_dim) + self.in_dim = in_dim + self.out_dim = out_dim + self.use_image_dataset = use_image_dataset + + # conv layers + self.conv1 = nn.Sequential( + nn.GroupNorm(32, in_dim), + nn.SiLU(), + nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding = (1, 0, 0))) + self.conv2 = nn.Sequential( + nn.GroupNorm(32, out_dim), + nn.SiLU(), + nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding = (1, 0, 0))) + self.conv3 = nn.Sequential( + nn.GroupNorm(32, out_dim), + nn.SiLU(), + nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding = (1, 0, 0))) + self.conv4 = nn.Sequential( + nn.GroupNorm(32, out_dim), + nn.SiLU(), + nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding = (1, 0, 0))) + + # zero out the last layer params,so the conv block is identity + nn.init.zeros_(self.conv4[-1].weight) + nn.init.zeros_(self.conv4[-1].bias) + + def forward(self, x): + identity = x + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + x = self.conv4(x) + + if self.use_image_dataset: + x = identity + 0.0 * x + else: + x = identity + x + return x + + +class DropPath(nn.Module): + r"""DropPath but without rescaling and supports optional all-zero and/or all-keep. + """ + def __init__(self, p): + super(DropPath, self).__init__() + self.p = p + + def forward(self, *args, zero=None, keep=None): + if not self.training: + return args[0] if len(args) == 1 else args + + # params + x = args[0] + b = x.size(0) + n = (torch.rand(b) < self.p).sum() + + # non-zero and non-keep mask + mask = x.new_ones(b, dtype=torch.bool) + if keep is not None: + mask[keep] = False + if zero is not None: + mask[zero] = False + + # drop-path index + index = torch.where(mask)[0] + index = index[torch.randperm(len(index))[:n]] + if zero is not None: + index = torch.cat([index, torch.where(zero)[0]], dim=0) + + # drop-path multiplier + multiplier = x.new_ones(b) + multiplier[index] = 0.0 + output = tuple(u * self.broadcast(multiplier, u) for u in args) + return output[0] if len(args) == 1 else output + + def broadcast(self, src, dst): + assert src.size(0) == dst.size(0) + shape = (dst.size(0), ) + (1, ) * (dst.ndim - 1) + return src.view(shape) + + +