diff --git a/utils/__init__.py b/utils/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/utils/__pycache__/__init__.cpython-310.pyc b/utils/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 52a43aa..0000000 Binary files a/utils/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/utils/__pycache__/__init__.cpython-39.pyc b/utils/__pycache__/__init__.cpython-39.pyc deleted file mode 100644 index e0d4a3e..0000000 Binary files a/utils/__pycache__/__init__.cpython-39.pyc and /dev/null differ diff --git a/utils/__pycache__/assign_cfg.cpython-310.pyc b/utils/__pycache__/assign_cfg.cpython-310.pyc deleted file mode 100644 index 60b8b22..0000000 Binary files a/utils/__pycache__/assign_cfg.cpython-310.pyc and /dev/null differ diff --git a/utils/__pycache__/assign_cfg.cpython-39.pyc b/utils/__pycache__/assign_cfg.cpython-39.pyc deleted file mode 100644 index 116ba8f..0000000 Binary files a/utils/__pycache__/assign_cfg.cpython-39.pyc and /dev/null differ diff --git a/utils/__pycache__/config.cpython-310.pyc b/utils/__pycache__/config.cpython-310.pyc deleted file mode 100644 index d1ff3f1..0000000 Binary files a/utils/__pycache__/config.cpython-310.pyc and /dev/null differ diff --git a/utils/__pycache__/config.cpython-39.pyc b/utils/__pycache__/config.cpython-39.pyc deleted file mode 100644 index 0ec8010..0000000 Binary files a/utils/__pycache__/config.cpython-39.pyc and /dev/null differ diff --git a/utils/__pycache__/distributed.cpython-310.pyc b/utils/__pycache__/distributed.cpython-310.pyc deleted file mode 100644 index 1148ab6..0000000 Binary files a/utils/__pycache__/distributed.cpython-310.pyc and /dev/null differ diff --git a/utils/__pycache__/distributed.cpython-39.pyc b/utils/__pycache__/distributed.cpython-39.pyc deleted file mode 100644 index 2cdc5c9..0000000 Binary files a/utils/__pycache__/distributed.cpython-39.pyc and /dev/null differ diff --git a/utils/__pycache__/logging.cpython-310.pyc b/utils/__pycache__/logging.cpython-310.pyc deleted file mode 100644 index de4b9d8..0000000 Binary files a/utils/__pycache__/logging.cpython-310.pyc and /dev/null differ diff --git a/utils/__pycache__/logging.cpython-39.pyc b/utils/__pycache__/logging.cpython-39.pyc deleted file mode 100644 index 1e3b71d..0000000 Binary files a/utils/__pycache__/logging.cpython-39.pyc and /dev/null differ diff --git a/utils/__pycache__/multi_port.cpython-310.pyc b/utils/__pycache__/multi_port.cpython-310.pyc deleted file mode 100644 index a389343..0000000 Binary files a/utils/__pycache__/multi_port.cpython-310.pyc and /dev/null differ diff --git a/utils/__pycache__/multi_port.cpython-39.pyc b/utils/__pycache__/multi_port.cpython-39.pyc deleted file mode 100644 index bb7b57e..0000000 Binary files a/utils/__pycache__/multi_port.cpython-39.pyc and /dev/null differ diff --git a/utils/__pycache__/registry.cpython-310.pyc b/utils/__pycache__/registry.cpython-310.pyc deleted file mode 100644 index 6537c09..0000000 Binary files a/utils/__pycache__/registry.cpython-310.pyc and /dev/null differ diff --git a/utils/__pycache__/registry.cpython-39.pyc b/utils/__pycache__/registry.cpython-39.pyc deleted file mode 100644 index cc83bf6..0000000 Binary files a/utils/__pycache__/registry.cpython-39.pyc and /dev/null differ diff --git a/utils/__pycache__/registry_class.cpython-310.pyc b/utils/__pycache__/registry_class.cpython-310.pyc deleted file mode 100644 index e55fc76..0000000 Binary files a/utils/__pycache__/registry_class.cpython-310.pyc and /dev/null differ diff --git a/utils/__pycache__/registry_class.cpython-39.pyc b/utils/__pycache__/registry_class.cpython-39.pyc deleted file mode 100644 index 9e3460d..0000000 Binary files a/utils/__pycache__/registry_class.cpython-39.pyc and /dev/null differ diff --git a/utils/__pycache__/seed.cpython-310.pyc b/utils/__pycache__/seed.cpython-310.pyc deleted file mode 100644 index b666274..0000000 Binary files a/utils/__pycache__/seed.cpython-310.pyc and /dev/null differ diff --git a/utils/__pycache__/seed.cpython-39.pyc b/utils/__pycache__/seed.cpython-39.pyc deleted file mode 100644 index 5e0b322..0000000 Binary files a/utils/__pycache__/seed.cpython-39.pyc and /dev/null differ diff --git a/utils/__pycache__/transforms.cpython-310.pyc b/utils/__pycache__/transforms.cpython-310.pyc deleted file mode 100644 index 4576b9a..0000000 Binary files a/utils/__pycache__/transforms.cpython-310.pyc and /dev/null differ diff --git a/utils/__pycache__/transforms.cpython-39.pyc b/utils/__pycache__/transforms.cpython-39.pyc deleted file mode 100644 index b7bd188..0000000 Binary files a/utils/__pycache__/transforms.cpython-39.pyc and /dev/null differ diff --git a/utils/__pycache__/video_op.cpython-310.pyc b/utils/__pycache__/video_op.cpython-310.pyc deleted file mode 100644 index 0f2cac5..0000000 Binary files a/utils/__pycache__/video_op.cpython-310.pyc and /dev/null differ diff --git a/utils/__pycache__/video_op.cpython-39.pyc b/utils/__pycache__/video_op.cpython-39.pyc deleted file mode 100644 index bf39bf1..0000000 Binary files a/utils/__pycache__/video_op.cpython-39.pyc and /dev/null differ diff --git a/utils/assign_cfg.py b/utils/assign_cfg.py deleted file mode 100644 index 74e0c61..0000000 --- a/utils/assign_cfg.py +++ /dev/null @@ -1,158 +0,0 @@ -<<<<<<< HEAD -import os, yaml -from copy import deepcopy, copy - - -# def get prior and ldm config -def assign_prior_mudule_cfg(cfg): - ''' - ''' - # - prior_cfg = deepcopy(cfg) - vldm_cfg = deepcopy(cfg) - - with open(cfg.prior_cfg, 'r') as f: - _cfg_update = yaml.load(f.read(), Loader=yaml.SafeLoader) - # _cfg_update = _cfg_update.cfg_dict - for k, v in _cfg_update.items(): - if isinstance(v, dict) and k in cfg: - prior_cfg[k].update(v) - else: - prior_cfg[k] = v - - with open(cfg.vldm_cfg, 'r') as f: - _cfg_update = yaml.load(f.read(), Loader=yaml.SafeLoader) - # _cfg_update = _cfg_update.cfg_dict - for k, v in _cfg_update.items(): - if isinstance(v, dict) and k in cfg: - vldm_cfg[k].update(v) - else: - vldm_cfg[k] = v - - return prior_cfg, vldm_cfg - - -# def get prior and ldm config -def assign_vldm_vsr_mudule_cfg(cfg): - ''' - ''' - # - vldm_cfg = deepcopy(cfg) - vsr_cfg = deepcopy(cfg) - - with open(cfg.vldm_cfg, 'r') as f: - _cfg_update = yaml.load(f.read(), Loader=yaml.SafeLoader) - # _cfg_update = _cfg_update.cfg_dict - for k, v in _cfg_update.items(): - if isinstance(v, dict) and k in cfg: - vldm_cfg[k].update(v) - else: - vldm_cfg[k] = v - - with open(cfg.vsr_cfg, 'r') as f: - _cfg_update = yaml.load(f.read(), Loader=yaml.SafeLoader) - # _cfg_update = _cfg_update.cfg_dict - for k, v in _cfg_update.items(): - if isinstance(v, dict) and k in cfg: - vsr_cfg[k].update(v) - else: - vsr_cfg[k] = v - - return vldm_cfg, vsr_cfg - - -# def get prior and ldm config -def assign_signle_cfg(cfg, _cfg_update, tname): - ''' - ''' - # - vldm_cfg = deepcopy(cfg) - if os.path.exists(_cfg_update[tname]): - with open(_cfg_update[tname], 'r') as f: - _cfg_update = yaml.load(f.read(), Loader=yaml.SafeLoader) - # _cfg_update = _cfg_update.cfg_dict - for k, v in _cfg_update.items(): - if isinstance(v, dict) and k in cfg: - vldm_cfg[k].update(v) - else: - vldm_cfg[k] = v -======= -import os, yaml -from copy import deepcopy, copy - - -# def get prior and ldm config -def assign_prior_mudule_cfg(cfg): - ''' - ''' - # - prior_cfg = deepcopy(cfg) - vldm_cfg = deepcopy(cfg) - - with open(cfg.prior_cfg, 'r') as f: - _cfg_update = yaml.load(f.read(), Loader=yaml.SafeLoader) - # _cfg_update = _cfg_update.cfg_dict - for k, v in _cfg_update.items(): - if isinstance(v, dict) and k in cfg: - prior_cfg[k].update(v) - else: - prior_cfg[k] = v - - with open(cfg.vldm_cfg, 'r') as f: - _cfg_update = yaml.load(f.read(), Loader=yaml.SafeLoader) - # _cfg_update = _cfg_update.cfg_dict - for k, v in _cfg_update.items(): - if isinstance(v, dict) and k in cfg: - vldm_cfg[k].update(v) - else: - vldm_cfg[k] = v - - return prior_cfg, vldm_cfg - - -# def get prior and ldm config -def assign_vldm_vsr_mudule_cfg(cfg): - ''' - ''' - # - vldm_cfg = deepcopy(cfg) - vsr_cfg = deepcopy(cfg) - - with open(cfg.vldm_cfg, 'r') as f: - _cfg_update = yaml.load(f.read(), Loader=yaml.SafeLoader) - # _cfg_update = _cfg_update.cfg_dict - for k, v in _cfg_update.items(): - if isinstance(v, dict) and k in cfg: - vldm_cfg[k].update(v) - else: - vldm_cfg[k] = v - - with open(cfg.vsr_cfg, 'r') as f: - _cfg_update = yaml.load(f.read(), Loader=yaml.SafeLoader) - # _cfg_update = _cfg_update.cfg_dict - for k, v in _cfg_update.items(): - if isinstance(v, dict) and k in cfg: - vsr_cfg[k].update(v) - else: - vsr_cfg[k] = v - - return vldm_cfg, vsr_cfg - - -# def get prior and ldm config -def assign_signle_cfg(cfg, _cfg_update, tname): - ''' - ''' - # - vldm_cfg = deepcopy(cfg) - if os.path.exists(_cfg_update[tname]): - with open(_cfg_update[tname], 'r') as f: - _cfg_update = yaml.load(f.read(), Loader=yaml.SafeLoader) - # _cfg_update = _cfg_update.cfg_dict - for k, v in _cfg_update.items(): - if isinstance(v, dict) and k in cfg: - vldm_cfg[k].update(v) - else: - vldm_cfg[k] = v ->>>>>>> 626e7afc02230297b6f553675ea1c32c29971314 - return vldm_cfg \ No newline at end of file diff --git a/utils/config.py b/utils/config.py deleted file mode 100644 index d92b1df..0000000 --- a/utils/config.py +++ /dev/null @@ -1,488 +0,0 @@ -<<<<<<< HEAD -import os -import yaml -import json -import copy -import argparse - -from ..utils import logging -# logger = logging.get_logger(__name__) - -class Config(object): - def __init__(self, load=True, cfg_dict=None, cfg_level=None): - self._level = "cfg" + ("." + cfg_level if cfg_level is not None else "") - - current_directory = os.path.dirname(os.path.abspath(__file__)) - parent_directory = os.path.dirname(current_directory) - self.config_file_loc = os.path.join(parent_directory, 'configs/UniAnimate_infer.yaml') - - if load: - self.args = self._parse_args() - # logger.info("Loading config from {}.".format(self.args.cfg_file)) - self.need_initialization = True - cfg_base = self._load_yaml(self.args) # self._initialize_cfg() - cfg_dict = self._load_yaml(self.args) - cfg_dict = self._merge_cfg_from_base(cfg_base, cfg_dict) - cfg_dict = self._update_from_args(cfg_dict) - self.cfg_dict = cfg_dict - self._update_dict(cfg_dict) - - def _parse_args(self): - parser = argparse.ArgumentParser( - description="Argparser for configuring the codebase" - ) - parser.add_argument( - "--cfg", - dest="cfg_file", - help="Path to the configuration file", - default= self.config_file_loc - ) - parser.add_argument( - "--init_method", - help="Initialization method, includes TCP or shared file-system", - default="tcp://localhost:9999", - type=str, - ) - parser.add_argument( - '--debug', - action='store_true', - default=False, - help='Output debug information' - ) - parser.add_argument( - '--windows-standalone-build', - action='store_true', - default=False, - help='Indicates if the build is a standalone build for Windows' - ) - parser.add_argument( - "opts", - help="Other configurations", - default=None, - nargs=argparse.REMAINDER - ) - return parser.parse_args() - - - def _path_join(self, path_list): - path = "" - for p in path_list: - path+= p + '/' - return path[:-1] - - def _update_from_args(self, cfg_dict): - args = self.args - for var in vars(args): - cfg_dict[var] = getattr(args, var) - return cfg_dict - - def _initialize_cfg(self): - if self.need_initialization: - self.need_initialization = False - if os.path.exists('./configs/base.yaml'): - with open("./configs/base.yaml", 'r') as f: - cfg = yaml.load(f.read(), Loader=yaml.SafeLoader) - else: - with open(os.path.realpath(__file__).split('/')[-3] + "/configs/base.yaml", 'r') as f: - cfg = yaml.load(f.read(), Loader=yaml.SafeLoader) - return cfg - - def _load_yaml(self, args, file_name=""): - assert args.cfg_file is not None - if not file_name == "": # reading from base file - with open(file_name, 'r') as f: - cfg = yaml.load(f.read(), Loader=yaml.SafeLoader) - else: - if os.getcwd().split("/")[-1] == args.cfg_file.split("/")[0]: - args.cfg_file = args.cfg_file.replace(os.getcwd().split("/")[-1], "./") - with open(args.cfg_file, 'r') as f: - cfg = yaml.load(f.read(), Loader=yaml.SafeLoader) - file_name = args.cfg_file - - if "_BASE_RUN" not in cfg.keys() and "_BASE_MODEL" not in cfg.keys() and "_BASE" not in cfg.keys(): - # return cfg if the base file is being accessed - cfg = self._merge_cfg_from_command_update(args, cfg) - return cfg - - if "_BASE" in cfg.keys(): - if cfg["_BASE"][1] == '.': - prev_count = cfg["_BASE"].count('..') - cfg_base_file = self._path_join(file_name.split('/')[:(-1-cfg["_BASE"].count('..'))] + cfg["_BASE"].split('/')[prev_count:]) - else: - cfg_base_file = cfg["_BASE"].replace( - "./", - args.cfg_file.replace(args.cfg_file.split('/')[-1], "") - ) - cfg_base = self._load_yaml(args, cfg_base_file) - cfg = self._merge_cfg_from_base(cfg_base, cfg) - else: - if "_BASE_RUN" in cfg.keys(): - if cfg["_BASE_RUN"][1] == '.': - prev_count = cfg["_BASE_RUN"].count('..') - cfg_base_file = self._path_join(file_name.split('/')[:(-1-prev_count)] + cfg["_BASE_RUN"].split('/')[prev_count:]) - else: - cfg_base_file = cfg["_BASE_RUN"].replace( - "./", - args.cfg_file.replace(args.cfg_file.split('/')[-1], "") - ) - cfg_base = self._load_yaml(args, cfg_base_file) - cfg = self._merge_cfg_from_base(cfg_base, cfg, preserve_base=True) - if "_BASE_MODEL" in cfg.keys(): - if cfg["_BASE_MODEL"][1] == '.': - prev_count = cfg["_BASE_MODEL"].count('..') - cfg_base_file = self._path_join(file_name.split('/')[:(-1-cfg["_BASE_MODEL"].count('..'))] + cfg["_BASE_MODEL"].split('/')[prev_count:]) - else: - cfg_base_file = cfg["_BASE_MODEL"].replace( - "./", - args.cfg_file.replace(args.cfg_file.split('/')[-1], "") - ) - cfg_base = self._load_yaml(args, cfg_base_file) - cfg = self._merge_cfg_from_base(cfg_base, cfg) - cfg = self._merge_cfg_from_command(args, cfg) - return cfg - - def _merge_cfg_from_base(self, cfg_base, cfg_new, preserve_base=False): - for k,v in cfg_new.items(): - if k in cfg_base.keys(): - if isinstance(v, dict): - self._merge_cfg_from_base(cfg_base[k], v) - else: - cfg_base[k] = v - else: - if "BASE" not in k or preserve_base: - cfg_base[k] = v - return cfg_base - - def _merge_cfg_from_command_update(self, args, cfg): - if len(args.opts) == 0: - return cfg - - assert len(args.opts) % 2 == 0, 'Override list {} has odd length: {}.'.format( - args.opts, len(args.opts) - ) - keys = args.opts[0::2] - vals = args.opts[1::2] - - for key, val in zip(keys, vals): - cfg[key] = val - - return cfg - - def _merge_cfg_from_command(self, args, cfg): - assert len(args.opts) % 2 == 0, 'Override list {} has odd length: {}.'.format( - args.opts, len(args.opts) - ) - keys = args.opts[0::2] - vals = args.opts[1::2] - - # maximum supported depth 3 - for idx, key in enumerate(keys): - key_split = key.split('.') - assert len(key_split) <= 4, 'Key depth error. \nMaximum depth: 3\n Get depth: {}'.format( - len(key_split) - ) - assert key_split[0] in cfg.keys(), 'Non-existant key: {}.'.format( - key_split[0] - ) - if len(key_split) == 2: - assert key_split[1] in cfg[key_split[0]].keys(), 'Non-existant key: {}.'.format( - key - ) - elif len(key_split) == 3: - assert key_split[1] in cfg[key_split[0]].keys(), 'Non-existant key: {}.'.format( - key - ) - assert key_split[2] in cfg[key_split[0]][key_split[1]].keys(), 'Non-existant key: {}.'.format( - key - ) - elif len(key_split) == 4: - assert key_split[1] in cfg[key_split[0]].keys(), 'Non-existant key: {}.'.format( - key - ) - assert key_split[2] in cfg[key_split[0]][key_split[1]].keys(), 'Non-existant key: {}.'.format( - key - ) - assert key_split[3] in cfg[key_split[0]][key_split[1]][key_split[2]].keys(), 'Non-existant key: {}.'.format( - key - ) - if len(key_split) == 1: - cfg[key_split[0]] = vals[idx] - elif len(key_split) == 2: - cfg[key_split[0]][key_split[1]] = vals[idx] - elif len(key_split) == 3: - cfg[key_split[0]][key_split[1]][key_split[2]] = vals[idx] - elif len(key_split) == 4: - cfg[key_split[0]][key_split[1]][key_split[2]][key_split[3]] = vals[idx] - return cfg - - def _update_dict(self, cfg_dict): - def recur(key, elem): - if type(elem) is dict: - return key, Config(load=False, cfg_dict=elem, cfg_level=key) - else: - if type(elem) is str and elem[1:3]=="e-": - elem = float(elem) - return key, elem - dic = dict(recur(k, v) for k, v in cfg_dict.items()) - self.__dict__.update(dic) - - def get_args(self): - return self.args - - def __repr__(self): - return "{}\n".format(self.dump()) - - def dump(self): - return json.dumps(self.cfg_dict, indent=2) - - def deep_copy(self): - return copy.deepcopy(self) - -# if __name__ == '__main__': -# # debug -# cfg = Config(load=True) -======= -import os -import yaml -import json -import copy -import argparse - -from ..utils import logging -# logger = logging.get_logger(__name__) - -class Config(object): - def __init__(self, load=True, cfg_dict=None, cfg_level=None): - self._level = "cfg" + ("." + cfg_level if cfg_level is not None else "") - - current_directory = os.path.dirname(os.path.abspath(__file__)) - parent_directory = os.path.dirname(current_directory) - self.config_file_loc = os.path.join(parent_directory, 'configs/UniAnimate_infer.yaml') - - if load: - self.args = self._parse_args() - # logger.info("Loading config from {}.".format(self.args.cfg_file)) - self.need_initialization = True - cfg_base = self._load_yaml(self.args) # self._initialize_cfg() - cfg_dict = self._load_yaml(self.args) - cfg_dict = self._merge_cfg_from_base(cfg_base, cfg_dict) - cfg_dict = self._update_from_args(cfg_dict) - self.cfg_dict = cfg_dict - self._update_dict(cfg_dict) - - def _parse_args(self): - parser = argparse.ArgumentParser( - description="Argparser for configuring the codebase" - ) - parser.add_argument( - "--cfg", - dest="cfg_file", - help="Path to the configuration file", - default= self.config_file_loc - ) - parser.add_argument( - "--init_method", - help="Initialization method, includes TCP or shared file-system", - default="tcp://localhost:9999", - type=str, - ) - parser.add_argument( - '--debug', - action='store_true', - default=False, - help='Output debug information' - ) - parser.add_argument( - '--windows-standalone-build', - action='store_true', - default=False, - help='Indicates if the build is a standalone build for Windows' - ) - parser.add_argument( - "opts", - help="Other configurations", - default=None, - nargs=argparse.REMAINDER - ) - return parser.parse_args() - - - def _path_join(self, path_list): - path = "" - for p in path_list: - path+= p + '/' - return path[:-1] - - def _update_from_args(self, cfg_dict): - args = self.args - for var in vars(args): - cfg_dict[var] = getattr(args, var) - return cfg_dict - - def _initialize_cfg(self): - if self.need_initialization: - self.need_initialization = False - if os.path.exists('./configs/base.yaml'): - with open("./configs/base.yaml", 'r') as f: - cfg = yaml.load(f.read(), Loader=yaml.SafeLoader) - else: - with open(os.path.realpath(__file__).split('/')[-3] + "/configs/base.yaml", 'r') as f: - cfg = yaml.load(f.read(), Loader=yaml.SafeLoader) - return cfg - - def _load_yaml(self, args, file_name=""): - assert args.cfg_file is not None - if not file_name == "": # reading from base file - with open(file_name, 'r') as f: - cfg = yaml.load(f.read(), Loader=yaml.SafeLoader) - else: - if os.getcwd().split("/")[-1] == args.cfg_file.split("/")[0]: - args.cfg_file = args.cfg_file.replace(os.getcwd().split("/")[-1], "./") - with open(args.cfg_file, 'r') as f: - cfg = yaml.load(f.read(), Loader=yaml.SafeLoader) - file_name = args.cfg_file - - if "_BASE_RUN" not in cfg.keys() and "_BASE_MODEL" not in cfg.keys() and "_BASE" not in cfg.keys(): - # return cfg if the base file is being accessed - cfg = self._merge_cfg_from_command_update(args, cfg) - return cfg - - if "_BASE" in cfg.keys(): - if cfg["_BASE"][1] == '.': - prev_count = cfg["_BASE"].count('..') - cfg_base_file = self._path_join(file_name.split('/')[:(-1-cfg["_BASE"].count('..'))] + cfg["_BASE"].split('/')[prev_count:]) - else: - cfg_base_file = cfg["_BASE"].replace( - "./", - args.cfg_file.replace(args.cfg_file.split('/')[-1], "") - ) - cfg_base = self._load_yaml(args, cfg_base_file) - cfg = self._merge_cfg_from_base(cfg_base, cfg) - else: - if "_BASE_RUN" in cfg.keys(): - if cfg["_BASE_RUN"][1] == '.': - prev_count = cfg["_BASE_RUN"].count('..') - cfg_base_file = self._path_join(file_name.split('/')[:(-1-prev_count)] + cfg["_BASE_RUN"].split('/')[prev_count:]) - else: - cfg_base_file = cfg["_BASE_RUN"].replace( - "./", - args.cfg_file.replace(args.cfg_file.split('/')[-1], "") - ) - cfg_base = self._load_yaml(args, cfg_base_file) - cfg = self._merge_cfg_from_base(cfg_base, cfg, preserve_base=True) - if "_BASE_MODEL" in cfg.keys(): - if cfg["_BASE_MODEL"][1] == '.': - prev_count = cfg["_BASE_MODEL"].count('..') - cfg_base_file = self._path_join(file_name.split('/')[:(-1-cfg["_BASE_MODEL"].count('..'))] + cfg["_BASE_MODEL"].split('/')[prev_count:]) - else: - cfg_base_file = cfg["_BASE_MODEL"].replace( - "./", - args.cfg_file.replace(args.cfg_file.split('/')[-1], "") - ) - cfg_base = self._load_yaml(args, cfg_base_file) - cfg = self._merge_cfg_from_base(cfg_base, cfg) - cfg = self._merge_cfg_from_command(args, cfg) - return cfg - - def _merge_cfg_from_base(self, cfg_base, cfg_new, preserve_base=False): - for k,v in cfg_new.items(): - if k in cfg_base.keys(): - if isinstance(v, dict): - self._merge_cfg_from_base(cfg_base[k], v) - else: - cfg_base[k] = v - else: - if "BASE" not in k or preserve_base: - cfg_base[k] = v - return cfg_base - - def _merge_cfg_from_command_update(self, args, cfg): - if len(args.opts) == 0: - return cfg - - assert len(args.opts) % 2 == 0, 'Override list {} has odd length: {}.'.format( - args.opts, len(args.opts) - ) - keys = args.opts[0::2] - vals = args.opts[1::2] - - for key, val in zip(keys, vals): - cfg[key] = val - - return cfg - - def _merge_cfg_from_command(self, args, cfg): - assert len(args.opts) % 2 == 0, 'Override list {} has odd length: {}.'.format( - args.opts, len(args.opts) - ) - keys = args.opts[0::2] - vals = args.opts[1::2] - - # maximum supported depth 3 - for idx, key in enumerate(keys): - key_split = key.split('.') - assert len(key_split) <= 4, 'Key depth error. \nMaximum depth: 3\n Get depth: {}'.format( - len(key_split) - ) - assert key_split[0] in cfg.keys(), 'Non-existant key: {}.'.format( - key_split[0] - ) - if len(key_split) == 2: - assert key_split[1] in cfg[key_split[0]].keys(), 'Non-existant key: {}.'.format( - key - ) - elif len(key_split) == 3: - assert key_split[1] in cfg[key_split[0]].keys(), 'Non-existant key: {}.'.format( - key - ) - assert key_split[2] in cfg[key_split[0]][key_split[1]].keys(), 'Non-existant key: {}.'.format( - key - ) - elif len(key_split) == 4: - assert key_split[1] in cfg[key_split[0]].keys(), 'Non-existant key: {}.'.format( - key - ) - assert key_split[2] in cfg[key_split[0]][key_split[1]].keys(), 'Non-existant key: {}.'.format( - key - ) - assert key_split[3] in cfg[key_split[0]][key_split[1]][key_split[2]].keys(), 'Non-existant key: {}.'.format( - key - ) - if len(key_split) == 1: - cfg[key_split[0]] = vals[idx] - elif len(key_split) == 2: - cfg[key_split[0]][key_split[1]] = vals[idx] - elif len(key_split) == 3: - cfg[key_split[0]][key_split[1]][key_split[2]] = vals[idx] - elif len(key_split) == 4: - cfg[key_split[0]][key_split[1]][key_split[2]][key_split[3]] = vals[idx] - return cfg - - def _update_dict(self, cfg_dict): - def recur(key, elem): - if type(elem) is dict: - return key, Config(load=False, cfg_dict=elem, cfg_level=key) - else: - if type(elem) is str and elem[1:3]=="e-": - elem = float(elem) - return key, elem - dic = dict(recur(k, v) for k, v in cfg_dict.items()) - self.__dict__.update(dic) - - def get_args(self): - return self.args - - def __repr__(self): - return "{}\n".format(self.dump()) - - def dump(self): - return json.dumps(self.cfg_dict, indent=2) - - def deep_copy(self): - return copy.deepcopy(self) - -# if __name__ == '__main__': -# # debug -# cfg = Config(load=True) ->>>>>>> 626e7afc02230297b6f553675ea1c32c29971314 -# print(cfg.DATA) \ No newline at end of file diff --git a/utils/distributed.py b/utils/distributed.py deleted file mode 100644 index 284dbdb..0000000 --- a/utils/distributed.py +++ /dev/null @@ -1,863 +0,0 @@ -<<<<<<< HEAD -#!/usr/bin/env python3 -# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. - -import torch -import torch.nn.functional as F -import torch.distributed as dist -import functools -import pickle -import numpy as np -from collections import OrderedDict -from torch.autograd import Function - -__all__ = ['is_dist_initialized', - 'get_world_size', - 'get_rank', - 'new_group', - 'destroy_process_group', - 'barrier', - 'broadcast', - 'all_reduce', - 'reduce', - 'gather', - 'all_gather', - 'reduce_dict', - 'get_global_gloo_group', - 'generalized_all_gather', - 'generalized_gather', - 'scatter', - 'reduce_scatter', - 'send', - 'recv', - 'isend', - 'irecv', - 'shared_random_seed', - 'diff_all_gather', - 'diff_all_reduce', - 'diff_scatter', - 'diff_copy', - 'spherical_kmeans', - 'sinkhorn'] - -#-------------------------------- Distributed operations --------------------------------# - -def is_dist_initialized(): - return dist.is_available() and dist.is_initialized() - -def get_world_size(group=None): - return dist.get_world_size(group) if is_dist_initialized() else 1 - -def get_rank(group=None): - return dist.get_rank(group) if is_dist_initialized() else 0 - -def new_group(ranks=None, **kwargs): - if is_dist_initialized(): - return dist.new_group(ranks, **kwargs) - return None - -def destroy_process_group(): - if is_dist_initialized(): - dist.destroy_process_group() - -def barrier(group=None, **kwargs): - if get_world_size(group) > 1: - dist.barrier(group, **kwargs) - -def broadcast(tensor, src, group=None, **kwargs): - if get_world_size(group) > 1: - return dist.broadcast(tensor, src, group, **kwargs) - -def all_reduce(tensor, op=dist.ReduceOp.SUM, group=None, **kwargs): - if get_world_size(group) > 1: - return dist.all_reduce(tensor, op, group, **kwargs) - -def reduce(tensor, dst, op=dist.ReduceOp.SUM, group=None, **kwargs): - if get_world_size(group) > 1: - return dist.reduce(tensor, dst, op, group, **kwargs) - -def gather(tensor, dst=0, group=None, **kwargs): - rank = get_rank() # global rank - world_size = get_world_size(group) - if world_size == 1: - return [tensor] - tensor_list = [torch.empty_like(tensor) for _ in range(world_size)] if rank == dst else None - dist.gather(tensor, tensor_list, dst, group, **kwargs) - return tensor_list - -def all_gather(tensor, uniform_size=True, group=None, **kwargs): - world_size = get_world_size(group) - if world_size == 1: - return [tensor] - assert tensor.is_contiguous(), 'ops.all_gather requires the tensor to be contiguous()' - - if uniform_size: - tensor_list = [torch.empty_like(tensor) for _ in range(world_size)] - dist.all_gather(tensor_list, tensor, group, **kwargs) - return tensor_list - else: - # collect tensor shapes across GPUs - shape = tuple(tensor.shape) - shape_list = generalized_all_gather(shape, group) - - # flatten the tensor - tensor = tensor.reshape(-1) - size = int(np.prod(shape)) - size_list = [int(np.prod(u)) for u in shape_list] - max_size = max(size_list) - - # pad to maximum size - if size != max_size: - padding = tensor.new_zeros(max_size - size) - tensor = torch.cat([tensor, padding], dim=0) - - # all_gather - tensor_list = [torch.empty_like(tensor) for _ in range(world_size)] - dist.all_gather(tensor_list, tensor, group, **kwargs) - - # reshape tensors - tensor_list = [t[:n].view(s) for t, n, s in zip( - tensor_list, size_list, shape_list)] - return tensor_list - -@torch.no_grad() -def reduce_dict(input_dict, group=None, reduction='mean', **kwargs): - assert reduction in ['mean', 'sum'] - world_size = get_world_size(group) - if world_size == 1: - return input_dict - - # ensure that the orders of keys are consistent across processes - if isinstance(input_dict, OrderedDict): - keys = list(input_dict.keys) - else: - keys = sorted(input_dict.keys()) - vals = [input_dict[key] for key in keys] - vals = torch.stack(vals, dim=0) - dist.reduce(vals, dst=0, group=group, **kwargs) - if dist.get_rank(group) == 0 and reduction == 'mean': - vals /= world_size - dist.broadcast(vals, src=0, group=group, **kwargs) - reduced_dict = type(input_dict)([ - (key, val) for key, val in zip(keys, vals)]) - return reduced_dict - -@functools.lru_cache() -def get_global_gloo_group(): - backend = dist.get_backend() - assert backend in ['gloo', 'nccl'] - if backend == 'nccl': - return dist.new_group(backend='gloo') - else: - return dist.group.WORLD - -def _serialize_to_tensor(data, group): - backend = dist.get_backend(group) - assert backend in ['gloo', 'nccl'] - device = torch.device('cpu' if backend == 'gloo' else 'cuda') - - buffer = pickle.dumps(data) - if len(buffer) > 1024 ** 3: - logger = logging.getLogger(__name__) - logger.warning( - 'Rank {} trying to all-gather {:.2f} GB of data on device' - '{}'.format(get_rank(), len(buffer) / (1024 ** 3), device)) - storage = torch.ByteStorage.from_buffer(buffer) - tensor = torch.ByteTensor(storage).to(device=device) - return tensor - -def _pad_to_largest_tensor(tensor, group): - world_size = dist.get_world_size(group=group) - assert world_size >= 1, \ - 'gather/all_gather must be called from ranks within' \ - 'the give group!' - local_size = torch.tensor( - [tensor.numel()], dtype=torch.int64, device=tensor.device) - size_list = [torch.zeros( - [1], dtype=torch.int64, device=tensor.device) - for _ in range(world_size)] - - # gather tensors and compute the maximum size - dist.all_gather(size_list, local_size, group=group) - size_list = [int(size.item()) for size in size_list] - max_size = max(size_list) - - # pad tensors to the same size - if local_size != max_size: - padding = torch.zeros( - (max_size - local_size, ), - dtype=torch.uint8, device=tensor.device) - tensor = torch.cat((tensor, padding), dim=0) - return size_list, tensor - -def generalized_all_gather(data, group=None): - if get_world_size(group) == 1: - return [data] - if group is None: - group = get_global_gloo_group() - - tensor = _serialize_to_tensor(data, group) - size_list, tensor = _pad_to_largest_tensor(tensor, group) - max_size = max(size_list) - - # receiving tensors from all ranks - tensor_list = [torch.empty( - (max_size, ), dtype=torch.uint8, device=tensor.device) - for _ in size_list] - dist.all_gather(tensor_list, tensor, group=group) - - data_list = [] - for size, tensor in zip(size_list, tensor_list): - buffer = tensor.cpu().numpy().tobytes()[:size] - data_list.append(pickle.loads(buffer)) - return data_list - -def generalized_gather(data, dst=0, group=None): - world_size = get_world_size(group) - if world_size == 1: - return [data] - if group is None: - group = get_global_gloo_group() - rank = dist.get_rank() # global rank - - tensor = _serialize_to_tensor(data, group) - size_list, tensor = _pad_to_largest_tensor(tensor, group) - - # receiving tensors from all ranks to dst - if rank == dst: - max_size = max(size_list) - tensor_list = [torch.empty( - (max_size, ), dtype=torch.uint8, device=tensor.device) - for _ in size_list] - dist.gather(tensor, tensor_list, dst=dst, group=group) - - data_list = [] - for size, tensor in zip(size_list, tensor_list): - buffer = tensor.cpu().numpy().tobytes()[:size] - data_list.append(pickle.loads(buffer)) - return data_list - else: - dist.gather(tensor, [], dst=dst, group=group) - return [] - -def scatter(data, scatter_list=None, src=0, group=None, **kwargs): - r"""NOTE: only supports CPU tensor communication. - """ - if get_world_size(group) > 1: - return dist.scatter(data, scatter_list, src, group, **kwargs) - -def reduce_scatter(output, input_list, op=dist.ReduceOp.SUM, group=None, **kwargs): - if get_world_size(group) > 1: - return dist.reduce_scatter(output, input_list, op, group, **kwargs) - -def send(tensor, dst, group=None, **kwargs): - if get_world_size(group) > 1: - assert tensor.is_contiguous(), 'ops.send requires the tensor to be contiguous()' - return dist.send(tensor, dst, group, **kwargs) - -def recv(tensor, src=None, group=None, **kwargs): - if get_world_size(group) > 1: - assert tensor.is_contiguous(), 'ops.recv requires the tensor to be contiguous()' - return dist.recv(tensor, src, group, **kwargs) - -def isend(tensor, dst, group=None, **kwargs): - if get_world_size(group) > 1: - assert tensor.is_contiguous(), 'ops.isend requires the tensor to be contiguous()' - return dist.isend(tensor, dst, group, **kwargs) - -def irecv(tensor, src=None, group=None, **kwargs): - if get_world_size(group) > 1: - assert tensor.is_contiguous(), 'ops.irecv requires the tensor to be contiguous()' - return dist.irecv(tensor, src, group, **kwargs) - -def shared_random_seed(group=None): - seed = np.random.randint(2 ** 31) - all_seeds = generalized_all_gather(seed, group) - return all_seeds[0] - -#-------------------------------- Differentiable operations --------------------------------# - -def _all_gather(x): - if not (dist.is_available() and dist.is_initialized()) or dist.get_world_size() == 1: - return x - rank = dist.get_rank() - world_size = dist.get_world_size() - tensors = [torch.empty_like(x) for _ in range(world_size)] - tensors[rank] = x - dist.all_gather(tensors, x) - return torch.cat(tensors, dim=0).contiguous() - -def _all_reduce(x): - if not (dist.is_available() and dist.is_initialized()) or dist.get_world_size() == 1: - return x - dist.all_reduce(x) - return x - -def _split(x): - if not (dist.is_available() and dist.is_initialized()) or dist.get_world_size() == 1: - return x - rank = dist.get_rank() - world_size = dist.get_world_size() - return x.chunk(world_size, dim=0)[rank].contiguous() - -class DiffAllGather(Function): - r"""Differentiable all-gather. - """ - @staticmethod - def symbolic(graph, input): - return _all_gather(input) - - @staticmethod - def forward(ctx, input): - return _all_gather(input) - - @staticmethod - def backward(ctx, grad_output): - return _split(grad_output) - -class DiffAllReduce(Function): - r"""Differentiable all-reducd. - """ - @staticmethod - def symbolic(graph, input): - return _all_reduce(input) - - @staticmethod - def forward(ctx, input): - return _all_reduce(input) - - @staticmethod - def backward(ctx, grad_output): - return grad_output - -class DiffScatter(Function): - r"""Differentiable scatter. - """ - @staticmethod - def symbolic(graph, input): - return _split(input) - - @staticmethod - def symbolic(ctx, input): - return _split(input) - - @staticmethod - def backward(ctx, grad_output): - return _all_gather(grad_output) - -class DiffCopy(Function): - r"""Differentiable copy that reduces all gradients during backward. - """ - @staticmethod - def symbolic(graph, input): - return input - - @staticmethod - def forward(ctx, input): - return input - - @staticmethod - def backward(ctx, grad_output): - return _all_reduce(grad_output) - -diff_all_gather = DiffAllGather.apply -diff_all_reduce = DiffAllReduce.apply -diff_scatter = DiffScatter.apply -diff_copy = DiffCopy.apply - -#-------------------------------- Distributed algorithms --------------------------------# - -@torch.no_grad() -def spherical_kmeans(feats, num_clusters, num_iters=10): - k, n, c = num_clusters, *feats.size() - ones = feats.new_ones(n, dtype=torch.long) - - # distributed settings - rank = get_rank() - world_size = get_world_size() - - # init clusters - rand_inds = torch.randperm(n)[:int(np.ceil(k / world_size))] - clusters = torch.cat(all_gather(feats[rand_inds]), dim=0)[:k] - - # variables - new_clusters = feats.new_zeros(k, c) - counts = feats.new_zeros(k, dtype=torch.long) - - # iterative Expectation-Maximization - for step in range(num_iters + 1): - # Expectation step - simmat = torch.mm(feats, clusters.t()) - scores, assigns = simmat.max(dim=1) - if step == num_iters: - break - - # Maximization step - new_clusters.zero_().scatter_add_(0, assigns.unsqueeze(1).repeat(1, c), feats) - all_reduce(new_clusters) - - counts.zero_() - counts.index_add_(0, assigns, ones) - all_reduce(counts) - - mask = (counts > 0) - clusters[mask] = new_clusters[mask] / counts[mask].view(-1, 1) - clusters = F.normalize(clusters, p=2, dim=1) - return clusters, assigns, scores - -@torch.no_grad() -def sinkhorn(Q, eps=0.5, num_iters=3): - # normalize Q - Q = torch.exp(Q / eps).t() - sum_Q = Q.sum() - all_reduce(sum_Q) - Q /= sum_Q - - # variables - n, m = Q.size() - u = Q.new_zeros(n) - r = Q.new_ones(n) / n - c = Q.new_ones(m) / (m * get_world_size()) - - # iterative update - cur_sum = Q.sum(dim=1) - all_reduce(cur_sum) - for i in range(num_iters): - u = cur_sum - Q *= (r / u).unsqueeze(1) - Q *= (c / Q.sum(dim=0)).unsqueeze(0) - cur_sum = Q.sum(dim=1) - all_reduce(cur_sum) - return (Q / Q.sum(dim=0, keepdim=True)).t().float() -======= -#!/usr/bin/env python3 -# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. - -import torch -import torch.nn.functional as F -import torch.distributed as dist -import functools -import pickle -import numpy as np -from collections import OrderedDict -from torch.autograd import Function - -__all__ = ['is_dist_initialized', - 'get_world_size', - 'get_rank', - 'new_group', - 'destroy_process_group', - 'barrier', - 'broadcast', - 'all_reduce', - 'reduce', - 'gather', - 'all_gather', - 'reduce_dict', - 'get_global_gloo_group', - 'generalized_all_gather', - 'generalized_gather', - 'scatter', - 'reduce_scatter', - 'send', - 'recv', - 'isend', - 'irecv', - 'shared_random_seed', - 'diff_all_gather', - 'diff_all_reduce', - 'diff_scatter', - 'diff_copy', - 'spherical_kmeans', - 'sinkhorn'] - -#-------------------------------- Distributed operations --------------------------------# - -def is_dist_initialized(): - return dist.is_available() and dist.is_initialized() - -def get_world_size(group=None): - return dist.get_world_size(group) if is_dist_initialized() else 1 - -def get_rank(group=None): - return dist.get_rank(group) if is_dist_initialized() else 0 - -def new_group(ranks=None, **kwargs): - if is_dist_initialized(): - return dist.new_group(ranks, **kwargs) - return None - -def destroy_process_group(): - if is_dist_initialized(): - dist.destroy_process_group() - -def barrier(group=None, **kwargs): - if get_world_size(group) > 1: - dist.barrier(group, **kwargs) - -def broadcast(tensor, src, group=None, **kwargs): - if get_world_size(group) > 1: - return dist.broadcast(tensor, src, group, **kwargs) - -def all_reduce(tensor, op=dist.ReduceOp.SUM, group=None, **kwargs): - if get_world_size(group) > 1: - return dist.all_reduce(tensor, op, group, **kwargs) - -def reduce(tensor, dst, op=dist.ReduceOp.SUM, group=None, **kwargs): - if get_world_size(group) > 1: - return dist.reduce(tensor, dst, op, group, **kwargs) - -def gather(tensor, dst=0, group=None, **kwargs): - rank = get_rank() # global rank - world_size = get_world_size(group) - if world_size == 1: - return [tensor] - tensor_list = [torch.empty_like(tensor) for _ in range(world_size)] if rank == dst else None - dist.gather(tensor, tensor_list, dst, group, **kwargs) - return tensor_list - -def all_gather(tensor, uniform_size=True, group=None, **kwargs): - world_size = get_world_size(group) - if world_size == 1: - return [tensor] - assert tensor.is_contiguous(), 'ops.all_gather requires the tensor to be contiguous()' - - if uniform_size: - tensor_list = [torch.empty_like(tensor) for _ in range(world_size)] - dist.all_gather(tensor_list, tensor, group, **kwargs) - return tensor_list - else: - # collect tensor shapes across GPUs - shape = tuple(tensor.shape) - shape_list = generalized_all_gather(shape, group) - - # flatten the tensor - tensor = tensor.reshape(-1) - size = int(np.prod(shape)) - size_list = [int(np.prod(u)) for u in shape_list] - max_size = max(size_list) - - # pad to maximum size - if size != max_size: - padding = tensor.new_zeros(max_size - size) - tensor = torch.cat([tensor, padding], dim=0) - - # all_gather - tensor_list = [torch.empty_like(tensor) for _ in range(world_size)] - dist.all_gather(tensor_list, tensor, group, **kwargs) - - # reshape tensors - tensor_list = [t[:n].view(s) for t, n, s in zip( - tensor_list, size_list, shape_list)] - return tensor_list - -@torch.no_grad() -def reduce_dict(input_dict, group=None, reduction='mean', **kwargs): - assert reduction in ['mean', 'sum'] - world_size = get_world_size(group) - if world_size == 1: - return input_dict - - # ensure that the orders of keys are consistent across processes - if isinstance(input_dict, OrderedDict): - keys = list(input_dict.keys) - else: - keys = sorted(input_dict.keys()) - vals = [input_dict[key] for key in keys] - vals = torch.stack(vals, dim=0) - dist.reduce(vals, dst=0, group=group, **kwargs) - if dist.get_rank(group) == 0 and reduction == 'mean': - vals /= world_size - dist.broadcast(vals, src=0, group=group, **kwargs) - reduced_dict = type(input_dict)([ - (key, val) for key, val in zip(keys, vals)]) - return reduced_dict - -@functools.lru_cache() -def get_global_gloo_group(): - backend = dist.get_backend() - assert backend in ['gloo', 'nccl'] - if backend == 'nccl': - return dist.new_group(backend='gloo') - else: - return dist.group.WORLD - -def _serialize_to_tensor(data, group): - backend = dist.get_backend(group) - assert backend in ['gloo', 'nccl'] - device = torch.device('cpu' if backend == 'gloo' else 'cuda') - - buffer = pickle.dumps(data) - if len(buffer) > 1024 ** 3: - logger = logging.getLogger(__name__) - logger.warning( - 'Rank {} trying to all-gather {:.2f} GB of data on device' - '{}'.format(get_rank(), len(buffer) / (1024 ** 3), device)) - storage = torch.ByteStorage.from_buffer(buffer) - tensor = torch.ByteTensor(storage).to(device=device) - return tensor - -def _pad_to_largest_tensor(tensor, group): - world_size = dist.get_world_size(group=group) - assert world_size >= 1, \ - 'gather/all_gather must be called from ranks within' \ - 'the give group!' - local_size = torch.tensor( - [tensor.numel()], dtype=torch.int64, device=tensor.device) - size_list = [torch.zeros( - [1], dtype=torch.int64, device=tensor.device) - for _ in range(world_size)] - - # gather tensors and compute the maximum size - dist.all_gather(size_list, local_size, group=group) - size_list = [int(size.item()) for size in size_list] - max_size = max(size_list) - - # pad tensors to the same size - if local_size != max_size: - padding = torch.zeros( - (max_size - local_size, ), - dtype=torch.uint8, device=tensor.device) - tensor = torch.cat((tensor, padding), dim=0) - return size_list, tensor - -def generalized_all_gather(data, group=None): - if get_world_size(group) == 1: - return [data] - if group is None: - group = get_global_gloo_group() - - tensor = _serialize_to_tensor(data, group) - size_list, tensor = _pad_to_largest_tensor(tensor, group) - max_size = max(size_list) - - # receiving tensors from all ranks - tensor_list = [torch.empty( - (max_size, ), dtype=torch.uint8, device=tensor.device) - for _ in size_list] - dist.all_gather(tensor_list, tensor, group=group) - - data_list = [] - for size, tensor in zip(size_list, tensor_list): - buffer = tensor.cpu().numpy().tobytes()[:size] - data_list.append(pickle.loads(buffer)) - return data_list - -def generalized_gather(data, dst=0, group=None): - world_size = get_world_size(group) - if world_size == 1: - return [data] - if group is None: - group = get_global_gloo_group() - rank = dist.get_rank() # global rank - - tensor = _serialize_to_tensor(data, group) - size_list, tensor = _pad_to_largest_tensor(tensor, group) - - # receiving tensors from all ranks to dst - if rank == dst: - max_size = max(size_list) - tensor_list = [torch.empty( - (max_size, ), dtype=torch.uint8, device=tensor.device) - for _ in size_list] - dist.gather(tensor, tensor_list, dst=dst, group=group) - - data_list = [] - for size, tensor in zip(size_list, tensor_list): - buffer = tensor.cpu().numpy().tobytes()[:size] - data_list.append(pickle.loads(buffer)) - return data_list - else: - dist.gather(tensor, [], dst=dst, group=group) - return [] - -def scatter(data, scatter_list=None, src=0, group=None, **kwargs): - r"""NOTE: only supports CPU tensor communication. - """ - if get_world_size(group) > 1: - return dist.scatter(data, scatter_list, src, group, **kwargs) - -def reduce_scatter(output, input_list, op=dist.ReduceOp.SUM, group=None, **kwargs): - if get_world_size(group) > 1: - return dist.reduce_scatter(output, input_list, op, group, **kwargs) - -def send(tensor, dst, group=None, **kwargs): - if get_world_size(group) > 1: - assert tensor.is_contiguous(), 'ops.send requires the tensor to be contiguous()' - return dist.send(tensor, dst, group, **kwargs) - -def recv(tensor, src=None, group=None, **kwargs): - if get_world_size(group) > 1: - assert tensor.is_contiguous(), 'ops.recv requires the tensor to be contiguous()' - return dist.recv(tensor, src, group, **kwargs) - -def isend(tensor, dst, group=None, **kwargs): - if get_world_size(group) > 1: - assert tensor.is_contiguous(), 'ops.isend requires the tensor to be contiguous()' - return dist.isend(tensor, dst, group, **kwargs) - -def irecv(tensor, src=None, group=None, **kwargs): - if get_world_size(group) > 1: - assert tensor.is_contiguous(), 'ops.irecv requires the tensor to be contiguous()' - return dist.irecv(tensor, src, group, **kwargs) - -def shared_random_seed(group=None): - seed = np.random.randint(2 ** 31) - all_seeds = generalized_all_gather(seed, group) - return all_seeds[0] - -#-------------------------------- Differentiable operations --------------------------------# - -def _all_gather(x): - if not (dist.is_available() and dist.is_initialized()) or dist.get_world_size() == 1: - return x - rank = dist.get_rank() - world_size = dist.get_world_size() - tensors = [torch.empty_like(x) for _ in range(world_size)] - tensors[rank] = x - dist.all_gather(tensors, x) - return torch.cat(tensors, dim=0).contiguous() - -def _all_reduce(x): - if not (dist.is_available() and dist.is_initialized()) or dist.get_world_size() == 1: - return x - dist.all_reduce(x) - return x - -def _split(x): - if not (dist.is_available() and dist.is_initialized()) or dist.get_world_size() == 1: - return x - rank = dist.get_rank() - world_size = dist.get_world_size() - return x.chunk(world_size, dim=0)[rank].contiguous() - -class DiffAllGather(Function): - r"""Differentiable all-gather. - """ - @staticmethod - def symbolic(graph, input): - return _all_gather(input) - - @staticmethod - def forward(ctx, input): - return _all_gather(input) - - @staticmethod - def backward(ctx, grad_output): - return _split(grad_output) - -class DiffAllReduce(Function): - r"""Differentiable all-reducd. - """ - @staticmethod - def symbolic(graph, input): - return _all_reduce(input) - - @staticmethod - def forward(ctx, input): - return _all_reduce(input) - - @staticmethod - def backward(ctx, grad_output): - return grad_output - -class DiffScatter(Function): - r"""Differentiable scatter. - """ - @staticmethod - def symbolic(graph, input): - return _split(input) - - @staticmethod - def symbolic(ctx, input): - return _split(input) - - @staticmethod - def backward(ctx, grad_output): - return _all_gather(grad_output) - -class DiffCopy(Function): - r"""Differentiable copy that reduces all gradients during backward. - """ - @staticmethod - def symbolic(graph, input): - return input - - @staticmethod - def forward(ctx, input): - return input - - @staticmethod - def backward(ctx, grad_output): - return _all_reduce(grad_output) - -diff_all_gather = DiffAllGather.apply -diff_all_reduce = DiffAllReduce.apply -diff_scatter = DiffScatter.apply -diff_copy = DiffCopy.apply - -#-------------------------------- Distributed algorithms --------------------------------# - -@torch.no_grad() -def spherical_kmeans(feats, num_clusters, num_iters=10): - k, n, c = num_clusters, *feats.size() - ones = feats.new_ones(n, dtype=torch.long) - - # distributed settings - rank = get_rank() - world_size = get_world_size() - - # init clusters - rand_inds = torch.randperm(n)[:int(np.ceil(k / world_size))] - clusters = torch.cat(all_gather(feats[rand_inds]), dim=0)[:k] - - # variables - new_clusters = feats.new_zeros(k, c) - counts = feats.new_zeros(k, dtype=torch.long) - - # iterative Expectation-Maximization - for step in range(num_iters + 1): - # Expectation step - simmat = torch.mm(feats, clusters.t()) - scores, assigns = simmat.max(dim=1) - if step == num_iters: - break - - # Maximization step - new_clusters.zero_().scatter_add_(0, assigns.unsqueeze(1).repeat(1, c), feats) - all_reduce(new_clusters) - - counts.zero_() - counts.index_add_(0, assigns, ones) - all_reduce(counts) - - mask = (counts > 0) - clusters[mask] = new_clusters[mask] / counts[mask].view(-1, 1) - clusters = F.normalize(clusters, p=2, dim=1) - return clusters, assigns, scores - -@torch.no_grad() -def sinkhorn(Q, eps=0.5, num_iters=3): - # normalize Q - Q = torch.exp(Q / eps).t() - sum_Q = Q.sum() - all_reduce(sum_Q) - Q /= sum_Q - - # variables - n, m = Q.size() - u = Q.new_zeros(n) - r = Q.new_ones(n) / n - c = Q.new_ones(m) / (m * get_world_size()) - - # iterative update - cur_sum = Q.sum(dim=1) - all_reduce(cur_sum) - for i in range(num_iters): - u = cur_sum - Q *= (r / u).unsqueeze(1) - Q *= (c / Q.sum(dim=0)).unsqueeze(0) - cur_sum = Q.sum(dim=1) - all_reduce(cur_sum) - return (Q / Q.sum(dim=0, keepdim=True)).t().float() ->>>>>>> 626e7afc02230297b6f553675ea1c32c29971314 diff --git a/utils/logging.py b/utils/logging.py deleted file mode 100644 index f5d0758..0000000 --- a/utils/logging.py +++ /dev/null @@ -1,183 +0,0 @@ -<<<<<<< HEAD -#!/usr/bin/env python3 -# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. - -"""Logging.""" - -import builtins -import decimal -import functools -import logging -import os -import sys -from ..lib import simplejson -# from fvcore.common.file_io import PathManager - -from ..utils import distributed as du - - -def _suppress_print(): - """ - Suppresses printing from the current process. - """ - - def print_pass(*objects, sep=" ", end="\n", file=sys.stdout, flush=False): - pass - - builtins.print = print_pass - - -# @functools.lru_cache(maxsize=None) -# def _cached_log_stream(filename): -# return PathManager.open(filename, "a") - - -def setup_logging(cfg, log_file): - """ - Sets up the logging for multiple processes. Only enable the logging for the - master process, and suppress logging for the non-master processes. - """ - if du.is_master_proc(): - # Enable logging for the master process. - logging.root.handlers = [] - else: - # Suppress logging for non-master processes. - _suppress_print() - - logger = logging.getLogger() - logger.setLevel(logging.INFO) - logger.propagate = False - plain_formatter = logging.Formatter( - "[%(asctime)s][%(levelname)s] %(name)s: %(lineno)4d: %(message)s", - datefmt="%m/%d %H:%M:%S", - ) - - if du.is_master_proc(): - ch = logging.StreamHandler(stream=sys.stdout) - ch.setLevel(logging.DEBUG) - ch.setFormatter(plain_formatter) - logger.addHandler(ch) - - if log_file is not None and du.is_master_proc(du.get_world_size()): - filename = os.path.join(cfg.OUTPUT_DIR, log_file) - fh = logging.FileHandler(filename) - fh.setLevel(logging.DEBUG) - fh.setFormatter(plain_formatter) - logger.addHandler(fh) - - -def get_logger(name): - """ - Retrieve the logger with the specified name or, if name is None, return a - logger which is the root logger of the hierarchy. - Args: - name (string): name of the logger. - """ - return logging.getLogger(name) - - -def log_json_stats(stats): - """ - Logs json stats. - Args: - stats (dict): a dictionary of statistical information to log. - """ - stats = { - k: decimal.Decimal("{:.6f}".format(v)) if isinstance(v, float) else v - for k, v in stats.items() - } - json_stats = simplejson.dumps(stats, sort_keys=True, use_decimal=True) - logger = get_logger(__name__) - logger.info("{:s}".format(json_stats)) -======= -#!/usr/bin/env python3 -# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. - -"""Logging.""" - -import builtins -import decimal -import functools -import logging -import os -import sys -from ..lib import simplejson -# from fvcore.common.file_io import PathManager - -from ..utils import distributed as du - - -def _suppress_print(): - """ - Suppresses printing from the current process. - """ - - def print_pass(*objects, sep=" ", end="\n", file=sys.stdout, flush=False): - pass - - builtins.print = print_pass - - -# @functools.lru_cache(maxsize=None) -# def _cached_log_stream(filename): -# return PathManager.open(filename, "a") - - -def setup_logging(cfg, log_file): - """ - Sets up the logging for multiple processes. Only enable the logging for the - master process, and suppress logging for the non-master processes. - """ - if du.is_master_proc(): - # Enable logging for the master process. - logging.root.handlers = [] - else: - # Suppress logging for non-master processes. - _suppress_print() - - logger = logging.getLogger() - logger.setLevel(logging.INFO) - logger.propagate = False - plain_formatter = logging.Formatter( - "[%(asctime)s][%(levelname)s] %(name)s: %(lineno)4d: %(message)s", - datefmt="%m/%d %H:%M:%S", - ) - - if du.is_master_proc(): - ch = logging.StreamHandler(stream=sys.stdout) - ch.setLevel(logging.DEBUG) - ch.setFormatter(plain_formatter) - logger.addHandler(ch) - - if log_file is not None and du.is_master_proc(du.get_world_size()): - filename = os.path.join(cfg.OUTPUT_DIR, log_file) - fh = logging.FileHandler(filename) - fh.setLevel(logging.DEBUG) - fh.setFormatter(plain_formatter) - logger.addHandler(fh) - - -def get_logger(name): - """ - Retrieve the logger with the specified name or, if name is None, return a - logger which is the root logger of the hierarchy. - Args: - name (string): name of the logger. - """ - return logging.getLogger(name) - - -def log_json_stats(stats): - """ - Logs json stats. - Args: - stats (dict): a dictionary of statistical information to log. - """ - stats = { - k: decimal.Decimal("{:.6f}".format(v)) if isinstance(v, float) else v - for k, v in stats.items() - } - json_stats = simplejson.dumps(stats, sort_keys=True, use_decimal=True) - logger = get_logger(__name__) - logger.info("{:s}".format(json_stats)) ->>>>>>> 626e7afc02230297b6f553675ea1c32c29971314 diff --git a/utils/mp4_to_gif.py b/utils/mp4_to_gif.py deleted file mode 100644 index 2be4c04..0000000 --- a/utils/mp4_to_gif.py +++ /dev/null @@ -1,34 +0,0 @@ -<<<<<<< HEAD -import os - - - -# source_mp4_dir = "outputs/UniAnimate_infer" -# target_gif_dir = "outputs/UniAnimate_infer_gif" - -source_mp4_dir = "outputs/UniAnimate_infer_long" -target_gif_dir = "outputs/UniAnimate_infer_long_gif" - -os.makedirs(target_gif_dir, exist_ok=True) -for video in os.listdir(source_mp4_dir): - video_dir = os.path.join(source_mp4_dir, video) - gif_dir = os.path.join(target_gif_dir, video.replace(".mp4", ".gif")) - cmd = f'ffmpeg -i {video_dir} {gif_dir}' -======= -import os - - - -# source_mp4_dir = "outputs/UniAnimate_infer" -# target_gif_dir = "outputs/UniAnimate_infer_gif" - -source_mp4_dir = "outputs/UniAnimate_infer_long" -target_gif_dir = "outputs/UniAnimate_infer_long_gif" - -os.makedirs(target_gif_dir, exist_ok=True) -for video in os.listdir(source_mp4_dir): - video_dir = os.path.join(source_mp4_dir, video) - gif_dir = os.path.join(target_gif_dir, video.replace(".mp4", ".gif")) - cmd = f'ffmpeg -i {video_dir} {gif_dir}' ->>>>>>> 626e7afc02230297b6f553675ea1c32c29971314 - os.system(cmd) \ No newline at end of file diff --git a/utils/multi_port.py b/utils/multi_port.py deleted file mode 100644 index b39be00..0000000 --- a/utils/multi_port.py +++ /dev/null @@ -1,20 +0,0 @@ -<<<<<<< HEAD -import socket -from contextlib import closing - -def find_free_port(): - """ https://stackoverflow.com/questions/1365265/on-localhost-how-do-i-pick-a-free-port-number """ - with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: - s.bind(('', 0)) - s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) -======= -import socket -from contextlib import closing - -def find_free_port(): - """ https://stackoverflow.com/questions/1365265/on-localhost-how-do-i-pick-a-free-port-number """ - with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: - s.bind(('', 0)) - s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) ->>>>>>> 626e7afc02230297b6f553675ea1c32c29971314 - return str(s.getsockname()[1]) \ No newline at end of file diff --git a/utils/optim/__init__.py b/utils/optim/__init__.py deleted file mode 100644 index a37e510..0000000 --- a/utils/optim/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -<<<<<<< HEAD -from .lr_scheduler import * -from .adafactor import * -======= -from .lr_scheduler import * -from .adafactor import * ->>>>>>> 626e7afc02230297b6f553675ea1c32c29971314 diff --git a/utils/optim/adafactor.py b/utils/optim/adafactor.py deleted file mode 100644 index 8369553..0000000 --- a/utils/optim/adafactor.py +++ /dev/null @@ -1,463 +0,0 @@ -<<<<<<< HEAD -import math -import torch -from torch.optim import Optimizer -from torch.optim.lr_scheduler import LambdaLR - -__all__ = ['Adafactor'] - -class Adafactor(Optimizer): - """ - AdaFactor pytorch implementation can be used as a drop in replacement for Adam original fairseq code: - https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py - Paper: *Adafactor: Adaptive Learning Rates with Sublinear Memory Cost* https://arxiv.org/abs/1804.04235 Note that - this optimizer internally adjusts the learning rate depending on the `scale_parameter`, `relative_step` and - `warmup_init` options. To use a manual (external) learning rate schedule you should set `scale_parameter=False` and - `relative_step=False`. - Arguments: - params (`Iterable[nn.parameter.Parameter]`): - Iterable of parameters to optimize or dictionaries defining parameter groups. - lr (`float`, *optional*): - The external learning rate. - eps (`Tuple[float, float]`, *optional*, defaults to (1e-30, 1e-3)): - Regularization constants for square gradient and parameter scale respectively - clip_threshold (`float`, *optional*, defaults 1.0): - Threshold of root mean square of final gradient update - decay_rate (`float`, *optional*, defaults to -0.8): - Coefficient used to compute running averages of square - beta1 (`float`, *optional*): - Coefficient used for computing running averages of gradient - weight_decay (`float`, *optional*, defaults to 0): - Weight decay (L2 penalty) - scale_parameter (`bool`, *optional*, defaults to `True`): - If True, learning rate is scaled by root mean square - relative_step (`bool`, *optional*, defaults to `True`): - If True, time-dependent learning rate is computed instead of external learning rate - warmup_init (`bool`, *optional*, defaults to `False`): - Time-dependent learning rate computation depends on whether warm-up initialization is being used - This implementation handles low-precision (FP16, bfloat) values, but we have not thoroughly tested. - Recommended T5 finetuning settings (https://discuss.huggingface.co/t/t5-finetuning-tips/684/3): - - Training without LR warmup or clip_threshold is not recommended. - - use scheduled LR warm-up to fixed LR - - use clip_threshold=1.0 (https://arxiv.org/abs/1804.04235) - - Disable relative updates - - Use scale_parameter=False - - Additional optimizer operations like gradient clipping should not be used alongside Adafactor - Example: - ```python - Adafactor(model.parameters(), scale_parameter=False, relative_step=False, warmup_init=False, lr=1e-3) - ``` - Others reported the following combination to work well: - ```python - Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None) - ``` - When using `lr=None` with [`Trainer`] you will most likely need to use [`~optimization.AdafactorSchedule`] - scheduler as following: - ```python - from transformers.optimization import Adafactor, AdafactorSchedule - optimizer = Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None) - lr_scheduler = AdafactorSchedule(optimizer) - trainer = Trainer(..., optimizers=(optimizer, lr_scheduler)) - ``` - Usage: - ```python - # replace AdamW with Adafactor - optimizer = Adafactor( - model.parameters(), - lr=1e-3, - eps=(1e-30, 1e-3), - clip_threshold=1.0, - decay_rate=-0.8, - beta1=None, - weight_decay=0.0, - relative_step=False, - scale_parameter=False, - warmup_init=False, - ) - ```""" - - def __init__( - self, - params, - lr=None, - eps=(1e-30, 1e-3), - clip_threshold=1.0, - decay_rate=-0.8, - beta1=None, - weight_decay=0.0, - scale_parameter=True, - relative_step=True, - warmup_init=False, - ): - r"""require_version("torch>=1.5.0") # add_ with alpha - """ - if lr is not None and relative_step: - raise ValueError("Cannot combine manual `lr` and `relative_step=True` options") - if warmup_init and not relative_step: - raise ValueError("`warmup_init=True` requires `relative_step=True`") - - defaults = dict( - lr=lr, - eps=eps, - clip_threshold=clip_threshold, - decay_rate=decay_rate, - beta1=beta1, - weight_decay=weight_decay, - scale_parameter=scale_parameter, - relative_step=relative_step, - warmup_init=warmup_init, - ) - super().__init__(params, defaults) - - @staticmethod - def _get_lr(param_group, param_state): - rel_step_sz = param_group["lr"] - if param_group["relative_step"]: - min_step = 1e-6 * param_state["step"] if param_group["warmup_init"] else 1e-2 - rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"])) - param_scale = 1.0 - if param_group["scale_parameter"]: - param_scale = max(param_group["eps"][1], param_state["RMS"]) - return param_scale * rel_step_sz - - @staticmethod - def _get_options(param_group, param_shape): - factored = len(param_shape) >= 2 - use_first_moment = param_group["beta1"] is not None - return factored, use_first_moment - - @staticmethod - def _rms(tensor): - return tensor.norm(2) / (tensor.numel() ** 0.5) - - @staticmethod - def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col): - # copy from fairseq's adafactor implementation: - # https://github.com/huggingface/transformers/blob/8395f14de6068012787d83989c3627c3df6a252b/src/transformers/optimization.py#L505 - r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1) - c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() - return torch.mul(r_factor, c_factor) - - def step(self, closure=None): - """ - Performs a single optimization step - Arguments: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. - """ - loss = None - if closure is not None: - loss = closure() - - for group in self.param_groups: - for p in group["params"]: - if p.grad is None: - continue - grad = p.grad.data - if grad.dtype in {torch.float16, torch.bfloat16}: - grad = grad.float() - if grad.is_sparse: - raise RuntimeError("Adafactor does not support sparse gradients.") - - state = self.state[p] - grad_shape = grad.shape - - factored, use_first_moment = self._get_options(group, grad_shape) - # State Initialization - if len(state) == 0: - state["step"] = 0 - - if use_first_moment: - # Exponential moving average of gradient values - state["exp_avg"] = torch.zeros_like(grad) - if factored: - state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad) - state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad) - else: - state["exp_avg_sq"] = torch.zeros_like(grad) - - state["RMS"] = 0 - else: - if use_first_moment: - state["exp_avg"] = state["exp_avg"].to(grad) - if factored: - state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad) - state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad) - else: - state["exp_avg_sq"] = state["exp_avg_sq"].to(grad) - - p_data_fp32 = p.data - if p.data.dtype in {torch.float16, torch.bfloat16}: - p_data_fp32 = p_data_fp32.float() - - state["step"] += 1 - state["RMS"] = self._rms(p_data_fp32) - lr = self._get_lr(group, state) - - beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) - update = (grad**2) + group["eps"][0] - if factored: - exp_avg_sq_row = state["exp_avg_sq_row"] - exp_avg_sq_col = state["exp_avg_sq_col"] - - exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t)) - exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t)) - - # Approximation of exponential moving average of square of gradient - update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) - update.mul_(grad) - else: - exp_avg_sq = state["exp_avg_sq"] - - exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t)) - update = exp_avg_sq.rsqrt().mul_(grad) - - update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0)) - update.mul_(lr) - - if use_first_moment: - exp_avg = state["exp_avg"] - exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"])) - update = exp_avg - - if group["weight_decay"] != 0: - p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr)) - - p_data_fp32.add_(-update) - - if p.data.dtype in {torch.float16, torch.bfloat16}: - p.data.copy_(p_data_fp32) - - return loss -======= -import math -import torch -from torch.optim import Optimizer -from torch.optim.lr_scheduler import LambdaLR - -__all__ = ['Adafactor'] - -class Adafactor(Optimizer): - """ - AdaFactor pytorch implementation can be used as a drop in replacement for Adam original fairseq code: - https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py - Paper: *Adafactor: Adaptive Learning Rates with Sublinear Memory Cost* https://arxiv.org/abs/1804.04235 Note that - this optimizer internally adjusts the learning rate depending on the `scale_parameter`, `relative_step` and - `warmup_init` options. To use a manual (external) learning rate schedule you should set `scale_parameter=False` and - `relative_step=False`. - Arguments: - params (`Iterable[nn.parameter.Parameter]`): - Iterable of parameters to optimize or dictionaries defining parameter groups. - lr (`float`, *optional*): - The external learning rate. - eps (`Tuple[float, float]`, *optional*, defaults to (1e-30, 1e-3)): - Regularization constants for square gradient and parameter scale respectively - clip_threshold (`float`, *optional*, defaults 1.0): - Threshold of root mean square of final gradient update - decay_rate (`float`, *optional*, defaults to -0.8): - Coefficient used to compute running averages of square - beta1 (`float`, *optional*): - Coefficient used for computing running averages of gradient - weight_decay (`float`, *optional*, defaults to 0): - Weight decay (L2 penalty) - scale_parameter (`bool`, *optional*, defaults to `True`): - If True, learning rate is scaled by root mean square - relative_step (`bool`, *optional*, defaults to `True`): - If True, time-dependent learning rate is computed instead of external learning rate - warmup_init (`bool`, *optional*, defaults to `False`): - Time-dependent learning rate computation depends on whether warm-up initialization is being used - This implementation handles low-precision (FP16, bfloat) values, but we have not thoroughly tested. - Recommended T5 finetuning settings (https://discuss.huggingface.co/t/t5-finetuning-tips/684/3): - - Training without LR warmup or clip_threshold is not recommended. - - use scheduled LR warm-up to fixed LR - - use clip_threshold=1.0 (https://arxiv.org/abs/1804.04235) - - Disable relative updates - - Use scale_parameter=False - - Additional optimizer operations like gradient clipping should not be used alongside Adafactor - Example: - ```python - Adafactor(model.parameters(), scale_parameter=False, relative_step=False, warmup_init=False, lr=1e-3) - ``` - Others reported the following combination to work well: - ```python - Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None) - ``` - When using `lr=None` with [`Trainer`] you will most likely need to use [`~optimization.AdafactorSchedule`] - scheduler as following: - ```python - from transformers.optimization import Adafactor, AdafactorSchedule - optimizer = Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None) - lr_scheduler = AdafactorSchedule(optimizer) - trainer = Trainer(..., optimizers=(optimizer, lr_scheduler)) - ``` - Usage: - ```python - # replace AdamW with Adafactor - optimizer = Adafactor( - model.parameters(), - lr=1e-3, - eps=(1e-30, 1e-3), - clip_threshold=1.0, - decay_rate=-0.8, - beta1=None, - weight_decay=0.0, - relative_step=False, - scale_parameter=False, - warmup_init=False, - ) - ```""" - - def __init__( - self, - params, - lr=None, - eps=(1e-30, 1e-3), - clip_threshold=1.0, - decay_rate=-0.8, - beta1=None, - weight_decay=0.0, - scale_parameter=True, - relative_step=True, - warmup_init=False, - ): - r"""require_version("torch>=1.5.0") # add_ with alpha - """ - if lr is not None and relative_step: - raise ValueError("Cannot combine manual `lr` and `relative_step=True` options") - if warmup_init and not relative_step: - raise ValueError("`warmup_init=True` requires `relative_step=True`") - - defaults = dict( - lr=lr, - eps=eps, - clip_threshold=clip_threshold, - decay_rate=decay_rate, - beta1=beta1, - weight_decay=weight_decay, - scale_parameter=scale_parameter, - relative_step=relative_step, - warmup_init=warmup_init, - ) - super().__init__(params, defaults) - - @staticmethod - def _get_lr(param_group, param_state): - rel_step_sz = param_group["lr"] - if param_group["relative_step"]: - min_step = 1e-6 * param_state["step"] if param_group["warmup_init"] else 1e-2 - rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"])) - param_scale = 1.0 - if param_group["scale_parameter"]: - param_scale = max(param_group["eps"][1], param_state["RMS"]) - return param_scale * rel_step_sz - - @staticmethod - def _get_options(param_group, param_shape): - factored = len(param_shape) >= 2 - use_first_moment = param_group["beta1"] is not None - return factored, use_first_moment - - @staticmethod - def _rms(tensor): - return tensor.norm(2) / (tensor.numel() ** 0.5) - - @staticmethod - def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col): - # copy from fairseq's adafactor implementation: - # https://github.com/huggingface/transformers/blob/8395f14de6068012787d83989c3627c3df6a252b/src/transformers/optimization.py#L505 - r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1) - c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() - return torch.mul(r_factor, c_factor) - - def step(self, closure=None): - """ - Performs a single optimization step - Arguments: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. - """ - loss = None - if closure is not None: - loss = closure() - - for group in self.param_groups: - for p in group["params"]: - if p.grad is None: - continue - grad = p.grad.data - if grad.dtype in {torch.float16, torch.bfloat16}: - grad = grad.float() - if grad.is_sparse: - raise RuntimeError("Adafactor does not support sparse gradients.") - - state = self.state[p] - grad_shape = grad.shape - - factored, use_first_moment = self._get_options(group, grad_shape) - # State Initialization - if len(state) == 0: - state["step"] = 0 - - if use_first_moment: - # Exponential moving average of gradient values - state["exp_avg"] = torch.zeros_like(grad) - if factored: - state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad) - state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad) - else: - state["exp_avg_sq"] = torch.zeros_like(grad) - - state["RMS"] = 0 - else: - if use_first_moment: - state["exp_avg"] = state["exp_avg"].to(grad) - if factored: - state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad) - state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad) - else: - state["exp_avg_sq"] = state["exp_avg_sq"].to(grad) - - p_data_fp32 = p.data - if p.data.dtype in {torch.float16, torch.bfloat16}: - p_data_fp32 = p_data_fp32.float() - - state["step"] += 1 - state["RMS"] = self._rms(p_data_fp32) - lr = self._get_lr(group, state) - - beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) - update = (grad**2) + group["eps"][0] - if factored: - exp_avg_sq_row = state["exp_avg_sq_row"] - exp_avg_sq_col = state["exp_avg_sq_col"] - - exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t)) - exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t)) - - # Approximation of exponential moving average of square of gradient - update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) - update.mul_(grad) - else: - exp_avg_sq = state["exp_avg_sq"] - - exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t)) - update = exp_avg_sq.rsqrt().mul_(grad) - - update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0)) - update.mul_(lr) - - if use_first_moment: - exp_avg = state["exp_avg"] - exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"])) - update = exp_avg - - if group["weight_decay"] != 0: - p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr)) - - p_data_fp32.add_(-update) - - if p.data.dtype in {torch.float16, torch.bfloat16}: - p.data.copy_(p_data_fp32) - - return loss ->>>>>>> 626e7afc02230297b6f553675ea1c32c29971314 diff --git a/utils/optim/lr_scheduler.py b/utils/optim/lr_scheduler.py deleted file mode 100644 index 2349b9a..0000000 --- a/utils/optim/lr_scheduler.py +++ /dev/null @@ -1,119 +0,0 @@ -<<<<<<< HEAD -import math -from torch.optim.lr_scheduler import _LRScheduler - -__all__ = ['AnnealingLR'] - -class AnnealingLR(_LRScheduler): - - def __init__(self, optimizer, base_lr, warmup_steps, total_steps, decay_mode='cosine', min_lr=0.0, last_step=-1): - assert decay_mode in ['linear', 'cosine', 'none'] - self.optimizer = optimizer - self.base_lr = base_lr - self.warmup_steps = warmup_steps - self.total_steps = total_steps - self.decay_mode = decay_mode - self.min_lr = min_lr - self.current_step = last_step + 1 - self.step(self.current_step) - - def get_lr(self): - if self.warmup_steps > 0 and self.current_step <= self.warmup_steps: - return self.base_lr * self.current_step / self.warmup_steps - else: - ratio = (self.current_step - self.warmup_steps) / (self.total_steps - self.warmup_steps) - ratio = min(1.0, max(0.0, ratio)) - if self.decay_mode == 'linear': - return self.base_lr * (1 - ratio) - elif self.decay_mode == 'cosine': - return self.base_lr * (math.cos(math.pi * ratio) + 1.0) / 2.0 - else: - return self.base_lr - - def step(self, current_step=None): - if current_step is None: - current_step = self.current_step + 1 - self.current_step = current_step - new_lr = max(self.min_lr, self.get_lr()) - if isinstance(self.optimizer, list): - for o in self.optimizer: - for group in o.param_groups: - group['lr'] = new_lr - else: - for group in self.optimizer.param_groups: - group['lr'] = new_lr - - def state_dict(self): - return { - 'base_lr': self.base_lr, - 'warmup_steps': self.warmup_steps, - 'total_steps': self.total_steps, - 'decay_mode': self.decay_mode, - 'current_step': self.current_step} - - def load_state_dict(self, state_dict): - self.base_lr = state_dict['base_lr'] - self.warmup_steps = state_dict['warmup_steps'] - self.total_steps = state_dict['total_steps'] - self.decay_mode = state_dict['decay_mode'] - self.current_step = state_dict['current_step'] -======= -import math -from torch.optim.lr_scheduler import _LRScheduler - -__all__ = ['AnnealingLR'] - -class AnnealingLR(_LRScheduler): - - def __init__(self, optimizer, base_lr, warmup_steps, total_steps, decay_mode='cosine', min_lr=0.0, last_step=-1): - assert decay_mode in ['linear', 'cosine', 'none'] - self.optimizer = optimizer - self.base_lr = base_lr - self.warmup_steps = warmup_steps - self.total_steps = total_steps - self.decay_mode = decay_mode - self.min_lr = min_lr - self.current_step = last_step + 1 - self.step(self.current_step) - - def get_lr(self): - if self.warmup_steps > 0 and self.current_step <= self.warmup_steps: - return self.base_lr * self.current_step / self.warmup_steps - else: - ratio = (self.current_step - self.warmup_steps) / (self.total_steps - self.warmup_steps) - ratio = min(1.0, max(0.0, ratio)) - if self.decay_mode == 'linear': - return self.base_lr * (1 - ratio) - elif self.decay_mode == 'cosine': - return self.base_lr * (math.cos(math.pi * ratio) + 1.0) / 2.0 - else: - return self.base_lr - - def step(self, current_step=None): - if current_step is None: - current_step = self.current_step + 1 - self.current_step = current_step - new_lr = max(self.min_lr, self.get_lr()) - if isinstance(self.optimizer, list): - for o in self.optimizer: - for group in o.param_groups: - group['lr'] = new_lr - else: - for group in self.optimizer.param_groups: - group['lr'] = new_lr - - def state_dict(self): - return { - 'base_lr': self.base_lr, - 'warmup_steps': self.warmup_steps, - 'total_steps': self.total_steps, - 'decay_mode': self.decay_mode, - 'current_step': self.current_step} - - def load_state_dict(self, state_dict): - self.base_lr = state_dict['base_lr'] - self.warmup_steps = state_dict['warmup_steps'] - self.total_steps = state_dict['total_steps'] - self.decay_mode = state_dict['decay_mode'] - self.current_step = state_dict['current_step'] ->>>>>>> 626e7afc02230297b6f553675ea1c32c29971314 diff --git a/utils/registry.py b/utils/registry.py deleted file mode 100644 index fdaec06..0000000 --- a/utils/registry.py +++ /dev/null @@ -1,337 +0,0 @@ -<<<<<<< HEAD -# Copyright 2021 Alibaba Group Holding Limited. All Rights Reserved. - -# Registry class & build_from_config function partially modified from -# https://github.com/open-mmlab/mmcv/blob/master/mmcv/utils/registry.py -# Copyright 2018-2020 Open-MMLab. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import copy -import inspect -import warnings - - -def build_from_config(cfg, registry, **kwargs): - """ Default builder function. - - Args: - cfg (dict): A dict which contains parameters passes to target class or function. - Must contains key 'type', indicates the target class or function name. - registry (Registry): An registry to search target class or function. - kwargs (dict, optional): Other params not in config dict. - - Returns: - Target class object or object returned by invoking function. - - Raises: - TypeError: - KeyError: - Exception: - """ - if not isinstance(cfg, dict): - raise TypeError(f"config must be type dict, got {type(cfg)}") - if "type" not in cfg: - raise KeyError(f"config must contain key type, got {cfg}") - if not isinstance(registry, Registry): - raise TypeError(f"registry must be type Registry, got {type(registry)}") - - cfg = copy.deepcopy(cfg) - - req_type = cfg.pop("type") - req_type_entry = req_type - if isinstance(req_type, str): - req_type_entry = registry.get(req_type) - if req_type_entry is None: - try: - print(f"For Windows users, we explicitly import registry function {req_type} !!!") - from tools.inferences.inference_unianimate_entrance import inference_unianimate_entrance - from tools.inferences.inference_unianimate_long_entrance import inference_unianimate_long_entrance - # from tools.modules.diffusions.diffusion_ddim import DiffusionDDIM - # from tools.modules.diffusions.diffusion_ddim import DiffusionDDIMLong - # from tools.modules.autoencoder import AutoencoderKL - # from tools.modules.clip_embedder import FrozenOpenCLIPTextVisualEmbedder - # from tools.modules.unet.unet_unianimate import UNetSD_UniAnimate - - req_type_entry = eval(req_type) - except: - raise KeyError(f"{req_type} not found in {registry.name} registry") - - if kwargs is not None: - cfg.update(kwargs) - - if inspect.isclass(req_type_entry): - try: - return req_type_entry(**cfg) - except Exception as e: - raise Exception(f"Failed to init class {req_type_entry}, with {e}") - elif inspect.isfunction(req_type_entry): - try: - return req_type_entry(**cfg) - except Exception as e: - raise Exception(f"Failed to invoke function {req_type_entry}, with {e}") - else: - raise TypeError(f"type must be str or class, got {type(req_type_entry)}") - - -class Registry(object): - """ A registry maps key to classes or functions. - - Example: - >>> MODELS = Registry('MODELS') - >>> @MODELS.register_class() - >>> class ResNet(object): - >>> pass - >>> resnet = MODELS.build(dict(type="ResNet")) - >>> - >>> import torchvision - >>> @MODELS.register_function("InceptionV3") - >>> def get_inception_v3(pretrained=False, progress=True): - >>> return torchvision.models.inception_v3(pretrained=pretrained, progress=progress) - >>> inception_v3 = MODELS.build(dict(type='InceptionV3', pretrained=True)) - - Args: - name (str): Registry name. - build_func (func, None): Instance construct function. Default is build_from_config. - allow_types (tuple): Indicates how to construct the instance, by constructing class or invoking function. - """ - - def __init__(self, name, build_func=None, allow_types=("class", "function")): - self.name = name - self.allow_types = allow_types - self.class_map = {} - self.func_map = {} - self.build_func = build_func or build_from_config - - def get(self, req_type): - return self.class_map.get(req_type) or self.func_map.get(req_type) - - def build(self, *args, **kwargs): - return self.build_func(*args, **kwargs, registry=self) - - def register_class(self, name=None): - def _register(cls): - if not inspect.isclass(cls): - raise TypeError(f"Module must be type class, got {type(cls)}") - if "class" not in self.allow_types: - raise TypeError(f"Register {self.name} only allows type {self.allow_types}, got class") - module_name = name or cls.__name__ - if module_name in self.class_map: - warnings.warn(f"Class {module_name} already registered by {self.class_map[module_name]}, " - f"will be replaced by {cls}") - self.class_map[module_name] = cls - return cls - - return _register - - def register_function(self, name=None): - def _register(func): - if not inspect.isfunction(func): - raise TypeError(f"Registry must be type function, got {type(func)}") - if "function" not in self.allow_types: - raise TypeError(f"Registry {self.name} only allows type {self.allow_types}, got function") - func_name = name or func.__name__ - if func_name in self.class_map: - warnings.warn(f"Function {func_name} already registered by {self.func_map[func_name]}, " - f"will be replaced by {func}") - self.func_map[func_name] = func - return func - - return _register - - def _list(self): - keys = sorted(list(self.class_map.keys()) + list(self.func_map.keys())) - descriptions = [] - for key in keys: - if key in self.class_map: - descriptions.append(f"{key}: {self.class_map[key]}") - else: - descriptions.append( - f"{key}: ") - return "\n".join(descriptions) - - def __repr__(self): - description = self._list() - description = '\n'.join(['\t' + s for s in description.split('\n')]) - return f"{self.__class__.__name__} [{self.name}], \n" + description - - -======= -# Copyright 2021 Alibaba Group Holding Limited. All Rights Reserved. - -# Registry class & build_from_config function partially modified from -# https://github.com/open-mmlab/mmcv/blob/master/mmcv/utils/registry.py -# Copyright 2018-2020 Open-MMLab. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import copy -import inspect -import warnings - - -def build_from_config(cfg, registry, **kwargs): - """ Default builder function. - - Args: - cfg (dict): A dict which contains parameters passes to target class or function. - Must contains key 'type', indicates the target class or function name. - registry (Registry): An registry to search target class or function. - kwargs (dict, optional): Other params not in config dict. - - Returns: - Target class object or object returned by invoking function. - - Raises: - TypeError: - KeyError: - Exception: - """ - if not isinstance(cfg, dict): - raise TypeError(f"config must be type dict, got {type(cfg)}") - if "type" not in cfg: - raise KeyError(f"config must contain key type, got {cfg}") - if not isinstance(registry, Registry): - raise TypeError(f"registry must be type Registry, got {type(registry)}") - - cfg = copy.deepcopy(cfg) - - req_type = cfg.pop("type") - req_type_entry = req_type - if isinstance(req_type, str): - req_type_entry = registry.get(req_type) - if req_type_entry is None: - try: - print(f"For Windows users, we explicitly import registry function {req_type} !!!") - from tools.inferences.inference_unianimate_entrance import inference_unianimate_entrance - from tools.inferences.inference_unianimate_long_entrance import inference_unianimate_long_entrance - # from tools.modules.diffusions.diffusion_ddim import DiffusionDDIM - # from tools.modules.diffusions.diffusion_ddim import DiffusionDDIMLong - # from tools.modules.autoencoder import AutoencoderKL - # from tools.modules.clip_embedder import FrozenOpenCLIPTextVisualEmbedder - # from tools.modules.unet.unet_unianimate import UNetSD_UniAnimate - - req_type_entry = eval(req_type) - except: - raise KeyError(f"{req_type} not found in {registry.name} registry") - - if kwargs is not None: - cfg.update(kwargs) - - if inspect.isclass(req_type_entry): - try: - return req_type_entry(**cfg) - except Exception as e: - raise Exception(f"Failed to init class {req_type_entry}, with {e}") - elif inspect.isfunction(req_type_entry): - try: - return req_type_entry(**cfg) - except Exception as e: - raise Exception(f"Failed to invoke function {req_type_entry}, with {e}") - else: - raise TypeError(f"type must be str or class, got {type(req_type_entry)}") - - -class Registry(object): - """ A registry maps key to classes or functions. - - Example: - >>> MODELS = Registry('MODELS') - >>> @MODELS.register_class() - >>> class ResNet(object): - >>> pass - >>> resnet = MODELS.build(dict(type="ResNet")) - >>> - >>> import torchvision - >>> @MODELS.register_function("InceptionV3") - >>> def get_inception_v3(pretrained=False, progress=True): - >>> return torchvision.models.inception_v3(pretrained=pretrained, progress=progress) - >>> inception_v3 = MODELS.build(dict(type='InceptionV3', pretrained=True)) - - Args: - name (str): Registry name. - build_func (func, None): Instance construct function. Default is build_from_config. - allow_types (tuple): Indicates how to construct the instance, by constructing class or invoking function. - """ - - def __init__(self, name, build_func=None, allow_types=("class", "function")): - self.name = name - self.allow_types = allow_types - self.class_map = {} - self.func_map = {} - self.build_func = build_func or build_from_config - - def get(self, req_type): - return self.class_map.get(req_type) or self.func_map.get(req_type) - - def build(self, *args, **kwargs): - return self.build_func(*args, **kwargs, registry=self) - - def register_class(self, name=None): - def _register(cls): - if not inspect.isclass(cls): - raise TypeError(f"Module must be type class, got {type(cls)}") - if "class" not in self.allow_types: - raise TypeError(f"Register {self.name} only allows type {self.allow_types}, got class") - module_name = name or cls.__name__ - if module_name in self.class_map: - warnings.warn(f"Class {module_name} already registered by {self.class_map[module_name]}, " - f"will be replaced by {cls}") - self.class_map[module_name] = cls - return cls - - return _register - - def register_function(self, name=None): - def _register(func): - if not inspect.isfunction(func): - raise TypeError(f"Registry must be type function, got {type(func)}") - if "function" not in self.allow_types: - raise TypeError(f"Registry {self.name} only allows type {self.allow_types}, got function") - func_name = name or func.__name__ - if func_name in self.class_map: - warnings.warn(f"Function {func_name} already registered by {self.func_map[func_name]}, " - f"will be replaced by {func}") - self.func_map[func_name] = func - return func - - return _register - - def _list(self): - keys = sorted(list(self.class_map.keys()) + list(self.func_map.keys())) - descriptions = [] - for key in keys: - if key in self.class_map: - descriptions.append(f"{key}: {self.class_map[key]}") - else: - descriptions.append( - f"{key}: ") - return "\n".join(descriptions) - - def __repr__(self): - description = self._list() - description = '\n'.join(['\t' + s for s in description.split('\n')]) - return f"{self.__class__.__name__} [{self.name}], \n" + description - - ->>>>>>> 626e7afc02230297b6f553675ea1c32c29971314 diff --git a/utils/registry_class.py b/utils/registry_class.py deleted file mode 100644 index 3b52145..0000000 --- a/utils/registry_class.py +++ /dev/null @@ -1,41 +0,0 @@ -<<<<<<< HEAD -from .registry import Registry, build_from_config - -def build_func(cfg, registry, **kwargs): - """ - Except for config, if passing a list of dataset config, then return the concat type of it - """ - return build_from_config(cfg, registry, **kwargs) - -AUTO_ENCODER = Registry("AUTO_ENCODER", build_func=build_func) -DATASETS = Registry("DATASETS", build_func=build_func) -DIFFUSION = Registry("DIFFUSION", build_func=build_func) -DISTRIBUTION = Registry("DISTRIBUTION", build_func=build_func) -EMBEDDER = Registry("EMBEDDER", build_func=build_func) -ENGINE = Registry("ENGINE", build_func=build_func) -INFER_ENGINE = Registry("INFER_ENGINE", build_func=build_func) -MODEL = Registry("MODEL", build_func=build_func) -PRETRAIN = Registry("PRETRAIN", build_func=build_func) -VISUAL = Registry("VISUAL", build_func=build_func) -EMBEDMANAGER = Registry("EMBEDMANAGER", build_func=build_func) -======= -from .registry import Registry, build_from_config - -def build_func(cfg, registry, **kwargs): - """ - Except for config, if passing a list of dataset config, then return the concat type of it - """ - return build_from_config(cfg, registry, **kwargs) - -AUTO_ENCODER = Registry("AUTO_ENCODER", build_func=build_func) -DATASETS = Registry("DATASETS", build_func=build_func) -DIFFUSION = Registry("DIFFUSION", build_func=build_func) -DISTRIBUTION = Registry("DISTRIBUTION", build_func=build_func) -EMBEDDER = Registry("EMBEDDER", build_func=build_func) -ENGINE = Registry("ENGINE", build_func=build_func) -INFER_ENGINE = Registry("INFER_ENGINE", build_func=build_func) -MODEL = Registry("MODEL", build_func=build_func) -PRETRAIN = Registry("PRETRAIN", build_func=build_func) -VISUAL = Registry("VISUAL", build_func=build_func) -EMBEDMANAGER = Registry("EMBEDMANAGER", build_func=build_func) ->>>>>>> 626e7afc02230297b6f553675ea1c32c29971314 diff --git a/utils/seed.py b/utils/seed.py deleted file mode 100644 index 93967ef..0000000 --- a/utils/seed.py +++ /dev/null @@ -1,24 +0,0 @@ -<<<<<<< HEAD -import torch -import random -import numpy as np - - -def setup_seed(seed): - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - np.random.seed(seed) - random.seed(seed) -======= -import torch -import random -import numpy as np - - -def setup_seed(seed): - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - np.random.seed(seed) - random.seed(seed) ->>>>>>> 626e7afc02230297b6f553675ea1c32c29971314 - torch.backends.cudnn.deterministic = True \ No newline at end of file diff --git a/utils/transforms.py b/utils/transforms.py deleted file mode 100644 index a397e86..0000000 --- a/utils/transforms.py +++ /dev/null @@ -1,709 +0,0 @@ -<<<<<<< HEAD -import torch -import torchvision.transforms.functional as F -import random -import math -import numpy as np -from PIL import Image, ImageFilter - -__all__ = ['Compose', 'Resize', 'Rescale', 'CenterCrop', 'CenterCropV2', 'CenterCropWide', 'RandomCrop', 'RandomCropV2', 'RandomHFlip',\ - 'GaussianBlur', 'ColorJitter', 'RandomGray', 'ToTensor', 'Normalize', "ResizeRandomCrop", "ExtractResizeRandomCrop", "ExtractResizeAssignCrop"] - - -class Compose(object): - - def __init__(self, transforms): - self.transforms = transforms - - def __getitem__(self, index): - if isinstance(index, slice): - return Compose(self.transforms[index]) - else: - return self.transforms[index] - - def __len__(self): - return len(self.transforms) - - def __call__(self, rgb): - for t in self.transforms: - rgb = t(rgb) - return rgb - -class Resize(object): - - def __init__(self, size=256): - if isinstance(size, int): - size = (size, size) - self.size = size - - def __call__(self, rgb): - if isinstance(rgb, list): - rgb = [u.resize(self.size, Image.BILINEAR) for u in rgb] - else: - rgb = rgb.resize(self.size, Image.BILINEAR) - return rgb - -class Rescale(object): - - def __init__(self, size=256, interpolation=Image.BILINEAR): - self.size = size - self.interpolation = interpolation - - def __call__(self, rgb): - w, h = rgb[0].size - scale = self.size / min(w, h) - out_w, out_h = int(round(w * scale)), int(round(h * scale)) - rgb = [u.resize((out_w, out_h), self.interpolation) for u in rgb] - return rgb - -class CenterCrop(object): - - def __init__(self, size=224): - self.size = size - - def __call__(self, rgb): - w, h = rgb[0].size - assert min(w, h) >= self.size - x1 = (w - self.size) // 2 - y1 = (h - self.size) // 2 - rgb = [u.crop((x1, y1, x1 + self.size, y1 + self.size)) for u in rgb] - return rgb - -class ResizeRandomCrop(object): - - def __init__(self, size=256, size_short=292): - self.size = size - # self.min_area = min_area - self.size_short = size_short - - def __call__(self, rgb): - - # consistent crop between rgb and m - while min(rgb[0].size) >= 2 * self.size_short: - rgb = [u.resize((u.width // 2, u.height // 2), resample=Image.BOX) for u in rgb] - scale = self.size_short / min(rgb[0].size) - rgb = [u.resize((round(scale * u.width), round(scale * u.height)), resample=Image.BICUBIC) for u in rgb] - out_w = self.size - out_h = self.size - w, h = rgb[0].size # (518, 292) - x1 = random.randint(0, w - out_w) - y1 = random.randint(0, h - out_h) - - rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb] - # rgb = [u.resize((self.size, self.size), Image.BILINEAR) for u in rgb] - # # center crop - # x1 = (img[0].width - self.size) // 2 - # y1 = (img[0].height - self.size) // 2 - # img = [u.crop((x1, y1, x1 + self.size, y1 + self.size)) for u in img] - return rgb - - - -class ExtractResizeRandomCrop(object): - - def __init__(self, size=256, size_short=292): - self.size = size - # self.min_area = min_area - self.size_short = size_short - - def __call__(self, rgb): - - # consistent crop between rgb and m - while min(rgb[0].size) >= 2 * self.size_short: - rgb = [u.resize((u.width // 2, u.height // 2), resample=Image.BOX) for u in rgb] - scale = self.size_short / min(rgb[0].size) - rgb = [u.resize((round(scale * u.width), round(scale * u.height)), resample=Image.BICUBIC) for u in rgb] - out_w = self.size - out_h = self.size - w, h = rgb[0].size # (518, 292) - x1 = random.randint(0, w - out_w) - y1 = random.randint(0, h - out_h) - - rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb] - wh = [x1, y1, x1 + out_w, y1 + out_h] - return rgb, wh - - -class ExtractResizeAssignCrop(object): - - def __init__(self, size=256, size_short=292): - self.size = size - # self.min_area = min_area - self.size_short = size_short - - def __call__(self, rgb, wh): - - # consistent crop between rgb and m - while min(rgb[0].size) >= 2 * self.size_short: - rgb = [u.resize((u.width // 2, u.height // 2), resample=Image.BOX) for u in rgb] - scale = self.size_short / min(rgb[0].size) - rgb = [u.resize((round(scale * u.width), round(scale * u.height)), resample=Image.BICUBIC) for u in rgb] - - rgb = [u.crop(wh) for u in rgb] - rgb = [u.resize((self.size, self.size), Image.BILINEAR) for u in rgb] - return rgb - -class CenterCropV2(object): - def __init__(self, size): - self.size = size - - def __call__(self, img): - # fast resize - while min(img[0].size) >= 2 * self.size: - img = [u.resize((u.width // 2, u.height // 2), resample=Image.BOX) for u in img] - scale = self.size / min(img[0].size) - img = [u.resize((round(scale * u.width), round(scale * u.height)), resample=Image.BICUBIC) for u in img] - - # center crop - x1 = (img[0].width - self.size) // 2 - y1 = (img[0].height - self.size) // 2 - img = [u.crop((x1, y1, x1 + self.size, y1 + self.size)) for u in img] - return img - - -class CenterCropWide(object): - def __init__(self, size, interpolation=Image.BOX): - self.size = size - self.interpolation = interpolation - - def __call__(self, img): - if isinstance(img, list): - scale = min(img[0].size[0]/self.size[0], img[0].size[1]/self.size[1]) - img = [u.resize((round(u.width // scale), round(u.height // scale)), resample=self.interpolation) for u in img] - - # center crop - x1 = (img[0].width - self.size[0]) // 2 - y1 = (img[0].height - self.size[1]) // 2 - img = [u.crop((x1, y1, x1 + self.size[0], y1 + self.size[1])) for u in img] - return img - else: - scale = min(img.size[0]/self.size[0], img.size[1]/self.size[1]) - img = img.resize((round(img.width // scale), round(img.height // scale)), resample=self.interpolation) - x1 = (img.width - self.size[0]) // 2 - y1 = (img.height - self.size[1]) // 2 - img = img.crop((x1, y1, x1 + self.size[0], y1 + self.size[1])) - return img - - - -class RandomCrop(object): - - def __init__(self, size=224, min_area=0.4): - self.size = size - self.min_area = min_area - - def __call__(self, rgb): - - # consistent crop between rgb and m - w, h = rgb[0].size - area = w * h - out_w, out_h = float('inf'), float('inf') - while out_w > w or out_h > h: - target_area = random.uniform(self.min_area, 1.0) * area - aspect_ratio = random.uniform(3. / 4., 4. / 3.) - out_w = int(round(math.sqrt(target_area * aspect_ratio))) - out_h = int(round(math.sqrt(target_area / aspect_ratio))) - x1 = random.randint(0, w - out_w) - y1 = random.randint(0, h - out_h) - - rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb] - rgb = [u.resize((self.size, self.size), Image.BILINEAR) for u in rgb] - - return rgb - -class RandomCropV2(object): - - def __init__(self, size=224, min_area=0.4, ratio=(3. / 4., 4. / 3.)): - if isinstance(size, (tuple, list)): - self.size = size - else: - self.size = (size, size) - self.min_area = min_area - self.ratio = ratio - - def _get_params(self, img): - width, height = img.size - area = height * width - - for _ in range(10): - target_area = random.uniform(self.min_area, 1.0) * area - log_ratio = (math.log(self.ratio[0]), math.log(self.ratio[1])) - aspect_ratio = math.exp(random.uniform(*log_ratio)) - - w = int(round(math.sqrt(target_area * aspect_ratio))) - h = int(round(math.sqrt(target_area / aspect_ratio))) - - if 0 < w <= width and 0 < h <= height: - i = random.randint(0, height - h) - j = random.randint(0, width - w) - return i, j, h, w - - # Fallback to central crop - in_ratio = float(width) / float(height) - if (in_ratio < min(self.ratio)): - w = width - h = int(round(w / min(self.ratio))) - elif (in_ratio > max(self.ratio)): - h = height - w = int(round(h * max(self.ratio))) - else: # whole image - w = width - h = height - i = (height - h) // 2 - j = (width - w) // 2 - return i, j, h, w - - def __call__(self, rgb): - i, j, h, w = self._get_params(rgb[0]) - rgb = [F.resized_crop(u, i, j, h, w, self.size) for u in rgb] - return rgb - -class RandomHFlip(object): - - def __init__(self, p=0.5): - self.p = p - - def __call__(self, rgb): - if random.random() < self.p: - rgb = [u.transpose(Image.FLIP_LEFT_RIGHT) for u in rgb] - return rgb - -class GaussianBlur(object): - - def __init__(self, sigmas=[0.1, 2.0], p=0.5): - self.sigmas = sigmas - self.p = p - - def __call__(self, rgb): - if random.random() < self.p: - sigma = random.uniform(*self.sigmas) - rgb = [u.filter(ImageFilter.GaussianBlur(radius=sigma)) for u in rgb] - return rgb - -class ColorJitter(object): - - def __init__(self, brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1, p=0.5): - self.brightness = brightness - self.contrast = contrast - self.saturation = saturation - self.hue = hue - self.p = p - - def __call__(self, rgb): - if random.random() < self.p: - brightness, contrast, saturation, hue = self._random_params() - transforms = [ - lambda f: F.adjust_brightness(f, brightness), - lambda f: F.adjust_contrast(f, contrast), - lambda f: F.adjust_saturation(f, saturation), - lambda f: F.adjust_hue(f, hue)] - random.shuffle(transforms) - for t in transforms: - rgb = [t(u) for u in rgb] - - return rgb - - def _random_params(self): - brightness = random.uniform( - max(0, 1 - self.brightness), 1 + self.brightness) - contrast = random.uniform( - max(0, 1 - self.contrast), 1 + self.contrast) - saturation = random.uniform( - max(0, 1 - self.saturation), 1 + self.saturation) - hue = random.uniform(-self.hue, self.hue) - return brightness, contrast, saturation, hue - -class RandomGray(object): - - def __init__(self, p=0.2): - self.p = p - - def __call__(self, rgb): - if random.random() < self.p: - rgb = [u.convert('L').convert('RGB') for u in rgb] - return rgb - -class ToTensor(object): - - def __call__(self, rgb): - if isinstance(rgb, list): - rgb = torch.stack([F.to_tensor(u) for u in rgb], dim=0) - else: - rgb = F.to_tensor(rgb) - - return rgb - -class Normalize(object): - - def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): - self.mean = mean - self.std = std - - def __call__(self, rgb): - rgb = rgb.clone() - rgb.clamp_(0, 1) - if not isinstance(self.mean, torch.Tensor): - self.mean = rgb.new_tensor(self.mean).view(-1) - if not isinstance(self.std, torch.Tensor): - self.std = rgb.new_tensor(self.std).view(-1) - if rgb.dim() == 4: - rgb.sub_(self.mean.view(1, -1, 1, 1)).div_(self.std.view(1, -1, 1, 1)) - elif rgb.dim() == 3: - rgb.sub_(self.mean.view(-1, 1, 1)).div_(self.std.view(-1, 1, 1)) - return rgb - -======= -import torch -import torchvision.transforms.functional as F -import random -import math -import numpy as np -from PIL import Image, ImageFilter - -__all__ = ['Compose', 'Resize', 'Rescale', 'CenterCrop', 'CenterCropV2', 'CenterCropWide', 'RandomCrop', 'RandomCropV2', 'RandomHFlip',\ - 'GaussianBlur', 'ColorJitter', 'RandomGray', 'ToTensor', 'Normalize', "ResizeRandomCrop", "ExtractResizeRandomCrop", "ExtractResizeAssignCrop"] - - -class Compose(object): - - def __init__(self, transforms): - self.transforms = transforms - - def __getitem__(self, index): - if isinstance(index, slice): - return Compose(self.transforms[index]) - else: - return self.transforms[index] - - def __len__(self): - return len(self.transforms) - - def __call__(self, rgb): - for t in self.transforms: - rgb = t(rgb) - return rgb - -class Resize(object): - - def __init__(self, size=256): - if isinstance(size, int): - size = (size, size) - self.size = size - - def __call__(self, rgb): - if isinstance(rgb, list): - rgb = [u.resize(self.size, Image.BILINEAR) for u in rgb] - else: - rgb = rgb.resize(self.size, Image.BILINEAR) - return rgb - -class Rescale(object): - - def __init__(self, size=256, interpolation=Image.BILINEAR): - self.size = size - self.interpolation = interpolation - - def __call__(self, rgb): - w, h = rgb[0].size - scale = self.size / min(w, h) - out_w, out_h = int(round(w * scale)), int(round(h * scale)) - rgb = [u.resize((out_w, out_h), self.interpolation) for u in rgb] - return rgb - -class CenterCrop(object): - - def __init__(self, size=224): - self.size = size - - def __call__(self, rgb): - w, h = rgb[0].size - assert min(w, h) >= self.size - x1 = (w - self.size) // 2 - y1 = (h - self.size) // 2 - rgb = [u.crop((x1, y1, x1 + self.size, y1 + self.size)) for u in rgb] - return rgb - -class ResizeRandomCrop(object): - - def __init__(self, size=256, size_short=292): - self.size = size - # self.min_area = min_area - self.size_short = size_short - - def __call__(self, rgb): - - # consistent crop between rgb and m - while min(rgb[0].size) >= 2 * self.size_short: - rgb = [u.resize((u.width // 2, u.height // 2), resample=Image.BOX) for u in rgb] - scale = self.size_short / min(rgb[0].size) - rgb = [u.resize((round(scale * u.width), round(scale * u.height)), resample=Image.BICUBIC) for u in rgb] - out_w = self.size - out_h = self.size - w, h = rgb[0].size # (518, 292) - x1 = random.randint(0, w - out_w) - y1 = random.randint(0, h - out_h) - - rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb] - # rgb = [u.resize((self.size, self.size), Image.BILINEAR) for u in rgb] - # # center crop - # x1 = (img[0].width - self.size) // 2 - # y1 = (img[0].height - self.size) // 2 - # img = [u.crop((x1, y1, x1 + self.size, y1 + self.size)) for u in img] - return rgb - - - -class ExtractResizeRandomCrop(object): - - def __init__(self, size=256, size_short=292): - self.size = size - # self.min_area = min_area - self.size_short = size_short - - def __call__(self, rgb): - - # consistent crop between rgb and m - while min(rgb[0].size) >= 2 * self.size_short: - rgb = [u.resize((u.width // 2, u.height // 2), resample=Image.BOX) for u in rgb] - scale = self.size_short / min(rgb[0].size) - rgb = [u.resize((round(scale * u.width), round(scale * u.height)), resample=Image.BICUBIC) for u in rgb] - out_w = self.size - out_h = self.size - w, h = rgb[0].size # (518, 292) - x1 = random.randint(0, w - out_w) - y1 = random.randint(0, h - out_h) - - rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb] - wh = [x1, y1, x1 + out_w, y1 + out_h] - return rgb, wh - - -class ExtractResizeAssignCrop(object): - - def __init__(self, size=256, size_short=292): - self.size = size - # self.min_area = min_area - self.size_short = size_short - - def __call__(self, rgb, wh): - - # consistent crop between rgb and m - while min(rgb[0].size) >= 2 * self.size_short: - rgb = [u.resize((u.width // 2, u.height // 2), resample=Image.BOX) for u in rgb] - scale = self.size_short / min(rgb[0].size) - rgb = [u.resize((round(scale * u.width), round(scale * u.height)), resample=Image.BICUBIC) for u in rgb] - - rgb = [u.crop(wh) for u in rgb] - rgb = [u.resize((self.size, self.size), Image.BILINEAR) for u in rgb] - return rgb - -class CenterCropV2(object): - def __init__(self, size): - self.size = size - - def __call__(self, img): - # fast resize - while min(img[0].size) >= 2 * self.size: - img = [u.resize((u.width // 2, u.height // 2), resample=Image.BOX) for u in img] - scale = self.size / min(img[0].size) - img = [u.resize((round(scale * u.width), round(scale * u.height)), resample=Image.BICUBIC) for u in img] - - # center crop - x1 = (img[0].width - self.size) // 2 - y1 = (img[0].height - self.size) // 2 - img = [u.crop((x1, y1, x1 + self.size, y1 + self.size)) for u in img] - return img - - -class CenterCropWide(object): - def __init__(self, size, interpolation=Image.BOX): - self.size = size - self.interpolation = interpolation - - def __call__(self, img): - if isinstance(img, list): - scale = min(img[0].size[0]/self.size[0], img[0].size[1]/self.size[1]) - img = [u.resize((round(u.width // scale), round(u.height // scale)), resample=self.interpolation) for u in img] - - # center crop - x1 = (img[0].width - self.size[0]) // 2 - y1 = (img[0].height - self.size[1]) // 2 - img = [u.crop((x1, y1, x1 + self.size[0], y1 + self.size[1])) for u in img] - return img - else: - scale = min(img.size[0]/self.size[0], img.size[1]/self.size[1]) - img = img.resize((round(img.width // scale), round(img.height // scale)), resample=self.interpolation) - x1 = (img.width - self.size[0]) // 2 - y1 = (img.height - self.size[1]) // 2 - img = img.crop((x1, y1, x1 + self.size[0], y1 + self.size[1])) - return img - - - -class RandomCrop(object): - - def __init__(self, size=224, min_area=0.4): - self.size = size - self.min_area = min_area - - def __call__(self, rgb): - - # consistent crop between rgb and m - w, h = rgb[0].size - area = w * h - out_w, out_h = float('inf'), float('inf') - while out_w > w or out_h > h: - target_area = random.uniform(self.min_area, 1.0) * area - aspect_ratio = random.uniform(3. / 4., 4. / 3.) - out_w = int(round(math.sqrt(target_area * aspect_ratio))) - out_h = int(round(math.sqrt(target_area / aspect_ratio))) - x1 = random.randint(0, w - out_w) - y1 = random.randint(0, h - out_h) - - rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb] - rgb = [u.resize((self.size, self.size), Image.BILINEAR) for u in rgb] - - return rgb - -class RandomCropV2(object): - - def __init__(self, size=224, min_area=0.4, ratio=(3. / 4., 4. / 3.)): - if isinstance(size, (tuple, list)): - self.size = size - else: - self.size = (size, size) - self.min_area = min_area - self.ratio = ratio - - def _get_params(self, img): - width, height = img.size - area = height * width - - for _ in range(10): - target_area = random.uniform(self.min_area, 1.0) * area - log_ratio = (math.log(self.ratio[0]), math.log(self.ratio[1])) - aspect_ratio = math.exp(random.uniform(*log_ratio)) - - w = int(round(math.sqrt(target_area * aspect_ratio))) - h = int(round(math.sqrt(target_area / aspect_ratio))) - - if 0 < w <= width and 0 < h <= height: - i = random.randint(0, height - h) - j = random.randint(0, width - w) - return i, j, h, w - - # Fallback to central crop - in_ratio = float(width) / float(height) - if (in_ratio < min(self.ratio)): - w = width - h = int(round(w / min(self.ratio))) - elif (in_ratio > max(self.ratio)): - h = height - w = int(round(h * max(self.ratio))) - else: # whole image - w = width - h = height - i = (height - h) // 2 - j = (width - w) // 2 - return i, j, h, w - - def __call__(self, rgb): - i, j, h, w = self._get_params(rgb[0]) - rgb = [F.resized_crop(u, i, j, h, w, self.size) for u in rgb] - return rgb - -class RandomHFlip(object): - - def __init__(self, p=0.5): - self.p = p - - def __call__(self, rgb): - if random.random() < self.p: - rgb = [u.transpose(Image.FLIP_LEFT_RIGHT) for u in rgb] - return rgb - -class GaussianBlur(object): - - def __init__(self, sigmas=[0.1, 2.0], p=0.5): - self.sigmas = sigmas - self.p = p - - def __call__(self, rgb): - if random.random() < self.p: - sigma = random.uniform(*self.sigmas) - rgb = [u.filter(ImageFilter.GaussianBlur(radius=sigma)) for u in rgb] - return rgb - -class ColorJitter(object): - - def __init__(self, brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1, p=0.5): - self.brightness = brightness - self.contrast = contrast - self.saturation = saturation - self.hue = hue - self.p = p - - def __call__(self, rgb): - if random.random() < self.p: - brightness, contrast, saturation, hue = self._random_params() - transforms = [ - lambda f: F.adjust_brightness(f, brightness), - lambda f: F.adjust_contrast(f, contrast), - lambda f: F.adjust_saturation(f, saturation), - lambda f: F.adjust_hue(f, hue)] - random.shuffle(transforms) - for t in transforms: - rgb = [t(u) for u in rgb] - - return rgb - - def _random_params(self): - brightness = random.uniform( - max(0, 1 - self.brightness), 1 + self.brightness) - contrast = random.uniform( - max(0, 1 - self.contrast), 1 + self.contrast) - saturation = random.uniform( - max(0, 1 - self.saturation), 1 + self.saturation) - hue = random.uniform(-self.hue, self.hue) - return brightness, contrast, saturation, hue - -class RandomGray(object): - - def __init__(self, p=0.2): - self.p = p - - def __call__(self, rgb): - if random.random() < self.p: - rgb = [u.convert('L').convert('RGB') for u in rgb] - return rgb - -class ToTensor(object): - - def __call__(self, rgb): - if isinstance(rgb, list): - rgb = torch.stack([F.to_tensor(u) for u in rgb], dim=0) - else: - rgb = F.to_tensor(rgb) - - return rgb - -class Normalize(object): - - def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): - self.mean = mean - self.std = std - - def __call__(self, rgb): - rgb = rgb.clone() - rgb.clamp_(0, 1) - if not isinstance(self.mean, torch.Tensor): - self.mean = rgb.new_tensor(self.mean).view(-1) - if not isinstance(self.std, torch.Tensor): - self.std = rgb.new_tensor(self.std).view(-1) - if rgb.dim() == 4: - rgb.sub_(self.mean.view(1, -1, 1, 1)).div_(self.std.view(1, -1, 1, 1)) - elif rgb.dim() == 3: - rgb.sub_(self.mean.view(-1, 1, 1)).div_(self.std.view(-1, 1, 1)) - return rgb - ->>>>>>> 626e7afc02230297b6f553675ea1c32c29971314 diff --git a/utils/util.py b/utils/util.py deleted file mode 100644 index e93d55b..0000000 --- a/utils/util.py +++ /dev/null @@ -1,35 +0,0 @@ -<<<<<<< HEAD -import torch - -def to_device(batch, device, non_blocking=False): - if isinstance(batch, (list, tuple)): - return type(batch)([ - to_device(u, device, non_blocking) - for u in batch]) - elif isinstance(batch, dict): - return type(batch)([ - (k, to_device(v, device, non_blocking)) - for k, v in batch.items()]) - elif isinstance(batch, torch.Tensor) and batch.device != device: - batch = batch.to(device, non_blocking=non_blocking) - else: - return batch - return batch -======= -import torch - -def to_device(batch, device, non_blocking=False): - if isinstance(batch, (list, tuple)): - return type(batch)([ - to_device(u, device, non_blocking) - for u in batch]) - elif isinstance(batch, dict): - return type(batch)([ - (k, to_device(v, device, non_blocking)) - for k, v in batch.items()]) - elif isinstance(batch, torch.Tensor) and batch.device != device: - batch = batch.to(device, non_blocking=non_blocking) - else: - return batch - return batch ->>>>>>> 626e7afc02230297b6f553675ea1c32c29971314 diff --git a/utils/video_op.py b/utils/video_op.py deleted file mode 100644 index 03cacaa..0000000 --- a/utils/video_op.py +++ /dev/null @@ -1,721 +0,0 @@ -<<<<<<< HEAD -import os -import os.path as osp -import sys -import cv2 -import glob -import math -import torch -import gzip -import copy -import time -import json -import pickle -import base64 -import imageio -import hashlib -import requests -import binascii -import zipfile -# import skvideo.io -import numpy as np -from io import BytesIO -import urllib.request -import torch.nn.functional as F -import torchvision.utils as tvutils -from multiprocessing.pool import ThreadPool as Pool -from einops import rearrange -from PIL import Image, ImageDraw, ImageFont - - -def gen_text_image(captions, text_size): - num_char = int(38 * (text_size / text_size)) - font_size = int(text_size / 20) - font = ImageFont.truetype('data/font/DejaVuSans.ttf', size=font_size) - text_image_list = [] - for text in captions: - txt_img = Image.new("RGB", (text_size, text_size), color="white") - draw = ImageDraw.Draw(txt_img) - lines = "\n".join(text[start:start + num_char] for start in range(0, len(text), num_char)) - draw.text((0, 0), lines, fill="black", font=font) - txt_img = np.array(txt_img) - text_image_list.append(txt_img) - text_images = np.stack(text_image_list, axis=0) - text_images = torch.from_numpy(text_images) - return text_images - -@torch.no_grad() -def save_video_refimg_and_text( - local_path, - ref_frame, - gen_video, - captions, - mean=[0.5, 0.5, 0.5], - std=[0.5, 0.5, 0.5], - text_size=256, - nrow=4, - save_fps=8, - retry=5): - ''' - gen_video: BxCxFxHxW - ''' - nrow = max(int(gen_video.size(0) / 2), 1) - vid_mean = torch.tensor(mean, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw - vid_std = torch.tensor(std, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw - - text_images = gen_text_image(captions, text_size) # Tensor 8x256x256x3 - text_images = text_images.unsqueeze(1) # Tensor 8x1x256x256x3 - text_images = text_images.repeat_interleave(repeats=gen_video.size(2), dim=1) # 8x16x256x256x3 - - ref_frame = ref_frame.unsqueeze(2) - ref_frame = ref_frame.mul_(vid_std).add_(vid_mean) - ref_frame = ref_frame.repeat_interleave(repeats=gen_video.size(2), dim=2) # 8x16x256x256x3 - ref_frame.clamp_(0, 1) - ref_frame = ref_frame * 255.0 - ref_frame = rearrange(ref_frame, 'b c f h w -> b f h w c') - - gen_video = gen_video.mul_(vid_std).add_(vid_mean) # 8x3x16x256x384 - gen_video.clamp_(0, 1) - gen_video = gen_video * 255.0 - - images = rearrange(gen_video, 'b c f h w -> b f h w c') - images = torch.cat([ref_frame, images, text_images], dim=3) - - images = rearrange(images, '(r j) f h w c -> f (r h) (j w) c', r=nrow) - images = [(img.numpy()).astype('uint8') for img in images] - - for _ in [None] * retry: - try: - if len(images) == 1: - local_path = local_path + '.png' - cv2.imwrite(local_path, images[0][:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100]) - else: - local_path = local_path + '.mp4' - frame_dir = os.path.join(os.path.dirname(local_path), '%s_frames' % (os.path.basename(local_path))) - os.system(f'rm -rf {frame_dir}'); os.makedirs(frame_dir, exist_ok=True) - for fid, frame in enumerate(images): - tpth = os.path.join(frame_dir, '%04d.png' % (fid+1)) - cv2.imwrite(tpth, frame[:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100]) - cmd = f'ffmpeg -y -f image2 -loglevel quiet -framerate {save_fps} -i {frame_dir}/%04d.png -vcodec libx264 -crf 17 -pix_fmt yuv420p {local_path}' - os.system(cmd); os.system(f'rm -rf {frame_dir}') - # os.system(f'rm -rf {local_path}') - exception = None - break - except Exception as e: - exception = e - continue - - -@torch.no_grad() -def save_i2vgen_video( - local_path, - image_id, - gen_video, - captions, - mean=[0.5, 0.5, 0.5], - std=[0.5, 0.5, 0.5], - text_size=256, - retry=5, - save_fps = 8 -): - ''' - Save both the generated video and the input conditions. - ''' - vid_mean = torch.tensor(mean, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw - vid_std = torch.tensor(std, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw - - text_images = gen_text_image(captions, text_size) # Tensor 1x256x256x3 - text_images = text_images.unsqueeze(1) # Tensor 1x1x256x256x3 - text_images = text_images.repeat_interleave(repeats=gen_video.size(2), dim=1) # 1x16x256x256x3 - - image_id = image_id.unsqueeze(2) # B, C, F, H, W - image_id = image_id.repeat_interleave(repeats=gen_video.size(2), dim=2) # 1x3x32x256x448 - image_id = image_id.mul_(vid_std).add_(vid_mean) # 32x3x256x448 - image_id.clamp_(0, 1) - image_id = image_id * 255.0 - image_id = rearrange(image_id, 'b c f h w -> b f h w c') - - gen_video = gen_video.mul_(vid_std).add_(vid_mean) # 8x3x16x256x384 - gen_video.clamp_(0, 1) - gen_video = gen_video * 255.0 - - images = rearrange(gen_video, 'b c f h w -> b f h w c') - images = torch.cat([image_id, images, text_images], dim=3) - images = images[0] - images = [(img.numpy()).astype('uint8') for img in images] - - exception = None - for _ in [None] * retry: - try: - frame_dir = os.path.join(os.path.dirname(local_path), '%s_frames' % (os.path.basename(local_path))) - os.system(f'rm -rf {frame_dir}'); os.makedirs(frame_dir, exist_ok=True) - for fid, frame in enumerate(images): - tpth = os.path.join(frame_dir, '%04d.png' % (fid+1)) - cv2.imwrite(tpth, frame[:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100]) - cmd = f'ffmpeg -y -f image2 -loglevel quiet -framerate {save_fps} -i {frame_dir}/%04d.png -vcodec libx264 -crf 17 -pix_fmt yuv420p {local_path}' - os.system(cmd); os.system(f'rm -rf {frame_dir}') - break - except Exception as e: - exception = e - continue - - if exception is not None: - raise exception - - -@torch.no_grad() -def save_i2vgen_video_safe( - local_path, - gen_video, - captions, - mean=[0.5, 0.5, 0.5], - std=[0.5, 0.5, 0.5], - text_size=256, - retry=5, - save_fps = 8 -): - ''' - Save only the generated video, do not save the related reference conditions, and at the same time perform anomaly detection on the last frame. - ''' - vid_mean = torch.tensor(mean, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw - vid_std = torch.tensor(std, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw - - gen_video = gen_video.mul_(vid_std).add_(vid_mean) # 8x3x16x256x384 - gen_video.clamp_(0, 1) - gen_video = gen_video * 255.0 - - images = rearrange(gen_video, 'b c f h w -> b f h w c') - images = images[0] - images = [(img.numpy()).astype('uint8') for img in images] - num_image = len(images) - exception = None - for _ in [None] * retry: - try: - if num_image == 1: - local_path = local_path + '.png' - cv2.imwrite(local_path, images[0][:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100]) - else: - writer = imageio.get_writer(local_path, fps=save_fps, codec='libx264', quality=8) - for fid, frame in enumerate(images): - if fid == num_image-1: # Fix known bugs. - ratio = (np.sum((frame >= 117) & (frame <= 137)))/(frame.size) - if ratio > 0.4: continue - writer.append_data(frame) - writer.close() - break - except Exception as e: - exception = e - continue - - if exception is not None: - raise exception - - -@torch.no_grad() -def save_t2vhigen_video_safe( - local_path, - gen_video, - captions, - mean=[0.5, 0.5, 0.5], - std=[0.5, 0.5, 0.5], - text_size=256, - retry=5, - save_fps = 8 -): - ''' - Save only the generated video, do not save the related reference conditions, and at the same time perform anomaly detection on the last frame. - ''' - vid_mean = torch.tensor(mean, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw - vid_std = torch.tensor(std, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw - - gen_video = gen_video.mul_(vid_std).add_(vid_mean) # 8x3x16x256x384 - gen_video.clamp_(0, 1) - gen_video = gen_video * 255.0 - - images = rearrange(gen_video, 'b c f h w -> b f h w c') - images = images[0] - images = [(img.numpy()).astype('uint8') for img in images] - num_image = len(images) - exception = None - for _ in [None] * retry: - try: - if num_image == 1: - local_path = local_path + '.png' - cv2.imwrite(local_path, images[0][:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100]) - else: - frame_dir = os.path.join(os.path.dirname(local_path), '%s_frames' % (os.path.basename(local_path))) - os.system(f'rm -rf {frame_dir}'); os.makedirs(frame_dir, exist_ok=True) - for fid, frame in enumerate(images): - if fid == num_image-1: # Fix known bugs. - ratio = (np.sum((frame >= 117) & (frame <= 137)))/(frame.size) - if ratio > 0.4: continue - tpth = os.path.join(frame_dir, '%04d.png' % (fid+1)) - cv2.imwrite(tpth, frame[:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100]) - cmd = f'ffmpeg -y -f image2 -loglevel quiet -framerate {save_fps} -i {frame_dir}/%04d.png -vcodec libx264 -crf 17 -pix_fmt yuv420p {local_path}' - os.system(cmd) - os.system(f'rm -rf {frame_dir}') - break - except Exception as e: - exception = e - continue - - if exception is not None: - raise exception - - - - -@torch.no_grad() -def save_video_multiple_conditions_not_gif_horizontal_3col(local_path, video_tensor, model_kwargs, source_imgs, - mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5], nrow=8, retry=5, save_fps=8): - mean=torch.tensor(mean,device=video_tensor.device).view(1,-1,1,1,1)#ncfhw - std=torch.tensor(std,device=video_tensor.device).view(1,-1,1,1,1)#ncfhw - video_tensor = video_tensor.mul_(std).add_(mean) #### unnormalize back to [0,1] - video_tensor.clamp_(0, 1) - - b, c, n, h, w = video_tensor.shape - source_imgs = F.adaptive_avg_pool3d(source_imgs, (n, h, w)) - source_imgs = source_imgs.cpu() - - model_kwargs_channel3 = {} - for key, conditions in model_kwargs[0].items(): - - - if conditions.size(1) == 1: - conditions = torch.cat([conditions, conditions, conditions], dim=1) - conditions = F.adaptive_avg_pool3d(conditions, (n, h, w)) - if conditions.size(1) == 2: - conditions = torch.cat([conditions, conditions[:,:1,]], dim=1) - conditions = F.adaptive_avg_pool3d(conditions, (n, h, w)) - elif conditions.size(1) == 3: - conditions = F.adaptive_avg_pool3d(conditions, (n, h, w)) - elif conditions.size(1) == 4: # means it is a mask. - color = ((conditions[:, 0:3] + 1.)/2.) # .astype(np.float32) - alpha = conditions[:, 3:4] # .astype(np.float32) - conditions = color * alpha + 1.0 * (1.0 - alpha) - conditions = F.adaptive_avg_pool3d(conditions, (n, h, w)) - model_kwargs_channel3[key] = conditions.cpu() if conditions.is_cuda else conditions - - # filename = rand_name(suffix='.gif') - for _ in [None] * retry: - try: - vid_gif = rearrange(video_tensor, '(i j) c f h w -> c f (i h) (j w)', i = nrow) - - # cons_list = [rearrange(con, '(i j) c f h w -> c f (i h) (j w)', i = nrow) for _, con in model_kwargs_channel3.items()] - # vid_gif = torch.cat(cons_list + [vid_gif,], dim=3) #Uncomment this and previous line to compare output video with input pose frames - - vid_gif = vid_gif.permute(1,2,3,0) - - images = vid_gif * 255.0 - images = [(img.numpy()).astype('uint8') for img in images] - if len(images) == 1: - - local_path = local_path.replace('.mp4', '.png') - cv2.imwrite(local_path, images[0][:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100]) - # bucket.put_object_from_file(oss_key, local_path) - else: - - outputs = [] - for image_name in images: - x = Image.fromarray(image_name) - outputs.append(x) - from pathlib import Path - save_fmt = Path(local_path).suffix - - if save_fmt == ".mp4": - with imageio.get_writer(local_path, fps=save_fps) as writer: - for img in outputs: - img_array = np.array(img) # Convert PIL Image to numpy array - writer.append_data(img_array) - - elif save_fmt == ".gif": - outputs[0].save( - fp=local_path, - format="GIF", - append_images=outputs[1:], - save_all=True, - duration=(1 / save_fps * 1000), - loop=0, - ) - else: - raise ValueError("Unsupported file type. Use .mp4 or .gif.") - - # fourcc = cv2.VideoWriter_fourcc(*'mp4v') - # fps = save_fps - # image = images[0] - # media_writer = cv2.VideoWriter(local_path, fourcc, fps, (image.shape[1],image.shape[0])) - # for image_name in images: - # im = image_name[:,:,::-1] - # media_writer.write(im) - # media_writer.release() - - - exception = None - break - except Exception as e: - exception = e - continue - if exception is not None: - print('save video to {} failed, error: {}'.format(local_path, exception), flush=True) - -======= -import os -import os.path as osp -import sys -import cv2 -import glob -import math -import torch -import gzip -import copy -import time -import json -import pickle -import base64 -import imageio -import hashlib -import requests -import binascii -import zipfile -# import skvideo.io -import numpy as np -from io import BytesIO -import urllib.request -import torch.nn.functional as F -import torchvision.utils as tvutils -from multiprocessing.pool import ThreadPool as Pool -from einops import rearrange -from PIL import Image, ImageDraw, ImageFont - - -def gen_text_image(captions, text_size): - num_char = int(38 * (text_size / text_size)) - font_size = int(text_size / 20) - font = ImageFont.truetype('data/font/DejaVuSans.ttf', size=font_size) - text_image_list = [] - for text in captions: - txt_img = Image.new("RGB", (text_size, text_size), color="white") - draw = ImageDraw.Draw(txt_img) - lines = "\n".join(text[start:start + num_char] for start in range(0, len(text), num_char)) - draw.text((0, 0), lines, fill="black", font=font) - txt_img = np.array(txt_img) - text_image_list.append(txt_img) - text_images = np.stack(text_image_list, axis=0) - text_images = torch.from_numpy(text_images) - return text_images - -@torch.no_grad() -def save_video_refimg_and_text( - local_path, - ref_frame, - gen_video, - captions, - mean=[0.5, 0.5, 0.5], - std=[0.5, 0.5, 0.5], - text_size=256, - nrow=4, - save_fps=8, - retry=5): - ''' - gen_video: BxCxFxHxW - ''' - nrow = max(int(gen_video.size(0) / 2), 1) - vid_mean = torch.tensor(mean, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw - vid_std = torch.tensor(std, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw - - text_images = gen_text_image(captions, text_size) # Tensor 8x256x256x3 - text_images = text_images.unsqueeze(1) # Tensor 8x1x256x256x3 - text_images = text_images.repeat_interleave(repeats=gen_video.size(2), dim=1) # 8x16x256x256x3 - - ref_frame = ref_frame.unsqueeze(2) - ref_frame = ref_frame.mul_(vid_std).add_(vid_mean) - ref_frame = ref_frame.repeat_interleave(repeats=gen_video.size(2), dim=2) # 8x16x256x256x3 - ref_frame.clamp_(0, 1) - ref_frame = ref_frame * 255.0 - ref_frame = rearrange(ref_frame, 'b c f h w -> b f h w c') - - gen_video = gen_video.mul_(vid_std).add_(vid_mean) # 8x3x16x256x384 - gen_video.clamp_(0, 1) - gen_video = gen_video * 255.0 - - images = rearrange(gen_video, 'b c f h w -> b f h w c') - images = torch.cat([ref_frame, images, text_images], dim=3) - - images = rearrange(images, '(r j) f h w c -> f (r h) (j w) c', r=nrow) - images = [(img.numpy()).astype('uint8') for img in images] - - for _ in [None] * retry: - try: - if len(images) == 1: - local_path = local_path + '.png' - cv2.imwrite(local_path, images[0][:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100]) - else: - local_path = local_path + '.mp4' - frame_dir = os.path.join(os.path.dirname(local_path), '%s_frames' % (os.path.basename(local_path))) - os.system(f'rm -rf {frame_dir}'); os.makedirs(frame_dir, exist_ok=True) - for fid, frame in enumerate(images): - tpth = os.path.join(frame_dir, '%04d.png' % (fid+1)) - cv2.imwrite(tpth, frame[:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100]) - cmd = f'ffmpeg -y -f image2 -loglevel quiet -framerate {save_fps} -i {frame_dir}/%04d.png -vcodec libx264 -crf 17 -pix_fmt yuv420p {local_path}' - os.system(cmd); os.system(f'rm -rf {frame_dir}') - # os.system(f'rm -rf {local_path}') - exception = None - break - except Exception as e: - exception = e - continue - - -@torch.no_grad() -def save_i2vgen_video( - local_path, - image_id, - gen_video, - captions, - mean=[0.5, 0.5, 0.5], - std=[0.5, 0.5, 0.5], - text_size=256, - retry=5, - save_fps = 8 -): - ''' - Save both the generated video and the input conditions. - ''' - vid_mean = torch.tensor(mean, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw - vid_std = torch.tensor(std, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw - - text_images = gen_text_image(captions, text_size) # Tensor 1x256x256x3 - text_images = text_images.unsqueeze(1) # Tensor 1x1x256x256x3 - text_images = text_images.repeat_interleave(repeats=gen_video.size(2), dim=1) # 1x16x256x256x3 - - image_id = image_id.unsqueeze(2) # B, C, F, H, W - image_id = image_id.repeat_interleave(repeats=gen_video.size(2), dim=2) # 1x3x32x256x448 - image_id = image_id.mul_(vid_std).add_(vid_mean) # 32x3x256x448 - image_id.clamp_(0, 1) - image_id = image_id * 255.0 - image_id = rearrange(image_id, 'b c f h w -> b f h w c') - - gen_video = gen_video.mul_(vid_std).add_(vid_mean) # 8x3x16x256x384 - gen_video.clamp_(0, 1) - gen_video = gen_video * 255.0 - - images = rearrange(gen_video, 'b c f h w -> b f h w c') - images = torch.cat([image_id, images, text_images], dim=3) - images = images[0] - images = [(img.numpy()).astype('uint8') for img in images] - - exception = None - for _ in [None] * retry: - try: - frame_dir = os.path.join(os.path.dirname(local_path), '%s_frames' % (os.path.basename(local_path))) - os.system(f'rm -rf {frame_dir}'); os.makedirs(frame_dir, exist_ok=True) - for fid, frame in enumerate(images): - tpth = os.path.join(frame_dir, '%04d.png' % (fid+1)) - cv2.imwrite(tpth, frame[:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100]) - cmd = f'ffmpeg -y -f image2 -loglevel quiet -framerate {save_fps} -i {frame_dir}/%04d.png -vcodec libx264 -crf 17 -pix_fmt yuv420p {local_path}' - os.system(cmd); os.system(f'rm -rf {frame_dir}') - break - except Exception as e: - exception = e - continue - - if exception is not None: - raise exception - - -@torch.no_grad() -def save_i2vgen_video_safe( - local_path, - gen_video, - captions, - mean=[0.5, 0.5, 0.5], - std=[0.5, 0.5, 0.5], - text_size=256, - retry=5, - save_fps = 8 -): - ''' - Save only the generated video, do not save the related reference conditions, and at the same time perform anomaly detection on the last frame. - ''' - vid_mean = torch.tensor(mean, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw - vid_std = torch.tensor(std, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw - - gen_video = gen_video.mul_(vid_std).add_(vid_mean) # 8x3x16x256x384 - gen_video.clamp_(0, 1) - gen_video = gen_video * 255.0 - - images = rearrange(gen_video, 'b c f h w -> b f h w c') - images = images[0] - images = [(img.numpy()).astype('uint8') for img in images] - num_image = len(images) - exception = None - for _ in [None] * retry: - try: - if num_image == 1: - local_path = local_path + '.png' - cv2.imwrite(local_path, images[0][:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100]) - else: - writer = imageio.get_writer(local_path, fps=save_fps, codec='libx264', quality=8) - for fid, frame in enumerate(images): - if fid == num_image-1: # Fix known bugs. - ratio = (np.sum((frame >= 117) & (frame <= 137)))/(frame.size) - if ratio > 0.4: continue - writer.append_data(frame) - writer.close() - break - except Exception as e: - exception = e - continue - - if exception is not None: - raise exception - - -@torch.no_grad() -def save_t2vhigen_video_safe( - local_path, - gen_video, - captions, - mean=[0.5, 0.5, 0.5], - std=[0.5, 0.5, 0.5], - text_size=256, - retry=5, - save_fps = 8 -): - ''' - Save only the generated video, do not save the related reference conditions, and at the same time perform anomaly detection on the last frame. - ''' - vid_mean = torch.tensor(mean, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw - vid_std = torch.tensor(std, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw - - gen_video = gen_video.mul_(vid_std).add_(vid_mean) # 8x3x16x256x384 - gen_video.clamp_(0, 1) - gen_video = gen_video * 255.0 - - images = rearrange(gen_video, 'b c f h w -> b f h w c') - images = images[0] - images = [(img.numpy()).astype('uint8') for img in images] - num_image = len(images) - exception = None - for _ in [None] * retry: - try: - if num_image == 1: - local_path = local_path + '.png' - cv2.imwrite(local_path, images[0][:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100]) - else: - frame_dir = os.path.join(os.path.dirname(local_path), '%s_frames' % (os.path.basename(local_path))) - os.system(f'rm -rf {frame_dir}'); os.makedirs(frame_dir, exist_ok=True) - for fid, frame in enumerate(images): - if fid == num_image-1: # Fix known bugs. - ratio = (np.sum((frame >= 117) & (frame <= 137)))/(frame.size) - if ratio > 0.4: continue - tpth = os.path.join(frame_dir, '%04d.png' % (fid+1)) - cv2.imwrite(tpth, frame[:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100]) - cmd = f'ffmpeg -y -f image2 -loglevel quiet -framerate {save_fps} -i {frame_dir}/%04d.png -vcodec libx264 -crf 17 -pix_fmt yuv420p {local_path}' - os.system(cmd) - os.system(f'rm -rf {frame_dir}') - break - except Exception as e: - exception = e - continue - - if exception is not None: - raise exception - - - - -@torch.no_grad() -def save_video_multiple_conditions_not_gif_horizontal_3col(local_path, video_tensor, model_kwargs, source_imgs, - mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5], nrow=8, retry=5, save_fps=8): - mean=torch.tensor(mean,device=video_tensor.device).view(1,-1,1,1,1)#ncfhw - std=torch.tensor(std,device=video_tensor.device).view(1,-1,1,1,1)#ncfhw - video_tensor = video_tensor.mul_(std).add_(mean) #### unnormalize back to [0,1] - video_tensor.clamp_(0, 1) - - b, c, n, h, w = video_tensor.shape - source_imgs = F.adaptive_avg_pool3d(source_imgs, (n, h, w)) - source_imgs = source_imgs.cpu() - - model_kwargs_channel3 = {} - for key, conditions in model_kwargs[0].items(): - - - if conditions.size(1) == 1: - conditions = torch.cat([conditions, conditions, conditions], dim=1) - conditions = F.adaptive_avg_pool3d(conditions, (n, h, w)) - if conditions.size(1) == 2: - conditions = torch.cat([conditions, conditions[:,:1,]], dim=1) - conditions = F.adaptive_avg_pool3d(conditions, (n, h, w)) - elif conditions.size(1) == 3: - conditions = F.adaptive_avg_pool3d(conditions, (n, h, w)) - elif conditions.size(1) == 4: # means it is a mask. - color = ((conditions[:, 0:3] + 1.)/2.) # .astype(np.float32) - alpha = conditions[:, 3:4] # .astype(np.float32) - conditions = color * alpha + 1.0 * (1.0 - alpha) - conditions = F.adaptive_avg_pool3d(conditions, (n, h, w)) - model_kwargs_channel3[key] = conditions.cpu() if conditions.is_cuda else conditions - - # filename = rand_name(suffix='.gif') - for _ in [None] * retry: - try: - vid_gif = rearrange(video_tensor, '(i j) c f h w -> c f (i h) (j w)', i = nrow) - - # cons_list = [rearrange(con, '(i j) c f h w -> c f (i h) (j w)', i = nrow) for _, con in model_kwargs_channel3.items()] - # vid_gif = torch.cat(cons_list + [vid_gif,], dim=3) #Uncomment this and previous line to compare output video with input pose frames - - vid_gif = vid_gif.permute(1,2,3,0) - - images = vid_gif * 255.0 - images = [(img.numpy()).astype('uint8') for img in images] - if len(images) == 1: - - local_path = local_path.replace('.mp4', '.png') - cv2.imwrite(local_path, images[0][:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100]) - # bucket.put_object_from_file(oss_key, local_path) - else: - - outputs = [] - for image_name in images: - x = Image.fromarray(image_name) - outputs.append(x) - from pathlib import Path - save_fmt = Path(local_path).suffix - - if save_fmt == ".mp4": - with imageio.get_writer(local_path, fps=save_fps) as writer: - for img in outputs: - img_array = np.array(img) # Convert PIL Image to numpy array - writer.append_data(img_array) - - elif save_fmt == ".gif": - outputs[0].save( - fp=local_path, - format="GIF", - append_images=outputs[1:], - save_all=True, - duration=(1 / save_fps * 1000), - loop=0, - ) - else: - raise ValueError("Unsupported file type. Use .mp4 or .gif.") - - # fourcc = cv2.VideoWriter_fourcc(*'mp4v') - # fps = save_fps - # image = images[0] - # media_writer = cv2.VideoWriter(local_path, fourcc, fps, (image.shape[1],image.shape[0])) - # for image_name in images: - # im = image_name[:,:,::-1] - # media_writer.write(im) - # media_writer.release() - - - exception = None - break - except Exception as e: - exception = e - continue - if exception is not None: - print('save video to {} failed, error: {}'.format(local_path, exception), flush=True) - ->>>>>>> 626e7afc02230297b6f553675ea1c32c29971314