diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index 07711a37..f903c05d 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -482,8 +482,12 @@ def validation_step(self, batch: dict, batch_idx): logged_losses["MSE_persistence/val"] = F.mse_loss(persistence, y) # Log for each timestep the persistence loss for i in range(self.forecast_len): - logged_losses[f"MAE_persistence/step_{i:03}/val"] = F.l1_loss(persistence[:, i], y[:, i]) - logged_losses[f"MSE_persistence/step_{i:03}/val"] = F.mse_loss(persistence[:, i], y[:, i]) + logged_losses[f"MAE_persistence/step_{i:03}/val"] = F.l1_loss( + persistence[:, i], y[:, i] + ) + logged_losses[f"MSE_persistence/step_{i:03}/val"] = F.mse_loss( + persistence[:, i], y[:, i] + ) # Get the losses in the format of {VAL>_horizon/step_000: 0.1, VAL>_horizon/step_001: 0.2} # for each step in the forecast horizon # This is needed for the custom plot