Skip to content

Commit

Permalink
introduce cli neps (init, run)
Browse files Browse the repository at this point in the history
  • Loading branch information
danrgll committed Jul 30, 2024
1 parent 82b4890 commit de21ef8
Show file tree
Hide file tree
Showing 3 changed files with 225 additions and 0 deletions.
1 change: 1 addition & 0 deletions neps/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from neps.utils.common import get_searcher_data, get_value
from neps.optimizers.info import SearcherConfigs


def run(
run_pipeline: Callable | None = Default(None),
root_directory: str | Path | None = Default(None),
Expand Down
221 changes: 221 additions & 0 deletions neps/utils/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
"""This module provides a command-line interface (CLI) for NePS."""

from __future__ import annotations

import argparse
import logging
from pathlib import Path

import neps
from neps.api import Default


def init_config(args: argparse.Namespace) -> None:
"""Creates a 'run_args' configuration YAML file template if it does not already
exist.
"""
config_path = Path(args.config_path) if args.config_path else Path("config.yaml")
if not config_path.exists():
with config_path.open("w") as file:
file.write("# Add your NEPS configuration settings here\n")
else:
pass


def parse_kv_pairs(kv_list: list[str]) -> dict:
"""Parse a list of key=value strings into a dictionary."""
result = {}
for item in kv_list:
if "=" in item:
key, value = item.split("=", 1)
result[key] = value
else:
raise ValueError("Each kwarg must be in key=value format.")
return result


def run_optimization(args: argparse.Namespace) -> None:
"""Collects arguments from the parser and runs the NePS optimization.
Args: args (argparse.Namespace): Parsed command-line arguments.
"""
kwargs = {}
if args.searcher_kwargs:
kwargs = parse_kv_pairs(args.searcher_kwargs) # convert kwargs

# Collect arguments from args and prepare them for neps.run
options = {
"run_args": args.run_args,
"root_directory": args.root_directory,
"overwrite_working_directory": args.overwrite_working_directory,
"post_run_summary": args.post_run_summary,
"development_stage_id": args.development_stage_id,
"task_id": args.task_id,
"max_evaluations_total": args.max_evaluations_total,
"max_evaluations_per_run": args.max_evaluations_per_run,
"continue_until_max_evaluation_completed": (
args.continue_until_max_evaluation_completed
),
"max_cost_total": args.max_cost_total,
"ignore_errors": args.ignore_errors,
"loss_value_on_error": args.loss_value_on_error,
"cost_value_on_error": args.cost_value_on_error,
"searcher": args.searcher,
**kwargs,
}
logging.basicConfig(level=logging.INFO)
neps.run(**options)


def main() -> None:
"""CLI entry point.
This function sets up the command-line interface (CLI) for NePS using argparse.
It defines the available subcommands and their respective arguments.
Available commands:
- init: Generates a 'run_args' YAML template file.
- run: Runs the optimization with specified configuration.
"""
parser = argparse.ArgumentParser(description="NePS Command Line Interface")
subparsers = parser.add_subparsers(
dest="command", help="Available commands: init, run"
)

# Subparser for "init" command
parser_init = subparsers.add_parser("init", help="Generate 'run_args' YAML file")
parser_init.add_argument(
"--config-path",
type=str,
default=None,
help="Optional custom path for generating the configuration file. "
"Default is 'config.yaml'.",
)
parser_init.set_defaults(func=init_config)

# Subparser for "run" command
parser_run = subparsers.add_parser("run", help="Run a neural pipeline search.")
# Adding arguments to the 'run' subparser with defaults
parser_run.add_argument(
"run_args",
type=str,
help="Path to the YAML configuration file. This file must include the "
"'run_pipeline' settings.",
)
parser_run.add_argument(
"--root-directory",
type=str,
default=Default(None),
help="The directory to save progress to. This is also used to synchronize "
"multiple calls for parallelization. (default: %(default)s)",
)
parser_run.add_argument(
"--overwrite-working-directory",
action="store_true",
default=Default(False), # noqa: FBT003
help="If set, deletes the working directory at the start of the run. "
"This is useful, for example, when debugging a run_pipeline function. "
"(default: %(default)s)",
)
parser_run.add_argument(
"--development-stage-id",
type=int,
default=Default(None),
help="Identifier for the current development stage, used in multi-stage "
"projects. (default: %(default)s)",
)
parser_run.add_argument(
"--task-id",
type=int,
default=Default(None),
help="Identifier for the current task, useful in projects with multiple tasks. "
"(default: %(default)s)",
)
# Create a mutually exclusive group for post-run summary flags
summary_group = parser_run.add_mutually_exclusive_group(required=False)
summary_group.add_argument(
"--post-run-summary",
action="store_true",
default=Default(True), # noqa: FBT003
help="Provide a summary of the results after running. (default: %(default)s)",
)
summary_group.add_argument(
"--no-post-run-summary",
action="store_false",
dest="post_run_summary",
help="Do not provide a summary of the results after running.",
)
parser_run.add_argument(
"--max-evaluations-total",
type=int,
default=Default(None),
help="Total number of evaluations to run. (default: %(default)s)",
)
parser_run.add_argument(
"--max-evaluations-per-run",
type=int,
default=Default(None),
help="Number of evaluations a specific call should maximally do. "
"(default: %(default)s)",
)
parser_run.add_argument(
"--continue-until-max-evaluation-completed",
action="store_true",
default=Default(False), # noqa: FBT003
help="If set, only stop after max-evaluations-total have been completed. This "
"is only relevant in the parallel setting. (default: %(default)s)",
)
parser_run.add_argument(
"--max-cost-total",
type=float,
default=Default(None),
help="No new evaluations will start when this cost is exceeded. Requires "
"returning a cost in the run_pipeline function, e.g., `return dict("
"loss=loss, cost=cost)`. (default: %(default)s)",
)
parser_run.add_argument(
"--ignore-errors",
action="store_true",
default=Default(False), # noqa: FBT003
help="If set, ignore errors during the optimization process. (default: %("
"default)s)",
)
parser_run.add_argument(
"--loss-value-on-error",
type=float,
default=Default(None),
help="Loss value to assume on error. (default: %(default)s)",
)
parser_run.add_argument(
"--cost-value-on-error",
type=float,
default=Default(None),
help="Cost value to assume on error. (default: %(default)s)",
)

parser_run.add_argument(
"--searcher",
type=str,
default=Default("default"),
help="String key of searcher algorithm to use for optimization. (default: %("
"default)s)",
)

parser_run.add_argument(
"--searcher-kwargs",
type=str,
nargs="+",
help="Additional keyword arguments as key=value pairs for the searcher.",
)

parser_run.set_defaults(func=run_optimization)

args = parser.parse_args()

if hasattr(args, "func"):
args.func(args)
else:
parser.print_help()


if __name__ == "__main__":
main()
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -306,3 +306,6 @@ module = [
"neps.utils.run_args_from_yaml",
]
ignore_errors = true

[tool.poetry.scripts]
neps = "neps.utils.cli:main"

0 comments on commit de21ef8

Please sign in to comment.