Skip to content

Commit

Permalink
Merge branch 'main' into heads-and-utils
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster authored Dec 11, 2024
2 parents 2175d90 + 887e66a commit 39cfe9d
Show file tree
Hide file tree
Showing 7 changed files with 67 additions and 18 deletions.
8 changes: 4 additions & 4 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -129,16 +129,16 @@ Contributors
Thanks goes to all people that make ``metatrain`` possible:

.. image:: https://contrib.rocks/image?repo=lab-cosmo/metatrain
:target: https://github.com/lab-cosmo/metatrain/graphs/contributors
:target: https://github.com/metatensor/metatrain/graphs/contributors

.. |tests| image:: https://github.com/lab-cosmo/metatrain/workflows/Tests/badge.svg
:alt: Github Actions Tests Job Status
:target: https://github.com/lab-cosmo/metatrain/actions?query=branch%3Amain
:target: https://github.com/metatensor/metatrain/actions?query=branch%3Amain

.. |codecov| image:: https://codecov.io/gh/lab-cosmo/metatrain/branch/main/graph/badge.svg
:alt: Code coverage
:target: https://codecov.io/gh/lab-cosmo/metatrain
:target: https://codecov.io/gh/metatensor/metatrain

.. |docs| image:: https://img.shields.io/badge/documentation-latest-sucess
:alt: Documentation
:target: https://lab-cosmo.github.io/metatrain/latest/
:target: https://metatensor.github.io/metatrain/latest
6 changes: 3 additions & 3 deletions docs/src/getting-started/override.rst
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,12 @@ possibility. The changes above can be achieved by typing:
.. code-block:: bash
mtt train options.yaml \
-r architecture.model.soap.cutoff=7.0 architecture.training.num_epochs=200
-r architecture.model.soap.cutoff=7.0 -r architecture.training.num_epochs=200
Here, the ``-r`` or equivalent ``--override`` flag is used to parse the override flags.
The syntax follows a dotlist-style string format where each level of the options is
seperated by a ``.``. For example to use single precision as the base precision for your
training use ``-r base_precision=32``
seperated by a ``.``. As a further example, to use single precision for your training
you can add ``-r base_precision=32``.

.. note::
Command line overrides allow adding new values to your training parameters and
Expand Down
8 changes: 4 additions & 4 deletions src/metatrain/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,9 @@ def _add_train_model_parser(subparser: argparse._SubParsersAction) -> None:
"-r",
"--override",
dest="override_options",
type=lambda string: OmegaConf.from_dotlist(string.split()),
help="Command line override flags.",
action="append",
help="Command-line override flags.",
default=[],
)


Expand All @@ -93,8 +94,7 @@ def _prepare_train_model_args(args: argparse.Namespace) -> None:
args.options = OmegaConf.load(args.options)
# merge/override file options with command line options
override_options = args.__dict__.pop("override_options")
if override_options is None:
override_options = {}
override_options = OmegaConf.from_dotlist(override_options)

args.options = OmegaConf.merge(args.options, override_options)

Expand Down
7 changes: 5 additions & 2 deletions src/metatrain/utils/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,13 @@ def update(self, other: "DatasetInfo") -> None:

intersecting_target_keys = self.targets.keys() & other.targets.keys()
for key in intersecting_target_keys:
if self.targets[key] != other.targets[key]:
if not self.targets[key].is_compatible_with(other.targets[key]):
raise ValueError(
f"Can't update DatasetInfo with different target information for "
f"target '{key}': {self.targets[key]} != {other.targets[key]}"
f"target '{key}': {self.targets[key]} is not compatible with "
f"{other.targets[key]}. If the units, quantity and keys of the two "
"targets are the same, this must be due to a mismatch in the "
"internal metadata of the layout."
)
self.targets.update(other.targets)

Expand Down
31 changes: 31 additions & 0 deletions src/metatrain/utils/data/target_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,37 @@ def _check_layout(self, layout: TensorMap) -> None:
"Gradients of spherical tensor targets are not supported."
)

def is_compatible_with(self, other: "TargetInfo") -> bool:
"""Check if two targets are compatible.
Two target infos are compatible if they have the same quantity, unit,
and layout, except for gradients. This method can be used to check if two
target infos with the same name can correspond to the same output
in a model.
:param other: The target info to compare with.
:return: :py:obj:`True` if the two target infos are compatible,
:py:obj:`False` otherwise.
"""
if self.quantity != other.quantity:
return False
if self.unit != other.unit:
return False
if self.layout.keys.names != other.layout.keys.names:
return False
for key, block in self.layout.items():
if key not in other.layout.keys:
return False
other_block = other.layout[key]
if not block.samples == other_block.samples:
return False
if not block.components == other_block.components:
return False
if not block.properties == other_block.properties:
return False
# gradients are not checked on purpose
return True


def get_energy_target_info(
target: DictConfig,
Expand Down
11 changes: 6 additions & 5 deletions tests/cli/test_train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ def test_train(capfd, monkeypatch, tmp_path, output):
@pytest.mark.parametrize(
"overrides",
[
"architecture.training.num_epochs=2",
"architecture.training.num_epochs=2 architecture.training.batch_size=3",
["architecture.training.num_epochs=2"],
["architecture.training.num_epochs=2", "architecture.training.batch_size=3"],
],
)
def test_command_line_override(monkeypatch, tmp_path, overrides):
Expand All @@ -116,7 +116,9 @@ def test_command_line_override(monkeypatch, tmp_path, overrides):
shutil.copy(DATASET_PATH_QM9, "qm9_reduced_100.xyz")
shutil.copy(OPTIONS_PATH, "options.yaml")

command = ["mtt", "train", "options.yaml", "-r", overrides]
command = ["mtt", "train", "options.yaml"]
for override in overrides:
command += ["-r", override]

subprocess.check_call(command)

Expand All @@ -126,7 +128,7 @@ def test_command_line_override(monkeypatch, tmp_path, overrides):
restart_options = OmegaConf.load(restart_glob[0])
assert restart_options["architecture"]["training"]["num_epochs"] == 2

if len(overrides.split()) == 2:
if len(overrides) == 2:
assert restart_options["architecture"]["training"]["batch_size"] == 3


Expand Down Expand Up @@ -487,7 +489,6 @@ def test_continue_different_dataset(options, monkeypatch, tmp_path):

options["training_set"]["systems"]["read_from"] = "ethanol_reduced_100.xyz"
options["training_set"]["targets"]["energy"]["key"] = "energy"
options["training_set"]["targets"]["energy"]["forces"] = False

train_model(options, continue_from=MODEL_PATH_64_BIT)

Expand Down
14 changes: 14 additions & 0 deletions tests/utils/data/test_target_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,17 @@ def test_is_auxiliary_output():
assert is_auxiliary_output("features")
assert is_auxiliary_output("energy_ensemble")
assert is_auxiliary_output("mtt::aux::energy_ensemble")


def test_is_compatible_with(energy_target_config, spherical_target_config):
energy_target_info = get_energy_target_info(energy_target_config)
spherical_target_config = get_generic_target_info(spherical_target_config)
energy_target_info_with_forces = get_energy_target_info(
energy_target_config, add_position_gradients=True
)
assert energy_target_info.is_compatible_with(energy_target_info)
assert energy_target_info_with_forces.is_compatible_with(energy_target_info)
assert not energy_target_info.is_compatible_with(spherical_target_config)
assert not (
energy_target_info_with_forces.is_compatible_with(spherical_target_config)
)

0 comments on commit 39cfe9d

Please sign in to comment.