From 523675952c3e3e46512d14f25e38b379f3518d7e Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Fri, 23 Feb 2024 10:59:13 +0000 Subject: [PATCH 1/2] Add persistence calculation --- pvnet/models/base_model.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index db56946f..07711a37 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -471,11 +471,19 @@ def validation_step(self, batch: dict, batch_idx): y_hat = self(batch) # Sensor seems to be in batch, station, time order y = batch[self._target_key][:, -self.forecast_len :, 0] - + persistence = batch[self._target_key][:, -self.forecast_len - 1, 0] + # Expand persistence to be the same shape as y + persistence = persistence.unsqueeze(1).expand(-1, self.forecast_len) losses = self._calculate_common_losses(y, y_hat) losses.update(self._calculate_val_losses(y, y_hat)) logged_losses = {f"{k}/val": v for k, v in losses.items()} + logged_losses["MAE_persistence/val"] = F.l1_loss(persistence, y) + logged_losses["MSE_persistence/val"] = F.mse_loss(persistence, y) + # Log for each timestep the persistence loss + for i in range(self.forecast_len): + logged_losses[f"MAE_persistence/step_{i:03}/val"] = F.l1_loss(persistence[:, i], y[:, i]) + logged_losses[f"MSE_persistence/step_{i:03}/val"] = F.mse_loss(persistence[:, i], y[:, i]) # Get the losses in the format of {VAL>_horizon/step_000: 0.1, VAL>_horizon/step_001: 0.2} # for each step in the forecast horizon # This is needed for the custom plot From 5d97d1f43ae64dd556836ada4d9e172ec5b920bd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 23 Feb 2024 11:00:10 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pvnet/models/base_model.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index 07711a37..f903c05d 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -482,8 +482,12 @@ def validation_step(self, batch: dict, batch_idx): logged_losses["MSE_persistence/val"] = F.mse_loss(persistence, y) # Log for each timestep the persistence loss for i in range(self.forecast_len): - logged_losses[f"MAE_persistence/step_{i:03}/val"] = F.l1_loss(persistence[:, i], y[:, i]) - logged_losses[f"MSE_persistence/step_{i:03}/val"] = F.mse_loss(persistence[:, i], y[:, i]) + logged_losses[f"MAE_persistence/step_{i:03}/val"] = F.l1_loss( + persistence[:, i], y[:, i] + ) + logged_losses[f"MSE_persistence/step_{i:03}/val"] = F.mse_loss( + persistence[:, i], y[:, i] + ) # Get the losses in the format of {VAL>_horizon/step_000: 0.1, VAL>_horizon/step_001: 0.2} # for each step in the forecast horizon # This is needed for the custom plot