From 5d97d1f43ae64dd556836ada4d9e172ec5b920bd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 23 Feb 2024 11:00:10 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pvnet/models/base_model.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index 07711a37..f903c05d 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -482,8 +482,12 @@ def validation_step(self, batch: dict, batch_idx): logged_losses["MSE_persistence/val"] = F.mse_loss(persistence, y) # Log for each timestep the persistence loss for i in range(self.forecast_len): - logged_losses[f"MAE_persistence/step_{i:03}/val"] = F.l1_loss(persistence[:, i], y[:, i]) - logged_losses[f"MSE_persistence/step_{i:03}/val"] = F.mse_loss(persistence[:, i], y[:, i]) + logged_losses[f"MAE_persistence/step_{i:03}/val"] = F.l1_loss( + persistence[:, i], y[:, i] + ) + logged_losses[f"MSE_persistence/step_{i:03}/val"] = F.mse_loss( + persistence[:, i], y[:, i] + ) # Get the losses in the format of {VAL>_horizon/step_000: 0.1, VAL>_horizon/step_001: 0.2} # for each step in the forecast horizon # This is needed for the custom plot