diff --git a/docs/src/advanced-concepts/index.rst b/docs/src/advanced-concepts/index.rst index cd5fdb04d..dd5f98f01 100644 --- a/docs/src/advanced-concepts/index.rst +++ b/docs/src/advanced-concepts/index.rst @@ -13,3 +13,4 @@ such as output naming, auxiliary outputs, and wrapper models. multi-gpu auto-restarting fine-tuning + preparing-generic-targets diff --git a/docs/src/advanced-concepts/preparing-generic-targets.rst b/docs/src/advanced-concepts/preparing-generic-targets.rst new file mode 100644 index 000000000..949e8f9f8 --- /dev/null +++ b/docs/src/advanced-concepts/preparing-generic-targets.rst @@ -0,0 +1,112 @@ +Preparing generic targets for reading by metatrain +================================================== + +Besides energy-like targets, the library also supports reading (and training on) +more generic targets. + +Input file +---------- + +In order to read a generic target, you will have to specify its layout in the input +file. Suppose you want to learn a target named ``mtt::my_target``, which is +represented as a set of 10 independent per-atom 3D Cartesian vector (we need to +learn 3x10 values for each atom). The ``target`` section in the input file +should look +like this: + +.. code-block:: yaml + + targets: + mtt::my_target: + read_from: dataset.xyz + key: my_target + quantity: "" + unit: "" + per_atom: True + type: + cartiesian: + rank: 1 + num_subtargets: 10 + +The crucial fields here are: + +- ``per_atom``: This field should be set to ``True`` if the target is a per-atom + property. Otherwise, it should be set to ``False``. +- ``type``: This field specifies the type of the target. In this case, the target is + a Cartesian vector. The ``rank`` field specifies the rank of the target. For + Cartesian vectors, the rank is 1. Other possibilities for the ``type`` are + ``scalar`` (for a scalar target) and ``spherical`` (for a spherical tensor). +- ``num_subtargets``: This field specifies the number of sub-targets that need to be + learned as part of this target. They are treated as entirely equivalent by models in + metatrain and will often be represented as outputs of the same neural network layer. + A common use case for this field is when you are learning a discretization of a + continuous target, such as the grid points of a band structure. In the example + above, there are 10 sub-targets. In ``metatensor``, these correspond to the number + of ``properties`` of the target. + +A few more words should be spent on ``spherical`` targets. These should be made of a +certain number of irreducible spherical tensors. For example, if you are learning a +property that can be decomposed into two proper spherical tensors with L=0 and L=2, +the target section should would look like this: + +.. code-block:: yaml + + targets: + mtt::my_target: + quantity: "" + read_from: dataset.xyz + key: energy + unit: "" + per_atom: True + type: + spherical: + irreps: + - {o3_lambda: 0, o3_sigma: 1} + - {o3_lambda: 2, o3_sigma: 1} + num_subtargets: 10 + +where ``o3_lambda`` specifies the L value of the spherical tensor and ``o3_sigma`` its +parity with respect to inversion (1 for proper tensors, -1 for pseudo-tensors). + +Preparing your targets -- ASE +----------------------------- + +If you are using the ASE readers to read your targets, you will have to save them +either in the ``.info`` (if the target is per structure, i.e. not per atom) or in the +``.arrays`` (if the target is per atom) attributes of the ASE atoms object. Then you can +dump the atoms object to a file using ``ase.io.write``. + +The ASE reader will automatically try to reshape the target data to the format expected +given the target section in the input file. In case your target data is invalid, an +error will be raised. + +Reading targets with more than one spherical tensor is not supported by the ASE reader. +In that case, you should use the metatensor reader. + +Preparing your targets -- metatensor +------------------------------------ + +If you are using the metatensor readers to read your targets, you will have to save them +as a ``metatensor.torch.TensorMap`` object with ``metatensor.torch.TensorMap.save()`` +into a file with the ``.npz`` extension. + +The metatensor reader will verify that the target data in the input files corresponds to +the metadata in the provided ``TensorMap`` objects. In case of a mismatch, errors will +be raised. + +In particular: + +- if the target is per atom, the samples should have the [``system``, ``atom``] names, + otherwise the [``system``] name. +- if the target is a ``scalar``, only one ``TensorBlock`` should be present, the keys + of the ``TensorMap`` should be a ``Labels.single()`` object, and there should be no + components. +- if the target is a ``cartesian`` tensor, only one ``TensorBlock`` should be present, + the keys of the ``TensorMap`` should be a ``Labels.single()`` object, and there should + be one components, with names [``xyz``] for a rank-1 tensor, + [``xyz_1``, ``xyz_2``, etc.] for higher rank tensors. +- if the target is a ``spherical`` tensor, the ``TensorMap`` can contain multiple + ``TensorBlock``, each corresponding to one irreducible spherical tensor. The keys of + the ``TensorMap`` should have the ``o3_lambda`` and ``o3_sigma`` names, corresponding + to the values provided in the input file, and each ``TensorBlock`` should be one + component label, with name ``o3_mu`` and values going from -L to L. diff --git a/docs/src/dev-docs/utils/data/readers.rst b/docs/src/dev-docs/utils/data/readers.rst index 17a173b72..b6dc726bb 100644 --- a/docs/src/dev-docs/utils/data/readers.rst +++ b/docs/src/dev-docs/utils/data/readers.rst @@ -13,10 +13,13 @@ Parsers for obtaining *system* and *target* information from files. Currently, * - ``ase`` - system, energy, forces, stress, virials - ``.xyz``, ``.extxyz`` + * - ``metatensor`` + - system, energy, forces, stress, virials + - ``.npz`` -If the ``reader`` parameter is not set the library is determined from the file -extension. Overriding this behavior is in particular useful, if a file format is not +If the ``reader`` parameter is not set, the library is determined from the file +extension. Overriding this behavior is in particular useful if a file format is not listed here but might be supported by a library. Below the synopsis of the reader functions in details. @@ -24,19 +27,38 @@ Below the synopsis of the reader functions in details. System and target data readers ============================== -The main entry point for reading system and target information are the reader functions +The main entry point for reading system and target information are the reader functions. .. autofunction:: metatrain.utils.data.read_systems .. autofunction:: metatrain.utils.data.read_targets -Target type specific readers ----------------------------- +These functions dispatch the reading of the system and target information to the +appropriate readers, based on the file extension or the user-provided library. + +In addition, the read_targets function uses the user-provided information about the +targets to call the appropriate target reader function (for energy targets or generic +targets). + +ASE +--- + +This section describes the parsers for the ASE library. + +.. autofunction:: metatrain.utils.data.readers.ase.read_systems +.. autofunction:: metatrain.utils.data.readers.ase.read_energy +.. autofunction:: metatrain.utils.data.readers.ase.read_generic + +It should be noted that :func:`metatrain.utils.data.readers.ase.read_energy` currently +uses sub-functions to parse the energy and its gradients like ``forces``, ``virial`` +and ``stress``. + +Metatensor +---------- -:func:`metatrain.utils.data.read_targets` uses sub-functions to parse supported -target properties like the ``energy`` or ``forces``. Currently we support reading the -following target properties via +This section describes the parsers for the ``metatensor`` library. As the systems and/or +targets are already stored in the ``metatensor`` format, these reader functions mainly +perform checks and return the data. -.. autofunction:: metatrain.utils.data.read_energy -.. autofunction:: metatrain.utils.data.read_forces -.. autofunction:: metatrain.utils.data.read_virial -.. autofunction:: metatrain.utils.data.read_stress +.. autofunction:: metatrain.utils.data.readers.metatensor.read_systems +.. autofunction:: metatrain.utils.data.readers.metatensor.read_energy +.. autofunction:: metatrain.utils.data.readers.metatensor.read_generic diff --git a/examples/programmatic/llpr/llpr.py b/examples/programmatic/llpr/llpr.py index e17c15148..7a82ec25f 100644 --- a/examples/programmatic/llpr/llpr.py +++ b/examples/programmatic/llpr/llpr.py @@ -64,6 +64,9 @@ "reader": "ase", "key": "energy", "unit": "kcal/mol", + "type": "scalar", + "per_atom": False, + "num_subtargets": 1, "forces": False, "stress": False, "virial": False, diff --git a/examples/programmatic/llpr_forces/force_llpr.py b/examples/programmatic/llpr_forces/force_llpr.py index 155e33527..1c2f8b86b 100644 --- a/examples/programmatic/llpr_forces/force_llpr.py +++ b/examples/programmatic/llpr_forces/force_llpr.py @@ -27,6 +27,9 @@ "reader": "ase", "key": "energy", "unit": "kcal/mol", + "type": "scalar", + "per_atom": False, + "num_subtargets": 1, "forces": { "read_from": "train.xyz", "file_format": ".xyz", @@ -53,6 +56,9 @@ "reader": "ase", "key": "energy", "unit": "kcal/mol", + "type": "scalar", + "per_atom": False, + "num_subtargets": 1, "forces": { "read_from": "valid.xyz", "file_format": ".xyz", @@ -79,6 +85,9 @@ "reader": "ase", "key": "energy", "unit": "kcal/mol", + "type": "scalar", + "per_atom": False, + "num_subtargets": 1, "forces": { "read_from": "test.xyz", "file_format": ".xyz", diff --git a/src/metatrain/experimental/alchemical_model/tests/test_exported.py b/src/metatrain/experimental/alchemical_model/tests/test_exported.py index 04874e6c3..3a900b57b 100644 --- a/src/metatrain/experimental/alchemical_model/tests/test_exported.py +++ b/src/metatrain/experimental/alchemical_model/tests/test_exported.py @@ -3,12 +3,12 @@ from metatensor.torch.atomistic import ModelEvaluationOptions, System from metatrain.experimental.alchemical_model import AlchemicalModel -from metatrain.utils.data import DatasetInfo, TargetInfo +from metatrain.utils.data import DatasetInfo +from metatrain.utils.data.target_info import get_energy_target_info from metatrain.utils.neighbor_lists import ( get_requested_neighbor_lists, get_system_with_neighbor_lists, ) -from metatrain.utils.testing import energy_layout from . import MODEL_HYPERS @@ -23,9 +23,7 @@ def test_to(device, dtype): dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], - targets={ - "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) - }, + targets={"energy": get_energy_target_info({"unit": "eV"})}, ) model = AlchemicalModel(MODEL_HYPERS, dataset_info).to(dtype=dtype) exported = model.export() diff --git a/src/metatrain/experimental/alchemical_model/tests/test_functionality.py b/src/metatrain/experimental/alchemical_model/tests/test_functionality.py index 6b2eae8df..e46ccac73 100644 --- a/src/metatrain/experimental/alchemical_model/tests/test_functionality.py +++ b/src/metatrain/experimental/alchemical_model/tests/test_functionality.py @@ -2,12 +2,12 @@ from metatensor.torch.atomistic import ModelEvaluationOptions, System from metatrain.experimental.alchemical_model import AlchemicalModel -from metatrain.utils.data import DatasetInfo, TargetInfo +from metatrain.utils.data import DatasetInfo +from metatrain.utils.data.target_info import get_energy_target_info from metatrain.utils.neighbor_lists import ( get_requested_neighbor_lists, get_system_with_neighbor_lists, ) -from metatrain.utils.testing import energy_layout from . import MODEL_HYPERS @@ -19,9 +19,7 @@ def test_prediction_subset_elements(): dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], - targets={ - "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) - }, + targets={"energy": get_energy_target_info({"unit": "eV"})}, ) model = AlchemicalModel(MODEL_HYPERS, dataset_info) diff --git a/src/metatrain/experimental/alchemical_model/tests/test_invariance.py b/src/metatrain/experimental/alchemical_model/tests/test_invariance.py index c8a42676d..601604f0a 100644 --- a/src/metatrain/experimental/alchemical_model/tests/test_invariance.py +++ b/src/metatrain/experimental/alchemical_model/tests/test_invariance.py @@ -5,12 +5,12 @@ from metatensor.torch.atomistic import ModelEvaluationOptions, systems_to_torch from metatrain.experimental.alchemical_model import AlchemicalModel -from metatrain.utils.data import DatasetInfo, TargetInfo +from metatrain.utils.data import DatasetInfo +from metatrain.utils.data.target_info import get_energy_target_info from metatrain.utils.neighbor_lists import ( get_requested_neighbor_lists, get_system_with_neighbor_lists, ) -from metatrain.utils.testing import energy_layout from . import DATASET_PATH, MODEL_HYPERS @@ -21,9 +21,7 @@ def test_rotational_invariance(): dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], - targets={ - "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) - }, + targets={"energy": get_energy_target_info({"unit": "eV"})}, ) model = AlchemicalModel(MODEL_HYPERS, dataset_info) diff --git a/src/metatrain/experimental/alchemical_model/tests/test_regression.py b/src/metatrain/experimental/alchemical_model/tests/test_regression.py index 0d82379c3..ab94cf2c6 100644 --- a/src/metatrain/experimental/alchemical_model/tests/test_regression.py +++ b/src/metatrain/experimental/alchemical_model/tests/test_regression.py @@ -6,18 +6,12 @@ from omegaconf import OmegaConf from metatrain.experimental.alchemical_model import AlchemicalModel, Trainer -from metatrain.utils.data import ( - Dataset, - DatasetInfo, - TargetInfo, - read_systems, - read_targets, -) +from metatrain.utils.data import Dataset, DatasetInfo, read_systems, read_targets +from metatrain.utils.data.target_info import get_energy_target_info from metatrain.utils.neighbor_lists import ( get_requested_neighbor_lists, get_system_with_neighbor_lists, ) -from metatrain.utils.testing import energy_layout from . import DATASET_PATH, DEFAULT_HYPERS, MODEL_HYPERS @@ -32,7 +26,7 @@ def test_regression_init(): """Perform a regression test on the model at initialization""" targets = {} - targets["mtt::U0"] = TargetInfo(quantity="energy", unit="eV", layout=energy_layout) + targets["mtt::U0"] = get_energy_target_info({"unit": "eV"}) dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=targets @@ -90,6 +84,9 @@ def test_regression_train(): "reader": "ase", "key": "U0", "unit": "eV", + "type": "scalar", + "per_atom": False, + "num_subtargets": 1, "forces": False, "stress": False, "virial": False, diff --git a/src/metatrain/experimental/alchemical_model/tests/test_torch_alchemical_compatibility.py b/src/metatrain/experimental/alchemical_model/tests/test_torch_alchemical_compatibility.py index 9f17dd9d0..f9ceee10e 100644 --- a/src/metatrain/experimental/alchemical_model/tests/test_torch_alchemical_compatibility.py +++ b/src/metatrain/experimental/alchemical_model/tests/test_torch_alchemical_compatibility.py @@ -14,9 +14,9 @@ from metatrain.experimental.alchemical_model.utils import ( systems_to_torch_alchemical_batch, ) -from metatrain.utils.data import DatasetInfo, TargetInfo, read_systems +from metatrain.utils.data import DatasetInfo, read_systems +from metatrain.utils.data.target_info import get_energy_target_info from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists -from metatrain.utils.testing import energy_layout from . import MODEL_HYPERS, QM9_DATASET_PATH @@ -72,9 +72,7 @@ def test_alchemical_model_inference(): dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=unique_numbers, - targets={ - "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) - }, + targets={"energy": get_energy_target_info({"unit": "eV"})}, ) alchemical_model = AlchemicalModel(MODEL_HYPERS, dataset_info) diff --git a/src/metatrain/experimental/alchemical_model/tests/test_torchscript.py b/src/metatrain/experimental/alchemical_model/tests/test_torchscript.py index 30d5511d7..fbac2fe6b 100644 --- a/src/metatrain/experimental/alchemical_model/tests/test_torchscript.py +++ b/src/metatrain/experimental/alchemical_model/tests/test_torchscript.py @@ -1,8 +1,8 @@ import torch from metatrain.experimental.alchemical_model import AlchemicalModel -from metatrain.utils.data import DatasetInfo, TargetInfo -from metatrain.utils.testing import energy_layout +from metatrain.utils.data import DatasetInfo +from metatrain.utils.data.target_info import get_energy_target_info from . import MODEL_HYPERS @@ -13,9 +13,7 @@ def test_torchscript(): dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], - targets={ - "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) - }, + targets={"energy": get_energy_target_info({"unit": "eV"})}, ) model = AlchemicalModel(MODEL_HYPERS, dataset_info) @@ -28,9 +26,7 @@ def test_torchscript_save_load(): dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], - targets={ - "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) - }, + targets={"energy": get_energy_target_info({"unit": "eV"})}, ) model = AlchemicalModel(MODEL_HYPERS, dataset_info) torch.jit.save( diff --git a/src/metatrain/experimental/gap/tests/test_errors.py b/src/metatrain/experimental/gap/tests/test_errors.py index 791e061c6..3f959a869 100644 --- a/src/metatrain/experimental/gap/tests/test_errors.py +++ b/src/metatrain/experimental/gap/tests/test_errors.py @@ -8,14 +8,8 @@ from omegaconf import OmegaConf from metatrain.experimental.gap import GAP, Trainer -from metatrain.utils.data import ( - Dataset, - DatasetInfo, - TargetInfo, - read_systems, - read_targets, -) -from metatrain.utils.testing import energy_force_layout +from metatrain.utils.data import Dataset, DatasetInfo, read_systems, read_targets +from metatrain.utils.data.target_info import get_energy_target_info from . import DATASET_ETHANOL_PATH, DEFAULT_HYPERS @@ -43,6 +37,9 @@ def test_ethanol_regression_train_and_invariance(): "reader": "ase", "key": "energy", "unit": "kcal/mol", + "type": "scalar", + "per_atom": False, + "num_subtargets": 1, "forces": { "read_from": DATASET_ETHANOL_PATH, "reader": "ase", @@ -62,8 +59,8 @@ def test_ethanol_regression_train_and_invariance(): hypers["model"]["krr"]["num_sparse_points"] = 30 target_info_dict = { - "energy": TargetInfo( - quantity="energy", unit="kcal/mol", layout=energy_force_layout + "energy": get_energy_target_info( + {"unit": "kcal/mol"}, add_position_gradients=True ) } diff --git a/src/metatrain/experimental/gap/tests/test_regression.py b/src/metatrain/experimental/gap/tests/test_regression.py index a3d703e76..5eef83612 100644 --- a/src/metatrain/experimental/gap/tests/test_regression.py +++ b/src/metatrain/experimental/gap/tests/test_regression.py @@ -8,9 +8,9 @@ from omegaconf import OmegaConf from metatrain.experimental.gap import GAP, Trainer -from metatrain.utils.data import Dataset, DatasetInfo, TargetInfo +from metatrain.utils.data import Dataset, DatasetInfo from metatrain.utils.data.readers import read_systems, read_targets -from metatrain.utils.testing import energy_force_layout, energy_layout +from metatrain.utils.data.target_info import get_energy_target_info from . import DATASET_ETHANOL_PATH, DATASET_PATH, DEFAULT_HYPERS @@ -27,7 +27,7 @@ def test_regression_init(): """Perform a regression test on the model at initialization""" targets = {} - targets["mtt::U0"] = TargetInfo(quantity="energy", unit="eV", layout=energy_layout) + targets["mtt::U0"] = get_energy_target_info({"unit": "eV"}) dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=targets @@ -50,6 +50,9 @@ def test_regression_train_and_invariance(): "reader": "ase", "key": "U0", "unit": "kcal/mol", + "type": "scalar", + "per_atom": False, + "num_subtargets": 1, "forces": False, "stress": False, "virial": False, @@ -59,9 +62,7 @@ def test_regression_train_and_invariance(): dataset = Dataset.from_dict({"system": systems, "mtt::U0": targets["mtt::U0"]}) target_info_dict = {} - target_info_dict["mtt::U0"] = TargetInfo( - quantity="energy", unit="eV", layout=energy_layout - ) + target_info_dict["mtt::U0"] = get_energy_target_info({"unit": "eV"}) dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=target_info_dict @@ -124,6 +125,9 @@ def test_ethanol_regression_train_and_invariance(): "read_from": DATASET_ETHANOL_PATH, "reader": "ase", "key": "energy", + "type": "scalar", + "per_atom": False, + "num_subtargets": 1, "forces": { "read_from": DATASET_ETHANOL_PATH, "reader": "ase", @@ -142,9 +146,7 @@ def test_ethanol_regression_train_and_invariance(): hypers["model"]["krr"]["num_sparse_points"] = 900 target_info_dict = { - "energy": TargetInfo( - quantity="energy", unit="kcal/mol", layout=energy_force_layout - ) + "energy": get_energy_target_info({"unit": "eV"}, add_position_gradients=True) } dataset_info = DatasetInfo( diff --git a/src/metatrain/experimental/gap/tests/test_torchscript.py b/src/metatrain/experimental/gap/tests/test_torchscript.py index e7d1cbc2f..04465366d 100644 --- a/src/metatrain/experimental/gap/tests/test_torchscript.py +++ b/src/metatrain/experimental/gap/tests/test_torchscript.py @@ -2,9 +2,9 @@ from omegaconf import OmegaConf from metatrain.experimental.gap import GAP, Trainer -from metatrain.utils.data import Dataset, DatasetInfo, TargetInfo +from metatrain.utils.data import Dataset, DatasetInfo from metatrain.utils.data.readers import read_systems, read_targets -from metatrain.utils.testing import energy_layout +from metatrain.utils.data.target_info import get_energy_target_info from . import DATASET_PATH, DEFAULT_HYPERS @@ -15,9 +15,7 @@ def test_torchscript(): """Tests that the model can be jitted.""" target_info_dict = {} - target_info_dict["mtt::U0"] = TargetInfo( - quantity="energy", unit="eV", layout=energy_layout - ) + target_info_dict["mtt::U0"] = get_energy_target_info({"unit": "eV"}) dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=target_info_dict @@ -29,6 +27,9 @@ def test_torchscript(): "reader": "ase", "key": "U0", "unit": "kcal/mol", + "type": "scalar", + "per_atom": False, + "num_subtargets": 1, "forces": False, "stress": False, "virial": False, @@ -66,7 +67,7 @@ def test_torchscript(): def test_torchscript_save(): """Tests that the model can be jitted and saved.""" targets = {} - targets["mtt::U0"] = TargetInfo(quantity="energy", unit="eV", layout=energy_layout) + targets["mtt::U0"] = get_energy_target_info({"unit": "eV"}) dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=targets diff --git a/src/metatrain/experimental/pet/tests/test_exported.py b/src/metatrain/experimental/pet/tests/test_exported.py index 08bdca71c..18cd5d5ad 100644 --- a/src/metatrain/experimental/pet/tests/test_exported.py +++ b/src/metatrain/experimental/pet/tests/test_exported.py @@ -11,12 +11,12 @@ from metatrain.experimental.pet import PET as WrappedPET from metatrain.utils.architectures import get_default_hypers -from metatrain.utils.data import DatasetInfo, TargetInfo +from metatrain.utils.data import DatasetInfo +from metatrain.utils.data.target_info import get_energy_target_info from metatrain.utils.neighbor_lists import ( get_requested_neighbor_lists, get_system_with_neighbor_lists, ) -from metatrain.utils.testing import energy_layout DEFAULT_HYPERS = get_default_hypers("experimental.pet") @@ -32,9 +32,7 @@ def test_to(device): dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], - targets={ - "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) - }, + targets={"energy": get_energy_target_info({"unit": "eV"})}, ) model = WrappedPET(DEFAULT_HYPERS["model"], dataset_info) ARCHITECTURAL_HYPERS = Hypers(model.hypers) diff --git a/src/metatrain/experimental/pet/tests/test_functionality.py b/src/metatrain/experimental/pet/tests/test_functionality.py index 2ebd3841f..e0c43446d 100644 --- a/src/metatrain/experimental/pet/tests/test_functionality.py +++ b/src/metatrain/experimental/pet/tests/test_functionality.py @@ -18,13 +18,13 @@ from metatrain.experimental.pet import PET as WrappedPET from metatrain.utils.architectures import get_default_hypers -from metatrain.utils.data import DatasetInfo, TargetInfo +from metatrain.utils.data import DatasetInfo +from metatrain.utils.data.target_info import get_energy_target_info from metatrain.utils.jsonschema import validate from metatrain.utils.neighbor_lists import ( get_requested_neighbor_lists, get_system_with_neighbor_lists, ) -from metatrain.utils.testing import energy_layout DEFAULT_HYPERS = get_default_hypers("experimental.pet") @@ -66,9 +66,7 @@ def test_prediction(): dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], - targets={ - "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) - }, + targets={"energy": get_energy_target_info({"unit": "eV"})}, ) model = WrappedPET(DEFAULT_HYPERS["model"], dataset_info) ARCHITECTURAL_HYPERS = Hypers(model.hypers) @@ -119,9 +117,7 @@ def test_per_atom_predictions_functionality(): dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], - targets={ - "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) - }, + targets={"energy": get_energy_target_info({"unit": "eV"})}, ) model = WrappedPET(DEFAULT_HYPERS["model"], dataset_info) ARCHITECTURAL_HYPERS = Hypers(model.hypers) @@ -173,9 +169,7 @@ def test_selected_atoms_functionality(): dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], - targets={ - "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) - }, + targets={"energy": get_energy_target_info({"unit": "eV"})}, ) model = WrappedPET(DEFAULT_HYPERS["model"], dataset_info) ARCHITECTURAL_HYPERS = Hypers(model.hypers) diff --git a/src/metatrain/experimental/pet/tests/test_pet_compatibility.py b/src/metatrain/experimental/pet/tests/test_pet_compatibility.py index 1a83371af..6684482e0 100644 --- a/src/metatrain/experimental/pet/tests/test_pet_compatibility.py +++ b/src/metatrain/experimental/pet/tests/test_pet_compatibility.py @@ -17,9 +17,9 @@ from metatrain.experimental.pet import PET as WrappedPET from metatrain.experimental.pet.utils import systems_to_batch_dict from metatrain.utils.architectures import get_default_hypers -from metatrain.utils.data import DatasetInfo, TargetInfo +from metatrain.utils.data import DatasetInfo +from metatrain.utils.data.target_info import get_energy_target_info from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists -from metatrain.utils.testing import energy_layout from . import DATASET_PATH @@ -98,9 +98,7 @@ def test_predictions_compatibility(cutoff): dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=structure.numbers, - targets={ - "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) - }, + targets={"energy": get_energy_target_info({"unit": "eV"})}, ) capabilities = ModelCapabilities( length_unit="Angstrom", diff --git a/src/metatrain/experimental/pet/tests/test_torchscript.py b/src/metatrain/experimental/pet/tests/test_torchscript.py index 15adc95a8..6f5623932 100644 --- a/src/metatrain/experimental/pet/tests/test_torchscript.py +++ b/src/metatrain/experimental/pet/tests/test_torchscript.py @@ -4,8 +4,8 @@ from metatrain.experimental.pet import PET as WrappedPET from metatrain.utils.architectures import get_default_hypers -from metatrain.utils.data import DatasetInfo, TargetInfo -from metatrain.utils.testing import energy_layout +from metatrain.utils.data import DatasetInfo +from metatrain.utils.data.target_info import get_energy_target_info DEFAULT_HYPERS = get_default_hypers("experimental.pet") @@ -17,9 +17,7 @@ def test_torchscript(): dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], - targets={ - "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) - }, + targets={"energy": get_energy_target_info({"unit": "eV"})}, ) model = WrappedPET(DEFAULT_HYPERS["model"], dataset_info) ARCHITECTURAL_HYPERS = Hypers(model.hypers) @@ -34,9 +32,7 @@ def test_torchscript_save_load(): dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], - targets={ - "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) - }, + targets={"energy": get_energy_target_info({"unit": "eV"})}, ) model = WrappedPET(DEFAULT_HYPERS["model"], dataset_info) ARCHITECTURAL_HYPERS = Hypers(model.hypers) diff --git a/src/metatrain/experimental/soap_bpnn/tests/test_continue.py b/src/metatrain/experimental/soap_bpnn/tests/test_continue.py index 50fa7118a..96eceebf8 100644 --- a/src/metatrain/experimental/soap_bpnn/tests/test_continue.py +++ b/src/metatrain/experimental/soap_bpnn/tests/test_continue.py @@ -6,9 +6,9 @@ from omegaconf import OmegaConf from metatrain.experimental.soap_bpnn import SoapBpnn, Trainer -from metatrain.utils.data import Dataset, DatasetInfo, TargetInfo +from metatrain.utils.data import Dataset, DatasetInfo from metatrain.utils.data.readers import read_systems, read_targets -from metatrain.utils.testing import energy_layout +from metatrain.utils.data.target_info import get_energy_target_info from . import DATASET_PATH, DEFAULT_HYPERS, MODEL_HYPERS @@ -24,9 +24,7 @@ def test_continue(monkeypatch, tmp_path): systems = [system.to(torch.float32) for system in systems] target_info_dict = {} - target_info_dict["mtt::U0"] = TargetInfo( - quantity="energy", unit="eV", layout=energy_layout - ) + target_info_dict["mtt::U0"] = get_energy_target_info({"unit": "eV"}) dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=target_info_dict @@ -41,6 +39,9 @@ def test_continue(monkeypatch, tmp_path): "reader": "ase", "key": "U0", "unit": "eV", + "type": "scalar", + "per_atom": False, + "num_subtargets": 1, "forces": False, "stress": False, "virial": False, diff --git a/src/metatrain/experimental/soap_bpnn/tests/test_exported.py b/src/metatrain/experimental/soap_bpnn/tests/test_exported.py index 421cc422a..b3e496c43 100644 --- a/src/metatrain/experimental/soap_bpnn/tests/test_exported.py +++ b/src/metatrain/experimental/soap_bpnn/tests/test_exported.py @@ -3,12 +3,12 @@ from metatensor.torch.atomistic import ModelEvaluationOptions, System from metatrain.experimental.soap_bpnn import SoapBpnn -from metatrain.utils.data import DatasetInfo, TargetInfo +from metatrain.utils.data import DatasetInfo +from metatrain.utils.data.target_info import get_energy_target_info from metatrain.utils.neighbor_lists import ( get_requested_neighbor_lists, get_system_with_neighbor_lists, ) -from metatrain.utils.testing import energy_layout from . import MODEL_HYPERS @@ -23,9 +23,7 @@ def test_to(device, dtype): dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], - targets={ - "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) - }, + targets={"energy": get_energy_target_info({"unit": "eV"})}, ) model = SoapBpnn(MODEL_HYPERS, dataset_info).to(dtype=dtype) exported = model.export() diff --git a/src/metatrain/experimental/soap_bpnn/tests/test_functionality.py b/src/metatrain/experimental/soap_bpnn/tests/test_functionality.py index 2e7cde5e7..98336336f 100644 --- a/src/metatrain/experimental/soap_bpnn/tests/test_functionality.py +++ b/src/metatrain/experimental/soap_bpnn/tests/test_functionality.py @@ -7,8 +7,8 @@ from metatrain.experimental.soap_bpnn import SoapBpnn from metatrain.utils.architectures import check_architecture_options -from metatrain.utils.data import DatasetInfo, TargetInfo -from metatrain.utils.testing import energy_layout +from metatrain.utils.data import DatasetInfo +from metatrain.utils.data.target_info import get_energy_target_info from . import DEFAULT_HYPERS, MODEL_HYPERS @@ -20,9 +20,7 @@ def test_prediction_subset_elements(): dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], - targets={ - "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) - }, + targets={"energy": get_energy_target_info({"unit": "eV"})}, ) model = SoapBpnn(MODEL_HYPERS, dataset_info) @@ -46,9 +44,7 @@ def test_prediction_subset_atoms(): dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], - targets={ - "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) - }, + targets={"energy": get_energy_target_info({"unit": "eV"})}, ) model = SoapBpnn(MODEL_HYPERS, dataset_info) @@ -116,9 +112,7 @@ def test_output_last_layer_features(): dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], - targets={ - "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) - }, + targets={"energy": get_energy_target_info({"unit": "eV"})}, ) model = SoapBpnn(MODEL_HYPERS, dataset_info) @@ -190,9 +184,7 @@ def test_output_per_atom(): dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], - targets={ - "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) - }, + targets={"energy": get_energy_target_info({"unit": "eV"})}, ) model = SoapBpnn(MODEL_HYPERS, dataset_info) diff --git a/src/metatrain/experimental/soap_bpnn/tests/test_invariance.py b/src/metatrain/experimental/soap_bpnn/tests/test_invariance.py index 92c96767d..28a8e8007 100644 --- a/src/metatrain/experimental/soap_bpnn/tests/test_invariance.py +++ b/src/metatrain/experimental/soap_bpnn/tests/test_invariance.py @@ -5,8 +5,8 @@ from metatensor.torch.atomistic import systems_to_torch from metatrain.experimental.soap_bpnn import SoapBpnn -from metatrain.utils.data import DatasetInfo, TargetInfo -from metatrain.utils.testing import energy_layout +from metatrain.utils.data import DatasetInfo +from metatrain.utils.data.target_info import get_energy_target_info from . import DATASET_PATH, MODEL_HYPERS @@ -17,9 +17,7 @@ def test_rotational_invariance(): dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], - targets={ - "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) - }, + targets={"energy": get_energy_target_info({"unit": "eV"})}, ) model = SoapBpnn(MODEL_HYPERS, dataset_info) diff --git a/src/metatrain/experimental/soap_bpnn/tests/test_regression.py b/src/metatrain/experimental/soap_bpnn/tests/test_regression.py index 35f691051..61d08c11b 100644 --- a/src/metatrain/experimental/soap_bpnn/tests/test_regression.py +++ b/src/metatrain/experimental/soap_bpnn/tests/test_regression.py @@ -6,9 +6,9 @@ from omegaconf import OmegaConf from metatrain.experimental.soap_bpnn import SoapBpnn, Trainer -from metatrain.utils.data import Dataset, DatasetInfo, TargetInfo +from metatrain.utils.data import Dataset, DatasetInfo from metatrain.utils.data.readers import read_systems, read_targets -from metatrain.utils.testing import energy_layout +from metatrain.utils.data.target_info import get_energy_target_info from . import DATASET_PATH, DEFAULT_HYPERS, MODEL_HYPERS @@ -23,7 +23,7 @@ def test_regression_init(): """Perform a regression test on the model at initialization""" targets = {} - targets["mtt::U0"] = TargetInfo(quantity="energy", unit="eV", layout=energy_layout) + targets["mtt::U0"] = get_energy_target_info({"unit": "eV"}) dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=targets @@ -71,6 +71,9 @@ def test_regression_train(): "reader": "ase", "key": "U0", "unit": "eV", + "type": "scalar", + "per_atom": False, + "num_subtargets": 1, "forces": False, "stress": False, "virial": False, diff --git a/src/metatrain/experimental/soap_bpnn/tests/test_torchscript.py b/src/metatrain/experimental/soap_bpnn/tests/test_torchscript.py index 53a7e161e..2e16ba26a 100644 --- a/src/metatrain/experimental/soap_bpnn/tests/test_torchscript.py +++ b/src/metatrain/experimental/soap_bpnn/tests/test_torchscript.py @@ -4,8 +4,8 @@ from metatensor.torch.atomistic import System from metatrain.experimental.soap_bpnn import SoapBpnn -from metatrain.utils.data import DatasetInfo, TargetInfo -from metatrain.utils.testing import energy_layout +from metatrain.utils.data import DatasetInfo +from metatrain.utils.data.target_info import get_energy_target_info from . import MODEL_HYPERS @@ -16,9 +16,7 @@ def test_torchscript(): dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], - targets={ - "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) - }, + targets={"energy": get_energy_target_info({"unit": "eV"})}, ) model = SoapBpnn(MODEL_HYPERS, dataset_info) model = torch.jit.script(model) @@ -43,9 +41,7 @@ def test_torchscript_with_identity(): dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], - targets={ - "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) - }, + targets={"energy": get_energy_target_info({"unit": "eV"})}, ) hypers = copy.deepcopy(MODEL_HYPERS) hypers["bpnn"]["layernorm"] = False @@ -72,9 +68,7 @@ def test_torchscript_save_load(): dataset_info = DatasetInfo( length_unit="Angstrom", atomic_types=[1, 6, 7, 8], - targets={ - "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) - }, + targets={"energy": get_energy_target_info({"unit": "eV"})}, ) model = SoapBpnn(MODEL_HYPERS, dataset_info) torch.jit.save( diff --git a/src/metatrain/share/schema-dataset.json b/src/metatrain/share/schema-dataset.json index ef33134dd..ac42264c9 100644 --- a/src/metatrain/share/schema-dataset.json +++ b/src/metatrain/share/schema-dataset.json @@ -132,6 +132,61 @@ } ] }, + "per_atom": { + "type": "boolean" + }, + "num_subtargets": { + "type": "integer" + }, + "type": { + "oneOf": [ + { + "type": "string", + "enum": ["scalar"] + }, + { + "type": "object", + "properties": { + "cartesian": { + "type": "object", + "properties": { + "rank": { "type": "integer" } + }, + "required": ["rank"], + "additionalProperties": false + } + }, + "required": ["cartesian"], + "additionalProperties": false + }, + { + "type": "object", + "properties": { + "spherical": { + "type": "object", + "properties": { + "irreps": { + "type": "array", + "items": { + "type": "object", + "properties": { + "o3_lambda": { "type": "integer" }, + "o3_sigma": { "type": "number" } + }, + "required": ["o3_lambda", "o3_sigma"], + "additionalProperties": false + } + } + }, + "required": ["irreps"], + "additionalProperties": false + } + }, + "required": ["spherical"], + "additionalProperties": false + } + ] + }, "forces": { "$ref": "#/$defs/gradient_section" }, diff --git a/src/metatrain/utils/data/__init__.py b/src/metatrain/utils/data/__init__.py index e93621947..a4886839d 100644 --- a/src/metatrain/utils/data/__init__.py +++ b/src/metatrain/utils/data/__init__.py @@ -1,6 +1,5 @@ from .dataset import ( # noqa: F401 Dataset, - TargetInfo, DatasetInfo, get_atomic_types, get_all_targets, @@ -8,15 +7,8 @@ check_datasets, get_stats, ) -from .readers import ( # noqa: F401 - read_energy, - read_forces, - read_stress, - read_systems, - read_targets, - read_virial, -) - +from .target_info import TargetInfo # noqa: F401 +from .readers import read_systems, read_targets # noqa: F401 from .writers import write_predictions # noqa: F401 from .combine_dataloaders import CombinedDataLoader # noqa: F401 from .system_to_ase import system_to_ase # noqa: F401 diff --git a/src/metatrain/utils/data/dataset.py b/src/metatrain/utils/data/dataset.py index 20c192503..673d28e43 100644 --- a/src/metatrain/utils/data/dataset.py +++ b/src/metatrain/utils/data/dataset.py @@ -2,7 +2,6 @@ import warnings from typing import Any, Dict, List, Tuple, Union -import metatensor.torch import numpy as np from metatensor.learn.data import Dataset, group_and_join from metatensor.torch import TensorMap @@ -10,202 +9,7 @@ from ..external_naming import to_external_name from ..units import get_gradient_units - - -class TargetInfo: - """A class that contains information about a target. - - :param quantity: The physical quantity of the target (e.g., "energy"). - :param layout: The layout of the target, as a ``TensorMap`` with 0 samples. - This ``TensorMap`` will be used to retrieve the names of - the ``samples``, as well as the ``components`` and ``properties`` of the - target and their gradients. For example, this allows to infer the type of - the target (scalar, Cartesian tensor, spherical tensor), whether it is per - atom, the names of its gradients, etc. - :param unit: The unit of the target. If :py:obj:`None` the ``unit`` will be set to - an empty string ``""``. - """ - - def __init__( - self, - quantity: str, - layout: TensorMap, - unit: Union[None, str] = "", - ): - # one of these will be set to True inside the _check_layout method - self._is_scalar = False - self._is_cartesian = False - self._is_spherical = False - - self._check_layout(layout) - - self.quantity = quantity # float64: otherwise metatensor can't serialize - self.layout = layout - self.unit = unit if unit is not None else "" - - @property - def is_scalar(self) -> bool: - """Whether the target is a scalar.""" - return self._is_scalar - - @property - def is_cartesian(self) -> bool: - """Whether the target is a Cartesian tensor.""" - return self._is_cartesian - - @property - def is_spherical(self) -> bool: - """Whether the target is a spherical tensor.""" - return self._is_spherical - - @property - def gradients(self) -> List[str]: - """Sorted and unique list of gradient names.""" - if self._is_scalar: - return sorted(self.layout.block().gradients_list()) - else: - return [] - - @property - def per_atom(self) -> bool: - """Whether the target is per atom.""" - return "atom" in self.layout.block(0).samples.names - - def __repr__(self): - return ( - f"TargetInfo(quantity={self.quantity!r}, unit={self.unit!r}, " - f"layout={self.layout!r})" - ) - - def __eq__(self, other): - if not isinstance(other, TargetInfo): - raise NotImplementedError( - "Comparison between a TargetInfo instance and a " - f"{type(other).__name__} instance is not implemented." - ) - return ( - self.quantity == other.quantity - and self.unit == other.unit - and metatensor.torch.equal(self.layout, other.layout) - ) - - def _check_layout(self, layout: TensorMap) -> None: - """Check that the layout is a valid layout.""" - - # examine basic properties of all blocks - for block in layout.blocks(): - for sample_name in block.samples.names: - if sample_name not in ["system", "atom"]: - raise ValueError( - "The layout ``TensorMap`` of a target should only have samples " - "named 'system' or 'atom', but found " - f"'{sample_name}' instead." - ) - if len(block.values) != 0: - raise ValueError( - "The layout ``TensorMap`` of a target should have 0 " - f"samples, but found {len(block.values)} samples." - ) - - # examine the components of the first block to decide whether this is - # a scalar, a Cartesian tensor or a spherical tensor - - if len(layout) == 0: - raise ValueError( - "The layout ``TensorMap`` of a target should have at least one " - "block, but found 0 blocks." - ) - components_first_block = layout.block(0).components - if len(components_first_block) == 0: - self._is_scalar = True - elif components_first_block[0].names[0].startswith("xyz"): - self._is_cartesian = True - elif ( - len(components_first_block) == 1 - and components_first_block[0].names[0] == "o3_mu" - ): - self._is_spherical = True - else: - raise ValueError( - "The layout ``TensorMap`` of a target should be " - "either scalars, Cartesian tensors or spherical tensors. The type of " - "the target could not be determined." - ) - - if self._is_scalar: - if layout.keys.names != ["_"]: - raise ValueError( - "The layout ``TensorMap`` of a scalar target should have " - "a single key sample named '_'." - ) - if len(layout.blocks()) != 1: - raise ValueError( - "The layout ``TensorMap`` of a scalar target should have " - "a single block." - ) - gradients_names = layout.block(0).gradients_list() - for gradient_name in gradients_names: - if gradient_name not in ["positions", "strain"]: - raise ValueError( - "Only `positions` and `strain` gradients are supported for " - "scalar targets. " - f"Found '{gradient_name}' instead." - ) - if self._is_cartesian: - if layout.keys.names != ["_"]: - raise ValueError( - "The layout ``TensorMap`` of a Cartesian tensor target should have " - "a single key sample named '_'." - ) - if len(layout.blocks()) != 1: - raise ValueError( - "The layout ``TensorMap`` of a Cartesian tensor target should have " - "a single block." - ) - if len(layout.block(0).gradients_list()) > 0: - raise ValueError( - "Gradients of Cartesian tensor targets are not supported." - ) - - if self._is_spherical: - if layout.keys.names != ["o3_lambda", "o3_sigma"]: - raise ValueError( - "The layout ``TensorMap`` of a spherical tensor target " - "should have two keys named 'o3_lambda' and 'o3_sigma'." - f"Found '{layout.keys.names}' instead." - ) - for key, block in layout.items(): - o3_lambda, o3_sigma = int(key.values[0].item()), int( - key.values[1].item() - ) - if o3_sigma not in [-1, 1]: - raise ValueError( - "The layout ``TensorMap`` of a spherical tensor target should " - "have a key sample 'o3_sigma' that is either -1 or 1." - f"Found '{o3_sigma}' instead." - ) - if o3_lambda < 0: - raise ValueError( - "The layout ``TensorMap`` of a spherical tensor target should " - "have a key sample 'o3_lambda' that is non-negative." - f"Found '{o3_lambda}' instead." - ) - components = block.components - if len(components) != 1: - raise ValueError( - "The layout ``TensorMap`` of a spherical tensor target should " - "have a single component." - ) - if len(components[0]) != 2 * o3_lambda + 1: - raise ValueError( - "Each ``TensorBlock`` of a spherical tensor target should have " - "a component with 2*o3_lambda + 1 elements." - f"Found '{len(components[0])}' elements instead." - ) - if len(block.gradients_list()) > 0: - raise ValueError( - "Gradients of spherical tensor targets are not supported." - ) +from .target_info import TargetInfo class DatasetInfo: diff --git a/src/metatrain/utils/data/get_dataset.py b/src/metatrain/utils/data/get_dataset.py index 502aea40f..e022027e7 100644 --- a/src/metatrain/utils/data/get_dataset.py +++ b/src/metatrain/utils/data/get_dataset.py @@ -2,8 +2,9 @@ from omegaconf import DictConfig -from .dataset import Dataset, TargetInfo +from .dataset import Dataset from .readers import read_systems, read_targets +from .target_info import TargetInfo def get_dataset(options: DictConfig) -> Tuple[Dataset, Dict[str, TargetInfo]]: diff --git a/src/metatrain/utils/data/readers/__init__.py b/src/metatrain/utils/data/readers/__init__.py index 39ca25823..6694cd3bf 100644 --- a/src/metatrain/utils/data/readers/__init__.py +++ b/src/metatrain/utils/data/readers/__init__.py @@ -1,8 +1,4 @@ from .readers import ( # noqa: F401 - read_energy, - read_forces, - read_stress, read_systems, read_targets, - read_virial, ) diff --git a/src/metatrain/utils/data/readers/ase.py b/src/metatrain/utils/data/readers/ase.py index 400e81e15..ac3abe568 100644 --- a/src/metatrain/utils/data/readers/ase.py +++ b/src/metatrain/utils/data/readers/ase.py @@ -1,10 +1,17 @@ +import logging import warnings -from typing import List +from typing import List, Tuple import ase.io import torch -from metatensor.torch import Labels, TensorBlock +from metatensor.torch import Labels, TensorBlock, TensorMap from metatensor.torch.atomistic import System, systems_to_torch +from omegaconf import DictConfig + +from ..target_info import TargetInfo, get_energy_target_info, get_generic_target_info + + +logger = logging.getLogger(__name__) def _wrapped_ase_io_read(filename): @@ -14,7 +21,7 @@ def _wrapped_ase_io_read(filename): raise ValueError(f"Failed to read '{filename}' with ASE: {e}") from e -def read_systems_ase(filename: str) -> List[System]: +def read_systems(filename: str) -> List[System]: """Store system informations using ase. :param filename: name of the file to read @@ -23,7 +30,7 @@ def read_systems_ase(filename: str) -> List[System]: return systems_to_torch(_wrapped_ase_io_read(filename), dtype=torch.float64) -def read_energy_ase(filename: str, key: str) -> List[TensorBlock]: +def _read_energy_ase(filename: str, key: str) -> List[TensorBlock]: """Store energy information in a List of :class:`metatensor.TensorBlock`. :param filename: name of the file to read @@ -56,7 +63,7 @@ def read_energy_ase(filename: str, key: str) -> List[TensorBlock]: return blocks -def read_forces_ase(filename: str, key: str = "energy") -> List[TensorBlock]: +def _read_forces_ase(filename: str, key: str = "energy") -> List[TensorBlock]: """Store force information in a List of :class:`metatensor.TensorBlock` which can be used as ``position`` gradients. @@ -99,7 +106,7 @@ def read_forces_ase(filename: str, key: str = "energy") -> List[TensorBlock]: return blocks -def read_virial_ase(filename: str, key: str = "virial") -> List[TensorBlock]: +def _read_virial_ase(filename: str, key: str = "virial") -> List[TensorBlock]: """Store virial information in a List of :class:`metatensor.TensorBlock` which can be used as ``strain`` gradients. @@ -110,7 +117,7 @@ def read_virial_ase(filename: str, key: str = "virial") -> List[TensorBlock]: return _read_virial_stress_ase(filename=filename, key=key, is_virial=True) -def read_stress_ase(filename: str, key: str = "stress") -> List[TensorBlock]: +def _read_stress_ase(filename: str, key: str = "stress") -> List[TensorBlock]: """Store stress information in a List of :class:`metatensor.TensorBlock` which can be used as ``strain`` gradients. @@ -182,3 +189,162 @@ def _read_virial_stress_ase( blocks.append(block) return blocks + + +def read_energy(target: DictConfig) -> Tuple[List[TensorMap], TargetInfo]: + target_key = target["key"] + + blocks = _read_energy_ase( + filename=target["read_from"], + key=target["key"], + ) + + add_position_gradients = False + if target["forces"]: + try: + position_gradients = _read_forces_ase( + filename=target["forces"]["read_from"], + key=target["forces"]["key"], + ) + except Exception: + logger.warning(f"No forces found in section {target_key!r}.") + add_position_gradients = False + else: + logger.info( + f"Forces found in section {target_key!r}, " + "we will use this gradient to train the model" + ) + for block, position_gradient in zip(blocks, position_gradients): + block.add_gradient(parameter="positions", gradient=position_gradient) + add_position_gradients = True + + if target["stress"] and target["virial"]: + raise ValueError("Cannot use stress and virial at the same time") + + add_strain_gradients = False + + if target["stress"]: + try: + strain_gradients = _read_stress_ase( + filename=target["stress"]["read_from"], + key=target["stress"]["key"], + ) + except Exception: + logger.warning(f"No stress found in section {target_key!r}.") + add_strain_gradients = False + else: + logger.info( + f"Stress found in section {target_key!r}, " + "we will use this gradient to train the model" + ) + for block, strain_gradient in zip(blocks, strain_gradients): + block.add_gradient(parameter="strain", gradient=strain_gradient) + add_strain_gradients = True + + if target["virial"]: + try: + strain_gradients = _read_virial_ase( + filename=target["virial"]["read_from"], + key=target["virial"]["key"], + ) + except Exception: + logger.warning(f"No virial found in section {target_key!r}.") + add_strain_gradients = False + else: + logger.info( + f"Virial found in section {target_key!r}, " + "we will use this gradient to train the model" + ) + for block, strain_gradient in zip(blocks, strain_gradients): + block.add_gradient(parameter="strain", gradient=strain_gradient) + add_strain_gradients = True + tensor_map_list = [ + TensorMap( + keys=Labels(["_"], torch.tensor([[0]])), + blocks=[block], + ) + for block in blocks + ] + target_info = get_energy_target_info( + target, add_position_gradients, add_strain_gradients + ) + return tensor_map_list, target_info + + +def read_generic(target: DictConfig) -> Tuple[List[TensorMap], TargetInfo]: + filename = target["read_from"] + frames = _wrapped_ase_io_read(filename) + + # we don't allow ASE to read spherical tensors with more than one irrep, + # otherwise it's a mess + if ( + isinstance(target["type"], DictConfig) + and next(iter(target["type"].keys())) == "spherical" + ): + irreps = target["type"]["spherical"]["irreps"] + if len(irreps) > 1: + raise ValueError( + "The metatrain ASE reader does not support reading " + "spherical tensors with more than one irreducible " + "representation. Please use the metatensor reader." + ) + + target_info = get_generic_target_info(target) + components = target_info.layout.block().components + properties = target_info.layout.block().properties + shape_after_samples = target_info.layout.block().shape[1:] + per_atom = target_info.per_atom + keys = target_info.layout.keys + + target_key = target["key"] + + tensor_maps = [] + for i_system, atoms in enumerate(frames): + + if not per_atom and target_key not in atoms.info: + raise ValueError( + f"Target key {target_key!r} was not found in system {filename!r} at " + f"index {i_system}" + ) + if per_atom and target_key not in atoms.arrays: + raise ValueError( + f"Target key {target_key!r} was not found in system {filename!r} at " + f"index {i_system}" + ) + + # here we reshape to allow for more flexibility; this is actually + # necessary for the `arrays`, which are stored in a 2D array + if per_atom: + values = torch.tensor( + atoms.arrays[target_key], dtype=torch.float64 + ).reshape([-1] + shape_after_samples) + else: + values = torch.tensor(atoms.info[target_key], dtype=torch.float64).reshape( + [-1] + shape_after_samples + ) + + samples = ( + Labels( + ["system", "atom"], + torch.tensor([[i_system, a] for a in range(len(values))]), + ) + if per_atom + else Labels( + ["system"], + torch.tensor([[i_system]]), + ) + ) + + block = TensorBlock( + values=values, + samples=samples, + components=components, + properties=properties, + ) + tensor_map = TensorMap( + keys=keys, + blocks=[block], + ) + tensor_maps.append(tensor_map) + + return tensor_maps, target_info diff --git a/src/metatrain/utils/data/readers/metatensor.py b/src/metatrain/utils/data/readers/metatensor.py new file mode 100644 index 000000000..2c35f4708 --- /dev/null +++ b/src/metatrain/utils/data/readers/metatensor.py @@ -0,0 +1,168 @@ +import logging +from typing import List, Tuple + +import metatensor.torch +import torch +from metatensor.torch import Labels, TensorBlock, TensorMap +from metatensor.torch.atomistic import System +from omegaconf import DictConfig + +from ..target_info import TargetInfo, get_energy_target_info, get_generic_target_info + + +logger = logging.getLogger(__name__) + + +def read_systems(filename: str) -> List[System]: + """Read system information using metatensor. + + :param filename: name of the file to read + + :raises NotImplementedError: Serialization of systems is not yet + available in metatensor. + """ + raise NotImplementedError("Reading metatensor systems is not yet implemented.") + + +def _wrapped_metatensor_read(filename) -> TensorMap: + try: + return metatensor.torch.load(filename) + except Exception as e: + raise ValueError(f"Failed to read '{filename}' with torch: {e}") from e + + +def read_energy(target: DictConfig) -> Tuple[TensorMap, TargetInfo]: + tensor_map = _wrapped_metatensor_read(target["read_from"]) + + if len(tensor_map) != 1: + raise ValueError("Energy TensorMaps should have exactly one block.") + + add_position_gradients = target["forces"] + add_strain_gradients = target["stress"] or target["virial"] + target_info = get_energy_target_info( + target, add_position_gradients, add_strain_gradients + ) + + # now check all the expected metadata (from target_info.layout) matches + # the actual metadata in the tensor maps + _check_tensor_map_metadata(tensor_map, target_info.layout) + + selections = [ + Labels( + names=["system"], + values=torch.tensor([[int(i)]]), + ) + for i in torch.unique( + torch.concatenate( + [block.samples.column("system") for block in tensor_map.blocks()] + ) + ) + ] + tensor_maps = metatensor.torch.split(tensor_map, "samples", selections) + return tensor_maps, target_info + + +def read_generic(target: DictConfig) -> Tuple[List[TensorMap], TargetInfo]: + tensor_map = _wrapped_metatensor_read(target["read_from"]) + + for block in tensor_map.blocks(): + if len(block.gradients_list()) > 0: + raise ValueError("Only energy targets can have gradient blocks.") + + target_info = get_generic_target_info(target) + _check_tensor_map_metadata(tensor_map, target_info.layout) + + # make sure that the properties of the target_info.layout also match the + # actual properties of the tensor maps + target_info.layout = _empty_tensor_map_like(tensor_map) + + selections = [ + Labels( + names=["system"], + values=torch.tensor([[int(i)]]), + ) + for i in torch.unique(tensor_map.block(0).samples.column("system")) + ] + tensor_maps = metatensor.torch.split(tensor_map, "samples", selections) + return tensor_maps, target_info + + +def _check_tensor_map_metadata(tensor_map: TensorMap, layout: TensorMap): + if tensor_map.keys != layout.keys: + raise ValueError( + f"Unexpected keys in metatensor targets: " + f"expected: {layout.keys} " + f"actual: {tensor_map.keys}" + ) + for key in layout.keys: + block = tensor_map.block(key) + block_from_layout = layout.block(key) + if block.samples.names != block_from_layout.samples.names: + raise ValueError( + f"Unexpected samples in metatensor targets: " + f"expected: {block_from_layout.samples.names} " + f"actual: {block.samples.names}" + ) + if block.components != block_from_layout.components: + raise ValueError( + f"Unexpected components in metatensor targets: " + f"expected: {block_from_layout.components} " + f"actual: {block.components}" + ) + # the properties can be different from those of the default `TensorMap` + # given by `get_generic_target_info`, so we don't check them + if set(block.gradients_list()) != set(block_from_layout.gradients_list()): + raise ValueError( + f"Unexpected gradients in metatensor targets: " + f"expected: {block_from_layout.gradients_list()} " + f"actual: {block.gradients_list()}" + ) + for name in block_from_layout.gradients_list(): + gradient_block = block.gradient(name) + gradient_block_from_layout = block_from_layout.gradient(name) + if gradient_block.labels.names != gradient_block_from_layout.labels.names: + raise ValueError( + f"Unexpected samples in metatensor targets " + f"for `{name}` gradient block: " + f"expected: {gradient_block_from_layout.labels.names} " + f"actual: {gradient_block.labels.names}" + ) + if gradient_block.components != gradient_block_from_layout.components: + raise ValueError( + f"Unexpected components in metatensor targets " + f"for `{name}` gradient block: " + f"expected: {gradient_block_from_layout.components} " + f"actual: {gradient_block.components}" + ) + + +def _empty_tensor_map_like(tensor_map: TensorMap) -> TensorMap: + new_keys = tensor_map.keys + new_blocks: List[TensorBlock] = [] + for block in tensor_map.blocks(): + new_block = _empty_tensor_block_like(block) + new_blocks.append(new_block) + return TensorMap(keys=new_keys, blocks=new_blocks) + + +def _empty_tensor_block_like(tensor_block: TensorBlock) -> TensorBlock: + new_block = TensorBlock( + values=torch.empty( + (0,) + tensor_block.values.shape[1:], + dtype=torch.float64, # metatensor can't serialize otherwise + device=tensor_block.values.device, + ), + samples=Labels( + names=tensor_block.samples.names, + values=torch.empty( + (0, tensor_block.samples.values.shape[1]), + dtype=tensor_block.samples.values.dtype, + device=tensor_block.samples.values.device, + ), + ), + components=tensor_block.components, + properties=tensor_block.properties, + ) + for gradient_name, gradient in tensor_block.gradients(): + new_block.add_gradient(gradient_name, _empty_tensor_block_like(gradient)) + return new_block diff --git a/src/metatrain/utils/data/readers/readers.py b/src/metatrain/utils/data/readers/readers.py index 7a6076143..1e7c1b8ae 100644 --- a/src/metatrain/utils/data/readers/readers.py +++ b/src/metatrain/utils/data/readers/readers.py @@ -1,34 +1,39 @@ import importlib import logging +import warnings from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple -import torch -from metatensor.torch import Labels, TensorBlock, TensorMap +from metatensor.torch import TensorMap from metatensor.torch.atomistic import System from omegaconf import DictConfig -from ..dataset import TargetInfo +from ..target_info import TargetInfo logger = logging.getLogger(__name__) -AVAILABLE_READERS = ["ase"] +AVAILABLE_READERS = ["ase", "metatensor"] """:py:class:`list`: list containing all implemented reader libraries""" -DEFAULT_READER = { - ".xyz": "ase", - ".extxyz": "ase", -} +DEFAULT_READER = {".xyz": "ase", ".extxyz": "ase", ".npz": "metatensor"} """:py:class:`dict`: dictionary mapping file extensions to a default reader""" -def _base_reader( - target: str, +def read_systems( filename: str, reader: Optional[str] = None, - **reader_kwargs, -) -> List[Any]: +) -> List[System]: + """Read system informations from a file. + + :param filename: name of the file to read + :param reader: reader library for parsing the file. If :py:obj:`None` the library is + is tried to determined from the file extension. + :param dtype: desired data type of returned tensor + :returns: list of systems + determined from the file extension. + :returns: list of systems stored in double precision + """ if reader is None: try: file_suffix = Path(filename).suffix @@ -51,112 +56,22 @@ def _base_reader( ) try: - reader_met = getattr(reader_mod, f"read_{target}_{reader}") + reader_met = reader_mod.read_systems except AttributeError: - raise ValueError(f"Reader library {reader!r} can't read {target!r}.") + raise ValueError( + f"Reader library {reader!r} cannot read systems." + f"You can try with other readers: {AVAILABLE_READERS}" + ) - data = reader_met(filename, **reader_kwargs) + systems = reader_met(filename) # elements in data are `torch.ScriptObject`s and their `dtype` is an integer. # A C++ double/torch.float64 is `7` according to # https://github.com/pytorch/pytorch/blob/207564bab1c4fe42750931765734ee604032fb69/c10/core/ScalarType.h#L54-L93 - assert all(d.dtype == 7 for d in data) - - return data - - -def read_energy( - filename: str, - target_value: str = "energy", - reader: Optional[str] = None, -) -> List[TensorBlock]: - """Read energy informations from a file. - - :param filename: name of the file to read - :param target_value: target value key name to be parsed from the file. - :param reader: reader library for parsing the file. If :py:obj:`None` the library is - is tried to determined from the file extension. - :returns: energy stored stored in double precision as a - :class:`metatensor.TensorBlock` - """ - return _base_reader( - target="energy", filename=filename, reader=reader, key=target_value - ) + if not all(s.dtype == 7 for s in systems): + raise ValueError("The loaded systems are not in double precision.") - -def read_forces( - filename: str, - target_value: str = "forces", - reader: Optional[str] = None, -) -> List[TensorBlock]: - """Read force informations from a file. - - :param filename: name of the file to read - :param target_value: target value key name to be parsed from the file - :param reader: reader library for parsing the file. If :py:obj:`None` the library is - is tried to determined from the file extension. - :returns: forces stored in double precision stored as a - :class:`metatensor.TensorBlock` - """ - return _base_reader( - target="forces", filename=filename, reader=reader, key=target_value - ) - - -def read_stress( - filename: str, - target_value: str = "stress", - reader: Optional[str] = None, -) -> List[TensorBlock]: - """Read stress informations from a file. - - :param filename: name of the file to read - :param target_value: target value key name to be parsed from the file. - :param reader: reader library for parsing the file. If :py:obj:`None` the library is - is tried to determined from the file extension. - :returns: stress stored in double precision as a :class:`metatensor.TensorBlock` - """ - return _base_reader( - target="stress", filename=filename, reader=reader, key=target_value - ) - - -def read_systems( - filename: str, - reader: Optional[str] = None, -) -> List[System]: - """Read system informations from a file. - - :param filename: name of the file to read - :param reader: reader library for parsing the file. If :py:obj:`None` the library is - is tried to determined from the file extension. - :param dtype: desired data type of returned tensor - :returns: list of systems - determined from the file extension. - :returns: list of systems stored in double precision - """ - return _base_reader(target="systems", filename=filename, reader=reader) - - -def read_virial( - filename: str, - target_value: str = "virial", - reader: Optional[str] = None, -) -> List[TensorBlock]: - """Read virial informations from a file. - - :param filename: name of the file to read - :param target_value: target value key name to be parsed from the file. - :param reader: reader library for parsing the file. If :py:obj:`None` the library is - is tried to determined from the file extension. - :returns: virial stored in double precision as a :class:`metatensor.TensorBlock` - """ - return _base_reader( - target="virial", - filename=filename, - reader=reader, - key=target_value, - ) + return systems def read_targets( @@ -188,138 +103,82 @@ def read_targets( standard_outputs_list = ["energy"] for target_key, target in conf.items(): - target_info_gradients: List[str] = [] is_standard_target = target_key in standard_outputs_list if not is_standard_target and not target_key.startswith("mtt::"): if target_key.lower() in ["force", "forces", "virial", "stress"]: - raise ValueError( - f"{target_key!r} should not be it's own top-level target, " - "but rather a sub-section of the 'energy' target" + warnings.warn( + f"{target_key!r} should not be its own top-level target, " + "but rather a sub-section of the 'energy' target", + stacklevel=2, ) else: raise ValueError( f"Target name ({target_key}) must either be one of " f"{standard_outputs_list} or start with `mtt::`." ) - - if target["quantity"] == "energy": - blocks = read_energy( - filename=target["read_from"], - target_value=target["key"], - reader=target["reader"], + if ( + "force" in target_key.lower() + or "virial" in target_key.lower() + or "stress" in target_key.lower() + ): + warnings.warn( + f"the name of {target_key!r} resembles to a gradient of " + "energies; it should probably not be its own top-level target, " + "but rather a gradient sub-section of a target with the " + "`energy` quantity", + stacklevel=2, ) - if target["forces"]: - try: - position_gradients = read_forces( - filename=target["forces"]["read_from"], - target_value=target["forces"]["key"], - reader=target["forces"]["reader"], - ) - except Exception: - logger.warning(f"No forces found in section {target_key!r}.") - else: - logger.info( - f"Forces found in section {target_key!r}, " - "we will use this gradient to train the model" - ) - for block, position_gradient in zip(blocks, position_gradients): - block.add_gradient( - parameter="positions", gradient=position_gradient - ) - - target_info_gradients.append("positions") - - if target["stress"] and target["virial"]: - raise ValueError("Cannot use stress and virial at the same time") - - if target["stress"]: - try: - strain_gradients = read_stress( - filename=target["stress"]["read_from"], - target_value=target["stress"]["key"], - reader=target["stress"]["reader"], - ) - except Exception: - logger.warning(f"No stress found in section {target_key!r}.") - else: - logger.info( - f"Stress found in section {target_key!r}, " - "we will use this gradient to train the model" - ) - for block, strain_gradient in zip(blocks, strain_gradients): - block.add_gradient(parameter="strain", gradient=strain_gradient) + is_energy = ( + (target["quantity"] == "energy") + and (not target["per_atom"]) + and target["num_subtargets"] == 1 + and target["type"] == "scalar" + ) + energy_or_generic = "energy" if is_energy else "generic" - target_info_gradients.append("strain") + reader = target["reader"] + filename = target["read_from"] - if target["virial"]: - try: - strain_gradients = read_virial( - filename=target["virial"]["read_from"], - target_value=target["virial"]["key"], - reader=target["virial"]["reader"], - ) - except Exception: - logger.warning(f"No virial found in section {target_key!r}.") - else: - logger.info( - f"Virial found in section {target_key!r}, " - "we will use this gradient to train the model" - ) - for block, strain_gradient in zip(blocks, strain_gradients): - block.add_gradient(parameter="strain", gradient=strain_gradient) + if reader is None: + try: + file_suffix = Path(filename).suffix + reader = DEFAULT_READER[file_suffix] + except KeyError: + raise ValueError( + f"File extension {file_suffix!r} is not linked to a default reader " + "library. You can try reading it by setting a specific 'reader' " + f"from the known ones: {', '.join(AVAILABLE_READERS)} " + ) - target_info_gradients.append("strain") - else: + try: + reader_mod = importlib.import_module( + name=f".{reader}", package="metatrain.utils.data.readers" + ) + except ImportError: raise ValueError( - f"Quantity: {target['quantity']!r} is not supported. Choose 'energy'." + f"Reader library {reader!r} not supported. Choose from " + f"{', '.join(AVAILABLE_READERS)}" ) - target_dictionary[target_key] = [ - TensorMap( - keys=Labels(["_"], torch.tensor([[0]])), - blocks=[block], + try: + reader_met = getattr(reader_mod, f"read_{energy_or_generic}") + except AttributeError: + raise ValueError( + f"Reader library {reader!r} cannot read {target!r}." + f"You can try with other readers: {AVAILABLE_READERS}" ) - for block in blocks - ] - target_info_dictionary[target_key] = TargetInfo( - quantity=target["quantity"], - unit=target["unit"], - layout=_empty_tensor_map_like(target_dictionary[target_key][0]), - ) - - return target_dictionary, target_info_dictionary + targets_as_list_of_tensor_maps, target_info = reader_met(target) + # elements in data are `torch.ScriptObject`s and their `dtype` is an integer. + # A C++ double/torch.float64 is `7` according to + # https://github.com/pytorch/pytorch/blob/207564bab1c4fe42750931765734ee604032fb69/c10/core/ScalarType.h#L54-L93 + if not all(t.dtype == 7 for t in targets_as_list_of_tensor_maps): + raise ValueError("The loaded targets are not in double precision.") -def _empty_tensor_map_like(tensor_map: TensorMap) -> TensorMap: - new_keys = tensor_map.keys - new_blocks: List[TensorBlock] = [] - for block in tensor_map.blocks(): - new_block = _empty_tensor_block_like(block) - new_blocks.append(new_block) - return TensorMap(keys=new_keys, blocks=new_blocks) + target_dictionary[target_key] = targets_as_list_of_tensor_maps + target_info_dictionary[target_key] = target_info - -def _empty_tensor_block_like(tensor_block: TensorBlock) -> TensorBlock: - new_block = TensorBlock( - values=torch.empty( - (0,) + tensor_block.values.shape[1:], - dtype=torch.float64, # metatensor can't serialize otherwise - device=tensor_block.values.device, - ), - samples=Labels( - names=tensor_block.samples.names, - values=torch.empty( - (0, tensor_block.samples.values.shape[1]), - dtype=tensor_block.samples.values.dtype, - device=tensor_block.samples.values.device, - ), - ), - components=tensor_block.components, - properties=tensor_block.properties, - ) - for gradient_name, gradient in tensor_block.gradients(): - new_block.add_gradient(gradient_name, _empty_tensor_block_like(gradient)) - return new_block + return target_dictionary, target_info_dictionary diff --git a/src/metatrain/utils/data/target_info.py b/src/metatrain/utils/data/target_info.py new file mode 100644 index 000000000..6511fd277 --- /dev/null +++ b/src/metatrain/utils/data/target_info.py @@ -0,0 +1,392 @@ +from typing import List, Union + +import metatensor.torch +import torch +from metatensor.torch import Labels, TensorBlock, TensorMap +from omegaconf import DictConfig + + +class TargetInfo: + """A class that contains information about a target. + + :param quantity: The physical quantity of the target (e.g., "energy"). + :param layout: The layout of the target, as a ``TensorMap`` with 0 samples. + This ``TensorMap`` will be used to retrieve the names of + the ``samples``, as well as the ``components`` and ``properties`` of the + target and their gradients. For example, this allows to infer the type of + the target (scalar, Cartesian tensor, spherical tensor), whether it is per + atom, the names of its gradients, etc. + :param unit: The unit of the target. If :py:obj:`None` the ``unit`` will be set to + an empty string ``""``. + """ + + def __init__( + self, + quantity: str, + layout: TensorMap, + unit: Union[None, str] = "", + ): + # one of these will be set to True inside the _check_layout method + self.is_scalar = False + self.is_cartesian = False + self.is_spherical = False + + self._check_layout(layout) + + self.quantity = quantity # float64: otherwise metatensor can't serialize + self.layout = layout + self.unit = unit if unit is not None else "" + + @property + def gradients(self) -> List[str]: + """Sorted and unique list of gradient names.""" + if self.is_scalar: + return sorted(self.layout.block().gradients_list()) + else: + return [] + + @property + def per_atom(self) -> bool: + """Whether the target is per atom.""" + return "atom" in self.layout.block(0).samples.names + + def __repr__(self): + return ( + f"TargetInfo(quantity={self.quantity!r}, unit={self.unit!r}, " + f"layout={self.layout!r})" + ) + + def __eq__(self, other): + if not isinstance(other, TargetInfo): + raise NotImplementedError( + "Comparison between a TargetInfo instance and a " + f"{type(other).__name__} instance is not implemented." + ) + return ( + self.quantity == other.quantity + and self.unit == other.unit + and metatensor.torch.equal(self.layout, other.layout) + ) + + def _check_layout(self, layout: TensorMap) -> None: + """Check that the layout is a valid layout.""" + + # examine basic properties of all blocks + for block in layout.blocks(): + for sample_name in block.samples.names: + if sample_name not in ["system", "atom"]: + raise ValueError( + "The layout ``TensorMap`` of a target should only have samples " + "named 'system' or 'atom', but found " + f"'{sample_name}' instead." + ) + if len(block.values) != 0: + raise ValueError( + "The layout ``TensorMap`` of a target should have 0 " + f"samples, but found {len(block.values)} samples." + ) + + # examine the components of the first block to decide whether this is + # a scalar, a Cartesian tensor or a spherical tensor + + if len(layout) == 0: + raise ValueError( + "The layout ``TensorMap`` of a target should have at least one " + "block, but found 0 blocks." + ) + components_first_block = layout.block(0).components + if len(components_first_block) == 0: + self.is_scalar = True + elif components_first_block[0].names[0].startswith("xyz"): + self.is_cartesian = True + elif ( + len(components_first_block) == 1 + and components_first_block[0].names[0] == "o3_mu" + ): + self.is_spherical = True + else: + raise ValueError( + "The layout ``TensorMap`` of a target should be " + "either scalars, Cartesian tensors or spherical tensors. The type of " + "the target could not be determined." + ) + + if self.is_scalar: + if layout.keys.names != ["_"]: + raise ValueError( + "The layout ``TensorMap`` of a scalar target should have " + "a single key sample named '_'." + ) + if len(layout.blocks()) != 1: + raise ValueError( + "The layout ``TensorMap`` of a scalar target should have " + "a single block." + ) + gradients_names = layout.block(0).gradients_list() + for gradient_name in gradients_names: + if gradient_name not in ["positions", "strain"]: + raise ValueError( + "Only `positions` and `strain` gradients are supported for " + "scalar targets. " + f"Found '{gradient_name}' instead." + ) + if self.is_cartesian: + if layout.keys.names != ["_"]: + raise ValueError( + "The layout ``TensorMap`` of a Cartesian tensor target should have " + "a single key sample named '_'." + ) + if len(layout.blocks()) != 1: + raise ValueError( + "The layout ``TensorMap`` of a Cartesian tensor target should have " + "a single block." + ) + if len(layout.block(0).gradients_list()) > 0: + raise ValueError( + "Gradients of Cartesian tensor targets are not supported." + ) + + if self.is_spherical: + if layout.keys.names != ["o3_lambda", "o3_sigma"]: + raise ValueError( + "The layout ``TensorMap`` of a spherical tensor target " + "should have two keys named 'o3_lambda' and 'o3_sigma'." + f"Found '{layout.keys.names}' instead." + ) + for key, block in layout.items(): + o3_lambda, o3_sigma = int(key.values[0].item()), int( + key.values[1].item() + ) + if o3_sigma not in [-1, 1]: + raise ValueError( + "The layout ``TensorMap`` of a spherical tensor target should " + "have a key sample 'o3_sigma' that is either -1 or 1." + f"Found '{o3_sigma}' instead." + ) + if o3_lambda < 0: + raise ValueError( + "The layout ``TensorMap`` of a spherical tensor target should " + "have a key sample 'o3_lambda' that is non-negative." + f"Found '{o3_lambda}' instead." + ) + components = block.components + if len(components) != 1: + raise ValueError( + "The layout ``TensorMap`` of a spherical tensor target should " + "have a single component." + ) + if len(components[0]) != 2 * o3_lambda + 1: + raise ValueError( + "Each ``TensorBlock`` of a spherical tensor target should have " + "a component with 2*o3_lambda + 1 elements." + f"Found '{len(components[0])}' elements instead." + ) + if len(block.gradients_list()) > 0: + raise ValueError( + "Gradients of spherical tensor targets are not supported." + ) + + +def get_energy_target_info( + target: DictConfig, + add_position_gradients: bool = False, + add_strain_gradients: bool = False, +) -> TargetInfo: + + block = TensorBlock( + # float64: otherwise metatensor can't serialize + values=torch.empty(0, 1, dtype=torch.float64), + samples=Labels( + names=["system"], + values=torch.empty((0, 1), dtype=torch.int32), + ), + components=[], + properties=Labels.range("energy", 1), + ) + + if add_position_gradients: + position_gradient_block = TensorBlock( + # float64: otherwise metatensor can't serialize + values=torch.empty(0, 3, 1, dtype=torch.float64), + samples=Labels( + names=["sample", "atom"], + values=torch.empty((0, 2), dtype=torch.int32), + ), + components=[ + Labels( + names=["xyz"], + values=torch.arange(3, dtype=torch.int32).reshape(-1, 1), + ), + ], + properties=Labels.range("energy", 1), + ) + block.add_gradient("positions", position_gradient_block) + + if add_strain_gradients: + strain_gradient_block = TensorBlock( + # float64: otherwise metatensor can't serialize + values=torch.empty(0, 3, 3, 1, dtype=torch.float64), + samples=Labels( + names=["sample", "atom"], + values=torch.empty((0, 2), dtype=torch.int32), + ), + components=[ + Labels( + names=["xyz_1"], + values=torch.arange(3, dtype=torch.int32).reshape(-1, 1), + ), + Labels( + names=["xyz_2"], + values=torch.arange(3, dtype=torch.int32).reshape(-1, 1), + ), + ], + properties=Labels.range("energy", 1), + ) + block.add_gradient("strain", strain_gradient_block) + + layout = TensorMap( + keys=Labels.single(), + blocks=[block], + ) + + target_info = TargetInfo( + quantity="energy", + unit=target["unit"], + layout=layout, + ) + return target_info + + +def get_generic_target_info(target: DictConfig) -> TargetInfo: + if target["type"] == "scalar": + return _get_scalar_target_info(target) + elif len(target["type"]) == 1 and next(iter(target["type"])).lower() == "cartesian": + return _get_cartesian_target_info(target) + elif len(target["type"]) == 1 and next(iter(target["type"])) == "spherical": + return _get_spherical_target_info(target) + else: + raise ValueError( + f"Target type {target['type']} is not supported. " + "Supported types are 'scalar', 'cartesian' and 'spherical'." + ) + + +def _get_scalar_target_info(target: DictConfig) -> TargetInfo: + sample_names = ["system"] + if target["per_atom"]: + sample_names.append("atom") + + block = TensorBlock( + # float64: otherwise metatensor can't serialize + values=torch.empty(0, target["num_subtargets"], dtype=torch.float64), + samples=Labels( + names=sample_names, + values=torch.empty((0, len(sample_names)), dtype=torch.int32), + ), + components=[], + properties=Labels.range("properties", target["num_subtargets"]), + ) + layout = TensorMap( + keys=Labels.single(), + blocks=[block], + ) + + target_info = TargetInfo( + quantity=target["quantity"], + unit=target["unit"], + layout=layout, + ) + return target_info + + +def _get_cartesian_target_info(target: DictConfig) -> TargetInfo: + sample_names = ["system"] + if target["per_atom"]: + sample_names.append("atom") + + cartesian_key = next(iter(target["type"])) + + if target["type"][cartesian_key]["rank"] == 1: + components = [Labels(["xyz"], torch.arange(3).reshape(-1, 1))] + else: + components = [] + for component in range(1, target["type"][cartesian_key]["rank"] + 1): + components.append( + Labels( + names=[f"xyz_{component}"], + values=torch.arange(3, dtype=torch.int32).reshape(-1, 1), + ) + ) + + block = TensorBlock( + # float64: otherwise metatensor can't serialize + values=torch.empty( + [0] + [3] * len(components) + [target["num_subtargets"]], + dtype=torch.float64, + ), + samples=Labels( + names=sample_names, + values=torch.empty((0, len(sample_names)), dtype=torch.int32), + ), + components=components, + properties=Labels.range("properties", target["num_subtargets"]), + ) + layout = TensorMap( + keys=Labels.single(), + blocks=[block], + ) + + target_info = TargetInfo( + quantity=target["quantity"], + unit=target["unit"], + layout=layout, + ) + return target_info + + +def _get_spherical_target_info(target: DictConfig) -> TargetInfo: + sample_names = ["system"] + if target["per_atom"]: + sample_names.append("atom") + + irreps = target["type"]["spherical"]["irreps"] + keys = [] + blocks = [] + for irrep in irreps: + components = [ + Labels( + names=["o3_mu"], + values=torch.arange( + -irrep["o3_lambda"], irrep["o3_lambda"] + 1, dtype=torch.int32 + ).reshape(-1, 1), + ) + ] + block = TensorBlock( + # float64: otherwise metatensor can't serialize + values=torch.empty( + 0, + 2 * irrep["o3_lambda"] + 1, + target["num_subtargets"], + dtype=torch.float64, + ), + samples=Labels( + names=sample_names, + values=torch.empty((0, len(sample_names)), dtype=torch.int32), + ), + components=components, + properties=Labels.range("properties", target["num_subtargets"]), + ) + keys.append([irrep["o3_lambda"], irrep["o3_sigma"]]) + blocks.append(block) + + layout = TensorMap( + keys=Labels(["o3_lambda", "o3_sigma"], torch.tensor(keys, dtype=torch.int32)), + blocks=blocks, + ) + + target_info = TargetInfo( + quantity=target["quantity"], + unit=target["unit"], + layout=layout, + ) + return target_info diff --git a/src/metatrain/utils/data/writers/xyz.py b/src/metatrain/utils/data/writers/xyz.py index da1478cbf..421a72d3a 100644 --- a/src/metatrain/utils/data/writers/xyz.py +++ b/src/metatrain/utils/data/writers/xyz.py @@ -57,13 +57,16 @@ def write_xyz( block = target_map.block() if "atom" in block.samples.names: # save inside arrays - arrays[target_name] = block.values.detach().cpu().numpy() + values = block.values.detach().cpu().numpy() + arrays[target_name] = values.reshape(values.shape[0], -1) + # reshaping here is necessary because `arrays` only accepts 2D arrays else: # save inside info if block.values.numel() == 1: info[target_name] = block.values.item() else: - info[target_name] = block.values.detach().cpu().numpy() + info[target_name] = block.values.detach().cpu().numpy().squeeze(0) + # squeeze the sample dimension, which corresponds to the system for gradient_name, gradient_block in block.gradients(): # here, we assume that gradients are always an array, and never a scalar diff --git a/src/metatrain/utils/omegaconf.py b/src/metatrain/utils/omegaconf.py index 3a223df5d..d704e3acd 100644 --- a/src/metatrain/utils/omegaconf.py +++ b/src/metatrain/utils/omegaconf.py @@ -96,6 +96,9 @@ def _resolve_single_str(config: str) -> DictConfig: "reader": None, "key": None, "unit": None, + "per_atom": False, + "type": "scalar", + "num_subtargets": 1, } ) @@ -108,7 +111,7 @@ def _resolve_single_str(config: str) -> DictConfig: } ) -KNWON_GRADIENTS = list(CONF_GRADIENTS.keys()) +KNOWN_GRADIENTS = list(CONF_GRADIENTS.keys()) # Merge configs to get default configs for energies and other targets CONF_TARGET = OmegaConf.merge(CONF_TARGET_FIELDS, CONF_GRADIENTS) @@ -253,7 +256,7 @@ def expand_dataset_config(conf: Union[str, DictConfig, ListConfig]) -> ListConfi for gradient_key, gradient_conf in conf_element["targets"][ target_key ].items(): - if gradient_key in KNWON_GRADIENTS: + if gradient_key in KNOWN_GRADIENTS: if gradient_conf is True: gradient_conf = CONF_GRADIENT.copy() elif type(gradient_conf) is str: diff --git a/src/metatrain/utils/testing.py b/src/metatrain/utils/testing.py deleted file mode 100644 index faedbdb00..000000000 --- a/src/metatrain/utils/testing.py +++ /dev/null @@ -1,76 +0,0 @@ -# This file contains some example TensorMap layouts that can be -# used for testing purposes. - -import torch -from metatensor.torch import Labels, TensorBlock, TensorMap - - -block = TensorBlock( - # float64: otherwise metatensor can't serialize - values=torch.empty(0, 1, dtype=torch.float64), - samples=Labels( - names=["system"], - values=torch.empty((0, 1), dtype=torch.int32), - ), - components=[], - properties=Labels.range("energy", 1), -) -energy_layout = TensorMap( - keys=Labels.single(), - blocks=[block], -) - -block_with_position_gradients = block.copy() -position_gradient_block = TensorBlock( - # float64: otherwise metatensor can't serialize - values=torch.empty(0, 3, 1, dtype=torch.float64), - samples=Labels( - names=["sample", "atom"], - values=torch.empty((0, 2), dtype=torch.int32), - ), - components=[ - Labels( - names=["xyz"], - values=torch.arange(3, dtype=torch.int32).reshape(-1, 1), - ), - ], - properties=Labels.range("energy", 1), -) -block_with_position_gradients.add_gradient("positions", position_gradient_block) -energy_force_layout = TensorMap( - keys=Labels.single(), - blocks=[block_with_position_gradients], -) - -block_with_position_and_strain_gradients = block_with_position_gradients.copy() -strain_gradient_block = TensorBlock( - # float64: otherwise metatensor can't serialize - values=torch.empty(0, 3, 3, 1, dtype=torch.float64), - samples=Labels( - names=["sample", "atom"], - values=torch.empty((0, 2), dtype=torch.int32), - ), - components=[ - Labels( - names=["xyz_1"], - values=torch.arange(3, dtype=torch.int32).reshape(-1, 1), - ), - Labels( - names=["xyz_2"], - values=torch.arange(3, dtype=torch.int32).reshape(-1, 1), - ), - ], - properties=Labels.range("energy", 1), -) -block_with_position_and_strain_gradients.add_gradient("strain", strain_gradient_block) -energy_force_stress_layout = TensorMap( - keys=Labels.single(), - blocks=[block_with_position_and_strain_gradients], -) - -block_with_strain_gradients = block.copy() -block_with_strain_gradients.add_gradient("strain", strain_gradient_block) -energy_stress_layout = TensorMap( - keys=Labels.single(), - blocks=[block_with_strain_gradients], -) diff --git a/tests/cli/test_eval_model.py b/tests/cli/test_eval_model.py index 6da30fb5c..d9eaf7827 100644 --- a/tests/cli/test_eval_model.py +++ b/tests/cli/test_eval_model.py @@ -10,8 +10,8 @@ from metatrain.cli.eval import eval_model from metatrain.experimental.soap_bpnn import __model__ -from metatrain.utils.data import DatasetInfo, TargetInfo -from metatrain.utils.testing import energy_layout +from metatrain.utils.data import DatasetInfo +from metatrain.utils.data.target_info import get_energy_target_info from . import EVAL_OPTIONS_PATH, MODEL_HYPERS, MODEL_PATH, RESOURCES_PATH @@ -84,9 +84,7 @@ def test_eval_export(monkeypatch, tmp_path, options): dataset_info = DatasetInfo( length_unit="angstrom", atomic_types={1, 6, 7, 8}, - targets={ - "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) - }, + targets={"energy": get_energy_target_info({"unit": "eV"})}, ) model = __model__(model_hypers=MODEL_HYPERS, dataset_info=dataset_info) diff --git a/tests/cli/test_export_model.py b/tests/cli/test_export_model.py index d9515d1b1..10feab713 100644 --- a/tests/cli/test_export_model.py +++ b/tests/cli/test_export_model.py @@ -14,9 +14,9 @@ from metatrain.cli.export import export_model from metatrain.experimental.soap_bpnn import __model__ from metatrain.utils.architectures import find_all_architectures -from metatrain.utils.data import DatasetInfo, TargetInfo +from metatrain.utils.data import DatasetInfo +from metatrain.utils.data.target_info import get_energy_target_info from metatrain.utils.io import load_model -from metatrain.utils.testing import energy_layout from . import MODEL_HYPERS, RESOURCES_PATH @@ -30,9 +30,7 @@ def test_export(monkeypatch, tmp_path, path, caplog): dataset_info = DatasetInfo( length_unit="angstrom", atomic_types={1}, - targets={ - "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) - }, + targets={"energy": get_energy_target_info({"unit": "eV"})}, ) model = __model__(model_hypers=MODEL_HYPERS, dataset_info=dataset_info) export_model(model, path) @@ -96,9 +94,7 @@ def test_reexport(monkeypatch, tmp_path): dataset_info = DatasetInfo( length_unit="angstrom", atomic_types={1, 6, 7, 8}, - targets={ - "energy": TargetInfo(quantity="energy", unit="eV", layout=energy_layout) - }, + targets={"energy": get_energy_target_info({"unit": "eV"})}, ) model = __model__(model_hypers=MODEL_HYPERS, dataset_info=dataset_info) diff --git a/tests/resources/test.yaml b/tests/resources/test.yaml new file mode 100644 index 000000000..a13ce734e --- /dev/null +++ b/tests/resources/test.yaml @@ -0,0 +1,21 @@ +seed: 42 + +architecture: + name: experimental.soap_bpnn + training: + batch_size: 2 + num_epochs: 1 + +training_set: + systems: + read_from: ethanol_reduced_100.xyz + length_unit: angstrom + targets: + forces: + quantity: force + key: forces + per_atom: true + num_subtargets: 3 + +test_set: 0.5 +validation_set: 0.1 diff --git a/tests/utils/data/test_combine_dataloaders.py b/tests/utils/data/test_combine_dataloaders.py index beb78fbdc..d0ef63929 100644 --- a/tests/utils/data/test_combine_dataloaders.py +++ b/tests/utils/data/test_combine_dataloaders.py @@ -30,6 +30,9 @@ def test_without_shuffling(): "reader": "ase", "key": "U0", "unit": "eV", + "type": "scalar", + "per_atom": False, + "num_subtargets": 1, "forces": False, "stress": False, "virial": False, @@ -49,6 +52,9 @@ def test_without_shuffling(): "reader": "ase", "key": "energy", "unit": "eV", + "type": "scalar", + "per_atom": False, + "num_subtargets": 1, "forces": False, "stress": False, "virial": False, @@ -88,6 +94,9 @@ def test_with_shuffling(): "reader": "ase", "key": "U0", "unit": "eV", + "type": "scalar", + "per_atom": False, + "num_subtargets": 1, "forces": False, "stress": False, "virial": False, @@ -109,6 +118,9 @@ def test_with_shuffling(): "reader": "ase", "key": "energy", "unit": "eV", + "type": "scalar", + "per_atom": False, + "num_subtargets": 1, "forces": False, "stress": False, "virial": False, diff --git a/tests/utils/data/test_dataset.py b/tests/utils/data/test_dataset.py index 9c76c5a7b..f14ef6b51 100644 --- a/tests/utils/data/test_dataset.py +++ b/tests/utils/data/test_dataset.py @@ -374,6 +374,9 @@ def test_dataset(): "reader": "ase", "key": "U0", "unit": "eV", + "type": "scalar", + "per_atom": False, + "num_subtargets": 1, "forces": False, "stress": False, "virial": False, @@ -400,6 +403,9 @@ def test_get_atomic_types(): "reader": "ase", "key": "U0", "unit": "eV", + "type": "scalar", + "per_atom": False, + "num_subtargets": 1, "forces": False, "stress": False, "virial": False, @@ -413,6 +419,9 @@ def test_get_atomic_types(): "reader": "ase", "key": "energy", "unit": "eV", + "type": "scalar", + "per_atom": False, + "num_subtargets": 1, "forces": False, "stress": False, "virial": False, @@ -439,6 +448,9 @@ def test_get_all_targets(): "reader": "ase", "key": "U0", "unit": "eV", + "type": "scalar", + "per_atom": False, + "num_subtargets": 1, "forces": False, "stress": False, "virial": False, @@ -452,6 +464,9 @@ def test_get_all_targets(): "reader": "ase", "key": "energy", "unit": "eV", + "type": "scalar", + "per_atom": False, + "num_subtargets": 1, "forces": False, "stress": False, "virial": False, @@ -477,6 +492,9 @@ def test_check_datasets(): "reader": "ase", "key": "U0", "unit": "eV", + "type": "scalar", + "per_atom": False, + "num_subtargets": 1, "forces": False, "stress": False, "virial": False, @@ -490,6 +508,9 @@ def test_check_datasets(): "reader": "ase", "key": "energy", "unit": "eV", + "type": "scalar", + "per_atom": False, + "num_subtargets": 1, "forces": False, "stress": False, "virial": False, @@ -551,6 +572,9 @@ def test_collate_fn(): "reader": "ase", "key": "U0", "unit": "eV", + "type": "scalar", + "per_atom": False, + "num_subtargets": 1, "forces": False, "stress": False, "virial": False, @@ -578,6 +602,9 @@ def test_get_stats(layout_scalar): "reader": "ase", "key": "U0", "unit": "eV", + "type": "scalar", + "per_atom": False, + "num_subtargets": 1, "forces": False, "stress": False, "virial": False, @@ -591,6 +618,9 @@ def test_get_stats(layout_scalar): "reader": "ase", "key": "energy", "unit": "eV", + "type": "scalar", + "per_atom": False, + "num_subtargets": 1, "forces": False, "stress": False, "virial": False, diff --git a/tests/utils/data/test_get_dataset.py b/tests/utils/data/test_get_dataset.py index 765f6a62b..0e8f687be 100644 --- a/tests/utils/data/test_get_dataset.py +++ b/tests/utils/data/test_get_dataset.py @@ -22,6 +22,9 @@ def test_get_dataset(): "reader": "ase", "key": "U0", "unit": "eV", + "type": "scalar", + "per_atom": False, + "num_subtargets": 1, "forces": False, "stress": False, "virial": False, diff --git a/tests/utils/data/test_readers.py b/tests/utils/data/test_readers.py index 5876d8e70..3999110db 100644 --- a/tests/utils/data/test_readers.py +++ b/tests/utils/data/test_readers.py @@ -11,16 +11,7 @@ from omegaconf import OmegaConf from test_targets_ase import ase_system, ase_systems -from metatrain.utils.data.dataset import TargetInfo -from metatrain.utils.data.readers import ( - read_energy, - read_forces, - read_stress, - read_systems, - read_targets, - read_virial, -) -from metatrain.utils.data.readers.readers import _base_reader +from metatrain.utils.data import TargetInfo, read_systems, read_targets @pytest.mark.parametrize("reader", (None, "ase")) @@ -58,84 +49,18 @@ def test_read_unknonw_library(): read_systems("foo.foo", reader="foo") -@pytest.mark.parametrize("reader", (None, "ase")) -def test_read_energies(reader, monkeypatch, tmp_path): - monkeypatch.chdir(tmp_path) - - filename = "systems.xyz" - systems = ase_systems() - ase.io.write(filename, systems) - - results = read_energy(filename, reader=reader, target_value="true_energy") - - assert type(results) is list - assert len(results) == len(systems) - for i_system, result in enumerate(results): - assert result.values.dtype is torch.float64 - assert result.samples.names == ["system"] - assert result.samples.values == torch.tensor([[i_system]]) - assert result.properties == Labels("energy", torch.tensor([[0]])) - - -@pytest.mark.parametrize("reader", (None, "ase")) -def test_read_forces(reader, monkeypatch, tmp_path): - monkeypatch.chdir(tmp_path) - - filename = "systems.xyz" - systems = ase_systems() - ase.io.write(filename, systems) - - results = read_forces(filename, reader=reader, target_value="forces") - - assert type(results) is list - assert len(results) == len(systems) - for i_system, result in enumerate(results): - assert result.values.dtype is torch.float64 - assert result.samples.names == ["sample", "system", "atom"] - assert torch.all(result.samples["sample"] == torch.tensor(0)) - assert torch.all(result.samples["system"] == torch.tensor(i_system)) - assert result.components == [Labels(["xyz"], torch.arange(3).reshape(-1, 1))] - assert result.properties == Labels("energy", torch.tensor([[0]])) - - -@pytest.mark.parametrize("reader_func", [read_stress, read_virial]) -@pytest.mark.parametrize("reader", (None, "ase")) -def test_read_stress_virial(reader_func, reader, monkeypatch, tmp_path): - monkeypatch.chdir(tmp_path) - - filename = "systems.xyz" - systems = ase_systems() - ase.io.write(filename, systems) - - results = reader_func(filename, reader=reader, target_value="stress-3x3") - - assert type(results) is list - assert len(results) == len(systems) - components = [ - Labels(["xyz_1"], torch.arange(3).reshape(-1, 1)), - Labels(["xyz_2"], torch.arange(3).reshape(-1, 1)), - ] - for result in results: - assert result.values.dtype is torch.float64 - assert result.samples.names == ["sample"] - assert result.samples.values == torch.tensor([[0]]) - assert result.components == components - assert result.properties == Labels("energy", torch.tensor([[0]])) - - -@pytest.mark.parametrize( - "reader_func", [read_energy, read_forces, read_stress, read_virial] -) -def test_reader_unknown_reader(reader_func): - match = "File extension '.bar' is not linked to a default reader" - with pytest.raises(ValueError, match=match): - reader_func("foo.bar", target_value="baz") - +def test_unsupported_target_name(): + conf = { + "free_energy": { + "quantity": "energy", + } + } -def test_reader_unknown_target(): - match = "Reader library 'ase' can't read 'mytarget'." - with pytest.raises(ValueError, match=match): - _base_reader(target="mytarget", filename="structures.xyz", reader="ase") + with pytest.raises( + ValueError, + match="start with `mtt::`", + ): + read_targets(OmegaConf.create(conf)) STRESS_VIRIAL_DICT = { @@ -162,6 +87,9 @@ def test_read_targets(stress_dict, virial_dict, monkeypatch, tmp_path, caplog): "reader": "ase", "key": "true_energy", "unit": "eV", + "type": "scalar", + "per_atom": False, + "num_subtargets": 1, "forces": {"read_from": filename, "reader": "ase", "key": "forces"}, "stress": stress_dict, "virial": virial_dict, @@ -246,6 +174,9 @@ def test_read_targets_warnings(stress_dict, virial_dict, monkeypatch, tmp_path, "reader": "ase", "key": "true_energy", "unit": "eV", + "type": "scalar", + "per_atom": False, + "num_subtargets": 1, "forces": {"read_from": filename, "reader": "ase", "key": "forces"}, "stress": stress_dict, "virial": virial_dict, @@ -276,6 +207,9 @@ def test_read_targets_error(monkeypatch, tmp_path): "read_from": filename, "reader": "ase", "key": "true_energy", + "type": "scalar", + "per_atom": False, + "num_subtargets": 1, "forces": {"read_from": filename, "reader": "ase", "key": "forces"}, "stress": True, "virial": True, @@ -290,29 +224,136 @@ def test_read_targets_error(monkeypatch, tmp_path): read_targets(OmegaConf.create(conf)) -def test_unsupported_quantity(): - conf = { - "energy": { - "quantity": "foo", - } +@pytest.mark.parametrize("key", ["stress-3x3", "stress-9"]) +def test_read_targets_generic_1(key, monkeypatch, tmp_path): + """Reads a 3x3 stress as a Cartesian vector with 3 properties.""" + monkeypatch.chdir(tmp_path) + + filename = "systems.xyz" + systems = ase_system() + ase.io.write(filename, systems) + + stress_section = { + "quantity": "stress", + "read_from": filename, + "reader": "ase", + "key": key, + "unit": "GPa", + "type": { + "cartesian": { + "rank": 1, + } + }, + "per_atom": False, + "num_subtargets": 3, } + conf = {"stress": stress_section} + with pytest.warns(UserWarning, match="should not be its own top-level target"): + with pytest.warns(UserWarning, match="resembles to a gradient of energies"): + read_targets(OmegaConf.create(conf)) + # this will trigger a shape error + conf["stress"]["type"]["cartesian"]["rank"] = 2 with pytest.raises( - ValueError, - match="Quantity: 'foo' is not supported. Choose 'energy'.", + RuntimeError, + match="shape", ): - read_targets(OmegaConf.create(conf)) + with pytest.warns(UserWarning, match="should not be its own top-level target"): + with pytest.warns(UserWarning, match="resembles to a gradient of energies"): + read_targets(OmegaConf.create(conf)) -def test_unsupported_target_name(): - conf = { - "free_energy": { - "quantity": "energy", - } +@pytest.mark.parametrize("key", ["stress-3x3", "stress-9"]) +def test_read_targets_generic_2(key, monkeypatch, tmp_path): + """Reads a 3x3 stress as a Cartesian rank-2 tensor.""" + monkeypatch.chdir(tmp_path) + + filename = "systems.xyz" + systems = ase_system() + ase.io.write(filename, systems) + + stress_section = { + "quantity": "stress", + "read_from": filename, + "reader": "ase", + "key": key, + "unit": "GPa", + "type": { + "cartesian": { + "rank": 2, + } + }, + "per_atom": False, + "num_subtargets": 1, } + conf = {"stress": stress_section} + with pytest.warns(UserWarning, match="should not be its own top-level target"): + with pytest.warns(UserWarning, match="resembles to a gradient of energies"): + read_targets(OmegaConf.create(conf)) + # this will trigger a shape error + conf["stress"]["type"]["cartesian"]["rank"] = 1 with pytest.raises( - ValueError, - match="start with `mtt::`", + RuntimeError, + match="shape", ): - read_targets(OmegaConf.create(conf)) + with pytest.warns(UserWarning, match="should not be its own top-level target"): + with pytest.warns(UserWarning, match="resembles to a gradient of energies"): + read_targets(OmegaConf.create(conf)) + + +@pytest.mark.parametrize("key", ["stress-3x3", "stress-9"]) +def test_read_targets_generic_3(key, monkeypatch, tmp_path): + """Reads a 3x3 stress as a scalar with 9 properties""" + monkeypatch.chdir(tmp_path) + + filename = "systems.xyz" + systems = ase_system() + ase.io.write(filename, systems) + + stress_section = { + "quantity": "stress", + "read_from": filename, + "reader": "ase", + "key": key, + "unit": "GPa", + "type": "scalar", + "per_atom": False, + "num_subtargets": 9, + } + conf = {"stress": stress_section} + with pytest.warns(UserWarning, match="should not be its own top-level target"): + with pytest.warns(UserWarning, match="resembles to a gradient of energies"): + read_targets(OmegaConf.create(conf)) + + +def test_read_targets_generic_errors(monkeypatch, tmp_path): + """Reads a 3x3 stress as a scalar with 9 properties""" + monkeypatch.chdir(tmp_path) + + filename = "systems.xyz" + systems = ase_system() + ase.io.write(filename, systems) + + stress_section = { + "quantity": "stress", + "read_from": filename, + "reader": "ase", + "key": "stress-3x3", + "unit": "GPa", + "type": { + "spherical": { + "irreps": [ + {"o3_lambda": 0, "o3_sigma": 1}, + {"o3_lambda": 2, "o3_sigma": 1}, + ] + } + }, + "per_atom": False, + "num_subtargets": 9, + } + conf = {"stress": stress_section} + with pytest.raises(ValueError, match="use the metatensor reader"): + with pytest.warns(UserWarning, match="should not be its own top-level target"): + with pytest.warns(UserWarning, match="resembles to a gradient of energies"): + read_targets(OmegaConf.create(conf)) diff --git a/tests/utils/data/test_readers_ase.py b/tests/utils/data/test_readers_ase.py new file mode 100644 index 000000000..99c1e5c2f --- /dev/null +++ b/tests/utils/data/test_readers_ase.py @@ -0,0 +1,80 @@ +"""Tests for the ASE readers. The functionality of the top-level functions +`read_systems`, `read_energy`, `read_generic` is already tested through +the reader tests in `test_readers.py`. Here we test the specific ASE readers +for energies, forces, stresses, and virials.""" + +import ase +import ase.io +import pytest +import torch +from metatensor.torch import Labels +from test_targets_ase import ase_systems + +from metatrain.utils.data.readers.ase import ( + _read_energy_ase, + _read_forces_ase, + _read_stress_ase, + _read_virial_ase, +) + + +def test_read_energies(monkeypatch, tmp_path): + monkeypatch.chdir(tmp_path) + + filename = "systems.xyz" + systems = ase_systems() + ase.io.write(filename, systems) + + results = _read_energy_ase(filename, key="true_energy") + + assert type(results) is list + assert len(results) == len(systems) + for i_system, result in enumerate(results): + assert result.values.dtype is torch.float64 + assert result.samples.names == ["system"] + assert result.samples.values == torch.tensor([[i_system]]) + assert result.properties == Labels("energy", torch.tensor([[0]])) + + +def test_read_forces(monkeypatch, tmp_path): + monkeypatch.chdir(tmp_path) + + filename = "systems.xyz" + systems = ase_systems() + ase.io.write(filename, systems) + + results = _read_forces_ase(filename, key="forces") + + assert type(results) is list + assert len(results) == len(systems) + for i_system, result in enumerate(results): + assert result.values.dtype is torch.float64 + assert result.samples.names == ["sample", "system", "atom"] + assert torch.all(result.samples["sample"] == torch.tensor(0)) + assert torch.all(result.samples["system"] == torch.tensor(i_system)) + assert result.components == [Labels(["xyz"], torch.arange(3).reshape(-1, 1))] + assert result.properties == Labels("energy", torch.tensor([[0]])) + + +@pytest.mark.parametrize("reader_func", [_read_stress_ase, _read_virial_ase]) +def test_read_stress_virial(reader_func, monkeypatch, tmp_path): + monkeypatch.chdir(tmp_path) + + filename = "systems.xyz" + systems = ase_systems() + ase.io.write(filename, systems) + + results = reader_func(filename, key="stress-3x3") + + assert type(results) is list + assert len(results) == len(systems) + components = [ + Labels(["xyz_1"], torch.arange(3).reshape(-1, 1)), + Labels(["xyz_2"], torch.arange(3).reshape(-1, 1)), + ] + for result in results: + assert result.values.dtype is torch.float64 + assert result.samples.names == ["sample"] + assert result.samples.values == torch.tensor([[0]]) + assert result.components == components + assert result.properties == Labels("energy", torch.tensor([[0]])) diff --git a/tests/utils/data/test_readers_metatensor.py b/tests/utils/data/test_readers_metatensor.py new file mode 100644 index 000000000..c48366fc4 --- /dev/null +++ b/tests/utils/data/test_readers_metatensor.py @@ -0,0 +1,290 @@ +import metatensor.torch +import numpy as np +import pytest +import torch +from metatensor.torch import Labels, TensorBlock, TensorMap +from omegaconf import OmegaConf + +from metatrain.utils.data.readers.metatensor import ( + read_energy, + read_generic, + read_systems, +) + + +@pytest.fixture +def energy_tensor_map(): + return TensorMap( + keys=Labels.single(), + blocks=[ + TensorBlock( + values=torch.rand(2, 1, dtype=torch.float64), + samples=Labels( + names=["system"], + values=torch.tensor([[0], [1]], dtype=torch.int32), + ), + components=[], + properties=Labels.range("energy", 1), + ) + ], + ) + + +@pytest.fixture +def scalar_tensor_map(): + return TensorMap( + keys=Labels.single(), + blocks=[ + TensorBlock( + values=torch.rand(3, 10, dtype=torch.float64), + samples=Labels( + names=["system", "atom"], + values=torch.tensor([[0, 0], [0, 1], [1, 0]], dtype=torch.int32), + ), + components=[], + properties=Labels.range("properties", 10), + ) + ], + ) + + +@pytest.fixture +def spherical_tensor_map(): + return TensorMap( + keys=Labels( + names=["o3_lambda", "o3_sigma"], + values=torch.tensor([[0, 1], [2, 1]]), + ), + blocks=[ + TensorBlock( + values=torch.rand(2, 1, 1, dtype=torch.float64), + samples=Labels( + names=["system"], + values=torch.tensor([[0], [1]], dtype=torch.int32), + ), + components=[ + Labels( + names=["o3_mu"], + values=torch.arange(0, 1, dtype=torch.int32).reshape(-1, 1), + ), + ], + properties=Labels.range("properties", 1), + ), + TensorBlock( + values=torch.rand(2, 5, 1, dtype=torch.float64), + samples=Labels( + names=["system"], + values=torch.tensor([[0], [1]], dtype=torch.int32), + ), + components=[ + Labels( + names=["o3_mu"], + values=torch.arange(-2, 3, dtype=torch.int32).reshape(-1, 1), + ), + ], + properties=Labels.range("properties", 1), + ), + ], + ) + + +@pytest.fixture +def cartesian_tensor_map(): + return TensorMap( + keys=Labels.single(), + blocks=[ + TensorBlock( + values=torch.rand(2, 3, 3, 1, dtype=torch.float64), + samples=Labels( + names=["system"], + values=torch.tensor([[0], [1]], dtype=torch.int32), + ), + components=[ + Labels( + names=["xyz_1"], + values=torch.arange(0, 3, dtype=torch.int32).reshape(-1, 1), + ), + Labels( + names=["xyz_2"], + values=torch.arange(0, 3, dtype=torch.int32).reshape(-1, 1), + ), + ], + properties=Labels.range("properties", 1), + ), + ], + ) + + +def test_read_systems(): + with pytest.raises(NotImplementedError): + read_systems("foo.npz") + + +def test_read_energy(monkeypatch, tmpdir, energy_tensor_map): + monkeypatch.chdir(tmpdir) + + metatensor.torch.save( + "energy.npz", + energy_tensor_map, + ) + + conf = { + "quantity": "energy", + "read_from": "energy.npz", + "reader": "metatensor", + "key": "true_energy", + "unit": "eV", + "type": "scalar", + "per_atom": False, + "num_subtargets": 1, + "forces": False, + "stress": False, + "virial": False, + } + + tensor_maps, target_info = read_energy(OmegaConf.create(conf)) + + tensor_map = metatensor.torch.join( + tensor_maps, axis="samples", remove_tensor_name=True + ) + assert metatensor.torch.equal(tensor_map, energy_tensor_map) + + +def test_read_generic_scalar(monkeypatch, tmpdir, scalar_tensor_map): + monkeypatch.chdir(tmpdir) + + metatensor.torch.save( + "generic.npz", + scalar_tensor_map, + ) + + conf = { + "quantity": "generic", + "read_from": "generic.npz", + "reader": "metatensor", + "keys": ["scalar"], + "per_atom": True, + "unit": "unit", + "type": "scalar", + "num_subtargets": 10, + } + + tensor_maps, target_info = read_generic(OmegaConf.create(conf)) + + tensor_map = metatensor.torch.join( + tensor_maps, axis="samples", remove_tensor_name=True + ) + assert metatensor.torch.equal(tensor_map, scalar_tensor_map) + + +def test_read_generic_spherical(monkeypatch, tmpdir, spherical_tensor_map): + monkeypatch.chdir(tmpdir) + + metatensor.torch.save( + "generic.npz", + spherical_tensor_map, + ) + + conf = { + "quantity": "generic", + "read_from": "generic.npz", + "reader": "metatensor", + "keys": ["o3_lambda", "o3_sigma"], + "per_atom": False, + "unit": "unit", + "type": { + "spherical": { + "irreps": [ + {"o3_lambda": 0, "o3_sigma": 1}, + {"o3_lambda": 2, "o3_sigma": 1}, + ], + }, + }, + "num_subtargets": 1, + } + + tensor_maps, target_info = read_generic(OmegaConf.create(conf)) + + tensor_map = metatensor.torch.join( + tensor_maps, axis="samples", remove_tensor_name=True + ) + assert metatensor.torch.equal(tensor_map, spherical_tensor_map) + + +def test_read_generic_cartesian(monkeypatch, tmpdir, cartesian_tensor_map): + monkeypatch.chdir(tmpdir) + + metatensor.torch.save( + "generic.npz", + cartesian_tensor_map, + ) + + conf = { + "quantity": "generic", + "read_from": "generic.npz", + "reader": "metatensor", + "keys": ["cartesian"], + "per_atom": False, + "unit": "unit", + "type": { + "cartesian": { + "rank": 2, + }, + }, + "num_subtargets": 1, + } + + tensor_maps, target_info = read_generic(OmegaConf.create(conf)) + + print(tensor_maps) + + tensor_map = metatensor.torch.join( + tensor_maps, axis="samples", remove_tensor_name=True + ) + print(tensor_map) + print(cartesian_tensor_map) + assert metatensor.torch.equal(tensor_map, cartesian_tensor_map) + + +def test_read_errors(monkeypatch, tmpdir, energy_tensor_map, scalar_tensor_map): + monkeypatch.chdir(tmpdir) + + metatensor.torch.save( + "energy.npz", + energy_tensor_map, + ) + + conf = { + "quantity": "energy", + "read_from": "energy.npz", + "reader": "metatensor", + "key": "true_energy", + "unit": "eV", + "type": "scalar", + "per_atom": False, + "num_subtargets": 1, + "forces": False, + "stress": False, + "virial": False, + } + + numpy_array = np.zeros((2, 2)) + np.save("numpy_array.npz", numpy_array) + conf["read_from"] = "numpy_array.npz" + with pytest.raises(ValueError, match="Failed to read"): + read_energy(OmegaConf.create(conf)) + conf["read_from"] = "energy.npz" + + conf["forces"] = True + with pytest.raises(ValueError, match="Unexpected gradients"): + read_energy(OmegaConf.create(conf)) + conf["forces"] = False + + metatensor.torch.save( + "scalar.npz", + scalar_tensor_map, + ) + + conf["read_from"] = "scalar.npz" + with pytest.raises(ValueError, match="Unexpected samples"): + read_generic(OmegaConf.create(conf)) diff --git a/tests/utils/data/test_target_info.py b/tests/utils/data/test_target_info.py new file mode 100644 index 000000000..f0b8d6757 --- /dev/null +++ b/tests/utils/data/test_target_info.py @@ -0,0 +1,119 @@ +import pytest +from omegaconf import DictConfig + +from metatrain.utils.data.target_info import ( + get_energy_target_info, + get_generic_target_info, +) + + +@pytest.fixture +def energy_target_config() -> DictConfig: + return DictConfig( + { + "quantity": "energy", + "unit": "eV", + "per_atom": False, + "num_subtargets": 1, + "type": "scalar", + } + ) + + +@pytest.fixture +def scalar_target_config() -> DictConfig: + return DictConfig( + { + "quantity": "scalar", + "unit": "", + "per_atom": False, + "num_subtargets": 10, + "type": "scalar", + } + ) + + +@pytest.fixture +def cartesian_target_config() -> DictConfig: + return DictConfig( + { + "quantity": "dipole", + "unit": "D", + "per_atom": True, + "num_subtargets": 5, + "type": { + "Cartesian": { + "rank": 1, + } + }, + } + ) + + +@pytest.fixture +def spherical_target_config() -> DictConfig: + return DictConfig( + { + "quantity": "spherical", + "unit": "", + "per_atom": False, + "num_subtargets": 1, + "type": { + "spherical": { + "irreps": [ + {"o3_lambda": 0, "o3_sigma": 1}, + {"o3_lambda": 2, "o3_sigma": 1}, + ], + }, + }, + } + ) + + +def test_layout_energy(energy_target_config): + + target_info = get_energy_target_info(energy_target_config) + assert target_info.quantity == "energy" + assert target_info.unit == "eV" + assert target_info.per_atom is False + assert target_info.gradients == [] + + target_info = get_energy_target_info( + energy_target_config, add_position_gradients=True + ) + assert target_info.quantity == "energy" + assert target_info.unit == "eV" + assert target_info.per_atom is False + assert target_info.gradients == ["positions"] + + target_info = get_energy_target_info( + energy_target_config, add_position_gradients=True, add_strain_gradients=True + ) + assert target_info.quantity == "energy" + assert target_info.unit == "eV" + assert target_info.per_atom is False + assert target_info.gradients == ["positions", "strain"] + + +def test_layout_scalar(scalar_target_config): + target_info = get_generic_target_info(scalar_target_config) + assert target_info.quantity == "scalar" + assert target_info.unit == "" + assert target_info.per_atom is False + assert target_info.gradients == [] + + +def test_layout_cartesian(cartesian_target_config): + target_info = get_generic_target_info(cartesian_target_config) + assert target_info.quantity == "dipole" + assert target_info.unit == "D" + assert target_info.per_atom is True + assert target_info.gradients == [] + + +def test_layout_spherical(spherical_target_config): + target_info = get_generic_target_info(spherical_target_config) + assert target_info.quantity == "spherical" + assert target_info.unit == "" + assert target_info.per_atom is False + assert target_info.gradients == [] diff --git a/tests/utils/data/test_target_writers.py b/tests/utils/data/test_target_writers.py index 525ccd446..7c4e1f355 100644 --- a/tests/utils/data/test_target_writers.py +++ b/tests/utils/data/test_target_writers.py @@ -86,6 +86,93 @@ def test_write_xyz(monkeypatch, tmp_path): assert all(atoms.pbc == 3 * [False]) +def test_write_components_and_properties_xyz(monkeypatch, tmp_path): + monkeypatch.chdir(tmp_path) + + systems, _, _ = systems_capabilities_predictions() + + capabilities = ModelCapabilities( + length_unit="angstrom", + outputs={"energy": ModelOutput(quantity="dos", unit="")}, + interaction_range=1.0, + dtype="float32", + ) + + predictions = { + "dos": TensorMap( + keys=Labels.single(), + blocks=[ + TensorBlock( + values=torch.rand(2, 3, 100), + samples=Labels(["system"], torch.tensor([[0], [1]])), + components=[ + Labels.range("xyz", 3), + ], + properties=Labels( + ["property"], + torch.arange(100, dtype=torch.int32).reshape(-1, 1), + ), + ) + ], + ) + } + + filename = "test_output.xyz" + + write_xyz(filename, systems, capabilities, predictions) + + # Read the file and verify its contents + frames = ase.io.read(filename, index=":") + assert len(frames) == len(systems) + for atoms in frames: + assert atoms.info["dos"].shape == (3, 100) + + +def test_write_components_and_properties_xyz_per_atom(monkeypatch, tmp_path): + monkeypatch.chdir(tmp_path) + + systems, _, _ = systems_capabilities_predictions() + + capabilities = ModelCapabilities( + length_unit="angstrom", + outputs={"energy": ModelOutput(quantity="dos", unit="", per_atom=True)}, + interaction_range=1.0, + dtype="float32", + ) + + predictions = { + "dos": TensorMap( + keys=Labels.single(), + blocks=[ + TensorBlock( + values=torch.rand(4, 3, 100), + samples=Labels( + ["system", "atom"], + torch.tensor([[0, 0], [0, 1], [1, 0], [1, 1]]), + ), + components=[ + Labels.range("xyz", 3), + ], + properties=Labels( + ["property"], + torch.arange(100, dtype=torch.int32).reshape(-1, 1), + ), + ) + ], + ) + } + + filename = "test_output.xyz" + + write_xyz(filename, systems, capabilities, predictions) + + # Read the file and verify its contents + frames = ase.io.read(filename, index=":") + assert len(frames) == len(systems) + for atoms in frames: + assert atoms.arrays["dos"].shape == (2, 300) + + def test_write_xyz_cell(monkeypatch, tmp_path): monkeypatch.chdir(tmp_path) @@ -100,10 +187,14 @@ def test_write_xyz_cell(monkeypatch, tmp_path): # Read the file and verify its contents frames = ase.io.read(filename, index=":") - for atoms in frames: + for i, atoms in enumerate(frames): cell_actual = torch.tensor(atoms.cell.array, dtype=cell_expected.dtype) torch.testing.assert_close(cell_actual, cell_expected) assert all(atoms.pbc == 3 * [True]) + assert atoms.info["energy"] == float(predictions["energy"].block().values[i, 0]) + assert atoms.arrays["forces"].shape == (2, 3) + assert atoms.info["stress"].shape == (3, 3) + assert atoms.info["virial"].shape == (3, 3) @pytest.mark.parametrize("fileformat", (None, ".xyz")) diff --git a/tests/utils/data/test_targets_ase.py b/tests/utils/data/test_targets_ase.py index 9444216ff..f7da147ba 100644 --- a/tests/utils/data/test_targets_ase.py +++ b/tests/utils/data/test_targets_ase.py @@ -9,11 +9,11 @@ import torch from metatrain.utils.data.readers.ase import ( - read_energy_ase, - read_forces_ase, - read_stress_ase, - read_systems_ase, - read_virial_ase, + _read_energy_ase, + _read_forces_ase, + _read_stress_ase, + _read_virial_ase, + read_systems, ) @@ -40,7 +40,7 @@ def test_read_ase(monkeypatch, tmp_path): systems = ase_system() ase.io.write(filename, systems) - result = read_systems_ase(filename) + result = read_systems(filename) assert isinstance(result, list) assert len(result) == 1 @@ -67,7 +67,7 @@ def test_read_energy_ase(monkeypatch, tmp_path): systems = ase_systems() ase.io.write(filename, systems) - results = read_energy_ase(filename=filename, key="true_energy") + results = _read_energy_ase(filename=filename, key="true_energy") for result, atoms in zip(results, systems): expected = torch.tensor([[atoms.info["true_energy"]]], dtype=torch.float64) @@ -77,10 +77,10 @@ def test_read_energy_ase(monkeypatch, tmp_path): @pytest.mark.parametrize( "func, target_name", [ - (read_energy_ase, "energy"), - (read_forces_ase, "forces"), - (read_virial_ase, "virial"), - (read_stress_ase, "stress"), + (_read_energy_ase, "energy"), + (_read_forces_ase, "forces"), + (_read_virial_ase, "virial"), + (_read_stress_ase, "stress"), ], ) def test_ase_key_errors(func, target_name, monkeypatch, tmp_path): @@ -105,7 +105,7 @@ def test_read_forces_ase(monkeypatch, tmp_path): systems = ase_systems() ase.io.write(filename, systems) - results = read_forces_ase(filename=filename, key="forces") + results = _read_forces_ase(filename=filename, key="forces") for result, atoms in zip(results, systems): expected = -torch.tensor(atoms.get_array("forces"), dtype=torch.float64) @@ -121,7 +121,7 @@ def test_read_stress_ase(monkeypatch, tmp_path): systems = ase_systems() ase.io.write(filename, systems) - results = read_stress_ase(filename=filename, key="stress-3x3") + results = _read_stress_ase(filename=filename, key="stress-3x3") for result, atoms in zip(results, systems): expected = atoms.cell.volume * torch.tensor( @@ -143,7 +143,7 @@ def test_no_cell_error(monkeypatch, tmp_path): ase.io.write(filename, systems) with pytest.raises(ValueError, match="system 0 has zero cell vectors."): - read_stress_ase(filename=filename, key="stress-3x3") + _read_stress_ase(filename=filename, key="stress-3x3") def test_read_virial_ase(monkeypatch, tmp_path): @@ -154,7 +154,7 @@ def test_read_virial_ase(monkeypatch, tmp_path): systems = ase_systems() ase.io.write(filename, systems) - results = read_virial_ase(filename=filename, key="stress-3x3") + results = _read_virial_ase(filename=filename, key="stress-3x3") for result, atoms in zip(results, systems): expected = -torch.tensor(atoms.info["stress-3x3"], dtype=torch.float64) @@ -171,7 +171,7 @@ def test_read_virial_warn(monkeypatch, tmp_path): ase.io.write(filename, systems) with pytest.warns(match="Found 9-long numerical vector"): - results = read_virial_ase(filename=filename, key="stress-9") + results = _read_virial_ase(filename=filename, key="stress-9") expected = -torch.tensor(systems.info["stress-9"], dtype=torch.float64) expected = expected.reshape(-1, 3, 3, 1) @@ -188,4 +188,4 @@ def test_read_virial_error(monkeypatch, tmp_path): ase.io.write(filename, systems) with pytest.raises(ValueError, match="Stress/virial must be a 3 x 3 matrix"): - read_virial_ase(filename=filename, key="stress-9") + _read_virial_ase(filename=filename, key="stress-9") diff --git a/tests/utils/test_additive.py b/tests/utils/test_additive.py index b3472d2f3..cdffe31e1 100644 --- a/tests/utils/test_additive.py +++ b/tests/utils/test_additive.py @@ -8,13 +8,16 @@ from omegaconf import OmegaConf from metatrain.utils.additive import ZBL, CompositionModel, remove_additive -from metatrain.utils.data import Dataset, DatasetInfo, TargetInfo +from metatrain.utils.data import Dataset, DatasetInfo from metatrain.utils.data.readers import read_systems, read_targets +from metatrain.utils.data.target_info import ( + get_energy_target_info, + get_generic_target_info, +) from metatrain.utils.neighbor_lists import ( get_requested_neighbor_lists, get_system_with_neighbor_lists, ) -from metatrain.utils.testing import energy_layout RESOURCES_PATH = Path(__file__).parents[1] / "resources" @@ -83,12 +86,7 @@ def test_composition_model_train(): dataset_info=DatasetInfo( length_unit="angstrom", atomic_types=[1, 8], - targets={ - "energy": TargetInfo( - quantity="energy", - layout=energy_layout, - ) - }, + targets={"energy": get_energy_target_info({"unit": "eV"})}, ), ) @@ -134,6 +132,9 @@ def test_composition_model_predict(): "reader": "ase", "key": "U0", "unit": "eV", + "type": "scalar", + "per_atom": False, + "num_subtargets": 1, "forces": False, "stress": False, "virial": False, @@ -210,12 +211,7 @@ def test_composition_model_torchscript(tmpdir): dataset_info=DatasetInfo( length_unit="angstrom", atomic_types=[1, 8], - targets={ - "energy": TargetInfo( - quantity="energy", - layout=energy_layout, - ) - }, + targets={"energy": get_energy_target_info({"unit": "eV"})}, ), ) composition_model = torch.jit.script(composition_model) @@ -243,6 +239,9 @@ def test_remove_additive(): "reader": "ase", "key": "U0", "unit": "eV", + "type": "scalar", + "per_atom": False, + "num_subtargets": 1, "forces": False, "stress": False, "virial": False, @@ -339,12 +338,7 @@ def test_composition_model_missing_types(): dataset_info=DatasetInfo( length_unit="angstrom", atomic_types=[1], - targets={ - "energy": TargetInfo( - quantity="energy", - layout=energy_layout, - ) - }, + targets={"energy": get_energy_target_info({"unit": "eV"})}, ), ) with pytest.raises( @@ -358,12 +352,7 @@ def test_composition_model_missing_types(): dataset_info=DatasetInfo( length_unit="angstrom", atomic_types=[1, 8, 100], - targets={ - "energy": TargetInfo( - quantity="energy", - layout=energy_layout, - ) - }, + targets={"energy": get_energy_target_info({"unit": "eV"})}, ), ) with pytest.warns( @@ -388,9 +377,18 @@ def test_composition_model_wrong_target(): length_unit="angstrom", atomic_types=[1], targets={ - "energy": TargetInfo( - quantity="FOO", - layout=energy_layout, + "energy": get_generic_target_info( + { + "quantity": "dipole", + "unit": "D", + "per_atom": True, + "num_subtargets": 5, + "type": { + "Cartesian": { + "rank": 1, + } + }, + } ) }, ), @@ -412,6 +410,9 @@ def test_zbl(): "reader": "ase", "key": "U0", "unit": "eV", + "type": "scalar", + "per_atom": False, + "num_subtargets": 1, "forces": False, "stress": False, "virial": False, diff --git a/tests/utils/test_evaluate_model.py b/tests/utils/test_evaluate_model.py index c18033172..90fc5c38d 100644 --- a/tests/utils/test_evaluate_model.py +++ b/tests/utils/test_evaluate_model.py @@ -2,13 +2,13 @@ import torch from metatrain.experimental.soap_bpnn import __model__ -from metatrain.utils.data import DatasetInfo, TargetInfo, read_systems +from metatrain.utils.data import DatasetInfo, read_systems +from metatrain.utils.data.target_info import get_energy_target_info from metatrain.utils.evaluate_model import evaluate_model from metatrain.utils.neighbor_lists import ( get_requested_neighbor_lists, get_system_with_neighbor_lists, ) -from metatrain.utils.testing import energy_force_stress_layout from . import MODEL_HYPERS, RESOURCES_PATH @@ -25,8 +25,10 @@ def test_evaluate_model(training, exported): ) targets = { - "energy": TargetInfo( - quantity="energy", unit="eV", layout=energy_force_stress_layout + "energy": get_energy_target_info( + {"unit": "eV"}, + add_position_gradients=True, + add_strain_gradients=True, ) } diff --git a/tests/utils/test_external_naming.py b/tests/utils/test_external_naming.py index 0e3db1371..5b9c79de7 100644 --- a/tests/utils/test_external_naming.py +++ b/tests/utils/test_external_naming.py @@ -1,17 +1,20 @@ -from metatrain.utils.data.dataset import TargetInfo +from metatrain.utils.data.target_info import get_energy_target_info from metatrain.utils.external_naming import to_external_name, to_internal_name -from metatrain.utils.testing import energy_layout def test_to_external_name(): """Tests the to_external_name function.""" quantities = { - "energy": TargetInfo(quantity="energy", layout=energy_layout), - "mtt::free_energy": TargetInfo(quantity="energy", layout=energy_layout), - "mtt::foo": TargetInfo(quantity="bar", layout=energy_layout), + "energy": get_energy_target_info({"unit": "eV"}), + "mtt::free_energy": get_energy_target_info({"unit": "eV"}), + "mtt::foo": get_energy_target_info({"unit": "eV"}), } + # hack to test the fact that non-energies should be treated differently + # (i.e., their gradients should not have special names) + quantities["mtt::foo"].quantity = "bar" + assert to_external_name("energy_positions_gradients", quantities) == "forces" assert ( to_external_name("mtt::free_energy_positions_gradients", quantities) diff --git a/tests/utils/test_llpr.py b/tests/utils/test_llpr.py index bb87d266a..280537562 100644 --- a/tests/utils/test_llpr.py +++ b/tests/utils/test_llpr.py @@ -35,6 +35,9 @@ def test_llpr(tmpdir): "reader": "ase", "key": "U0", "unit": "kcal/mol", + "type": "scalar", + "per_atom": False, + "num_subtargets": 1, "forces": False, "stress": False, "virial": False, @@ -157,6 +160,9 @@ def test_llpr_covariance_as_pseudo_hessian(tmpdir): "reader": "ase", "key": "U0", "unit": "kcal/mol", + "type": "scalar", + "per_atom": False, + "num_subtargets": 1, "forces": False, "stress": False, "virial": False, diff --git a/tests/utils/test_neighbor_list.py b/tests/utils/test_neighbor_list.py index 768114d6c..d346dca00 100644 --- a/tests/utils/test_neighbor_list.py +++ b/tests/utils/test_neighbor_list.py @@ -2,7 +2,7 @@ from metatensor.torch.atomistic import NeighborListOptions -from metatrain.utils.data.readers.ase import read_systems_ase +from metatrain.utils.data.readers.ase import read_systems from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists @@ -11,7 +11,7 @@ def test_attach_neighbor_lists(): filename = RESOURCES_PATH / "qm9_reduced_100.xyz" - systems = read_systems_ase(filename) + systems = read_systems(filename) requested_neighbor_lists = [ NeighborListOptions(cutoff=4.0, full_list=True, strict=True), diff --git a/tests/utils/test_output_gradient.py b/tests/utils/test_output_gradient.py index 75aaf76e7..ac20bac4d 100644 --- a/tests/utils/test_output_gradient.py +++ b/tests/utils/test_output_gradient.py @@ -4,13 +4,9 @@ from metatensor.torch.atomistic import System from metatrain.experimental.soap_bpnn import __model__ -from metatrain.utils.data import DatasetInfo, TargetInfo, read_systems +from metatrain.utils.data import DatasetInfo, read_systems +from metatrain.utils.data.target_info import get_energy_target_info from metatrain.utils.output_gradient import compute_gradient -from metatrain.utils.testing import ( - energy_force_layout, - energy_force_stress_layout, - energy_stress_layout, -) from . import MODEL_HYPERS, RESOURCES_PATH @@ -23,8 +19,8 @@ def test_forces(is_training): length_unit="angstrom", atomic_types={1, 6, 7, 8}, targets={ - "energy": TargetInfo( - quantity="energy", unit="eV", layout=energy_force_layout + "energy": get_energy_target_info( + {"unit": "eV"}, add_position_gradients=True ) }, ) @@ -82,9 +78,7 @@ def test_virial(is_training): length_unit="angstrom", atomic_types={6}, targets={ - "energy": TargetInfo( - quantity="energy", unit="eV", layout=energy_stress_layout - ) + "energy": get_energy_target_info({"unit": "eV"}, add_strain_gradients=True) }, ) model = __model__(model_hypers=MODEL_HYPERS, dataset_info=dataset_info) @@ -153,10 +147,10 @@ def test_both(is_training): length_unit="angstrom", atomic_types={6}, targets={ - "energy": TargetInfo( - quantity="energy", - unit="eV", - layout=energy_force_stress_layout, + "energy": get_energy_target_info( + {"unit": "eV"}, + add_position_gradients=True, + add_strain_gradients=True, ) }, )