Skip to content

Commit

Permalink
Parallel validation (#739)
Browse files Browse the repository at this point in the history
* Parallel validation

* error kind

* better formatting
  • Loading branch information
skearnes authored Jul 16, 2024
1 parent f5cfc5b commit bc14c24
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 13 deletions.
31 changes: 25 additions & 6 deletions ord_schema/scripts/validate_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,21 @@
"""Validates a set of Dataset protocol buffers.
Usage:
validate.py --input=<str> [--filter=<str>]
validate.py --input=<str> [--filter=<str> --n_jobs=<int>]
Options:
--input=<str> Input pattern for Dataset protos
--filter=<str> Regex filename filter
--n_jobs=<int> Number of parallel workers [default: 1]
"""
import glob
import re
from collections.abc import Iterable
from concurrent.futures import ProcessPoolExecutor, as_completed

import docopt
from rdkit import RDLogger
from tqdm import tqdm

from ord_schema import message_helpers, validations
from ord_schema.logging import get_logger
Expand All @@ -43,18 +46,34 @@ def filter_filenames(filenames: Iterable[str], pattern: str) -> list[str]:
return filtered_filenames


def run(filename: str) -> None:
"""Validates a single dataset."""
RDLogger.DisableLog("rdApp.*") # Disable RDKit logging.
dataset = message_helpers.load_message(filename, dataset_pb2.Dataset)
validations.validate_datasets({filename: dataset})


def main(kwargs):
filenames = sorted(glob.glob(kwargs["--input"], recursive=True))
logger.info("Found %d datasets", len(filenames))
if kwargs["--filter"]:
filenames = filter_filenames(filenames, kwargs["--filter"])
logger.info("Filtered to %d datasets", len(filenames))
for filename in filenames:
logger.info("Validating %s", filename)
dataset = message_helpers.load_message(filename, dataset_pb2.Dataset)
validations.validate_datasets({filename: dataset})
futures = {}
failures = []
with ProcessPoolExecutor(int(kwargs["--n_jobs"])) as executor:
for filename in filenames:
future = executor.submit(run, filename=filename)
futures[future] = filename
for future in tqdm(as_completed(futures), total=len(futures)):
try:
future.result()
except validations.ValidationError as error:
failures.append(f"{futures[future]}: {error}")
if failures:
text = "\n".join(failures)
raise validations.ValidationError(f"Dataset(s) failed validation:\n{text}")


if __name__ == "__main__":
RDLogger.DisableLog("rdApp.*") # Disable RDKit logging.
main(docopt.docopt(__doc__))
9 changes: 2 additions & 7 deletions ord_schema/validations.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,17 +117,12 @@ def _validate_datasets(
num_bad_reactions += 1
for error in reaction_output.errors:
errors.append(error)
logger.warning(f"Validation error for {label}[{i}]: {error}")
num_successful = (len(dataset.reactions) - num_bad_reactions,)
logger.info(
f"Validation summary for {label}: {num_successful}/{len(dataset.reactions)} successful "
f"({num_bad_reactions} failures)"
)
logger.error(f"Validation error for {label}[{i}]: {error}")
# Dataset-level validation of cross-references.
dataset_output = validate_message(dataset, raise_on_error=False, recurse=False, options=options)
for error in dataset_output.errors:
errors.append(error)
logger.warning(f"Validation error for {label}: {error}")
logger.error(f"Validation error for {label}: {error}")

return errors

Expand Down

0 comments on commit bc14c24

Please sign in to comment.