Skip to content

Commit

Permalink
Add --dry-run command line parameter.
Browse files Browse the repository at this point in the history
This commit fixes issue kbase#56 by adding the --dry-run parameter to validate the
input data and print an output summary even in the presence of errors. To this
end, the --output parameter is also introduced to give an option to display the
summary as JSON or in a more user-friendly text format. Data should not be
loaded if there are errors or if this command is invoked with --dry-run.
  • Loading branch information
dakotablair committed Nov 19, 2020
1 parent afd0a65 commit 53eb379
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 43 deletions.
158 changes: 122 additions & 36 deletions importers/djornl/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
RES_ROOT_DATA_PATH=/path/to/data/dir python -m importers.djornl.parser
"""
import argparse
import csv
import json
import requests
import os
import csv
import requests
import yaml

import importers.utils.config as config
Expand Down Expand Up @@ -48,7 +49,9 @@ def config(self, value):

def _configure(self):

configuration = config.load_from_env(extra_required=['ROOT_DATA_PATH'])
configuration = config.load_from_env(
extra_required=['ROOT_DATA_PATH']
)

# Collection name config
configuration['node_name'] = 'djornl_node'
Expand Down Expand Up @@ -89,7 +92,9 @@ def _get_dataset_schema_dir(self):

if not hasattr(self, '_dataset_schema_dir'):
dir_path = os.path.dirname(os.path.realpath(__file__))
self._dataset_schema_dir = os.path.join(dir_path, '../', '../', 'spec', 'datasets', 'djornl')
self._dataset_schema_dir = os.path.join(
dir_path, '../', '../', 'spec', 'datasets', 'djornl'
)

return self._dataset_schema_dir

Expand Down Expand Up @@ -131,7 +136,7 @@ def _get_file_reader(self, fd, file):
'''Given a dict containing file information, instantiate the correct type of parser'''

delimiter = '\t'
if 'file_format' in file and file['file_format'].lower() == 'csv' or file['path'].lower().endswith('.csv'):
if file.get('file_format', '').lower() == 'csv' or file['path'].lower().endswith('.csv'):
delimiter = ','
return csv.reader(fd, delimiter=delimiter)

Expand Down Expand Up @@ -176,8 +181,8 @@ def check_headers(self, headers, validator=None):
:return header_errs: (dict) dict of header errors:
'missing': required headers that are missing from the input
'invalid': additional headers that should not be in the input
'duplicate': duplicated headers (content would be overwritten)
'invalid': headers that should not be in the input
'duplicate': duplicated headers (data would be overwritten)
If the list of headers supplied is valid--i.e. it
contains all the fields marked as required in the validator
schema--or no validator has been supplied, the method
Expand Down Expand Up @@ -207,7 +212,7 @@ def check_headers(self, headers, validator=None):
if missing_headers:
header_errs['missing'] = missing_headers

if 'additionalProperties' in validator.schema and validator.schema['additionalProperties'] is False:
if not validator.schema.get('additionalProperties', True):
all_props = validator.schema['properties'].keys()
extra_headers = [i for i in headers if i not in all_props]
if extra_headers:
Expand Down Expand Up @@ -268,11 +273,16 @@ def process_file(self, file, remap_fn, store_fn, err_list, validator=None):
"""
print("Parsing " + file['data_type'] + " file " + file['file_path'])
file_parser = self.parser_gen(file)

def add_error(error):
print(error)
err_list.append(error)

try:
(line_no, cols, err_str) = next(file_parser)
except StopIteration:
# no valid lines found in the file
err_list.append(f"{file['path']}: no header line found")
add_error(f"{file['path']}: no header line found")
return

header_errors = self.check_headers(cols, validator)
Expand All @@ -284,7 +294,7 @@ def process_file(self, file, remap_fn, store_fn, err_list, validator=None):
}
for err_type in ['missing', 'invalid', 'duplicate']:
if err_type in header_errors:
err_list.append(
add_error(
f"{file['path']}: {err_str[err_type]} headers: "
+ ", ".join(sorted(header_errors[err_type]))
)
Expand All @@ -295,7 +305,7 @@ def process_file(self, file, remap_fn, store_fn, err_list, validator=None):
for (line_no, cols, err_str) in file_parser:
# mismatch in number of cols
if cols is None:
err_list.append(err_str)
add_error(err_str)
continue

# merge headers with cols to create an object
Expand All @@ -308,15 +318,15 @@ def process_file(self, file, remap_fn, store_fn, err_list, validator=None):
f"{file['path']} line {line_no}: " + e.message
for e in sorted(validator.iter_errors(row_object), key=str)
)
err_list.append(err_msg)
add_error(err_msg)
continue

try:
# transform it using the remap_functions
datum = self.remap_object(row_object, remap_fn)
except Exception as err:
err_type = type(err)
err_list.append(
add_error(
f"{file['path']} line {line_no}: error remapping data: {err_type} {err}"
)
continue
Expand All @@ -326,16 +336,16 @@ def process_file(self, file, remap_fn, store_fn, err_list, validator=None):
if storage_error is None:
n_stored += 1
else:
err_list.append(f"{file['path']} line {line_no}: " + storage_error)
add_error(f"{file['path']} line {line_no}: " + storage_error)

if not n_stored:
err_list.append(f"{file['path']}: no valid data found")
add_error(f"{file['path']}: no valid data found")

def store_parsed_edge_data(self, datum):
"""
store node and edge data in the node (node_ix) and edge (edge_ix) indexes respectively
Nodes are indexed by the '_key' attribute. Parsed edge data only contains node '_key' values.
Nodes are indexed by the '_key' attribute.
Parsed edge data only contains node '_key' values.
Edges are indexed by the unique combination of the two node IDs and the edge type. It is
assumed that if there is more than one score for a given combination of node IDs and edge
Expand Down Expand Up @@ -380,7 +390,9 @@ def load_edges(self):
# can do so because that key is in a 'required' property in the CSV spec file
remap_functions = {
# create a unique key for each record
'_key': lambda row: '__'.join([row[_] for _ in ['node1', 'node2', 'edge_type', 'score']]),
'_key': lambda row: '__'.join(
[row[_] for _ in ['node1', 'node2', 'edge_type', 'score']]
),
'node1': None, # this will be deleted in the 'store' step
'node2': None, # as will this
'_from': lambda row: node_name + '/' + row['node1'],
Expand All @@ -399,8 +411,8 @@ def load_edges(self):
)

return {
'nodes': self.node_ix.values(),
'edges': self.edge_ix.values(),
'nodes': list(self.node_ix.values()),
'edges': list(self.edge_ix.values()),
'err_list': err_list,
}

Expand Down Expand Up @@ -431,7 +443,10 @@ def _try_node_merge(self, existing_node, new_node, path=[]):
merge = {**existing_node, **new_node}

# find the shared keys -- keys in both existing and new nodes where the values differ
shared_keys = [i for i in new_node if i in existing_node and new_node[i] != existing_node[i]]
shared_keys = [
i for i in new_node
if i in existing_node and new_node[i] != existing_node[i]
]

# if there were no shared keys, return the merged list
if not shared_keys:
Expand Down Expand Up @@ -589,7 +604,9 @@ def load_clusters(self):

for file in self.config('cluster_files'):
prefix = file['cluster_prefix']
remap_functions['cluster_id'] = lambda row: prefix + ':' + row['cluster_id'].replace('Cluster', '')
remap_functions['cluster_id'] = (
lambda row: prefix + ':' + row['cluster_id'].replace('Cluster', '')
)

self.process_file(
file=file,
Expand Down Expand Up @@ -646,22 +663,18 @@ def load_data(self, dry_run=False):
if output['err_list']:
all_errs = all_errs + output['err_list']

if all_errs:
raise RuntimeError("\n".join(all_errs))

if dry_run:
# report stats on the data that has been gathered
return self.summarise_dataset()
# if there are no errors then save the dataset unless this is a dry run
if len(all_errs) == 0 and not dry_run:
self.save_dataset()

# otherwise, save the dataset
self.save_dataset()
return True
# report stats on the data that has been gathered
return self.summarise_dataset(all_errs)

def summarise_dataset(self):
def summarise_dataset(self, errs):
"""summarise the data that has been loaded"""

# go through the node index, checking for nodes that only have one attribute ('_key') or
# were loaded from the clusters files, with their only attributes being '_key' and 'clusters'
# were loaded from the clusters files, with only '_key' and 'clusters' attributes

node_type_ix = {
'__NO_TYPE__': 0
Expand Down Expand Up @@ -709,13 +722,86 @@ def summarise_dataset(self):
'cluster': len(node_data['cluster']),
'full': len(node_data['full'])
},
'errors_total': len(errs),
'errors': errs
}


if __name__ == '__main__':
def format_summary(summary, output):
if output == 'json':
return json.dumps(summary)
node_type_counts = [count for count in summary['node_type_count'].values()]
edge_type_counts = [count for count in summary['node_type_count'].values()]
values = [
summary['nodes_total'],
summary['edges_total'],
summary['nodes_in_edge'],
summary['node_data_available']['key_only'],
summary['node_data_available']['cluster'],
summary['node_data_available']['full'],
summary.get('errors_total'),
] + node_type_counts + edge_type_counts
value_width = max([len(str(value)) for value in values])
node_type_names = dict(__NO_TYPE__="No type")
node_types = "\n".join([(
f"{count:{value_width}} {node_type_names.get(ntype, ntype)}"
.format(value_width)
)
for ntype, count in summary['node_type_count'].items()
])
edge_type_names = dict()
edge_types = "\n".join([(
f"{count:{value_width}} {edge_type_names.get(etype, etype)}"
.format(value_width)
)
for etype, count in summary['edge_type_count'].items()
])
text_summary = f"""
{summary['nodes_total']:{value_width}} Total nodes
{summary['edges_total']:{value_width}} Total edges
{summary['nodes_in_edge']:{value_width}} Nodes in edge
---
Node Types
{node_types:{value_width}}
---
Edge Types
{edge_types:{value_width}}
---
Node data available
{summary['node_data_available']['key_only']:{value_width}} Key only
{summary['node_data_available']['cluster']:{value_width}} Cluster
{summary['node_data_available']['full']:{value_width}} Full
---
{summary.get('errors_total'):{value_width}} Errors
""".format(value_width)
return text_summary


def main():
argparser = argparse.ArgumentParser(description='Load DJORNL data')
argparser.add_argument(
'--dry-run', dest='dry', action='store_true',
help='Perform all actions of the parser, except loading the data.'
)
argparser.add_argument(
'--output', default='text',
help='Specify the format of any output generated. (text or json)'
)
args = argparser.parse_args()
parser = DJORNL_Parser()
summary = dict()
try:
parser.load_data()
summary = parser.load_data(dry_run=args.dry)
except Exception as err:
print(err)
print('Unhandled exception', err)
exit(1)
errors = summary.get('errors')
if summary:
print(format_summary(summary, args.output))
if errors:
error_output = f'Aborted with {len(errors)} errors.\n'
raise RuntimeError(error_output)


if __name__ == '__main__':
main()
12 changes: 6 additions & 6 deletions importers/test/test_djornl_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,9 @@ def test_errors(self, parser=None, errs={}):

with self.subTest(data_type="all types"):
# test all errors
with self.assertRaisesRegex(RuntimeError, all_errs[0]) as cm:
parser.load_data()
exception = cm.exception
err_list = exception.split("\n")
self.assertEqual(err_list, all_errs)
summary = parser.load_data(dry_run=True)
err_list = summary['errors']
self.assertEqual(err_list, all_errs)

def test_missing_required_env_var(self):
'''test that the parser exits with code 1 if the RES_ROOT_DATA_PATH env var is not set'''
Expand Down Expand Up @@ -303,7 +301,9 @@ def test_dry_run(self):
'node_data_available': {'cluster': 0, 'full': 14, 'key_only': 0},
'node_type_count': {'__NO_TYPE__': 0, 'gene': 10, 'pheno': 4},
'nodes_in_edge': 10,
'nodes_total': 14
'nodes_total': 14,
'errors_total': 0,
'errors': []
},
output
)
Expand Down
2 changes: 1 addition & 1 deletion importers/test/test_djornl_parser_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@ def test_the_full_shebang(self):
with modified_environ(RES_ROOT_DATA_PATH=os.path.join(_TEST_DIR, 'djornl', 'test_data')):
parser = DJORNL_Parser()
parser.load_data()
self.assertEqual(True, parser.load_data())
self.assertTrue(bool(parser.load_data()))

0 comments on commit 53eb379

Please sign in to comment.