From 3f1058ff498823eff008395544aa9d1b9c55a1af Mon Sep 17 00:00:00 2001
From: Qiao Qiao <68757394+qiaouchicago@users.noreply.github.com>
Date: Mon, 17 Oct 2022 09:42:07 -0500
Subject: [PATCH] DEV-1483: add black and reformat (#389)
reformat all using black
---
.pre-commit-config.yaml | 5 +
.secrets.baseline | 4 +-
bin/update_related_case_caches.py | 40 +-
docs/bin/schemata_to_graphviz.py | 24 +-
gdcdatamodel/__main__.py | 51 ++-
gdcdatamodel/gdc_postgres_admin.py | 142 ++++---
gdcdatamodel/models/__init__.py | 4 +-
gdcdatamodel/models/caching.py | 71 ++--
gdcdatamodel/models/indexes.py | 27 +-
gdcdatamodel/models/utils.py | 4 +-
gdcdatamodel/models/versioned_nodes.py | 30 +-
gdcdatamodel/models/versioning.py | 6 +-
gdcdatamodel/query.py | 65 ++-
gdcdatamodel/validators/graph_validators.py | 111 +++---
gdcdatamodel/validators/json_validators.py | 28 +-
gdcdatamodel/viz/__init__.py | 10 +-
migrations/async_transactions.py | 16 +-
migrations/index_secondary_keys.py | 13 +-
migrations/notifications.py | 4 +-
migrations/set_null_edge_columns.py | 9 +-
migrations/update_case_cache_append_only.py | 33 +-
migrations/update_legacy_states.py | 90 ++---
test/conftest.py | 40 +-
test/helpers.py | 4 +-
test/models.py | 1 -
test/test_admin_script.py | 97 +++--
test/test_cache_related_cases.py | 109 +++---
test/test_datamodel.py | 68 ++--
test/test_dictionary_loadiing.py | 28 +-
test/test_gdc_postgres_admin.py | 164 +++++---
test/test_indexes.py | 8 +-
test/test_node_tagging.py | 72 +++-
test/test_update_case_cache.py | 20 +-
test/test_validators.py | 413 +++++++++++++-------
test/test_versioned_nodes.py | 55 +--
35 files changed, 1083 insertions(+), 783 deletions(-)
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 7fd0f026..f5177be2 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -20,6 +20,11 @@ repos:
args: [--autofix]
- id: trailing-whitespace
args: [--markdown-linebreak-ext=md]
+ - repo: https://github.com/psf/black
+ rev: 20.8b1 # for pre-commit 1.21.0 in jenkins
+ hooks:
+ - id: black
+ additional_dependencies: [click==8.0.4]
- repo: https://github.com/pycqa/isort
rev: 5.6.4 # last version to support pre-commit 1.21.0 in Jenkins
hooks:
diff --git a/.secrets.baseline b/.secrets.baseline
index 9e872771..e18f4983 100644
--- a/.secrets.baseline
+++ b/.secrets.baseline
@@ -3,7 +3,7 @@
"files": "^.secrets.baseline$",
"lines": null
},
- "generated_at": "2022-10-11T18:55:31Z",
+ "generated_at": "2022-10-14T21:28:31Z",
"plugins_used": [
{
"name": "AWSKeyDetector"
@@ -69,7 +69,7 @@
"hashed_secret": "5d0fa74acf95d1d6bebd0d37f76a94e77d604fd9",
"is_secret": false,
"is_verified": false,
- "line_number": 33,
+ "line_number": 36,
"type": "Basic Auth Credentials"
}
]
diff --git a/bin/update_related_case_caches.py b/bin/update_related_case_caches.py
index 9cd660d9..52ea94c6 100644
--- a/bin/update_related_case_caches.py
+++ b/bin/update_related_case_caches.py
@@ -25,8 +25,11 @@ def recursive_update_related_case_caches(node, case, visited_ids=set()):
"""
- logger.info("{}: | case: {} | project: {}".format(
- node, case, node._props.get('project_id', '?')))
+ logger.info(
+ "{}: | case: {} | project: {}".format(
+ node, case, node._props.get("project_id", "?")
+ )
+ )
visited_ids.add(node.node_id)
@@ -34,10 +37,10 @@ def recursive_update_related_case_caches(node, case, visited_ids=set()):
if edge.src is None:
continue
- if edge.__class__.__name__.endswith('RelatesToCase'):
+ if edge.__class__.__name__.endswith("RelatesToCase"):
continue
- if not hasattr(edge.src, '_related_cases'):
+ if not hasattr(edge.src, "_related_cases"):
continue
original = set(edge.src._related_cases)
@@ -61,14 +64,23 @@ def update_project_related_case_cache(project):
def main():
parser = argparse.ArgumentParser()
- parser.add_argument("-H", "--host", type=str, action="store",
- required=True, help="psql-server host")
- parser.add_argument("-U", "--user", type=str, action="store",
- required=True, help="psql test user")
- parser.add_argument("-D", "--database", type=str, action="store",
- required=True, help="psql test database")
- parser.add_argument("-P", "--password", type=str, action="store",
- help="psql test password")
+ parser.add_argument(
+ "-H", "--host", type=str, action="store", required=True, help="psql-server host"
+ )
+ parser.add_argument(
+ "-U", "--user", type=str, action="store", required=True, help="psql test user"
+ )
+ parser.add_argument(
+ "-D",
+ "--database",
+ type=str,
+ action="store",
+ required=True,
+ help="psql test database",
+ )
+ parser.add_argument(
+ "-P", "--password", type=str, action="store", help="psql test password"
+ )
args = parser.parse_args()
prompt = "Password for {}:".format(args.user)
@@ -76,12 +88,12 @@ def main():
g = PsqlGraphDriver(args.host, args.user, password, args.database)
with g.session_scope():
- projects = g.nodes(md.Project).not_props(state='legacy').all()
+ projects = g.nodes(md.Project).not_props(state="legacy").all()
for p in projects:
update_project_related_case_cache(p)
print("Done.")
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/docs/bin/schemata_to_graphviz.py b/docs/bin/schemata_to_graphviz.py
index e5bb2407..f6866000 100644
--- a/docs/bin/schemata_to_graphviz.py
+++ b/docs/bin/schemata_to_graphviz.py
@@ -6,19 +6,21 @@
def build_visualization():
- print('Building schema documentation...')
+ print ("Building schema documentation...")
# Load directory tree info
bin_dir = os.path.dirname(os.path.realpath(__file__))
- root_dir = os.path.join(os.path.abspath(
- os.path.join(bin_dir, os.pardir, os.pardir)))
+ root_dir = os.path.join(
+ os.path.abspath(os.path.join(bin_dir, os.pardir, os.pardir))
+ )
# Create graph
dot = Digraph(
- comment="High level graph representation of GDC data model", format='pdf')
- dot.graph_attr['rankdir'] = 'RL'
- dot.node_attr['fillcolor'] = 'lightblue'
- dot.node_attr['style'] = 'filled'
+ comment="High level graph representation of GDC data model", format="pdf"
+ )
+ dot.graph_attr["rankdir"] = "RL"
+ dot.node_attr["fillcolor"] = "lightblue"
+ dot.node_attr["style"] = "filled"
# Add nodes
for node in m.Node.get_subclasses():
@@ -28,7 +30,7 @@ def build_visualization():
# Add edges
for edge in m.Edge.get_subclasses():
- if edge.__dst_class__ == 'Case' and edge.label == 'relates_to':
+ if edge.__dst_class__ == "Case" and edge.label == "relates_to":
# Skip case cache edges
continue
@@ -36,10 +38,10 @@ def build_visualization():
dst = m.Node.get_subclass_named(edge.__dst_class__)
dot.edge(src.get_label(), dst.get_label(), edge.get_label())
- gv_path = os.path.join(root_dir, 'docs', 'viz', 'gdc_data_model.gv')
+ gv_path = os.path.join(root_dir, "docs", "viz", "gdc_data_model.gv")
dot.render(gv_path)
- print('graphviz output to {}'.format(gv_path))
+ print ("graphviz output to {}".format(gv_path))
-if __name__ == '__main__':
+if __name__ == "__main__":
build_visualization()
diff --git a/gdcdatamodel/__main__.py b/gdcdatamodel/__main__.py
index f3a9e387..7e1c4011 100644
--- a/gdcdatamodel/__main__.py
+++ b/gdcdatamodel/__main__.py
@@ -1,7 +1,6 @@
import argparse
import getpass
-import psqlgraph
from models import * # noqa
from models.versioned_nodes import VersionedNode # noqa
from psqlgraph import * # noqa
@@ -9,12 +8,18 @@
try:
import IPython
+
ipython = True
except Exception as e:
- print(('{}, using standard interactive console. '
- 'If you install IPython, then it will automatically '
- 'be used for this repl.').format(e))
+ print(
+ (
+ "{}, using standard interactive console. "
+ "If you install IPython, then it will automatically "
+ "be used for this repl."
+ ).format(e)
+ )
import code
+
ipython = False
@@ -31,17 +36,32 @@
"""
-if __name__ == '__main__':
+if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument('-d', '--database', default='test', type=str,
- help='name of the database to connect to')
- parser.add_argument('-i', '--host', default='localhost', type=str,
- help='host of the postgres server')
- parser.add_argument('-u', '--user', default='test', type=str,
- help='user to connect to postgres as')
- parser.add_argument('-p', '--password', default=None, type=str,
- help='password for given user. If no '
- 'password given, one will be prompted.')
+ parser.add_argument(
+ "-d",
+ "--database",
+ default="test",
+ type=str,
+ help="name of the database to connect to",
+ )
+ parser.add_argument(
+ "-i",
+ "--host",
+ default="localhost",
+ type=str,
+ help="host of the postgres server",
+ )
+ parser.add_argument(
+ "-u", "--user", default="test", type=str, help="user to connect to postgres as"
+ )
+ parser.add_argument(
+ "-p",
+ "--password",
+ default=None,
+ type=str,
+ help="password for given user. If no " "password given, one will be prompted.",
+ )
args = parser.parse_args()
@@ -49,8 +69,7 @@
if args.password is None:
args.password = getpass.getpass()
- g = psqlgraph.PsqlGraphDriver(
- args.host, args.user, args.password, args.database)
+ g = psqlgraph.PsqlGraphDriver(args.host, args.user, args.password, args.database)
with g.session_scope() as s:
rb = s.rollback
diff --git a/gdcdatamodel/gdc_postgres_admin.py b/gdcdatamodel/gdc_postgres_admin.py
index 21f40ad8..71216d4c 100644
--- a/gdcdatamodel/gdc_postgres_admin.py
+++ b/gdcdatamodel/gdc_postgres_admin.py
@@ -75,7 +75,7 @@ def execute_for_all_graph_tables(engine, sql, namespace=None, **kwargs):
edge_cls = ext.get_abstract_edge(namespace)
for cls in node_cls.get_subclasses() + edge_cls.get_subclasses():
- _kwargs = dict(kwargs, **{'table': cls.__tablename__})
+ _kwargs = dict(kwargs, **{"table": cls.__tablename__})
statement = sql.format(**_kwargs)
execute(engine, statement)
@@ -100,13 +100,13 @@ def create_graph_tables(engine, timeout, namespace=None):
"""
create a table
"""
- logger.info('Creating tables (timeout: %d)', timeout)
+ logger.info("Creating tables (timeout: %d)", timeout)
connection = engine.connect()
trans = connection.begin()
logger.info("Setting lock_timeout to %d", timeout)
- timeout_str = '{}s'.format(int(timeout+1))
+ timeout_str = "{}s".format(int(timeout + 1))
connection.execute("SET LOCAL lock_timeout = %s;", timeout_str)
orm_base = ext.get_orm_base(namespace) if namespace else ORMBase
@@ -123,25 +123,25 @@ def create_tables(engine, delay, retries, namespace=None):
"""
- logger.info('Running table creator named %s', app_name)
+ logger.info("Running table creator named %s", app_name)
try:
return create_graph_tables(engine, delay, namespace=namespace)
except OperationalError as e:
- if 'timeout' in str(e):
- logger.warning('Attempt timed out')
+ if "timeout" in str(e):
+ logger.warning("Attempt timed out")
else:
raise
if retries <= 0:
- raise RuntimeError('Max retries exceeded')
+ raise RuntimeError("Max retries exceeded")
logger.info(
- 'Trying again in {} seconds ({} retries remaining)'
- .format(delay, retries))
+ "Trying again in {} seconds ({} retries remaining)".format(delay, retries)
+ )
time.sleep(delay)
- create_tables(engine, delay, retries-1, namespace=namespace)
+ create_tables(engine, delay, retries - 1, namespace=namespace)
def subcommand_create(args):
@@ -153,10 +153,7 @@ def subcommand_create(args):
logger.info("Running subcommand 'create'")
engine = get_engine(args.host, args.user, args.password, args.database)
kwargs = dict(
- engine=engine,
- delay=args.delay,
- retries=args.retries,
- namespace=args.namespace
+ engine=engine, delay=args.delay, retries=args.retries, namespace=args.namespace
)
return create_tables(**kwargs)
@@ -172,15 +169,15 @@ def subcommand_grant(args):
logger.info("Running subcommand 'grant'")
engine = get_engine(args.host, args.user, args.password, args.database)
- assert args.read or args.write, 'No premission types/users specified.'
+ assert args.read or args.write, "No premission types/users specified."
if args.read:
- users_read = [u for u in args.read.split(',') if u]
+ users_read = [u for u in args.read.split(",") if u]
for user in users_read:
grant_read_permissions_to_graph(engine, user, args.namespace)
if args.write:
- users_write = [u for u in args.write.split(',') if u]
+ users_write = [u for u in args.write.split(",") if u]
for user in users_write:
grant_write_permissions_to_graph(engine, user, args.namespace)
@@ -196,73 +193,104 @@ def subcommand_revoke(args):
engine = get_engine(args.host, args.user, args.password, args.database)
if args.read:
- users_read = [u for u in args.read.split(',') if u]
+ users_read = [u for u in args.read.split(",") if u]
for user in users_read:
revoke_read_permissions_to_graph(engine, user, args.namespace)
if args.write:
- users_write = [u for u in args.write.split(',') if u]
+ users_write = [u for u in args.write.split(",") if u]
for user in users_write:
revoke_write_permissions_to_graph(engine, user, args.namespace)
def add_base_args(subparser):
- subparser.add_argument("-H", "--host", type=str, action="store",
- required=True, help="psql-server host")
- subparser.add_argument("-U", "--user", type=str, action="store",
- required=True, help="psql test user")
- subparser.add_argument("-D", "--database", type=str, action="store",
- required=True, help="psql test database")
- subparser.add_argument("-P", "--password", type=str, action="store",
- default='', help="psql test password")
- subparser.add_argument("-N", "--namespace", type=lambda x: x if x else None,
- help="psqlgraph model namespace")
+ subparser.add_argument(
+ "-H", "--host", type=str, action="store", required=True, help="psql-server host"
+ )
+ subparser.add_argument(
+ "-U", "--user", type=str, action="store", required=True, help="psql test user"
+ )
+ subparser.add_argument(
+ "-D",
+ "--database",
+ type=str,
+ action="store",
+ required=True,
+ help="psql test database",
+ )
+ subparser.add_argument(
+ "-P",
+ "--password",
+ type=str,
+ action="store",
+ default="",
+ help="psql test password",
+ )
+ subparser.add_argument(
+ "-N",
+ "--namespace",
+ type=lambda x: x if x else None,
+ help="psqlgraph model namespace",
+ )
return subparser
def add_subcommand_create(subparsers):
- parser = add_base_args(subparsers.add_parser(
- 'graph-create',
- help=subcommand_create.__doc__
- ))
+ parser = add_base_args(
+ subparsers.add_parser("graph-create", help=subcommand_create.__doc__)
+ )
parser.add_argument(
- "--delay", type=int, action="store", default=60,
- help="How many seconds to wait for blocking processes to finish before retrying."
+ "--delay",
+ type=int,
+ action="store",
+ default=60,
+ help="How many seconds to wait for blocking processes to finish before retrying.",
)
parser.add_argument(
- "--retries", type=int, action="store", default=10,
- help="If blocked by important process, how many times to retry after waiting `delay` seconds."
+ "--retries",
+ type=int,
+ action="store",
+ default=10,
+ help="If blocked by important process, how many times to retry after waiting `delay` seconds.",
)
def add_subcommand_grant(subparsers):
- parser = add_base_args(subparsers.add_parser(
- 'graph-grant',
- help=subcommand_grant.__doc__
- ))
+ parser = add_base_args(
+ subparsers.add_parser("graph-grant", help=subcommand_grant.__doc__)
+ )
parser.add_argument(
- "--read", type=str, action="store",
- help="Users to grant read access to (comma separated)."
+ "--read",
+ type=str,
+ action="store",
+ help="Users to grant read access to (comma separated).",
)
parser.add_argument(
- "--write", type=str, action="store",
- help="Users to grant read/write access to (comma separated)."
+ "--write",
+ type=str,
+ action="store",
+ help="Users to grant read/write access to (comma separated).",
)
def add_subcommand_revoke(subparsers):
- parser = add_base_args(subparsers.add_parser(
- 'graph-revoke',
- help=subcommand_revoke.__doc__
- ))
+ parser = add_base_args(
+ subparsers.add_parser("graph-revoke", help=subcommand_revoke.__doc__)
+ )
parser.add_argument(
- "--read", type=str, action="store",
- help="Users to revoke read access from (comma separated)."
+ "--read",
+ type=str,
+ action="store",
+ help="Users to revoke read access from (comma separated).",
)
parser.add_argument(
- "--write", type=str, action="store",
- help=("Users to revoke write access from (comma separated). "
- "NOTE: The user will still have read privs!!")
+ "--write",
+ type=str,
+ action="store",
+ help=(
+ "Users to revoke write access from (comma separated). "
+ "NOTE: The user will still have read privs!!"
+ ),
)
@@ -284,9 +312,9 @@ def main(args=None):
logger.info("[ NAMESPACE : %-10s ]", args.namespace or "default")
return_value = {
- 'graph-create': subcommand_create,
- 'graph-grant': subcommand_grant,
- 'graph-revoke': subcommand_revoke,
+ "graph-create": subcommand_create,
+ "graph-grant": subcommand_grant,
+ "graph-revoke": subcommand_revoke,
}[args.subcommand](args)
logger.info("Done.")
diff --git a/gdcdatamodel/models/__init__.py b/gdcdatamodel/models/__init__.py
index 10aa9e03..e0be12ec 100644
--- a/gdcdatamodel/models/__init__.py
+++ b/gdcdatamodel/models/__init__.py
@@ -1,4 +1,3 @@
-
"""gdcdatamodel.models
----------------------------------
@@ -738,7 +737,7 @@ def load_edges(dictionary, node_cls=Node, edge_cls=Edge, package_namespace=None)
src_cls._pg_links[link["name"]] = {
"edge_out": edge_name,
"dst_type": node_cls.get_subclass(link["target_type"]),
- "backref": link["backref"]
+ "backref": link["backref"],
}
for src_cls in node_cls.get_subclasses():
@@ -795,6 +794,7 @@ def inject_pg_edges(node_cls):
{ : {'backref': , 'type': } }
"""
+
def cls_inject_forward_edges(cls):
"""We should have already added the links that go OUT from this class,
so let's add them to `_pg_edges`
diff --git a/gdcdatamodel/models/caching.py b/gdcdatamodel/models/caching.py
index 6c3308c9..92b13e84 100644
--- a/gdcdatamodel/models/caching.py
+++ b/gdcdatamodel/models/caching.py
@@ -16,17 +16,17 @@
import logging
-logger = logging.getLogger('gdcdatamodel')
+logger = logging.getLogger("gdcdatamodel")
#: This variable contains the link name for the case shortcut
#: association proxy
-RELATED_CASES_LINK_NAME = '_related_cases'
+RELATED_CASES_LINK_NAME = "_related_cases"
#: This variable specifies the categories for which we won't create
# short cut : edges to case
NOT_RELATED_CASES_CATEGORIES = {
- 'administrative',
- 'TBD',
+ "administrative",
+ "TBD",
}
@@ -53,7 +53,7 @@ def get_related_case_edge_cls_name(node):
"""
- return '{}RelatesToCase'.format(node.__class__.__name__)
+ return "{}RelatesToCase".format(node.__class__.__name__)
def get_edge_src(edge):
@@ -69,18 +69,19 @@ def get_edge_src(edge):
src = edge.src
elif edge.src_id is not None:
src_class = node_cls.get_subclass_named(edge.__src_class__)
- src = (edge.get_session().query(src_class)
- .filter(src_class.node_id == edge.src_id)
- .first())
+ src = (
+ edge.get_session()
+ .query(src_class)
+ .filter(src_class.node_id == edge.src_id)
+ .first()
+ )
else:
src = None
return src
def get_edge_dst(edge, allow_query=False):
- """Return the edge's destination or None.
-
- """
+ """Return the edge's destination or None."""
node_cls = edge.get_node_class()
if edge.dst:
@@ -88,9 +89,12 @@ def get_edge_dst(edge, allow_query=False):
elif edge.dst_id is not None and allow_query:
dst_class = node_cls.get_subclass_named(edge.__dst_class__)
- dst = (edge.get_session().query(dst_class)
- .filter(dst_class.node_id == edge.dst_id)
- .first())
+ dst = (
+ edge.get_session()
+ .query(dst_class)
+ .filter(dst_class.node_id == edge.dst_id)
+ .first()
+ )
else:
dst = None
@@ -120,9 +124,7 @@ def related_cases_from_parents(node):
"""
- skip_edges_named = [
- get_related_case_edge_cls_name(node)
- ]
+ skip_edges_named = [get_related_case_edge_cls_name(node)]
# Make sure the edges haven't been expunged
edges_out = [e for e in node.edges_out if e in node.get_session()]
@@ -142,17 +144,15 @@ def related_cases_from_parents(node):
continue
node_cls = edge.get_node_class()
dst_class = node_cls.get_subclass_named(edge.__dst_class__)
- if dst_class.label == 'case' and edge.dst:
+ if dst_class.label == "case" and edge.dst:
cases.add(edge.dst)
return list(filter(None, cases))
-def cache_related_cases_recursive(node,
- session,
- flush_context,
- instances,
- visited_nodes=None):
+def cache_related_cases_recursive(
+ node, session, flush_context, instances, visited_nodes=None
+):
"""Update the related case cache on source node and its children
recursively iff the this update changes the related case source
node's shortcut edges.
@@ -219,21 +219,17 @@ def update_cache_edges(node, session, correct_cases):
# Get information about the existing edges
edge_name = get_related_case_edge_cls_name(node)
- existing_edges = getattr(node, '_{}_out'.format(edge_name))
+ existing_edges = getattr(node, "_{}_out".format(edge_name))
# Remove edges that should no longer exist
cases_disconnected = [
- edge.dst
- for edge in existing_edges
- if edge.dst_id not in correct_cases
+ edge.dst for edge in existing_edges if edge.dst_id not in correct_cases
]
for case in cases_disconnected:
assoc_proxy.remove(case)
- existing_edge_dst_case_ids = {
- edge.dst_id for edge in existing_edges
- }
+ existing_edge_dst_case_ids = {edge.dst_id for edge in existing_edges}
cases_connected = [
case
@@ -245,10 +241,7 @@ def update_cache_edges(node, session, correct_cases):
assoc_proxy.append(case)
-def cache_related_cases_on_insert(target,
- session,
- flush_context,
- instances):
+def cache_related_cases_on_insert(target, session, flush_context, instances):
"""Hook on updated edges. Update the related case cache on source
node and its children iff the this update changes the related case
source node's cache.
@@ -281,10 +274,7 @@ def cache_related_cases_on_insert(target,
)
-def cache_related_cases_on_update(target,
- session,
- flush_context,
- instances):
+def cache_related_cases_on_update(target, session, flush_context, instances):
"""Hook on deleted edges. Update the related case cache on source
node and its children.
@@ -309,10 +299,7 @@ def cache_related_cases_on_update(target,
)
-def cache_related_cases_on_delete(target,
- session,
- flush_context,
- instances):
+def cache_related_cases_on_delete(target, session, flush_context, instances):
"""Hook on deleted edges. Update the related case cache on source
node and its children.
diff --git a/gdcdatamodel/models/indexes.py b/gdcdatamodel/models/indexes.py
index eb500cba..9b633445 100644
--- a/gdcdatamodel/models/indexes.py
+++ b/gdcdatamodel/models/indexes.py
@@ -1,4 +1,3 @@
-
"""gdcdatamodel.models.indexes
----------------------------------
@@ -33,20 +32,20 @@ def index_name(cls, description):
"""
- name = 'index_{}_{}'.format(cls.__tablename__, description)
+ name = "index_{}_{}".format(cls.__tablename__, description)
# If the name is too long, prepend it with the first 8 hex of it's hash
# truncate the each part of the name
if len(name) > 40:
oldname = index_name
- logger.debug('Edge tablename {} too long, shortening'.format(oldname))
- name = 'index_{}_{}_{}'.format(
+ logger.debug("Edge tablename {} too long, shortening".format(oldname))
+ name = "index_{}_{}_{}".format(
hashlib.md5(py3_to_bytes(cls.__tablename__)).hexdigest()[:8],
- ''.join([a[:4] for a in cls.get_label().split('_')])[:20],
- '_'.join([a[:8] for a in description.split('_')])[:25],
+ "".join([a[:4] for a in cls.get_label().split("_")])[:20],
+ "_".join([a[:8] for a in description.split("_")])[:25],
)
- logger.debug('Shortening {} -> {}'.format(oldname, index_name))
+ logger.debug("Shortening {} -> {}".format(oldname, index_name))
return name
@@ -62,7 +61,7 @@ def get_secondary_key_indexes(cls):
"""
#: use text_pattern_ops, allows LIKE statements not starting with %
- index_op = 'text_pattern_ops'
+ index_op = "text_pattern_ops"
secondary_keys = {key for pair in cls.__pg_secondary_keys for key in pair}
key_indexes = (
@@ -70,15 +69,17 @@ def get_secondary_key_indexes(cls):
index_name(cls, key),
cls._props[key].astext.label(key),
postgresql_ops={key: index_op},
- ) for key in secondary_keys
+ )
+ for key in secondary_keys
)
lower_key_indexes = (
Index(
- index_name(cls, key+'_lower'),
- func.lower(cls._props[key].astext).label(key+'_lower'),
- postgresql_ops={key+'_lower': index_op},
- ) for key in secondary_keys
+ index_name(cls, key + "_lower"),
+ func.lower(cls._props[key].astext).label(key + "_lower"),
+ postgresql_ops={key + "_lower": index_op},
+ )
+ for key in secondary_keys
)
return tuple(key_indexes) + tuple(lower_key_indexes)
diff --git a/gdcdatamodel/models/utils.py b/gdcdatamodel/models/utils.py
index 75cab86a..f77f23d0 100644
--- a/gdcdatamodel/models/utils.py
+++ b/gdcdatamodel/models/utils.py
@@ -7,11 +7,13 @@ def decorator(f):
@wraps(f)
def wrapper(*args, **kwargs):
return f(*args, **kwargs)
+
return f
+
return decorator
def py3_to_bytes(bytes_or_str):
if sys.version_info[0] > 2 and isinstance(bytes_or_str, str):
- return bytes_or_str.encode('utf-8')
+ return bytes_or_str.encode("utf-8")
return bytes_or_str
diff --git a/gdcdatamodel/models/versioned_nodes.py b/gdcdatamodel/models/versioned_nodes.py
index 24bcfea9..03df64f5 100644
--- a/gdcdatamodel/models/versioned_nodes.py
+++ b/gdcdatamodel/models/versioned_nodes.py
@@ -9,21 +9,18 @@
class VersionedNode(Base):
- __tablename__ = 'versioned_nodes'
+ __tablename__ = "versioned_nodes"
__table_args__ = (
- Index('submitted_node_id_idx', 'node_id'),
- Index('submitted_node_gdc_versions_idx', 'node_id'),
+ Index("submitted_node_id_idx", "node_id"),
+ Index("submitted_node_gdc_versions_idx", "node_id"),
)
def __repr__(self):
- return (""
- .format(self.key, self.label, self.node_id))
+ return "".format(
+ self.key, self.label, self.node_id
+ )
- key = Column(
- BigInteger,
- primary_key=True,
- nullable=False
- )
+ key = Column(BigInteger, primary_key=True, nullable=False)
label = Column(
Text,
@@ -52,7 +49,7 @@ def __repr__(self):
versioned = Column(
DateTime(timezone=True),
nullable=False,
- server_default=text('now()'),
+ server_default=text("now()"),
)
acl = Column(
@@ -79,14 +76,13 @@ def clone(node):
return VersionedNode(
label=copy(node.label),
node_id=copy(node.node_id),
- project_id=copy(node._props.get('project_id')),
+ project_id=copy(node._props.get("project_id")),
created=copy(node.created),
acl=copy(node.acl),
system_annotations=copy(node.system_annotations),
properties=copy(node.properties),
- neighbors=copy([
- edge.dst_id for edge in node.edges_out
- ] + [
- edge.src_id for edge in node.edges_in
- ])
+ neighbors=copy(
+ [edge.dst_id for edge in node.edges_out]
+ + [edge.src_id for edge in node.edges_in]
+ ),
)
diff --git a/gdcdatamodel/models/versioning.py b/gdcdatamodel/models/versioning.py
index d382c658..00935126 100644
--- a/gdcdatamodel/models/versioning.py
+++ b/gdcdatamodel/models/versioning.py
@@ -102,10 +102,10 @@ def _constraints(self):
def is_taggable(self, node):
"""Returns true if node supports tagging else False. Ideally, instances that return false will not
- have tag and version number set on them
+ have tag and version number set on them
- Returns:
- bool: True for nodes that can be tagged
+ Returns:
+ bool: True for nodes that can be tagged
"""
return not any(criteria.match(node) for criteria in self._constraints())
diff --git a/gdcdatamodel/query.py b/gdcdatamodel/query.py
index 3a4f9ba6..0c4507ed 100644
--- a/gdcdatamodel/query.py
+++ b/gdcdatamodel/query.py
@@ -1,14 +1,30 @@
from psqlgraph import Edge, Node
traversals = {}
-terminal_nodes = ['annotations', 'centers', 'archives', 'tissue_source_sites',
- 'files', 'related_files', 'describing_files',
- 'clinical_metadata_files', 'experiment_metadata_files', 'run_metadata_files',
- 'analysis_metadata_files', 'biospecimen_metadata_files', 'aligned_reads_metrics',
- 'read_group_metrics', 'pathology_reports', 'simple_germline_variations',
- 'aligned_reads_indexes', 'mirna_expressions', 'exon_expressions',
- 'simple_somatic_mutations', 'gene_expressions', 'aggregated_somatic_mutations',
- ]
+terminal_nodes = [
+ "annotations",
+ "centers",
+ "archives",
+ "tissue_source_sites",
+ "files",
+ "related_files",
+ "describing_files",
+ "clinical_metadata_files",
+ "experiment_metadata_files",
+ "run_metadata_files",
+ "analysis_metadata_files",
+ "biospecimen_metadata_files",
+ "aligned_reads_metrics",
+ "read_group_metrics",
+ "pathology_reports",
+ "simple_germline_variations",
+ "aligned_reads_indexes",
+ "mirna_expressions",
+ "exon_expressions",
+ "simple_somatic_mutations",
+ "gene_expressions",
+ "aggregated_somatic_mutations",
+]
def construct_traversals(root, node, visited, path):
@@ -18,27 +34,38 @@ def construct_traversals(root, node, visited, path):
and neighbor not in visited
and neighbor != node
# no traveling THROUGH terminal nodes
- and (path[-1] not in terminal_nodes
- if path else neighbor.label not in terminal_nodes)
- and (not path[-1].startswith('_related')
- if path else not neighbor.label.startswith('_related')))
+ and (
+ path[-1] not in terminal_nodes
+ if path
+ else neighbor.label not in terminal_nodes
+ )
+ and (
+ not path[-1].startswith("_related")
+ if path
+ else not neighbor.label.startswith("_related")
+ )
+ )
for edge in Edge._get_edges_with_src(node.__name__):
- neighbor = [n for n in Node.get_subclasses()
- if n.__name__ == edge.__dst_class__][0]
+ neighbor = [
+ n for n in Node.get_subclasses() if n.__name__ == edge.__dst_class__
+ ][0]
if recurse(neighbor):
construct_traversals(
- root, neighbor, visited+[node], path+[edge.__src_dst_assoc__])
+ root, neighbor, visited + [node], path + [edge.__src_dst_assoc__]
+ )
for edge in Edge._get_edges_with_dst(node.__name__):
- neighbor = [n for n in Node.get_subclasses()
- if n.__name__ == edge.__src_class__][0]
+ neighbor = [
+ n for n in Node.get_subclasses() if n.__name__ == edge.__src_class__
+ ][0]
if recurse(neighbor):
construct_traversals(
- root, neighbor, visited+[node], path+[edge.__dst_src_assoc__])
+ root, neighbor, visited + [node], path + [edge.__dst_src_assoc__]
+ )
traversals[root][node.label] = traversals[root].get(node.label) or set()
- traversals[root][node.label].add('.'.join(path))
+ traversals[root][node.label].add(".".join(path))
def construct_traversals_for_all_nodes():
diff --git a/gdcdatamodel/validators/graph_validators.py b/gdcdatamodel/validators/graph_validators.py
index 2df33bb4..093e220e 100644
--- a/gdcdatamodel/validators/graph_validators.py
+++ b/gdcdatamodel/validators/graph_validators.py
@@ -2,16 +2,17 @@
class GDCGraphValidator(object):
- '''
+ """
Validator that validates entities' relationship with existing nodes in
database.
- '''
+ """
+
def __init__(self):
self.schemas = gdcdictionary
self.required_validators = {
- 'links_validator': GDCLinksValidator(),
- 'uniqueKeys_validator': GDCUniqueKeysValidator()
+ "links_validator": GDCLinksValidator(),
+ "uniqueKeys_validator": GDCUniqueKeysValidator(),
}
self.optional_validators = {}
@@ -21,20 +22,19 @@ def record_errors(self, graph, entities):
for entity in entities:
schema = self.schemas.schema[entity.node.label]
- validators = schema.get('validators')
+ validators = schema.get("validators")
if validators:
for validator_name in validators:
self.optional_validators[validator_name].validate()
class GDCLinksValidator(object):
-
def validate(self, entities, graph=None):
for entity in entities:
- for link in gdcdictionary.schema[entity.node.label]['links']:
- if 'name' in link:
+ for link in gdcdictionary.schema[entity.node.label]["links"]:
+ if "name" in link:
self.validate_edge(link, entity)
- elif 'subgroup' in link:
+ elif "subgroup" in link:
self.validate_edge_group(link, entity)
def validate_edge_group(self, schema, entity):
@@ -42,90 +42,97 @@ def validate_edge_group(self, schema, entity):
schema_links = []
num_of_edges = 0
- for group in schema['subgroup']:
- if 'subgroup' in schema['subgroup']:
+ for group in schema["subgroup"]:
+ if "subgroup" in schema["subgroup"]:
# nested subgroup
result = self.validate_edge_group(group, entity)
- if 'name' in group:
+ if "name" in group:
result = self.validate_edge(group, entity)
- if result['length'] > 0:
+ if result["length"] > 0:
submitted_links.append(result)
- num_of_edges += result['length']
- schema_links.append(result['name'])
+ num_of_edges += result["length"]
+ schema_links.append(result["name"])
- if schema.get('required') is True and len(submitted_links) == 0:
- names = ", ".join(
- schema_links[:-2] + [" or ".join(schema_links[-2:])])
+ if schema.get("required") is True and len(submitted_links) == 0:
+ names = ", ".join(schema_links[:-2] + [" or ".join(schema_links[-2:])])
entity.record_error(
- "Entity is missing a required link to {}"
- .format(names), keys=schema_links)
+ "Entity is missing a required link to {}".format(names),
+ keys=schema_links,
+ )
if schema.get("exclusive") is True and len(submitted_links) > 1:
- names = ", ".join(
- schema_links[:-2] + [" and ".join(schema_links[-2:])])
+ names = ", ".join(schema_links[:-2] + [" and ".join(schema_links[-2:])])
entity.record_error(
- "Links to {} are exclusive. More than one was provided: {}"
- .format(schema_links, entity.node.edges_out), keys=schema_links)
+ "Links to {} are exclusive. More than one was provided: {}".format(
+ schema_links, entity.node.edges_out
+ ),
+ keys=schema_links,
+ )
for edge in entity.node.edges_out:
- entity.record_error('{}'.format(edge.dst.submitter_id))
+ entity.record_error("{}".format(edge.dst.submitter_id))
- result = {'length': num_of_edges, 'name': ", ".join(schema_links)}
+ result = {"length": num_of_edges, "name": ", ".join(schema_links)}
def validate_edge(self, link_sub_schema, entity):
- association = link_sub_schema['name']
+ association = link_sub_schema["name"]
node = entity.node
targets = node[association]
- result = {'length': len(targets), 'name': association}
+ result = {"length": len(targets), "name": association}
if len(targets) > 0:
- multi = link_sub_schema['multiplicity']
+ multi = link_sub_schema["multiplicity"]
- if multi in ['many_to_one', 'one_to_one']:
+ if multi in ["many_to_one", "one_to_one"]:
if len(targets) > 1:
entity.record_error(
- "'{}' link has to be {}"
- .format(association, multi),
- keys=[association])
+ "'{}' link has to be {}".format(association, multi),
+ keys=[association],
+ )
- if multi in ['one_to_many', 'one_to_one']:
+ if multi in ["one_to_many", "one_to_one"]:
for target in targets:
- if len(target[link_sub_schema['backref']]) > 1:
+ if len(target[link_sub_schema["backref"]]) > 1:
entity.record_error(
- "'{}' link has to be {}, target node {} already has {}"
- .format(association, multi,
- target.label, link_sub_schema['backref']),
- keys=[association])
+ "'{}' link has to be {}, target node {} already has {}".format(
+ association,
+ multi,
+ target.label,
+ link_sub_schema["backref"],
+ ),
+ keys=[association],
+ )
- if multi == 'many_to_many':
+ if multi == "many_to_many":
pass
else:
- if link_sub_schema.get('required') is True:
+ if link_sub_schema.get("required") is True:
entity.record_error(
- "Entity is missing required link to {}"
- .format(association),
- keys=[association])
+ "Entity is missing required link to {}".format(association),
+ keys=[association],
+ )
return result
class GDCUniqueKeysValidator(object):
-
def validate(self, entities, graph=None):
for entity in entities:
schema = gdcdictionary.schema[entity.node.label]
node = entity.node
- for keys in schema['uniqueKeys']:
+ for keys in schema["uniqueKeys"]:
props = {}
- if keys == ['id']:
+ if keys == ["id"]:
continue
for key in keys:
- prop = schema['properties'][key].get('systemAlias')
+ prop = schema["properties"][key].get("systemAlias")
if prop:
props[prop] = node[prop]
else:
props[key] = node[key]
if graph.nodes().props(props).count() > 1:
- entity.record_error(
- '{} with {} already exists in the GDC'
- .format(node.label, props), keys=list(props.keys())
- )
+ entity.record_error(
+ "{} with {} already exists in the GDC".format(
+ node.label, props
+ ),
+ keys=list(props.keys()),
+ )
diff --git a/gdcdatamodel/validators/json_validators.py b/gdcdatamodel/validators/json_validators.py
index 13923170..506713a7 100644
--- a/gdcdatamodel/validators/json_validators.py
+++ b/gdcdatamodel/validators/json_validators.py
@@ -3,8 +3,10 @@
from gdcdictionary import gdcdictionary
from jsonschema import Draft4Validator
-missing_prop_re = re.compile("\'([a-zA-Z_-]+)\' is a required property")
-extra_prop_re = re.compile("Additional properties are not allowed \(u\'([a-zA-Z_-]+)\' was unexpected\)")
+missing_prop_re = re.compile("'([a-zA-Z_-]+)' is a required property")
+extra_prop_re = re.compile(
+ "Additional properties are not allowed \(u'([a-zA-Z_-]+)' was unexpected\)"
+)
def get_keys(error_msg):
@@ -27,29 +29,33 @@ def __init__(self):
def iter_errors(self, doc):
# Note whenever gdcdictionary use a newer version of jsonschema
# we need to update the Validator
- validator = Draft4Validator(self.schemas.schema[doc['type']])
+ validator = Draft4Validator(self.schemas.schema[doc["type"]])
return validator.iter_errors(doc)
def record_errors(self, entities):
for entity in entities:
json_doc = entity.doc
- if 'type' not in json_doc:
- entity.record_error(
- "'type' is a required property", keys=['type'])
+ if "type" not in json_doc:
+ entity.record_error("'type' is a required property", keys=["type"])
break
- if json_doc['type'] not in self.schemas.schema:
+ if json_doc["type"] not in self.schemas.schema:
entity.record_error(
- "specified type: {} is not in the current data model"
- .format(json_doc['type']), keys=['type'])
+ "specified type: {} is not in the current data model".format(
+ json_doc["type"]
+ ),
+ keys=["type"],
+ )
break
for error in self.iter_errors(json_doc):
# the key will be property.sub property for nested properties
errors = [str(e) for e in error.path if error.path]
- keys = ['.'.join(errors)] if errors else []
+ keys = [".".join(errors)] if errors else []
if not keys:
keys = get_keys(error.message)
message = error.message
if error.context:
- message += ': {}'.format(' and '.join([c.message for c in error.context]))
+ message += ": {}".format(
+ " and ".join([c.message for c in error.context])
+ )
entity.record_error(message, keys=keys)
# additional validators go here
diff --git a/gdcdatamodel/viz/__init__.py b/gdcdatamodel/viz/__init__.py
index ceeaa4be..e7c6eac4 100644
--- a/gdcdatamodel/viz/__init__.py
+++ b/gdcdatamodel/viz/__init__.py
@@ -8,18 +8,18 @@ def create_graphviz(nodes, include_case_cache_edges=False):
"""
dot = Digraph()
- dot.graph_attr['rankdir'] = 'RL'
+ dot.graph_attr["rankdir"] = "RL"
edges_added = set()
nodes = {node.node_id: node for node in nodes}
def is_edge_drawn(edge, neighbor):
- is_case_cache_edge = 'RelatesToCase' in edge.__class__.__name__
+ is_case_cache_edge = "RelatesToCase" in edge.__class__.__name__
return (
- (include_case_cache_edges or not is_case_cache_edge) and
- edge not in edges_added and
- neighbor in nodes
+ (include_case_cache_edges or not is_case_cache_edge)
+ and edge not in edges_added
+ and neighbor in nodes
)
for node in nodes.values():
diff --git a/migrations/async_transactions.py b/migrations/async_transactions.py
index 16ad58f9..eaee33b3 100644
--- a/migrations/async_transactions.py
+++ b/migrations/async_transactions.py
@@ -19,9 +19,10 @@
def up_transaction(connection):
- logger.info('Migrating async-transactions: up')
+ logger.info("Migrating async-transactions: up")
- connection.execute("""
+ connection.execute(
+ """
ALTER TABLE transaction_logs ADD COLUMN state TEXT;
ALTER TABLE transaction_logs ADD COLUMN committed_by INTEGER;
ALTER TABLE transaction_logs ADD COLUMN is_dry_run BOOLEAN;
@@ -37,13 +38,15 @@ def up_transaction(connection):
UPDATE transaction_logs SET state = 'SUCCEEDED' WHERE state IS NULL;
UPDATE transaction_logs SET is_dry_run = FALSE WHERE is_dry_run IS NULL;
- """)
+ """
+ )
def down_transaction(connection):
- logger.info('Migrating async-transactions: down')
+ logger.info("Migrating async-transactions: down")
- connection.execute("""
+ connection.execute(
+ """
DROP INDEX transaction_logs_program_idx;
DROP INDEX transaction_logs_project_idx;
DROP INDEX transaction_logs_is_dry_run_idx;
@@ -56,7 +59,8 @@ def down_transaction(connection):
ALTER TABLE transaction_logs DROP COLUMN closed;
ALTER TABLE transaction_logs DROP COLUMN committed_by;
ALTER TABLE transaction_logs DROP COLUMN is_dry_run;
- """)
+ """
+ )
def up(connection):
diff --git a/migrations/index_secondary_keys.py b/migrations/index_secondary_keys.py
index aa60fc1c..d6852f14 100644
--- a/migrations/index_secondary_keys.py
+++ b/migrations/index_secondary_keys.py
@@ -24,26 +24,27 @@
TX_LOG_PROJECT_ID_IDX = Index(
- 'transaction_logs_project_id_idx',
- TransactionLog.program+'_'+TransactionLog.project)
+ "transaction_logs_project_id_idx",
+ TransactionLog.program + "_" + TransactionLog.project,
+)
def up_transaction(connection):
- logger.info('Migrating async-transactions: up')
+ logger.info("Migrating async-transactions: up")
for cls in Node.get_subclasses():
for index in get_secondary_key_indexes(cls):
- logger.info('Creating %s', index.name)
+ logger.info("Creating %s", index.name)
index.create(connection)
TX_LOG_PROJECT_ID_IDX.create(connection)
def down_transaction(connection):
- logger.info('Migrating async-transactions: down')
+ logger.info("Migrating async-transactions: down")
for cls in Node.get_subclasses():
for index in get_secondary_key_indexes(cls):
- logger.info('Dropping %s', index.name)
+ logger.info("Dropping %s", index.name)
index.drop(connection)
TX_LOG_PROJECT_ID_IDX.drop(connection)
diff --git a/migrations/notifications.py b/migrations/notifications.py
index 74857dfb..6a757267 100644
--- a/migrations/notifications.py
+++ b/migrations/notifications.py
@@ -14,7 +14,7 @@
def up(connection):
- logger.info('Migrating notifications: up')
+ logger.info("Migrating notifications: up")
models.notifications.Base.metadata.create_all(connection)
models.redaction.Base.metadata.create_all(connection)
@@ -22,7 +22,7 @@ def up(connection):
def down(connection):
- logger.info('Migrating notifications: down')
+ logger.info("Migrating notifications: down")
models.notifications.Base.metadata.drop_all(connection)
models.redaction.Base.metadata.drop_all(connection)
diff --git a/migrations/set_null_edge_columns.py b/migrations/set_null_edge_columns.py
index 0e97d176..b6d2881d 100644
--- a/migrations/set_null_edge_columns.py
+++ b/migrations/set_null_edge_columns.py
@@ -7,7 +7,7 @@
CACHE_EDGES = {
Node.get_subclass_named(edge.__src_class__): edge
for edge in Edge.get_subclasses()
- if 'RelatesToCase' in edge.__name__
+ if "RelatesToCase" in edge.__name__
}
@@ -32,9 +32,10 @@ def set_null_edge_columns(graph):
def main():
- print("No main() action defined, please manually call "
- "set_null_edge_columns(graph)")
+ print(
+ "No main() action defined, please manually call " "set_null_edge_columns(graph)"
+ )
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/migrations/update_case_cache_append_only.py b/migrations/update_case_cache_append_only.py
index ae20bfe4..0ca4c508 100644
--- a/migrations/update_case_cache_append_only.py
+++ b/migrations/update_case_cache_append_only.py
@@ -7,7 +7,7 @@
CACHE_EDGES = {
Node.get_subclass_named(edge.__src_class__): edge
for edge in Edge.get_subclasses()
- if 'RelatesToCase' in edge.__name__
+ if "RelatesToCase" in edge.__name__
}
@@ -68,13 +68,10 @@ def max_distances_from_case():
cls, level = to_visit.pop(0)
if cls not in distances:
- children = (
- link['src_type']
- for _, link in cls._pg_backrefs.items()
- )
- to_visit.extend((child, level+1) for child in children)
+ children = (link["src_type"] for _, link in cls._pg_backrefs.items())
+ to_visit.extend((child, level + 1) for child in children)
- distances[cls] = max(distances.get(cls, level+1), level)
+ distances[cls] = max(distances.get(cls, level + 1), level)
return distances
@@ -89,9 +86,8 @@ def get_levels():
distinct_distances = set(distances.values())
levels = {
- level: [
- cls for cls, distance in distances.items() if distance == level
- ] for level in distinct_distances
+ level: [cls for cls, distance in distances.items() if distance == level]
+ for level in distinct_distances
}
return levels
@@ -106,7 +102,7 @@ def append_cache_from_parent(graph, child, parent):
"""
- description = child.label + ' -> ' + parent.label + ' -> case'
+ description = child.label + " -> " + parent.label + " -> case"
if parent not in CACHE_EDGES:
print("skipping:", description, ": parent is not cached")
@@ -137,10 +133,7 @@ def append_cache_from_parents(graph, cls):
"""
- parents = {
- link['dst_type']
- for link in cls._pg_links.itervalues()
- }
+ parents = {link["dst_type"] for link in cls._pg_links.itervalues()}
for parent in parents:
append_cache_from_parent(graph, cls, parent)
@@ -168,7 +161,7 @@ def seed_level_1(graph, cls):
cls_to_case_edge_table=case_edge.__tablename__,
)
- print('Seeding {} through {}'.format(cls.get_label(), case_edge.__name__))
+ print("Seeding {} through {}".format(cls.get_label(), case_edge.__name__))
graph.current_session().execute(statement)
@@ -195,9 +188,11 @@ def update_case_cache_append_only(graph):
def main():
- print("No main() action defined, please manually call "
- "update_case_cache_append_only(graph)")
+ print(
+ "No main() action defined, please manually call "
+ "update_case_cache_append_only(graph)"
+ )
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/migrations/update_legacy_states.py b/migrations/update_legacy_states.py
index 8faed978..4d280524 100644
--- a/migrations/update_legacy_states.py
+++ b/migrations/update_legacy_states.py
@@ -44,14 +44,12 @@
from gdcdatamodel import models as md
CLS_WITH_PROJECT_ID = {
- cls for cls in Node.get_subclasses()
- if 'project_id' in cls.__pg_properties__
+ cls for cls in Node.get_subclasses() if "project_id" in cls.__pg_properties__
}
CLS_WITH_STATE = {
- cls for cls in Node.get_subclasses()
- if 'state' in cls.__pg_properties__
+ cls for cls in Node.get_subclasses() if "state" in cls.__pg_properties__
}
@@ -59,34 +57,13 @@
# Determines state and file_state based on existing state
STATE_MAP = {
- None: {
- 'state': 'submitted',
- 'file_state': None
- },
- 'error': {
- 'state': 'validated',
- 'file_state': 'error'
- },
- 'invalid': {
- 'state': 'validated',
- 'file_state': 'error'
- },
- 'live': {
- 'state': 'submitted',
- 'file_state': 'submitted'
- },
- 'submitted': {
- 'state': 'submitted',
- 'file_state': 'registered'
- },
- 'uploaded': {
- 'state': 'submitted',
- 'file_state': 'uploaded'
- },
- 'validated': {
- 'state': 'submitted',
- 'file_state': 'validated'
- },
+ None: {"state": "submitted", "file_state": None},
+ "error": {"state": "validated", "file_state": "error"},
+ "invalid": {"state": "validated", "file_state": "error"},
+ "live": {"state": "submitted", "file_state": "submitted"},
+ "submitted": {"state": "submitted", "file_state": "registered"},
+ "uploaded": {"state": "submitted", "file_state": "uploaded"},
+ "validated": {"state": "submitted", "file_state": "validated"},
}
@@ -101,15 +78,12 @@ def legacy_filter(query, legacy_projects):
"""
legacy_filters = [
- query.entity().project_id.astext ==
- project.programs[0].name + '-' + project.code
+ query.entity().project_id.astext
+ == project.programs[0].name + "-" + project.code
for project in legacy_projects
]
- return query.filter(or_(
- null_prop(query.entity(), 'project_id'),
- *legacy_filters
- ))
+ return query.filter(or_(null_prop(query.entity(), "project_id"), *legacy_filters))
def null_prop(cls, key):
@@ -130,8 +104,11 @@ def print_cls_query_summary(graph):
}
print(
- "%s: %d" % ("legacy_stateless_nodes".ljust(40),
- sum([query.count() for query in cls_queries.itervalues()]))
+ "%s: %d"
+ % (
+ "legacy_stateless_nodes".ljust(40),
+ sum([query.count() for query in cls_queries.itervalues()]),
+ )
)
for label, query in cls_queries.items():
@@ -143,19 +120,18 @@ def print_cls_query_summary(graph):
def cls_query(graph, cls):
"""Returns query for legacy nodes with state in {null, 'live'}"""
- legacy_projects = graph.nodes(md.Project).props(state='legacy').all()
+ legacy_projects = graph.nodes(md.Project).props(state="legacy").all()
options = [
# state
- null_prop(cls, 'state'),
+ null_prop(cls, "state"),
cls.state.astext.in_(STATE_MAP),
]
- if 'file_state' in cls.__pg_properties__:
- options += [null_prop(cls, 'file_state')]
+ if "file_state" in cls.__pg_properties__:
+ options += [null_prop(cls, "file_state")]
- return (legacy_filter(graph.nodes(cls), legacy_projects)
- .filter(or_(*options)))
+ return legacy_filter(graph.nodes(cls), legacy_projects).filter(or_(*options))
def update_cls(graph, cls):
@@ -167,32 +143,32 @@ def update_cls(graph, cls):
if count == 0:
return
- logger.info('Loading %d %s nodes', count, cls.label)
+ logger.info("Loading %d %s nodes", count, cls.label)
nodes = query.all()
- logger.info('Loaded %d %s nodes', len(nodes), cls.label)
+ logger.info("Loaded %d %s nodes", len(nodes), cls.label)
for node in nodes:
- state = node._props.get('state', None)
- file_state = node._props.get('file_state', None)
+ state = node._props.get("state", None)
+ file_state = node._props.get("file_state", None)
if state in STATE_MAP:
- node.state = STATE_MAP[state]['state']
+ node.state = STATE_MAP[state]["state"]
set_file_state = (
- 'file_state' in node.__pg_properties__
+ "file_state" in node.__pg_properties__
and file_state is None
and state in STATE_MAP
)
if set_file_state:
- node.file_state = STATE_MAP[state]['file_state']
+ node.file_state = STATE_MAP[state]["file_state"]
- node.sysan['legacy_state'] = state
- node.sysan['legacy_file_state'] = file_state
+ node.sysan["legacy_state"] = state
+ node.sysan["legacy_file_state"] = file_state
- logger.info('Committing %s nodes', cls.label)
+ logger.info("Committing %s nodes", cls.label)
graph.current_session().commit()
- logger.info('Done with %s nodes', cls.label)
+ logger.info("Done with %s nodes", cls.label)
def update_classes(graph_kwargs, input_q):
diff --git a/test/conftest.py b/test/conftest.py
index 726d7ea4..4b38beb0 100644
--- a/test/conftest.py
+++ b/test/conftest.py
@@ -22,25 +22,27 @@
from gdcdatamodel.models import basic # noqa
-@pytest.fixture(scope='session')
+@pytest.fixture(scope="session")
def db_config():
return {
- 'host': 'localhost',
- 'user': 'test',
- 'password': 'test',
- 'database': 'automated_test',
+ "host": "localhost",
+ "user": "test",
+ "password": "test",
+ "database": "automated_test",
}
-@pytest.fixture(scope='session')
+@pytest.fixture(scope="session")
def tables_created(db_config):
"""
Create necessary tables
"""
engine = create_engine(
"postgres://{user}:{pwd}@{host}/{db}".format(
- user=db_config['user'], host=db_config['host'],
- pwd=db_config['password'], db=db_config['database']
+ user=db_config["user"],
+ host=db_config["host"],
+ pwd=db_config["password"],
+ db=db_config["database"],
)
)
@@ -51,14 +53,14 @@ def tables_created(db_config):
truncate(engine)
-@pytest.fixture(scope='session')
+@pytest.fixture(scope="session")
def g(db_config, tables_created):
"""Fixture for database driver"""
return PsqlGraphDriver(**db_config)
-@pytest.fixture(scope='class')
+@pytest.fixture(scope="class")
def db_class(request, g):
"""
Sets g property on a test class
@@ -66,9 +68,10 @@ def db_class(request, g):
request.cls.g = g
-@pytest.fixture(scope='session')
+@pytest.fixture(scope="session")
def indexes(g):
- rows = g.engine.execute("""
+ rows = g.engine.execute(
+ """
SELECT i.relname as indname,
ARRAY(
SELECT pg_get_indexdef(idx.indexrelid, k + 1, true)
@@ -80,14 +83,15 @@ def indexes(g):
ON i.oid = idx.indexrelid
JOIN pg_am as am
ON i.relam = am.oid;
- """).fetchall()
+ """
+ ).fetchall()
- return { row[0]: row[1] for row in rows }
+ return {row[0]: row[1] for row in rows}
@pytest.fixture()
def redacted_fixture(g):
- """ Creates a redacted log entry"""
+ """Creates a redacted log entry"""
with g.session_scope() as sxn:
log = models.redaction.RedactionLog()
@@ -100,7 +104,9 @@ def redacted_fixture(g):
count = 0
for i in range(random.randint(2, 5)):
count += 1
- entry = models.redaction.RedactionEntry(node_id=str(uuid.uuid4()), node_type="Aligned Reads")
+ entry = models.redaction.RedactionEntry(
+ node_id=str(uuid.uuid4()), node_type="Aligned Reads"
+ )
log.entries.append(entry)
sxn.add(log)
@@ -116,7 +122,7 @@ def redacted_fixture(g):
sxn.delete(log)
-@pytest.mark.usefixtures('db_class')
+@pytest.mark.usefixtures("db_class")
class BaseTestCase(unittest.TestCase):
def setUp(self):
truncate(self.g.engine)
diff --git a/test/helpers.py b/test/helpers.py
index b8ce1c46..0b8266b1 100644
--- a/test/helpers.py
+++ b/test/helpers.py
@@ -16,10 +16,10 @@ def truncate(engine, namespace=None):
conn = engine.connect()
for table in abstract_node.get_subclass_table_names():
if table != abstract_node.__tablename__:
- conn.execute('delete from {}'.format(table))
+ conn.execute("delete from {}".format(table))
for table in abstract_edge.get_subclass_table_names():
if table != abstract_edge.__tablename__:
- conn.execute('delete from {}'.format(table))
+ conn.execute("delete from {}".format(table))
if not namespace:
# add ng models only to main graph model
diff --git a/test/models.py b/test/models.py
index 177879dd..b1c5f787 100644
--- a/test/models.py
+++ b/test/models.py
@@ -8,7 +8,6 @@ def _load(name):
class Dictionary:
-
def __init__(self, name):
self.schema = _load(name)
diff --git a/test/test_admin_script.py b/test/test_admin_script.py
index e5c434f4..54b82383 100644
--- a/test/test_admin_script.py
+++ b/test/test_admin_script.py
@@ -8,12 +8,7 @@
def get_base_args(host="localhost", database="automated_test", namespace=None):
- return [
- '-H', host,
- '-U', "postgres",
- '-D', database,
- "-N", namespace or ""
- ]
+ return ["-H", host, "-U", "postgres", "-D", database, "-N", namespace or ""]
def get_admin_driver(db_config, namespace=None):
@@ -24,14 +19,18 @@ def get_admin_driver(db_config, namespace=None):
host=db_config["host"],
user="postgres",
password=None,
- database=db_config["database"]
+ database=db_config["database"],
)
return g
def drop_all_tables(g):
- orm_base = ext.get_orm_base(g.package_namespace) if g.package_namespace else psqlgraph.base.ORMBase
+ orm_base = (
+ ext.get_orm_base(g.package_namespace)
+ if g.package_namespace
+ else psqlgraph.base.ORMBase
+ )
psqlgraph.base.drop_all(g.engine, orm_base)
@@ -44,12 +43,12 @@ def run_admin_command(args, namespace=None):
def invalid_write_access_fn(g):
with pytest.raises(ProgrammingError):
with g.session_scope() as s:
- s.add(models.Case('1'))
+ s.add(models.Case("1"))
def valid_write_access_fn(g):
with g.session_scope() as s:
- s.merge(models.Case('1'))
+ s.merge(models.Case("1"))
yield
with g.session_scope() as s:
n = g.nodes().get("1")
@@ -77,7 +76,9 @@ def add_test_database_user(db_config):
g = get_admin_driver(db_config)
try:
- g.engine.execute("CREATE USER {} WITH PASSWORD '{}'".format(dummy_user, dummy_pwd))
+ g.engine.execute(
+ "CREATE USER {} WITH PASSWORD '{}'".format(dummy_user, dummy_pwd)
+ )
g.engine.execute("GRANT USAGE ON SCHEMA public TO {}".format(dummy_user))
yield dummy_user, dummy_pwd
finally:
@@ -86,10 +87,10 @@ def add_test_database_user(db_config):
@pytest.mark.parametrize("namespace", [None, "gdc"], ids=["default", "custom"])
def test_create_tables(db_config, namespace):
- """ Tests tables can be created with the admin script using either a custom dictionary or the default
- Args:
- db_config (dict[str,str]): db connection config
- namespace (str): module namespace, None for default
+ """Tests tables can be created with the admin script using either a custom dictionary or the default
+ Args:
+ db_config (dict[str,str]): db connection config
+ namespace (str): module namespace, None for default
"""
# simulate loading a different dictionary
@@ -114,7 +115,9 @@ def test_create_tables(db_config, namespace):
assert 'relation "node_case" does not exist' in str(e.value)
# create tables using admin script
- args = ['graph-create', '--delay', '1', '--retries', '0'] + get_base_args(namespace=namespace)
+ args = ["graph-create", "--delay", "1", "--retries", "0"] + get_base_args(
+ namespace=namespace
+ )
parsed_args = pgadmin.get_parser().parse_args(args)
pgadmin.main(parsed_args)
@@ -125,13 +128,22 @@ def test_create_tables(db_config, namespace):
@pytest.mark.parametrize("namespace", [None, "gdc"], ids=["default", "custom"])
-@pytest.mark.parametrize("permission, invalid_permission_fn, valid_permission_fn", [
- ("read", invalid_read_access_fn, valid_read_access_fn),
- ("write", invalid_write_access_fn, valid_write_access_fn),
- ], ids=["read", "write"]
+@pytest.mark.parametrize(
+ "permission, invalid_permission_fn, valid_permission_fn",
+ [
+ ("read", invalid_read_access_fn, valid_read_access_fn),
+ ("write", invalid_write_access_fn, valid_write_access_fn),
+ ],
+ ids=["read", "write"],
)
-def test_grant_permissions(db_config, namespace, add_test_database_user,
- permission, invalid_permission_fn, valid_permission_fn):
+def test_grant_permissions(
+ db_config,
+ namespace,
+ add_test_database_user,
+ permission,
+ invalid_permission_fn,
+ valid_permission_fn,
+):
# simulate loading a different dictionary
if namespace:
@@ -140,11 +152,11 @@ def test_grant_permissions(db_config, namespace, add_test_database_user,
dummy_user, dummy_pwd = add_test_database_user
g = psqlgraph.PsqlGraphDriver(
- host=db_config["host"],
- user=dummy_user,
- password=dummy_pwd,
- database=db_config["database"],
- package_namespace=namespace,
+ host=db_config["host"],
+ user=dummy_user,
+ password=dummy_pwd,
+ database=db_config["database"],
+ package_namespace=namespace,
)
# verify user does not have permission
@@ -156,13 +168,22 @@ def test_grant_permissions(db_config, namespace, add_test_database_user,
@pytest.mark.parametrize("namespace", [None, "gdc"], ids=["default", "custom"])
-@pytest.mark.parametrize("permission, invalid_permission_fn, valid_permission_fn", [
- ("read", invalid_read_access_fn, valid_read_access_fn),
- ("write", invalid_write_access_fn, valid_write_access_fn),
- ], ids=["read", "write"]
+@pytest.mark.parametrize(
+ "permission, invalid_permission_fn, valid_permission_fn",
+ [
+ ("read", invalid_read_access_fn, valid_read_access_fn),
+ ("write", invalid_write_access_fn, valid_write_access_fn),
+ ],
+ ids=["read", "write"],
)
-def test_revoke_permissions(db_config, namespace, add_test_database_user,
- permission, invalid_permission_fn, valid_permission_fn):
+def test_revoke_permissions(
+ db_config,
+ namespace,
+ add_test_database_user,
+ permission,
+ invalid_permission_fn,
+ valid_permission_fn,
+):
# simulate loading a different dictionary
if namespace:
@@ -171,11 +192,11 @@ def test_revoke_permissions(db_config, namespace, add_test_database_user,
dummy_user, dummy_pwd = add_test_database_user
g = psqlgraph.PsqlGraphDriver(
- host=db_config["host"],
- user=dummy_user,
- password=dummy_pwd,
- database=db_config["database"],
- package_namespace=namespace,
+ host=db_config["host"],
+ user=dummy_user,
+ password=dummy_pwd,
+ database=db_config["database"],
+ package_namespace=namespace,
)
# grant user permissions
diff --git a/test/test_cache_related_cases.py b/test/test_cache_related_cases.py
index bb3ab094..6a40b2f2 100644
--- a/test/test_cache_related_cases.py
+++ b/test/test_cache_related_cases.py
@@ -6,64 +6,64 @@
class TestCacheRelatedCases(BaseTestCase):
-
def test_insert_single_association_proxy(self):
with self.g.session_scope() as s:
- case = md.Case('case_id_1')
- sample = md.Sample('sample_id_1')
+ case = md.Case("case_id_1")
+ sample = md.Sample("sample_id_1")
sample.cases = [case]
s.merge(sample)
with self.g.session_scope() as s:
- sample = self.g.nodes(md.Sample).subq_path('cases').one()
+ sample = self.g.nodes(md.Sample).subq_path("cases").one()
assert sample._related_cases_from_cache == [case]
def test_insert_single_edge(self):
with self.g.session_scope() as s:
- case = s.merge(md.Case('case_id_1'))
- sample = s.merge(md.Sample('sample_id_1'))
+ case = s.merge(md.Case("case_id_1"))
+ sample = s.merge(md.Sample("sample_id_1"))
edge = md.SampleDerivedFromCase(sample.node_id, case.node_id)
s.merge(edge)
with self.g.session_scope() as s:
- sample = self.g.nodes(md.Sample).subq_path('cases').one()
+ sample = self.g.nodes(md.Sample).subq_path("cases").one()
assert sample._related_cases_from_cache == [case]
def test_insert_double_edge_in(self):
with self.g.session_scope() as s:
- case = md.Case('case_id_1')
- sample1 = md.Sample('sample_id_1')
- sample2 = md.Sample('sample_id_2')
+ case = md.Case("case_id_1")
+ sample1 = md.Sample("sample_id_1")
+ sample2 = md.Sample("sample_id_2")
case.samples = [sample1, sample2]
s.merge(case)
with self.g.session_scope() as s:
- samples = self.g.nodes(md.Sample).subq_path('cases').all()
+ samples = self.g.nodes(md.Sample).subq_path("cases").all()
self.assertEqual(len(samples), 2)
for sample in samples:
assert sample._related_cases_from_cache == [case]
def test_insert_double_edge_out(self):
with self.g.session_scope() as s:
- case1 = md.Case('case_id_1')
- case2 = md.Case('case_id_2')
- sample = md.Sample('sample_id_1')
+ case1 = md.Case("case_id_1")
+ case2 = md.Case("case_id_2")
+ sample = md.Sample("sample_id_1")
sample.cases = [case1, case2]
s.merge(sample)
with self.g.session_scope() as s:
- sample = self.g.nodes(md.Sample).subq_path('cases').one()
- assert {c.node_id for c in sample._related_cases} == \
- {c.node_id for c in [case1, case2]}
+ sample = self.g.nodes(md.Sample).subq_path("cases").one()
+ assert {c.node_id for c in sample._related_cases} == {
+ c.node_id for c in [case1, case2]
+ }
def test_insert_multiple_edges(self):
with self.g.session_scope() as s:
- case = md.Case('case_id_1')
- sample = md.Sample('sample_id_1')
- portion = md.Portion('portion_id_1')
- analyte = md.Analyte('analyte_id_1')
- aliquot = md.Aliquot('aliquot_id_1')
- general_file = md.File('file_id_1')
+ case = md.Case("case_id_1")
+ sample = md.Sample("sample_id_1")
+ portion = md.Portion("portion_id_1")
+ analyte = md.Analyte("analyte_id_1")
+ aliquot = md.Aliquot("aliquot_id_1")
+ general_file = md.File("file_id_1")
sample.cases = [case]
portion.samples = [sample]
@@ -74,17 +74,17 @@ def test_insert_multiple_edges(self):
with self.g.session_scope() as s:
nodes = self.g.nodes(Node).all()
- nodes = [n for n in nodes if n.label not in ['case']]
+ nodes = [n for n in nodes if n.label not in ["case"]]
for node in nodes:
assert node._related_cases == [case]
def test_insert_update_children(self):
with self.g.session_scope() as s:
- aliquot = s.merge(md.Aliquot('aliquot_id_1'))
- sample = s.merge(md.Sample('sample_id_1'))
+ aliquot = s.merge(md.Aliquot("aliquot_id_1"))
+ sample = s.merge(md.Sample("sample_id_1"))
aliquot.samples = [sample]
- s.merge(md.Case('case_id_1'))
+ s.merge(md.Case("case_id_1"))
with self.g.session_scope() as s:
case = self.g.nodes(md.Case).one()
@@ -100,9 +100,9 @@ def test_insert_update_children(self):
def test_delete_dst_association_proxy(self):
with self.g.session_scope() as s:
- case = md.Case('case_id_1')
- aliquot = md.Aliquot('aliquot_id_1')
- sample = md.Sample('sample_id_1')
+ case = md.Case("case_id_1")
+ aliquot = md.Aliquot("aliquot_id_1")
+ sample = md.Sample("sample_id_1")
aliquot.samples = [sample]
sample.cases = [case]
s.merge(case)
@@ -112,7 +112,7 @@ def test_delete_dst_association_proxy(self):
case.samples = []
with self.g.session_scope() as s:
- assert not self.g.nodes(md.Sample).subq_path('cases').count()
+ assert not self.g.nodes(md.Sample).subq_path("cases").count()
with self.g.session_scope() as s:
sample = self.g.nodes(md.Sample).one()
@@ -123,9 +123,9 @@ def test_delete_dst_association_proxy(self):
def test_delete_src_association_proxy(self):
with self.g.session_scope() as s:
- case = md.Case('case_id_1')
- aliquot = md.Aliquot('aliquot_id_1')
- sample = md.Sample('sample_id_1')
+ case = md.Case("case_id_1")
+ aliquot = md.Aliquot("aliquot_id_1")
+ sample = md.Sample("sample_id_1")
aliquot.samples = [sample]
sample.cases = [case]
s.merge(case)
@@ -135,7 +135,7 @@ def test_delete_src_association_proxy(self):
sample.cases = []
with self.g.session_scope() as s:
- assert not self.g.nodes(md.Sample).subq_path('cases').count()
+ assert not self.g.nodes(md.Sample).subq_path("cases").count()
with self.g.session_scope() as s:
sample = self.g.nodes(md.Sample).one()
@@ -146,9 +146,9 @@ def test_delete_src_association_proxy(self):
def test_delete_edge(self):
with self.g.session_scope() as s:
- case = md.Case('case_id_1')
- aliquot = md.Aliquot('aliquot_id_1')
- sample = md.Sample('sample_id_1')
+ case = md.Case("case_id_1")
+ aliquot = md.Aliquot("aliquot_id_1")
+ sample = md.Sample("sample_id_1")
aliquot.samples = [sample]
sample.cases = [case]
s.merge(case)
@@ -156,13 +156,14 @@ def test_delete_edge(self):
with self.g.session_scope() as s:
case = self.g.nodes(md.Case).one()
edge = [
- e for e in case.edges_in
- if e.label != 'relates_to' and e.src.label == 'sample'
+ e
+ for e in case.edges_in
+ if e.label != "relates_to" and e.src.label == "sample"
][0]
s.delete(edge)
with self.g.session_scope() as s:
- assert not self.g.nodes(md.Sample).subq_path('cases').count()
+ assert not self.g.nodes(md.Sample).subq_path("cases").count()
with self.g.session_scope() as s:
sample = self.g.nodes(md.Sample).one()
@@ -173,8 +174,8 @@ def test_delete_edge(self):
def test_delete_parent(self):
with self.g.session_scope() as s:
- case = md.Case('case_id_1')
- sample = md.Sample('sample_id_1')
+ case = md.Case("case_id_1")
+ sample = md.Sample("sample_id_1")
sample.cases = [case]
s.merge(case)
@@ -189,14 +190,14 @@ def test_delete_parent(self):
def test_delete_one_parent(self):
with self.g.session_scope() as s:
- case1 = md.Case('case_id_1')
- case2 = md.Case('case_id_2')
- sample = md.Sample('sample_id_1')
+ case1 = md.Case("case_id_1")
+ case2 = md.Case("case_id_2")
+ sample = md.Sample("sample_id_1")
sample.cases = [case1, case2]
s.merge(sample)
with self.g.session_scope() as s:
- case1 = self.g.nodes(md.Case).ids('case_id_1').one()
+ case1 = self.g.nodes(md.Case).ids("case_id_1").one()
s.delete(case1)
with self.g.session_scope() as s:
@@ -206,7 +207,7 @@ def test_delete_one_parent(self):
def test_preserve_timestamps(self):
"""Confirm cache changes do not affect the case's timestamps."""
with self.g.session_scope() as s:
- s.merge(md.Case('case_id_1'))
+ s.merge(md.Case("case_id_1"))
with self.g.session_scope():
case = self.g.nodes(md.Case).one()
@@ -214,16 +215,16 @@ def test_preserve_timestamps(self):
old_updated_datetime = case.updated_datetime
# Test addition of cache edges.
- sample = md.Sample('sample_id_1')
- portion = md.Portion('portion_id_1')
- analyte = md.Analyte('analyte_id_1')
- aliquot = md.Aliquot('aliquot_id_1')
+ sample = md.Sample("sample_id_1")
+ portion = md.Portion("portion_id_1")
+ analyte = md.Analyte("analyte_id_1")
+ aliquot = md.Aliquot("aliquot_id_1")
sample.cases = [case]
portion.samples = [sample]
analyte.portions = [portion]
aliquot.analytes = [analyte]
- sample2 = md.Sample('sample_id_2')
+ sample2 = md.Sample("sample_id_2")
sample2.cases = [case]
with self.g.session_scope() as s:
@@ -231,7 +232,7 @@ def test_preserve_timestamps(self):
# Exercise a few cache edge removal use cases as well.
analyte = self.g.nodes(md.Analyte).one()
- sample2 = self.g.nodes(md.Sample).get('sample_id_2')
+ sample2 = self.g.nodes(md.Sample).get("sample_id_2")
s.delete(analyte)
sample2.cases = []
diff --git a/test/test_datamodel.py b/test/test_datamodel.py
index c0c322d3..e638b38f 100644
--- a/test/test_datamodel.py
+++ b/test/test_datamodel.py
@@ -13,10 +13,10 @@
class TestDataModel(unittest.TestCase):
@classmethod
def setUpClass(cls):
- host = 'localhost'
- user = 'test'
- password = 'test'
- database = 'automated_test'
+ host = "localhost"
+ user = "test"
+ password = "test"
+ database = "automated_test"
cls.g = PsqlGraphDriver(host, user, password, database)
cls._clear_tables()
@@ -27,59 +27,59 @@ def tearDown(self):
@classmethod
def _clear_tables(cls):
conn = cls.g.engine.connect()
- conn.execute('commit')
+ conn.execute("commit")
for table in Node().get_subclass_table_names():
if table != Node.__tablename__:
- conn.execute('delete from {}'.format(table))
+ conn.execute("delete from {}".format(table))
for table in Edge.get_subclass_table_names():
if table != Edge.__tablename__:
- conn.execute('delete from {}'.format(table))
- conn.execute('delete from versioned_nodes')
- conn.execute('delete from _voided_nodes')
- conn.execute('delete from _voided_edges')
+ conn.execute("delete from {}".format(table))
+ conn.execute("delete from versioned_nodes")
+ conn.execute("delete from _voided_nodes")
+ conn.execute("delete from _voided_edges")
conn.close()
def test_type_validation(self):
f = md.File()
with self.assertRaises(ValidationError):
- f.file_size = '0'
+ f.file_size = "0"
f.file_size = 0
f = md.File()
with self.assertRaises(ValidationError):
f.file_name = 0
- f.file_name = '0'
+ f.file_name = "0"
s = md.Sample()
with self.assertRaises(ValidationError):
- s.is_ffpe = 'false'
+ s.is_ffpe = "false"
s.is_ffpe = False
s = md.Slide()
with self.assertRaises(ValidationError):
- s.percent_necrosis = '0.0'
+ s.percent_necrosis = "0.0"
s.percent_necrosis = 0.0
def test_link_clobber_prevention(self):
with self.assertRaises(AssertionError):
md.EdgeFactory(
- 'Testedge',
- 'test',
- 'sample',
- 'aliquot',
- 'aliquots',
- '_uncontended_backref',
+ "Testedge",
+ "test",
+ "sample",
+ "aliquot",
+ "aliquots",
+ "_uncontended_backref",
)
def test_backref_clobber_prevention(self):
with self.assertRaises(AssertionError):
md.EdgeFactory(
- 'Testedge',
- 'test',
- 'sample',
- 'aliquot',
- '_uncontended_link',
- 'samples',
+ "Testedge",
+ "test",
+ "sample",
+ "aliquot",
+ "_uncontended_link",
+ "samples",
)
def test_created_datetime_hook(self):
@@ -87,7 +87,7 @@ def test_created_datetime_hook(self):
time_before = datetime.now().isoformat()
with self.g.session_scope() as s:
- s.add(md.Case('case1'))
+ s.add(md.Case("case1"))
time_after = datetime.now().isoformat()
@@ -102,14 +102,14 @@ def test_created_datetime_hook(self):
def test_updated_datetime_hook(self):
"""Test setting updated datetime when a node is updated."""
with self.g.session_scope() as s:
- s.merge(md.Case('case1'))
+ s.merge(md.Case("case1"))
with self.g.session_scope():
case = self.g.nodes(md.Case).one()
old_created_datetime = case.created_datetime
old_updated_datetime = case.updated_datetime
- case.primary_site = 'Kidney'
+ case.primary_site = "Kidney"
with self.g.session_scope():
updated_case = self.g.nodes(md.Case).one()
@@ -119,14 +119,14 @@ def test_updated_datetime_hook(self):
def test_no_datetime_update_for_new_edge(self):
"""Verify new inbound edges do not affect a node's updated datetime."""
with self.g.session_scope() as s:
- s.merge(md.Case('case1'))
+ s.merge(md.Case("case1"))
with self.g.session_scope() as s:
case = self.g.nodes(md.Case).one()
old_created_datetime = case.created_datetime
old_updated_datetime = case.updated_datetime
- sample = s.merge(md.Sample('sample1'))
+ sample = s.merge(md.Sample("sample1"))
case.samples.append(sample)
with self.g.session_scope():
@@ -137,9 +137,9 @@ def test_no_datetime_update_for_new_edge(self):
def test_default_values(self):
p = md.Project()
project_defaults = p._defaults
- assert project_defaults['state'] == 'open'
- assert project_defaults['submission_enabled'] == True
- assert project_defaults['released'] == False
+ assert project_defaults["state"] == "open"
+ assert project_defaults["submission_enabled"] == True
+ assert project_defaults["released"] == False
def test_file_pg_edges():
diff --git a/test/test_dictionary_loadiing.py b/test/test_dictionary_loadiing.py
index 87a5018e..b7673a01 100644
--- a/test/test_dictionary_loadiing.py
+++ b/test/test_dictionary_loadiing.py
@@ -3,16 +3,20 @@
from gdcdatamodel import models
-@pytest.mark.parametrize("namespace, expectation", [
- (None, "gdcdatamodel.models"),
- ("d1", "gdcdatamodel.models.d1"),
- ("d2", "gdcdatamodel.models.d2"),
-], ids=["default module", "custom 1", "custom 2"])
+@pytest.mark.parametrize(
+ "namespace, expectation",
+ [
+ (None, "gdcdatamodel.models"),
+ ("d1", "gdcdatamodel.models.d1"),
+ ("d2", "gdcdatamodel.models.d2"),
+ ],
+ ids=["default module", "custom 1", "custom 2"],
+)
def test_get_package_for_class(namespace, expectation):
- """ Tests retrieving the proper module to insert generated classes for a given dictionary
- Args:
- namespace (str): package namespace used to logically divided classes
- expectation (str): final module name
+ """Tests retrieving the proper module to insert generated classes for a given dictionary
+ Args:
+ namespace (str): package namespace used to logically divided classes
+ expectation (str): final module name
"""
module_name = models.get_cls_package(package_namespace=namespace)
@@ -20,8 +24,8 @@ def test_get_package_for_class(namespace, expectation):
def test_loading_same_dictionary():
- """ Tests loading gdcdictionary into a different namespace,
- even though it might already be loaded into the default
+ """Tests loading gdcdictionary into a different namespace,
+ even though it might already be loaded into the default
"""
# assert the custom module does not currently exist
with pytest.raises(ImportError):
@@ -31,6 +35,7 @@ def test_loading_same_dictionary():
# assert module & models now exist
from gdcdatamodel.models import gdcx # noqa
+
assert gdcx.Project and gdcx.Program and gdcx.Case
@@ -43,6 +48,7 @@ def test_case_cache_related_edge_resolution():
models.load_dictionary(dictionary=None, package_namespace="gdc")
from gdcdatamodel.models import gdc # noqa
+
ns = models.caching.get_related_case_edge_cls(gdc.AlignedReads())
class_name = "{}.{}".format(ns.__module__, ns.__name__)
assert "gdcdatamodel.models.gdc.AlignedReadsRelatesToCase" == class_name
diff --git a/test/test_gdc_postgres_admin.py b/test/test_gdc_postgres_admin.py
index e08d080e..e7899533 100644
--- a/test/test_gdc_postgres_admin.py
+++ b/test/test_gdc_postgres_admin.py
@@ -16,58 +16,60 @@
class TestGDCPostgresAdmin(unittest.TestCase):
- logger = logging.getLogger('TestGDCPostgresAdmin')
+ logger = logging.getLogger("TestGDCPostgresAdmin")
logger.setLevel(logging.INFO)
- host = 'localhost'
- user = 'postgres'
- database = 'automated_test'
+ host = "localhost"
+ user = "postgres"
+ database = "automated_test"
base_args = [
- '-H', host,
- '-U', user,
- '-D', database,
+ "-H",
+ host,
+ "-U",
+ user,
+ "-D",
+ database,
]
- g = PsqlGraphDriver(host, user, '', database)
+ g = PsqlGraphDriver(host, user, "", database)
root_con_str = "postgres://{user}:{pwd}@{host}/{db}".format(
- user=user, host=host, pwd='', db=database)
+ user=user, host=host, pwd="", db=database
+ )
engine = pgadmin.create_engine(root_con_str)
@classmethod
def tearDownClass(cls):
- """Recreate the database for tests that follow.
-
- """
+ """Recreate the database for tests that follow."""
cls.create_all_tables()
# Re-grant permissions to test user
for scls in Node.get_subclasses() + Edge.get_subclasses():
- statment = ("GRANT ALL PRIVILEGES ON TABLE {} TO test"
- .format(scls.__tablename__))
- cls.engine.execute('BEGIN; %s; COMMIT;' % statment)
+ statment = "GRANT ALL PRIVILEGES ON TABLE {} TO test".format(
+ scls.__tablename__
+ )
+ cls.engine.execute("BEGIN; %s; COMMIT;" % statment)
@classmethod
def drop_all_tables(cls):
for scls in Node.get_subclasses():
try:
- cls.engine.execute("DROP TABLE {} CASCADE"
- .format(scls.__tablename__))
+ cls.engine.execute("DROP TABLE {} CASCADE".format(scls.__tablename__))
except Exception as e:
cls.logger.warning(e)
@classmethod
def create_all_tables(cls):
parser = pgadmin.get_parser()
- args = parser.parse_args([
- 'graph-create', '--delay', '1', '--retries', '0'
- ] + cls.base_args)
+ args = parser.parse_args(
+ ["graph-create", "--delay", "1", "--retries", "0"] + cls.base_args
+ )
pgadmin.main(args)
@classmethod
def drop_a_table(cls):
- cls.engine.execute('DROP TABLE edge_clinicaldescribescase')
- cls.engine.execute('DROP TABLE node_clinical')
+ cls.engine.execute("DROP TABLE edge_clinicaldescribescase")
+ cls.engine.execute("DROP TABLE node_clinical")
def startTestRun(self):
self.drop_all_tables()
@@ -77,25 +79,29 @@ def setUp(self):
def test_args(self):
parser = pgadmin.get_parser()
- parser.parse_args(['graph-create'] + self.base_args)
+ parser.parse_args(["graph-create"] + self.base_args)
def test_create_single(self):
"""Test simple table creation"""
- pgadmin.main(pgadmin.get_parser().parse_args([
- 'graph-create', '--delay', '1', '--retries', '0'
- ] + self.base_args))
+ pgadmin.main(
+ pgadmin.get_parser().parse_args(
+ ["graph-create", "--delay", "1", "--retries", "0"] + self.base_args
+ )
+ )
- self.engine.execute('SELECT * from node_case')
+ self.engine.execute("SELECT * from node_case")
def test_create_double(self):
"""Test idempotency of table creation"""
- pgadmin.main(pgadmin.get_parser().parse_args([
- 'graph-create', '--delay', '1', '--retries', '0'
- ] + self.base_args))
+ pgadmin.main(
+ pgadmin.get_parser().parse_args(
+ ["graph-create", "--delay", "1", "--retries", "0"] + self.base_args
+ )
+ )
- self.engine.execute('SELECT * from node_case')
+ self.engine.execute("SELECT * from node_case")
def test_priv_grant_read(self):
"""Test ability to grant read but not write privs"""
@@ -106,23 +112,29 @@ def test_priv_grant_read(self):
self.engine.execute("CREATE USER pytest WITH PASSWORD 'pyt3st'")
self.engine.execute("GRANT USAGE ON SCHEMA public TO pytest")
- g = PsqlGraphDriver(self.host, 'pytest', 'pyt3st', self.database)
+ g = PsqlGraphDriver(self.host, "pytest", "pyt3st", self.database)
#: If this failes, this test (not the code) is wrong!
with self.assertRaises(ProgrammingError):
with g.session_scope():
g.nodes().count()
- pgadmin.main(pgadmin.get_parser().parse_args([
- 'graph-grant', '--read=pytest',
- ] + self.base_args))
+ pgadmin.main(
+ pgadmin.get_parser().parse_args(
+ [
+ "graph-grant",
+ "--read=pytest",
+ ]
+ + self.base_args
+ )
+ )
with g.session_scope():
g.nodes().count()
with self.assertRaises(ProgrammingError):
with g.session_scope() as s:
- s.merge(models.Case('1'))
+ s.merge(models.Case("1"))
finally:
self.engine.execute("DROP OWNED BY pytest; DROP USER pytest")
@@ -136,14 +148,20 @@ def test_priv_grant_write(self):
self.engine.execute("CREATE USER pytest WITH PASSWORD 'pyt3st'")
self.engine.execute("GRANT USAGE ON SCHEMA public TO pytest")
- g = PsqlGraphDriver(self.host, 'pytest', 'pyt3st', self.database)
- pgadmin.main(pgadmin.get_parser().parse_args([
- 'graph-grant', '--write=pytest',
- ] + self.base_args))
+ g = PsqlGraphDriver(self.host, "pytest", "pyt3st", self.database)
+ pgadmin.main(
+ pgadmin.get_parser().parse_args(
+ [
+ "graph-grant",
+ "--write=pytest",
+ ]
+ + self.base_args
+ )
+ )
with g.session_scope() as s:
g.nodes().count()
- s.merge(models.Case('1'))
+ s.merge(models.Case("1"))
finally:
self.engine.execute("DROP OWNED BY pytest; DROP USER pytest")
@@ -157,20 +175,32 @@ def test_priv_revoke_read(self):
self.engine.execute("CREATE USER pytest WITH PASSWORD 'pyt3st'")
self.engine.execute("GRANT USAGE ON SCHEMA public TO pytest")
- g = PsqlGraphDriver(self.host, 'pytest', 'pyt3st', self.database)
-
- pgadmin.main(pgadmin.get_parser().parse_args([
- 'graph-grant', '--read=pytest',
- ] + self.base_args))
-
- pgadmin.main(pgadmin.get_parser().parse_args([
- 'graph-revoke', '--read=pytest',
- ] + self.base_args))
+ g = PsqlGraphDriver(self.host, "pytest", "pyt3st", self.database)
+
+ pgadmin.main(
+ pgadmin.get_parser().parse_args(
+ [
+ "graph-grant",
+ "--read=pytest",
+ ]
+ + self.base_args
+ )
+ )
+
+ pgadmin.main(
+ pgadmin.get_parser().parse_args(
+ [
+ "graph-revoke",
+ "--read=pytest",
+ ]
+ + self.base_args
+ )
+ )
with self.assertRaises(ProgrammingError):
with g.session_scope() as s:
g.nodes().count()
- s.merge(models.Case('1'))
+ s.merge(models.Case("1"))
finally:
self.engine.execute("DROP OWNED BY pytest; DROP USER pytest")
@@ -184,22 +214,34 @@ def test_priv_revoke_write(self):
self.engine.execute("CREATE USER pytest WITH PASSWORD 'pyt3st'")
self.engine.execute("GRANT USAGE ON SCHEMA public TO pytest")
- g = PsqlGraphDriver(self.host, 'pytest', 'pyt3st', self.database)
-
- pgadmin.main(pgadmin.get_parser().parse_args([
- 'graph-grant', '--write=pytest',
- ] + self.base_args))
-
- pgadmin.main(pgadmin.get_parser().parse_args([
- 'graph-revoke', '--write=pytest',
- ] + self.base_args))
+ g = PsqlGraphDriver(self.host, "pytest", "pyt3st", self.database)
+
+ pgadmin.main(
+ pgadmin.get_parser().parse_args(
+ [
+ "graph-grant",
+ "--write=pytest",
+ ]
+ + self.base_args
+ )
+ )
+
+ pgadmin.main(
+ pgadmin.get_parser().parse_args(
+ [
+ "graph-revoke",
+ "--write=pytest",
+ ]
+ + self.base_args
+ )
+ )
with g.session_scope() as s:
g.nodes().count()
with self.assertRaises(ProgrammingError):
with g.session_scope() as s:
- s.merge(models.Case('1'))
+ s.merge(models.Case("1"))
finally:
self.engine.execute("DROP OWNED BY pytest; DROP USER pytest")
diff --git a/test/test_indexes.py b/test/test_indexes.py
index cd0f4050..2307e0ea 100644
--- a/test/test_indexes.py
+++ b/test/test_indexes.py
@@ -8,7 +8,7 @@
def test_secondary_key_indexes(indexes):
- assert 'index_node_datasubtype_name_lower' in indexes
- assert 'index_node_analyte_project_id' in indexes
- assert 'index_4df72441_famihist_submitte_id_lower' in indexes
- assert 'transaction_logs_project_id_idx' in indexes
+ assert "index_node_datasubtype_name_lower" in indexes
+ assert "index_node_analyte_project_id" in indexes
+ assert "index_4df72441_famihist_submitte_id_lower" in indexes
+ assert "transaction_logs_project_id_idx" in indexes
diff --git a/test/test_node_tagging.py b/test/test_node_tagging.py
index 65bbb46b..64ce70df 100644
--- a/test/test_node_tagging.py
+++ b/test/test_node_tagging.py
@@ -6,16 +6,16 @@
from gdcdatamodel.models import basic, versioning # noqa
-@pytest.fixture(scope='module')
+@pytest.fixture(scope="module")
def bg():
"""Fixture for database driver"""
cfg = {
- 'host': 'localhost',
- 'user': 'test',
- 'password': 'test',
- 'database': 'dev_models',
- 'package_namespace': 'basic',
+ "host": "localhost",
+ "user": "test",
+ "password": "test",
+ "database": "dev_models",
+ "package_namespace": "basic",
}
g = PsqlGraphDriver(**cfg)
@@ -31,7 +31,10 @@ def create_samples(sample_data, bg):
version_2s = []
for node in sample_data:
# delay adding version 2
- if node.node_id in ["a2b2d27a-6523-4ddd-8b2e-e94437a2aa23", "5ffb4b0e-969e-4643-8187-536ce7130e9c"]:
+ if node.node_id in [
+ "a2b2d27a-6523-4ddd-8b2e-e94437a2aa23",
+ "5ffb4b0e-969e-4643-8187-536ce7130e9c",
+ ]:
version_2s.append(node)
continue
s.add(node)
@@ -45,16 +48,51 @@ def create_samples(sample_data, bg):
bg.node_delete(n.node_id)
-@pytest.mark.parametrize("node_id, tag, version", [
- ("be66197b-f6cc-4366-bded-365856ec4f63", "84044bd2-54a4-5837-b83d-f920eb97c18d", 1),
- ("a2b2d27a-6523-4ddd-8b2e-e94437a2aa23", "84044bd2-54a4-5837-b83d-f920eb97c18d", 2),
- ("813f97c4-dffc-4f94-b3f6-66a93476a233", "9a81bbad-b525-568c-b85d-d269a8bdc70a", 1),
- ("6974c692-be47-4cb8-b8d6-9bd815983cd9", "55814b2f-fc23-5bed-9eab-c73c52c105df", 1),
- ("5ffb4b0e-969e-4643-8187-536ce7130e9c", "55814b2f-fc23-5bed-9eab-c73c52c105df", 2),
- ("c6a795f6-ee4a-4fcd-bfed-79348e07cd49", "8cc95392-5861-5524-8b98-a85e18d8294c", 1),
- ("ed9aa864-1e40-4657-9378-7e3dc26551cc", "fddc5826-8853-5c1a-847d-5850d58ccb3e", 1),
- ("fb69d25b-5c5d-4879-8955-8f2126e57524", "293d5dd3-117c-5a0a-8030-a428fdf2681b", 1),
-])
+@pytest.mark.parametrize(
+ "node_id, tag, version",
+ [
+ (
+ "be66197b-f6cc-4366-bded-365856ec4f63",
+ "84044bd2-54a4-5837-b83d-f920eb97c18d",
+ 1,
+ ),
+ (
+ "a2b2d27a-6523-4ddd-8b2e-e94437a2aa23",
+ "84044bd2-54a4-5837-b83d-f920eb97c18d",
+ 2,
+ ),
+ (
+ "813f97c4-dffc-4f94-b3f6-66a93476a233",
+ "9a81bbad-b525-568c-b85d-d269a8bdc70a",
+ 1,
+ ),
+ (
+ "6974c692-be47-4cb8-b8d6-9bd815983cd9",
+ "55814b2f-fc23-5bed-9eab-c73c52c105df",
+ 1,
+ ),
+ (
+ "5ffb4b0e-969e-4643-8187-536ce7130e9c",
+ "55814b2f-fc23-5bed-9eab-c73c52c105df",
+ 2,
+ ),
+ (
+ "c6a795f6-ee4a-4fcd-bfed-79348e07cd49",
+ "8cc95392-5861-5524-8b98-a85e18d8294c",
+ 1,
+ ),
+ (
+ "ed9aa864-1e40-4657-9378-7e3dc26551cc",
+ "fddc5826-8853-5c1a-847d-5850d58ccb3e",
+ 1,
+ ),
+ (
+ "fb69d25b-5c5d-4879-8955-8f2126e57524",
+ "293d5dd3-117c-5a0a-8030-a428fdf2681b",
+ 1,
+ ),
+ ],
+)
def test_1(create_samples, bg, node_id, tag, version):
with bg.session_scope():
diff --git a/test/test_update_case_cache.py b/test/test_update_case_cache.py
index 87170de0..0c681e30 100644
--- a/test/test_update_case_cache.py
+++ b/test/test_update_case_cache.py
@@ -16,22 +16,22 @@
def case_tree(g):
"""Create tree to test cache cache on"""
- case = md.Case('case')
+ case = md.Case("case")
case.samples = [
- md.Sample('sample1'),
- md.Sample('sample2'),
+ md.Sample("sample1"),
+ md.Sample("sample2"),
]
case.samples[0].portions = [
- md.Portion('portion1'),
- md.Portion('portion2'),
+ md.Portion("portion1"),
+ md.Portion("portion2"),
]
case.samples[0].portions[0].analytes = [
- md.Analyte('analyte1'),
- md.Analyte('analyte2'),
+ md.Analyte("analyte1"),
+ md.Analyte("analyte2"),
]
case.samples[0].portions[0].analytes[0].aliquots = [
- md.Aliquot('aliquot1'),
- md.Aliquot('alituoq2'),
+ md.Aliquot("aliquot1"),
+ md.Aliquot("alituoq2"),
]
with g.session_scope() as session:
session.merge(case)
@@ -58,7 +58,7 @@ def test_update_case_cache(g, case_tree_no_cache):
with g.session_scope():
nodes = g.nodes().all()
- nodes = (node for node in nodes if hasattr(node, '_related_cases'))
+ nodes = (node for node in nodes if hasattr(node, "_related_cases"))
for node in nodes:
assert node._related_cases
diff --git a/test/test_validators.py b/test/test_validators.py
index 6a4c5d79..fd760138 100644
--- a/test/test_validators.py
+++ b/test/test_validators.py
@@ -25,44 +25,59 @@ def setUp(self):
self.entities = [MockSubmissionEntity()]
def test_json_validator_with_insufficient_properties(self):
- self.entities[0].doc = {'type': 'aliquot',
- 'centers': {'submitter_id': 'test'}}
+ self.entities[0].doc = {"type": "aliquot", "centers": {"submitter_id": "test"}}
self.json_validator.record_errors(self.entities)
- self.assertEqual(self.entities[0].errors[0]['keys'], ['submitter_id'])
+ self.assertEqual(self.entities[0].errors[0]["keys"], ["submitter_id"])
self.assertEqual(1, len(self.entities[0].errors))
def test_json_validator_with_wrong_node_type(self):
- self.entities[0].doc = {'type': 'aliquo'}
+ self.entities[0].doc = {"type": "aliquo"}
self.json_validator.record_errors(self.entities)
- self.assertEqual(self.entities[0].errors[0]['keys'], ['type'])
+ self.assertEqual(self.entities[0].errors[0]["keys"], ["type"])
self.assertEqual(1, len(self.entities[0].errors))
def test_json_validator_with_wrong_property_type(self):
- self.entities[0].doc = {'type': 'aliquot',
- 'submitter_id': 1, 'centers': {'submitter_id': 'test'}}
+ self.entities[0].doc = {
+ "type": "aliquot",
+ "submitter_id": 1,
+ "centers": {"submitter_id": "test"},
+ }
self.json_validator.record_errors(self.entities)
- self.assertEqual(['submitter_id'], self.entities[0].errors[0]['keys'])
+ self.assertEqual(["submitter_id"], self.entities[0].errors[0]["keys"])
self.assertEqual(1, len(self.entities[0].errors))
def test_json_validator_with_multiple_errors(self):
- self.entities[0].doc = {'type': 'aliquot', 'submitter_id': 1,
- 'test': 'test',
- 'centers': {'submitter_id': 'test'}}
+ self.entities[0].doc = {
+ "type": "aliquot",
+ "submitter_id": 1,
+ "test": "test",
+ "centers": {"submitter_id": "test"},
+ }
self.json_validator.record_errors(self.entities)
self.assertEqual(2, len(self.entities[0].errors))
def test_json_validator_with_nested_error_keys(self):
- self.entities[0].doc = {'type': 'aliquot', 'submitter_id': 'test',
- 'centers': {'submitter_id': True}}
+ self.entities[0].doc = {
+ "type": "aliquot",
+ "submitter_id": "test",
+ "centers": {"submitter_id": True},
+ }
self.json_validator.record_errors(self.entities)
- self.assertEqual(['centers'], self.entities[0].errors[0]['keys'])
+ self.assertEqual(["centers"], self.entities[0].errors[0]["keys"])
def test_json_validator_with_multiple_entities(self):
- self.entities[0].doc = {'type': 'aliquot', 'submitter_id': 1, 'test': 'test',
- 'centers': {'submitter_id': 'test'}}
+ self.entities[0].doc = {
+ "type": "aliquot",
+ "submitter_id": 1,
+ "test": "test",
+ "centers": {"submitter_id": "test"},
+ }
entity = MockSubmissionEntity()
- entity.doc = {'type': 'aliquot', 'submitter_id': 'test',
- 'centers': {'submitter_id': 'test'}}
+ entity.doc = {
+ "type": "aliquot",
+ "submitter_id": "test",
+ "centers": {"submitter_id": "test"},
+ }
self.entities.append(entity)
self.json_validator.record_errors(self.entities)
@@ -71,10 +86,10 @@ def test_json_validator_with_multiple_entities(self):
def test_json_validator_with_array_prop(self):
entity_doc = {
- 'type': 'diagnosis',
- 'submitter_id': 'test',
- 'age_at_diagnosis': 10,
- 'primary_diagnosis': 'Abdominal desmoid',
+ "type": "diagnosis",
+ "submitter_id": "test",
+ "age_at_diagnosis": 10,
+ "primary_diagnosis": "Abdominal desmoid",
"morphology": "8000/0",
"tissue_or_organ_of_origin": "Abdomen, NOS",
"site_of_resection_or_biopsy": "Abdomen, NOS",
@@ -107,10 +122,10 @@ def mock_doc(sites_of_involvement):
self.assertIn("diagnosis_is_primary_disease", error_keys_one)
def create_node(self, doc, session):
- cls = Node.get_subclass(doc['type'])
+ cls = Node.get_subclass(doc["type"])
node = cls(str(uuid.uuid4()))
- node.props = doc['props']
- for key, value in doc['edges'].items():
+ node.props = doc["props"]
+ for key, value in doc["edges"].items():
for target_id in value:
edge = self.g.nodes().ids(target_id).first()
node[key].append(edge)
@@ -122,164 +137,264 @@ def update_schema(self, entity, key, schema):
def test_graph_validator_without_required_link(self):
with self.g.session_scope() as session:
- node = self.create_node({'type': 'aliquot',
- 'props': {'submitter_id': 'test'},
- 'edges': {}}, session)
+ node = self.create_node(
+ {"type": "aliquot", "props": {"submitter_id": "test"}, "edges": {}},
+ session,
+ )
self.entities[0].node = node
self.update_schema(
- 'aliquot',
- 'links',
- [{'name': 'analytes',
- 'backref': 'aliquots',
- 'label': 'derived_from',
- 'multiplicity': 'many_to_one',
- 'target_type': 'analyte',
- 'required': True}])
+ "aliquot",
+ "links",
+ [
+ {
+ "name": "analytes",
+ "backref": "aliquots",
+ "label": "derived_from",
+ "multiplicity": "many_to_one",
+ "target_type": "analyte",
+ "required": True,
+ }
+ ],
+ )
self.graph_validator.record_errors(self.g, self.entities)
- self.assertEqual(['analytes'], self.entities[0].errors[0]['keys'])
+ self.assertEqual(["analytes"], self.entities[0].errors[0]["keys"])
def test_graph_validator_with_exclusive_link(self):
with self.g.session_scope() as session:
analyte = self.create_node(
- {'type': 'analyte',
- 'props': {'submitter_id': 'test',
- 'analyte_type_id': 'D',
- 'analyte_type': 'DNA'},
- 'edges': {}}, session)
- sample = self.create_node({'type': 'sample',
- 'props': {'submitter_id': 'test',
- 'sample_type': 'DNA',
- 'sample_type_id': '01'},
- 'edges': {}}, session)
+ {
+ "type": "analyte",
+ "props": {
+ "submitter_id": "test",
+ "analyte_type_id": "D",
+ "analyte_type": "DNA",
+ },
+ "edges": {},
+ },
+ session,
+ )
+ sample = self.create_node(
+ {
+ "type": "sample",
+ "props": {
+ "submitter_id": "test",
+ "sample_type": "DNA",
+ "sample_type_id": "01",
+ },
+ "edges": {},
+ },
+ session,
+ )
node = self.create_node(
- {'type': 'aliquot',
- 'props': {'submitter_id': 'test'},
- 'edges': {'analytes': [analyte.node_id],
- 'samples': [sample.node_id]}}, session)
+ {
+ "type": "aliquot",
+ "props": {"submitter_id": "test"},
+ "edges": {
+ "analytes": [analyte.node_id],
+ "samples": [sample.node_id],
+ },
+ },
+ session,
+ )
self.entities[0].node = node
self.update_schema(
- 'aliquot',
- 'links',
- [{'exclusive': True,
- 'required': True,
- 'subgroup': [
- {'name': 'analytes',
- 'backref': 'aliquots',
- 'label': 'derived_from',
- 'multiplicity': 'many_to_one',
- 'target_type': 'analyte'},
- {'name': 'samples',
- 'backref': 'aliquots',
- 'label': 'derived_from',
- 'multiplicity': 'many_to_one',
- 'target_type': 'sample'}]}])
+ "aliquot",
+ "links",
+ [
+ {
+ "exclusive": True,
+ "required": True,
+ "subgroup": [
+ {
+ "name": "analytes",
+ "backref": "aliquots",
+ "label": "derived_from",
+ "multiplicity": "many_to_one",
+ "target_type": "analyte",
+ },
+ {
+ "name": "samples",
+ "backref": "aliquots",
+ "label": "derived_from",
+ "multiplicity": "many_to_one",
+ "target_type": "sample",
+ },
+ ],
+ }
+ ],
+ )
self.graph_validator.record_errors(self.g, self.entities)
- self.assertEqual(['analytes', 'samples'],
- self.entities[0].errors[0]['keys'])
+ self.assertEqual(
+ ["analytes", "samples"], self.entities[0].errors[0]["keys"]
+ )
def test_graph_validator_with_wrong_multiplicity(self):
with self.g.session_scope() as session:
- analyte = self.create_node({'type': 'analyte',
- 'props': {'submitter_id': 'test',
- 'analyte_type_id': 'D',
- 'analyte_type': 'DNA'},
- 'edges': {}}, session)
-
- analyte_b = self.create_node({'type': 'analyte',
- 'props': {'submitter_id': 'testb',
- 'analyte_type_id': 'H',
- 'analyte_type': 'RNA'},
- 'edges': {}}, session)
-
- node = self.create_node({'type': 'aliquot',
- 'props': {'submitter_id': 'test'},
- 'edges': {'analytes': [analyte.node_id,
- analyte_b.node_id]}},
- session)
+ analyte = self.create_node(
+ {
+ "type": "analyte",
+ "props": {
+ "submitter_id": "test",
+ "analyte_type_id": "D",
+ "analyte_type": "DNA",
+ },
+ "edges": {},
+ },
+ session,
+ )
+
+ analyte_b = self.create_node(
+ {
+ "type": "analyte",
+ "props": {
+ "submitter_id": "testb",
+ "analyte_type_id": "H",
+ "analyte_type": "RNA",
+ },
+ "edges": {},
+ },
+ session,
+ )
+
+ node = self.create_node(
+ {
+ "type": "aliquot",
+ "props": {"submitter_id": "test"},
+ "edges": {"analytes": [analyte.node_id, analyte_b.node_id]},
+ },
+ session,
+ )
self.entities[0].node = node
self.update_schema(
- 'aliquot',
- 'links',
- [{'exclusive': False,
- 'required': True,
- 'subgroup': [
- {'name': 'analytes',
- 'backref': 'aliquots',
- 'label': 'derived_from',
- 'multiplicity': 'many_to_one',
- 'target_type': 'analyte'},
- {'name': 'samples',
- 'backref': 'aliquots',
- 'label': 'derived_from',
- 'multiplicity': 'many_to_one',
- 'target_type': 'sample'}]}])
+ "aliquot",
+ "links",
+ [
+ {
+ "exclusive": False,
+ "required": True,
+ "subgroup": [
+ {
+ "name": "analytes",
+ "backref": "aliquots",
+ "label": "derived_from",
+ "multiplicity": "many_to_one",
+ "target_type": "analyte",
+ },
+ {
+ "name": "samples",
+ "backref": "aliquots",
+ "label": "derived_from",
+ "multiplicity": "many_to_one",
+ "target_type": "sample",
+ },
+ ],
+ }
+ ],
+ )
self.graph_validator.record_errors(self.g, self.entities)
- self.assertEqual(['analytes'], self.entities[0].errors[0]['keys'])
+ self.assertEqual(["analytes"], self.entities[0].errors[0]["keys"])
def test_graph_validator_with_correct_node(self):
with self.g.session_scope() as session:
- analyte = self.create_node({'type': 'analyte',
- 'props': {'submitter_id': 'test',
- 'analyte_type_id': 'D',
- 'analyte_type': 'DNA'},
- 'edges': {}}, session)
-
- node = self.create_node({'type': 'aliquot',
- 'props': {'submitter_id': 'test'},
- 'edges': {'analytes': [analyte.node_id]}},
- session)
+ analyte = self.create_node(
+ {
+ "type": "analyte",
+ "props": {
+ "submitter_id": "test",
+ "analyte_type_id": "D",
+ "analyte_type": "DNA",
+ },
+ "edges": {},
+ },
+ session,
+ )
+
+ node = self.create_node(
+ {
+ "type": "aliquot",
+ "props": {"submitter_id": "test"},
+ "edges": {"analytes": [analyte.node_id]},
+ },
+ session,
+ )
self.entities[0].node = node
self.update_schema(
- 'aliquot',
- 'links',
- [{'exclusive': False,
- 'required': True,
- 'subgroup': [
- {'name': 'analytes',
- 'backref': 'aliquots',
- 'label': 'derived_from',
- 'multiplicity': 'many_to_one',
- 'target_type': 'analyte'},
- {'name': 'samples',
- 'backref': 'aliquots',
- 'label': 'derived_from',
- 'multiplicity': 'many_to_one',
- 'target_type': 'sample'}]}])
+ "aliquot",
+ "links",
+ [
+ {
+ "exclusive": False,
+ "required": True,
+ "subgroup": [
+ {
+ "name": "analytes",
+ "backref": "aliquots",
+ "label": "derived_from",
+ "multiplicity": "many_to_one",
+ "target_type": "analyte",
+ },
+ {
+ "name": "samples",
+ "backref": "aliquots",
+ "label": "derived_from",
+ "multiplicity": "many_to_one",
+ "target_type": "sample",
+ },
+ ],
+ }
+ ],
+ )
self.graph_validator.record_errors(self.g, self.entities)
self.assertEqual(0, len(self.entities[0].errors))
def test_graph_validator_with_existing_unique_keys(self):
with self.g.session_scope() as session:
- node = self.create_node({'type': 'data_format',
- 'props': {'name': 'test'},
- 'edges': {}},
- session)
- node = self.create_node({'type': 'data_format',
- 'props': {'name': 'test'},
- 'edges': {}},
- session)
- self.update_schema('data_format', 'uniqueKeys', [['name']])
+ node = self.create_node(
+ {"type": "data_format", "props": {"name": "test"}, "edges": {}}, session
+ )
+ node = self.create_node(
+ {"type": "data_format", "props": {"name": "test"}, "edges": {}}, session
+ )
+ self.update_schema("data_format", "uniqueKeys", [["name"]])
self.entities[0].node = node
self.graph_validator.record_errors(self.g, self.entities)
- self.assertEqual(['name'], self.entities[0].errors[0]['keys'])
+ self.assertEqual(["name"], self.entities[0].errors[0]["keys"])
def test_graph_validator_with_existing_unique_keys_for_different_node_types(self):
with self.g.session_scope() as session:
- node = self.create_node({'type': 'sample',
- 'props': {'submitter_id': 'test','project_id':'A'},
- 'edges': {}},
- session)
- node = self.create_node({'type': 'aliquot',
- 'props': {'submitter_id': 'test', 'project_id':'A'},
- 'edges': {}},
- session)
- self.update_schema('data_format', 'uniqueKeys', [['submitter_id', 'project_id']])
+ node = self.create_node(
+ {
+ "type": "sample",
+ "props": {"submitter_id": "test", "project_id": "A"},
+ "edges": {},
+ },
+ session,
+ )
+ node = self.create_node(
+ {
+ "type": "aliquot",
+ "props": {"submitter_id": "test", "project_id": "A"},
+ "edges": {},
+ },
+ session,
+ )
+ self.update_schema(
+ "data_format", "uniqueKeys", [["submitter_id", "project_id"]]
+ )
self.entities[0].node = node
self.graph_validator.record_errors(self.g, self.entities)
# Check (project_id, submitter_id) uniqueness is captured
- self.assertTrue(any({'project_id', 'submitter_id'} == set(e['keys'])
- for e in self.entities[0].errors))
+ self.assertTrue(
+ any(
+ {"project_id", "submitter_id"} == set(e["keys"])
+ for e in self.entities[0].errors
+ )
+ )
# Check that missing edges is captured
- self.assertTrue(any({'analytes', 'samples'} == set(e['keys'])
- for e in self.entities[0].errors))
+ self.assertTrue(
+ any(
+ {"analytes", "samples"} == set(e["keys"])
+ for e in self.entities[0].errors
+ )
+ )
diff --git a/test/test_versioned_nodes.py b/test/test_versioned_nodes.py
index 82f82856..d0256e38 100644
--- a/test/test_versioned_nodes.py
+++ b/test/test_versioned_nodes.py
@@ -4,32 +4,35 @@
class TestValidators(BaseTestCase):
-
@staticmethod
def new_portion():
- portion = md.Portion(**{
- 'node_id': 'case1',
- 'is_ffpe': False,
- 'portion_number': '01',
- 'project_id': 'CGCI-BLGSP',
- 'state': 'validated',
- 'submitter_id': 'PORTION-1',
- 'weight': 54.0
- })
- portion.acl = ['acl1']
- portion.sysan.update({'key1': 'val1'})
+ portion = md.Portion(
+ **{
+ "node_id": "case1",
+ "is_ffpe": False,
+ "portion_number": "01",
+ "project_id": "CGCI-BLGSP",
+ "state": "validated",
+ "submitter_id": "PORTION-1",
+ "weight": 54.0,
+ }
+ )
+ portion.acl = ["acl1"]
+ portion.sysan.update({"key1": "val1"})
return portion
@staticmethod
def new_analyte():
- return md.Analyte(**{
- 'node_id': 'analyte1',
- 'analyte_type': 'Repli-G (Qiagen) DNA',
- 'analyte_type_id': 'W',
- 'project_id': 'CGCI-BLGSP',
- 'state': 'validated',
- 'submitter_id': 'TCGA-AR-A1AR-01A-31W',
- })
+ return md.Analyte(
+ **{
+ "node_id": "analyte1",
+ "analyte_type": "Repli-G (Qiagen) DNA",
+ "analyte_type_id": "W",
+ "project_id": "CGCI-BLGSP",
+ "state": "validated",
+ "submitter_id": "TCGA-AR-A1AR-01A-31W",
+ }
+ )
def test_round_trip(self):
with self.g.session_scope() as session:
@@ -46,12 +49,12 @@ def test_round_trip(self):
with self.g.session_scope():
v_node = self.g.nodes(md.VersionedNode).one()
- self.assertEqual(v_node.properties['is_ffpe'], False)
- self.assertEqual(v_node.properties['state'], 'validated')
- self.assertEqual(v_node.properties['state'], 'validated')
- self.assertEqual(v_node.system_annotations, {'key1': 'val1'})
- self.assertEqual(v_node.acl, ['acl1'])
- self.assertEqual(v_node.neighbors, ['analyte1'])
+ self.assertEqual(v_node.properties["is_ffpe"], False)
+ self.assertEqual(v_node.properties["state"], "validated")
+ self.assertEqual(v_node.properties["state"], "validated")
+ self.assertEqual(v_node.system_annotations, {"key1": "val1"})
+ self.assertEqual(v_node.acl, ["acl1"])
+ self.assertEqual(v_node.neighbors, ["analyte1"])
self.assertIsNotNone(v_node.versioned)
self.assertIsNotNone(v_node.key)