Skip to content

Commit

Permalink
Allow multiple override flags
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Dec 9, 2024
1 parent 0744fdd commit 05ff1a7
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 12 deletions.
7 changes: 3 additions & 4 deletions docs/src/getting-started/override.rst
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,12 @@ possibility. The changes above can be achieved by typing:
.. code-block:: bash
mtt train options.yaml \
-r architecture.model.soap.cutoff=7.0,architecture.training.num_epochs=200
-r architecture.model.soap.cutoff=7.0 -r architecture.training.num_epochs=200
Here, the ``-r`` or equivalent ``--override`` flag is used to parse the override flags.
The syntax follows a dotlist-style string format where each level of the options is
seperated by a ``.`` and each separate option to override is separated by a comma.
As a further example, to use single precision for your training you can add
``-r base_precision=32``.
seperated by a ``.``. As a further example, to use single precision for your training
you can add ``-r base_precision=32``.

.. note::
Command line overrides allow adding new values to your training parameters and
Expand Down
8 changes: 4 additions & 4 deletions src/metatrain/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,9 @@ def _add_train_model_parser(subparser: argparse._SubParsersAction) -> None:
"-r",
"--override",
dest="override_options",
type=lambda string: OmegaConf.from_dotlist(string.split(",")),
help="Command line override flags.",
action="append",
help="Command-line override flags.",
default=[],
)


Expand All @@ -93,8 +94,7 @@ def _prepare_train_model_args(args: argparse.Namespace) -> None:
args.options = OmegaConf.load(args.options)
# merge/override file options with command line options
override_options = args.__dict__.pop("override_options")
if override_options is None:
override_options = {}
override_options = OmegaConf.from_dotlist(override_options)

args.options = OmegaConf.merge(args.options, override_options)

Expand Down
10 changes: 6 additions & 4 deletions tests/cli/test_train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ def test_train(capfd, monkeypatch, tmp_path, output):
@pytest.mark.parametrize(
"overrides",
[
"architecture.training.num_epochs=2",
"architecture.training.num_epochs=2,architecture.training.batch_size=3",
["architecture.training.num_epochs=2"],
["architecture.training.num_epochs=2", "architecture.training.batch_size=3"],
],
)
def test_command_line_override(monkeypatch, tmp_path, overrides):
Expand All @@ -116,7 +116,9 @@ def test_command_line_override(monkeypatch, tmp_path, overrides):
shutil.copy(DATASET_PATH_QM9, "qm9_reduced_100.xyz")
shutil.copy(OPTIONS_PATH, "options.yaml")

command = ["mtt", "train", "options.yaml", "-r", overrides]
command = ["mtt", "train", "options.yaml"]
for override in overrides:
command += ["-r", override]

subprocess.check_call(command)

Expand All @@ -126,7 +128,7 @@ def test_command_line_override(monkeypatch, tmp_path, overrides):
restart_options = OmegaConf.load(restart_glob[0])
assert restart_options["architecture"]["training"]["num_epochs"] == 2

if len(overrides.split()) == 2:
if len(overrides) == 2:
assert restart_options["architecture"]["training"]["batch_size"] == 3


Expand Down

0 comments on commit 05ff1a7

Please sign in to comment.