From bc14c24a30c786183f544acdf0dd3f457b5c3312 Mon Sep 17 00:00:00 2001 From: Steven Kearnes Date: Tue, 16 Jul 2024 10:36:45 -0400 Subject: [PATCH] Parallel validation (#739) * Parallel validation * error kind * better formatting --- ord_schema/scripts/validate_dataset.py | 31 +++++++++++++++++++++----- ord_schema/validations.py | 9 ++------ 2 files changed, 27 insertions(+), 13 deletions(-) diff --git a/ord_schema/scripts/validate_dataset.py b/ord_schema/scripts/validate_dataset.py index ce51513d..f091aba3 100644 --- a/ord_schema/scripts/validate_dataset.py +++ b/ord_schema/scripts/validate_dataset.py @@ -14,18 +14,21 @@ """Validates a set of Dataset protocol buffers. Usage: - validate.py --input= [--filter=] + validate.py --input= [--filter= --n_jobs=] Options: --input= Input pattern for Dataset protos --filter= Regex filename filter + --n_jobs= Number of parallel workers [default: 1] """ import glob import re from collections.abc import Iterable +from concurrent.futures import ProcessPoolExecutor, as_completed import docopt from rdkit import RDLogger +from tqdm import tqdm from ord_schema import message_helpers, validations from ord_schema.logging import get_logger @@ -43,18 +46,34 @@ def filter_filenames(filenames: Iterable[str], pattern: str) -> list[str]: return filtered_filenames +def run(filename: str) -> None: + """Validates a single dataset.""" + RDLogger.DisableLog("rdApp.*") # Disable RDKit logging. + dataset = message_helpers.load_message(filename, dataset_pb2.Dataset) + validations.validate_datasets({filename: dataset}) + + def main(kwargs): filenames = sorted(glob.glob(kwargs["--input"], recursive=True)) logger.info("Found %d datasets", len(filenames)) if kwargs["--filter"]: filenames = filter_filenames(filenames, kwargs["--filter"]) logger.info("Filtered to %d datasets", len(filenames)) - for filename in filenames: - logger.info("Validating %s", filename) - dataset = message_helpers.load_message(filename, dataset_pb2.Dataset) - validations.validate_datasets({filename: dataset}) + futures = {} + failures = [] + with ProcessPoolExecutor(int(kwargs["--n_jobs"])) as executor: + for filename in filenames: + future = executor.submit(run, filename=filename) + futures[future] = filename + for future in tqdm(as_completed(futures), total=len(futures)): + try: + future.result() + except validations.ValidationError as error: + failures.append(f"{futures[future]}: {error}") + if failures: + text = "\n".join(failures) + raise validations.ValidationError(f"Dataset(s) failed validation:\n{text}") if __name__ == "__main__": - RDLogger.DisableLog("rdApp.*") # Disable RDKit logging. main(docopt.docopt(__doc__)) diff --git a/ord_schema/validations.py b/ord_schema/validations.py index 87e7783a..9f9e565c 100644 --- a/ord_schema/validations.py +++ b/ord_schema/validations.py @@ -117,17 +117,12 @@ def _validate_datasets( num_bad_reactions += 1 for error in reaction_output.errors: errors.append(error) - logger.warning(f"Validation error for {label}[{i}]: {error}") - num_successful = (len(dataset.reactions) - num_bad_reactions,) - logger.info( - f"Validation summary for {label}: {num_successful}/{len(dataset.reactions)} successful " - f"({num_bad_reactions} failures)" - ) + logger.error(f"Validation error for {label}[{i}]: {error}") # Dataset-level validation of cross-references. dataset_output = validate_message(dataset, raise_on_error=False, recurse=False, options=options) for error in dataset_output.errors: errors.append(error) - logger.warning(f"Validation error for {label}: {error}") + logger.error(f"Validation error for {label}: {error}") return errors