Skip to content

Commit

Permalink
Fix command-line overrides
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Dec 9, 2024
1 parent 3a89558 commit 0744fdd
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 5 deletions.
7 changes: 4 additions & 3 deletions docs/src/getting-started/override.rst
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,13 @@ possibility. The changes above can be achieved by typing:
.. code-block:: bash
mtt train options.yaml \
-r architecture.model.soap.cutoff=7.0 architecture.training.num_epochs=200
-r architecture.model.soap.cutoff=7.0,architecture.training.num_epochs=200
Here, the ``-r`` or equivalent ``--override`` flag is used to parse the override flags.
The syntax follows a dotlist-style string format where each level of the options is
seperated by a ``.``. For example to use single precision as the base precision for your
training use ``-r base_precision=32``
seperated by a ``.`` and each separate option to override is separated by a comma.
As a further example, to use single precision for your training you can add
``-r base_precision=32``.

.. note::
Command line overrides allow adding new values to your training parameters and
Expand Down
2 changes: 1 addition & 1 deletion src/metatrain/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def _add_train_model_parser(subparser: argparse._SubParsersAction) -> None:
"-r",
"--override",
dest="override_options",
type=lambda string: OmegaConf.from_dotlist(string.split()),
type=lambda string: OmegaConf.from_dotlist(string.split(",")),
help="Command line override flags.",
)

Expand Down
2 changes: 1 addition & 1 deletion tests/cli/test_train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def test_train(capfd, monkeypatch, tmp_path, output):
"overrides",
[
"architecture.training.num_epochs=2",
"architecture.training.num_epochs=2 architecture.training.batch_size=3",
"architecture.training.num_epochs=2,architecture.training.batch_size=3",
],
)
def test_command_line_override(monkeypatch, tmp_path, overrides):
Expand Down

0 comments on commit 0744fdd

Please sign in to comment.