diff --git a/docs/source/models.rst b/docs/source/models.rst index 244c16928..5ec760e33 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -83,38 +83,9 @@ This is a minimal example of a custom training loop: optimizer.step() - - -Loading a model for inference ------------------------------ - -Once you have trained a model you should have a checkpoint that you can load for inference using :py:mod:`torchmdnet.models.model.load_model` as in the following example. - -.. code:: python - - import torch - from torchmdnet.models.model import load_model - checkpoint = "/path/to/checkpoint/my_checkpoint.ckpt" - model = load_model(checkpoint, derivative=True) - # An arbitrary set of inputs for the model - n_atoms = 10 - zs = torch.tensor([1, 6, 7, 8, 9], dtype=torch.long) - z = zs[torch.randint(0, len(zs), (n_atoms,))] - pos = torch.randn(len(z), 3) - batch = torch.zeros(len(z), dtype=torch.long) - - y, neg_dy = model(z, pos, batch) - -.. note:: You can train a model using only the labels (i.e. energy) by passing :code:`derivative=False` and then set it to :code:`True` to compute its derivative (i.e. forces) only during inference. - -.. note:: Some models take additional inputs such as the charge :code:`q` and the spin :code:`s` of the atoms depending on the chosen priors/outputs. Check the documentation of the model you are using to see if this is the case. - -.. note:: When periodic boundary conditions are required, modules typically offer the possibility of providing the box vectors at construction and/or as an argument to the forward pass. Check the documentation of the class you are using to see if this is the case. - - .. _delta-learning: Training on relative energies ------------------------------ +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ It might be useful to train the model on relative energies but then make the model produce total energies when running inference. TorchMD-Net supports delta training via the :code:`remove_ref_energy` option. Passing this option when training (either via the :ref:`configuration-file` or using the :ref:`torchmd-train` command line interface) will subtract the reference energy from each atom in a sample before passing it to the model. @@ -126,7 +97,7 @@ If :code:`remove_ref_energy` is turned on, the reference energy is stored in the .. note:: The reference energies are stored as an :py:mod:`torchmdnet.priors.Atomref` prior with :code:`enable=False`. Example -~~~~~~~ +******** First we train a model with the :code:`remove_ref_energy` option turned on: @@ -151,6 +122,56 @@ Then we load the model for inference: batch = torch.zeros(len(z), dtype=torch.long) y, neg_dy = model(z, pos, batch) + + +Loading a model for inference +----------------------------- + +Once you have trained a model you should have a checkpoint that you can load for inference using :py:mod:`torchmdnet.models.model.load_model` as in the following example. + +.. code:: python + + import torch + from torchmdnet.models.model import load_model + checkpoint = "/path/to/checkpoint/my_checkpoint.ckpt" + model = load_model(checkpoint, derivative=True) + # An arbitrary set of inputs for the model + n_atoms = 10 + zs = torch.tensor([1, 6, 7, 8, 9], dtype=torch.long) + z = zs[torch.randint(0, len(zs), (n_atoms,))] + pos = torch.randn(len(z), 3) + batch = torch.zeros(len(z), dtype=torch.long) + + y, neg_dy = model(z, pos, batch) + +.. note:: You can train a model using only the labels (i.e. energy) by passing :code:`derivative=False` and then set it to :code:`True` to compute its derivative (i.e. forces) only during inference. + +.. note:: Some models take additional inputs such as the charge :code:`q` and the spin :code:`s` of the atoms depending on the chosen priors/outputs. Check the documentation of the model you are using to see if this is the case. + +.. note:: When periodic boundary conditions are required, modules typically offer the possibility of providing the box vectors at construction and/or as an argument to the forward pass. Check the documentation of the class you are using to see if this is the case. + + +Model Ensembles +--------------- +It is possible to create an ensemble of models by loading multiple checkpoints and averaging their predictions. The following example shows how to do this: + +.. code:: python + + import torch + from torchmdnet.models.model import load_model + checkpoints = ["/path/to/checkpoint/my_checkpoint1.ckpt", "/path/to/checkpoint/my_checkpoint2.ckpt"] + model_ensemble = load_model(checkpoints, return_std=True) + y_ensemble, neg_dy_ensemble, y_std, neg_dy_std = ensemble_model(z, pos, batch) + + +.. note:: :py:mod:`torchmdnet.models.model.load_model` will return an instance of :py:mod:`torchmdnet.models.model.Ensemble` if a list of checkpoints is passed. The :code:`return_std` option can be used to return the standard deviation of the predictions. + + + +.. autoclass:: torchmdnet.models.model.Ensemble + :noindex: + + diff --git a/tests/test_model.py b/tests/test_model.py index 00010f890..1dd5e3549 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -9,7 +9,7 @@ import torch import lightning as pl from torchmdnet import models -from torchmdnet.models.model import create_model +from torchmdnet.models.model import create_model, load_model from torchmdnet.models import output_modules from torchmdnet.models.utils import dtype_mapping @@ -23,7 +23,9 @@ def test_forward(model_name, use_batch, explicit_q_s, precision): z, pos, batch = create_example_batch() pos = pos.to(dtype=dtype_mapping[precision]) - model = create_model(load_example_args(model_name, prior_model=None, precision=precision)) + model = create_model( + load_example_args(model_name, prior_model=None, precision=precision) + ) batch = batch if use_batch else None if explicit_q_s: model(z, pos, batch=batch, q=None, s=None) @@ -33,10 +35,12 @@ def test_forward(model_name, use_batch, explicit_q_s, precision): @mark.parametrize("model_name", models.__all_models__) @mark.parametrize("output_model", output_modules.__all__) -@mark.parametrize("precision", [32,64]) +@mark.parametrize("precision", [32, 64]) def test_forward_output_modules(model_name, output_model, precision): z, pos, batch = create_example_batch() - args = load_example_args(model_name, remove_prior=True, output_model=output_model, precision=precision) + args = load_example_args( + model_name, remove_prior=True, output_model=output_model, precision=precision + ) model = create_model(args) model(z, pos, batch=batch) @@ -61,18 +65,25 @@ def test_torchscript(model_name, device): grad_outputs=grad_outputs, )[0] + def test_torchscript_output_modification(): - model = create_model(load_example_args("tensornet", remove_prior=True, derivative=True)) + model = create_model( + load_example_args("tensornet", remove_prior=True, derivative=True) + ) + class MyModel(torch.nn.Module): def __init__(self): super(MyModel, self).__init__() self.model = model + def forward(self, z, pos, batch): y, neg_dy = self.model(z, pos, batch=batch) # A TorchScript bug is triggered if we modify an output of model marked as Optional[Tensor] - return y, 2*neg_dy + return y, 2 * neg_dy + torch.jit.script(MyModel()) + @mark.parametrize("model_name", models.__all_models__) @mark.parametrize("device", ["cpu", "cuda"]) def test_torchscript_dynamic_shapes(model_name, device): @@ -84,11 +95,11 @@ def test_torchscript_dynamic_shapes(model_name, device): model = torch.jit.script( create_model(load_example_args(model_name, remove_prior=True, derivative=True)) ).to(device=device) - #Repeat the input to make it dynamic + # Repeat the input to make it dynamic for rep in range(0, 5): print(rep) - zi = z.repeat_interleave(rep+1, dim=0).to(device=device) - posi = pos.repeat_interleave(rep+1, dim=0).to(device=device) + zi = z.repeat_interleave(rep + 1, dim=0).to(device=device) + posi = pos.repeat_interleave(rep + 1, dim=0).to(device=device) batchi = torch.randint(0, 10, (zi.shape[0],)).sort()[0].to(device=device) y, neg_dy = model(zi, posi, batch=batchi) grad_outputs = [torch.ones_like(neg_dy)] @@ -98,7 +109,8 @@ def test_torchscript_dynamic_shapes(model_name, device): grad_outputs=grad_outputs, )[0] -#Currently only tensornet is CUDA graph compatible + +# Currently only tensornet is CUDA graph compatible @mark.parametrize("model_name", ["tensornet"]) def test_cuda_graph_compatible(model_name): if not torch.cuda.is_available(): @@ -142,6 +154,7 @@ def test_cuda_graph_compatible(model_name): assert torch.allclose(y, y2) assert torch.allclose(neg_dy, neg_dy2, atol=1e-5, rtol=1e-5) + @mark.parametrize("model_name", models.__all_models__) def test_seed(model_name): args = load_example_args(model_name, remove_prior=True) @@ -153,6 +166,7 @@ def test_seed(model_name): for p1, p2 in zip(m1.parameters(), m2.parameters()): assert (p1 == p2).all(), "Parameters don't match although using the same seed." + @mark.parametrize("model_name", models.__all_models__) @mark.parametrize( "output_model", @@ -199,7 +213,9 @@ def test_forward_output(model_name, output_model, overwrite_reference=False): ), f"Set new reference outputs for {model_name} with output model {output_model}." # compare actual ouput with reference - torch.testing.assert_close(pred, expected[model_name][output_model]["pred"], atol=1e-5, rtol=1e-5) + torch.testing.assert_close( + pred, expected[model_name][output_model]["pred"], atol=1e-5, rtol=1e-5 + ) if derivative: torch.testing.assert_close( deriv, expected[model_name][output_model]["deriv"], atol=1e-5, rtol=1e-5 @@ -218,7 +234,7 @@ def test_gradients(model_name): remove_prior=True, output_model=output_model, derivative=derivative, - precision=precision + precision=precision, ) model = create_model(args) z, pos, batch = create_example_batch(n_atoms=5) @@ -227,3 +243,20 @@ def test_gradients(model_name): torch.autograd.gradcheck( model, (z, pos, batch), eps=1e-4, atol=1e-3, rtol=1e-2, nondet_tol=1e-3 ) + + +def test_ensemble(): + ckpts = [join(dirname(dirname(__file__)), "tests", "example.ckpt")] * 3 + model = load_model(ckpts[0]) + ensemble_model = load_model(ckpts, return_std=True) + z, pos, batch = create_example_batch(n_atoms=5) + + pred, deriv = model(z, pos, batch) + pred_ensemble, deriv_ensemble, y_std, neg_dy_std = ensemble_model(z, pos, batch) + + torch.testing.assert_close(pred, pred_ensemble, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(deriv, deriv_ensemble, atol=1e-5, rtol=1e-5) + assert y_std.shape == pred.shape + assert neg_dy_std.shape == deriv.shape + assert (y_std == 0).all() + assert (neg_dy_std == 0).all() diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index a2a80f901..c090f90f6 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -139,18 +139,25 @@ def create_model(args, prior_model=None, mean=None, std=None): return model -def load_model(filepath, args=None, device="cpu", **kwargs): +def load_model(filepath, args=None, device="cpu", return_std=False, **kwargs): """Load a model from a checkpoint file. + If a list of paths is given, an :py:mod:`Ensemble` model is returned. Args: - filepath (str): Path to the checkpoint file. + filepath (str or list): Path to the checkpoint file or a list of paths. args (dict, optional): Arguments for the model. Defaults to None. device (str, optional): Device on which the model should be loaded. Defaults to "cpu". + return_std (bool, optional): Whether to return the standard deviation of an Ensemble model. Defaults to False. **kwargs: Extra keyword arguments for the model. Returns: nn.Module: An instance of the TorchMD_Net model. """ + if isinstance(filepath, (list, tuple)): + return Ensemble( + [load_model(f, args=args, device=device, **kwargs) for f in filepath], + return_std=return_std, + ) ckpt = torch.load(filepath, map_location="cpu") if args is None: @@ -187,29 +194,32 @@ def create_prior_models(args, dataset=None): 1. A single prior model name and its arguments as a dictionary: - ```python - args = { - "prior_model": "Atomref", - "prior_args": {"max_z": 100} - } - ``` + .. code:: python + + args = { + "prior_model": "Atomref", + "prior_args": {"max_z": 100} + } + + 2. A list of prior model names and their arguments as a list of dictionaries: - ```python + .. code:: python + + args = { + "prior_model": ["Atomref", "D2"], + "prior_args": [{"max_z": 100}, {"max_z": 100}] + } - args = { - "prior_model": ["Atomref", "D2"], - "prior_args": [{"max_z": 100}, {"max_z": 100}] - } - ``` 3. A list of prior model names and their arguments as a dictionary: - ```python - args = { - "prior_model": [{"Atomref": {"max_z": 100}}, {"D2": {"max_z": 100}}] - } - ``` + .. code:: python + + args = { + "prior_model": [{"Atomref": {"max_z": 100}}, {"D2": {"max_z": 100}}] + } + Args: args (dict): Arguments for the model. @@ -426,3 +436,52 @@ def forward( # Returning an empty tensor allows to decorate this method as always returning two tensors. # This is required to overcome a TorchScript limitation, xref https://github.com/openmm/openmm-torch/issues/135 return y, torch.empty(0) + + +class Ensemble(torch.nn.ModuleList): + """Average predictions over an ensemble of TorchMD-Net models. + + This module behaves like a single TorchMD-Net model, but its forward method returns the average and standard deviation of the predictions over all models it was initialized with. + + Args: + modules (List[nn.Module]): List of :py:mod:`TorchMD_Net` models to average predictions over. + return_std (bool, optional): Whether to return the standard deviation of the predictions. Defaults to False. If set to True, the model returns 4 arguments (mean_y, mean_neg_dy, std_y, std_neg_dy) instead of 2 (mean_y, mean_neg_dy). + """ + + def __init__(self, modules: List[nn.Module], return_std: bool = False): + for module in modules: + assert isinstance(module, TorchMD_Net) + super().__init__(modules) + self.return_std = return_std + + def forward( + self, + *args, + **kwargs, + ): + """Average predictions over all models in the ensemble. + The arguments to this function are simply relayed to the forward method of each :py:mod:`TorchMD_Net` model in the ensemble. + Args: + *args: Positional arguments to forward to the models. + **kwargs: Keyword arguments to forward to the models. + Returns: + Tuple[Tensor, Optional[Tensor]] or Tuple[Tensor, Optional[Tensor], Tensor, Optional[Tensor]]: The average and standard deviation of the predictions over all models in the ensemble. If return_std is False, the output is a tuple (mean_y, mean_neg_dy). If return_std is True, the output is a tuple (mean_y, mean_neg_dy, std_y, std_neg_dy). + + """ + y = [] + neg_dy = [] + for model in self: + res = model(*args, **kwargs) + y.append(res[0]) + neg_dy.append(res[1]) + y = torch.stack(y) + neg_dy = torch.stack(neg_dy) + y_mean = torch.mean(y, axis=0) + neg_dy_mean = torch.mean(neg_dy, axis=0) + y_std = torch.std(y, axis=0) + neg_dy_std = torch.std(neg_dy, axis=0) + + if self.return_std: + return y_mean, neg_dy_mean, y_std, neg_dy_std + else: + return y_mean, neg_dy_mean