From 3f1058ff498823eff008395544aa9d1b9c55a1af Mon Sep 17 00:00:00 2001 From: Qiao Qiao <68757394+qiaouchicago@users.noreply.github.com> Date: Mon, 17 Oct 2022 09:42:07 -0500 Subject: [PATCH] DEV-1483: add black and reformat (#389) reformat all using black --- .pre-commit-config.yaml | 5 + .secrets.baseline | 4 +- bin/update_related_case_caches.py | 40 +- docs/bin/schemata_to_graphviz.py | 24 +- gdcdatamodel/__main__.py | 51 ++- gdcdatamodel/gdc_postgres_admin.py | 142 ++++--- gdcdatamodel/models/__init__.py | 4 +- gdcdatamodel/models/caching.py | 71 ++-- gdcdatamodel/models/indexes.py | 27 +- gdcdatamodel/models/utils.py | 4 +- gdcdatamodel/models/versioned_nodes.py | 30 +- gdcdatamodel/models/versioning.py | 6 +- gdcdatamodel/query.py | 65 ++- gdcdatamodel/validators/graph_validators.py | 111 +++--- gdcdatamodel/validators/json_validators.py | 28 +- gdcdatamodel/viz/__init__.py | 10 +- migrations/async_transactions.py | 16 +- migrations/index_secondary_keys.py | 13 +- migrations/notifications.py | 4 +- migrations/set_null_edge_columns.py | 9 +- migrations/update_case_cache_append_only.py | 33 +- migrations/update_legacy_states.py | 90 ++--- test/conftest.py | 40 +- test/helpers.py | 4 +- test/models.py | 1 - test/test_admin_script.py | 97 +++-- test/test_cache_related_cases.py | 109 +++--- test/test_datamodel.py | 68 ++-- test/test_dictionary_loadiing.py | 28 +- test/test_gdc_postgres_admin.py | 164 +++++--- test/test_indexes.py | 8 +- test/test_node_tagging.py | 72 +++- test/test_update_case_cache.py | 20 +- test/test_validators.py | 413 +++++++++++++------- test/test_versioned_nodes.py | 55 +-- 35 files changed, 1083 insertions(+), 783 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7fd0f026..f5177be2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -20,6 +20,11 @@ repos: args: [--autofix] - id: trailing-whitespace args: [--markdown-linebreak-ext=md] + - repo: https://github.com/psf/black + rev: 20.8b1 # for pre-commit 1.21.0 in jenkins + hooks: + - id: black + additional_dependencies: [click==8.0.4] - repo: https://github.com/pycqa/isort rev: 5.6.4 # last version to support pre-commit 1.21.0 in Jenkins hooks: diff --git a/.secrets.baseline b/.secrets.baseline index 9e872771..e18f4983 100644 --- a/.secrets.baseline +++ b/.secrets.baseline @@ -3,7 +3,7 @@ "files": "^.secrets.baseline$", "lines": null }, - "generated_at": "2022-10-11T18:55:31Z", + "generated_at": "2022-10-14T21:28:31Z", "plugins_used": [ { "name": "AWSKeyDetector" @@ -69,7 +69,7 @@ "hashed_secret": "5d0fa74acf95d1d6bebd0d37f76a94e77d604fd9", "is_secret": false, "is_verified": false, - "line_number": 33, + "line_number": 36, "type": "Basic Auth Credentials" } ] diff --git a/bin/update_related_case_caches.py b/bin/update_related_case_caches.py index 9cd660d9..52ea94c6 100644 --- a/bin/update_related_case_caches.py +++ b/bin/update_related_case_caches.py @@ -25,8 +25,11 @@ def recursive_update_related_case_caches(node, case, visited_ids=set()): """ - logger.info("{}: | case: {} | project: {}".format( - node, case, node._props.get('project_id', '?'))) + logger.info( + "{}: | case: {} | project: {}".format( + node, case, node._props.get("project_id", "?") + ) + ) visited_ids.add(node.node_id) @@ -34,10 +37,10 @@ def recursive_update_related_case_caches(node, case, visited_ids=set()): if edge.src is None: continue - if edge.__class__.__name__.endswith('RelatesToCase'): + if edge.__class__.__name__.endswith("RelatesToCase"): continue - if not hasattr(edge.src, '_related_cases'): + if not hasattr(edge.src, "_related_cases"): continue original = set(edge.src._related_cases) @@ -61,14 +64,23 @@ def update_project_related_case_cache(project): def main(): parser = argparse.ArgumentParser() - parser.add_argument("-H", "--host", type=str, action="store", - required=True, help="psql-server host") - parser.add_argument("-U", "--user", type=str, action="store", - required=True, help="psql test user") - parser.add_argument("-D", "--database", type=str, action="store", - required=True, help="psql test database") - parser.add_argument("-P", "--password", type=str, action="store", - help="psql test password") + parser.add_argument( + "-H", "--host", type=str, action="store", required=True, help="psql-server host" + ) + parser.add_argument( + "-U", "--user", type=str, action="store", required=True, help="psql test user" + ) + parser.add_argument( + "-D", + "--database", + type=str, + action="store", + required=True, + help="psql test database", + ) + parser.add_argument( + "-P", "--password", type=str, action="store", help="psql test password" + ) args = parser.parse_args() prompt = "Password for {}:".format(args.user) @@ -76,12 +88,12 @@ def main(): g = PsqlGraphDriver(args.host, args.user, password, args.database) with g.session_scope(): - projects = g.nodes(md.Project).not_props(state='legacy').all() + projects = g.nodes(md.Project).not_props(state="legacy").all() for p in projects: update_project_related_case_cache(p) print("Done.") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/docs/bin/schemata_to_graphviz.py b/docs/bin/schemata_to_graphviz.py index e5bb2407..f6866000 100644 --- a/docs/bin/schemata_to_graphviz.py +++ b/docs/bin/schemata_to_graphviz.py @@ -6,19 +6,21 @@ def build_visualization(): - print('Building schema documentation...') + print ("Building schema documentation...") # Load directory tree info bin_dir = os.path.dirname(os.path.realpath(__file__)) - root_dir = os.path.join(os.path.abspath( - os.path.join(bin_dir, os.pardir, os.pardir))) + root_dir = os.path.join( + os.path.abspath(os.path.join(bin_dir, os.pardir, os.pardir)) + ) # Create graph dot = Digraph( - comment="High level graph representation of GDC data model", format='pdf') - dot.graph_attr['rankdir'] = 'RL' - dot.node_attr['fillcolor'] = 'lightblue' - dot.node_attr['style'] = 'filled' + comment="High level graph representation of GDC data model", format="pdf" + ) + dot.graph_attr["rankdir"] = "RL" + dot.node_attr["fillcolor"] = "lightblue" + dot.node_attr["style"] = "filled" # Add nodes for node in m.Node.get_subclasses(): @@ -28,7 +30,7 @@ def build_visualization(): # Add edges for edge in m.Edge.get_subclasses(): - if edge.__dst_class__ == 'Case' and edge.label == 'relates_to': + if edge.__dst_class__ == "Case" and edge.label == "relates_to": # Skip case cache edges continue @@ -36,10 +38,10 @@ def build_visualization(): dst = m.Node.get_subclass_named(edge.__dst_class__) dot.edge(src.get_label(), dst.get_label(), edge.get_label()) - gv_path = os.path.join(root_dir, 'docs', 'viz', 'gdc_data_model.gv') + gv_path = os.path.join(root_dir, "docs", "viz", "gdc_data_model.gv") dot.render(gv_path) - print('graphviz output to {}'.format(gv_path)) + print ("graphviz output to {}".format(gv_path)) -if __name__ == '__main__': +if __name__ == "__main__": build_visualization() diff --git a/gdcdatamodel/__main__.py b/gdcdatamodel/__main__.py index f3a9e387..7e1c4011 100644 --- a/gdcdatamodel/__main__.py +++ b/gdcdatamodel/__main__.py @@ -1,7 +1,6 @@ import argparse import getpass -import psqlgraph from models import * # noqa from models.versioned_nodes import VersionedNode # noqa from psqlgraph import * # noqa @@ -9,12 +8,18 @@ try: import IPython + ipython = True except Exception as e: - print(('{}, using standard interactive console. ' - 'If you install IPython, then it will automatically ' - 'be used for this repl.').format(e)) + print( + ( + "{}, using standard interactive console. " + "If you install IPython, then it will automatically " + "be used for this repl." + ).format(e) + ) import code + ipython = False @@ -31,17 +36,32 @@ """ -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('-d', '--database', default='test', type=str, - help='name of the database to connect to') - parser.add_argument('-i', '--host', default='localhost', type=str, - help='host of the postgres server') - parser.add_argument('-u', '--user', default='test', type=str, - help='user to connect to postgres as') - parser.add_argument('-p', '--password', default=None, type=str, - help='password for given user. If no ' - 'password given, one will be prompted.') + parser.add_argument( + "-d", + "--database", + default="test", + type=str, + help="name of the database to connect to", + ) + parser.add_argument( + "-i", + "--host", + default="localhost", + type=str, + help="host of the postgres server", + ) + parser.add_argument( + "-u", "--user", default="test", type=str, help="user to connect to postgres as" + ) + parser.add_argument( + "-p", + "--password", + default=None, + type=str, + help="password for given user. If no " "password given, one will be prompted.", + ) args = parser.parse_args() @@ -49,8 +69,7 @@ if args.password is None: args.password = getpass.getpass() - g = psqlgraph.PsqlGraphDriver( - args.host, args.user, args.password, args.database) + g = psqlgraph.PsqlGraphDriver(args.host, args.user, args.password, args.database) with g.session_scope() as s: rb = s.rollback diff --git a/gdcdatamodel/gdc_postgres_admin.py b/gdcdatamodel/gdc_postgres_admin.py index 21f40ad8..71216d4c 100644 --- a/gdcdatamodel/gdc_postgres_admin.py +++ b/gdcdatamodel/gdc_postgres_admin.py @@ -75,7 +75,7 @@ def execute_for_all_graph_tables(engine, sql, namespace=None, **kwargs): edge_cls = ext.get_abstract_edge(namespace) for cls in node_cls.get_subclasses() + edge_cls.get_subclasses(): - _kwargs = dict(kwargs, **{'table': cls.__tablename__}) + _kwargs = dict(kwargs, **{"table": cls.__tablename__}) statement = sql.format(**_kwargs) execute(engine, statement) @@ -100,13 +100,13 @@ def create_graph_tables(engine, timeout, namespace=None): """ create a table """ - logger.info('Creating tables (timeout: %d)', timeout) + logger.info("Creating tables (timeout: %d)", timeout) connection = engine.connect() trans = connection.begin() logger.info("Setting lock_timeout to %d", timeout) - timeout_str = '{}s'.format(int(timeout+1)) + timeout_str = "{}s".format(int(timeout + 1)) connection.execute("SET LOCAL lock_timeout = %s;", timeout_str) orm_base = ext.get_orm_base(namespace) if namespace else ORMBase @@ -123,25 +123,25 @@ def create_tables(engine, delay, retries, namespace=None): """ - logger.info('Running table creator named %s', app_name) + logger.info("Running table creator named %s", app_name) try: return create_graph_tables(engine, delay, namespace=namespace) except OperationalError as e: - if 'timeout' in str(e): - logger.warning('Attempt timed out') + if "timeout" in str(e): + logger.warning("Attempt timed out") else: raise if retries <= 0: - raise RuntimeError('Max retries exceeded') + raise RuntimeError("Max retries exceeded") logger.info( - 'Trying again in {} seconds ({} retries remaining)' - .format(delay, retries)) + "Trying again in {} seconds ({} retries remaining)".format(delay, retries) + ) time.sleep(delay) - create_tables(engine, delay, retries-1, namespace=namespace) + create_tables(engine, delay, retries - 1, namespace=namespace) def subcommand_create(args): @@ -153,10 +153,7 @@ def subcommand_create(args): logger.info("Running subcommand 'create'") engine = get_engine(args.host, args.user, args.password, args.database) kwargs = dict( - engine=engine, - delay=args.delay, - retries=args.retries, - namespace=args.namespace + engine=engine, delay=args.delay, retries=args.retries, namespace=args.namespace ) return create_tables(**kwargs) @@ -172,15 +169,15 @@ def subcommand_grant(args): logger.info("Running subcommand 'grant'") engine = get_engine(args.host, args.user, args.password, args.database) - assert args.read or args.write, 'No premission types/users specified.' + assert args.read or args.write, "No premission types/users specified." if args.read: - users_read = [u for u in args.read.split(',') if u] + users_read = [u for u in args.read.split(",") if u] for user in users_read: grant_read_permissions_to_graph(engine, user, args.namespace) if args.write: - users_write = [u for u in args.write.split(',') if u] + users_write = [u for u in args.write.split(",") if u] for user in users_write: grant_write_permissions_to_graph(engine, user, args.namespace) @@ -196,73 +193,104 @@ def subcommand_revoke(args): engine = get_engine(args.host, args.user, args.password, args.database) if args.read: - users_read = [u for u in args.read.split(',') if u] + users_read = [u for u in args.read.split(",") if u] for user in users_read: revoke_read_permissions_to_graph(engine, user, args.namespace) if args.write: - users_write = [u for u in args.write.split(',') if u] + users_write = [u for u in args.write.split(",") if u] for user in users_write: revoke_write_permissions_to_graph(engine, user, args.namespace) def add_base_args(subparser): - subparser.add_argument("-H", "--host", type=str, action="store", - required=True, help="psql-server host") - subparser.add_argument("-U", "--user", type=str, action="store", - required=True, help="psql test user") - subparser.add_argument("-D", "--database", type=str, action="store", - required=True, help="psql test database") - subparser.add_argument("-P", "--password", type=str, action="store", - default='', help="psql test password") - subparser.add_argument("-N", "--namespace", type=lambda x: x if x else None, - help="psqlgraph model namespace") + subparser.add_argument( + "-H", "--host", type=str, action="store", required=True, help="psql-server host" + ) + subparser.add_argument( + "-U", "--user", type=str, action="store", required=True, help="psql test user" + ) + subparser.add_argument( + "-D", + "--database", + type=str, + action="store", + required=True, + help="psql test database", + ) + subparser.add_argument( + "-P", + "--password", + type=str, + action="store", + default="", + help="psql test password", + ) + subparser.add_argument( + "-N", + "--namespace", + type=lambda x: x if x else None, + help="psqlgraph model namespace", + ) return subparser def add_subcommand_create(subparsers): - parser = add_base_args(subparsers.add_parser( - 'graph-create', - help=subcommand_create.__doc__ - )) + parser = add_base_args( + subparsers.add_parser("graph-create", help=subcommand_create.__doc__) + ) parser.add_argument( - "--delay", type=int, action="store", default=60, - help="How many seconds to wait for blocking processes to finish before retrying." + "--delay", + type=int, + action="store", + default=60, + help="How many seconds to wait for blocking processes to finish before retrying.", ) parser.add_argument( - "--retries", type=int, action="store", default=10, - help="If blocked by important process, how many times to retry after waiting `delay` seconds." + "--retries", + type=int, + action="store", + default=10, + help="If blocked by important process, how many times to retry after waiting `delay` seconds.", ) def add_subcommand_grant(subparsers): - parser = add_base_args(subparsers.add_parser( - 'graph-grant', - help=subcommand_grant.__doc__ - )) + parser = add_base_args( + subparsers.add_parser("graph-grant", help=subcommand_grant.__doc__) + ) parser.add_argument( - "--read", type=str, action="store", - help="Users to grant read access to (comma separated)." + "--read", + type=str, + action="store", + help="Users to grant read access to (comma separated).", ) parser.add_argument( - "--write", type=str, action="store", - help="Users to grant read/write access to (comma separated)." + "--write", + type=str, + action="store", + help="Users to grant read/write access to (comma separated).", ) def add_subcommand_revoke(subparsers): - parser = add_base_args(subparsers.add_parser( - 'graph-revoke', - help=subcommand_revoke.__doc__ - )) + parser = add_base_args( + subparsers.add_parser("graph-revoke", help=subcommand_revoke.__doc__) + ) parser.add_argument( - "--read", type=str, action="store", - help="Users to revoke read access from (comma separated)." + "--read", + type=str, + action="store", + help="Users to revoke read access from (comma separated).", ) parser.add_argument( - "--write", type=str, action="store", - help=("Users to revoke write access from (comma separated). " - "NOTE: The user will still have read privs!!") + "--write", + type=str, + action="store", + help=( + "Users to revoke write access from (comma separated). " + "NOTE: The user will still have read privs!!" + ), ) @@ -284,9 +312,9 @@ def main(args=None): logger.info("[ NAMESPACE : %-10s ]", args.namespace or "default") return_value = { - 'graph-create': subcommand_create, - 'graph-grant': subcommand_grant, - 'graph-revoke': subcommand_revoke, + "graph-create": subcommand_create, + "graph-grant": subcommand_grant, + "graph-revoke": subcommand_revoke, }[args.subcommand](args) logger.info("Done.") diff --git a/gdcdatamodel/models/__init__.py b/gdcdatamodel/models/__init__.py index 10aa9e03..e0be12ec 100644 --- a/gdcdatamodel/models/__init__.py +++ b/gdcdatamodel/models/__init__.py @@ -1,4 +1,3 @@ - """gdcdatamodel.models ---------------------------------- @@ -738,7 +737,7 @@ def load_edges(dictionary, node_cls=Node, edge_cls=Edge, package_namespace=None) src_cls._pg_links[link["name"]] = { "edge_out": edge_name, "dst_type": node_cls.get_subclass(link["target_type"]), - "backref": link["backref"] + "backref": link["backref"], } for src_cls in node_cls.get_subclasses(): @@ -795,6 +794,7 @@ def inject_pg_edges(node_cls): { : {'backref': , 'type': } } """ + def cls_inject_forward_edges(cls): """We should have already added the links that go OUT from this class, so let's add them to `_pg_edges` diff --git a/gdcdatamodel/models/caching.py b/gdcdatamodel/models/caching.py index 6c3308c9..92b13e84 100644 --- a/gdcdatamodel/models/caching.py +++ b/gdcdatamodel/models/caching.py @@ -16,17 +16,17 @@ import logging -logger = logging.getLogger('gdcdatamodel') +logger = logging.getLogger("gdcdatamodel") #: This variable contains the link name for the case shortcut #: association proxy -RELATED_CASES_LINK_NAME = '_related_cases' +RELATED_CASES_LINK_NAME = "_related_cases" #: This variable specifies the categories for which we won't create # short cut : edges to case NOT_RELATED_CASES_CATEGORIES = { - 'administrative', - 'TBD', + "administrative", + "TBD", } @@ -53,7 +53,7 @@ def get_related_case_edge_cls_name(node): """ - return '{}RelatesToCase'.format(node.__class__.__name__) + return "{}RelatesToCase".format(node.__class__.__name__) def get_edge_src(edge): @@ -69,18 +69,19 @@ def get_edge_src(edge): src = edge.src elif edge.src_id is not None: src_class = node_cls.get_subclass_named(edge.__src_class__) - src = (edge.get_session().query(src_class) - .filter(src_class.node_id == edge.src_id) - .first()) + src = ( + edge.get_session() + .query(src_class) + .filter(src_class.node_id == edge.src_id) + .first() + ) else: src = None return src def get_edge_dst(edge, allow_query=False): - """Return the edge's destination or None. - - """ + """Return the edge's destination or None.""" node_cls = edge.get_node_class() if edge.dst: @@ -88,9 +89,12 @@ def get_edge_dst(edge, allow_query=False): elif edge.dst_id is not None and allow_query: dst_class = node_cls.get_subclass_named(edge.__dst_class__) - dst = (edge.get_session().query(dst_class) - .filter(dst_class.node_id == edge.dst_id) - .first()) + dst = ( + edge.get_session() + .query(dst_class) + .filter(dst_class.node_id == edge.dst_id) + .first() + ) else: dst = None @@ -120,9 +124,7 @@ def related_cases_from_parents(node): """ - skip_edges_named = [ - get_related_case_edge_cls_name(node) - ] + skip_edges_named = [get_related_case_edge_cls_name(node)] # Make sure the edges haven't been expunged edges_out = [e for e in node.edges_out if e in node.get_session()] @@ -142,17 +144,15 @@ def related_cases_from_parents(node): continue node_cls = edge.get_node_class() dst_class = node_cls.get_subclass_named(edge.__dst_class__) - if dst_class.label == 'case' and edge.dst: + if dst_class.label == "case" and edge.dst: cases.add(edge.dst) return list(filter(None, cases)) -def cache_related_cases_recursive(node, - session, - flush_context, - instances, - visited_nodes=None): +def cache_related_cases_recursive( + node, session, flush_context, instances, visited_nodes=None +): """Update the related case cache on source node and its children recursively iff the this update changes the related case source node's shortcut edges. @@ -219,21 +219,17 @@ def update_cache_edges(node, session, correct_cases): # Get information about the existing edges edge_name = get_related_case_edge_cls_name(node) - existing_edges = getattr(node, '_{}_out'.format(edge_name)) + existing_edges = getattr(node, "_{}_out".format(edge_name)) # Remove edges that should no longer exist cases_disconnected = [ - edge.dst - for edge in existing_edges - if edge.dst_id not in correct_cases + edge.dst for edge in existing_edges if edge.dst_id not in correct_cases ] for case in cases_disconnected: assoc_proxy.remove(case) - existing_edge_dst_case_ids = { - edge.dst_id for edge in existing_edges - } + existing_edge_dst_case_ids = {edge.dst_id for edge in existing_edges} cases_connected = [ case @@ -245,10 +241,7 @@ def update_cache_edges(node, session, correct_cases): assoc_proxy.append(case) -def cache_related_cases_on_insert(target, - session, - flush_context, - instances): +def cache_related_cases_on_insert(target, session, flush_context, instances): """Hook on updated edges. Update the related case cache on source node and its children iff the this update changes the related case source node's cache. @@ -281,10 +274,7 @@ def cache_related_cases_on_insert(target, ) -def cache_related_cases_on_update(target, - session, - flush_context, - instances): +def cache_related_cases_on_update(target, session, flush_context, instances): """Hook on deleted edges. Update the related case cache on source node and its children. @@ -309,10 +299,7 @@ def cache_related_cases_on_update(target, ) -def cache_related_cases_on_delete(target, - session, - flush_context, - instances): +def cache_related_cases_on_delete(target, session, flush_context, instances): """Hook on deleted edges. Update the related case cache on source node and its children. diff --git a/gdcdatamodel/models/indexes.py b/gdcdatamodel/models/indexes.py index eb500cba..9b633445 100644 --- a/gdcdatamodel/models/indexes.py +++ b/gdcdatamodel/models/indexes.py @@ -1,4 +1,3 @@ - """gdcdatamodel.models.indexes ---------------------------------- @@ -33,20 +32,20 @@ def index_name(cls, description): """ - name = 'index_{}_{}'.format(cls.__tablename__, description) + name = "index_{}_{}".format(cls.__tablename__, description) # If the name is too long, prepend it with the first 8 hex of it's hash # truncate the each part of the name if len(name) > 40: oldname = index_name - logger.debug('Edge tablename {} too long, shortening'.format(oldname)) - name = 'index_{}_{}_{}'.format( + logger.debug("Edge tablename {} too long, shortening".format(oldname)) + name = "index_{}_{}_{}".format( hashlib.md5(py3_to_bytes(cls.__tablename__)).hexdigest()[:8], - ''.join([a[:4] for a in cls.get_label().split('_')])[:20], - '_'.join([a[:8] for a in description.split('_')])[:25], + "".join([a[:4] for a in cls.get_label().split("_")])[:20], + "_".join([a[:8] for a in description.split("_")])[:25], ) - logger.debug('Shortening {} -> {}'.format(oldname, index_name)) + logger.debug("Shortening {} -> {}".format(oldname, index_name)) return name @@ -62,7 +61,7 @@ def get_secondary_key_indexes(cls): """ #: use text_pattern_ops, allows LIKE statements not starting with % - index_op = 'text_pattern_ops' + index_op = "text_pattern_ops" secondary_keys = {key for pair in cls.__pg_secondary_keys for key in pair} key_indexes = ( @@ -70,15 +69,17 @@ def get_secondary_key_indexes(cls): index_name(cls, key), cls._props[key].astext.label(key), postgresql_ops={key: index_op}, - ) for key in secondary_keys + ) + for key in secondary_keys ) lower_key_indexes = ( Index( - index_name(cls, key+'_lower'), - func.lower(cls._props[key].astext).label(key+'_lower'), - postgresql_ops={key+'_lower': index_op}, - ) for key in secondary_keys + index_name(cls, key + "_lower"), + func.lower(cls._props[key].astext).label(key + "_lower"), + postgresql_ops={key + "_lower": index_op}, + ) + for key in secondary_keys ) return tuple(key_indexes) + tuple(lower_key_indexes) diff --git a/gdcdatamodel/models/utils.py b/gdcdatamodel/models/utils.py index 75cab86a..f77f23d0 100644 --- a/gdcdatamodel/models/utils.py +++ b/gdcdatamodel/models/utils.py @@ -7,11 +7,13 @@ def decorator(f): @wraps(f) def wrapper(*args, **kwargs): return f(*args, **kwargs) + return f + return decorator def py3_to_bytes(bytes_or_str): if sys.version_info[0] > 2 and isinstance(bytes_or_str, str): - return bytes_or_str.encode('utf-8') + return bytes_or_str.encode("utf-8") return bytes_or_str diff --git a/gdcdatamodel/models/versioned_nodes.py b/gdcdatamodel/models/versioned_nodes.py index 24bcfea9..03df64f5 100644 --- a/gdcdatamodel/models/versioned_nodes.py +++ b/gdcdatamodel/models/versioned_nodes.py @@ -9,21 +9,18 @@ class VersionedNode(Base): - __tablename__ = 'versioned_nodes' + __tablename__ = "versioned_nodes" __table_args__ = ( - Index('submitted_node_id_idx', 'node_id'), - Index('submitted_node_gdc_versions_idx', 'node_id'), + Index("submitted_node_id_idx", "node_id"), + Index("submitted_node_gdc_versions_idx", "node_id"), ) def __repr__(self): - return ("" - .format(self.key, self.label, self.node_id)) + return "".format( + self.key, self.label, self.node_id + ) - key = Column( - BigInteger, - primary_key=True, - nullable=False - ) + key = Column(BigInteger, primary_key=True, nullable=False) label = Column( Text, @@ -52,7 +49,7 @@ def __repr__(self): versioned = Column( DateTime(timezone=True), nullable=False, - server_default=text('now()'), + server_default=text("now()"), ) acl = Column( @@ -79,14 +76,13 @@ def clone(node): return VersionedNode( label=copy(node.label), node_id=copy(node.node_id), - project_id=copy(node._props.get('project_id')), + project_id=copy(node._props.get("project_id")), created=copy(node.created), acl=copy(node.acl), system_annotations=copy(node.system_annotations), properties=copy(node.properties), - neighbors=copy([ - edge.dst_id for edge in node.edges_out - ] + [ - edge.src_id for edge in node.edges_in - ]) + neighbors=copy( + [edge.dst_id for edge in node.edges_out] + + [edge.src_id for edge in node.edges_in] + ), ) diff --git a/gdcdatamodel/models/versioning.py b/gdcdatamodel/models/versioning.py index d382c658..00935126 100644 --- a/gdcdatamodel/models/versioning.py +++ b/gdcdatamodel/models/versioning.py @@ -102,10 +102,10 @@ def _constraints(self): def is_taggable(self, node): """Returns true if node supports tagging else False. Ideally, instances that return false will not - have tag and version number set on them + have tag and version number set on them - Returns: - bool: True for nodes that can be tagged + Returns: + bool: True for nodes that can be tagged """ return not any(criteria.match(node) for criteria in self._constraints()) diff --git a/gdcdatamodel/query.py b/gdcdatamodel/query.py index 3a4f9ba6..0c4507ed 100644 --- a/gdcdatamodel/query.py +++ b/gdcdatamodel/query.py @@ -1,14 +1,30 @@ from psqlgraph import Edge, Node traversals = {} -terminal_nodes = ['annotations', 'centers', 'archives', 'tissue_source_sites', - 'files', 'related_files', 'describing_files', - 'clinical_metadata_files', 'experiment_metadata_files', 'run_metadata_files', - 'analysis_metadata_files', 'biospecimen_metadata_files', 'aligned_reads_metrics', - 'read_group_metrics', 'pathology_reports', 'simple_germline_variations', - 'aligned_reads_indexes', 'mirna_expressions', 'exon_expressions', - 'simple_somatic_mutations', 'gene_expressions', 'aggregated_somatic_mutations', - ] +terminal_nodes = [ + "annotations", + "centers", + "archives", + "tissue_source_sites", + "files", + "related_files", + "describing_files", + "clinical_metadata_files", + "experiment_metadata_files", + "run_metadata_files", + "analysis_metadata_files", + "biospecimen_metadata_files", + "aligned_reads_metrics", + "read_group_metrics", + "pathology_reports", + "simple_germline_variations", + "aligned_reads_indexes", + "mirna_expressions", + "exon_expressions", + "simple_somatic_mutations", + "gene_expressions", + "aggregated_somatic_mutations", +] def construct_traversals(root, node, visited, path): @@ -18,27 +34,38 @@ def construct_traversals(root, node, visited, path): and neighbor not in visited and neighbor != node # no traveling THROUGH terminal nodes - and (path[-1] not in terminal_nodes - if path else neighbor.label not in terminal_nodes) - and (not path[-1].startswith('_related') - if path else not neighbor.label.startswith('_related'))) + and ( + path[-1] not in terminal_nodes + if path + else neighbor.label not in terminal_nodes + ) + and ( + not path[-1].startswith("_related") + if path + else not neighbor.label.startswith("_related") + ) + ) for edge in Edge._get_edges_with_src(node.__name__): - neighbor = [n for n in Node.get_subclasses() - if n.__name__ == edge.__dst_class__][0] + neighbor = [ + n for n in Node.get_subclasses() if n.__name__ == edge.__dst_class__ + ][0] if recurse(neighbor): construct_traversals( - root, neighbor, visited+[node], path+[edge.__src_dst_assoc__]) + root, neighbor, visited + [node], path + [edge.__src_dst_assoc__] + ) for edge in Edge._get_edges_with_dst(node.__name__): - neighbor = [n for n in Node.get_subclasses() - if n.__name__ == edge.__src_class__][0] + neighbor = [ + n for n in Node.get_subclasses() if n.__name__ == edge.__src_class__ + ][0] if recurse(neighbor): construct_traversals( - root, neighbor, visited+[node], path+[edge.__dst_src_assoc__]) + root, neighbor, visited + [node], path + [edge.__dst_src_assoc__] + ) traversals[root][node.label] = traversals[root].get(node.label) or set() - traversals[root][node.label].add('.'.join(path)) + traversals[root][node.label].add(".".join(path)) def construct_traversals_for_all_nodes(): diff --git a/gdcdatamodel/validators/graph_validators.py b/gdcdatamodel/validators/graph_validators.py index 2df33bb4..093e220e 100644 --- a/gdcdatamodel/validators/graph_validators.py +++ b/gdcdatamodel/validators/graph_validators.py @@ -2,16 +2,17 @@ class GDCGraphValidator(object): - ''' + """ Validator that validates entities' relationship with existing nodes in database. - ''' + """ + def __init__(self): self.schemas = gdcdictionary self.required_validators = { - 'links_validator': GDCLinksValidator(), - 'uniqueKeys_validator': GDCUniqueKeysValidator() + "links_validator": GDCLinksValidator(), + "uniqueKeys_validator": GDCUniqueKeysValidator(), } self.optional_validators = {} @@ -21,20 +22,19 @@ def record_errors(self, graph, entities): for entity in entities: schema = self.schemas.schema[entity.node.label] - validators = schema.get('validators') + validators = schema.get("validators") if validators: for validator_name in validators: self.optional_validators[validator_name].validate() class GDCLinksValidator(object): - def validate(self, entities, graph=None): for entity in entities: - for link in gdcdictionary.schema[entity.node.label]['links']: - if 'name' in link: + for link in gdcdictionary.schema[entity.node.label]["links"]: + if "name" in link: self.validate_edge(link, entity) - elif 'subgroup' in link: + elif "subgroup" in link: self.validate_edge_group(link, entity) def validate_edge_group(self, schema, entity): @@ -42,90 +42,97 @@ def validate_edge_group(self, schema, entity): schema_links = [] num_of_edges = 0 - for group in schema['subgroup']: - if 'subgroup' in schema['subgroup']: + for group in schema["subgroup"]: + if "subgroup" in schema["subgroup"]: # nested subgroup result = self.validate_edge_group(group, entity) - if 'name' in group: + if "name" in group: result = self.validate_edge(group, entity) - if result['length'] > 0: + if result["length"] > 0: submitted_links.append(result) - num_of_edges += result['length'] - schema_links.append(result['name']) + num_of_edges += result["length"] + schema_links.append(result["name"]) - if schema.get('required') is True and len(submitted_links) == 0: - names = ", ".join( - schema_links[:-2] + [" or ".join(schema_links[-2:])]) + if schema.get("required") is True and len(submitted_links) == 0: + names = ", ".join(schema_links[:-2] + [" or ".join(schema_links[-2:])]) entity.record_error( - "Entity is missing a required link to {}" - .format(names), keys=schema_links) + "Entity is missing a required link to {}".format(names), + keys=schema_links, + ) if schema.get("exclusive") is True and len(submitted_links) > 1: - names = ", ".join( - schema_links[:-2] + [" and ".join(schema_links[-2:])]) + names = ", ".join(schema_links[:-2] + [" and ".join(schema_links[-2:])]) entity.record_error( - "Links to {} are exclusive. More than one was provided: {}" - .format(schema_links, entity.node.edges_out), keys=schema_links) + "Links to {} are exclusive. More than one was provided: {}".format( + schema_links, entity.node.edges_out + ), + keys=schema_links, + ) for edge in entity.node.edges_out: - entity.record_error('{}'.format(edge.dst.submitter_id)) + entity.record_error("{}".format(edge.dst.submitter_id)) - result = {'length': num_of_edges, 'name': ", ".join(schema_links)} + result = {"length": num_of_edges, "name": ", ".join(schema_links)} def validate_edge(self, link_sub_schema, entity): - association = link_sub_schema['name'] + association = link_sub_schema["name"] node = entity.node targets = node[association] - result = {'length': len(targets), 'name': association} + result = {"length": len(targets), "name": association} if len(targets) > 0: - multi = link_sub_schema['multiplicity'] + multi = link_sub_schema["multiplicity"] - if multi in ['many_to_one', 'one_to_one']: + if multi in ["many_to_one", "one_to_one"]: if len(targets) > 1: entity.record_error( - "'{}' link has to be {}" - .format(association, multi), - keys=[association]) + "'{}' link has to be {}".format(association, multi), + keys=[association], + ) - if multi in ['one_to_many', 'one_to_one']: + if multi in ["one_to_many", "one_to_one"]: for target in targets: - if len(target[link_sub_schema['backref']]) > 1: + if len(target[link_sub_schema["backref"]]) > 1: entity.record_error( - "'{}' link has to be {}, target node {} already has {}" - .format(association, multi, - target.label, link_sub_schema['backref']), - keys=[association]) + "'{}' link has to be {}, target node {} already has {}".format( + association, + multi, + target.label, + link_sub_schema["backref"], + ), + keys=[association], + ) - if multi == 'many_to_many': + if multi == "many_to_many": pass else: - if link_sub_schema.get('required') is True: + if link_sub_schema.get("required") is True: entity.record_error( - "Entity is missing required link to {}" - .format(association), - keys=[association]) + "Entity is missing required link to {}".format(association), + keys=[association], + ) return result class GDCUniqueKeysValidator(object): - def validate(self, entities, graph=None): for entity in entities: schema = gdcdictionary.schema[entity.node.label] node = entity.node - for keys in schema['uniqueKeys']: + for keys in schema["uniqueKeys"]: props = {} - if keys == ['id']: + if keys == ["id"]: continue for key in keys: - prop = schema['properties'][key].get('systemAlias') + prop = schema["properties"][key].get("systemAlias") if prop: props[prop] = node[prop] else: props[key] = node[key] if graph.nodes().props(props).count() > 1: - entity.record_error( - '{} with {} already exists in the GDC' - .format(node.label, props), keys=list(props.keys()) - ) + entity.record_error( + "{} with {} already exists in the GDC".format( + node.label, props + ), + keys=list(props.keys()), + ) diff --git a/gdcdatamodel/validators/json_validators.py b/gdcdatamodel/validators/json_validators.py index 13923170..506713a7 100644 --- a/gdcdatamodel/validators/json_validators.py +++ b/gdcdatamodel/validators/json_validators.py @@ -3,8 +3,10 @@ from gdcdictionary import gdcdictionary from jsonschema import Draft4Validator -missing_prop_re = re.compile("\'([a-zA-Z_-]+)\' is a required property") -extra_prop_re = re.compile("Additional properties are not allowed \(u\'([a-zA-Z_-]+)\' was unexpected\)") +missing_prop_re = re.compile("'([a-zA-Z_-]+)' is a required property") +extra_prop_re = re.compile( + "Additional properties are not allowed \(u'([a-zA-Z_-]+)' was unexpected\)" +) def get_keys(error_msg): @@ -27,29 +29,33 @@ def __init__(self): def iter_errors(self, doc): # Note whenever gdcdictionary use a newer version of jsonschema # we need to update the Validator - validator = Draft4Validator(self.schemas.schema[doc['type']]) + validator = Draft4Validator(self.schemas.schema[doc["type"]]) return validator.iter_errors(doc) def record_errors(self, entities): for entity in entities: json_doc = entity.doc - if 'type' not in json_doc: - entity.record_error( - "'type' is a required property", keys=['type']) + if "type" not in json_doc: + entity.record_error("'type' is a required property", keys=["type"]) break - if json_doc['type'] not in self.schemas.schema: + if json_doc["type"] not in self.schemas.schema: entity.record_error( - "specified type: {} is not in the current data model" - .format(json_doc['type']), keys=['type']) + "specified type: {} is not in the current data model".format( + json_doc["type"] + ), + keys=["type"], + ) break for error in self.iter_errors(json_doc): # the key will be property.sub property for nested properties errors = [str(e) for e in error.path if error.path] - keys = ['.'.join(errors)] if errors else [] + keys = [".".join(errors)] if errors else [] if not keys: keys = get_keys(error.message) message = error.message if error.context: - message += ': {}'.format(' and '.join([c.message for c in error.context])) + message += ": {}".format( + " and ".join([c.message for c in error.context]) + ) entity.record_error(message, keys=keys) # additional validators go here diff --git a/gdcdatamodel/viz/__init__.py b/gdcdatamodel/viz/__init__.py index ceeaa4be..e7c6eac4 100644 --- a/gdcdatamodel/viz/__init__.py +++ b/gdcdatamodel/viz/__init__.py @@ -8,18 +8,18 @@ def create_graphviz(nodes, include_case_cache_edges=False): """ dot = Digraph() - dot.graph_attr['rankdir'] = 'RL' + dot.graph_attr["rankdir"] = "RL" edges_added = set() nodes = {node.node_id: node for node in nodes} def is_edge_drawn(edge, neighbor): - is_case_cache_edge = 'RelatesToCase' in edge.__class__.__name__ + is_case_cache_edge = "RelatesToCase" in edge.__class__.__name__ return ( - (include_case_cache_edges or not is_case_cache_edge) and - edge not in edges_added and - neighbor in nodes + (include_case_cache_edges or not is_case_cache_edge) + and edge not in edges_added + and neighbor in nodes ) for node in nodes.values(): diff --git a/migrations/async_transactions.py b/migrations/async_transactions.py index 16ad58f9..eaee33b3 100644 --- a/migrations/async_transactions.py +++ b/migrations/async_transactions.py @@ -19,9 +19,10 @@ def up_transaction(connection): - logger.info('Migrating async-transactions: up') + logger.info("Migrating async-transactions: up") - connection.execute(""" + connection.execute( + """ ALTER TABLE transaction_logs ADD COLUMN state TEXT; ALTER TABLE transaction_logs ADD COLUMN committed_by INTEGER; ALTER TABLE transaction_logs ADD COLUMN is_dry_run BOOLEAN; @@ -37,13 +38,15 @@ def up_transaction(connection): UPDATE transaction_logs SET state = 'SUCCEEDED' WHERE state IS NULL; UPDATE transaction_logs SET is_dry_run = FALSE WHERE is_dry_run IS NULL; - """) + """ + ) def down_transaction(connection): - logger.info('Migrating async-transactions: down') + logger.info("Migrating async-transactions: down") - connection.execute(""" + connection.execute( + """ DROP INDEX transaction_logs_program_idx; DROP INDEX transaction_logs_project_idx; DROP INDEX transaction_logs_is_dry_run_idx; @@ -56,7 +59,8 @@ def down_transaction(connection): ALTER TABLE transaction_logs DROP COLUMN closed; ALTER TABLE transaction_logs DROP COLUMN committed_by; ALTER TABLE transaction_logs DROP COLUMN is_dry_run; - """) + """ + ) def up(connection): diff --git a/migrations/index_secondary_keys.py b/migrations/index_secondary_keys.py index aa60fc1c..d6852f14 100644 --- a/migrations/index_secondary_keys.py +++ b/migrations/index_secondary_keys.py @@ -24,26 +24,27 @@ TX_LOG_PROJECT_ID_IDX = Index( - 'transaction_logs_project_id_idx', - TransactionLog.program+'_'+TransactionLog.project) + "transaction_logs_project_id_idx", + TransactionLog.program + "_" + TransactionLog.project, +) def up_transaction(connection): - logger.info('Migrating async-transactions: up') + logger.info("Migrating async-transactions: up") for cls in Node.get_subclasses(): for index in get_secondary_key_indexes(cls): - logger.info('Creating %s', index.name) + logger.info("Creating %s", index.name) index.create(connection) TX_LOG_PROJECT_ID_IDX.create(connection) def down_transaction(connection): - logger.info('Migrating async-transactions: down') + logger.info("Migrating async-transactions: down") for cls in Node.get_subclasses(): for index in get_secondary_key_indexes(cls): - logger.info('Dropping %s', index.name) + logger.info("Dropping %s", index.name) index.drop(connection) TX_LOG_PROJECT_ID_IDX.drop(connection) diff --git a/migrations/notifications.py b/migrations/notifications.py index 74857dfb..6a757267 100644 --- a/migrations/notifications.py +++ b/migrations/notifications.py @@ -14,7 +14,7 @@ def up(connection): - logger.info('Migrating notifications: up') + logger.info("Migrating notifications: up") models.notifications.Base.metadata.create_all(connection) models.redaction.Base.metadata.create_all(connection) @@ -22,7 +22,7 @@ def up(connection): def down(connection): - logger.info('Migrating notifications: down') + logger.info("Migrating notifications: down") models.notifications.Base.metadata.drop_all(connection) models.redaction.Base.metadata.drop_all(connection) diff --git a/migrations/set_null_edge_columns.py b/migrations/set_null_edge_columns.py index 0e97d176..b6d2881d 100644 --- a/migrations/set_null_edge_columns.py +++ b/migrations/set_null_edge_columns.py @@ -7,7 +7,7 @@ CACHE_EDGES = { Node.get_subclass_named(edge.__src_class__): edge for edge in Edge.get_subclasses() - if 'RelatesToCase' in edge.__name__ + if "RelatesToCase" in edge.__name__ } @@ -32,9 +32,10 @@ def set_null_edge_columns(graph): def main(): - print("No main() action defined, please manually call " - "set_null_edge_columns(graph)") + print( + "No main() action defined, please manually call " "set_null_edge_columns(graph)" + ) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/migrations/update_case_cache_append_only.py b/migrations/update_case_cache_append_only.py index ae20bfe4..0ca4c508 100644 --- a/migrations/update_case_cache_append_only.py +++ b/migrations/update_case_cache_append_only.py @@ -7,7 +7,7 @@ CACHE_EDGES = { Node.get_subclass_named(edge.__src_class__): edge for edge in Edge.get_subclasses() - if 'RelatesToCase' in edge.__name__ + if "RelatesToCase" in edge.__name__ } @@ -68,13 +68,10 @@ def max_distances_from_case(): cls, level = to_visit.pop(0) if cls not in distances: - children = ( - link['src_type'] - for _, link in cls._pg_backrefs.items() - ) - to_visit.extend((child, level+1) for child in children) + children = (link["src_type"] for _, link in cls._pg_backrefs.items()) + to_visit.extend((child, level + 1) for child in children) - distances[cls] = max(distances.get(cls, level+1), level) + distances[cls] = max(distances.get(cls, level + 1), level) return distances @@ -89,9 +86,8 @@ def get_levels(): distinct_distances = set(distances.values()) levels = { - level: [ - cls for cls, distance in distances.items() if distance == level - ] for level in distinct_distances + level: [cls for cls, distance in distances.items() if distance == level] + for level in distinct_distances } return levels @@ -106,7 +102,7 @@ def append_cache_from_parent(graph, child, parent): """ - description = child.label + ' -> ' + parent.label + ' -> case' + description = child.label + " -> " + parent.label + " -> case" if parent not in CACHE_EDGES: print("skipping:", description, ": parent is not cached") @@ -137,10 +133,7 @@ def append_cache_from_parents(graph, cls): """ - parents = { - link['dst_type'] - for link in cls._pg_links.itervalues() - } + parents = {link["dst_type"] for link in cls._pg_links.itervalues()} for parent in parents: append_cache_from_parent(graph, cls, parent) @@ -168,7 +161,7 @@ def seed_level_1(graph, cls): cls_to_case_edge_table=case_edge.__tablename__, ) - print('Seeding {} through {}'.format(cls.get_label(), case_edge.__name__)) + print("Seeding {} through {}".format(cls.get_label(), case_edge.__name__)) graph.current_session().execute(statement) @@ -195,9 +188,11 @@ def update_case_cache_append_only(graph): def main(): - print("No main() action defined, please manually call " - "update_case_cache_append_only(graph)") + print( + "No main() action defined, please manually call " + "update_case_cache_append_only(graph)" + ) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/migrations/update_legacy_states.py b/migrations/update_legacy_states.py index 8faed978..4d280524 100644 --- a/migrations/update_legacy_states.py +++ b/migrations/update_legacy_states.py @@ -44,14 +44,12 @@ from gdcdatamodel import models as md CLS_WITH_PROJECT_ID = { - cls for cls in Node.get_subclasses() - if 'project_id' in cls.__pg_properties__ + cls for cls in Node.get_subclasses() if "project_id" in cls.__pg_properties__ } CLS_WITH_STATE = { - cls for cls in Node.get_subclasses() - if 'state' in cls.__pg_properties__ + cls for cls in Node.get_subclasses() if "state" in cls.__pg_properties__ } @@ -59,34 +57,13 @@ # Determines state and file_state based on existing state STATE_MAP = { - None: { - 'state': 'submitted', - 'file_state': None - }, - 'error': { - 'state': 'validated', - 'file_state': 'error' - }, - 'invalid': { - 'state': 'validated', - 'file_state': 'error' - }, - 'live': { - 'state': 'submitted', - 'file_state': 'submitted' - }, - 'submitted': { - 'state': 'submitted', - 'file_state': 'registered' - }, - 'uploaded': { - 'state': 'submitted', - 'file_state': 'uploaded' - }, - 'validated': { - 'state': 'submitted', - 'file_state': 'validated' - }, + None: {"state": "submitted", "file_state": None}, + "error": {"state": "validated", "file_state": "error"}, + "invalid": {"state": "validated", "file_state": "error"}, + "live": {"state": "submitted", "file_state": "submitted"}, + "submitted": {"state": "submitted", "file_state": "registered"}, + "uploaded": {"state": "submitted", "file_state": "uploaded"}, + "validated": {"state": "submitted", "file_state": "validated"}, } @@ -101,15 +78,12 @@ def legacy_filter(query, legacy_projects): """ legacy_filters = [ - query.entity().project_id.astext == - project.programs[0].name + '-' + project.code + query.entity().project_id.astext + == project.programs[0].name + "-" + project.code for project in legacy_projects ] - return query.filter(or_( - null_prop(query.entity(), 'project_id'), - *legacy_filters - )) + return query.filter(or_(null_prop(query.entity(), "project_id"), *legacy_filters)) def null_prop(cls, key): @@ -130,8 +104,11 @@ def print_cls_query_summary(graph): } print( - "%s: %d" % ("legacy_stateless_nodes".ljust(40), - sum([query.count() for query in cls_queries.itervalues()])) + "%s: %d" + % ( + "legacy_stateless_nodes".ljust(40), + sum([query.count() for query in cls_queries.itervalues()]), + ) ) for label, query in cls_queries.items(): @@ -143,19 +120,18 @@ def print_cls_query_summary(graph): def cls_query(graph, cls): """Returns query for legacy nodes with state in {null, 'live'}""" - legacy_projects = graph.nodes(md.Project).props(state='legacy').all() + legacy_projects = graph.nodes(md.Project).props(state="legacy").all() options = [ # state - null_prop(cls, 'state'), + null_prop(cls, "state"), cls.state.astext.in_(STATE_MAP), ] - if 'file_state' in cls.__pg_properties__: - options += [null_prop(cls, 'file_state')] + if "file_state" in cls.__pg_properties__: + options += [null_prop(cls, "file_state")] - return (legacy_filter(graph.nodes(cls), legacy_projects) - .filter(or_(*options))) + return legacy_filter(graph.nodes(cls), legacy_projects).filter(or_(*options)) def update_cls(graph, cls): @@ -167,32 +143,32 @@ def update_cls(graph, cls): if count == 0: return - logger.info('Loading %d %s nodes', count, cls.label) + logger.info("Loading %d %s nodes", count, cls.label) nodes = query.all() - logger.info('Loaded %d %s nodes', len(nodes), cls.label) + logger.info("Loaded %d %s nodes", len(nodes), cls.label) for node in nodes: - state = node._props.get('state', None) - file_state = node._props.get('file_state', None) + state = node._props.get("state", None) + file_state = node._props.get("file_state", None) if state in STATE_MAP: - node.state = STATE_MAP[state]['state'] + node.state = STATE_MAP[state]["state"] set_file_state = ( - 'file_state' in node.__pg_properties__ + "file_state" in node.__pg_properties__ and file_state is None and state in STATE_MAP ) if set_file_state: - node.file_state = STATE_MAP[state]['file_state'] + node.file_state = STATE_MAP[state]["file_state"] - node.sysan['legacy_state'] = state - node.sysan['legacy_file_state'] = file_state + node.sysan["legacy_state"] = state + node.sysan["legacy_file_state"] = file_state - logger.info('Committing %s nodes', cls.label) + logger.info("Committing %s nodes", cls.label) graph.current_session().commit() - logger.info('Done with %s nodes', cls.label) + logger.info("Done with %s nodes", cls.label) def update_classes(graph_kwargs, input_q): diff --git a/test/conftest.py b/test/conftest.py index 726d7ea4..4b38beb0 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -22,25 +22,27 @@ from gdcdatamodel.models import basic # noqa -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def db_config(): return { - 'host': 'localhost', - 'user': 'test', - 'password': 'test', - 'database': 'automated_test', + "host": "localhost", + "user": "test", + "password": "test", + "database": "automated_test", } -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def tables_created(db_config): """ Create necessary tables """ engine = create_engine( "postgres://{user}:{pwd}@{host}/{db}".format( - user=db_config['user'], host=db_config['host'], - pwd=db_config['password'], db=db_config['database'] + user=db_config["user"], + host=db_config["host"], + pwd=db_config["password"], + db=db_config["database"], ) ) @@ -51,14 +53,14 @@ def tables_created(db_config): truncate(engine) -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def g(db_config, tables_created): """Fixture for database driver""" return PsqlGraphDriver(**db_config) -@pytest.fixture(scope='class') +@pytest.fixture(scope="class") def db_class(request, g): """ Sets g property on a test class @@ -66,9 +68,10 @@ def db_class(request, g): request.cls.g = g -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def indexes(g): - rows = g.engine.execute(""" + rows = g.engine.execute( + """ SELECT i.relname as indname, ARRAY( SELECT pg_get_indexdef(idx.indexrelid, k + 1, true) @@ -80,14 +83,15 @@ def indexes(g): ON i.oid = idx.indexrelid JOIN pg_am as am ON i.relam = am.oid; - """).fetchall() + """ + ).fetchall() - return { row[0]: row[1] for row in rows } + return {row[0]: row[1] for row in rows} @pytest.fixture() def redacted_fixture(g): - """ Creates a redacted log entry""" + """Creates a redacted log entry""" with g.session_scope() as sxn: log = models.redaction.RedactionLog() @@ -100,7 +104,9 @@ def redacted_fixture(g): count = 0 for i in range(random.randint(2, 5)): count += 1 - entry = models.redaction.RedactionEntry(node_id=str(uuid.uuid4()), node_type="Aligned Reads") + entry = models.redaction.RedactionEntry( + node_id=str(uuid.uuid4()), node_type="Aligned Reads" + ) log.entries.append(entry) sxn.add(log) @@ -116,7 +122,7 @@ def redacted_fixture(g): sxn.delete(log) -@pytest.mark.usefixtures('db_class') +@pytest.mark.usefixtures("db_class") class BaseTestCase(unittest.TestCase): def setUp(self): truncate(self.g.engine) diff --git a/test/helpers.py b/test/helpers.py index b8ce1c46..0b8266b1 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -16,10 +16,10 @@ def truncate(engine, namespace=None): conn = engine.connect() for table in abstract_node.get_subclass_table_names(): if table != abstract_node.__tablename__: - conn.execute('delete from {}'.format(table)) + conn.execute("delete from {}".format(table)) for table in abstract_edge.get_subclass_table_names(): if table != abstract_edge.__tablename__: - conn.execute('delete from {}'.format(table)) + conn.execute("delete from {}".format(table)) if not namespace: # add ng models only to main graph model diff --git a/test/models.py b/test/models.py index 177879dd..b1c5f787 100644 --- a/test/models.py +++ b/test/models.py @@ -8,7 +8,6 @@ def _load(name): class Dictionary: - def __init__(self, name): self.schema = _load(name) diff --git a/test/test_admin_script.py b/test/test_admin_script.py index e5c434f4..54b82383 100644 --- a/test/test_admin_script.py +++ b/test/test_admin_script.py @@ -8,12 +8,7 @@ def get_base_args(host="localhost", database="automated_test", namespace=None): - return [ - '-H', host, - '-U', "postgres", - '-D', database, - "-N", namespace or "" - ] + return ["-H", host, "-U", "postgres", "-D", database, "-N", namespace or ""] def get_admin_driver(db_config, namespace=None): @@ -24,14 +19,18 @@ def get_admin_driver(db_config, namespace=None): host=db_config["host"], user="postgres", password=None, - database=db_config["database"] + database=db_config["database"], ) return g def drop_all_tables(g): - orm_base = ext.get_orm_base(g.package_namespace) if g.package_namespace else psqlgraph.base.ORMBase + orm_base = ( + ext.get_orm_base(g.package_namespace) + if g.package_namespace + else psqlgraph.base.ORMBase + ) psqlgraph.base.drop_all(g.engine, orm_base) @@ -44,12 +43,12 @@ def run_admin_command(args, namespace=None): def invalid_write_access_fn(g): with pytest.raises(ProgrammingError): with g.session_scope() as s: - s.add(models.Case('1')) + s.add(models.Case("1")) def valid_write_access_fn(g): with g.session_scope() as s: - s.merge(models.Case('1')) + s.merge(models.Case("1")) yield with g.session_scope() as s: n = g.nodes().get("1") @@ -77,7 +76,9 @@ def add_test_database_user(db_config): g = get_admin_driver(db_config) try: - g.engine.execute("CREATE USER {} WITH PASSWORD '{}'".format(dummy_user, dummy_pwd)) + g.engine.execute( + "CREATE USER {} WITH PASSWORD '{}'".format(dummy_user, dummy_pwd) + ) g.engine.execute("GRANT USAGE ON SCHEMA public TO {}".format(dummy_user)) yield dummy_user, dummy_pwd finally: @@ -86,10 +87,10 @@ def add_test_database_user(db_config): @pytest.mark.parametrize("namespace", [None, "gdc"], ids=["default", "custom"]) def test_create_tables(db_config, namespace): - """ Tests tables can be created with the admin script using either a custom dictionary or the default - Args: - db_config (dict[str,str]): db connection config - namespace (str): module namespace, None for default + """Tests tables can be created with the admin script using either a custom dictionary or the default + Args: + db_config (dict[str,str]): db connection config + namespace (str): module namespace, None for default """ # simulate loading a different dictionary @@ -114,7 +115,9 @@ def test_create_tables(db_config, namespace): assert 'relation "node_case" does not exist' in str(e.value) # create tables using admin script - args = ['graph-create', '--delay', '1', '--retries', '0'] + get_base_args(namespace=namespace) + args = ["graph-create", "--delay", "1", "--retries", "0"] + get_base_args( + namespace=namespace + ) parsed_args = pgadmin.get_parser().parse_args(args) pgadmin.main(parsed_args) @@ -125,13 +128,22 @@ def test_create_tables(db_config, namespace): @pytest.mark.parametrize("namespace", [None, "gdc"], ids=["default", "custom"]) -@pytest.mark.parametrize("permission, invalid_permission_fn, valid_permission_fn", [ - ("read", invalid_read_access_fn, valid_read_access_fn), - ("write", invalid_write_access_fn, valid_write_access_fn), - ], ids=["read", "write"] +@pytest.mark.parametrize( + "permission, invalid_permission_fn, valid_permission_fn", + [ + ("read", invalid_read_access_fn, valid_read_access_fn), + ("write", invalid_write_access_fn, valid_write_access_fn), + ], + ids=["read", "write"], ) -def test_grant_permissions(db_config, namespace, add_test_database_user, - permission, invalid_permission_fn, valid_permission_fn): +def test_grant_permissions( + db_config, + namespace, + add_test_database_user, + permission, + invalid_permission_fn, + valid_permission_fn, +): # simulate loading a different dictionary if namespace: @@ -140,11 +152,11 @@ def test_grant_permissions(db_config, namespace, add_test_database_user, dummy_user, dummy_pwd = add_test_database_user g = psqlgraph.PsqlGraphDriver( - host=db_config["host"], - user=dummy_user, - password=dummy_pwd, - database=db_config["database"], - package_namespace=namespace, + host=db_config["host"], + user=dummy_user, + password=dummy_pwd, + database=db_config["database"], + package_namespace=namespace, ) # verify user does not have permission @@ -156,13 +168,22 @@ def test_grant_permissions(db_config, namespace, add_test_database_user, @pytest.mark.parametrize("namespace", [None, "gdc"], ids=["default", "custom"]) -@pytest.mark.parametrize("permission, invalid_permission_fn, valid_permission_fn", [ - ("read", invalid_read_access_fn, valid_read_access_fn), - ("write", invalid_write_access_fn, valid_write_access_fn), - ], ids=["read", "write"] +@pytest.mark.parametrize( + "permission, invalid_permission_fn, valid_permission_fn", + [ + ("read", invalid_read_access_fn, valid_read_access_fn), + ("write", invalid_write_access_fn, valid_write_access_fn), + ], + ids=["read", "write"], ) -def test_revoke_permissions(db_config, namespace, add_test_database_user, - permission, invalid_permission_fn, valid_permission_fn): +def test_revoke_permissions( + db_config, + namespace, + add_test_database_user, + permission, + invalid_permission_fn, + valid_permission_fn, +): # simulate loading a different dictionary if namespace: @@ -171,11 +192,11 @@ def test_revoke_permissions(db_config, namespace, add_test_database_user, dummy_user, dummy_pwd = add_test_database_user g = psqlgraph.PsqlGraphDriver( - host=db_config["host"], - user=dummy_user, - password=dummy_pwd, - database=db_config["database"], - package_namespace=namespace, + host=db_config["host"], + user=dummy_user, + password=dummy_pwd, + database=db_config["database"], + package_namespace=namespace, ) # grant user permissions diff --git a/test/test_cache_related_cases.py b/test/test_cache_related_cases.py index bb3ab094..6a40b2f2 100644 --- a/test/test_cache_related_cases.py +++ b/test/test_cache_related_cases.py @@ -6,64 +6,64 @@ class TestCacheRelatedCases(BaseTestCase): - def test_insert_single_association_proxy(self): with self.g.session_scope() as s: - case = md.Case('case_id_1') - sample = md.Sample('sample_id_1') + case = md.Case("case_id_1") + sample = md.Sample("sample_id_1") sample.cases = [case] s.merge(sample) with self.g.session_scope() as s: - sample = self.g.nodes(md.Sample).subq_path('cases').one() + sample = self.g.nodes(md.Sample).subq_path("cases").one() assert sample._related_cases_from_cache == [case] def test_insert_single_edge(self): with self.g.session_scope() as s: - case = s.merge(md.Case('case_id_1')) - sample = s.merge(md.Sample('sample_id_1')) + case = s.merge(md.Case("case_id_1")) + sample = s.merge(md.Sample("sample_id_1")) edge = md.SampleDerivedFromCase(sample.node_id, case.node_id) s.merge(edge) with self.g.session_scope() as s: - sample = self.g.nodes(md.Sample).subq_path('cases').one() + sample = self.g.nodes(md.Sample).subq_path("cases").one() assert sample._related_cases_from_cache == [case] def test_insert_double_edge_in(self): with self.g.session_scope() as s: - case = md.Case('case_id_1') - sample1 = md.Sample('sample_id_1') - sample2 = md.Sample('sample_id_2') + case = md.Case("case_id_1") + sample1 = md.Sample("sample_id_1") + sample2 = md.Sample("sample_id_2") case.samples = [sample1, sample2] s.merge(case) with self.g.session_scope() as s: - samples = self.g.nodes(md.Sample).subq_path('cases').all() + samples = self.g.nodes(md.Sample).subq_path("cases").all() self.assertEqual(len(samples), 2) for sample in samples: assert sample._related_cases_from_cache == [case] def test_insert_double_edge_out(self): with self.g.session_scope() as s: - case1 = md.Case('case_id_1') - case2 = md.Case('case_id_2') - sample = md.Sample('sample_id_1') + case1 = md.Case("case_id_1") + case2 = md.Case("case_id_2") + sample = md.Sample("sample_id_1") sample.cases = [case1, case2] s.merge(sample) with self.g.session_scope() as s: - sample = self.g.nodes(md.Sample).subq_path('cases').one() - assert {c.node_id for c in sample._related_cases} == \ - {c.node_id for c in [case1, case2]} + sample = self.g.nodes(md.Sample).subq_path("cases").one() + assert {c.node_id for c in sample._related_cases} == { + c.node_id for c in [case1, case2] + } def test_insert_multiple_edges(self): with self.g.session_scope() as s: - case = md.Case('case_id_1') - sample = md.Sample('sample_id_1') - portion = md.Portion('portion_id_1') - analyte = md.Analyte('analyte_id_1') - aliquot = md.Aliquot('aliquot_id_1') - general_file = md.File('file_id_1') + case = md.Case("case_id_1") + sample = md.Sample("sample_id_1") + portion = md.Portion("portion_id_1") + analyte = md.Analyte("analyte_id_1") + aliquot = md.Aliquot("aliquot_id_1") + general_file = md.File("file_id_1") sample.cases = [case] portion.samples = [sample] @@ -74,17 +74,17 @@ def test_insert_multiple_edges(self): with self.g.session_scope() as s: nodes = self.g.nodes(Node).all() - nodes = [n for n in nodes if n.label not in ['case']] + nodes = [n for n in nodes if n.label not in ["case"]] for node in nodes: assert node._related_cases == [case] def test_insert_update_children(self): with self.g.session_scope() as s: - aliquot = s.merge(md.Aliquot('aliquot_id_1')) - sample = s.merge(md.Sample('sample_id_1')) + aliquot = s.merge(md.Aliquot("aliquot_id_1")) + sample = s.merge(md.Sample("sample_id_1")) aliquot.samples = [sample] - s.merge(md.Case('case_id_1')) + s.merge(md.Case("case_id_1")) with self.g.session_scope() as s: case = self.g.nodes(md.Case).one() @@ -100,9 +100,9 @@ def test_insert_update_children(self): def test_delete_dst_association_proxy(self): with self.g.session_scope() as s: - case = md.Case('case_id_1') - aliquot = md.Aliquot('aliquot_id_1') - sample = md.Sample('sample_id_1') + case = md.Case("case_id_1") + aliquot = md.Aliquot("aliquot_id_1") + sample = md.Sample("sample_id_1") aliquot.samples = [sample] sample.cases = [case] s.merge(case) @@ -112,7 +112,7 @@ def test_delete_dst_association_proxy(self): case.samples = [] with self.g.session_scope() as s: - assert not self.g.nodes(md.Sample).subq_path('cases').count() + assert not self.g.nodes(md.Sample).subq_path("cases").count() with self.g.session_scope() as s: sample = self.g.nodes(md.Sample).one() @@ -123,9 +123,9 @@ def test_delete_dst_association_proxy(self): def test_delete_src_association_proxy(self): with self.g.session_scope() as s: - case = md.Case('case_id_1') - aliquot = md.Aliquot('aliquot_id_1') - sample = md.Sample('sample_id_1') + case = md.Case("case_id_1") + aliquot = md.Aliquot("aliquot_id_1") + sample = md.Sample("sample_id_1") aliquot.samples = [sample] sample.cases = [case] s.merge(case) @@ -135,7 +135,7 @@ def test_delete_src_association_proxy(self): sample.cases = [] with self.g.session_scope() as s: - assert not self.g.nodes(md.Sample).subq_path('cases').count() + assert not self.g.nodes(md.Sample).subq_path("cases").count() with self.g.session_scope() as s: sample = self.g.nodes(md.Sample).one() @@ -146,9 +146,9 @@ def test_delete_src_association_proxy(self): def test_delete_edge(self): with self.g.session_scope() as s: - case = md.Case('case_id_1') - aliquot = md.Aliquot('aliquot_id_1') - sample = md.Sample('sample_id_1') + case = md.Case("case_id_1") + aliquot = md.Aliquot("aliquot_id_1") + sample = md.Sample("sample_id_1") aliquot.samples = [sample] sample.cases = [case] s.merge(case) @@ -156,13 +156,14 @@ def test_delete_edge(self): with self.g.session_scope() as s: case = self.g.nodes(md.Case).one() edge = [ - e for e in case.edges_in - if e.label != 'relates_to' and e.src.label == 'sample' + e + for e in case.edges_in + if e.label != "relates_to" and e.src.label == "sample" ][0] s.delete(edge) with self.g.session_scope() as s: - assert not self.g.nodes(md.Sample).subq_path('cases').count() + assert not self.g.nodes(md.Sample).subq_path("cases").count() with self.g.session_scope() as s: sample = self.g.nodes(md.Sample).one() @@ -173,8 +174,8 @@ def test_delete_edge(self): def test_delete_parent(self): with self.g.session_scope() as s: - case = md.Case('case_id_1') - sample = md.Sample('sample_id_1') + case = md.Case("case_id_1") + sample = md.Sample("sample_id_1") sample.cases = [case] s.merge(case) @@ -189,14 +190,14 @@ def test_delete_parent(self): def test_delete_one_parent(self): with self.g.session_scope() as s: - case1 = md.Case('case_id_1') - case2 = md.Case('case_id_2') - sample = md.Sample('sample_id_1') + case1 = md.Case("case_id_1") + case2 = md.Case("case_id_2") + sample = md.Sample("sample_id_1") sample.cases = [case1, case2] s.merge(sample) with self.g.session_scope() as s: - case1 = self.g.nodes(md.Case).ids('case_id_1').one() + case1 = self.g.nodes(md.Case).ids("case_id_1").one() s.delete(case1) with self.g.session_scope() as s: @@ -206,7 +207,7 @@ def test_delete_one_parent(self): def test_preserve_timestamps(self): """Confirm cache changes do not affect the case's timestamps.""" with self.g.session_scope() as s: - s.merge(md.Case('case_id_1')) + s.merge(md.Case("case_id_1")) with self.g.session_scope(): case = self.g.nodes(md.Case).one() @@ -214,16 +215,16 @@ def test_preserve_timestamps(self): old_updated_datetime = case.updated_datetime # Test addition of cache edges. - sample = md.Sample('sample_id_1') - portion = md.Portion('portion_id_1') - analyte = md.Analyte('analyte_id_1') - aliquot = md.Aliquot('aliquot_id_1') + sample = md.Sample("sample_id_1") + portion = md.Portion("portion_id_1") + analyte = md.Analyte("analyte_id_1") + aliquot = md.Aliquot("aliquot_id_1") sample.cases = [case] portion.samples = [sample] analyte.portions = [portion] aliquot.analytes = [analyte] - sample2 = md.Sample('sample_id_2') + sample2 = md.Sample("sample_id_2") sample2.cases = [case] with self.g.session_scope() as s: @@ -231,7 +232,7 @@ def test_preserve_timestamps(self): # Exercise a few cache edge removal use cases as well. analyte = self.g.nodes(md.Analyte).one() - sample2 = self.g.nodes(md.Sample).get('sample_id_2') + sample2 = self.g.nodes(md.Sample).get("sample_id_2") s.delete(analyte) sample2.cases = [] diff --git a/test/test_datamodel.py b/test/test_datamodel.py index c0c322d3..e638b38f 100644 --- a/test/test_datamodel.py +++ b/test/test_datamodel.py @@ -13,10 +13,10 @@ class TestDataModel(unittest.TestCase): @classmethod def setUpClass(cls): - host = 'localhost' - user = 'test' - password = 'test' - database = 'automated_test' + host = "localhost" + user = "test" + password = "test" + database = "automated_test" cls.g = PsqlGraphDriver(host, user, password, database) cls._clear_tables() @@ -27,59 +27,59 @@ def tearDown(self): @classmethod def _clear_tables(cls): conn = cls.g.engine.connect() - conn.execute('commit') + conn.execute("commit") for table in Node().get_subclass_table_names(): if table != Node.__tablename__: - conn.execute('delete from {}'.format(table)) + conn.execute("delete from {}".format(table)) for table in Edge.get_subclass_table_names(): if table != Edge.__tablename__: - conn.execute('delete from {}'.format(table)) - conn.execute('delete from versioned_nodes') - conn.execute('delete from _voided_nodes') - conn.execute('delete from _voided_edges') + conn.execute("delete from {}".format(table)) + conn.execute("delete from versioned_nodes") + conn.execute("delete from _voided_nodes") + conn.execute("delete from _voided_edges") conn.close() def test_type_validation(self): f = md.File() with self.assertRaises(ValidationError): - f.file_size = '0' + f.file_size = "0" f.file_size = 0 f = md.File() with self.assertRaises(ValidationError): f.file_name = 0 - f.file_name = '0' + f.file_name = "0" s = md.Sample() with self.assertRaises(ValidationError): - s.is_ffpe = 'false' + s.is_ffpe = "false" s.is_ffpe = False s = md.Slide() with self.assertRaises(ValidationError): - s.percent_necrosis = '0.0' + s.percent_necrosis = "0.0" s.percent_necrosis = 0.0 def test_link_clobber_prevention(self): with self.assertRaises(AssertionError): md.EdgeFactory( - 'Testedge', - 'test', - 'sample', - 'aliquot', - 'aliquots', - '_uncontended_backref', + "Testedge", + "test", + "sample", + "aliquot", + "aliquots", + "_uncontended_backref", ) def test_backref_clobber_prevention(self): with self.assertRaises(AssertionError): md.EdgeFactory( - 'Testedge', - 'test', - 'sample', - 'aliquot', - '_uncontended_link', - 'samples', + "Testedge", + "test", + "sample", + "aliquot", + "_uncontended_link", + "samples", ) def test_created_datetime_hook(self): @@ -87,7 +87,7 @@ def test_created_datetime_hook(self): time_before = datetime.now().isoformat() with self.g.session_scope() as s: - s.add(md.Case('case1')) + s.add(md.Case("case1")) time_after = datetime.now().isoformat() @@ -102,14 +102,14 @@ def test_created_datetime_hook(self): def test_updated_datetime_hook(self): """Test setting updated datetime when a node is updated.""" with self.g.session_scope() as s: - s.merge(md.Case('case1')) + s.merge(md.Case("case1")) with self.g.session_scope(): case = self.g.nodes(md.Case).one() old_created_datetime = case.created_datetime old_updated_datetime = case.updated_datetime - case.primary_site = 'Kidney' + case.primary_site = "Kidney" with self.g.session_scope(): updated_case = self.g.nodes(md.Case).one() @@ -119,14 +119,14 @@ def test_updated_datetime_hook(self): def test_no_datetime_update_for_new_edge(self): """Verify new inbound edges do not affect a node's updated datetime.""" with self.g.session_scope() as s: - s.merge(md.Case('case1')) + s.merge(md.Case("case1")) with self.g.session_scope() as s: case = self.g.nodes(md.Case).one() old_created_datetime = case.created_datetime old_updated_datetime = case.updated_datetime - sample = s.merge(md.Sample('sample1')) + sample = s.merge(md.Sample("sample1")) case.samples.append(sample) with self.g.session_scope(): @@ -137,9 +137,9 @@ def test_no_datetime_update_for_new_edge(self): def test_default_values(self): p = md.Project() project_defaults = p._defaults - assert project_defaults['state'] == 'open' - assert project_defaults['submission_enabled'] == True - assert project_defaults['released'] == False + assert project_defaults["state"] == "open" + assert project_defaults["submission_enabled"] == True + assert project_defaults["released"] == False def test_file_pg_edges(): diff --git a/test/test_dictionary_loadiing.py b/test/test_dictionary_loadiing.py index 87a5018e..b7673a01 100644 --- a/test/test_dictionary_loadiing.py +++ b/test/test_dictionary_loadiing.py @@ -3,16 +3,20 @@ from gdcdatamodel import models -@pytest.mark.parametrize("namespace, expectation", [ - (None, "gdcdatamodel.models"), - ("d1", "gdcdatamodel.models.d1"), - ("d2", "gdcdatamodel.models.d2"), -], ids=["default module", "custom 1", "custom 2"]) +@pytest.mark.parametrize( + "namespace, expectation", + [ + (None, "gdcdatamodel.models"), + ("d1", "gdcdatamodel.models.d1"), + ("d2", "gdcdatamodel.models.d2"), + ], + ids=["default module", "custom 1", "custom 2"], +) def test_get_package_for_class(namespace, expectation): - """ Tests retrieving the proper module to insert generated classes for a given dictionary - Args: - namespace (str): package namespace used to logically divided classes - expectation (str): final module name + """Tests retrieving the proper module to insert generated classes for a given dictionary + Args: + namespace (str): package namespace used to logically divided classes + expectation (str): final module name """ module_name = models.get_cls_package(package_namespace=namespace) @@ -20,8 +24,8 @@ def test_get_package_for_class(namespace, expectation): def test_loading_same_dictionary(): - """ Tests loading gdcdictionary into a different namespace, - even though it might already be loaded into the default + """Tests loading gdcdictionary into a different namespace, + even though it might already be loaded into the default """ # assert the custom module does not currently exist with pytest.raises(ImportError): @@ -31,6 +35,7 @@ def test_loading_same_dictionary(): # assert module & models now exist from gdcdatamodel.models import gdcx # noqa + assert gdcx.Project and gdcx.Program and gdcx.Case @@ -43,6 +48,7 @@ def test_case_cache_related_edge_resolution(): models.load_dictionary(dictionary=None, package_namespace="gdc") from gdcdatamodel.models import gdc # noqa + ns = models.caching.get_related_case_edge_cls(gdc.AlignedReads()) class_name = "{}.{}".format(ns.__module__, ns.__name__) assert "gdcdatamodel.models.gdc.AlignedReadsRelatesToCase" == class_name diff --git a/test/test_gdc_postgres_admin.py b/test/test_gdc_postgres_admin.py index e08d080e..e7899533 100644 --- a/test/test_gdc_postgres_admin.py +++ b/test/test_gdc_postgres_admin.py @@ -16,58 +16,60 @@ class TestGDCPostgresAdmin(unittest.TestCase): - logger = logging.getLogger('TestGDCPostgresAdmin') + logger = logging.getLogger("TestGDCPostgresAdmin") logger.setLevel(logging.INFO) - host = 'localhost' - user = 'postgres' - database = 'automated_test' + host = "localhost" + user = "postgres" + database = "automated_test" base_args = [ - '-H', host, - '-U', user, - '-D', database, + "-H", + host, + "-U", + user, + "-D", + database, ] - g = PsqlGraphDriver(host, user, '', database) + g = PsqlGraphDriver(host, user, "", database) root_con_str = "postgres://{user}:{pwd}@{host}/{db}".format( - user=user, host=host, pwd='', db=database) + user=user, host=host, pwd="", db=database + ) engine = pgadmin.create_engine(root_con_str) @classmethod def tearDownClass(cls): - """Recreate the database for tests that follow. - - """ + """Recreate the database for tests that follow.""" cls.create_all_tables() # Re-grant permissions to test user for scls in Node.get_subclasses() + Edge.get_subclasses(): - statment = ("GRANT ALL PRIVILEGES ON TABLE {} TO test" - .format(scls.__tablename__)) - cls.engine.execute('BEGIN; %s; COMMIT;' % statment) + statment = "GRANT ALL PRIVILEGES ON TABLE {} TO test".format( + scls.__tablename__ + ) + cls.engine.execute("BEGIN; %s; COMMIT;" % statment) @classmethod def drop_all_tables(cls): for scls in Node.get_subclasses(): try: - cls.engine.execute("DROP TABLE {} CASCADE" - .format(scls.__tablename__)) + cls.engine.execute("DROP TABLE {} CASCADE".format(scls.__tablename__)) except Exception as e: cls.logger.warning(e) @classmethod def create_all_tables(cls): parser = pgadmin.get_parser() - args = parser.parse_args([ - 'graph-create', '--delay', '1', '--retries', '0' - ] + cls.base_args) + args = parser.parse_args( + ["graph-create", "--delay", "1", "--retries", "0"] + cls.base_args + ) pgadmin.main(args) @classmethod def drop_a_table(cls): - cls.engine.execute('DROP TABLE edge_clinicaldescribescase') - cls.engine.execute('DROP TABLE node_clinical') + cls.engine.execute("DROP TABLE edge_clinicaldescribescase") + cls.engine.execute("DROP TABLE node_clinical") def startTestRun(self): self.drop_all_tables() @@ -77,25 +79,29 @@ def setUp(self): def test_args(self): parser = pgadmin.get_parser() - parser.parse_args(['graph-create'] + self.base_args) + parser.parse_args(["graph-create"] + self.base_args) def test_create_single(self): """Test simple table creation""" - pgadmin.main(pgadmin.get_parser().parse_args([ - 'graph-create', '--delay', '1', '--retries', '0' - ] + self.base_args)) + pgadmin.main( + pgadmin.get_parser().parse_args( + ["graph-create", "--delay", "1", "--retries", "0"] + self.base_args + ) + ) - self.engine.execute('SELECT * from node_case') + self.engine.execute("SELECT * from node_case") def test_create_double(self): """Test idempotency of table creation""" - pgadmin.main(pgadmin.get_parser().parse_args([ - 'graph-create', '--delay', '1', '--retries', '0' - ] + self.base_args)) + pgadmin.main( + pgadmin.get_parser().parse_args( + ["graph-create", "--delay", "1", "--retries", "0"] + self.base_args + ) + ) - self.engine.execute('SELECT * from node_case') + self.engine.execute("SELECT * from node_case") def test_priv_grant_read(self): """Test ability to grant read but not write privs""" @@ -106,23 +112,29 @@ def test_priv_grant_read(self): self.engine.execute("CREATE USER pytest WITH PASSWORD 'pyt3st'") self.engine.execute("GRANT USAGE ON SCHEMA public TO pytest") - g = PsqlGraphDriver(self.host, 'pytest', 'pyt3st', self.database) + g = PsqlGraphDriver(self.host, "pytest", "pyt3st", self.database) #: If this failes, this test (not the code) is wrong! with self.assertRaises(ProgrammingError): with g.session_scope(): g.nodes().count() - pgadmin.main(pgadmin.get_parser().parse_args([ - 'graph-grant', '--read=pytest', - ] + self.base_args)) + pgadmin.main( + pgadmin.get_parser().parse_args( + [ + "graph-grant", + "--read=pytest", + ] + + self.base_args + ) + ) with g.session_scope(): g.nodes().count() with self.assertRaises(ProgrammingError): with g.session_scope() as s: - s.merge(models.Case('1')) + s.merge(models.Case("1")) finally: self.engine.execute("DROP OWNED BY pytest; DROP USER pytest") @@ -136,14 +148,20 @@ def test_priv_grant_write(self): self.engine.execute("CREATE USER pytest WITH PASSWORD 'pyt3st'") self.engine.execute("GRANT USAGE ON SCHEMA public TO pytest") - g = PsqlGraphDriver(self.host, 'pytest', 'pyt3st', self.database) - pgadmin.main(pgadmin.get_parser().parse_args([ - 'graph-grant', '--write=pytest', - ] + self.base_args)) + g = PsqlGraphDriver(self.host, "pytest", "pyt3st", self.database) + pgadmin.main( + pgadmin.get_parser().parse_args( + [ + "graph-grant", + "--write=pytest", + ] + + self.base_args + ) + ) with g.session_scope() as s: g.nodes().count() - s.merge(models.Case('1')) + s.merge(models.Case("1")) finally: self.engine.execute("DROP OWNED BY pytest; DROP USER pytest") @@ -157,20 +175,32 @@ def test_priv_revoke_read(self): self.engine.execute("CREATE USER pytest WITH PASSWORD 'pyt3st'") self.engine.execute("GRANT USAGE ON SCHEMA public TO pytest") - g = PsqlGraphDriver(self.host, 'pytest', 'pyt3st', self.database) - - pgadmin.main(pgadmin.get_parser().parse_args([ - 'graph-grant', '--read=pytest', - ] + self.base_args)) - - pgadmin.main(pgadmin.get_parser().parse_args([ - 'graph-revoke', '--read=pytest', - ] + self.base_args)) + g = PsqlGraphDriver(self.host, "pytest", "pyt3st", self.database) + + pgadmin.main( + pgadmin.get_parser().parse_args( + [ + "graph-grant", + "--read=pytest", + ] + + self.base_args + ) + ) + + pgadmin.main( + pgadmin.get_parser().parse_args( + [ + "graph-revoke", + "--read=pytest", + ] + + self.base_args + ) + ) with self.assertRaises(ProgrammingError): with g.session_scope() as s: g.nodes().count() - s.merge(models.Case('1')) + s.merge(models.Case("1")) finally: self.engine.execute("DROP OWNED BY pytest; DROP USER pytest") @@ -184,22 +214,34 @@ def test_priv_revoke_write(self): self.engine.execute("CREATE USER pytest WITH PASSWORD 'pyt3st'") self.engine.execute("GRANT USAGE ON SCHEMA public TO pytest") - g = PsqlGraphDriver(self.host, 'pytest', 'pyt3st', self.database) - - pgadmin.main(pgadmin.get_parser().parse_args([ - 'graph-grant', '--write=pytest', - ] + self.base_args)) - - pgadmin.main(pgadmin.get_parser().parse_args([ - 'graph-revoke', '--write=pytest', - ] + self.base_args)) + g = PsqlGraphDriver(self.host, "pytest", "pyt3st", self.database) + + pgadmin.main( + pgadmin.get_parser().parse_args( + [ + "graph-grant", + "--write=pytest", + ] + + self.base_args + ) + ) + + pgadmin.main( + pgadmin.get_parser().parse_args( + [ + "graph-revoke", + "--write=pytest", + ] + + self.base_args + ) + ) with g.session_scope() as s: g.nodes().count() with self.assertRaises(ProgrammingError): with g.session_scope() as s: - s.merge(models.Case('1')) + s.merge(models.Case("1")) finally: self.engine.execute("DROP OWNED BY pytest; DROP USER pytest") diff --git a/test/test_indexes.py b/test/test_indexes.py index cd0f4050..2307e0ea 100644 --- a/test/test_indexes.py +++ b/test/test_indexes.py @@ -8,7 +8,7 @@ def test_secondary_key_indexes(indexes): - assert 'index_node_datasubtype_name_lower' in indexes - assert 'index_node_analyte_project_id' in indexes - assert 'index_4df72441_famihist_submitte_id_lower' in indexes - assert 'transaction_logs_project_id_idx' in indexes + assert "index_node_datasubtype_name_lower" in indexes + assert "index_node_analyte_project_id" in indexes + assert "index_4df72441_famihist_submitte_id_lower" in indexes + assert "transaction_logs_project_id_idx" in indexes diff --git a/test/test_node_tagging.py b/test/test_node_tagging.py index 65bbb46b..64ce70df 100644 --- a/test/test_node_tagging.py +++ b/test/test_node_tagging.py @@ -6,16 +6,16 @@ from gdcdatamodel.models import basic, versioning # noqa -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def bg(): """Fixture for database driver""" cfg = { - 'host': 'localhost', - 'user': 'test', - 'password': 'test', - 'database': 'dev_models', - 'package_namespace': 'basic', + "host": "localhost", + "user": "test", + "password": "test", + "database": "dev_models", + "package_namespace": "basic", } g = PsqlGraphDriver(**cfg) @@ -31,7 +31,10 @@ def create_samples(sample_data, bg): version_2s = [] for node in sample_data: # delay adding version 2 - if node.node_id in ["a2b2d27a-6523-4ddd-8b2e-e94437a2aa23", "5ffb4b0e-969e-4643-8187-536ce7130e9c"]: + if node.node_id in [ + "a2b2d27a-6523-4ddd-8b2e-e94437a2aa23", + "5ffb4b0e-969e-4643-8187-536ce7130e9c", + ]: version_2s.append(node) continue s.add(node) @@ -45,16 +48,51 @@ def create_samples(sample_data, bg): bg.node_delete(n.node_id) -@pytest.mark.parametrize("node_id, tag, version", [ - ("be66197b-f6cc-4366-bded-365856ec4f63", "84044bd2-54a4-5837-b83d-f920eb97c18d", 1), - ("a2b2d27a-6523-4ddd-8b2e-e94437a2aa23", "84044bd2-54a4-5837-b83d-f920eb97c18d", 2), - ("813f97c4-dffc-4f94-b3f6-66a93476a233", "9a81bbad-b525-568c-b85d-d269a8bdc70a", 1), - ("6974c692-be47-4cb8-b8d6-9bd815983cd9", "55814b2f-fc23-5bed-9eab-c73c52c105df", 1), - ("5ffb4b0e-969e-4643-8187-536ce7130e9c", "55814b2f-fc23-5bed-9eab-c73c52c105df", 2), - ("c6a795f6-ee4a-4fcd-bfed-79348e07cd49", "8cc95392-5861-5524-8b98-a85e18d8294c", 1), - ("ed9aa864-1e40-4657-9378-7e3dc26551cc", "fddc5826-8853-5c1a-847d-5850d58ccb3e", 1), - ("fb69d25b-5c5d-4879-8955-8f2126e57524", "293d5dd3-117c-5a0a-8030-a428fdf2681b", 1), -]) +@pytest.mark.parametrize( + "node_id, tag, version", + [ + ( + "be66197b-f6cc-4366-bded-365856ec4f63", + "84044bd2-54a4-5837-b83d-f920eb97c18d", + 1, + ), + ( + "a2b2d27a-6523-4ddd-8b2e-e94437a2aa23", + "84044bd2-54a4-5837-b83d-f920eb97c18d", + 2, + ), + ( + "813f97c4-dffc-4f94-b3f6-66a93476a233", + "9a81bbad-b525-568c-b85d-d269a8bdc70a", + 1, + ), + ( + "6974c692-be47-4cb8-b8d6-9bd815983cd9", + "55814b2f-fc23-5bed-9eab-c73c52c105df", + 1, + ), + ( + "5ffb4b0e-969e-4643-8187-536ce7130e9c", + "55814b2f-fc23-5bed-9eab-c73c52c105df", + 2, + ), + ( + "c6a795f6-ee4a-4fcd-bfed-79348e07cd49", + "8cc95392-5861-5524-8b98-a85e18d8294c", + 1, + ), + ( + "ed9aa864-1e40-4657-9378-7e3dc26551cc", + "fddc5826-8853-5c1a-847d-5850d58ccb3e", + 1, + ), + ( + "fb69d25b-5c5d-4879-8955-8f2126e57524", + "293d5dd3-117c-5a0a-8030-a428fdf2681b", + 1, + ), + ], +) def test_1(create_samples, bg, node_id, tag, version): with bg.session_scope(): diff --git a/test/test_update_case_cache.py b/test/test_update_case_cache.py index 87170de0..0c681e30 100644 --- a/test/test_update_case_cache.py +++ b/test/test_update_case_cache.py @@ -16,22 +16,22 @@ def case_tree(g): """Create tree to test cache cache on""" - case = md.Case('case') + case = md.Case("case") case.samples = [ - md.Sample('sample1'), - md.Sample('sample2'), + md.Sample("sample1"), + md.Sample("sample2"), ] case.samples[0].portions = [ - md.Portion('portion1'), - md.Portion('portion2'), + md.Portion("portion1"), + md.Portion("portion2"), ] case.samples[0].portions[0].analytes = [ - md.Analyte('analyte1'), - md.Analyte('analyte2'), + md.Analyte("analyte1"), + md.Analyte("analyte2"), ] case.samples[0].portions[0].analytes[0].aliquots = [ - md.Aliquot('aliquot1'), - md.Aliquot('alituoq2'), + md.Aliquot("aliquot1"), + md.Aliquot("alituoq2"), ] with g.session_scope() as session: session.merge(case) @@ -58,7 +58,7 @@ def test_update_case_cache(g, case_tree_no_cache): with g.session_scope(): nodes = g.nodes().all() - nodes = (node for node in nodes if hasattr(node, '_related_cases')) + nodes = (node for node in nodes if hasattr(node, "_related_cases")) for node in nodes: assert node._related_cases diff --git a/test/test_validators.py b/test/test_validators.py index 6a4c5d79..fd760138 100644 --- a/test/test_validators.py +++ b/test/test_validators.py @@ -25,44 +25,59 @@ def setUp(self): self.entities = [MockSubmissionEntity()] def test_json_validator_with_insufficient_properties(self): - self.entities[0].doc = {'type': 'aliquot', - 'centers': {'submitter_id': 'test'}} + self.entities[0].doc = {"type": "aliquot", "centers": {"submitter_id": "test"}} self.json_validator.record_errors(self.entities) - self.assertEqual(self.entities[0].errors[0]['keys'], ['submitter_id']) + self.assertEqual(self.entities[0].errors[0]["keys"], ["submitter_id"]) self.assertEqual(1, len(self.entities[0].errors)) def test_json_validator_with_wrong_node_type(self): - self.entities[0].doc = {'type': 'aliquo'} + self.entities[0].doc = {"type": "aliquo"} self.json_validator.record_errors(self.entities) - self.assertEqual(self.entities[0].errors[0]['keys'], ['type']) + self.assertEqual(self.entities[0].errors[0]["keys"], ["type"]) self.assertEqual(1, len(self.entities[0].errors)) def test_json_validator_with_wrong_property_type(self): - self.entities[0].doc = {'type': 'aliquot', - 'submitter_id': 1, 'centers': {'submitter_id': 'test'}} + self.entities[0].doc = { + "type": "aliquot", + "submitter_id": 1, + "centers": {"submitter_id": "test"}, + } self.json_validator.record_errors(self.entities) - self.assertEqual(['submitter_id'], self.entities[0].errors[0]['keys']) + self.assertEqual(["submitter_id"], self.entities[0].errors[0]["keys"]) self.assertEqual(1, len(self.entities[0].errors)) def test_json_validator_with_multiple_errors(self): - self.entities[0].doc = {'type': 'aliquot', 'submitter_id': 1, - 'test': 'test', - 'centers': {'submitter_id': 'test'}} + self.entities[0].doc = { + "type": "aliquot", + "submitter_id": 1, + "test": "test", + "centers": {"submitter_id": "test"}, + } self.json_validator.record_errors(self.entities) self.assertEqual(2, len(self.entities[0].errors)) def test_json_validator_with_nested_error_keys(self): - self.entities[0].doc = {'type': 'aliquot', 'submitter_id': 'test', - 'centers': {'submitter_id': True}} + self.entities[0].doc = { + "type": "aliquot", + "submitter_id": "test", + "centers": {"submitter_id": True}, + } self.json_validator.record_errors(self.entities) - self.assertEqual(['centers'], self.entities[0].errors[0]['keys']) + self.assertEqual(["centers"], self.entities[0].errors[0]["keys"]) def test_json_validator_with_multiple_entities(self): - self.entities[0].doc = {'type': 'aliquot', 'submitter_id': 1, 'test': 'test', - 'centers': {'submitter_id': 'test'}} + self.entities[0].doc = { + "type": "aliquot", + "submitter_id": 1, + "test": "test", + "centers": {"submitter_id": "test"}, + } entity = MockSubmissionEntity() - entity.doc = {'type': 'aliquot', 'submitter_id': 'test', - 'centers': {'submitter_id': 'test'}} + entity.doc = { + "type": "aliquot", + "submitter_id": "test", + "centers": {"submitter_id": "test"}, + } self.entities.append(entity) self.json_validator.record_errors(self.entities) @@ -71,10 +86,10 @@ def test_json_validator_with_multiple_entities(self): def test_json_validator_with_array_prop(self): entity_doc = { - 'type': 'diagnosis', - 'submitter_id': 'test', - 'age_at_diagnosis': 10, - 'primary_diagnosis': 'Abdominal desmoid', + "type": "diagnosis", + "submitter_id": "test", + "age_at_diagnosis": 10, + "primary_diagnosis": "Abdominal desmoid", "morphology": "8000/0", "tissue_or_organ_of_origin": "Abdomen, NOS", "site_of_resection_or_biopsy": "Abdomen, NOS", @@ -107,10 +122,10 @@ def mock_doc(sites_of_involvement): self.assertIn("diagnosis_is_primary_disease", error_keys_one) def create_node(self, doc, session): - cls = Node.get_subclass(doc['type']) + cls = Node.get_subclass(doc["type"]) node = cls(str(uuid.uuid4())) - node.props = doc['props'] - for key, value in doc['edges'].items(): + node.props = doc["props"] + for key, value in doc["edges"].items(): for target_id in value: edge = self.g.nodes().ids(target_id).first() node[key].append(edge) @@ -122,164 +137,264 @@ def update_schema(self, entity, key, schema): def test_graph_validator_without_required_link(self): with self.g.session_scope() as session: - node = self.create_node({'type': 'aliquot', - 'props': {'submitter_id': 'test'}, - 'edges': {}}, session) + node = self.create_node( + {"type": "aliquot", "props": {"submitter_id": "test"}, "edges": {}}, + session, + ) self.entities[0].node = node self.update_schema( - 'aliquot', - 'links', - [{'name': 'analytes', - 'backref': 'aliquots', - 'label': 'derived_from', - 'multiplicity': 'many_to_one', - 'target_type': 'analyte', - 'required': True}]) + "aliquot", + "links", + [ + { + "name": "analytes", + "backref": "aliquots", + "label": "derived_from", + "multiplicity": "many_to_one", + "target_type": "analyte", + "required": True, + } + ], + ) self.graph_validator.record_errors(self.g, self.entities) - self.assertEqual(['analytes'], self.entities[0].errors[0]['keys']) + self.assertEqual(["analytes"], self.entities[0].errors[0]["keys"]) def test_graph_validator_with_exclusive_link(self): with self.g.session_scope() as session: analyte = self.create_node( - {'type': 'analyte', - 'props': {'submitter_id': 'test', - 'analyte_type_id': 'D', - 'analyte_type': 'DNA'}, - 'edges': {}}, session) - sample = self.create_node({'type': 'sample', - 'props': {'submitter_id': 'test', - 'sample_type': 'DNA', - 'sample_type_id': '01'}, - 'edges': {}}, session) + { + "type": "analyte", + "props": { + "submitter_id": "test", + "analyte_type_id": "D", + "analyte_type": "DNA", + }, + "edges": {}, + }, + session, + ) + sample = self.create_node( + { + "type": "sample", + "props": { + "submitter_id": "test", + "sample_type": "DNA", + "sample_type_id": "01", + }, + "edges": {}, + }, + session, + ) node = self.create_node( - {'type': 'aliquot', - 'props': {'submitter_id': 'test'}, - 'edges': {'analytes': [analyte.node_id], - 'samples': [sample.node_id]}}, session) + { + "type": "aliquot", + "props": {"submitter_id": "test"}, + "edges": { + "analytes": [analyte.node_id], + "samples": [sample.node_id], + }, + }, + session, + ) self.entities[0].node = node self.update_schema( - 'aliquot', - 'links', - [{'exclusive': True, - 'required': True, - 'subgroup': [ - {'name': 'analytes', - 'backref': 'aliquots', - 'label': 'derived_from', - 'multiplicity': 'many_to_one', - 'target_type': 'analyte'}, - {'name': 'samples', - 'backref': 'aliquots', - 'label': 'derived_from', - 'multiplicity': 'many_to_one', - 'target_type': 'sample'}]}]) + "aliquot", + "links", + [ + { + "exclusive": True, + "required": True, + "subgroup": [ + { + "name": "analytes", + "backref": "aliquots", + "label": "derived_from", + "multiplicity": "many_to_one", + "target_type": "analyte", + }, + { + "name": "samples", + "backref": "aliquots", + "label": "derived_from", + "multiplicity": "many_to_one", + "target_type": "sample", + }, + ], + } + ], + ) self.graph_validator.record_errors(self.g, self.entities) - self.assertEqual(['analytes', 'samples'], - self.entities[0].errors[0]['keys']) + self.assertEqual( + ["analytes", "samples"], self.entities[0].errors[0]["keys"] + ) def test_graph_validator_with_wrong_multiplicity(self): with self.g.session_scope() as session: - analyte = self.create_node({'type': 'analyte', - 'props': {'submitter_id': 'test', - 'analyte_type_id': 'D', - 'analyte_type': 'DNA'}, - 'edges': {}}, session) - - analyte_b = self.create_node({'type': 'analyte', - 'props': {'submitter_id': 'testb', - 'analyte_type_id': 'H', - 'analyte_type': 'RNA'}, - 'edges': {}}, session) - - node = self.create_node({'type': 'aliquot', - 'props': {'submitter_id': 'test'}, - 'edges': {'analytes': [analyte.node_id, - analyte_b.node_id]}}, - session) + analyte = self.create_node( + { + "type": "analyte", + "props": { + "submitter_id": "test", + "analyte_type_id": "D", + "analyte_type": "DNA", + }, + "edges": {}, + }, + session, + ) + + analyte_b = self.create_node( + { + "type": "analyte", + "props": { + "submitter_id": "testb", + "analyte_type_id": "H", + "analyte_type": "RNA", + }, + "edges": {}, + }, + session, + ) + + node = self.create_node( + { + "type": "aliquot", + "props": {"submitter_id": "test"}, + "edges": {"analytes": [analyte.node_id, analyte_b.node_id]}, + }, + session, + ) self.entities[0].node = node self.update_schema( - 'aliquot', - 'links', - [{'exclusive': False, - 'required': True, - 'subgroup': [ - {'name': 'analytes', - 'backref': 'aliquots', - 'label': 'derived_from', - 'multiplicity': 'many_to_one', - 'target_type': 'analyte'}, - {'name': 'samples', - 'backref': 'aliquots', - 'label': 'derived_from', - 'multiplicity': 'many_to_one', - 'target_type': 'sample'}]}]) + "aliquot", + "links", + [ + { + "exclusive": False, + "required": True, + "subgroup": [ + { + "name": "analytes", + "backref": "aliquots", + "label": "derived_from", + "multiplicity": "many_to_one", + "target_type": "analyte", + }, + { + "name": "samples", + "backref": "aliquots", + "label": "derived_from", + "multiplicity": "many_to_one", + "target_type": "sample", + }, + ], + } + ], + ) self.graph_validator.record_errors(self.g, self.entities) - self.assertEqual(['analytes'], self.entities[0].errors[0]['keys']) + self.assertEqual(["analytes"], self.entities[0].errors[0]["keys"]) def test_graph_validator_with_correct_node(self): with self.g.session_scope() as session: - analyte = self.create_node({'type': 'analyte', - 'props': {'submitter_id': 'test', - 'analyte_type_id': 'D', - 'analyte_type': 'DNA'}, - 'edges': {}}, session) - - node = self.create_node({'type': 'aliquot', - 'props': {'submitter_id': 'test'}, - 'edges': {'analytes': [analyte.node_id]}}, - session) + analyte = self.create_node( + { + "type": "analyte", + "props": { + "submitter_id": "test", + "analyte_type_id": "D", + "analyte_type": "DNA", + }, + "edges": {}, + }, + session, + ) + + node = self.create_node( + { + "type": "aliquot", + "props": {"submitter_id": "test"}, + "edges": {"analytes": [analyte.node_id]}, + }, + session, + ) self.entities[0].node = node self.update_schema( - 'aliquot', - 'links', - [{'exclusive': False, - 'required': True, - 'subgroup': [ - {'name': 'analytes', - 'backref': 'aliquots', - 'label': 'derived_from', - 'multiplicity': 'many_to_one', - 'target_type': 'analyte'}, - {'name': 'samples', - 'backref': 'aliquots', - 'label': 'derived_from', - 'multiplicity': 'many_to_one', - 'target_type': 'sample'}]}]) + "aliquot", + "links", + [ + { + "exclusive": False, + "required": True, + "subgroup": [ + { + "name": "analytes", + "backref": "aliquots", + "label": "derived_from", + "multiplicity": "many_to_one", + "target_type": "analyte", + }, + { + "name": "samples", + "backref": "aliquots", + "label": "derived_from", + "multiplicity": "many_to_one", + "target_type": "sample", + }, + ], + } + ], + ) self.graph_validator.record_errors(self.g, self.entities) self.assertEqual(0, len(self.entities[0].errors)) def test_graph_validator_with_existing_unique_keys(self): with self.g.session_scope() as session: - node = self.create_node({'type': 'data_format', - 'props': {'name': 'test'}, - 'edges': {}}, - session) - node = self.create_node({'type': 'data_format', - 'props': {'name': 'test'}, - 'edges': {}}, - session) - self.update_schema('data_format', 'uniqueKeys', [['name']]) + node = self.create_node( + {"type": "data_format", "props": {"name": "test"}, "edges": {}}, session + ) + node = self.create_node( + {"type": "data_format", "props": {"name": "test"}, "edges": {}}, session + ) + self.update_schema("data_format", "uniqueKeys", [["name"]]) self.entities[0].node = node self.graph_validator.record_errors(self.g, self.entities) - self.assertEqual(['name'], self.entities[0].errors[0]['keys']) + self.assertEqual(["name"], self.entities[0].errors[0]["keys"]) def test_graph_validator_with_existing_unique_keys_for_different_node_types(self): with self.g.session_scope() as session: - node = self.create_node({'type': 'sample', - 'props': {'submitter_id': 'test','project_id':'A'}, - 'edges': {}}, - session) - node = self.create_node({'type': 'aliquot', - 'props': {'submitter_id': 'test', 'project_id':'A'}, - 'edges': {}}, - session) - self.update_schema('data_format', 'uniqueKeys', [['submitter_id', 'project_id']]) + node = self.create_node( + { + "type": "sample", + "props": {"submitter_id": "test", "project_id": "A"}, + "edges": {}, + }, + session, + ) + node = self.create_node( + { + "type": "aliquot", + "props": {"submitter_id": "test", "project_id": "A"}, + "edges": {}, + }, + session, + ) + self.update_schema( + "data_format", "uniqueKeys", [["submitter_id", "project_id"]] + ) self.entities[0].node = node self.graph_validator.record_errors(self.g, self.entities) # Check (project_id, submitter_id) uniqueness is captured - self.assertTrue(any({'project_id', 'submitter_id'} == set(e['keys']) - for e in self.entities[0].errors)) + self.assertTrue( + any( + {"project_id", "submitter_id"} == set(e["keys"]) + for e in self.entities[0].errors + ) + ) # Check that missing edges is captured - self.assertTrue(any({'analytes', 'samples'} == set(e['keys']) - for e in self.entities[0].errors)) + self.assertTrue( + any( + {"analytes", "samples"} == set(e["keys"]) + for e in self.entities[0].errors + ) + ) diff --git a/test/test_versioned_nodes.py b/test/test_versioned_nodes.py index 82f82856..d0256e38 100644 --- a/test/test_versioned_nodes.py +++ b/test/test_versioned_nodes.py @@ -4,32 +4,35 @@ class TestValidators(BaseTestCase): - @staticmethod def new_portion(): - portion = md.Portion(**{ - 'node_id': 'case1', - 'is_ffpe': False, - 'portion_number': '01', - 'project_id': 'CGCI-BLGSP', - 'state': 'validated', - 'submitter_id': 'PORTION-1', - 'weight': 54.0 - }) - portion.acl = ['acl1'] - portion.sysan.update({'key1': 'val1'}) + portion = md.Portion( + **{ + "node_id": "case1", + "is_ffpe": False, + "portion_number": "01", + "project_id": "CGCI-BLGSP", + "state": "validated", + "submitter_id": "PORTION-1", + "weight": 54.0, + } + ) + portion.acl = ["acl1"] + portion.sysan.update({"key1": "val1"}) return portion @staticmethod def new_analyte(): - return md.Analyte(**{ - 'node_id': 'analyte1', - 'analyte_type': 'Repli-G (Qiagen) DNA', - 'analyte_type_id': 'W', - 'project_id': 'CGCI-BLGSP', - 'state': 'validated', - 'submitter_id': 'TCGA-AR-A1AR-01A-31W', - }) + return md.Analyte( + **{ + "node_id": "analyte1", + "analyte_type": "Repli-G (Qiagen) DNA", + "analyte_type_id": "W", + "project_id": "CGCI-BLGSP", + "state": "validated", + "submitter_id": "TCGA-AR-A1AR-01A-31W", + } + ) def test_round_trip(self): with self.g.session_scope() as session: @@ -46,12 +49,12 @@ def test_round_trip(self): with self.g.session_scope(): v_node = self.g.nodes(md.VersionedNode).one() - self.assertEqual(v_node.properties['is_ffpe'], False) - self.assertEqual(v_node.properties['state'], 'validated') - self.assertEqual(v_node.properties['state'], 'validated') - self.assertEqual(v_node.system_annotations, {'key1': 'val1'}) - self.assertEqual(v_node.acl, ['acl1']) - self.assertEqual(v_node.neighbors, ['analyte1']) + self.assertEqual(v_node.properties["is_ffpe"], False) + self.assertEqual(v_node.properties["state"], "validated") + self.assertEqual(v_node.properties["state"], "validated") + self.assertEqual(v_node.system_annotations, {"key1": "val1"}) + self.assertEqual(v_node.acl, ["acl1"]) + self.assertEqual(v_node.neighbors, ["analyte1"]) self.assertIsNotNone(v_node.versioned) self.assertIsNotNone(v_node.key)