From 523675952c3e3e46512d14f25e38b379f3518d7e Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Fri, 23 Feb 2024 10:59:13 +0000 Subject: [PATCH] Add persistence calculation --- pvnet/models/base_model.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index db56946f..07711a37 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -471,11 +471,19 @@ def validation_step(self, batch: dict, batch_idx): y_hat = self(batch) # Sensor seems to be in batch, station, time order y = batch[self._target_key][:, -self.forecast_len :, 0] - + persistence = batch[self._target_key][:, -self.forecast_len - 1, 0] + # Expand persistence to be the same shape as y + persistence = persistence.unsqueeze(1).expand(-1, self.forecast_len) losses = self._calculate_common_losses(y, y_hat) losses.update(self._calculate_val_losses(y, y_hat)) logged_losses = {f"{k}/val": v for k, v in losses.items()} + logged_losses["MAE_persistence/val"] = F.l1_loss(persistence, y) + logged_losses["MSE_persistence/val"] = F.mse_loss(persistence, y) + # Log for each timestep the persistence loss + for i in range(self.forecast_len): + logged_losses[f"MAE_persistence/step_{i:03}/val"] = F.l1_loss(persistence[:, i], y[:, i]) + logged_losses[f"MSE_persistence/step_{i:03}/val"] = F.mse_loss(persistence[:, i], y[:, i]) # Get the losses in the format of {VAL>_horizon/step_000: 0.1, VAL>_horizon/step_001: 0.2} # for each step in the forecast horizon # This is needed for the custom plot