Skip to content

Commit

Permalink
Add persistence calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobbieker committed Feb 23, 2024
1 parent 5ce274a commit 5236759
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion pvnet/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,11 +471,19 @@ def validation_step(self, batch: dict, batch_idx):
y_hat = self(batch)
# Sensor seems to be in batch, station, time order
y = batch[self._target_key][:, -self.forecast_len :, 0]

persistence = batch[self._target_key][:, -self.forecast_len - 1, 0]
# Expand persistence to be the same shape as y
persistence = persistence.unsqueeze(1).expand(-1, self.forecast_len)
losses = self._calculate_common_losses(y, y_hat)
losses.update(self._calculate_val_losses(y, y_hat))

logged_losses = {f"{k}/val": v for k, v in losses.items()}
logged_losses["MAE_persistence/val"] = F.l1_loss(persistence, y)
logged_losses["MSE_persistence/val"] = F.mse_loss(persistence, y)
# Log for each timestep the persistence loss
for i in range(self.forecast_len):
logged_losses[f"MAE_persistence/step_{i:03}/val"] = F.l1_loss(persistence[:, i], y[:, i])
logged_losses[f"MSE_persistence/step_{i:03}/val"] = F.mse_loss(persistence[:, i], y[:, i])
# Get the losses in the format of {VAL>_horizon/step_000: 0.1, VAL>_horizon/step_001: 0.2}
# for each step in the forecast horizon
# This is needed for the custom plot
Expand Down

0 comments on commit 5236759

Please sign in to comment.