Skip to content

Commit

Permalink
Adapt models to handle generic targets (#386)
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster authored Nov 25, 2024
1 parent 1ffccb4 commit 09e2550
Show file tree
Hide file tree
Showing 18 changed files with 507 additions and 178 deletions.
Original file line number Diff line number Diff line change
@@ -1,11 +1,52 @@
Preparing generic targets for reading by metatrain
==================================================
Fitting generic targets
=======================

Besides energy-like targets, the library also supports reading (and training on)
more generic targets.

Support for generic targets
---------------------------

Not all architectures can train on all types of target. Here you can find the
capabilities of the architectures in metatrain.

.. list-table:: Sample Table
:header-rows: 1

* - Target type
- Energy and its gradients
- Scalars
- Spherical tensors
- Cartesian tensors
* - SOAP-BPNN
- Energy, forces, stress/virial
- Yes
- Only with ``o3_lambda=1, o3_sigma=1``
- No
* - Alchemical Model
- Energy, forces, stress/virial
- No
- No
- No
* - GAP
- Energy, forces
- No
- No
- No
* - PET
- Energy, forces
- No
- No
- No


Preparing generic targets for reading by metatrain
--------------------------------------------------

Only a few steps are required to fit arbitrary targets in metatrain.

Input file
----------
##########

In order to read a generic target, you will have to specify its layout in the input
file. Suppose you want to learn a target named ``mtt::my_target``, which is
Expand Down Expand Up @@ -69,7 +110,7 @@ where ``o3_lambda`` specifies the L value of the spherical tensor and ``o3_sigma
parity with respect to inversion (1 for proper tensors, -1 for pseudo-tensors).

Preparing your targets -- ASE
-----------------------------
#############################

If you are using the ASE readers to read your targets, you will have to save them
either in the ``.info`` (if the target is per structure, i.e. not per atom) or in the
Expand All @@ -84,7 +125,7 @@ Reading targets with more than one spherical tensor is not supported by the ASE
In that case, you should use the metatensor reader.

Preparing your targets -- metatensor
------------------------------------
####################################

If you are using the metatensor readers to read your targets, you will have to save them
as a ``metatensor.torch.TensorMap`` object with ``metatensor.torch.TensorMap.save()``
Expand Down
2 changes: 1 addition & 1 deletion docs/src/advanced-concepts/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ such as output naming, auxiliary outputs, and wrapper models.
multi-gpu
auto-restarting
fine-tuning
preparing-generic-targets
fitting-generic-targets
9 changes: 9 additions & 0 deletions docs/src/dev-docs/new-architecture.rst
Original file line number Diff line number Diff line change
Expand Up @@ -227,3 +227,12 @@ passed to the architecture model and trainer as is.
To create such a schema start by using `online tools <https://jsonformatter.org>`_ that
convert the ``default-hypers.yaml`` into a JSON schema. Besides online tools, we also
had success using ChatGPT/LLM for this for conversion.

Documentation
-------------

Each new architecture should be added to ``metatrain``'s documentation. A short page
describing the architecture and its default hyperparameters will be sufficient. You
can take inspiration from existing architectures. The various targets that the
architecture can fit should be added to the table in the "Fitting generic targets"
section.
22 changes: 16 additions & 6 deletions src/metatrain/experimental/alchemical_model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,24 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None:
self.atomic_types = dataset_info.atomic_types

if len(dataset_info.targets) != 1:
raise ValueError("The AlchemicalModel only supports a single target")
raise ValueError("The Alchemical Model only supports a single target")

target_name = next(iter(dataset_info.targets.keys()))
if dataset_info.targets[target_name].quantity != "energy":
raise ValueError("The AlchemicalModel only supports 'energies' as target")

if dataset_info.targets[target_name].per_atom:
raise ValueError("The AlchemicalModel does not support 'per-atom' training")
target = dataset_info.targets[target_name]
if not (
target.is_scalar
and target.quantity == "energy"
and len(target.layout.block(0).properties) == 1
):
raise ValueError(
"The Alchemical Model only supports total-energy-like outputs, "
f"but a {target.quantity} was provided"
)
if target.per_atom:
raise ValueError(
"Alchemical Model only supports per-structure outputs, "
"but a per-atom output was provided"
)

self.outputs = {
key: ModelOutput(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import pytest
import torch
from metatensor.torch.atomistic import ModelEvaluationOptions, System

from metatrain.experimental.alchemical_model import AlchemicalModel
from metatrain.utils.data import DatasetInfo
from metatrain.utils.data.target_info import get_energy_target_info
from metatrain.utils.data.target_info import (
get_energy_target_info,
get_generic_target_info,
)
from metatrain.utils.neighbor_lists import (
get_requested_neighbor_lists,
get_system_with_neighbor_lists,
Expand Down Expand Up @@ -40,3 +44,31 @@ def test_prediction_subset_elements():

exported = model.export()
exported([system], evaluation_options, check_consistency=True)


@pytest.mark.parametrize("per_atom", [True, False])
def test_vector_output(per_atom):
"""Tests that the model can predict a (spherical) vector output."""

dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types=[1, 6, 7, 8],
targets={
"forces": get_generic_target_info(
{
"quantity": "forces",
"unit": "",
"type": {
"spherical": {"irreps": [{"o3_lambda": 1, "o3_sigma": 1}]}
},
"num_subtargets": 100,
"per_atom": per_atom,
}
)
},
)

with pytest.raises(
ValueError, match="The Alchemical Model only supports total-energy-like outputs"
):
AlchemicalModel(MODEL_HYPERS, dataset_info)
8 changes: 6 additions & 2 deletions src/metatrain/experimental/gap/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,13 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None:

# Check capabilities
for target in dataset_info.targets.values():
if target.quantity != "energy":
if not (
target.is_scalar
and target.quantity == "energy"
and len(target.layout.block(0).properties) == 1
):
raise ValueError(
"GAP only supports energy-like outputs, "
"GAP only supports total-energy-like outputs, "
f"but a {target.quantity} was provided"
)
if target.per_atom:
Expand Down
31 changes: 30 additions & 1 deletion src/metatrain/experimental/gap/tests/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@

from metatrain.experimental.gap import GAP, Trainer
from metatrain.utils.data import Dataset, DatasetInfo, read_systems, read_targets
from metatrain.utils.data.target_info import get_energy_target_info
from metatrain.utils.data.target_info import (
get_energy_target_info,
get_generic_target_info,
)

from . import DATASET_ETHANOL_PATH, DEFAULT_HYPERS

Expand Down Expand Up @@ -85,3 +88,29 @@ def test_ethanol_regression_train_and_invariance():
val_datasets=[dataset],
checkpoint_dir=".",
)


@pytest.mark.parametrize("per_atom", [True, False])
def test_vector_output(per_atom):
"""Tests that the model can predict a (spherical) vector output."""

dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types=[1, 6, 7, 8],
targets={
"forces": get_generic_target_info(
{
"quantity": "forces",
"unit": "",
"type": {
"spherical": {"irreps": [{"o3_lambda": 1, "o3_sigma": 1}]}
},
"num_subtargets": 100,
"per_atom": per_atom,
}
)
},
)

with pytest.raises(ValueError, match="GAP only supports total-energy-like outputs"):
GAP(DEFAULT_HYPERS["model"], dataset_info)
17 changes: 15 additions & 2 deletions src/metatrain/experimental/pet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,21 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None:
if len(dataset_info.targets) != 1:
raise ValueError("PET only supports a single target")
self.target_name = next(iter(dataset_info.targets.keys()))
if dataset_info.targets[self.target_name].quantity != "energy":
raise ValueError("PET only supports energies as target")
target = dataset_info.targets[self.target_name]
if not (
target.is_scalar
and target.quantity == "energy"
and len(target.layout.block(0).properties) == 1
):
raise ValueError(
"PET only supports total-energy-like outputs, "
f"but a {target.quantity} was provided"
)
if target.per_atom:
raise ValueError(
"PET only supports per-structure outputs, "
"but a per-atom output was provided"
)

model_hypers["D_OUTPUT"] = 1
model_hypers["TARGET_TYPE"] = "atomic"
Expand Down
31 changes: 30 additions & 1 deletion src/metatrain/experimental/pet/tests/test_functionality.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@
from metatrain.experimental.pet import PET as WrappedPET
from metatrain.utils.architectures import get_default_hypers
from metatrain.utils.data import DatasetInfo
from metatrain.utils.data.target_info import get_energy_target_info
from metatrain.utils.data.target_info import (
get_energy_target_info,
get_generic_target_info,
)
from metatrain.utils.jsonschema import validate
from metatrain.utils.neighbor_lists import (
get_requested_neighbor_lists,
Expand Down Expand Up @@ -222,3 +225,29 @@ def test_selected_atoms_functionality():
evaluation_options,
check_consistency=True,
)


@pytest.mark.parametrize("per_atom", [True, False])
def test_vector_output(per_atom):
"""Tests that the model can predict a (spherical) vector output."""

dataset_info = DatasetInfo(
length_unit="Angstrom",
atomic_types=[1, 6, 7, 8],
targets={
"forces": get_generic_target_info(
{
"quantity": "forces",
"unit": "",
"type": {
"spherical": {"irreps": [{"o3_lambda": 1, "o3_sigma": 1}]}
},
"num_subtargets": 100,
"per_atom": per_atom,
}
)
},
)

with pytest.raises(ValueError, match="PET only supports total-energy-like outputs"):
WrappedPET(DEFAULT_HYPERS["model"], dataset_info)
Loading

0 comments on commit 09e2550

Please sign in to comment.