From 0744fdd7055f87609d8caeac4f60431f19537ecf Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Mon, 9 Dec 2024 07:57:22 +0100 Subject: [PATCH] Fix command-line overrides --- docs/src/getting-started/override.rst | 7 ++++--- src/metatrain/cli/train.py | 2 +- tests/cli/test_train_model.py | 2 +- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/docs/src/getting-started/override.rst b/docs/src/getting-started/override.rst index 45a9f36d0..872e8c8b5 100644 --- a/docs/src/getting-started/override.rst +++ b/docs/src/getting-started/override.rst @@ -52,12 +52,13 @@ possibility. The changes above can be achieved by typing: .. code-block:: bash mtt train options.yaml \ - -r architecture.model.soap.cutoff=7.0 architecture.training.num_epochs=200 + -r architecture.model.soap.cutoff=7.0,architecture.training.num_epochs=200 Here, the ``-r`` or equivalent ``--override`` flag is used to parse the override flags. The syntax follows a dotlist-style string format where each level of the options is -seperated by a ``.``. For example to use single precision as the base precision for your -training use ``-r base_precision=32`` +seperated by a ``.`` and each separate option to override is separated by a comma. +As a further example, to use single precision for your training you can add +``-r base_precision=32``. .. note:: Command line overrides allow adding new values to your training parameters and diff --git a/src/metatrain/cli/train.py b/src/metatrain/cli/train.py index 53e235c4f..4284637ad 100644 --- a/src/metatrain/cli/train.py +++ b/src/metatrain/cli/train.py @@ -83,7 +83,7 @@ def _add_train_model_parser(subparser: argparse._SubParsersAction) -> None: "-r", "--override", dest="override_options", - type=lambda string: OmegaConf.from_dotlist(string.split()), + type=lambda string: OmegaConf.from_dotlist(string.split(",")), help="Command line override flags.", ) diff --git a/tests/cli/test_train_model.py b/tests/cli/test_train_model.py index a17831e36..17bf0e319 100644 --- a/tests/cli/test_train_model.py +++ b/tests/cli/test_train_model.py @@ -107,7 +107,7 @@ def test_train(capfd, monkeypatch, tmp_path, output): "overrides", [ "architecture.training.num_epochs=2", - "architecture.training.num_epochs=2 architecture.training.batch_size=3", + "architecture.training.num_epochs=2,architecture.training.batch_size=3", ], ) def test_command_line_override(monkeypatch, tmp_path, overrides):