Skip to content

Commit

Permalink
Resolved merge conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
Isi-dev authored Aug 5, 2024
1 parent 93da5d2 commit 8f2e748
Show file tree
Hide file tree
Showing 25 changed files with 2,081 additions and 0 deletions.
Binary file added utils/__pycache__/assign_cfg.cpython-310.pyc
Binary file not shown.
Binary file added utils/__pycache__/config.cpython-310.pyc
Binary file not shown.
Binary file added utils/__pycache__/distributed.cpython-310.pyc
Binary file not shown.
Binary file added utils/__pycache__/logging.cpython-310.pyc
Binary file not shown.
Binary file added utils/__pycache__/multi_port.cpython-310.pyc
Binary file not shown.
Binary file added utils/__pycache__/registry.cpython-310.pyc
Binary file not shown.
Binary file added utils/__pycache__/registry_class.cpython-310.pyc
Binary file not shown.
Binary file added utils/__pycache__/seed.cpython-310.pyc
Binary file not shown.
Binary file added utils/__pycache__/transforms.cpython-310.pyc
Binary file not shown.
Binary file added utils/__pycache__/video_op.cpython-310.pyc
Binary file not shown.
78 changes: 78 additions & 0 deletions utils/assign_cfg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import os, yaml
from copy import deepcopy, copy


# def get prior and ldm config
def assign_prior_mudule_cfg(cfg):
'''
'''
#
prior_cfg = deepcopy(cfg)
vldm_cfg = deepcopy(cfg)

with open(cfg.prior_cfg, 'r') as f:
_cfg_update = yaml.load(f.read(), Loader=yaml.SafeLoader)
# _cfg_update = _cfg_update.cfg_dict
for k, v in _cfg_update.items():
if isinstance(v, dict) and k in cfg:
prior_cfg[k].update(v)
else:
prior_cfg[k] = v

with open(cfg.vldm_cfg, 'r') as f:
_cfg_update = yaml.load(f.read(), Loader=yaml.SafeLoader)
# _cfg_update = _cfg_update.cfg_dict
for k, v in _cfg_update.items():
if isinstance(v, dict) and k in cfg:
vldm_cfg[k].update(v)
else:
vldm_cfg[k] = v

return prior_cfg, vldm_cfg


# def get prior and ldm config
def assign_vldm_vsr_mudule_cfg(cfg):
'''
'''
#
vldm_cfg = deepcopy(cfg)
vsr_cfg = deepcopy(cfg)

with open(cfg.vldm_cfg, 'r') as f:
_cfg_update = yaml.load(f.read(), Loader=yaml.SafeLoader)
# _cfg_update = _cfg_update.cfg_dict
for k, v in _cfg_update.items():
if isinstance(v, dict) and k in cfg:
vldm_cfg[k].update(v)
else:
vldm_cfg[k] = v

with open(cfg.vsr_cfg, 'r') as f:
_cfg_update = yaml.load(f.read(), Loader=yaml.SafeLoader)
# _cfg_update = _cfg_update.cfg_dict
for k, v in _cfg_update.items():
if isinstance(v, dict) and k in cfg:
vsr_cfg[k].update(v)
else:
vsr_cfg[k] = v

return vldm_cfg, vsr_cfg


# def get prior and ldm config
def assign_signle_cfg(cfg, _cfg_update, tname):
'''
'''
#
vldm_cfg = deepcopy(cfg)
if os.path.exists(_cfg_update[tname]):
with open(_cfg_update[tname], 'r') as f:
_cfg_update = yaml.load(f.read(), Loader=yaml.SafeLoader)
# _cfg_update = _cfg_update.cfg_dict
for k, v in _cfg_update.items():
if isinstance(v, dict) and k in cfg:
vldm_cfg[k].update(v)
else:
vldm_cfg[k] = v
return vldm_cfg
243 changes: 243 additions & 0 deletions utils/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
import os
import yaml
import json
import copy
import argparse

from ..utils import logging
# logger = logging.get_logger(__name__)

class Config(object):
def __init__(self, load=True, cfg_dict=None, cfg_level=None):
self._level = "cfg" + ("." + cfg_level if cfg_level is not None else "")

current_directory = os.path.dirname(os.path.abspath(__file__))
parent_directory = os.path.dirname(current_directory)
self.config_file_loc = os.path.join(parent_directory, 'configs/UniAnimate_infer.yaml')

if load:
self.args = self._parse_args()
# logger.info("Loading config from {}.".format(self.args.cfg_file))
self.need_initialization = True
cfg_base = self._load_yaml(self.args) # self._initialize_cfg()
cfg_dict = self._load_yaml(self.args)
cfg_dict = self._merge_cfg_from_base(cfg_base, cfg_dict)
cfg_dict = self._update_from_args(cfg_dict)
self.cfg_dict = cfg_dict
self._update_dict(cfg_dict)

def _parse_args(self):
parser = argparse.ArgumentParser(
description="Argparser for configuring the codebase"
)
parser.add_argument(
"--cfg",
dest="cfg_file",
help="Path to the configuration file",
default= self.config_file_loc
)
parser.add_argument(
"--init_method",
help="Initialization method, includes TCP or shared file-system",
default="tcp://localhost:9999",
type=str,
)
parser.add_argument(
'--debug',
action='store_true',
default=False,
help='Output debug information'
)
parser.add_argument(
'--windows-standalone-build',
action='store_true',
default=False,
help='Indicates if the build is a standalone build for Windows'
)
parser.add_argument(
"opts",
help="Other configurations",
default=None,
nargs=argparse.REMAINDER
)
return parser.parse_args()


def _path_join(self, path_list):
path = ""
for p in path_list:
path+= p + '/'
return path[:-1]

def _update_from_args(self, cfg_dict):
args = self.args
for var in vars(args):
cfg_dict[var] = getattr(args, var)
return cfg_dict

def _initialize_cfg(self):
if self.need_initialization:
self.need_initialization = False
if os.path.exists('./configs/base.yaml'):
with open("./configs/base.yaml", 'r') as f:
cfg = yaml.load(f.read(), Loader=yaml.SafeLoader)
else:
with open(os.path.realpath(__file__).split('/')[-3] + "/configs/base.yaml", 'r') as f:
cfg = yaml.load(f.read(), Loader=yaml.SafeLoader)
return cfg

def _load_yaml(self, args, file_name=""):
assert args.cfg_file is not None
if not file_name == "": # reading from base file
with open(file_name, 'r') as f:
cfg = yaml.load(f.read(), Loader=yaml.SafeLoader)
else:
if os.getcwd().split("/")[-1] == args.cfg_file.split("/")[0]:
args.cfg_file = args.cfg_file.replace(os.getcwd().split("/")[-1], "./")
with open(args.cfg_file, 'r') as f:
cfg = yaml.load(f.read(), Loader=yaml.SafeLoader)
file_name = args.cfg_file

if "_BASE_RUN" not in cfg.keys() and "_BASE_MODEL" not in cfg.keys() and "_BASE" not in cfg.keys():
# return cfg if the base file is being accessed
cfg = self._merge_cfg_from_command_update(args, cfg)
return cfg

if "_BASE" in cfg.keys():
if cfg["_BASE"][1] == '.':
prev_count = cfg["_BASE"].count('..')
cfg_base_file = self._path_join(file_name.split('/')[:(-1-cfg["_BASE"].count('..'))] + cfg["_BASE"].split('/')[prev_count:])
else:
cfg_base_file = cfg["_BASE"].replace(
"./",
args.cfg_file.replace(args.cfg_file.split('/')[-1], "")
)
cfg_base = self._load_yaml(args, cfg_base_file)
cfg = self._merge_cfg_from_base(cfg_base, cfg)
else:
if "_BASE_RUN" in cfg.keys():
if cfg["_BASE_RUN"][1] == '.':
prev_count = cfg["_BASE_RUN"].count('..')
cfg_base_file = self._path_join(file_name.split('/')[:(-1-prev_count)] + cfg["_BASE_RUN"].split('/')[prev_count:])
else:
cfg_base_file = cfg["_BASE_RUN"].replace(
"./",
args.cfg_file.replace(args.cfg_file.split('/')[-1], "")
)
cfg_base = self._load_yaml(args, cfg_base_file)
cfg = self._merge_cfg_from_base(cfg_base, cfg, preserve_base=True)
if "_BASE_MODEL" in cfg.keys():
if cfg["_BASE_MODEL"][1] == '.':
prev_count = cfg["_BASE_MODEL"].count('..')
cfg_base_file = self._path_join(file_name.split('/')[:(-1-cfg["_BASE_MODEL"].count('..'))] + cfg["_BASE_MODEL"].split('/')[prev_count:])
else:
cfg_base_file = cfg["_BASE_MODEL"].replace(
"./",
args.cfg_file.replace(args.cfg_file.split('/')[-1], "")
)
cfg_base = self._load_yaml(args, cfg_base_file)
cfg = self._merge_cfg_from_base(cfg_base, cfg)
cfg = self._merge_cfg_from_command(args, cfg)
return cfg

def _merge_cfg_from_base(self, cfg_base, cfg_new, preserve_base=False):
for k,v in cfg_new.items():
if k in cfg_base.keys():
if isinstance(v, dict):
self._merge_cfg_from_base(cfg_base[k], v)
else:
cfg_base[k] = v
else:
if "BASE" not in k or preserve_base:
cfg_base[k] = v
return cfg_base

def _merge_cfg_from_command_update(self, args, cfg):
if len(args.opts) == 0:
return cfg

assert len(args.opts) % 2 == 0, 'Override list {} has odd length: {}.'.format(
args.opts, len(args.opts)
)
keys = args.opts[0::2]
vals = args.opts[1::2]

for key, val in zip(keys, vals):
cfg[key] = val

return cfg

def _merge_cfg_from_command(self, args, cfg):
assert len(args.opts) % 2 == 0, 'Override list {} has odd length: {}.'.format(
args.opts, len(args.opts)
)
keys = args.opts[0::2]
vals = args.opts[1::2]

# maximum supported depth 3
for idx, key in enumerate(keys):
key_split = key.split('.')
assert len(key_split) <= 4, 'Key depth error. \nMaximum depth: 3\n Get depth: {}'.format(
len(key_split)
)
assert key_split[0] in cfg.keys(), 'Non-existant key: {}.'.format(
key_split[0]
)
if len(key_split) == 2:
assert key_split[1] in cfg[key_split[0]].keys(), 'Non-existant key: {}.'.format(
key
)
elif len(key_split) == 3:
assert key_split[1] in cfg[key_split[0]].keys(), 'Non-existant key: {}.'.format(
key
)
assert key_split[2] in cfg[key_split[0]][key_split[1]].keys(), 'Non-existant key: {}.'.format(
key
)
elif len(key_split) == 4:
assert key_split[1] in cfg[key_split[0]].keys(), 'Non-existant key: {}.'.format(
key
)
assert key_split[2] in cfg[key_split[0]][key_split[1]].keys(), 'Non-existant key: {}.'.format(
key
)
assert key_split[3] in cfg[key_split[0]][key_split[1]][key_split[2]].keys(), 'Non-existant key: {}.'.format(
key
)
if len(key_split) == 1:
cfg[key_split[0]] = vals[idx]
elif len(key_split) == 2:
cfg[key_split[0]][key_split[1]] = vals[idx]
elif len(key_split) == 3:
cfg[key_split[0]][key_split[1]][key_split[2]] = vals[idx]
elif len(key_split) == 4:
cfg[key_split[0]][key_split[1]][key_split[2]][key_split[3]] = vals[idx]
return cfg

def _update_dict(self, cfg_dict):
def recur(key, elem):
if type(elem) is dict:
return key, Config(load=False, cfg_dict=elem, cfg_level=key)
else:
if type(elem) is str and elem[1:3]=="e-":
elem = float(elem)
return key, elem
dic = dict(recur(k, v) for k, v in cfg_dict.items())
self.__dict__.update(dic)

def get_args(self):
return self.args

def __repr__(self):
return "{}\n".format(self.dump())

def dump(self):
return json.dumps(self.cfg_dict, indent=2)

def deep_copy(self):
return copy.deepcopy(self)

# if __name__ == '__main__':
# # debug
# cfg = Config(load=True)
# print(cfg.DATA)
Loading

0 comments on commit 8f2e748

Please sign in to comment.