Skip to content

Commit

Permalink
Merge pull request #144 from openclimatefix/jacob/persistence
Browse files Browse the repository at this point in the history
Add Persistence Baseline
  • Loading branch information
jacobbieker authored Feb 23, 2024
2 parents 5ce274a + 5d97d1f commit cdb5fcc
Showing 1 changed file with 13 additions and 1 deletion.
14 changes: 13 additions & 1 deletion pvnet/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,11 +471,23 @@ def validation_step(self, batch: dict, batch_idx):
y_hat = self(batch)
# Sensor seems to be in batch, station, time order
y = batch[self._target_key][:, -self.forecast_len :, 0]

persistence = batch[self._target_key][:, -self.forecast_len - 1, 0]
# Expand persistence to be the same shape as y
persistence = persistence.unsqueeze(1).expand(-1, self.forecast_len)
losses = self._calculate_common_losses(y, y_hat)
losses.update(self._calculate_val_losses(y, y_hat))

logged_losses = {f"{k}/val": v for k, v in losses.items()}
logged_losses["MAE_persistence/val"] = F.l1_loss(persistence, y)
logged_losses["MSE_persistence/val"] = F.mse_loss(persistence, y)
# Log for each timestep the persistence loss
for i in range(self.forecast_len):
logged_losses[f"MAE_persistence/step_{i:03}/val"] = F.l1_loss(
persistence[:, i], y[:, i]
)
logged_losses[f"MSE_persistence/step_{i:03}/val"] = F.mse_loss(
persistence[:, i], y[:, i]
)
# Get the losses in the format of {VAL>_horizon/step_000: 0.1, VAL>_horizon/step_001: 0.2}
# for each step in the forecast horizon
# This is needed for the custom plot
Expand Down

0 comments on commit cdb5fcc

Please sign in to comment.