Skip to content

Commit

Permalink
Merge pull request #173 from openclimatefix/horizon_loss_fix
Browse files Browse the repository at this point in the history
Fix horizon accuracy plots
  • Loading branch information
jacobbieker authored Apr 11, 2024
2 parents 3206c75 + 41dd431 commit a23ce6c
Showing 1 changed file with 71 additions and 61 deletions.
132 changes: 71 additions & 61 deletions pvnet/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ def __init__(
self._accumulated_metrics = MetricAccumulator()
self._accumulated_batches = BatchAccumulator(key_to_keep=self._target_key_name)
self._accumulated_y_hat = PredAccumulator()
self._horizon_maes = MetricAccumulator()

# Store whether the model should use quantile regression or simply predict the mean
self.use_quantile_regression = self.output_quantiles is not None
Expand Down Expand Up @@ -390,6 +391,18 @@ def _calculate_common_losses(self, y, y_hat):

return losses

def _step_mae_and_mse(self, y, y_hat, dict_key_root):
"""Calculate the MSE and MAE at each forecast step"""
losses = {}

mse_each_step = torch.mean((y_hat - y) ** 2, dim=0)
mae_each_step = torch.mean(torch.abs(y_hat - y), dim=0)

losses.update({f"MSE_{dict_key_root}/step_{i:03}": m for i, m in enumerate(mse_each_step)})
losses.update({f"MAE_{dict_key_root}/step_{i:03}": m for i, m in enumerate(mae_each_step)})

return losses

def _calculate_val_losses(self, y, y_hat):
"""Calculate additional validation losses"""

Expand All @@ -405,12 +418,17 @@ def _calculate_val_losses(self, y, y_hat):

# Take median value for remaining metric calculations
y_hat = self._quantiles_to_prediction(y_hat)
mse_each_step = torch.mean((y_hat - y) ** 2, dim=0)
mae_each_step = torch.mean(torch.abs(y_hat - y), dim=0)

losses.update({f"MSE_horizon/step_{i:03}": m for i, m in enumerate(mse_each_step)})
losses.update({f"MAE_horizon/step_{i:03}": m for i, m in enumerate(mae_each_step)})
# Log the loss at each time horizon
losses.update(self._step_mae_and_mse(y, y_hat, dict_key_root="horizon"))

# Log the persistance losses
y_persist = y[:, -1].unsqueeze(1).expand(-1, self.forecast_len)
losses["MAE_persistence/val"] = F.l1_loss(y_persist, y)
losses["MSE_persistence/val"] = F.mse_loss(y_persist, y)

# Log persistance loss at each time horizon
losses.update(self._step_mae_and_mse(y, y_persist, dict_key_root="persistence"))
return losses

def _calculate_test_losses(self, y, y_hat):
Expand Down Expand Up @@ -477,51 +495,36 @@ def training_step(self, batch, batch_idx):
opt_target = losses["MAE/train"]
return opt_target

def _log_forecast_plot(self, batch, y_hat, accum_batch_num, timesteps_to_plot, plot_suffix):
"""Log forecast plot to wandb"""
fig = plot_batch_forecasts(
batch,
y_hat,
quantiles=self.output_quantiles,
key_to_plot=self._target_key_name,
)

plot_name = f"val_forecast_samples/batch_idx_{accum_batch_num}_{plot_suffix}"

self.logger.experiment.log({plot_name: wandb.Image(fig)})
plt.close(fig)

def validation_step(self, batch: dict, batch_idx):
"""Run validation step"""
y_hat = self(batch)
# Sensor seems to be in batch, station, time order
y = batch[self._target_key][:, -self.forecast_len :, 0]
persistence = batch[self._target_key][:, -self.forecast_len - 1, 0]

# Expand persistence to be the same shape as y
persistence = persistence.unsqueeze(1).expand(-1, self.forecast_len)
losses = self._calculate_common_losses(y, y_hat)
losses.update(self._calculate_val_losses(y, y_hat))

# Store these to make horizon accuracy plot
self._horizon_maes.append(
{i: losses[f"MAE_horizon/step_{i:03}"] for i in range(self.forecast_len)}
)

logged_losses = {f"{k}/val": v for k, v in losses.items()}
logged_losses["MAE_persistence/val"] = F.l1_loss(persistence, y)
logged_losses["MSE_persistence/val"] = F.mse_loss(persistence, y)
# Log for each timestep the persistence loss
for i in range(self.forecast_len):
logged_losses[f"MAE_persistence/step_{i:03}/val"] = F.l1_loss(
persistence[:, i], y[:, i]
)
logged_losses[f"MSE_persistence/step_{i:03}/val"] = F.mse_loss(
persistence[:, i], y[:, i]
)
# Get the losses in the format of {VAL>_horizon/step_000: 0.1, VAL>_horizon/step_001: 0.2}
# for each step in the forecast horizon
# This is needed for the custom plot
# And needs to be in order of step
x_values = [
int(k.split("_")[-1].split("/")[0])
for k in logged_losses.keys()
if "MAE_horizon/step" in k
]
y_values = []
for x in x_values:
y_values.append(logged_losses[f"MAE_horizon/step_{x:03}/val"])
per_step_losses = [[x, y] for (x, y) in zip(x_values, y_values)]
# Check if WandBLogger is being used
if isinstance(self.logger, pl.loggers.WandbLogger):
table = wandb.Table(data=per_step_losses, columns=["timestep", "MAE"])
wandb.log(
{
"mae_vs_timestep": wandb.plot.line(
table, "timestep", "MAE", title="MAE vs Timestep"
)
}
)

self.log_dict(
logged_losses,
Expand All @@ -531,55 +534,62 @@ def validation_step(self, batch: dict, batch_idx):

accum_batch_num = batch_idx // self.trainer.accumulate_grad_batches

if accum_batch_num in [0, 1]:
# Make plots only if using wandb logger
if isinstance(self.logger, pl.loggers.WandbLogger) and accum_batch_num in [0, 1]:
# Store these temporarily under self
if not hasattr(self, "_val_y_hats"):
self._val_y_hats = PredAccumulator()
self._val_batches = BatchAccumulator(key_to_keep=self._target_key_name)

self._val_y_hats.append(y_hat)
self._val_batches.append(batch)
# if batch had accumulated

# if batch has accumulated
if (batch_idx + 1) % self.trainer.accumulate_grad_batches == 0:
y_hat = self._val_y_hats.flush()
batch = self._val_batches.flush()

fig = plot_batch_forecasts(
self._log_forecast_plot(
batch,
y_hat,
quantiles=self.output_quantiles,
key_to_plot=self._target_key_name,
accum_batch_num,
timesteps_to_plot=None,
plot_suffix="all",
)

self.logger.experiment.log(
{
f"val_forecast_samples/batch_idx_{accum_batch_num}_all": wandb.Image(fig),
}
)
plt.close(fig)

if self.time_step_intervals_to_plot is not None:
for interval in self.time_step_intervals_to_plot:
fig = plot_batch_forecasts(
self._log_forecast_plot(
batch,
y_hat,
quantiles=self.output_quantiles,
key_to_plot=self._target_key_name,
accum_batch_num,
timesteps_to_plot=interval,
plot_suffix=f"timestep_{interval}",
)
self.logger.experiment.log(
{
f"val_forecast_samples/batch_idx_{accum_batch_num}_"
f"timestep_{interval}": wandb.Image(fig),
}
)
plt.close(fig)

del self._val_y_hats
del self._val_batches

return logged_losses

def on_validation_epoch_end(self):
"""Run on epoch end"""

horizon_maes_dict = self._horizon_maes.flush()

# Create the horizon accuracy curve
if isinstance(self.logger, pl.loggers.WandbLogger):
per_step_losses = [[i, horizon_maes_dict[i]] for i in range(self.forecast_len)]

table = wandb.Table(data=per_step_losses, columns=["horizon_step", "MAE"])
wandb.log(
{
"horizon_loss_curve": wandb.plot.line(
table, "horizon_step", "MAE", title="Horizon loss curve"
)
},
)

def test_step(self, batch, batch_idx):
"""Run test step"""
y_hat = self(batch)
Expand Down

0 comments on commit a23ce6c

Please sign in to comment.