diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 45cbd247..53126792 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -575,7 +575,13 @@ def aggregate_and_plot_metrics(self, metrics_dict, prefix): ) if self.trainer.is_global_zero and not self.trainer.sanity_checking: + + current_epoch = self.trainer.current_epoch + for key, figure in log_dict.items(): + if not isinstance(self.logger, pl.loggers.WandbLogger): + key = f"{key}-{current_epoch}" + self.logger.log_image(key=key, images=[figure]) plt.close("all") # Close all figs