diff --git a/docs/src/getting-started/override.rst b/docs/src/getting-started/override.rst index 872e8c8b5..bfabaf387 100644 --- a/docs/src/getting-started/override.rst +++ b/docs/src/getting-started/override.rst @@ -52,13 +52,12 @@ possibility. The changes above can be achieved by typing: .. code-block:: bash mtt train options.yaml \ - -r architecture.model.soap.cutoff=7.0,architecture.training.num_epochs=200 + -r architecture.model.soap.cutoff=7.0 -r architecture.training.num_epochs=200 Here, the ``-r`` or equivalent ``--override`` flag is used to parse the override flags. The syntax follows a dotlist-style string format where each level of the options is -seperated by a ``.`` and each separate option to override is separated by a comma. -As a further example, to use single precision for your training you can add -``-r base_precision=32``. +seperated by a ``.``. As a further example, to use single precision for your training +you can add ``-r base_precision=32``. .. note:: Command line overrides allow adding new values to your training parameters and diff --git a/src/metatrain/cli/train.py b/src/metatrain/cli/train.py index 4284637ad..56f4caace 100644 --- a/src/metatrain/cli/train.py +++ b/src/metatrain/cli/train.py @@ -83,8 +83,9 @@ def _add_train_model_parser(subparser: argparse._SubParsersAction) -> None: "-r", "--override", dest="override_options", - type=lambda string: OmegaConf.from_dotlist(string.split(",")), - help="Command line override flags.", + action="append", + help="Command-line override flags.", + default=[], ) @@ -93,8 +94,7 @@ def _prepare_train_model_args(args: argparse.Namespace) -> None: args.options = OmegaConf.load(args.options) # merge/override file options with command line options override_options = args.__dict__.pop("override_options") - if override_options is None: - override_options = {} + override_options = OmegaConf.from_dotlist(override_options) args.options = OmegaConf.merge(args.options, override_options) diff --git a/tests/cli/test_train_model.py b/tests/cli/test_train_model.py index 17bf0e319..e9764006b 100644 --- a/tests/cli/test_train_model.py +++ b/tests/cli/test_train_model.py @@ -106,8 +106,8 @@ def test_train(capfd, monkeypatch, tmp_path, output): @pytest.mark.parametrize( "overrides", [ - "architecture.training.num_epochs=2", - "architecture.training.num_epochs=2,architecture.training.batch_size=3", + ["architecture.training.num_epochs=2"], + ["architecture.training.num_epochs=2", "architecture.training.batch_size=3"], ], ) def test_command_line_override(monkeypatch, tmp_path, overrides): @@ -116,7 +116,9 @@ def test_command_line_override(monkeypatch, tmp_path, overrides): shutil.copy(DATASET_PATH_QM9, "qm9_reduced_100.xyz") shutil.copy(OPTIONS_PATH, "options.yaml") - command = ["mtt", "train", "options.yaml", "-r", overrides] + command = ["mtt", "train", "options.yaml"] + for override in overrides: + command += ["-r", override] subprocess.check_call(command) @@ -126,7 +128,7 @@ def test_command_line_override(monkeypatch, tmp_path, overrides): restart_options = OmegaConf.load(restart_glob[0]) assert restart_options["architecture"]["training"]["num_epochs"] == 2 - if len(overrides.split()) == 2: + if len(overrides) == 2: assert restart_options["architecture"]["training"]["batch_size"] == 3