diff --git a/fedn/cli/run_cmd.py b/fedn/cli/run_cmd.py index 20836cf96..9177296a1 100644 --- a/fedn/cli/run_cmd.py +++ b/fedn/cli/run_cmd.py @@ -4,7 +4,6 @@ import click import yaml - from fedn.common.exceptions import InvalidClientConfig from fedn.common.log_config import logger from fedn.network.clients.client import Client @@ -44,7 +43,38 @@ def run_cmd(ctx): """:param ctx: """ pass +@run_cmd.command("train") +@click.option("-p", "--path", required=True, help="Path to package directory containing fedn.yaml") +@click.option("-i", "--input", required=True, help="Path to input model parameters" ) +@click.option("-o", "--output", required=True,help="Path to write the updated model parameters ") +@click.pass_context +def train_cmd(ctx, path,input,output): + """Execute 'train' entrypoint in fedn.yaml. + :param ctx: + :param path: Path to folder containing fedn.yaml + :type path: str + """ + path = os.path.abspath(path) + yaml_file = os.path.join(path, "fedn.yaml") + if not os.path.exists(yaml_file): + logger.error(f"Could not find fedn.yaml in {path}") + exit(-1) + + config = _read_yaml_file(yaml_file) + # Check that train is defined in fedn.yaml under entry_points + if "train" not in config["entry_points"]: + logger.error("No train command defined in fedn.yaml") + exit(-1) + + dispatcher = Dispatcher(config, path) + _ = dispatcher._get_or_create_python_env() + dispatcher.run_cmd("train {} {}".format(input, output)) + + # delete the virtualenv + if dispatcher.python_env_path: + logger.info(f"Removing virtualenv {dispatcher.python_env_path}") + shutil.rmtree(dispatcher.python_env_path) @run_cmd.command("startup") @click.option("-p", "--path", required=True, help="Path to package directory containing fedn.yaml") @click.pass_context