Skip to content

Commit

Permalink
Update readers to load dataset with arbitrary targets (#376)
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster authored Nov 21, 2024
1 parent 5b9a8b0 commit 444fb72
Show file tree
Hide file tree
Showing 54 changed files with 2,012 additions and 858 deletions.
1 change: 1 addition & 0 deletions docs/src/advanced-concepts/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ such as output naming, auxiliary outputs, and wrapper models.
multi-gpu
auto-restarting
fine-tuning
preparing-generic-targets
112 changes: 112 additions & 0 deletions docs/src/advanced-concepts/preparing-generic-targets.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
Preparing generic targets for reading by metatrain
==================================================

Besides energy-like targets, the library also supports reading (and training on)
more generic targets.

Input file
----------

In order to read a generic target, you will have to specify its layout in the input
file. Suppose you want to learn a target named ``mtt::my_target``, which is
represented as a set of 10 independent per-atom 3D Cartesian vector (we need to
learn 3x10 values for each atom). The ``target`` section in the input file
should look
like this:

.. code-block:: yaml
targets:
mtt::my_target:
read_from: dataset.xyz
key: my_target
quantity: ""
unit: ""
per_atom: True
type:
cartiesian:
rank: 1
num_subtargets: 10
The crucial fields here are:

- ``per_atom``: This field should be set to ``True`` if the target is a per-atom
property. Otherwise, it should be set to ``False``.
- ``type``: This field specifies the type of the target. In this case, the target is
a Cartesian vector. The ``rank`` field specifies the rank of the target. For
Cartesian vectors, the rank is 1. Other possibilities for the ``type`` are
``scalar`` (for a scalar target) and ``spherical`` (for a spherical tensor).
- ``num_subtargets``: This field specifies the number of sub-targets that need to be
learned as part of this target. They are treated as entirely equivalent by models in
metatrain and will often be represented as outputs of the same neural network layer.
A common use case for this field is when you are learning a discretization of a
continuous target, such as the grid points of a band structure. In the example
above, there are 10 sub-targets. In ``metatensor``, these correspond to the number
of ``properties`` of the target.

A few more words should be spent on ``spherical`` targets. These should be made of a
certain number of irreducible spherical tensors. For example, if you are learning a
property that can be decomposed into two proper spherical tensors with L=0 and L=2,
the target section should would look like this:

.. code-block:: yaml
targets:
mtt::my_target:
quantity: ""
read_from: dataset.xyz
key: energy
unit: ""
per_atom: True
type:
spherical:
irreps:
- {o3_lambda: 0, o3_sigma: 1}
- {o3_lambda: 2, o3_sigma: 1}
num_subtargets: 10
where ``o3_lambda`` specifies the L value of the spherical tensor and ``o3_sigma`` its
parity with respect to inversion (1 for proper tensors, -1 for pseudo-tensors).

Preparing your targets -- ASE
-----------------------------

If you are using the ASE readers to read your targets, you will have to save them
either in the ``.info`` (if the target is per structure, i.e. not per atom) or in the
``.arrays`` (if the target is per atom) attributes of the ASE atoms object. Then you can
dump the atoms object to a file using ``ase.io.write``.

The ASE reader will automatically try to reshape the target data to the format expected
given the target section in the input file. In case your target data is invalid, an
error will be raised.

Reading targets with more than one spherical tensor is not supported by the ASE reader.
In that case, you should use the metatensor reader.

Preparing your targets -- metatensor
------------------------------------

If you are using the metatensor readers to read your targets, you will have to save them
as a ``metatensor.torch.TensorMap`` object with ``metatensor.torch.TensorMap.save()``
into a file with the ``.npz`` extension.

The metatensor reader will verify that the target data in the input files corresponds to
the metadata in the provided ``TensorMap`` objects. In case of a mismatch, errors will
be raised.

In particular:

- if the target is per atom, the samples should have the [``system``, ``atom``] names,
otherwise the [``system``] name.
- if the target is a ``scalar``, only one ``TensorBlock`` should be present, the keys
of the ``TensorMap`` should be a ``Labels.single()`` object, and there should be no
components.
- if the target is a ``cartesian`` tensor, only one ``TensorBlock`` should be present,
the keys of the ``TensorMap`` should be a ``Labels.single()`` object, and there should
be one components, with names [``xyz``] for a rank-1 tensor,
[``xyz_1``, ``xyz_2``, etc.] for higher rank tensors.
- if the target is a ``spherical`` tensor, the ``TensorMap`` can contain multiple
``TensorBlock``, each corresponding to one irreducible spherical tensor. The keys of
the ``TensorMap`` should have the ``o3_lambda`` and ``o3_sigma`` names, corresponding
to the values provided in the input file, and each ``TensorBlock`` should be one
component label, with name ``o3_mu`` and values going from -L to L.
46 changes: 34 additions & 12 deletions docs/src/dev-docs/utils/data/readers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,30 +13,52 @@ Parsers for obtaining *system* and *target* information from files. Currently,
* - ``ase``
- system, energy, forces, stress, virials
- ``.xyz``, ``.extxyz``
* - ``metatensor``
- system, energy, forces, stress, virials
- ``.npz``


If the ``reader`` parameter is not set the library is determined from the file
extension. Overriding this behavior is in particular useful, if a file format is not
If the ``reader`` parameter is not set, the library is determined from the file
extension. Overriding this behavior is in particular useful if a file format is not
listed here but might be supported by a library.

Below the synopsis of the reader functions in details.

System and target data readers
==============================

The main entry point for reading system and target information are the reader functions
The main entry point for reading system and target information are the reader functions.

.. autofunction:: metatrain.utils.data.read_systems
.. autofunction:: metatrain.utils.data.read_targets

Target type specific readers
----------------------------
These functions dispatch the reading of the system and target information to the
appropriate readers, based on the file extension or the user-provided library.

In addition, the read_targets function uses the user-provided information about the
targets to call the appropriate target reader function (for energy targets or generic
targets).

ASE
---

This section describes the parsers for the ASE library.

.. autofunction:: metatrain.utils.data.readers.ase.read_systems
.. autofunction:: metatrain.utils.data.readers.ase.read_energy
.. autofunction:: metatrain.utils.data.readers.ase.read_generic

It should be noted that :func:`metatrain.utils.data.readers.ase.read_energy` currently
uses sub-functions to parse the energy and its gradients like ``forces``, ``virial``
and ``stress``.

Metatensor
----------

:func:`metatrain.utils.data.read_targets` uses sub-functions to parse supported
target properties like the ``energy`` or ``forces``. Currently we support reading the
following target properties via
This section describes the parsers for the ``metatensor`` library. As the systems and/or
targets are already stored in the ``metatensor`` format, these reader functions mainly
perform checks and return the data.

.. autofunction:: metatrain.utils.data.read_energy
.. autofunction:: metatrain.utils.data.read_forces
.. autofunction:: metatrain.utils.data.read_virial
.. autofunction:: metatrain.utils.data.read_stress
.. autofunction:: metatrain.utils.data.readers.metatensor.read_systems
.. autofunction:: metatrain.utils.data.readers.metatensor.read_energy
.. autofunction:: metatrain.utils.data.readers.metatensor.read_generic
3 changes: 3 additions & 0 deletions examples/programmatic/llpr/llpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@
"reader": "ase",
"key": "energy",
"unit": "kcal/mol",
"type": "scalar",
"per_atom": False,
"num_subtargets": 1,
"forces": False,
"stress": False,
"virial": False,
Expand Down
9 changes: 9 additions & 0 deletions examples/programmatic/llpr_forces/force_llpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
"reader": "ase",
"key": "energy",
"unit": "kcal/mol",
"type": "scalar",
"per_atom": False,
"num_subtargets": 1,
"forces": {
"read_from": "train.xyz",
"file_format": ".xyz",
Expand All @@ -53,6 +56,9 @@
"reader": "ase",
"key": "energy",
"unit": "kcal/mol",
"type": "scalar",
"per_atom": False,
"num_subtargets": 1,
"forces": {
"read_from": "valid.xyz",
"file_format": ".xyz",
Expand All @@ -79,6 +85,9 @@
"reader": "ase",
"key": "energy",
"unit": "kcal/mol",
"type": "scalar",
"per_atom": False,
"num_subtargets": 1,
"forces": {
"read_from": "test.xyz",
"file_format": ".xyz",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
from metatensor.torch.atomistic import ModelEvaluationOptions, System

from metatrain.experimental.alchemical_model import AlchemicalModel
from metatrain.utils.data import DatasetInfo, TargetInfo
from metatrain.utils.data import DatasetInfo
from metatrain.utils.data.target_info import get_energy_target_info
from metatrain.utils.neighbor_lists import (
get_requested_neighbor_lists,
get_system_with_neighbor_lists,
)
from metatrain.utils.testing import energy_layout

from . import MODEL_HYPERS

Expand All @@ -23,9 +23,7 @@ def test_to(device, dtype):
dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types=[1, 6, 7, 8],
targets={
"energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout)
},
targets={"energy": get_energy_target_info({"unit": "eV"})},
)
model = AlchemicalModel(MODEL_HYPERS, dataset_info).to(dtype=dtype)
exported = model.export()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
from metatensor.torch.atomistic import ModelEvaluationOptions, System

from metatrain.experimental.alchemical_model import AlchemicalModel
from metatrain.utils.data import DatasetInfo, TargetInfo
from metatrain.utils.data import DatasetInfo
from metatrain.utils.data.target_info import get_energy_target_info
from metatrain.utils.neighbor_lists import (
get_requested_neighbor_lists,
get_system_with_neighbor_lists,
)
from metatrain.utils.testing import energy_layout

from . import MODEL_HYPERS

Expand All @@ -19,9 +19,7 @@ def test_prediction_subset_elements():
dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types=[1, 6, 7, 8],
targets={
"energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout)
},
targets={"energy": get_energy_target_info({"unit": "eV"})},
)

model = AlchemicalModel(MODEL_HYPERS, dataset_info)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
from metatensor.torch.atomistic import ModelEvaluationOptions, systems_to_torch

from metatrain.experimental.alchemical_model import AlchemicalModel
from metatrain.utils.data import DatasetInfo, TargetInfo
from metatrain.utils.data import DatasetInfo
from metatrain.utils.data.target_info import get_energy_target_info
from metatrain.utils.neighbor_lists import (
get_requested_neighbor_lists,
get_system_with_neighbor_lists,
)
from metatrain.utils.testing import energy_layout

from . import DATASET_PATH, MODEL_HYPERS

Expand All @@ -21,9 +21,7 @@ def test_rotational_invariance():
dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types=[1, 6, 7, 8],
targets={
"energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout)
},
targets={"energy": get_energy_target_info({"unit": "eV"})},
)
model = AlchemicalModel(MODEL_HYPERS, dataset_info)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,12 @@
from omegaconf import OmegaConf

from metatrain.experimental.alchemical_model import AlchemicalModel, Trainer
from metatrain.utils.data import (
Dataset,
DatasetInfo,
TargetInfo,
read_systems,
read_targets,
)
from metatrain.utils.data import Dataset, DatasetInfo, read_systems, read_targets
from metatrain.utils.data.target_info import get_energy_target_info
from metatrain.utils.neighbor_lists import (
get_requested_neighbor_lists,
get_system_with_neighbor_lists,
)
from metatrain.utils.testing import energy_layout

from . import DATASET_PATH, DEFAULT_HYPERS, MODEL_HYPERS

Expand All @@ -32,7 +26,7 @@ def test_regression_init():
"""Perform a regression test on the model at initialization"""

targets = {}
targets["mtt::U0"] = TargetInfo(quantity="energy", unit="eV", layout=energy_layout)
targets["mtt::U0"] = get_energy_target_info({"unit": "eV"})

dataset_info = DatasetInfo(
length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=targets
Expand Down Expand Up @@ -90,6 +84,9 @@ def test_regression_train():
"reader": "ase",
"key": "U0",
"unit": "eV",
"type": "scalar",
"per_atom": False,
"num_subtargets": 1,
"forces": False,
"stress": False,
"virial": False,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
from metatrain.experimental.alchemical_model.utils import (
systems_to_torch_alchemical_batch,
)
from metatrain.utils.data import DatasetInfo, TargetInfo, read_systems
from metatrain.utils.data import DatasetInfo, read_systems
from metatrain.utils.data.target_info import get_energy_target_info
from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists
from metatrain.utils.testing import energy_layout

from . import MODEL_HYPERS, QM9_DATASET_PATH

Expand Down Expand Up @@ -72,9 +72,7 @@ def test_alchemical_model_inference():
dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types=unique_numbers,
targets={
"energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout)
},
targets={"energy": get_energy_target_info({"unit": "eV"})},
)

alchemical_model = AlchemicalModel(MODEL_HYPERS, dataset_info)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import torch

from metatrain.experimental.alchemical_model import AlchemicalModel
from metatrain.utils.data import DatasetInfo, TargetInfo
from metatrain.utils.testing import energy_layout
from metatrain.utils.data import DatasetInfo
from metatrain.utils.data.target_info import get_energy_target_info

from . import MODEL_HYPERS

Expand All @@ -13,9 +13,7 @@ def test_torchscript():
dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types=[1, 6, 7, 8],
targets={
"energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout)
},
targets={"energy": get_energy_target_info({"unit": "eV"})},
)

model = AlchemicalModel(MODEL_HYPERS, dataset_info)
Expand All @@ -28,9 +26,7 @@ def test_torchscript_save_load():
dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types=[1, 6, 7, 8],
targets={
"energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout)
},
targets={"energy": get_energy_target_info({"unit": "eV"})},
)
model = AlchemicalModel(MODEL_HYPERS, dataset_info)
torch.jit.save(
Expand Down
Loading

0 comments on commit 444fb72

Please sign in to comment.