diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index db56946f..f903c05d 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -471,11 +471,23 @@ def validation_step(self, batch: dict, batch_idx): y_hat = self(batch) # Sensor seems to be in batch, station, time order y = batch[self._target_key][:, -self.forecast_len :, 0] - + persistence = batch[self._target_key][:, -self.forecast_len - 1, 0] + # Expand persistence to be the same shape as y + persistence = persistence.unsqueeze(1).expand(-1, self.forecast_len) losses = self._calculate_common_losses(y, y_hat) losses.update(self._calculate_val_losses(y, y_hat)) logged_losses = {f"{k}/val": v for k, v in losses.items()} + logged_losses["MAE_persistence/val"] = F.l1_loss(persistence, y) + logged_losses["MSE_persistence/val"] = F.mse_loss(persistence, y) + # Log for each timestep the persistence loss + for i in range(self.forecast_len): + logged_losses[f"MAE_persistence/step_{i:03}/val"] = F.l1_loss( + persistence[:, i], y[:, i] + ) + logged_losses[f"MSE_persistence/step_{i:03}/val"] = F.mse_loss( + persistence[:, i], y[:, i] + ) # Get the losses in the format of {VAL>_horizon/step_000: 0.1, VAL>_horizon/step_001: 0.2} # for each step in the forecast horizon # This is needed for the custom plot