diff --git a/utils/__pycache__/assign_cfg.cpython-310.pyc b/utils/__pycache__/assign_cfg.cpython-310.pyc new file mode 100644 index 0000000..68b17a8 Binary files /dev/null and b/utils/__pycache__/assign_cfg.cpython-310.pyc differ diff --git a/utils/__pycache__/config.cpython-310.pyc b/utils/__pycache__/config.cpython-310.pyc new file mode 100644 index 0000000..ece984f Binary files /dev/null and b/utils/__pycache__/config.cpython-310.pyc differ diff --git a/utils/__pycache__/distributed.cpython-310.pyc b/utils/__pycache__/distributed.cpython-310.pyc new file mode 100644 index 0000000..684f19c Binary files /dev/null and b/utils/__pycache__/distributed.cpython-310.pyc differ diff --git a/utils/__pycache__/logging.cpython-310.pyc b/utils/__pycache__/logging.cpython-310.pyc new file mode 100644 index 0000000..1d7f26c Binary files /dev/null and b/utils/__pycache__/logging.cpython-310.pyc differ diff --git a/utils/__pycache__/multi_port.cpython-310.pyc b/utils/__pycache__/multi_port.cpython-310.pyc new file mode 100644 index 0000000..6680883 Binary files /dev/null and b/utils/__pycache__/multi_port.cpython-310.pyc differ diff --git a/utils/__pycache__/registry.cpython-310.pyc b/utils/__pycache__/registry.cpython-310.pyc new file mode 100644 index 0000000..828895d Binary files /dev/null and b/utils/__pycache__/registry.cpython-310.pyc differ diff --git a/utils/__pycache__/registry_class.cpython-310.pyc b/utils/__pycache__/registry_class.cpython-310.pyc new file mode 100644 index 0000000..f2e76fe Binary files /dev/null and b/utils/__pycache__/registry_class.cpython-310.pyc differ diff --git a/utils/__pycache__/seed.cpython-310.pyc b/utils/__pycache__/seed.cpython-310.pyc new file mode 100644 index 0000000..52b04f1 Binary files /dev/null and b/utils/__pycache__/seed.cpython-310.pyc differ diff --git a/utils/__pycache__/transforms.cpython-310.pyc b/utils/__pycache__/transforms.cpython-310.pyc new file mode 100644 index 0000000..baa0c06 Binary files /dev/null and b/utils/__pycache__/transforms.cpython-310.pyc differ diff --git a/utils/__pycache__/video_op.cpython-310.pyc b/utils/__pycache__/video_op.cpython-310.pyc new file mode 100644 index 0000000..655f728 Binary files /dev/null and b/utils/__pycache__/video_op.cpython-310.pyc differ diff --git a/utils/assign_cfg.py b/utils/assign_cfg.py new file mode 100644 index 0000000..e911a68 --- /dev/null +++ b/utils/assign_cfg.py @@ -0,0 +1,78 @@ +import os, yaml +from copy import deepcopy, copy + + +# def get prior and ldm config +def assign_prior_mudule_cfg(cfg): + ''' + ''' + # + prior_cfg = deepcopy(cfg) + vldm_cfg = deepcopy(cfg) + + with open(cfg.prior_cfg, 'r') as f: + _cfg_update = yaml.load(f.read(), Loader=yaml.SafeLoader) + # _cfg_update = _cfg_update.cfg_dict + for k, v in _cfg_update.items(): + if isinstance(v, dict) and k in cfg: + prior_cfg[k].update(v) + else: + prior_cfg[k] = v + + with open(cfg.vldm_cfg, 'r') as f: + _cfg_update = yaml.load(f.read(), Loader=yaml.SafeLoader) + # _cfg_update = _cfg_update.cfg_dict + for k, v in _cfg_update.items(): + if isinstance(v, dict) and k in cfg: + vldm_cfg[k].update(v) + else: + vldm_cfg[k] = v + + return prior_cfg, vldm_cfg + + +# def get prior and ldm config +def assign_vldm_vsr_mudule_cfg(cfg): + ''' + ''' + # + vldm_cfg = deepcopy(cfg) + vsr_cfg = deepcopy(cfg) + + with open(cfg.vldm_cfg, 'r') as f: + _cfg_update = yaml.load(f.read(), Loader=yaml.SafeLoader) + # _cfg_update = _cfg_update.cfg_dict + for k, v in _cfg_update.items(): + if isinstance(v, dict) and k in cfg: + vldm_cfg[k].update(v) + else: + vldm_cfg[k] = v + + with open(cfg.vsr_cfg, 'r') as f: + _cfg_update = yaml.load(f.read(), Loader=yaml.SafeLoader) + # _cfg_update = _cfg_update.cfg_dict + for k, v in _cfg_update.items(): + if isinstance(v, dict) and k in cfg: + vsr_cfg[k].update(v) + else: + vsr_cfg[k] = v + + return vldm_cfg, vsr_cfg + + +# def get prior and ldm config +def assign_signle_cfg(cfg, _cfg_update, tname): + ''' + ''' + # + vldm_cfg = deepcopy(cfg) + if os.path.exists(_cfg_update[tname]): + with open(_cfg_update[tname], 'r') as f: + _cfg_update = yaml.load(f.read(), Loader=yaml.SafeLoader) + # _cfg_update = _cfg_update.cfg_dict + for k, v in _cfg_update.items(): + if isinstance(v, dict) and k in cfg: + vldm_cfg[k].update(v) + else: + vldm_cfg[k] = v + return vldm_cfg \ No newline at end of file diff --git a/utils/config.py b/utils/config.py new file mode 100644 index 0000000..ea587b1 --- /dev/null +++ b/utils/config.py @@ -0,0 +1,243 @@ +import os +import yaml +import json +import copy +import argparse + +from ..utils import logging +# logger = logging.get_logger(__name__) + +class Config(object): + def __init__(self, load=True, cfg_dict=None, cfg_level=None): + self._level = "cfg" + ("." + cfg_level if cfg_level is not None else "") + + current_directory = os.path.dirname(os.path.abspath(__file__)) + parent_directory = os.path.dirname(current_directory) + self.config_file_loc = os.path.join(parent_directory, 'configs/UniAnimate_infer.yaml') + + if load: + self.args = self._parse_args() + # logger.info("Loading config from {}.".format(self.args.cfg_file)) + self.need_initialization = True + cfg_base = self._load_yaml(self.args) # self._initialize_cfg() + cfg_dict = self._load_yaml(self.args) + cfg_dict = self._merge_cfg_from_base(cfg_base, cfg_dict) + cfg_dict = self._update_from_args(cfg_dict) + self.cfg_dict = cfg_dict + self._update_dict(cfg_dict) + + def _parse_args(self): + parser = argparse.ArgumentParser( + description="Argparser for configuring the codebase" + ) + parser.add_argument( + "--cfg", + dest="cfg_file", + help="Path to the configuration file", + default= self.config_file_loc + ) + parser.add_argument( + "--init_method", + help="Initialization method, includes TCP or shared file-system", + default="tcp://localhost:9999", + type=str, + ) + parser.add_argument( + '--debug', + action='store_true', + default=False, + help='Output debug information' + ) + parser.add_argument( + '--windows-standalone-build', + action='store_true', + default=False, + help='Indicates if the build is a standalone build for Windows' + ) + parser.add_argument( + "opts", + help="Other configurations", + default=None, + nargs=argparse.REMAINDER + ) + return parser.parse_args() + + + def _path_join(self, path_list): + path = "" + for p in path_list: + path+= p + '/' + return path[:-1] + + def _update_from_args(self, cfg_dict): + args = self.args + for var in vars(args): + cfg_dict[var] = getattr(args, var) + return cfg_dict + + def _initialize_cfg(self): + if self.need_initialization: + self.need_initialization = False + if os.path.exists('./configs/base.yaml'): + with open("./configs/base.yaml", 'r') as f: + cfg = yaml.load(f.read(), Loader=yaml.SafeLoader) + else: + with open(os.path.realpath(__file__).split('/')[-3] + "/configs/base.yaml", 'r') as f: + cfg = yaml.load(f.read(), Loader=yaml.SafeLoader) + return cfg + + def _load_yaml(self, args, file_name=""): + assert args.cfg_file is not None + if not file_name == "": # reading from base file + with open(file_name, 'r') as f: + cfg = yaml.load(f.read(), Loader=yaml.SafeLoader) + else: + if os.getcwd().split("/")[-1] == args.cfg_file.split("/")[0]: + args.cfg_file = args.cfg_file.replace(os.getcwd().split("/")[-1], "./") + with open(args.cfg_file, 'r') as f: + cfg = yaml.load(f.read(), Loader=yaml.SafeLoader) + file_name = args.cfg_file + + if "_BASE_RUN" not in cfg.keys() and "_BASE_MODEL" not in cfg.keys() and "_BASE" not in cfg.keys(): + # return cfg if the base file is being accessed + cfg = self._merge_cfg_from_command_update(args, cfg) + return cfg + + if "_BASE" in cfg.keys(): + if cfg["_BASE"][1] == '.': + prev_count = cfg["_BASE"].count('..') + cfg_base_file = self._path_join(file_name.split('/')[:(-1-cfg["_BASE"].count('..'))] + cfg["_BASE"].split('/')[prev_count:]) + else: + cfg_base_file = cfg["_BASE"].replace( + "./", + args.cfg_file.replace(args.cfg_file.split('/')[-1], "") + ) + cfg_base = self._load_yaml(args, cfg_base_file) + cfg = self._merge_cfg_from_base(cfg_base, cfg) + else: + if "_BASE_RUN" in cfg.keys(): + if cfg["_BASE_RUN"][1] == '.': + prev_count = cfg["_BASE_RUN"].count('..') + cfg_base_file = self._path_join(file_name.split('/')[:(-1-prev_count)] + cfg["_BASE_RUN"].split('/')[prev_count:]) + else: + cfg_base_file = cfg["_BASE_RUN"].replace( + "./", + args.cfg_file.replace(args.cfg_file.split('/')[-1], "") + ) + cfg_base = self._load_yaml(args, cfg_base_file) + cfg = self._merge_cfg_from_base(cfg_base, cfg, preserve_base=True) + if "_BASE_MODEL" in cfg.keys(): + if cfg["_BASE_MODEL"][1] == '.': + prev_count = cfg["_BASE_MODEL"].count('..') + cfg_base_file = self._path_join(file_name.split('/')[:(-1-cfg["_BASE_MODEL"].count('..'))] + cfg["_BASE_MODEL"].split('/')[prev_count:]) + else: + cfg_base_file = cfg["_BASE_MODEL"].replace( + "./", + args.cfg_file.replace(args.cfg_file.split('/')[-1], "") + ) + cfg_base = self._load_yaml(args, cfg_base_file) + cfg = self._merge_cfg_from_base(cfg_base, cfg) + cfg = self._merge_cfg_from_command(args, cfg) + return cfg + + def _merge_cfg_from_base(self, cfg_base, cfg_new, preserve_base=False): + for k,v in cfg_new.items(): + if k in cfg_base.keys(): + if isinstance(v, dict): + self._merge_cfg_from_base(cfg_base[k], v) + else: + cfg_base[k] = v + else: + if "BASE" not in k or preserve_base: + cfg_base[k] = v + return cfg_base + + def _merge_cfg_from_command_update(self, args, cfg): + if len(args.opts) == 0: + return cfg + + assert len(args.opts) % 2 == 0, 'Override list {} has odd length: {}.'.format( + args.opts, len(args.opts) + ) + keys = args.opts[0::2] + vals = args.opts[1::2] + + for key, val in zip(keys, vals): + cfg[key] = val + + return cfg + + def _merge_cfg_from_command(self, args, cfg): + assert len(args.opts) % 2 == 0, 'Override list {} has odd length: {}.'.format( + args.opts, len(args.opts) + ) + keys = args.opts[0::2] + vals = args.opts[1::2] + + # maximum supported depth 3 + for idx, key in enumerate(keys): + key_split = key.split('.') + assert len(key_split) <= 4, 'Key depth error. \nMaximum depth: 3\n Get depth: {}'.format( + len(key_split) + ) + assert key_split[0] in cfg.keys(), 'Non-existant key: {}.'.format( + key_split[0] + ) + if len(key_split) == 2: + assert key_split[1] in cfg[key_split[0]].keys(), 'Non-existant key: {}.'.format( + key + ) + elif len(key_split) == 3: + assert key_split[1] in cfg[key_split[0]].keys(), 'Non-existant key: {}.'.format( + key + ) + assert key_split[2] in cfg[key_split[0]][key_split[1]].keys(), 'Non-existant key: {}.'.format( + key + ) + elif len(key_split) == 4: + assert key_split[1] in cfg[key_split[0]].keys(), 'Non-existant key: {}.'.format( + key + ) + assert key_split[2] in cfg[key_split[0]][key_split[1]].keys(), 'Non-existant key: {}.'.format( + key + ) + assert key_split[3] in cfg[key_split[0]][key_split[1]][key_split[2]].keys(), 'Non-existant key: {}.'.format( + key + ) + if len(key_split) == 1: + cfg[key_split[0]] = vals[idx] + elif len(key_split) == 2: + cfg[key_split[0]][key_split[1]] = vals[idx] + elif len(key_split) == 3: + cfg[key_split[0]][key_split[1]][key_split[2]] = vals[idx] + elif len(key_split) == 4: + cfg[key_split[0]][key_split[1]][key_split[2]][key_split[3]] = vals[idx] + return cfg + + def _update_dict(self, cfg_dict): + def recur(key, elem): + if type(elem) is dict: + return key, Config(load=False, cfg_dict=elem, cfg_level=key) + else: + if type(elem) is str and elem[1:3]=="e-": + elem = float(elem) + return key, elem + dic = dict(recur(k, v) for k, v in cfg_dict.items()) + self.__dict__.update(dic) + + def get_args(self): + return self.args + + def __repr__(self): + return "{}\n".format(self.dump()) + + def dump(self): + return json.dumps(self.cfg_dict, indent=2) + + def deep_copy(self): + return copy.deepcopy(self) + +# if __name__ == '__main__': +# # debug +# cfg = Config(load=True) +# print(cfg.DATA) \ No newline at end of file diff --git a/utils/distributed.py b/utils/distributed.py new file mode 100644 index 0000000..cba28ba --- /dev/null +++ b/utils/distributed.py @@ -0,0 +1,430 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +import torch +import torch.nn.functional as F +import torch.distributed as dist +import functools +import pickle +import numpy as np +from collections import OrderedDict +from torch.autograd import Function + +__all__ = ['is_dist_initialized', + 'get_world_size', + 'get_rank', + 'new_group', + 'destroy_process_group', + 'barrier', + 'broadcast', + 'all_reduce', + 'reduce', + 'gather', + 'all_gather', + 'reduce_dict', + 'get_global_gloo_group', + 'generalized_all_gather', + 'generalized_gather', + 'scatter', + 'reduce_scatter', + 'send', + 'recv', + 'isend', + 'irecv', + 'shared_random_seed', + 'diff_all_gather', + 'diff_all_reduce', + 'diff_scatter', + 'diff_copy', + 'spherical_kmeans', + 'sinkhorn'] + +#-------------------------------- Distributed operations --------------------------------# + +def is_dist_initialized(): + return dist.is_available() and dist.is_initialized() + +def get_world_size(group=None): + return dist.get_world_size(group) if is_dist_initialized() else 1 + +def get_rank(group=None): + return dist.get_rank(group) if is_dist_initialized() else 0 + +def new_group(ranks=None, **kwargs): + if is_dist_initialized(): + return dist.new_group(ranks, **kwargs) + return None + +def destroy_process_group(): + if is_dist_initialized(): + dist.destroy_process_group() + +def barrier(group=None, **kwargs): + if get_world_size(group) > 1: + dist.barrier(group, **kwargs) + +def broadcast(tensor, src, group=None, **kwargs): + if get_world_size(group) > 1: + return dist.broadcast(tensor, src, group, **kwargs) + +def all_reduce(tensor, op=dist.ReduceOp.SUM, group=None, **kwargs): + if get_world_size(group) > 1: + return dist.all_reduce(tensor, op, group, **kwargs) + +def reduce(tensor, dst, op=dist.ReduceOp.SUM, group=None, **kwargs): + if get_world_size(group) > 1: + return dist.reduce(tensor, dst, op, group, **kwargs) + +def gather(tensor, dst=0, group=None, **kwargs): + rank = get_rank() # global rank + world_size = get_world_size(group) + if world_size == 1: + return [tensor] + tensor_list = [torch.empty_like(tensor) for _ in range(world_size)] if rank == dst else None + dist.gather(tensor, tensor_list, dst, group, **kwargs) + return tensor_list + +def all_gather(tensor, uniform_size=True, group=None, **kwargs): + world_size = get_world_size(group) + if world_size == 1: + return [tensor] + assert tensor.is_contiguous(), 'ops.all_gather requires the tensor to be contiguous()' + + if uniform_size: + tensor_list = [torch.empty_like(tensor) for _ in range(world_size)] + dist.all_gather(tensor_list, tensor, group, **kwargs) + return tensor_list + else: + # collect tensor shapes across GPUs + shape = tuple(tensor.shape) + shape_list = generalized_all_gather(shape, group) + + # flatten the tensor + tensor = tensor.reshape(-1) + size = int(np.prod(shape)) + size_list = [int(np.prod(u)) for u in shape_list] + max_size = max(size_list) + + # pad to maximum size + if size != max_size: + padding = tensor.new_zeros(max_size - size) + tensor = torch.cat([tensor, padding], dim=0) + + # all_gather + tensor_list = [torch.empty_like(tensor) for _ in range(world_size)] + dist.all_gather(tensor_list, tensor, group, **kwargs) + + # reshape tensors + tensor_list = [t[:n].view(s) for t, n, s in zip( + tensor_list, size_list, shape_list)] + return tensor_list + +@torch.no_grad() +def reduce_dict(input_dict, group=None, reduction='mean', **kwargs): + assert reduction in ['mean', 'sum'] + world_size = get_world_size(group) + if world_size == 1: + return input_dict + + # ensure that the orders of keys are consistent across processes + if isinstance(input_dict, OrderedDict): + keys = list(input_dict.keys) + else: + keys = sorted(input_dict.keys()) + vals = [input_dict[key] for key in keys] + vals = torch.stack(vals, dim=0) + dist.reduce(vals, dst=0, group=group, **kwargs) + if dist.get_rank(group) == 0 and reduction == 'mean': + vals /= world_size + dist.broadcast(vals, src=0, group=group, **kwargs) + reduced_dict = type(input_dict)([ + (key, val) for key, val in zip(keys, vals)]) + return reduced_dict + +@functools.lru_cache() +def get_global_gloo_group(): + backend = dist.get_backend() + assert backend in ['gloo', 'nccl'] + if backend == 'nccl': + return dist.new_group(backend='gloo') + else: + return dist.group.WORLD + +def _serialize_to_tensor(data, group): + backend = dist.get_backend(group) + assert backend in ['gloo', 'nccl'] + device = torch.device('cpu' if backend == 'gloo' else 'cuda') + + buffer = pickle.dumps(data) + if len(buffer) > 1024 ** 3: + logger = logging.getLogger(__name__) + logger.warning( + 'Rank {} trying to all-gather {:.2f} GB of data on device' + '{}'.format(get_rank(), len(buffer) / (1024 ** 3), device)) + storage = torch.ByteStorage.from_buffer(buffer) + tensor = torch.ByteTensor(storage).to(device=device) + return tensor + +def _pad_to_largest_tensor(tensor, group): + world_size = dist.get_world_size(group=group) + assert world_size >= 1, \ + 'gather/all_gather must be called from ranks within' \ + 'the give group!' + local_size = torch.tensor( + [tensor.numel()], dtype=torch.int64, device=tensor.device) + size_list = [torch.zeros( + [1], dtype=torch.int64, device=tensor.device) + for _ in range(world_size)] + + # gather tensors and compute the maximum size + dist.all_gather(size_list, local_size, group=group) + size_list = [int(size.item()) for size in size_list] + max_size = max(size_list) + + # pad tensors to the same size + if local_size != max_size: + padding = torch.zeros( + (max_size - local_size, ), + dtype=torch.uint8, device=tensor.device) + tensor = torch.cat((tensor, padding), dim=0) + return size_list, tensor + +def generalized_all_gather(data, group=None): + if get_world_size(group) == 1: + return [data] + if group is None: + group = get_global_gloo_group() + + tensor = _serialize_to_tensor(data, group) + size_list, tensor = _pad_to_largest_tensor(tensor, group) + max_size = max(size_list) + + # receiving tensors from all ranks + tensor_list = [torch.empty( + (max_size, ), dtype=torch.uint8, device=tensor.device) + for _ in size_list] + dist.all_gather(tensor_list, tensor, group=group) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + return data_list + +def generalized_gather(data, dst=0, group=None): + world_size = get_world_size(group) + if world_size == 1: + return [data] + if group is None: + group = get_global_gloo_group() + rank = dist.get_rank() # global rank + + tensor = _serialize_to_tensor(data, group) + size_list, tensor = _pad_to_largest_tensor(tensor, group) + + # receiving tensors from all ranks to dst + if rank == dst: + max_size = max(size_list) + tensor_list = [torch.empty( + (max_size, ), dtype=torch.uint8, device=tensor.device) + for _ in size_list] + dist.gather(tensor, tensor_list, dst=dst, group=group) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + return data_list + else: + dist.gather(tensor, [], dst=dst, group=group) + return [] + +def scatter(data, scatter_list=None, src=0, group=None, **kwargs): + r"""NOTE: only supports CPU tensor communication. + """ + if get_world_size(group) > 1: + return dist.scatter(data, scatter_list, src, group, **kwargs) + +def reduce_scatter(output, input_list, op=dist.ReduceOp.SUM, group=None, **kwargs): + if get_world_size(group) > 1: + return dist.reduce_scatter(output, input_list, op, group, **kwargs) + +def send(tensor, dst, group=None, **kwargs): + if get_world_size(group) > 1: + assert tensor.is_contiguous(), 'ops.send requires the tensor to be contiguous()' + return dist.send(tensor, dst, group, **kwargs) + +def recv(tensor, src=None, group=None, **kwargs): + if get_world_size(group) > 1: + assert tensor.is_contiguous(), 'ops.recv requires the tensor to be contiguous()' + return dist.recv(tensor, src, group, **kwargs) + +def isend(tensor, dst, group=None, **kwargs): + if get_world_size(group) > 1: + assert tensor.is_contiguous(), 'ops.isend requires the tensor to be contiguous()' + return dist.isend(tensor, dst, group, **kwargs) + +def irecv(tensor, src=None, group=None, **kwargs): + if get_world_size(group) > 1: + assert tensor.is_contiguous(), 'ops.irecv requires the tensor to be contiguous()' + return dist.irecv(tensor, src, group, **kwargs) + +def shared_random_seed(group=None): + seed = np.random.randint(2 ** 31) + all_seeds = generalized_all_gather(seed, group) + return all_seeds[0] + +#-------------------------------- Differentiable operations --------------------------------# + +def _all_gather(x): + if not (dist.is_available() and dist.is_initialized()) or dist.get_world_size() == 1: + return x + rank = dist.get_rank() + world_size = dist.get_world_size() + tensors = [torch.empty_like(x) for _ in range(world_size)] + tensors[rank] = x + dist.all_gather(tensors, x) + return torch.cat(tensors, dim=0).contiguous() + +def _all_reduce(x): + if not (dist.is_available() and dist.is_initialized()) or dist.get_world_size() == 1: + return x + dist.all_reduce(x) + return x + +def _split(x): + if not (dist.is_available() and dist.is_initialized()) or dist.get_world_size() == 1: + return x + rank = dist.get_rank() + world_size = dist.get_world_size() + return x.chunk(world_size, dim=0)[rank].contiguous() + +class DiffAllGather(Function): + r"""Differentiable all-gather. + """ + @staticmethod + def symbolic(graph, input): + return _all_gather(input) + + @staticmethod + def forward(ctx, input): + return _all_gather(input) + + @staticmethod + def backward(ctx, grad_output): + return _split(grad_output) + +class DiffAllReduce(Function): + r"""Differentiable all-reducd. + """ + @staticmethod + def symbolic(graph, input): + return _all_reduce(input) + + @staticmethod + def forward(ctx, input): + return _all_reduce(input) + + @staticmethod + def backward(ctx, grad_output): + return grad_output + +class DiffScatter(Function): + r"""Differentiable scatter. + """ + @staticmethod + def symbolic(graph, input): + return _split(input) + + @staticmethod + def symbolic(ctx, input): + return _split(input) + + @staticmethod + def backward(ctx, grad_output): + return _all_gather(grad_output) + +class DiffCopy(Function): + r"""Differentiable copy that reduces all gradients during backward. + """ + @staticmethod + def symbolic(graph, input): + return input + + @staticmethod + def forward(ctx, input): + return input + + @staticmethod + def backward(ctx, grad_output): + return _all_reduce(grad_output) + +diff_all_gather = DiffAllGather.apply +diff_all_reduce = DiffAllReduce.apply +diff_scatter = DiffScatter.apply +diff_copy = DiffCopy.apply + +#-------------------------------- Distributed algorithms --------------------------------# + +@torch.no_grad() +def spherical_kmeans(feats, num_clusters, num_iters=10): + k, n, c = num_clusters, *feats.size() + ones = feats.new_ones(n, dtype=torch.long) + + # distributed settings + rank = get_rank() + world_size = get_world_size() + + # init clusters + rand_inds = torch.randperm(n)[:int(np.ceil(k / world_size))] + clusters = torch.cat(all_gather(feats[rand_inds]), dim=0)[:k] + + # variables + new_clusters = feats.new_zeros(k, c) + counts = feats.new_zeros(k, dtype=torch.long) + + # iterative Expectation-Maximization + for step in range(num_iters + 1): + # Expectation step + simmat = torch.mm(feats, clusters.t()) + scores, assigns = simmat.max(dim=1) + if step == num_iters: + break + + # Maximization step + new_clusters.zero_().scatter_add_(0, assigns.unsqueeze(1).repeat(1, c), feats) + all_reduce(new_clusters) + + counts.zero_() + counts.index_add_(0, assigns, ones) + all_reduce(counts) + + mask = (counts > 0) + clusters[mask] = new_clusters[mask] / counts[mask].view(-1, 1) + clusters = F.normalize(clusters, p=2, dim=1) + return clusters, assigns, scores + +@torch.no_grad() +def sinkhorn(Q, eps=0.5, num_iters=3): + # normalize Q + Q = torch.exp(Q / eps).t() + sum_Q = Q.sum() + all_reduce(sum_Q) + Q /= sum_Q + + # variables + n, m = Q.size() + u = Q.new_zeros(n) + r = Q.new_ones(n) / n + c = Q.new_ones(m) / (m * get_world_size()) + + # iterative update + cur_sum = Q.sum(dim=1) + all_reduce(cur_sum) + for i in range(num_iters): + u = cur_sum + Q *= (r / u).unsqueeze(1) + Q *= (c / Q.sum(dim=0)).unsqueeze(0) + cur_sum = Q.sum(dim=1) + all_reduce(cur_sum) + return (Q / Q.sum(dim=0, keepdim=True)).t().float() diff --git a/utils/logging.py b/utils/logging.py new file mode 100644 index 0000000..6b1740f --- /dev/null +++ b/utils/logging.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +"""Logging.""" + +import builtins +import decimal +import functools +import logging +import os +import sys +from ..lib import simplejson +# from fvcore.common.file_io import PathManager + +from ..utils import distributed as du + + +def _suppress_print(): + """ + Suppresses printing from the current process. + """ + + def print_pass(*objects, sep=" ", end="\n", file=sys.stdout, flush=False): + pass + + builtins.print = print_pass + + +# @functools.lru_cache(maxsize=None) +# def _cached_log_stream(filename): +# return PathManager.open(filename, "a") + + +def setup_logging(cfg, log_file): + """ + Sets up the logging for multiple processes. Only enable the logging for the + master process, and suppress logging for the non-master processes. + """ + if du.is_master_proc(): + # Enable logging for the master process. + logging.root.handlers = [] + else: + # Suppress logging for non-master processes. + _suppress_print() + + logger = logging.getLogger() + logger.setLevel(logging.INFO) + logger.propagate = False + plain_formatter = logging.Formatter( + "[%(asctime)s][%(levelname)s] %(name)s: %(lineno)4d: %(message)s", + datefmt="%m/%d %H:%M:%S", + ) + + if du.is_master_proc(): + ch = logging.StreamHandler(stream=sys.stdout) + ch.setLevel(logging.DEBUG) + ch.setFormatter(plain_formatter) + logger.addHandler(ch) + + if log_file is not None and du.is_master_proc(du.get_world_size()): + filename = os.path.join(cfg.OUTPUT_DIR, log_file) + fh = logging.FileHandler(filename) + fh.setLevel(logging.DEBUG) + fh.setFormatter(plain_formatter) + logger.addHandler(fh) + + +def get_logger(name): + """ + Retrieve the logger with the specified name or, if name is None, return a + logger which is the root logger of the hierarchy. + Args: + name (string): name of the logger. + """ + return logging.getLogger(name) + + +def log_json_stats(stats): + """ + Logs json stats. + Args: + stats (dict): a dictionary of statistical information to log. + """ + stats = { + k: decimal.Decimal("{:.6f}".format(v)) if isinstance(v, float) else v + for k, v in stats.items() + } + json_stats = simplejson.dumps(stats, sort_keys=True, use_decimal=True) + logger = get_logger(__name__) + logger.info("{:s}".format(json_stats)) diff --git a/utils/mp4_to_gif.py b/utils/mp4_to_gif.py new file mode 100644 index 0000000..5c85921 --- /dev/null +++ b/utils/mp4_to_gif.py @@ -0,0 +1,16 @@ +import os + + + +# source_mp4_dir = "outputs/UniAnimate_infer" +# target_gif_dir = "outputs/UniAnimate_infer_gif" + +source_mp4_dir = "outputs/UniAnimate_infer_long" +target_gif_dir = "outputs/UniAnimate_infer_long_gif" + +os.makedirs(target_gif_dir, exist_ok=True) +for video in os.listdir(source_mp4_dir): + video_dir = os.path.join(source_mp4_dir, video) + gif_dir = os.path.join(target_gif_dir, video.replace(".mp4", ".gif")) + cmd = f'ffmpeg -i {video_dir} {gif_dir}' + os.system(cmd) \ No newline at end of file diff --git a/utils/multi_port.py b/utils/multi_port.py new file mode 100644 index 0000000..1542056 --- /dev/null +++ b/utils/multi_port.py @@ -0,0 +1,9 @@ +import socket +from contextlib import closing + +def find_free_port(): + """ https://stackoverflow.com/questions/1365265/on-localhost-how-do-i-pick-a-free-port-number """ + with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: + s.bind(('', 0)) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + return str(s.getsockname()[1]) \ No newline at end of file diff --git a/utils/optim/__init__.py b/utils/optim/__init__.py new file mode 100644 index 0000000..ffb9c49 --- /dev/null +++ b/utils/optim/__init__.py @@ -0,0 +1,2 @@ +from .lr_scheduler import * +from .adafactor import * diff --git a/utils/optim/adafactor.py b/utils/optim/adafactor.py new file mode 100644 index 0000000..63fee95 --- /dev/null +++ b/utils/optim/adafactor.py @@ -0,0 +1,230 @@ +import math +import torch +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LambdaLR + +__all__ = ['Adafactor'] + +class Adafactor(Optimizer): + """ + AdaFactor pytorch implementation can be used as a drop in replacement for Adam original fairseq code: + https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py + Paper: *Adafactor: Adaptive Learning Rates with Sublinear Memory Cost* https://arxiv.org/abs/1804.04235 Note that + this optimizer internally adjusts the learning rate depending on the `scale_parameter`, `relative_step` and + `warmup_init` options. To use a manual (external) learning rate schedule you should set `scale_parameter=False` and + `relative_step=False`. + Arguments: + params (`Iterable[nn.parameter.Parameter]`): + Iterable of parameters to optimize or dictionaries defining parameter groups. + lr (`float`, *optional*): + The external learning rate. + eps (`Tuple[float, float]`, *optional*, defaults to (1e-30, 1e-3)): + Regularization constants for square gradient and parameter scale respectively + clip_threshold (`float`, *optional*, defaults 1.0): + Threshold of root mean square of final gradient update + decay_rate (`float`, *optional*, defaults to -0.8): + Coefficient used to compute running averages of square + beta1 (`float`, *optional*): + Coefficient used for computing running averages of gradient + weight_decay (`float`, *optional*, defaults to 0): + Weight decay (L2 penalty) + scale_parameter (`bool`, *optional*, defaults to `True`): + If True, learning rate is scaled by root mean square + relative_step (`bool`, *optional*, defaults to `True`): + If True, time-dependent learning rate is computed instead of external learning rate + warmup_init (`bool`, *optional*, defaults to `False`): + Time-dependent learning rate computation depends on whether warm-up initialization is being used + This implementation handles low-precision (FP16, bfloat) values, but we have not thoroughly tested. + Recommended T5 finetuning settings (https://discuss.huggingface.co/t/t5-finetuning-tips/684/3): + - Training without LR warmup or clip_threshold is not recommended. + - use scheduled LR warm-up to fixed LR + - use clip_threshold=1.0 (https://arxiv.org/abs/1804.04235) + - Disable relative updates + - Use scale_parameter=False + - Additional optimizer operations like gradient clipping should not be used alongside Adafactor + Example: + ```python + Adafactor(model.parameters(), scale_parameter=False, relative_step=False, warmup_init=False, lr=1e-3) + ``` + Others reported the following combination to work well: + ```python + Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None) + ``` + When using `lr=None` with [`Trainer`] you will most likely need to use [`~optimization.AdafactorSchedule`] + scheduler as following: + ```python + from transformers.optimization import Adafactor, AdafactorSchedule + optimizer = Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None) + lr_scheduler = AdafactorSchedule(optimizer) + trainer = Trainer(..., optimizers=(optimizer, lr_scheduler)) + ``` + Usage: + ```python + # replace AdamW with Adafactor + optimizer = Adafactor( + model.parameters(), + lr=1e-3, + eps=(1e-30, 1e-3), + clip_threshold=1.0, + decay_rate=-0.8, + beta1=None, + weight_decay=0.0, + relative_step=False, + scale_parameter=False, + warmup_init=False, + ) + ```""" + + def __init__( + self, + params, + lr=None, + eps=(1e-30, 1e-3), + clip_threshold=1.0, + decay_rate=-0.8, + beta1=None, + weight_decay=0.0, + scale_parameter=True, + relative_step=True, + warmup_init=False, + ): + r"""require_version("torch>=1.5.0") # add_ with alpha + """ + if lr is not None and relative_step: + raise ValueError("Cannot combine manual `lr` and `relative_step=True` options") + if warmup_init and not relative_step: + raise ValueError("`warmup_init=True` requires `relative_step=True`") + + defaults = dict( + lr=lr, + eps=eps, + clip_threshold=clip_threshold, + decay_rate=decay_rate, + beta1=beta1, + weight_decay=weight_decay, + scale_parameter=scale_parameter, + relative_step=relative_step, + warmup_init=warmup_init, + ) + super().__init__(params, defaults) + + @staticmethod + def _get_lr(param_group, param_state): + rel_step_sz = param_group["lr"] + if param_group["relative_step"]: + min_step = 1e-6 * param_state["step"] if param_group["warmup_init"] else 1e-2 + rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"])) + param_scale = 1.0 + if param_group["scale_parameter"]: + param_scale = max(param_group["eps"][1], param_state["RMS"]) + return param_scale * rel_step_sz + + @staticmethod + def _get_options(param_group, param_shape): + factored = len(param_shape) >= 2 + use_first_moment = param_group["beta1"] is not None + return factored, use_first_moment + + @staticmethod + def _rms(tensor): + return tensor.norm(2) / (tensor.numel() ** 0.5) + + @staticmethod + def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col): + # copy from fairseq's adafactor implementation: + # https://github.com/huggingface/transformers/blob/8395f14de6068012787d83989c3627c3df6a252b/src/transformers/optimization.py#L505 + r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1) + c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() + return torch.mul(r_factor, c_factor) + + def step(self, closure=None): + """ + Performs a single optimization step + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad.data + if grad.dtype in {torch.float16, torch.bfloat16}: + grad = grad.float() + if grad.is_sparse: + raise RuntimeError("Adafactor does not support sparse gradients.") + + state = self.state[p] + grad_shape = grad.shape + + factored, use_first_moment = self._get_options(group, grad_shape) + # State Initialization + if len(state) == 0: + state["step"] = 0 + + if use_first_moment: + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(grad) + if factored: + state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad) + state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad) + else: + state["exp_avg_sq"] = torch.zeros_like(grad) + + state["RMS"] = 0 + else: + if use_first_moment: + state["exp_avg"] = state["exp_avg"].to(grad) + if factored: + state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad) + state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad) + else: + state["exp_avg_sq"] = state["exp_avg_sq"].to(grad) + + p_data_fp32 = p.data + if p.data.dtype in {torch.float16, torch.bfloat16}: + p_data_fp32 = p_data_fp32.float() + + state["step"] += 1 + state["RMS"] = self._rms(p_data_fp32) + lr = self._get_lr(group, state) + + beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) + update = (grad**2) + group["eps"][0] + if factored: + exp_avg_sq_row = state["exp_avg_sq_row"] + exp_avg_sq_col = state["exp_avg_sq_col"] + + exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t)) + exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t)) + + # Approximation of exponential moving average of square of gradient + update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + update.mul_(grad) + else: + exp_avg_sq = state["exp_avg_sq"] + + exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t)) + update = exp_avg_sq.rsqrt().mul_(grad) + + update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0)) + update.mul_(lr) + + if use_first_moment: + exp_avg = state["exp_avg"] + exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"])) + update = exp_avg + + if group["weight_decay"] != 0: + p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr)) + + p_data_fp32.add_(-update) + + if p.data.dtype in {torch.float16, torch.bfloat16}: + p.data.copy_(p_data_fp32) + + return loss diff --git a/utils/optim/lr_scheduler.py b/utils/optim/lr_scheduler.py new file mode 100644 index 0000000..20eba6d --- /dev/null +++ b/utils/optim/lr_scheduler.py @@ -0,0 +1,58 @@ +import math +from torch.optim.lr_scheduler import _LRScheduler + +__all__ = ['AnnealingLR'] + +class AnnealingLR(_LRScheduler): + + def __init__(self, optimizer, base_lr, warmup_steps, total_steps, decay_mode='cosine', min_lr=0.0, last_step=-1): + assert decay_mode in ['linear', 'cosine', 'none'] + self.optimizer = optimizer + self.base_lr = base_lr + self.warmup_steps = warmup_steps + self.total_steps = total_steps + self.decay_mode = decay_mode + self.min_lr = min_lr + self.current_step = last_step + 1 + self.step(self.current_step) + + def get_lr(self): + if self.warmup_steps > 0 and self.current_step <= self.warmup_steps: + return self.base_lr * self.current_step / self.warmup_steps + else: + ratio = (self.current_step - self.warmup_steps) / (self.total_steps - self.warmup_steps) + ratio = min(1.0, max(0.0, ratio)) + if self.decay_mode == 'linear': + return self.base_lr * (1 - ratio) + elif self.decay_mode == 'cosine': + return self.base_lr * (math.cos(math.pi * ratio) + 1.0) / 2.0 + else: + return self.base_lr + + def step(self, current_step=None): + if current_step is None: + current_step = self.current_step + 1 + self.current_step = current_step + new_lr = max(self.min_lr, self.get_lr()) + if isinstance(self.optimizer, list): + for o in self.optimizer: + for group in o.param_groups: + group['lr'] = new_lr + else: + for group in self.optimizer.param_groups: + group['lr'] = new_lr + + def state_dict(self): + return { + 'base_lr': self.base_lr, + 'warmup_steps': self.warmup_steps, + 'total_steps': self.total_steps, + 'decay_mode': self.decay_mode, + 'current_step': self.current_step} + + def load_state_dict(self, state_dict): + self.base_lr = state_dict['base_lr'] + self.warmup_steps = state_dict['warmup_steps'] + self.total_steps = state_dict['total_steps'] + self.decay_mode = state_dict['decay_mode'] + self.current_step = state_dict['current_step'] diff --git a/utils/registry.py b/utils/registry.py new file mode 100644 index 0000000..0faf364 --- /dev/null +++ b/utils/registry.py @@ -0,0 +1,167 @@ +# Copyright 2021 Alibaba Group Holding Limited. All Rights Reserved. + +# Registry class & build_from_config function partially modified from +# https://github.com/open-mmlab/mmcv/blob/master/mmcv/utils/registry.py +# Copyright 2018-2020 Open-MMLab. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import inspect +import warnings + + +def build_from_config(cfg, registry, **kwargs): + """ Default builder function. + + Args: + cfg (dict): A dict which contains parameters passes to target class or function. + Must contains key 'type', indicates the target class or function name. + registry (Registry): An registry to search target class or function. + kwargs (dict, optional): Other params not in config dict. + + Returns: + Target class object or object returned by invoking function. + + Raises: + TypeError: + KeyError: + Exception: + """ + if not isinstance(cfg, dict): + raise TypeError(f"config must be type dict, got {type(cfg)}") + if "type" not in cfg: + raise KeyError(f"config must contain key type, got {cfg}") + if not isinstance(registry, Registry): + raise TypeError(f"registry must be type Registry, got {type(registry)}") + + cfg = copy.deepcopy(cfg) + + req_type = cfg.pop("type") + req_type_entry = req_type + if isinstance(req_type, str): + req_type_entry = registry.get(req_type) + if req_type_entry is None: + try: + print(f"For Windows users, we explicitly import registry function {req_type} !!!") + from tools.inferences.inference_unianimate_entrance import inference_unianimate_entrance + from tools.inferences.inference_unianimate_long_entrance import inference_unianimate_long_entrance + # from tools.modules.diffusions.diffusion_ddim import DiffusionDDIM + # from tools.modules.diffusions.diffusion_ddim import DiffusionDDIMLong + # from tools.modules.autoencoder import AutoencoderKL + # from tools.modules.clip_embedder import FrozenOpenCLIPTextVisualEmbedder + # from tools.modules.unet.unet_unianimate import UNetSD_UniAnimate + + req_type_entry = eval(req_type) + except: + raise KeyError(f"{req_type} not found in {registry.name} registry") + + if kwargs is not None: + cfg.update(kwargs) + + if inspect.isclass(req_type_entry): + try: + return req_type_entry(**cfg) + except Exception as e: + raise Exception(f"Failed to init class {req_type_entry}, with {e}") + elif inspect.isfunction(req_type_entry): + try: + return req_type_entry(**cfg) + except Exception as e: + raise Exception(f"Failed to invoke function {req_type_entry}, with {e}") + else: + raise TypeError(f"type must be str or class, got {type(req_type_entry)}") + + +class Registry(object): + """ A registry maps key to classes or functions. + + Example: + >>> MODELS = Registry('MODELS') + >>> @MODELS.register_class() + >>> class ResNet(object): + >>> pass + >>> resnet = MODELS.build(dict(type="ResNet")) + >>> + >>> import torchvision + >>> @MODELS.register_function("InceptionV3") + >>> def get_inception_v3(pretrained=False, progress=True): + >>> return torchvision.models.inception_v3(pretrained=pretrained, progress=progress) + >>> inception_v3 = MODELS.build(dict(type='InceptionV3', pretrained=True)) + + Args: + name (str): Registry name. + build_func (func, None): Instance construct function. Default is build_from_config. + allow_types (tuple): Indicates how to construct the instance, by constructing class or invoking function. + """ + + def __init__(self, name, build_func=None, allow_types=("class", "function")): + self.name = name + self.allow_types = allow_types + self.class_map = {} + self.func_map = {} + self.build_func = build_func or build_from_config + + def get(self, req_type): + return self.class_map.get(req_type) or self.func_map.get(req_type) + + def build(self, *args, **kwargs): + return self.build_func(*args, **kwargs, registry=self) + + def register_class(self, name=None): + def _register(cls): + if not inspect.isclass(cls): + raise TypeError(f"Module must be type class, got {type(cls)}") + if "class" not in self.allow_types: + raise TypeError(f"Register {self.name} only allows type {self.allow_types}, got class") + module_name = name or cls.__name__ + if module_name in self.class_map: + warnings.warn(f"Class {module_name} already registered by {self.class_map[module_name]}, " + f"will be replaced by {cls}") + self.class_map[module_name] = cls + return cls + + return _register + + def register_function(self, name=None): + def _register(func): + if not inspect.isfunction(func): + raise TypeError(f"Registry must be type function, got {type(func)}") + if "function" not in self.allow_types: + raise TypeError(f"Registry {self.name} only allows type {self.allow_types}, got function") + func_name = name or func.__name__ + if func_name in self.class_map: + warnings.warn(f"Function {func_name} already registered by {self.func_map[func_name]}, " + f"will be replaced by {func}") + self.func_map[func_name] = func + return func + + return _register + + def _list(self): + keys = sorted(list(self.class_map.keys()) + list(self.func_map.keys())) + descriptions = [] + for key in keys: + if key in self.class_map: + descriptions.append(f"{key}: {self.class_map[key]}") + else: + descriptions.append( + f"{key}: ") + return "\n".join(descriptions) + + def __repr__(self): + description = self._list() + description = '\n'.join(['\t' + s for s in description.split('\n')]) + return f"{self.__class__.__name__} [{self.name}], \n" + description + + diff --git a/utils/registry_class.py b/utils/registry_class.py new file mode 100644 index 0000000..b35c169 --- /dev/null +++ b/utils/registry_class.py @@ -0,0 +1,19 @@ +from .registry import Registry, build_from_config + +def build_func(cfg, registry, **kwargs): + """ + Except for config, if passing a list of dataset config, then return the concat type of it + """ + return build_from_config(cfg, registry, **kwargs) + +AUTO_ENCODER = Registry("AUTO_ENCODER", build_func=build_func) +DATASETS = Registry("DATASETS", build_func=build_func) +DIFFUSION = Registry("DIFFUSION", build_func=build_func) +DISTRIBUTION = Registry("DISTRIBUTION", build_func=build_func) +EMBEDDER = Registry("EMBEDDER", build_func=build_func) +ENGINE = Registry("ENGINE", build_func=build_func) +INFER_ENGINE = Registry("INFER_ENGINE", build_func=build_func) +MODEL = Registry("MODEL", build_func=build_func) +PRETRAIN = Registry("PRETRAIN", build_func=build_func) +VISUAL = Registry("VISUAL", build_func=build_func) +EMBEDMANAGER = Registry("EMBEDMANAGER", build_func=build_func) diff --git a/utils/seed.py b/utils/seed.py new file mode 100644 index 0000000..d3656cd --- /dev/null +++ b/utils/seed.py @@ -0,0 +1,11 @@ +import torch +import random +import numpy as np + + +def setup_seed(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + torch.backends.cudnn.deterministic = True \ No newline at end of file diff --git a/utils/transforms.py b/utils/transforms.py new file mode 100644 index 0000000..923cb2b --- /dev/null +++ b/utils/transforms.py @@ -0,0 +1,353 @@ +import torch +import torchvision.transforms.functional as F +import random +import math +import numpy as np +from PIL import Image, ImageFilter + +__all__ = ['Compose', 'Resize', 'Rescale', 'CenterCrop', 'CenterCropV2', 'CenterCropWide', 'RandomCrop', 'RandomCropV2', 'RandomHFlip',\ + 'GaussianBlur', 'ColorJitter', 'RandomGray', 'ToTensor', 'Normalize', "ResizeRandomCrop", "ExtractResizeRandomCrop", "ExtractResizeAssignCrop"] + + +class Compose(object): + + def __init__(self, transforms): + self.transforms = transforms + + def __getitem__(self, index): + if isinstance(index, slice): + return Compose(self.transforms[index]) + else: + return self.transforms[index] + + def __len__(self): + return len(self.transforms) + + def __call__(self, rgb): + for t in self.transforms: + rgb = t(rgb) + return rgb + +class Resize(object): + + def __init__(self, size=256): + if isinstance(size, int): + size = (size, size) + self.size = size + + def __call__(self, rgb): + if isinstance(rgb, list): + rgb = [u.resize(self.size, Image.BILINEAR) for u in rgb] + else: + rgb = rgb.resize(self.size, Image.BILINEAR) + return rgb + +class Rescale(object): + + def __init__(self, size=256, interpolation=Image.BILINEAR): + self.size = size + self.interpolation = interpolation + + def __call__(self, rgb): + w, h = rgb[0].size + scale = self.size / min(w, h) + out_w, out_h = int(round(w * scale)), int(round(h * scale)) + rgb = [u.resize((out_w, out_h), self.interpolation) for u in rgb] + return rgb + +class CenterCrop(object): + + def __init__(self, size=224): + self.size = size + + def __call__(self, rgb): + w, h = rgb[0].size + assert min(w, h) >= self.size + x1 = (w - self.size) // 2 + y1 = (h - self.size) // 2 + rgb = [u.crop((x1, y1, x1 + self.size, y1 + self.size)) for u in rgb] + return rgb + +class ResizeRandomCrop(object): + + def __init__(self, size=256, size_short=292): + self.size = size + # self.min_area = min_area + self.size_short = size_short + + def __call__(self, rgb): + + # consistent crop between rgb and m + while min(rgb[0].size) >= 2 * self.size_short: + rgb = [u.resize((u.width // 2, u.height // 2), resample=Image.BOX) for u in rgb] + scale = self.size_short / min(rgb[0].size) + rgb = [u.resize((round(scale * u.width), round(scale * u.height)), resample=Image.BICUBIC) for u in rgb] + out_w = self.size + out_h = self.size + w, h = rgb[0].size # (518, 292) + x1 = random.randint(0, w - out_w) + y1 = random.randint(0, h - out_h) + + rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb] + # rgb = [u.resize((self.size, self.size), Image.BILINEAR) for u in rgb] + # # center crop + # x1 = (img[0].width - self.size) // 2 + # y1 = (img[0].height - self.size) // 2 + # img = [u.crop((x1, y1, x1 + self.size, y1 + self.size)) for u in img] + return rgb + + + +class ExtractResizeRandomCrop(object): + + def __init__(self, size=256, size_short=292): + self.size = size + # self.min_area = min_area + self.size_short = size_short + + def __call__(self, rgb): + + # consistent crop between rgb and m + while min(rgb[0].size) >= 2 * self.size_short: + rgb = [u.resize((u.width // 2, u.height // 2), resample=Image.BOX) for u in rgb] + scale = self.size_short / min(rgb[0].size) + rgb = [u.resize((round(scale * u.width), round(scale * u.height)), resample=Image.BICUBIC) for u in rgb] + out_w = self.size + out_h = self.size + w, h = rgb[0].size # (518, 292) + x1 = random.randint(0, w - out_w) + y1 = random.randint(0, h - out_h) + + rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb] + wh = [x1, y1, x1 + out_w, y1 + out_h] + return rgb, wh + + +class ExtractResizeAssignCrop(object): + + def __init__(self, size=256, size_short=292): + self.size = size + # self.min_area = min_area + self.size_short = size_short + + def __call__(self, rgb, wh): + + # consistent crop between rgb and m + while min(rgb[0].size) >= 2 * self.size_short: + rgb = [u.resize((u.width // 2, u.height // 2), resample=Image.BOX) for u in rgb] + scale = self.size_short / min(rgb[0].size) + rgb = [u.resize((round(scale * u.width), round(scale * u.height)), resample=Image.BICUBIC) for u in rgb] + + rgb = [u.crop(wh) for u in rgb] + rgb = [u.resize((self.size, self.size), Image.BILINEAR) for u in rgb] + return rgb + +class CenterCropV2(object): + def __init__(self, size): + self.size = size + + def __call__(self, img): + # fast resize + while min(img[0].size) >= 2 * self.size: + img = [u.resize((u.width // 2, u.height // 2), resample=Image.BOX) for u in img] + scale = self.size / min(img[0].size) + img = [u.resize((round(scale * u.width), round(scale * u.height)), resample=Image.BICUBIC) for u in img] + + # center crop + x1 = (img[0].width - self.size) // 2 + y1 = (img[0].height - self.size) // 2 + img = [u.crop((x1, y1, x1 + self.size, y1 + self.size)) for u in img] + return img + + +class CenterCropWide(object): + def __init__(self, size, interpolation=Image.BOX): + self.size = size + self.interpolation = interpolation + + def __call__(self, img): + if isinstance(img, list): + scale = min(img[0].size[0]/self.size[0], img[0].size[1]/self.size[1]) + img = [u.resize((round(u.width // scale), round(u.height // scale)), resample=self.interpolation) for u in img] + + # center crop + x1 = (img[0].width - self.size[0]) // 2 + y1 = (img[0].height - self.size[1]) // 2 + img = [u.crop((x1, y1, x1 + self.size[0], y1 + self.size[1])) for u in img] + return img + else: + scale = min(img.size[0]/self.size[0], img.size[1]/self.size[1]) + img = img.resize((round(img.width // scale), round(img.height // scale)), resample=self.interpolation) + x1 = (img.width - self.size[0]) // 2 + y1 = (img.height - self.size[1]) // 2 + img = img.crop((x1, y1, x1 + self.size[0], y1 + self.size[1])) + return img + + + +class RandomCrop(object): + + def __init__(self, size=224, min_area=0.4): + self.size = size + self.min_area = min_area + + def __call__(self, rgb): + + # consistent crop between rgb and m + w, h = rgb[0].size + area = w * h + out_w, out_h = float('inf'), float('inf') + while out_w > w or out_h > h: + target_area = random.uniform(self.min_area, 1.0) * area + aspect_ratio = random.uniform(3. / 4., 4. / 3.) + out_w = int(round(math.sqrt(target_area * aspect_ratio))) + out_h = int(round(math.sqrt(target_area / aspect_ratio))) + x1 = random.randint(0, w - out_w) + y1 = random.randint(0, h - out_h) + + rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb] + rgb = [u.resize((self.size, self.size), Image.BILINEAR) for u in rgb] + + return rgb + +class RandomCropV2(object): + + def __init__(self, size=224, min_area=0.4, ratio=(3. / 4., 4. / 3.)): + if isinstance(size, (tuple, list)): + self.size = size + else: + self.size = (size, size) + self.min_area = min_area + self.ratio = ratio + + def _get_params(self, img): + width, height = img.size + area = height * width + + for _ in range(10): + target_area = random.uniform(self.min_area, 1.0) * area + log_ratio = (math.log(self.ratio[0]), math.log(self.ratio[1])) + aspect_ratio = math.exp(random.uniform(*log_ratio)) + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + if 0 < w <= width and 0 < h <= height: + i = random.randint(0, height - h) + j = random.randint(0, width - w) + return i, j, h, w + + # Fallback to central crop + in_ratio = float(width) / float(height) + if (in_ratio < min(self.ratio)): + w = width + h = int(round(w / min(self.ratio))) + elif (in_ratio > max(self.ratio)): + h = height + w = int(round(h * max(self.ratio))) + else: # whole image + w = width + h = height + i = (height - h) // 2 + j = (width - w) // 2 + return i, j, h, w + + def __call__(self, rgb): + i, j, h, w = self._get_params(rgb[0]) + rgb = [F.resized_crop(u, i, j, h, w, self.size) for u in rgb] + return rgb + +class RandomHFlip(object): + + def __init__(self, p=0.5): + self.p = p + + def __call__(self, rgb): + if random.random() < self.p: + rgb = [u.transpose(Image.FLIP_LEFT_RIGHT) for u in rgb] + return rgb + +class GaussianBlur(object): + + def __init__(self, sigmas=[0.1, 2.0], p=0.5): + self.sigmas = sigmas + self.p = p + + def __call__(self, rgb): + if random.random() < self.p: + sigma = random.uniform(*self.sigmas) + rgb = [u.filter(ImageFilter.GaussianBlur(radius=sigma)) for u in rgb] + return rgb + +class ColorJitter(object): + + def __init__(self, brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1, p=0.5): + self.brightness = brightness + self.contrast = contrast + self.saturation = saturation + self.hue = hue + self.p = p + + def __call__(self, rgb): + if random.random() < self.p: + brightness, contrast, saturation, hue = self._random_params() + transforms = [ + lambda f: F.adjust_brightness(f, brightness), + lambda f: F.adjust_contrast(f, contrast), + lambda f: F.adjust_saturation(f, saturation), + lambda f: F.adjust_hue(f, hue)] + random.shuffle(transforms) + for t in transforms: + rgb = [t(u) for u in rgb] + + return rgb + + def _random_params(self): + brightness = random.uniform( + max(0, 1 - self.brightness), 1 + self.brightness) + contrast = random.uniform( + max(0, 1 - self.contrast), 1 + self.contrast) + saturation = random.uniform( + max(0, 1 - self.saturation), 1 + self.saturation) + hue = random.uniform(-self.hue, self.hue) + return brightness, contrast, saturation, hue + +class RandomGray(object): + + def __init__(self, p=0.2): + self.p = p + + def __call__(self, rgb): + if random.random() < self.p: + rgb = [u.convert('L').convert('RGB') for u in rgb] + return rgb + +class ToTensor(object): + + def __call__(self, rgb): + if isinstance(rgb, list): + rgb = torch.stack([F.to_tensor(u) for u in rgb], dim=0) + else: + rgb = F.to_tensor(rgb) + + return rgb + +class Normalize(object): + + def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): + self.mean = mean + self.std = std + + def __call__(self, rgb): + rgb = rgb.clone() + rgb.clamp_(0, 1) + if not isinstance(self.mean, torch.Tensor): + self.mean = rgb.new_tensor(self.mean).view(-1) + if not isinstance(self.std, torch.Tensor): + self.std = rgb.new_tensor(self.std).view(-1) + if rgb.dim() == 4: + rgb.sub_(self.mean.view(1, -1, 1, 1)).div_(self.std.view(1, -1, 1, 1)) + elif rgb.dim() == 3: + rgb.sub_(self.mean.view(-1, 1, 1)).div_(self.std.view(-1, 1, 1)) + return rgb + diff --git a/utils/util.py b/utils/util.py new file mode 100644 index 0000000..b73866c --- /dev/null +++ b/utils/util.py @@ -0,0 +1,16 @@ +import torch + +def to_device(batch, device, non_blocking=False): + if isinstance(batch, (list, tuple)): + return type(batch)([ + to_device(u, device, non_blocking) + for u in batch]) + elif isinstance(batch, dict): + return type(batch)([ + (k, to_device(v, device, non_blocking)) + for k, v in batch.items()]) + elif isinstance(batch, torch.Tensor) and batch.device != device: + batch = batch.to(device, non_blocking=non_blocking) + else: + return batch + return batch diff --git a/utils/video_op.py b/utils/video_op.py new file mode 100644 index 0000000..05df086 --- /dev/null +++ b/utils/video_op.py @@ -0,0 +1,359 @@ +import os +import os.path as osp +import sys +import cv2 +import glob +import math +import torch +import gzip +import copy +import time +import json +import pickle +import base64 +import imageio +import hashlib +import requests +import binascii +import zipfile +# import skvideo.io +import numpy as np +from io import BytesIO +import urllib.request +import torch.nn.functional as F +import torchvision.utils as tvutils +from multiprocessing.pool import ThreadPool as Pool +from einops import rearrange +from PIL import Image, ImageDraw, ImageFont + + +def gen_text_image(captions, text_size): + num_char = int(38 * (text_size / text_size)) + font_size = int(text_size / 20) + font = ImageFont.truetype('data/font/DejaVuSans.ttf', size=font_size) + text_image_list = [] + for text in captions: + txt_img = Image.new("RGB", (text_size, text_size), color="white") + draw = ImageDraw.Draw(txt_img) + lines = "\n".join(text[start:start + num_char] for start in range(0, len(text), num_char)) + draw.text((0, 0), lines, fill="black", font=font) + txt_img = np.array(txt_img) + text_image_list.append(txt_img) + text_images = np.stack(text_image_list, axis=0) + text_images = torch.from_numpy(text_images) + return text_images + +@torch.no_grad() +def save_video_refimg_and_text( + local_path, + ref_frame, + gen_video, + captions, + mean=[0.5, 0.5, 0.5], + std=[0.5, 0.5, 0.5], + text_size=256, + nrow=4, + save_fps=8, + retry=5): + ''' + gen_video: BxCxFxHxW + ''' + nrow = max(int(gen_video.size(0) / 2), 1) + vid_mean = torch.tensor(mean, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw + vid_std = torch.tensor(std, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw + + text_images = gen_text_image(captions, text_size) # Tensor 8x256x256x3 + text_images = text_images.unsqueeze(1) # Tensor 8x1x256x256x3 + text_images = text_images.repeat_interleave(repeats=gen_video.size(2), dim=1) # 8x16x256x256x3 + + ref_frame = ref_frame.unsqueeze(2) + ref_frame = ref_frame.mul_(vid_std).add_(vid_mean) + ref_frame = ref_frame.repeat_interleave(repeats=gen_video.size(2), dim=2) # 8x16x256x256x3 + ref_frame.clamp_(0, 1) + ref_frame = ref_frame * 255.0 + ref_frame = rearrange(ref_frame, 'b c f h w -> b f h w c') + + gen_video = gen_video.mul_(vid_std).add_(vid_mean) # 8x3x16x256x384 + gen_video.clamp_(0, 1) + gen_video = gen_video * 255.0 + + images = rearrange(gen_video, 'b c f h w -> b f h w c') + images = torch.cat([ref_frame, images, text_images], dim=3) + + images = rearrange(images, '(r j) f h w c -> f (r h) (j w) c', r=nrow) + images = [(img.numpy()).astype('uint8') for img in images] + + for _ in [None] * retry: + try: + if len(images) == 1: + local_path = local_path + '.png' + cv2.imwrite(local_path, images[0][:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100]) + else: + local_path = local_path + '.mp4' + frame_dir = os.path.join(os.path.dirname(local_path), '%s_frames' % (os.path.basename(local_path))) + os.system(f'rm -rf {frame_dir}'); os.makedirs(frame_dir, exist_ok=True) + for fid, frame in enumerate(images): + tpth = os.path.join(frame_dir, '%04d.png' % (fid+1)) + cv2.imwrite(tpth, frame[:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100]) + cmd = f'ffmpeg -y -f image2 -loglevel quiet -framerate {save_fps} -i {frame_dir}/%04d.png -vcodec libx264 -crf 17 -pix_fmt yuv420p {local_path}' + os.system(cmd); os.system(f'rm -rf {frame_dir}') + # os.system(f'rm -rf {local_path}') + exception = None + break + except Exception as e: + exception = e + continue + + +@torch.no_grad() +def save_i2vgen_video( + local_path, + image_id, + gen_video, + captions, + mean=[0.5, 0.5, 0.5], + std=[0.5, 0.5, 0.5], + text_size=256, + retry=5, + save_fps = 8 +): + ''' + Save both the generated video and the input conditions. + ''' + vid_mean = torch.tensor(mean, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw + vid_std = torch.tensor(std, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw + + text_images = gen_text_image(captions, text_size) # Tensor 1x256x256x3 + text_images = text_images.unsqueeze(1) # Tensor 1x1x256x256x3 + text_images = text_images.repeat_interleave(repeats=gen_video.size(2), dim=1) # 1x16x256x256x3 + + image_id = image_id.unsqueeze(2) # B, C, F, H, W + image_id = image_id.repeat_interleave(repeats=gen_video.size(2), dim=2) # 1x3x32x256x448 + image_id = image_id.mul_(vid_std).add_(vid_mean) # 32x3x256x448 + image_id.clamp_(0, 1) + image_id = image_id * 255.0 + image_id = rearrange(image_id, 'b c f h w -> b f h w c') + + gen_video = gen_video.mul_(vid_std).add_(vid_mean) # 8x3x16x256x384 + gen_video.clamp_(0, 1) + gen_video = gen_video * 255.0 + + images = rearrange(gen_video, 'b c f h w -> b f h w c') + images = torch.cat([image_id, images, text_images], dim=3) + images = images[0] + images = [(img.numpy()).astype('uint8') for img in images] + + exception = None + for _ in [None] * retry: + try: + frame_dir = os.path.join(os.path.dirname(local_path), '%s_frames' % (os.path.basename(local_path))) + os.system(f'rm -rf {frame_dir}'); os.makedirs(frame_dir, exist_ok=True) + for fid, frame in enumerate(images): + tpth = os.path.join(frame_dir, '%04d.png' % (fid+1)) + cv2.imwrite(tpth, frame[:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100]) + cmd = f'ffmpeg -y -f image2 -loglevel quiet -framerate {save_fps} -i {frame_dir}/%04d.png -vcodec libx264 -crf 17 -pix_fmt yuv420p {local_path}' + os.system(cmd); os.system(f'rm -rf {frame_dir}') + break + except Exception as e: + exception = e + continue + + if exception is not None: + raise exception + + +@torch.no_grad() +def save_i2vgen_video_safe( + local_path, + gen_video, + captions, + mean=[0.5, 0.5, 0.5], + std=[0.5, 0.5, 0.5], + text_size=256, + retry=5, + save_fps = 8 +): + ''' + Save only the generated video, do not save the related reference conditions, and at the same time perform anomaly detection on the last frame. + ''' + vid_mean = torch.tensor(mean, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw + vid_std = torch.tensor(std, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw + + gen_video = gen_video.mul_(vid_std).add_(vid_mean) # 8x3x16x256x384 + gen_video.clamp_(0, 1) + gen_video = gen_video * 255.0 + + images = rearrange(gen_video, 'b c f h w -> b f h w c') + images = images[0] + images = [(img.numpy()).astype('uint8') for img in images] + num_image = len(images) + exception = None + for _ in [None] * retry: + try: + if num_image == 1: + local_path = local_path + '.png' + cv2.imwrite(local_path, images[0][:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100]) + else: + writer = imageio.get_writer(local_path, fps=save_fps, codec='libx264', quality=8) + for fid, frame in enumerate(images): + if fid == num_image-1: # Fix known bugs. + ratio = (np.sum((frame >= 117) & (frame <= 137)))/(frame.size) + if ratio > 0.4: continue + writer.append_data(frame) + writer.close() + break + except Exception as e: + exception = e + continue + + if exception is not None: + raise exception + + +@torch.no_grad() +def save_t2vhigen_video_safe( + local_path, + gen_video, + captions, + mean=[0.5, 0.5, 0.5], + std=[0.5, 0.5, 0.5], + text_size=256, + retry=5, + save_fps = 8 +): + ''' + Save only the generated video, do not save the related reference conditions, and at the same time perform anomaly detection on the last frame. + ''' + vid_mean = torch.tensor(mean, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw + vid_std = torch.tensor(std, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw + + gen_video = gen_video.mul_(vid_std).add_(vid_mean) # 8x3x16x256x384 + gen_video.clamp_(0, 1) + gen_video = gen_video * 255.0 + + images = rearrange(gen_video, 'b c f h w -> b f h w c') + images = images[0] + images = [(img.numpy()).astype('uint8') for img in images] + num_image = len(images) + exception = None + for _ in [None] * retry: + try: + if num_image == 1: + local_path = local_path + '.png' + cv2.imwrite(local_path, images[0][:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100]) + else: + frame_dir = os.path.join(os.path.dirname(local_path), '%s_frames' % (os.path.basename(local_path))) + os.system(f'rm -rf {frame_dir}'); os.makedirs(frame_dir, exist_ok=True) + for fid, frame in enumerate(images): + if fid == num_image-1: # Fix known bugs. + ratio = (np.sum((frame >= 117) & (frame <= 137)))/(frame.size) + if ratio > 0.4: continue + tpth = os.path.join(frame_dir, '%04d.png' % (fid+1)) + cv2.imwrite(tpth, frame[:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100]) + cmd = f'ffmpeg -y -f image2 -loglevel quiet -framerate {save_fps} -i {frame_dir}/%04d.png -vcodec libx264 -crf 17 -pix_fmt yuv420p {local_path}' + os.system(cmd) + os.system(f'rm -rf {frame_dir}') + break + except Exception as e: + exception = e + continue + + if exception is not None: + raise exception + + + + +@torch.no_grad() +def save_video_multiple_conditions_not_gif_horizontal_3col(local_path, video_tensor, model_kwargs, source_imgs, + mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5], nrow=8, retry=5, save_fps=8): + mean=torch.tensor(mean,device=video_tensor.device).view(1,-1,1,1,1)#ncfhw + std=torch.tensor(std,device=video_tensor.device).view(1,-1,1,1,1)#ncfhw + video_tensor = video_tensor.mul_(std).add_(mean) #### unnormalize back to [0,1] + video_tensor.clamp_(0, 1) + + b, c, n, h, w = video_tensor.shape + source_imgs = F.adaptive_avg_pool3d(source_imgs, (n, h, w)) + source_imgs = source_imgs.cpu() + + model_kwargs_channel3 = {} + for key, conditions in model_kwargs[0].items(): + + + if conditions.size(1) == 1: + conditions = torch.cat([conditions, conditions, conditions], dim=1) + conditions = F.adaptive_avg_pool3d(conditions, (n, h, w)) + if conditions.size(1) == 2: + conditions = torch.cat([conditions, conditions[:,:1,]], dim=1) + conditions = F.adaptive_avg_pool3d(conditions, (n, h, w)) + elif conditions.size(1) == 3: + conditions = F.adaptive_avg_pool3d(conditions, (n, h, w)) + elif conditions.size(1) == 4: # means it is a mask. + color = ((conditions[:, 0:3] + 1.)/2.) # .astype(np.float32) + alpha = conditions[:, 3:4] # .astype(np.float32) + conditions = color * alpha + 1.0 * (1.0 - alpha) + conditions = F.adaptive_avg_pool3d(conditions, (n, h, w)) + model_kwargs_channel3[key] = conditions.cpu() if conditions.is_cuda else conditions + + # filename = rand_name(suffix='.gif') + for _ in [None] * retry: + try: + vid_gif = rearrange(video_tensor, '(i j) c f h w -> c f (i h) (j w)', i = nrow) + + # cons_list = [rearrange(con, '(i j) c f h w -> c f (i h) (j w)', i = nrow) for _, con in model_kwargs_channel3.items()] + # vid_gif = torch.cat(cons_list + [vid_gif,], dim=3) #Uncomment this and previous line to compare output video with input pose frames + + vid_gif = vid_gif.permute(1,2,3,0) + + images = vid_gif * 255.0 + images = [(img.numpy()).astype('uint8') for img in images] + if len(images) == 1: + + local_path = local_path.replace('.mp4', '.png') + cv2.imwrite(local_path, images[0][:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100]) + # bucket.put_object_from_file(oss_key, local_path) + else: + + outputs = [] + for image_name in images: + x = Image.fromarray(image_name) + outputs.append(x) + from pathlib import Path + save_fmt = Path(local_path).suffix + + if save_fmt == ".mp4": + with imageio.get_writer(local_path, fps=save_fps) as writer: + for img in outputs: + img_array = np.array(img) # Convert PIL Image to numpy array + writer.append_data(img_array) + + elif save_fmt == ".gif": + outputs[0].save( + fp=local_path, + format="GIF", + append_images=outputs[1:], + save_all=True, + duration=(1 / save_fps * 1000), + loop=0, + ) + else: + raise ValueError("Unsupported file type. Use .mp4 or .gif.") + + # fourcc = cv2.VideoWriter_fourcc(*'mp4v') + # fps = save_fps + # image = images[0] + # media_writer = cv2.VideoWriter(local_path, fourcc, fps, (image.shape[1],image.shape[0])) + # for image_name in images: + # im = image_name[:,:,::-1] + # media_writer.write(im) + # media_writer.release() + + + exception = None + break + except Exception as e: + exception = e + continue + if exception is not None: + print('save video to {} failed, error: {}'.format(local_path, exception), flush=True) +