forked from Isi-dev/ComfyUI-UniAnimate-W
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
25 changed files
with
2,081 additions
and
0 deletions.
There are no files selected for viewing
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import os, yaml | ||
from copy import deepcopy, copy | ||
|
||
|
||
# def get prior and ldm config | ||
def assign_prior_mudule_cfg(cfg): | ||
''' | ||
''' | ||
# | ||
prior_cfg = deepcopy(cfg) | ||
vldm_cfg = deepcopy(cfg) | ||
|
||
with open(cfg.prior_cfg, 'r') as f: | ||
_cfg_update = yaml.load(f.read(), Loader=yaml.SafeLoader) | ||
# _cfg_update = _cfg_update.cfg_dict | ||
for k, v in _cfg_update.items(): | ||
if isinstance(v, dict) and k in cfg: | ||
prior_cfg[k].update(v) | ||
else: | ||
prior_cfg[k] = v | ||
|
||
with open(cfg.vldm_cfg, 'r') as f: | ||
_cfg_update = yaml.load(f.read(), Loader=yaml.SafeLoader) | ||
# _cfg_update = _cfg_update.cfg_dict | ||
for k, v in _cfg_update.items(): | ||
if isinstance(v, dict) and k in cfg: | ||
vldm_cfg[k].update(v) | ||
else: | ||
vldm_cfg[k] = v | ||
|
||
return prior_cfg, vldm_cfg | ||
|
||
|
||
# def get prior and ldm config | ||
def assign_vldm_vsr_mudule_cfg(cfg): | ||
''' | ||
''' | ||
# | ||
vldm_cfg = deepcopy(cfg) | ||
vsr_cfg = deepcopy(cfg) | ||
|
||
with open(cfg.vldm_cfg, 'r') as f: | ||
_cfg_update = yaml.load(f.read(), Loader=yaml.SafeLoader) | ||
# _cfg_update = _cfg_update.cfg_dict | ||
for k, v in _cfg_update.items(): | ||
if isinstance(v, dict) and k in cfg: | ||
vldm_cfg[k].update(v) | ||
else: | ||
vldm_cfg[k] = v | ||
|
||
with open(cfg.vsr_cfg, 'r') as f: | ||
_cfg_update = yaml.load(f.read(), Loader=yaml.SafeLoader) | ||
# _cfg_update = _cfg_update.cfg_dict | ||
for k, v in _cfg_update.items(): | ||
if isinstance(v, dict) and k in cfg: | ||
vsr_cfg[k].update(v) | ||
else: | ||
vsr_cfg[k] = v | ||
|
||
return vldm_cfg, vsr_cfg | ||
|
||
|
||
# def get prior and ldm config | ||
def assign_signle_cfg(cfg, _cfg_update, tname): | ||
''' | ||
''' | ||
# | ||
vldm_cfg = deepcopy(cfg) | ||
if os.path.exists(_cfg_update[tname]): | ||
with open(_cfg_update[tname], 'r') as f: | ||
_cfg_update = yaml.load(f.read(), Loader=yaml.SafeLoader) | ||
# _cfg_update = _cfg_update.cfg_dict | ||
for k, v in _cfg_update.items(): | ||
if isinstance(v, dict) and k in cfg: | ||
vldm_cfg[k].update(v) | ||
else: | ||
vldm_cfg[k] = v | ||
return vldm_cfg |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,243 @@ | ||
import os | ||
import yaml | ||
import json | ||
import copy | ||
import argparse | ||
|
||
from ..utils import logging | ||
# logger = logging.get_logger(__name__) | ||
|
||
class Config(object): | ||
def __init__(self, load=True, cfg_dict=None, cfg_level=None): | ||
self._level = "cfg" + ("." + cfg_level if cfg_level is not None else "") | ||
|
||
current_directory = os.path.dirname(os.path.abspath(__file__)) | ||
parent_directory = os.path.dirname(current_directory) | ||
self.config_file_loc = os.path.join(parent_directory, 'configs/UniAnimate_infer.yaml') | ||
|
||
if load: | ||
self.args = self._parse_args() | ||
# logger.info("Loading config from {}.".format(self.args.cfg_file)) | ||
self.need_initialization = True | ||
cfg_base = self._load_yaml(self.args) # self._initialize_cfg() | ||
cfg_dict = self._load_yaml(self.args) | ||
cfg_dict = self._merge_cfg_from_base(cfg_base, cfg_dict) | ||
cfg_dict = self._update_from_args(cfg_dict) | ||
self.cfg_dict = cfg_dict | ||
self._update_dict(cfg_dict) | ||
|
||
def _parse_args(self): | ||
parser = argparse.ArgumentParser( | ||
description="Argparser for configuring the codebase" | ||
) | ||
parser.add_argument( | ||
"--cfg", | ||
dest="cfg_file", | ||
help="Path to the configuration file", | ||
default= self.config_file_loc | ||
) | ||
parser.add_argument( | ||
"--init_method", | ||
help="Initialization method, includes TCP or shared file-system", | ||
default="tcp://localhost:9999", | ||
type=str, | ||
) | ||
parser.add_argument( | ||
'--debug', | ||
action='store_true', | ||
default=False, | ||
help='Output debug information' | ||
) | ||
parser.add_argument( | ||
'--windows-standalone-build', | ||
action='store_true', | ||
default=False, | ||
help='Indicates if the build is a standalone build for Windows' | ||
) | ||
parser.add_argument( | ||
"opts", | ||
help="Other configurations", | ||
default=None, | ||
nargs=argparse.REMAINDER | ||
) | ||
return parser.parse_args() | ||
|
||
|
||
def _path_join(self, path_list): | ||
path = "" | ||
for p in path_list: | ||
path+= p + '/' | ||
return path[:-1] | ||
|
||
def _update_from_args(self, cfg_dict): | ||
args = self.args | ||
for var in vars(args): | ||
cfg_dict[var] = getattr(args, var) | ||
return cfg_dict | ||
|
||
def _initialize_cfg(self): | ||
if self.need_initialization: | ||
self.need_initialization = False | ||
if os.path.exists('./configs/base.yaml'): | ||
with open("./configs/base.yaml", 'r') as f: | ||
cfg = yaml.load(f.read(), Loader=yaml.SafeLoader) | ||
else: | ||
with open(os.path.realpath(__file__).split('/')[-3] + "/configs/base.yaml", 'r') as f: | ||
cfg = yaml.load(f.read(), Loader=yaml.SafeLoader) | ||
return cfg | ||
|
||
def _load_yaml(self, args, file_name=""): | ||
assert args.cfg_file is not None | ||
if not file_name == "": # reading from base file | ||
with open(file_name, 'r') as f: | ||
cfg = yaml.load(f.read(), Loader=yaml.SafeLoader) | ||
else: | ||
if os.getcwd().split("/")[-1] == args.cfg_file.split("/")[0]: | ||
args.cfg_file = args.cfg_file.replace(os.getcwd().split("/")[-1], "./") | ||
with open(args.cfg_file, 'r') as f: | ||
cfg = yaml.load(f.read(), Loader=yaml.SafeLoader) | ||
file_name = args.cfg_file | ||
|
||
if "_BASE_RUN" not in cfg.keys() and "_BASE_MODEL" not in cfg.keys() and "_BASE" not in cfg.keys(): | ||
# return cfg if the base file is being accessed | ||
cfg = self._merge_cfg_from_command_update(args, cfg) | ||
return cfg | ||
|
||
if "_BASE" in cfg.keys(): | ||
if cfg["_BASE"][1] == '.': | ||
prev_count = cfg["_BASE"].count('..') | ||
cfg_base_file = self._path_join(file_name.split('/')[:(-1-cfg["_BASE"].count('..'))] + cfg["_BASE"].split('/')[prev_count:]) | ||
else: | ||
cfg_base_file = cfg["_BASE"].replace( | ||
"./", | ||
args.cfg_file.replace(args.cfg_file.split('/')[-1], "") | ||
) | ||
cfg_base = self._load_yaml(args, cfg_base_file) | ||
cfg = self._merge_cfg_from_base(cfg_base, cfg) | ||
else: | ||
if "_BASE_RUN" in cfg.keys(): | ||
if cfg["_BASE_RUN"][1] == '.': | ||
prev_count = cfg["_BASE_RUN"].count('..') | ||
cfg_base_file = self._path_join(file_name.split('/')[:(-1-prev_count)] + cfg["_BASE_RUN"].split('/')[prev_count:]) | ||
else: | ||
cfg_base_file = cfg["_BASE_RUN"].replace( | ||
"./", | ||
args.cfg_file.replace(args.cfg_file.split('/')[-1], "") | ||
) | ||
cfg_base = self._load_yaml(args, cfg_base_file) | ||
cfg = self._merge_cfg_from_base(cfg_base, cfg, preserve_base=True) | ||
if "_BASE_MODEL" in cfg.keys(): | ||
if cfg["_BASE_MODEL"][1] == '.': | ||
prev_count = cfg["_BASE_MODEL"].count('..') | ||
cfg_base_file = self._path_join(file_name.split('/')[:(-1-cfg["_BASE_MODEL"].count('..'))] + cfg["_BASE_MODEL"].split('/')[prev_count:]) | ||
else: | ||
cfg_base_file = cfg["_BASE_MODEL"].replace( | ||
"./", | ||
args.cfg_file.replace(args.cfg_file.split('/')[-1], "") | ||
) | ||
cfg_base = self._load_yaml(args, cfg_base_file) | ||
cfg = self._merge_cfg_from_base(cfg_base, cfg) | ||
cfg = self._merge_cfg_from_command(args, cfg) | ||
return cfg | ||
|
||
def _merge_cfg_from_base(self, cfg_base, cfg_new, preserve_base=False): | ||
for k,v in cfg_new.items(): | ||
if k in cfg_base.keys(): | ||
if isinstance(v, dict): | ||
self._merge_cfg_from_base(cfg_base[k], v) | ||
else: | ||
cfg_base[k] = v | ||
else: | ||
if "BASE" not in k or preserve_base: | ||
cfg_base[k] = v | ||
return cfg_base | ||
|
||
def _merge_cfg_from_command_update(self, args, cfg): | ||
if len(args.opts) == 0: | ||
return cfg | ||
|
||
assert len(args.opts) % 2 == 0, 'Override list {} has odd length: {}.'.format( | ||
args.opts, len(args.opts) | ||
) | ||
keys = args.opts[0::2] | ||
vals = args.opts[1::2] | ||
|
||
for key, val in zip(keys, vals): | ||
cfg[key] = val | ||
|
||
return cfg | ||
|
||
def _merge_cfg_from_command(self, args, cfg): | ||
assert len(args.opts) % 2 == 0, 'Override list {} has odd length: {}.'.format( | ||
args.opts, len(args.opts) | ||
) | ||
keys = args.opts[0::2] | ||
vals = args.opts[1::2] | ||
|
||
# maximum supported depth 3 | ||
for idx, key in enumerate(keys): | ||
key_split = key.split('.') | ||
assert len(key_split) <= 4, 'Key depth error. \nMaximum depth: 3\n Get depth: {}'.format( | ||
len(key_split) | ||
) | ||
assert key_split[0] in cfg.keys(), 'Non-existant key: {}.'.format( | ||
key_split[0] | ||
) | ||
if len(key_split) == 2: | ||
assert key_split[1] in cfg[key_split[0]].keys(), 'Non-existant key: {}.'.format( | ||
key | ||
) | ||
elif len(key_split) == 3: | ||
assert key_split[1] in cfg[key_split[0]].keys(), 'Non-existant key: {}.'.format( | ||
key | ||
) | ||
assert key_split[2] in cfg[key_split[0]][key_split[1]].keys(), 'Non-existant key: {}.'.format( | ||
key | ||
) | ||
elif len(key_split) == 4: | ||
assert key_split[1] in cfg[key_split[0]].keys(), 'Non-existant key: {}.'.format( | ||
key | ||
) | ||
assert key_split[2] in cfg[key_split[0]][key_split[1]].keys(), 'Non-existant key: {}.'.format( | ||
key | ||
) | ||
assert key_split[3] in cfg[key_split[0]][key_split[1]][key_split[2]].keys(), 'Non-existant key: {}.'.format( | ||
key | ||
) | ||
if len(key_split) == 1: | ||
cfg[key_split[0]] = vals[idx] | ||
elif len(key_split) == 2: | ||
cfg[key_split[0]][key_split[1]] = vals[idx] | ||
elif len(key_split) == 3: | ||
cfg[key_split[0]][key_split[1]][key_split[2]] = vals[idx] | ||
elif len(key_split) == 4: | ||
cfg[key_split[0]][key_split[1]][key_split[2]][key_split[3]] = vals[idx] | ||
return cfg | ||
|
||
def _update_dict(self, cfg_dict): | ||
def recur(key, elem): | ||
if type(elem) is dict: | ||
return key, Config(load=False, cfg_dict=elem, cfg_level=key) | ||
else: | ||
if type(elem) is str and elem[1:3]=="e-": | ||
elem = float(elem) | ||
return key, elem | ||
dic = dict(recur(k, v) for k, v in cfg_dict.items()) | ||
self.__dict__.update(dic) | ||
|
||
def get_args(self): | ||
return self.args | ||
|
||
def __repr__(self): | ||
return "{}\n".format(self.dump()) | ||
|
||
def dump(self): | ||
return json.dumps(self.cfg_dict, indent=2) | ||
|
||
def deep_copy(self): | ||
return copy.deepcopy(self) | ||
|
||
# if __name__ == '__main__': | ||
# # debug | ||
# cfg = Config(load=True) | ||
# print(cfg.DATA) |
Oops, something went wrong.