diff --git a/.gitignore b/.gitignore index 3b9d80dbe..379bd2972 100644 --- a/.gitignore +++ b/.gitignore @@ -30,7 +30,6 @@ pip-log.txt # Unit test / coverage reports .coverage .tox -nosetests.xml # Translations *.mo diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index eaf5a9501..c4227fd18 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,8 @@ +exclude: ^docs + repos: - repo: https://github.com/asottile/pyupgrade - rev: v2.29.0 + rev: v2.31.0 hooks: - id: pyupgrade args: [--py37-plus] @@ -12,18 +14,18 @@ repos: args: [--in-place, --remove-all-unused-imports, --remove-unused-variable, --ignore-init-module-imports] - repo: https://github.com/psf/black - rev: 21.4b2 + rev: 22.1.0 hooks: - id: black args: [--line-length, '120'] - repo: https://github.com/PyCQA/isort - rev: 5.9.3 + rev: 5.10.1 hooks: - id: isort - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.0.1 + rev: v4.1.0 hooks: - id: check-case-conflict - id: check-symlinks diff --git a/fireworks/core/launchpad.py b/fireworks/core/launchpad.py index 2ab9dd589..9e6ffbd16 100644 --- a/fireworks/core/launchpad.py +++ b/fireworks/core/launchpad.py @@ -182,7 +182,7 @@ def __init__( uri_mode (bool): if set True, all Mongo connection parameters occur through a MongoDB URI string (set as the host). mongoclient_kwargs (dict): A list of any other custom keyword arguments to be - passed into the MongoClient connection (non-URI mode only). Use these kwargs to specify SSL/TLS + passed into the MongoClient connection. Use these kwargs to specify SSL/TLS or serverSelectionTimeoutMS arguments. Note these arguments are different depending on the major pymongo version used; see pymongo documentation for more details. """ @@ -206,7 +206,7 @@ def __init__( # get connection if uri_mode: - self.connection = MongoClient(host) + self.connection = MongoClient(host, **self.mongoclient_kwargs) dbname = host.split("/")[-1].split("?")[0] # parse URI to extract dbname self.db = self.connection[dbname] else: diff --git a/fireworks/fw_config.py b/fireworks/fw_config.py index b5d8666df..bfdf62cb3 100644 --- a/fireworks/fw_config.py +++ b/fireworks/fw_config.py @@ -17,7 +17,7 @@ NEGATIVE_FWID_CTR = 0 # this is where load_object() looks for serialized objects -USER_PACKAGES = ["fireworks.user_objects", "fireworks.utilities.tests", "fw_tutorials", "fireworks.features"] +USER_PACKAGES = ["fireworks.user_objects", "fw_tutorials", "fireworks.features"] # if you update a _fw_name, you can use this to record the change and maintain deserialization FW_NAME_UPDATES = { diff --git a/fireworks/scripts/_helpers.py b/fireworks/scripts/_helpers.py new file mode 100644 index 000000000..cbad25c8b --- /dev/null +++ b/fireworks/scripts/_helpers.py @@ -0,0 +1,41 @@ +import os +from argparse import Namespace +from typing import Sequence, Tuple, Union + +cfg_file_vldtor = Tuple[str, str, bool, Union[str, None]] + + +def _validate_config_file_paths(args: Namespace, cfg_files_to_validate: Sequence[cfg_file_vldtor]) -> None: + """Validate the CLI config files. + + Args: + args (argparse.Namespace): The parsed arguments from the CLI. + cfg_files_to_validate (list[tuple[str, str, bool, str | None]]): config files to validate. + Tuple is (config filename, CLI flag, is filepath required, default config file location). + + Raises: + ValueError: If a path to a required config file is not provided. + FileNotFoundError: If a config file is provided but does not exist. + """ + for filename, cli_flag, required, default_loc in cfg_files_to_validate: + + attr_name = f"{filename}_file" + file_path = getattr(args, attr_name) + + # args.config_dir defaults to '.' if not specified + file_in_config_dir = os.path.join(args.config_dir, f"my_{filename}.yaml") + if file_path is None and os.path.exists(file_in_config_dir): + setattr(args, attr_name, file_in_config_dir) + elif file_path is None: + setattr(args, attr_name, default_loc) + + file_path = getattr(args, attr_name, None) + + # throw on missing config files + if file_path is None and required: + raise ValueError( + f"No path specified for {attr_name}. Use the {cli_flag} flag to specify or check the value " + f"of CONFIG_FILE_DIR and make sure it points at where all your config files are." + ) + if file_path is not None and not os.path.exists(file_path): + raise FileNotFoundError(f"{attr_name} '{file_path}' does not exist!") diff --git a/fireworks/scripts/lpad_run.py b/fireworks/scripts/lpad_run.py index d224e5fad..0545189c4 100644 --- a/fireworks/scripts/lpad_run.py +++ b/fireworks/scripts/lpad_run.py @@ -10,9 +10,8 @@ import re import sys import time -import traceback -from argparse import ArgumentParser, ArgumentTypeError -from typing import Optional, Sequence +from argparse import ArgumentParser, ArgumentTypeError, Namespace +from typing import Any, Callable, Dict, List, Optional, Sequence, Union import ruamel.yaml as yaml from pymongo import ASCENDING, DESCENDING @@ -37,10 +36,14 @@ from fireworks.user_objects.firetasks.script_task import ScriptTask from fireworks.utilities.fw_serializers import DATETIME_HANDLER, recursive_dict +from ._helpers import _validate_config_file_paths + if sys.version_info < (3, 8): import importlib_metadata as metadata + from typing_extensions import Literal else: from importlib import metadata + from typing import Literal __author__ = "Anubhav Jain" __credits__ = "Shyue Ping Ong" @@ -52,7 +55,7 @@ DEFAULT_LPAD_YAML = "my_launchpad.yaml" -def pw_check(ids, args, skip_pw=False): +def pw_check(ids: List[int], args: Namespace, skip_pw: bool = False) -> List[int]: if len(ids) > PW_CHECK_NUM and not skip_pw: m_password = datetime.datetime.now().strftime("%Y-%m-%d") if not args.password: @@ -68,17 +71,17 @@ def pw_check(ids, args, skip_pw=False): return ids -def parse_helper(lp, args, wf_mode=False, skip_pw=False): +def parse_helper(lp: LaunchPad, args: Namespace, wf_mode: bool = False, skip_pw: bool = False) -> List[int]: """ Helper method to parse args that can take either id, name, state or query. Args: - args - wf_mode (bool) - skip_pw (bool) + args: Namespace of parsed CLI arguments. + wf_mode (bool): If True, will query lp for workflow instead of fireworks IDs. + skip_pw (bool): If True, skip PW check. Defaults to False. Returns: - list of ids + list[int]: Firework or Workflow IDs. """ if args.fw_id and sum(bool(x) for x in [args.name, args.state, args.query]) >= 1: raise ValueError("Cannot specify both fw_id and name/state/query)") @@ -108,23 +111,25 @@ def parse_helper(lp, args, wf_mode=False, skip_pw=False): return pw_check(lp.get_fw_ids(query, sort=sort, limit=max, launches_mode=args.launches_mode), args, skip_pw) -def get_lp(args): +def get_lp(args: Namespace) -> LaunchPad: try: - if not args.launchpad_file: - if os.path.exists(os.path.join(args.config_dir, DEFAULT_LPAD_YAML)): - args.launchpad_file = os.path.join(args.config_dir, DEFAULT_LPAD_YAML) - else: - args.launchpad_file = LAUNCHPAD_LOC - if args.launchpad_file: - return LaunchPad.from_file(args.launchpad_file) + lp = LaunchPad.from_file(args.launchpad_file) else: + args.loglvl = "CRITICAL" if args.silencer else args.loglvl - return LaunchPad(logdir=args.logdir, strm_lvl=args.loglvl) + # no lpad file means we try connect to localhost which is fast so use small timeout + # (default 30s) for quick response to user if no DB is running + mongo_kwds = {"serverSelectionTimeoutMS": 500} + lp = LaunchPad(logdir=args.logdir, strm_lvl=args.loglvl, mongoclient_kwargs=mongo_kwds) + + # make sure we can connect to DB, raises pymongo.errors.ServerSelectionTimeoutError if not + lp.connection.admin.command("ping") + return lp + except Exception: - traceback.print_exc() err_message = ( - "FireWorks was not able to connect to MongoDB. Is the server running? " + f"FireWorks was not able to connect to MongoDB at {lp.host}:{lp.port}. Is the server running? " f"The database file specified was {args.launchpad_file}." ) if not args.launchpad_file: @@ -133,10 +138,11 @@ def get_lp(args): "location and credentials of your Mongo database (otherwise use default " "localhost configuration)." ) - raise ValueError(err_message) + # use from None to hide the pymongo ServerSelectionTimeoutError that otherwise clutters up the stack trace + raise ValueError(err_message) from None -def init_yaml(args): +def init_yaml(args: Namespace) -> None: if args.uri_mode: fields = ( ("host", None, "Example: mongodb+srv://USER:PASSWORD@CLUSTERNAME.mongodb.net/fireworks"), @@ -164,19 +170,19 @@ def init_yaml(args): ), ) - doc = {} + doc: Dict[str, Union[str, int, bool, None]] = {} if args.uri_mode: print( "Note 1: You are in URI format mode. This means that all database parameters (username, password, host, " "port, database name, etc.) must be present in the URI. See: " "https://docs.mongodb.com/manual/reference/connection-string/ for details." ) - print("(Enter your connection URI in under the 'host' parameter)") + print("(Enter your connection URI through the 'host' parameter)") print("Please supply the following configuration values") print("(press Enter if you want to accept the defaults)\n") for k, default, helptext in fields: val = input(f"Enter {k} parameter. (default: {default}). {helptext}: ") - doc[k] = val if val else default + doc[k] = val or default if "port" in doc: doc["port"] = int(doc["port"]) # enforce the port as an int if args.uri_mode: @@ -187,7 +193,7 @@ def init_yaml(args): print(f"\nConfiguration written to {args.config_file}!") -def reset(args): +def reset(args: Namespace) -> None: lp = get_lp(args) if not args.password: if ( @@ -202,7 +208,7 @@ def reset(args): lp.reset(args.password) -def add_wf(args): +def add_wf(args: Namespace) -> None: lp = get_lp(args) if args.dir: files = [] @@ -219,31 +225,31 @@ def add_wf(args): lp.add_wf(fwf) -def append_wf(args): +def append_wf(args: Namespace) -> None: lp = get_lp(args) lp.append_wf(Workflow.from_file(args.wf_file), args.fw_id, detour=args.detour, pull_spec_mods=args.pull_spec_mods) -def dump_wf(args): +def dump_wf(args: Namespace) -> None: lp = get_lp(args) lp.get_wf_by_fw_id(args.fw_id).to_file(args.wf_file) -def check_wf(args): +def check_wf(args: Namespace) -> None: from fireworks.utilities.dagflow import DAGFlow lp = get_lp(args) DAGFlow.from_fireworks(lp.get_wf_by_fw_id(args.fw_id)).check() -def add_wf_dir(args): +def add_wf_dir(args: Namespace) -> None: lp = get_lp(args) for filename in os.listdir(args.wf_dir): fwf = Workflow.from_file(filename) lp.add_wf(fwf) -def print_fws(ids, lp, args): +def print_fws(ids, lp, args: Namespace) -> None: """Prints results of some FireWorks query to stdout.""" fws = [] if args.display_format == "ids": @@ -269,7 +275,7 @@ def print_fws(ids, lp, args): print(args.output(fws)) -def get_fw_ids_helper(lp, args, count_only=None): +def get_fw_ids_helper(lp: LaunchPad, args: Namespace, count_only: Union[bool, None] = None) -> Union[List[int], int]: """Build fws query from command line options and submit. Parameters: @@ -278,7 +284,7 @@ def get_fw_ids_helper(lp, args, count_only=None): count_only (bool): if None, then looked up in args. Returns: - [int]: resulting fw_ids or count of fws in query. + list[int] | int: resulting fw_ids or count of fws in query. """ if sum(bool(x) for x in [args.fw_id, args.name, args.state, args.query]) > 1: raise ValueError("Please specify exactly one of (fw_id, name, state, query)") @@ -321,7 +327,9 @@ def get_fw_ids_helper(lp, args, count_only=None): return ids -def get_fws_helper(lp, ids, args): +def get_fws_helper( + lp: LaunchPad, ids: List[int], args: Namespace +) -> Union[List[int], int, List[Dict[str, Union[str, int, bool]]], Union[str, int, bool]]: """Get fws from ids in a representation according to args.display_format.""" fws = [] if args.display_format == "ids": @@ -341,19 +349,17 @@ def get_fws_helper(lp, ids, args): if "launches" in d: del d["launches"] fws.append(d) - if len(fws) == 1: - fws = fws[0] - return fws + return fws[0] if len(fws) == 1 else fws -def get_fws(args): +def get_fws(args: Namespace) -> None: lp = get_lp(args) ids = get_fw_ids_helper(lp, args) fws = get_fws_helper(lp, ids, args) print(args.output(fws)) -def get_fws_in_wfs(args): +def get_fws_in_wfs(args: Namespace) -> None: # get_wfs lp = get_lp(args) if sum(bool(x) for x in [args.wf_fw_id, args.wf_name, args.wf_state, args.wf_query]) > 1: @@ -419,13 +425,13 @@ def get_fws_in_wfs(args): print_fws(ids, lp, args) -def update_fws(args): +def update_fws(args: Namespace) -> None: lp = get_lp(args) fw_ids = parse_helper(lp, args) lp.update_spec(fw_ids, json.loads(args.update), args.mongo) -def get_wfs(args): +def get_wfs(args: Namespace) -> None: lp = get_lp(args) if sum(bool(x) for x in [args.fw_id, args.name, args.state, args.query]) > 1: raise ValueError("Please specify exactly one of (fw_id, name, state, query)") @@ -478,7 +484,7 @@ def get_wfs(args): print(args.output(wfs)) -def delete_wfs(args): +def delete_wfs(args: Namespace) -> None: lp = get_lp(args) fw_ids = parse_helper(lp, args, wf_mode=True) for f in fw_ids: @@ -498,7 +504,7 @@ def get_children(links, start, max_depth): return data -def detect_lostruns(args): +def detect_lostruns(args: Namespace) -> None: lp = get_lp(args) query = ast.literal_eval(args.query) if args.query else None launch_query = ast.literal_eval(args.launch_query) if args.launch_query else None @@ -525,7 +531,7 @@ def detect_lostruns(args): print("You can fix inconsistent FWs using the --refresh argument to the detect_lostruns command") -def detect_unreserved(args): +def detect_unreserved(args: Namespace) -> None: lp = get_lp(args) if args.display_format is not None and args.display_format != "none": unreserved = lp.detect_unreserved(expiration_secs=args.time, rerun=False) @@ -538,12 +544,12 @@ def detect_unreserved(args): print(lp.detect_unreserved(expiration_secs=args.time, rerun=args.rerun)) -def tuneup(args): +def tuneup(args: Namespace) -> None: lp = get_lp(args) lp.tuneup(bkground=not args.full) -def defuse_wfs(args): +def defuse_wfs(args: Namespace) -> None: lp = get_lp(args) fw_ids = parse_helper(lp, args, wf_mode=True) for f in fw_ids: @@ -557,7 +563,7 @@ def defuse_wfs(args): ) -def pause_wfs(args): +def pause_wfs(args: Namespace) -> None: lp = get_lp(args) fw_ids = parse_helper(lp, args, wf_mode=True) for f in fw_ids: @@ -566,7 +572,7 @@ def pause_wfs(args): lp.m_logger.info(f"Finished defusing {len(fw_ids)} FWs.") -def archive(args): +def archive(args: Namespace) -> None: lp = get_lp(args) fw_ids = parse_helper(lp, args, wf_mode=True) for f in fw_ids: @@ -575,7 +581,7 @@ def archive(args): lp.m_logger.info(f"Finished archiving {len(fw_ids)} WFs") -def reignite_wfs(args): +def reignite_wfs(args: Namespace) -> None: lp = get_lp(args) fw_ids = parse_helper(lp, args, wf_mode=True) for f in fw_ids: @@ -584,7 +590,7 @@ def reignite_wfs(args): lp.m_logger.info(f"Finished reigniting {len(fw_ids)} Workflows") -def defuse_fws(args): +def defuse_fws(args: Namespace) -> None: lp = get_lp(args) fw_ids = parse_helper(lp, args) for f in fw_ids: @@ -593,7 +599,7 @@ def defuse_fws(args): lp.m_logger.info(f"Finished defusing {len(fw_ids)} FWs") -def pause_fws(args): +def pause_fws(args: Namespace) -> None: lp = get_lp(args) fw_ids = parse_helper(lp, args) for f in fw_ids: @@ -602,7 +608,7 @@ def pause_fws(args): lp.m_logger.info(f"Finished pausing {len(fw_ids)} FWs") -def reignite_fws(args): +def reignite_fws(args: Namespace) -> None: lp = get_lp(args) fw_ids = parse_helper(lp, args) for f in fw_ids: @@ -611,7 +617,7 @@ def reignite_fws(args): lp.m_logger.info(f"Finished reigniting {len(fw_ids)} FWs") -def resume_fws(args): +def resume_fws(args: Namespace) -> None: lp = get_lp(args) fw_ids = parse_helper(lp, args) for f in fw_ids: @@ -620,7 +626,7 @@ def resume_fws(args): lp.m_logger.info(f"Finished resuming {len(fw_ids)} FWs") -def rerun_fws(args): +def rerun_fws(args: Namespace) -> None: lp = get_lp(args) fw_ids = parse_helper(lp, args) if args.task_level: @@ -637,7 +643,7 @@ def rerun_fws(args): lp.m_logger.info(f"Finished setting {len(fw_ids)} FWs to rerun") -def refresh(args): +def refresh(args: Namespace) -> None: lp = get_lp(args) fw_ids = parse_helper(lp, args, wf_mode=True) for f in fw_ids: @@ -648,7 +654,7 @@ def refresh(args): lp.m_logger.info(f"Finished refreshing {len(fw_ids)} Workflows") -def unlock(args): +def unlock(args: Namespace) -> None: lp = get_lp(args) fw_ids = parse_helper(lp, args, wf_mode=True) for f in fw_ids: @@ -658,13 +664,13 @@ def unlock(args): lp.m_logger.info(f"Finished unlocking {len(fw_ids)} Workflows") -def get_qid(args): +def get_qid(args: Namespace) -> None: lp = get_lp(args) for f in args.fw_id: print(lp.get_reservation_id_from_fw_id(f)) -def cancel_qid(args): +def cancel_qid(args: Namespace) -> None: lp = get_lp(args) lp.m_logger.warning( "WARNING: cancel_qid does not actually remove jobs from the queue " @@ -673,7 +679,7 @@ def cancel_qid(args): lp.cancel_reservation_by_reservation_id(args.qid) -def set_priority(args): +def set_priority(args: Namespace) -> None: wf_mode = args.wf lp = get_lp(args) fw_ids = parse_helper(lp, args, wf_mode=wf_mode) @@ -697,7 +703,7 @@ def _open_webbrowser(url): webbrowser.open(url) -def webgui(args): +def webgui(args: Namespace) -> None: from fireworks.flask_site.app import app app.lp = get_lp(args) @@ -737,7 +743,7 @@ def webgui(args): StandaloneApplication(app, options).run() -def add_scripts(args): +def add_scripts(args: Namespace) -> None: lp = get_lp(args) args.names = args.names if args.names else [None] * len(args.scripts) args.wf_name = args.wf_name if args.wf_name else args.names[0] @@ -751,7 +757,7 @@ def add_scripts(args): lp.add_wf(Workflow(fws, links, args.wf_name)) -def recover_offline(args): +def recover_offline(args: Namespace) -> None: lp = get_lp(args) fworker_name = FWorker.from_file(args.fworker_file).name if args.fworker_file else None failed_fws = [] @@ -771,7 +777,7 @@ def recover_offline(args): lp.m_logger.info(f"FAILED to recover offline fw_ids: {failed_fws}") -def forget_offline(args): +def forget_offline(args: Namespace) -> None: lp = get_lp(args) fw_ids = parse_helper(lp, args) for f in fw_ids: @@ -780,7 +786,7 @@ def forget_offline(args): lp.m_logger.info(f"Finished forget_offine, processed {len(fw_ids)} FWs") -def report(args): +def report(args: Namespace) -> None: lp = get_lp(args) query = ast.literal_eval(args.query) if args.query else None fwr = FWReport(lp) @@ -795,7 +801,7 @@ def report(args): print(fwr.get_stats_str(stats)) -def introspect(args): +def introspect(args: Namespace) -> None: print("NOTE: This feature is in beta mode...") lp = get_lp(args) isp = Introspector(lp) @@ -807,13 +813,13 @@ def introspect(args): print("") -def get_launchdir(args): +def get_launchdir(args: Namespace) -> None: lp = get_lp(args) ld = lp.get_launchdir(args.fw_id, args.launch_idx) print(ld) -def track_fws(args): +def track_fws(args: Namespace) -> None: lp = get_lp(args) fw_ids = parse_helper(lp, args, skip_pw=True) include = args.include @@ -837,12 +843,12 @@ def track_fws(args): print("\n".join(output)) -def maintain(args): +def maintain(args: Namespace) -> None: lp = get_lp(args) lp.maintain(args.infinite, args.maintain_interval) -def orphaned(args): +def orphaned(args: Namespace) -> None: # get_fws lp = get_lp(args) fw_ids = get_fw_ids_helper(lp, args, count_only=False) @@ -863,14 +869,14 @@ def orphaned(args): print(args.output(fws)) -def get_output_func(format): +def get_output_func(format: Literal["json", "yaml"]) -> Callable[[str], Any]: if format == "json": return lambda x: json.dumps(x, default=DATETIME_HANDLER, indent=4) else: return lambda x: yaml.safe_dump(recursive_dict(x, preserve_unicode=False), default_flow_style=False) -def arg_positive_int(value): +def arg_positive_int(value: str) -> int: try: ivalue = int(value) except ValueError: @@ -1524,6 +1530,11 @@ def lpad(argv: Optional[Sequence[str]] = None) -> int: args = parser.parse_args(argv) + cfg_files_to_check = [("launchpad", "-l", False, LAUNCHPAD_LOC)] + if hasattr(args, "fworker_file"): + cfg_files_to_check.append(("fworker", "-w", False, FWORKER_LOC)) + _validate_config_file_paths(args, cfg_files_to_check) + args.output = get_output_func(args.output) if args.command is None: diff --git a/fireworks/scripts/mlaunch_run.py b/fireworks/scripts/mlaunch_run.py index 0b1f075db..d0505f570 100644 --- a/fireworks/scripts/mlaunch_run.py +++ b/fireworks/scripts/mlaunch_run.py @@ -12,6 +12,8 @@ from fireworks.features.multi_launcher import launch_multiprocess from fireworks.fw_config import CONFIG_FILE_DIR, FWORKER_LOC, LAUNCHPAD_LOC +from ._helpers import _validate_config_file_paths + if sys.version_info < (3, 8): import importlib_metadata as metadata else: @@ -81,15 +83,11 @@ def mlaunch(argv: Optional[Sequence[str]] = None) -> int: args = parser.parse_args(argv) - if ( - not args.launchpad_file - and args.config_dir - and os.path.exists(os.path.join(args.config_dir, "my_launchpad.yaml")) - ): - args.launchpad_file = os.path.join(args.config_dir, "my_launchpad.yaml") - - if not args.fworker_file and args.config_dir and os.path.exists(os.path.join(args.config_dir, "my_fworker.yaml")): - args.fworker_file = os.path.join(args.config_dir, "my_fworker.yaml") + cfg_files_to_check = [ + ("launchpad", "-l", False, LAUNCHPAD_LOC), + ("fworker", "-w", False, FWORKER_LOC), + ] + _validate_config_file_paths(args, cfg_files_to_check) args.loglvl = "CRITICAL" if args.silencer else args.loglvl diff --git a/fireworks/scripts/qlaunch_run.py b/fireworks/scripts/qlaunch_run.py index 379974a35..f61b1a62b 100644 --- a/fireworks/scripts/qlaunch_run.py +++ b/fireworks/scripts/qlaunch_run.py @@ -28,6 +28,8 @@ from fireworks.queue.queue_launcher import launch_rocket_to_queue, rapidfire from fireworks.utilities.fw_serializers import load_object_from_file +from ._helpers import _validate_config_file_paths + if sys.version_info < (3, 8): import importlib_metadata as metadata else: @@ -41,24 +43,17 @@ def do_launch(args): - if not args.launchpad_file and os.path.exists(os.path.join(args.config_dir, "my_launchpad.yaml")): - args.launchpad_file = os.path.join(args.config_dir, "my_launchpad.yaml") - elif not args.launchpad_file: - args.launchpad_file = LAUNCHPAD_LOC - - if not args.fworker_file and os.path.exists(os.path.join(args.config_dir, "my_fworker.yaml")): - args.fworker_file = os.path.join(args.config_dir, "my_fworker.yaml") - elif not args.fworker_file: - args.fworker_file = FWORKER_LOC - if not args.queueadapter_file and os.path.exists(os.path.join(args.config_dir, "my_qadapter.yaml")): - args.queueadapter_file = os.path.join(args.config_dir, "my_qadapter.yaml") - elif not args.queueadapter_file: - args.queueadapter_file = QUEUEADAPTER_LOC + cfg_files_to_check = [ + ("launchpad", "-l", False, LAUNCHPAD_LOC), + ("fworker", "-w", False, FWORKER_LOC), + ("qadapter", "-q", True, QUEUEADAPTER_LOC), + ] + _validate_config_file_paths(args, cfg_files_to_check) launchpad = LaunchPad.from_file(args.launchpad_file) if args.launchpad_file else LaunchPad(strm_lvl=args.loglvl) fworker = FWorker.from_file(args.fworker_file) if args.fworker_file else FWorker() - queueadapter = load_object_from_file(args.queueadapter_file) + queueadapter = load_object_from_file(args.qadapter_file) args.loglvl = "CRITICAL" if args.silencer else args.loglvl if args.command == "rapidfire": @@ -174,7 +169,7 @@ def qlaunch(argv: Optional[Sequence[str]] = None) -> int: parser.add_argument("-r", "--reserve", help="reserve a fw", action="store_true") parser.add_argument("-l", "--launchpad_file", help="path to launchpad file") parser.add_argument("-w", "--fworker_file", help="path to fworker file") - parser.add_argument("-q", "--queueadapter_file", help="path to queueadapter file") + parser.add_argument("-q", "--queueadapter_file", help="path to qadapter file") parser.add_argument( "-c", "--config_dir", @@ -213,10 +208,12 @@ def qlaunch(argv: Optional[Sequence[str]] = None) -> int: pass args = parser.parse_args(argv) + if hasattr(args, "queueadapter_file"): + args.qadapter_file = args.queueadapter_file if args.remote_host and not HAS_FABRIC: print("Remote options require the Fabric package v2+ to be installed!") - sys.exit(-1) + raise SystemExit(-1) if args.remote_setup and args.remote_host: for h in args.remote_host: @@ -276,7 +273,7 @@ def qlaunch(argv: Optional[Sequence[str]] = None) -> int: else: break - return 0 + return 0 if __name__ == "__main__": diff --git a/fireworks/scripts/rlaunch_run.py b/fireworks/scripts/rlaunch_run.py index 4f7e321c0..d51c0478e 100644 --- a/fireworks/scripts/rlaunch_run.py +++ b/fireworks/scripts/rlaunch_run.py @@ -15,6 +15,8 @@ from fireworks.fw_config import CONFIG_FILE_DIR, FWORKER_LOC, LAUNCHPAD_LOC from fireworks.utilities.fw_utilities import get_fw_logger, get_my_host, get_my_ip +from ._helpers import _validate_config_file_paths + if sys.version_info < (3, 8): import importlib_metadata as metadata else: @@ -130,15 +132,11 @@ def rlaunch(argv: Optional[Sequence[str]] = None) -> int: signal.signal(signal.SIGINT, handle_interrupt) # graceful exit on ^C - if not args.launchpad_file and os.path.exists(os.path.join(args.config_dir, "my_launchpad.yaml")): - args.launchpad_file = os.path.join(args.config_dir, "my_launchpad.yaml") - elif not args.launchpad_file: - args.launchpad_file = LAUNCHPAD_LOC - - if not args.fworker_file and os.path.exists(os.path.join(args.config_dir, "my_fworker.yaml")): - args.fworker_file = os.path.join(args.config_dir, "my_fworker.yaml") - elif not args.fworker_file: - args.fworker_file = FWORKER_LOC + cfg_files_to_check = [ + ("launchpad", "-l", False, LAUNCHPAD_LOC), + ("fworker", "-w", False, FWORKER_LOC), + ] + _validate_config_file_paths(args, cfg_files_to_check) args.loglvl = "CRITICAL" if args.silencer else args.loglvl diff --git a/fireworks/scripts/tests/__init__.py b/fireworks/scripts/tests/__init__.py index b99e5b3ad..acecacdb5 100644 --- a/fireworks/scripts/tests/__init__.py +++ b/fireworks/scripts/tests/__init__.py @@ -1 +1,5 @@ __author__ = "Janosh Riebesell " + +from os.path import abspath, dirname + +module_dir = dirname(abspath(__file__)) diff --git a/fireworks/scripts/tests/test_lpad_run.py b/fireworks/scripts/tests/test_lpad_run.py index 76c3bf814..9152aac84 100644 --- a/fireworks/scripts/tests/test_lpad_run.py +++ b/fireworks/scripts/tests/test_lpad_run.py @@ -57,3 +57,13 @@ def test_lpad_report_version(capsys, arg): assert stdout.startswith("lpad v") assert stderr == "" + + +def test_lpad_config_file_flags(): + """Test lpad CLI throws errors on missing config file flags.""" + + with pytest.raises(FileNotFoundError, match="launchpad_file '' does not exist!"): + lpad(["-l", "", "get_fws"]) + + with pytest.raises(FileNotFoundError, match="fworker_file 'missing_file' does not exist!"): + lpad(["recover_offline", "-w", "missing_file"]) diff --git a/fireworks/scripts/tests/test_mlaunch_run.py b/fireworks/scripts/tests/test_mlaunch_run.py index 3d4a0951a..07b1e1009 100644 --- a/fireworks/scripts/tests/test_mlaunch_run.py +++ b/fireworks/scripts/tests/test_mlaunch_run.py @@ -17,3 +17,15 @@ def test_mlaunch_report_version(capsys, arg): assert stdout.startswith("mlaunch v") assert stderr == "" + + +def test_mlaunch_config_file_flags(): + """Test mlaunch CLI throws errors on missing config file flags.""" + + num_jobs = "1" + + with pytest.raises(FileNotFoundError, match="launchpad_file '' does not exist!"): + mlaunch([num_jobs, "-l", ""]) + + with pytest.raises(FileNotFoundError, match="fworker_file 'missing_file' does not exist!"): + mlaunch([num_jobs, "-w", "missing_file"]) diff --git a/fireworks/scripts/tests/test_qlaunch_run.py b/fireworks/scripts/tests/test_qlaunch_run.py index 44b166d94..ef2e902b1 100644 --- a/fireworks/scripts/tests/test_qlaunch_run.py +++ b/fireworks/scripts/tests/test_qlaunch_run.py @@ -2,6 +2,8 @@ from fireworks.scripts.qlaunch_run import qlaunch +from . import module_dir + __author__ = "Janosh Riebesell " @@ -10,10 +12,28 @@ def test_qlaunch_report_version(capsys, arg): """Test qlaunch CLI version flag.""" with pytest.raises(SystemExit): - ret_code = qlaunch([arg]) - assert ret_code == 0 + qlaunch([arg]) stdout, stderr = capsys.readouterr() assert stdout.startswith("qlaunch v") assert stderr == "" + + +def test_qlaunch_config_file_flags(): + """Test qlaunch CLI throws errors on missing config file flags.""" + + # qadapter.yaml is mandatory, test for ValueError if missing + with pytest.raises(ValueError, match="No path specified for qadapter_file."): + qlaunch([]) + + # qadapter.yaml is mandatory, test for ValueError if missing + with pytest.raises(FileNotFoundError, match="qadapter_file '' does not exist!"): + qlaunch(["-q", ""]) + + with pytest.raises(FileNotFoundError, match="qadapter_file 'missing_file' does not exist!"): + qlaunch(["-q", "missing_file"]) + + with pytest.raises(FileNotFoundError, match="launchpad_file '' does not exist!"): + qadapter_file = f"{module_dir}/__init__.py" # just any file that passes os.path.exists() + qlaunch(["-q", qadapter_file, "-l", ""]) diff --git a/fireworks/scripts/tests/test_rlaunch_run.py b/fireworks/scripts/tests/test_rlaunch_run.py index a043336d7..f13e95ec9 100644 --- a/fireworks/scripts/tests/test_rlaunch_run.py +++ b/fireworks/scripts/tests/test_rlaunch_run.py @@ -17,3 +17,13 @@ def test_rlaunch_report_version(capsys, arg): assert stdout.startswith("rlaunch v") assert stderr == "" + + +def test_rlaunch_config_file_flags(): + """Test rlaunch CLI throws errors on missing config file flags.""" + + with pytest.raises(FileNotFoundError, match="launchpad_file '' does not exist!"): + rlaunch(["-l", ""]) + + with pytest.raises(FileNotFoundError, match="fworker_file 'missing_file' does not exist!"): + rlaunch(["-w", "missing_file"]) diff --git a/requirements-ci.txt b/requirements-ci.txt index 75f3c2915..af91fdd03 100644 --- a/requirements-ci.txt +++ b/requirements-ci.txt @@ -1,6 +1,6 @@ -pytest==5.3.2 -pytest-cov==2.8.1 -coverage==5.0.1 -pycodestyle==2.5.0 matplotlib -graphviz \ No newline at end of file +graphviz +pytest +pytest-cov +coverage +pycodestyle diff --git a/requirements.txt b/requirements.txt index bdf751800..0aea8ed6d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,10 @@ ruamel.yaml==0.16.5 pymongo==3.10.0 -Jinja2==2.11.3 +Jinja2 monty==3.0.2 python-dateutil==2.8.1 tabulate==0.8.6 flask==1.1.1 flask-paginate==0.5.5 gunicorn==20.0.4 -tqdm==4.41.0 \ No newline at end of file +tqdm==4.41.0 diff --git a/setup.py b/setup.py index 2a0acfa5e..27eb85287 100644 --- a/setup.py +++ b/setup.py @@ -38,7 +38,8 @@ "flask-paginate>=0.4.5", "gunicorn>=19.6.0", "tqdm>=4.8.4", - "importlib-metadata>=4.8.2;python_version<'3.8'", + "importlib-metadata>=4.8.2; python_version<'3.8'", + "typing-extensions; python_version<'3.8'", ], extras_require={ "rtransfer": ["paramiko>=2.4.2"], diff --git a/tasks.py b/tasks.py index d7894af95..8901e930f 100644 --- a/tasks.py +++ b/tasks.py @@ -76,9 +76,7 @@ def release_github(ctx): @task -def release(ctx, nosetest=False): - if nosetest: - ctx.run("nosetests") +def release(ctx): publish(ctx) update_doc(ctx) release_github(ctx)