diff --git a/torchmdnet/module.py b/torchmdnet/module.py index d5ea73cf..51a2ade4 100644 --- a/torchmdnet/module.py +++ b/torchmdnet/module.py @@ -6,6 +6,7 @@ import torch from torch.optim import AdamW from torch.optim.lr_scheduler import ReduceLROnPlateau +import torch.optim.swa_utils from torch.nn.functional import local_response_norm, mse_loss, l1_loss from torch import Tensor from typing import Optional, Dict, Tuple @@ -73,9 +74,26 @@ def __init__(self, hparams, prior_model=None, mean=None, std=None): else: self.model = create_model(self.hparams, prior_model, mean, std) - # initialize exponential smoothing - self.ema = None - self._reset_ema_dict() + self.ema_model = None + if ( + "ema_parameters_decay" in self.hparams + and self.hparams.ema_parameters_decay is not None + ): + self.ema_model = torch.optim.swa_utils.AveragedModel( + self.model, + multi_avg_fn=torch.optim.swa_utils.get_ema_multi_avg_fn( + self.hparams.ema_parameters_decay + ), + ) + self.ema_parameters_start = ( + self.hparams.ema_parameters_start + if "ema_parameters_start" in self.hparams + else 0 + ) + + # initialize exponential smoothing for the losses + self.ema_loss = None + self._reset_ema_loss_dict() # initialize loss collection self.losses = None @@ -177,12 +195,12 @@ def _update_loss_with_ema(self, stage, type, loss_name, loss): alpha = getattr(self.hparams, f"ema_alpha_{type}") if stage in ["train", "val"] and alpha < 1 and alpha > 0: ema = ( - self.ema[stage][type][loss_name] - if loss_name in self.ema[stage][type] + self.ema_loss[stage][type][loss_name] + if loss_name in self.ema_loss[stage][type] else loss.detach() ) loss = alpha * loss + (1 - alpha) * ema - self.ema[stage][type][loss_name] = loss.detach() + self.ema_loss[stage][type][loss_name] = loss.detach() return loss def step(self, batch, loss_fn_list, stage): @@ -250,6 +268,11 @@ def optimizer_step(self, *args, **kwargs): for pg in optimizer.param_groups: pg["lr"] = lr_scale * self.hparams.lr super().optimizer_step(*args, **kwargs) + if ( + self.trainer.current_epoch >= self.ema_parameters_start + and self.ema_model is not None + ): + self.ema_model.update_parameters(self.model) optimizer.zero_grad() def _get_mean_loss_dict_for_type(self, type): @@ -304,9 +327,9 @@ def _reset_losses_dict(self): for loss_type in ["total", "y", "neg_dy"]: self.losses[stage][loss_type] = defaultdict(list) - def _reset_ema_dict(self): - self.ema = {} + def _reset_ema_loss_dict(self): + self.ema_loss = {} for stage in ["train", "val"]: - self.ema[stage] = {} + self.ema_loss[stage] = {} for loss_type in ["y", "neg_dy"]: - self.ema[stage][loss_type] = {} + self.ema_loss[stage][loss_type] = {} diff --git a/torchmdnet/scripts/train.py b/torchmdnet/scripts/train.py index 7f2d8e07..aef1b0ec 100644 --- a/torchmdnet/scripts/train.py +++ b/torchmdnet/scripts/train.py @@ -59,6 +59,8 @@ def get_argparse(): parser.add_argument('--redirect', type=bool, default=False, help='Redirect stdout and stderr to log_dir/log') parser.add_argument('--gradient-clipping', type=float, default=0.0, help='Gradient clipping norm') parser.add_argument('--remove-ref-energy', action='store_true', help='If true, remove the reference energy from the dataset for delta-learning. Total energy can still be predicted by the model during inference by turning this flag off when loading. The dataset must be compatible with Atomref for this to be used.') + parser.add_argument('--ema-parameters-decay', type=float, default=None, help='Exponential moving average decay for model parameters (defaults to None, meaning disable). The decay is a parameter between 0 and 1 that controls how fast the averaged parameters are decayed.') + parser.add_argument('--ema-parameters-start', type=int, default=0, help='Epoch to start averaging the parameters.') # dataset specific parser.add_argument('--dataset', default=None, type=str, choices=datasets.__all__, help='Name of the torch_geometric dataset') parser.add_argument('--dataset-root', default='~/data', type=str, help='Data storage directory (not used if dataset is "CG")')