From 1b67010080cd4026253d5d2b81ed9e5279155f7f Mon Sep 17 00:00:00 2001 From: Stefan Doerr Date: Mon, 18 Mar 2024 14:12:46 +0200 Subject: [PATCH 01/11] Added support for ensemble models --- torchmdnet/models/model.py | 39 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index a2a80f90..ecc9c44d 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -151,6 +151,10 @@ def load_model(filepath, args=None, device="cpu", **kwargs): Returns: nn.Module: An instance of the TorchMD_Net model. """ + if isinstance(filepath, (list, tuple)): + return Ensemble( + [load_model(f, args=args, device=device, **kwargs) for f in filepath] + ) ckpt = torch.load(filepath, map_location="cpu") if args is None: @@ -426,3 +430,38 @@ def forward( # Returning an empty tensor allows to decorate this method as always returning two tensors. # This is required to overcome a TorchScript limitation, xref https://github.com/openmm/openmm-torch/issues/135 return y, torch.empty(0) + + +class Ensemble(torch.nn.ModuleList): + """Average predictions over an ensemble of TorchMD-Net models""" + + def __init__(self, modules): + super().__init__(modules) + + def forward( + self, + z: Tensor, + pos: Tensor, + batch: Optional[Tensor] = None, + box: Optional[Tensor] = None, + q: Optional[Tensor] = None, + s: Optional[Tensor] = None, + extra_args: Optional[Dict[str, Tensor]] = None, + ): + y = [] + neg_dy = [] + for model in self: + res = model( + z=z, pos=pos, batch=batch, box=box, q=q, s=s, extra_args=extra_args + ) + y.append(res[0]) + neg_dy.append(res[1]) + + y = torch.stack(y) + print(y, neg_dy) + neg_dy = torch.stack(neg_dy) + y_mean = torch.mean(y, axis=0) + neg_dy_mean = torch.mean(neg_dy, axis=0) + y_std = torch.std(y, axis=0) + neg_dy_std = torch.std(neg_dy, axis=0) + return y_mean, neg_dy_mean, y_std, neg_dy_std From a5c0a3a842f0a27ef133c7051718523266de633d Mon Sep 17 00:00:00 2001 From: Stefan Doerr Date: Mon, 18 Mar 2024 14:15:18 +0200 Subject: [PATCH 02/11] remove debug print --- torchmdnet/models/model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index ecc9c44d..c30d894f 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -458,7 +458,6 @@ def forward( neg_dy.append(res[1]) y = torch.stack(y) - print(y, neg_dy) neg_dy = torch.stack(neg_dy) y_mean = torch.mean(y, axis=0) neg_dy_mean = torch.mean(neg_dy, axis=0) From 2b99e77e7e4003658d7efc9a23ba799263a766b2 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Mon, 18 Mar 2024 13:40:04 +0100 Subject: [PATCH 03/11] Update load_model docstring --- torchmdnet/models/model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index c30d894f..6a0ecf7f 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -142,8 +142,9 @@ def create_model(args, prior_model=None, mean=None, std=None): def load_model(filepath, args=None, device="cpu", **kwargs): """Load a model from a checkpoint file. + If a list of paths is given, an :py:mod:`Ensemble` model is returned. Args: - filepath (str): Path to the checkpoint file. + filepath (str or list): Path to the checkpoint file or a list of paths. args (dict, optional): Arguments for the model. Defaults to None. device (str, optional): Device on which the model should be loaded. Defaults to "cpu". **kwargs: Extra keyword arguments for the model. From d071cd38324a96a87801098e6a179509efaf2db5 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Mon, 18 Mar 2024 13:43:31 +0100 Subject: [PATCH 04/11] Update Ensemble docstring --- torchmdnet/models/model.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index 6a0ecf7f..618a024a 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -434,9 +434,17 @@ def forward( class Ensemble(torch.nn.ModuleList): - """Average predictions over an ensemble of TorchMD-Net models""" + """Average predictions over an ensemble of TorchMD-Net models. - def __init__(self, modules): + This module behaves like a single TorchMD-Net model, but its forward method returns the average and standard deviation of the predictions over all models it was initialized with. + + Args: + modules (List[nn.Module]): List of :py:mod:`TorchMD_Net` models to average predictions over. + """ + + def __init__(self, modules: List[nn.Module]): + for module in modules: + assert isinstance(module, TorchMD_Net) super().__init__(modules) def forward( From 94f94d90208b43ca083ce7978ca6e7ddadd9085b Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Mon, 18 Mar 2024 14:07:14 +0100 Subject: [PATCH 05/11] Add test for ensemble --- tests/test_model.py | 99 ++++++++++++++++++++++++++++++--------------- 1 file changed, 67 insertions(+), 32 deletions(-) diff --git a/tests/test_model.py b/tests/test_model.py index b792595b..31a09a21 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -9,7 +9,7 @@ import torch import lightning as pl from torchmdnet import models -from torchmdnet.models.model import create_model +from torchmdnet.models.model import create_model, load_model from torchmdnet.models import output_modules from torchmdnet.models.utils import dtype_mapping @@ -23,7 +23,9 @@ def test_forward(model_name, use_batch, explicit_q_s, precision): z, pos, batch = create_example_batch() pos = pos.to(dtype=dtype_mapping[precision]) - model = create_model(load_example_args(model_name, prior_model=None, precision=precision)) + model = create_model( + load_example_args(model_name, prior_model=None, precision=precision) + ) batch = batch if use_batch else None if explicit_q_s: model(z, pos, batch=batch, q=None, s=None) @@ -33,10 +35,12 @@ def test_forward(model_name, use_batch, explicit_q_s, precision): @mark.parametrize("model_name", models.__all_models__) @mark.parametrize("output_model", output_modules.__all__) -@mark.parametrize("precision", [32,64]) +@mark.parametrize("precision", [32, 64]) def test_forward_output_modules(model_name, output_model, precision): z, pos, batch = create_example_batch() - args = load_example_args(model_name, remove_prior=True, output_model=output_model, precision=precision) + args = load_example_args( + model_name, remove_prior=True, output_model=output_model, precision=precision + ) model = create_model(args) model(z, pos, batch=batch) @@ -61,18 +65,25 @@ def test_torchscript(model_name, device): grad_outputs=grad_outputs, )[0] + def test_torchscript_output_modification(): - model = create_model(load_example_args("tensornet", remove_prior=True, derivative=True)) + model = create_model( + load_example_args("tensornet", remove_prior=True, derivative=True) + ) + class MyModel(torch.nn.Module): def __init__(self): super(MyModel, self).__init__() self.model = model + def forward(self, z, pos, batch): y, neg_dy = self.model(z, pos, batch=batch) # A TorchScript bug is triggered if we modify an output of model marked as Optional[Tensor] - return y, 2*neg_dy + return y, 2 * neg_dy + torch.jit.script(MyModel()) + @mark.parametrize("model_name", models.__all_models__) @mark.parametrize("device", ["cpu", "cuda"]) def test_torchscript_dynamic_shapes(model_name, device): @@ -84,11 +95,11 @@ def test_torchscript_dynamic_shapes(model_name, device): model = torch.jit.script( create_model(load_example_args(model_name, remove_prior=True, derivative=True)) ).to(device=device) - #Repeat the input to make it dynamic + # Repeat the input to make it dynamic for rep in range(0, 5): print(rep) - zi = z.repeat_interleave(rep+1, dim=0).to(device=device) - posi = pos.repeat_interleave(rep+1, dim=0).to(device=device) + zi = z.repeat_interleave(rep + 1, dim=0).to(device=device) + posi = pos.repeat_interleave(rep + 1, dim=0).to(device=device) batchi = torch.randint(0, 10, (zi.shape[0],)).sort()[0].to(device=device) y, neg_dy = model(zi, posi, batch=batchi) grad_outputs = [torch.ones_like(neg_dy)] @@ -98,32 +109,35 @@ def test_torchscript_dynamic_shapes(model_name, device): grad_outputs=grad_outputs, )[0] -#Currently only tensornet is CUDA graph compatible + +# Currently only tensornet is CUDA graph compatible @mark.parametrize("model_name", ["tensornet"]) def test_cuda_graph_compatible(model_name): if not torch.cuda.is_available(): pytest.skip("CUDA not available") z, pos, batch = create_example_batch() - args = {"model": model_name, - "embedding_dimension": 128, - "num_layers": 2, - "num_rbf": 32, - "rbf_type": "expnorm", - "trainable_rbf": False, - "activation": "silu", - "cutoff_lower": 0.0, - "cutoff_upper": 5.0, - "max_z": 100, - "max_num_neighbors": 128, - "equivariance_invariance_group": "O(3)", - "prior_model": None, - "atom_filter": -1, - "derivative": True, - "check_error": False, - "static_shapes": True, - "output_model": "Scalar", - "reduce_op": "sum", - "precision": 32 } + args = { + "model": model_name, + "embedding_dimension": 128, + "num_layers": 2, + "num_rbf": 32, + "rbf_type": "expnorm", + "trainable_rbf": False, + "activation": "silu", + "cutoff_lower": 0.0, + "cutoff_upper": 5.0, + "max_z": 100, + "max_num_neighbors": 128, + "equivariance_invariance_group": "O(3)", + "prior_model": None, + "atom_filter": -1, + "derivative": True, + "check_error": False, + "static_shapes": True, + "output_model": "Scalar", + "reduce_op": "sum", + "precision": 32, + } model = create_model(args).to(device="cuda") model.eval() z = z.to("cuda") @@ -142,6 +156,7 @@ def test_cuda_graph_compatible(model_name): assert torch.allclose(y, y2) assert torch.allclose(neg_dy, neg_dy2, atol=1e-5, rtol=1e-5) + @mark.parametrize("model_name", models.__all_models__) def test_seed(model_name): args = load_example_args(model_name, remove_prior=True) @@ -153,6 +168,7 @@ def test_seed(model_name): for p1, p2 in zip(m1.parameters(), m2.parameters()): assert (p1 == p2).all(), "Parameters don't match although using the same seed." + @mark.parametrize("model_name", models.__all_models__) @mark.parametrize( "output_model", @@ -199,7 +215,9 @@ def test_forward_output(model_name, output_model, overwrite_reference=False): ), f"Set new reference outputs for {model_name} with output model {output_model}." # compare actual ouput with reference - torch.testing.assert_close(pred, expected[model_name][output_model]["pred"], atol=1e-5, rtol=1e-5) + torch.testing.assert_close( + pred, expected[model_name][output_model]["pred"], atol=1e-5, rtol=1e-5 + ) if derivative: torch.testing.assert_close( deriv, expected[model_name][output_model]["deriv"], atol=1e-5, rtol=1e-5 @@ -218,7 +236,7 @@ def test_gradients(model_name): remove_prior=True, output_model=output_model, derivative=derivative, - precision=precision + precision=precision, ) model = create_model(args) z, pos, batch = create_example_batch(n_atoms=5) @@ -227,3 +245,20 @@ def test_gradients(model_name): torch.autograd.gradcheck( model, (z, pos, batch), eps=1e-4, atol=1e-3, rtol=1e-2, nondet_tol=1e-3 ) + + +def test_ensemble(): + ckpts = [join(dirname(dirname(__file__)), "tests", "example.ckpt")] * 3 + model = load_model(ckpts[0]) + ensemble_model = load_model(ckpts) + z, pos, batch = create_example_batch(n_atoms=5) + + pred, deriv = model(z, pos, batch) + pred_ensemble, deriv_ensemble, y_std, neg_dy_std = ensemble_model(z, pos, batch) + + torch.testing.assert_close(pred, pred_ensemble, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(deriv, deriv_ensemble, atol=1e-5, rtol=1e-5) + assert y_std.shape == pred.shape + assert neg_dy_std.shape == deriv.shape + assert (y_std == 0).all() + assert (neg_dy_std == 0).all() From 5ba3db96e0b8670808f3ddc5f75ff324d192660b Mon Sep 17 00:00:00 2001 From: Stefan Doerr Date: Tue, 19 Mar 2024 10:13:38 +0200 Subject: [PATCH 06/11] returning the standard deviation is now optional --- torchmdnet/models/model.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index 618a024a..ad94ad01 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -139,7 +139,7 @@ def create_model(args, prior_model=None, mean=None, std=None): return model -def load_model(filepath, args=None, device="cpu", **kwargs): +def load_model(filepath, args=None, device="cpu", return_std=False, **kwargs): """Load a model from a checkpoint file. If a list of paths is given, an :py:mod:`Ensemble` model is returned. @@ -147,6 +147,7 @@ def load_model(filepath, args=None, device="cpu", **kwargs): filepath (str or list): Path to the checkpoint file or a list of paths. args (dict, optional): Arguments for the model. Defaults to None. device (str, optional): Device on which the model should be loaded. Defaults to "cpu". + return_std (bool, optional): Whether to return the standard deviation of an Ensemble model. Defaults to False. **kwargs: Extra keyword arguments for the model. Returns: @@ -154,7 +155,8 @@ def load_model(filepath, args=None, device="cpu", **kwargs): """ if isinstance(filepath, (list, tuple)): return Ensemble( - [load_model(f, args=args, device=device, **kwargs) for f in filepath] + [load_model(f, args=args, device=device, **kwargs) for f in filepath], + return_std=return_std, ) ckpt = torch.load(filepath, map_location="cpu") @@ -440,12 +442,14 @@ class Ensemble(torch.nn.ModuleList): Args: modules (List[nn.Module]): List of :py:mod:`TorchMD_Net` models to average predictions over. + return_std (bool, optional): Whether to return the standard deviation of the predictions. Defaults to False. If set to True, the model returns 4 arguments (mean_y, mean_neg_dy, std_y, std_neg_dy) instead of 2 (mean_y, mean_neg_dy). """ - def __init__(self, modules: List[nn.Module]): + def __init__(self, modules: List[nn.Module], return_std: bool = False): for module in modules: assert isinstance(module, TorchMD_Net) super().__init__(modules) + self.return_std = return_std def forward( self, @@ -472,4 +476,8 @@ def forward( neg_dy_mean = torch.mean(neg_dy, axis=0) y_std = torch.std(y, axis=0) neg_dy_std = torch.std(neg_dy, axis=0) - return y_mean, neg_dy_mean, y_std, neg_dy_std + + if self.return_std: + return y_mean, neg_dy_mean, y_std, neg_dy_std + else: + return y_mean, neg_dy_mean From 417955d61617d2338a2c8f23f742162d862f670d Mon Sep 17 00:00:00 2001 From: Stefan Doerr Date: Tue, 19 Mar 2024 10:42:30 +0200 Subject: [PATCH 07/11] fix test --- tests/test_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_model.py b/tests/test_model.py index 31a09a21..88ac9a1c 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -250,7 +250,7 @@ def test_gradients(model_name): def test_ensemble(): ckpts = [join(dirname(dirname(__file__)), "tests", "example.ckpt")] * 3 model = load_model(ckpts[0]) - ensemble_model = load_model(ckpts) + ensemble_model = load_model(ckpts, return_std=True) z, pos, batch = create_example_batch(n_atoms=5) pred, deriv = model(z, pos, batch) From 5a60b40ae193dc6a1d7a0f4e8f57d727193be455 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Tue, 19 Mar 2024 15:06:29 +0100 Subject: [PATCH 08/11] Make Ensemble variadic --- torchmdnet/models/model.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index ad94ad01..0422d6de 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -453,23 +453,18 @@ def __init__(self, modules: List[nn.Module], return_std: bool = False): def forward( self, - z: Tensor, - pos: Tensor, - batch: Optional[Tensor] = None, - box: Optional[Tensor] = None, - q: Optional[Tensor] = None, - s: Optional[Tensor] = None, - extra_args: Optional[Dict[str, Tensor]] = None, + *args, + **kwargs, ): + """Average predictions over all models in the ensemble. + The arguments to this function are simply relayed to the forward method of each :py:mod:`TorchMD_Net` model in the ensemble. + """ y = [] neg_dy = [] for model in self: - res = model( - z=z, pos=pos, batch=batch, box=box, q=q, s=s, extra_args=extra_args - ) + res = model(*args, **kwargs) y.append(res[0]) neg_dy.append(res[1]) - y = torch.stack(y) neg_dy = torch.stack(neg_dy) y_mean = torch.mean(y, axis=0) From 013a691da1b213080b6f2921717d4a2e78f2258d Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Tue, 19 Mar 2024 15:29:07 +0100 Subject: [PATCH 09/11] Update documentation --- docs/source/models.rst | 83 ++++++++++++++++++++++++++---------------- 1 file changed, 52 insertions(+), 31 deletions(-) diff --git a/docs/source/models.rst b/docs/source/models.rst index 244c1692..5ec760e3 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -83,38 +83,9 @@ This is a minimal example of a custom training loop: optimizer.step() - - -Loading a model for inference ------------------------------ - -Once you have trained a model you should have a checkpoint that you can load for inference using :py:mod:`torchmdnet.models.model.load_model` as in the following example. - -.. code:: python - - import torch - from torchmdnet.models.model import load_model - checkpoint = "/path/to/checkpoint/my_checkpoint.ckpt" - model = load_model(checkpoint, derivative=True) - # An arbitrary set of inputs for the model - n_atoms = 10 - zs = torch.tensor([1, 6, 7, 8, 9], dtype=torch.long) - z = zs[torch.randint(0, len(zs), (n_atoms,))] - pos = torch.randn(len(z), 3) - batch = torch.zeros(len(z), dtype=torch.long) - - y, neg_dy = model(z, pos, batch) - -.. note:: You can train a model using only the labels (i.e. energy) by passing :code:`derivative=False` and then set it to :code:`True` to compute its derivative (i.e. forces) only during inference. - -.. note:: Some models take additional inputs such as the charge :code:`q` and the spin :code:`s` of the atoms depending on the chosen priors/outputs. Check the documentation of the model you are using to see if this is the case. - -.. note:: When periodic boundary conditions are required, modules typically offer the possibility of providing the box vectors at construction and/or as an argument to the forward pass. Check the documentation of the class you are using to see if this is the case. - - .. _delta-learning: Training on relative energies ------------------------------ +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ It might be useful to train the model on relative energies but then make the model produce total energies when running inference. TorchMD-Net supports delta training via the :code:`remove_ref_energy` option. Passing this option when training (either via the :ref:`configuration-file` or using the :ref:`torchmd-train` command line interface) will subtract the reference energy from each atom in a sample before passing it to the model. @@ -126,7 +97,7 @@ If :code:`remove_ref_energy` is turned on, the reference energy is stored in the .. note:: The reference energies are stored as an :py:mod:`torchmdnet.priors.Atomref` prior with :code:`enable=False`. Example -~~~~~~~ +******** First we train a model with the :code:`remove_ref_energy` option turned on: @@ -151,6 +122,56 @@ Then we load the model for inference: batch = torch.zeros(len(z), dtype=torch.long) y, neg_dy = model(z, pos, batch) + + +Loading a model for inference +----------------------------- + +Once you have trained a model you should have a checkpoint that you can load for inference using :py:mod:`torchmdnet.models.model.load_model` as in the following example. + +.. code:: python + + import torch + from torchmdnet.models.model import load_model + checkpoint = "/path/to/checkpoint/my_checkpoint.ckpt" + model = load_model(checkpoint, derivative=True) + # An arbitrary set of inputs for the model + n_atoms = 10 + zs = torch.tensor([1, 6, 7, 8, 9], dtype=torch.long) + z = zs[torch.randint(0, len(zs), (n_atoms,))] + pos = torch.randn(len(z), 3) + batch = torch.zeros(len(z), dtype=torch.long) + + y, neg_dy = model(z, pos, batch) + +.. note:: You can train a model using only the labels (i.e. energy) by passing :code:`derivative=False` and then set it to :code:`True` to compute its derivative (i.e. forces) only during inference. + +.. note:: Some models take additional inputs such as the charge :code:`q` and the spin :code:`s` of the atoms depending on the chosen priors/outputs. Check the documentation of the model you are using to see if this is the case. + +.. note:: When periodic boundary conditions are required, modules typically offer the possibility of providing the box vectors at construction and/or as an argument to the forward pass. Check the documentation of the class you are using to see if this is the case. + + +Model Ensembles +--------------- +It is possible to create an ensemble of models by loading multiple checkpoints and averaging their predictions. The following example shows how to do this: + +.. code:: python + + import torch + from torchmdnet.models.model import load_model + checkpoints = ["/path/to/checkpoint/my_checkpoint1.ckpt", "/path/to/checkpoint/my_checkpoint2.ckpt"] + model_ensemble = load_model(checkpoints, return_std=True) + y_ensemble, neg_dy_ensemble, y_std, neg_dy_std = ensemble_model(z, pos, batch) + + +.. note:: :py:mod:`torchmdnet.models.model.load_model` will return an instance of :py:mod:`torchmdnet.models.model.Ensemble` if a list of checkpoints is passed. The :code:`return_std` option can be used to return the standard deviation of the predictions. + + + +.. autoclass:: torchmdnet.models.model.Ensemble + :noindex: + + From 365edea1e121c4bd7e89aec9fef7819a699bd506 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Tue, 19 Mar 2024 15:29:14 +0100 Subject: [PATCH 10/11] Update docstrings --- torchmdnet/models/model.py | 37 ++++++++++++++++++++----------------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index 0422d6de..d7ec6d7d 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -194,29 +194,32 @@ def create_prior_models(args, dataset=None): 1. A single prior model name and its arguments as a dictionary: - ```python - args = { - "prior_model": "Atomref", - "prior_args": {"max_z": 100} - } - ``` + .. code:: python + + args = { + "prior_model": "Atomref", + "prior_args": {"max_z": 100} + } + + 2. A list of prior model names and their arguments as a list of dictionaries: - ```python + .. code:: python + + args = { + "prior_model": ["Atomref", "D2"], + "prior_args": [{"max_z": 100}, {"max_z": 100}] + } - args = { - "prior_model": ["Atomref", "D2"], - "prior_args": [{"max_z": 100}, {"max_z": 100}] - } - ``` 3. A list of prior model names and their arguments as a dictionary: - ```python - args = { - "prior_model": [{"Atomref": {"max_z": 100}}, {"D2": {"max_z": 100}}] - } - ``` + .. code:: python + + args = { + "prior_model": [{"Atomref": {"max_z": 100}}, {"D2": {"max_z": 100}}] + } + Args: args (dict): Arguments for the model. From f21f3de3b5e12d0781c642579081f3306d28b683 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Tue, 19 Mar 2024 15:31:12 +0100 Subject: [PATCH 11/11] Update docstring --- torchmdnet/models/model.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index d7ec6d7d..c090f90f 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -461,6 +461,12 @@ def forward( ): """Average predictions over all models in the ensemble. The arguments to this function are simply relayed to the forward method of each :py:mod:`TorchMD_Net` model in the ensemble. + Args: + *args: Positional arguments to forward to the models. + **kwargs: Keyword arguments to forward to the models. + Returns: + Tuple[Tensor, Optional[Tensor]] or Tuple[Tensor, Optional[Tensor], Tensor, Optional[Tensor]]: The average and standard deviation of the predictions over all models in the ensemble. If return_std is False, the output is a tuple (mean_y, mean_neg_dy). If return_std is True, the output is a tuple (mean_y, mean_neg_dy, std_y, std_neg_dy). + """ y = [] neg_dy = []