Skip to content

Commit

Permalink
Merge pull request #308 from torchmd/ensemble_model_support
Browse files Browse the repository at this point in the history
Added support for ensemble models
  • Loading branch information
stefdoerr authored Mar 19, 2024
2 parents 9546e88 + f21f3de commit 8a1be71
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 62 deletions.
83 changes: 52 additions & 31 deletions docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -83,38 +83,9 @@ This is a minimal example of a custom training loop:
optimizer.step()
Loading a model for inference
-----------------------------

Once you have trained a model you should have a checkpoint that you can load for inference using :py:mod:`torchmdnet.models.model.load_model` as in the following example.

.. code:: python
import torch
from torchmdnet.models.model import load_model
checkpoint = "/path/to/checkpoint/my_checkpoint.ckpt"
model = load_model(checkpoint, derivative=True)
# An arbitrary set of inputs for the model
n_atoms = 10
zs = torch.tensor([1, 6, 7, 8, 9], dtype=torch.long)
z = zs[torch.randint(0, len(zs), (n_atoms,))]
pos = torch.randn(len(z), 3)
batch = torch.zeros(len(z), dtype=torch.long)
y, neg_dy = model(z, pos, batch)
.. note:: You can train a model using only the labels (i.e. energy) by passing :code:`derivative=False` and then set it to :code:`True` to compute its derivative (i.e. forces) only during inference.

.. note:: Some models take additional inputs such as the charge :code:`q` and the spin :code:`s` of the atoms depending on the chosen priors/outputs. Check the documentation of the model you are using to see if this is the case.

.. note:: When periodic boundary conditions are required, modules typically offer the possibility of providing the box vectors at construction and/or as an argument to the forward pass. Check the documentation of the class you are using to see if this is the case.


.. _delta-learning:
Training on relative energies
-----------------------------
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

It might be useful to train the model on relative energies but then make the model produce total energies when running inference.
TorchMD-Net supports delta training via the :code:`remove_ref_energy` option. Passing this option when training (either via the :ref:`configuration-file` or using the :ref:`torchmd-train` command line interface) will subtract the reference energy from each atom in a sample before passing it to the model.
Expand All @@ -126,7 +97,7 @@ If :code:`remove_ref_energy` is turned on, the reference energy is stored in the
.. note:: The reference energies are stored as an :py:mod:`torchmdnet.priors.Atomref` prior with :code:`enable=False`.

Example
~~~~~~~
********

First we train a model with the :code:`remove_ref_energy` option turned on:

Expand All @@ -151,6 +122,56 @@ Then we load the model for inference:
batch = torch.zeros(len(z), dtype=torch.long)
y, neg_dy = model(z, pos, batch)
Loading a model for inference
-----------------------------

Once you have trained a model you should have a checkpoint that you can load for inference using :py:mod:`torchmdnet.models.model.load_model` as in the following example.

.. code:: python
import torch
from torchmdnet.models.model import load_model
checkpoint = "/path/to/checkpoint/my_checkpoint.ckpt"
model = load_model(checkpoint, derivative=True)
# An arbitrary set of inputs for the model
n_atoms = 10
zs = torch.tensor([1, 6, 7, 8, 9], dtype=torch.long)
z = zs[torch.randint(0, len(zs), (n_atoms,))]
pos = torch.randn(len(z), 3)
batch = torch.zeros(len(z), dtype=torch.long)
y, neg_dy = model(z, pos, batch)
.. note:: You can train a model using only the labels (i.e. energy) by passing :code:`derivative=False` and then set it to :code:`True` to compute its derivative (i.e. forces) only during inference.

.. note:: Some models take additional inputs such as the charge :code:`q` and the spin :code:`s` of the atoms depending on the chosen priors/outputs. Check the documentation of the model you are using to see if this is the case.

.. note:: When periodic boundary conditions are required, modules typically offer the possibility of providing the box vectors at construction and/or as an argument to the forward pass. Check the documentation of the class you are using to see if this is the case.


Model Ensembles
---------------
It is possible to create an ensemble of models by loading multiple checkpoints and averaging their predictions. The following example shows how to do this:

.. code:: python
import torch
from torchmdnet.models.model import load_model
checkpoints = ["/path/to/checkpoint/my_checkpoint1.ckpt", "/path/to/checkpoint/my_checkpoint2.ckpt"]
model_ensemble = load_model(checkpoints, return_std=True)
y_ensemble, neg_dy_ensemble, y_std, neg_dy_std = ensemble_model(z, pos, batch)
.. note:: :py:mod:`torchmdnet.models.model.load_model` will return an instance of :py:mod:`torchmdnet.models.model.Ensemble` if a list of checkpoints is passed. The :code:`return_std` option can be used to return the standard deviation of the predictions.



.. autoclass:: torchmdnet.models.model.Ensemble
:noindex:





Expand Down
57 changes: 45 additions & 12 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
import lightning as pl
from torchmdnet import models
from torchmdnet.models.model import create_model
from torchmdnet.models.model import create_model, load_model
from torchmdnet.models import output_modules
from torchmdnet.models.utils import dtype_mapping

Expand All @@ -23,7 +23,9 @@
def test_forward(model_name, use_batch, explicit_q_s, precision):
z, pos, batch = create_example_batch()
pos = pos.to(dtype=dtype_mapping[precision])
model = create_model(load_example_args(model_name, prior_model=None, precision=precision))
model = create_model(
load_example_args(model_name, prior_model=None, precision=precision)
)
batch = batch if use_batch else None
if explicit_q_s:
model(z, pos, batch=batch, q=None, s=None)
Expand All @@ -33,10 +35,12 @@ def test_forward(model_name, use_batch, explicit_q_s, precision):

@mark.parametrize("model_name", models.__all_models__)
@mark.parametrize("output_model", output_modules.__all__)
@mark.parametrize("precision", [32,64])
@mark.parametrize("precision", [32, 64])
def test_forward_output_modules(model_name, output_model, precision):
z, pos, batch = create_example_batch()
args = load_example_args(model_name, remove_prior=True, output_model=output_model, precision=precision)
args = load_example_args(
model_name, remove_prior=True, output_model=output_model, precision=precision
)
model = create_model(args)
model(z, pos, batch=batch)

Expand All @@ -61,18 +65,25 @@ def test_torchscript(model_name, device):
grad_outputs=grad_outputs,
)[0]


def test_torchscript_output_modification():
model = create_model(load_example_args("tensornet", remove_prior=True, derivative=True))
model = create_model(
load_example_args("tensornet", remove_prior=True, derivative=True)
)

class MyModel(torch.nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.model = model

def forward(self, z, pos, batch):
y, neg_dy = self.model(z, pos, batch=batch)
# A TorchScript bug is triggered if we modify an output of model marked as Optional[Tensor]
return y, 2*neg_dy
return y, 2 * neg_dy

torch.jit.script(MyModel())


@mark.parametrize("model_name", models.__all_models__)
@mark.parametrize("device", ["cpu", "cuda"])
def test_torchscript_dynamic_shapes(model_name, device):
Expand All @@ -84,11 +95,11 @@ def test_torchscript_dynamic_shapes(model_name, device):
model = torch.jit.script(
create_model(load_example_args(model_name, remove_prior=True, derivative=True))
).to(device=device)
#Repeat the input to make it dynamic
# Repeat the input to make it dynamic
for rep in range(0, 5):
print(rep)
zi = z.repeat_interleave(rep+1, dim=0).to(device=device)
posi = pos.repeat_interleave(rep+1, dim=0).to(device=device)
zi = z.repeat_interleave(rep + 1, dim=0).to(device=device)
posi = pos.repeat_interleave(rep + 1, dim=0).to(device=device)
batchi = torch.randint(0, 10, (zi.shape[0],)).sort()[0].to(device=device)
y, neg_dy = model(zi, posi, batch=batchi)
grad_outputs = [torch.ones_like(neg_dy)]
Expand All @@ -98,7 +109,8 @@ def test_torchscript_dynamic_shapes(model_name, device):
grad_outputs=grad_outputs,
)[0]

#Currently only tensornet is CUDA graph compatible

# Currently only tensornet is CUDA graph compatible
@mark.parametrize("model_name", ["tensornet"])
def test_cuda_graph_compatible(model_name):
if not torch.cuda.is_available():
Expand Down Expand Up @@ -142,6 +154,7 @@ def test_cuda_graph_compatible(model_name):
assert torch.allclose(y, y2)
assert torch.allclose(neg_dy, neg_dy2, atol=1e-5, rtol=1e-5)


@mark.parametrize("model_name", models.__all_models__)
def test_seed(model_name):
args = load_example_args(model_name, remove_prior=True)
Expand All @@ -153,6 +166,7 @@ def test_seed(model_name):
for p1, p2 in zip(m1.parameters(), m2.parameters()):
assert (p1 == p2).all(), "Parameters don't match although using the same seed."


@mark.parametrize("model_name", models.__all_models__)
@mark.parametrize(
"output_model",
Expand Down Expand Up @@ -199,7 +213,9 @@ def test_forward_output(model_name, output_model, overwrite_reference=False):
), f"Set new reference outputs for {model_name} with output model {output_model}."

# compare actual ouput with reference
torch.testing.assert_close(pred, expected[model_name][output_model]["pred"], atol=1e-5, rtol=1e-5)
torch.testing.assert_close(
pred, expected[model_name][output_model]["pred"], atol=1e-5, rtol=1e-5
)
if derivative:
torch.testing.assert_close(
deriv, expected[model_name][output_model]["deriv"], atol=1e-5, rtol=1e-5
Expand All @@ -218,7 +234,7 @@ def test_gradients(model_name):
remove_prior=True,
output_model=output_model,
derivative=derivative,
precision=precision
precision=precision,
)
model = create_model(args)
z, pos, batch = create_example_batch(n_atoms=5)
Expand All @@ -227,3 +243,20 @@ def test_gradients(model_name):
torch.autograd.gradcheck(
model, (z, pos, batch), eps=1e-4, atol=1e-3, rtol=1e-2, nondet_tol=1e-3
)


def test_ensemble():
ckpts = [join(dirname(dirname(__file__)), "tests", "example.ckpt")] * 3
model = load_model(ckpts[0])
ensemble_model = load_model(ckpts, return_std=True)
z, pos, batch = create_example_batch(n_atoms=5)

pred, deriv = model(z, pos, batch)
pred_ensemble, deriv_ensemble, y_std, neg_dy_std = ensemble_model(z, pos, batch)

torch.testing.assert_close(pred, pred_ensemble, atol=1e-5, rtol=1e-5)
torch.testing.assert_close(deriv, deriv_ensemble, atol=1e-5, rtol=1e-5)
assert y_std.shape == pred.shape
assert neg_dy_std.shape == deriv.shape
assert (y_std == 0).all()
assert (neg_dy_std == 0).all()
97 changes: 78 additions & 19 deletions torchmdnet/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,18 +139,25 @@ def create_model(args, prior_model=None, mean=None, std=None):
return model


def load_model(filepath, args=None, device="cpu", **kwargs):
def load_model(filepath, args=None, device="cpu", return_std=False, **kwargs):
"""Load a model from a checkpoint file.
If a list of paths is given, an :py:mod:`Ensemble` model is returned.
Args:
filepath (str): Path to the checkpoint file.
filepath (str or list): Path to the checkpoint file or a list of paths.
args (dict, optional): Arguments for the model. Defaults to None.
device (str, optional): Device on which the model should be loaded. Defaults to "cpu".
return_std (bool, optional): Whether to return the standard deviation of an Ensemble model. Defaults to False.
**kwargs: Extra keyword arguments for the model.
Returns:
nn.Module: An instance of the TorchMD_Net model.
"""
if isinstance(filepath, (list, tuple)):
return Ensemble(
[load_model(f, args=args, device=device, **kwargs) for f in filepath],
return_std=return_std,
)

ckpt = torch.load(filepath, map_location="cpu")
if args is None:
Expand Down Expand Up @@ -187,29 +194,32 @@ def create_prior_models(args, dataset=None):
1. A single prior model name and its arguments as a dictionary:
```python
args = {
"prior_model": "Atomref",
"prior_args": {"max_z": 100}
}
```
.. code:: python
args = {
"prior_model": "Atomref",
"prior_args": {"max_z": 100}
}
2. A list of prior model names and their arguments as a list of dictionaries:
```python
.. code:: python
args = {
"prior_model": ["Atomref", "D2"],
"prior_args": [{"max_z": 100}, {"max_z": 100}]
}
args = {
"prior_model": ["Atomref", "D2"],
"prior_args": [{"max_z": 100}, {"max_z": 100}]
}
```
3. A list of prior model names and their arguments as a dictionary:
```python
args = {
"prior_model": [{"Atomref": {"max_z": 100}}, {"D2": {"max_z": 100}}]
}
```
.. code:: python
args = {
"prior_model": [{"Atomref": {"max_z": 100}}, {"D2": {"max_z": 100}}]
}
Args:
args (dict): Arguments for the model.
Expand Down Expand Up @@ -426,3 +436,52 @@ def forward(
# Returning an empty tensor allows to decorate this method as always returning two tensors.
# This is required to overcome a TorchScript limitation, xref https://github.com/openmm/openmm-torch/issues/135
return y, torch.empty(0)


class Ensemble(torch.nn.ModuleList):
"""Average predictions over an ensemble of TorchMD-Net models.
This module behaves like a single TorchMD-Net model, but its forward method returns the average and standard deviation of the predictions over all models it was initialized with.
Args:
modules (List[nn.Module]): List of :py:mod:`TorchMD_Net` models to average predictions over.
return_std (bool, optional): Whether to return the standard deviation of the predictions. Defaults to False. If set to True, the model returns 4 arguments (mean_y, mean_neg_dy, std_y, std_neg_dy) instead of 2 (mean_y, mean_neg_dy).
"""

def __init__(self, modules: List[nn.Module], return_std: bool = False):
for module in modules:
assert isinstance(module, TorchMD_Net)
super().__init__(modules)
self.return_std = return_std

def forward(
self,
*args,
**kwargs,
):
"""Average predictions over all models in the ensemble.
The arguments to this function are simply relayed to the forward method of each :py:mod:`TorchMD_Net` model in the ensemble.
Args:
*args: Positional arguments to forward to the models.
**kwargs: Keyword arguments to forward to the models.
Returns:
Tuple[Tensor, Optional[Tensor]] or Tuple[Tensor, Optional[Tensor], Tensor, Optional[Tensor]]: The average and standard deviation of the predictions over all models in the ensemble. If return_std is False, the output is a tuple (mean_y, mean_neg_dy). If return_std is True, the output is a tuple (mean_y, mean_neg_dy, std_y, std_neg_dy).
"""
y = []
neg_dy = []
for model in self:
res = model(*args, **kwargs)
y.append(res[0])
neg_dy.append(res[1])
y = torch.stack(y)
neg_dy = torch.stack(neg_dy)
y_mean = torch.mean(y, axis=0)
neg_dy_mean = torch.mean(neg_dy, axis=0)
y_std = torch.std(y, axis=0)
neg_dy_std = torch.std(neg_dy, axis=0)

if self.return_std:
return y_mean, neg_dy_mean, y_std, neg_dy_std
else:
return y_mean, neg_dy_mean

0 comments on commit 8a1be71

Please sign in to comment.