Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Apr 10, 2024
1 parent 279ca72 commit 41dd431
Showing 1 changed file with 32 additions and 36 deletions.
68 changes: 32 additions & 36 deletions pvnet/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,24 +384,24 @@ def _calculate_common_losses(self, y, y_hat):
)

return losses

def _step_mae_and_mse(self, y, y_hat, dict_key_root):
"""Calculate the MSE and MAE at each forecast step"""
losses = {}

mse_each_step = torch.mean((y_hat - y) ** 2, dim=0)
mae_each_step = torch.mean(torch.abs(y_hat - y), dim=0)

losses.update({f"MSE_{dict_key_root}/step_{i:03}": m for i, m in enumerate(mse_each_step)})
losses.update({f"MAE_{dict_key_root}/step_{i:03}": m for i, m in enumerate(mae_each_step)})

return losses

def _calculate_val_losses(self, y, y_hat):
"""Calculate additional validation losses"""

losses = {}

if self.use_quantile_regression:
# Add fraction below each quantile for calibration
for i, quantile in enumerate(self.output_quantiles):
Expand All @@ -412,15 +412,15 @@ def _calculate_val_losses(self, y, y_hat):

# Take median value for remaining metric calculations
y_hat = self._quantiles_to_prediction(y_hat)
# Log the loss at each time horizon

# Log the loss at each time horizon
losses.update(self._step_mae_and_mse(y, y_hat, dict_key_root="horizon"))

# Log the persistance losses
y_persist = y[:, -1].unsqueeze(1).expand(-1, self.forecast_len)
losses["MAE_persistence/val"] = F.l1_loss(y_persist, y)
losses["MSE_persistence/val"] = F.mse_loss(y_persist, y)

# Log persistance loss at each time horizon
losses.update(self._step_mae_and_mse(y, y_persist, dict_key_root="persistence"))
return losses
Expand Down Expand Up @@ -488,7 +488,7 @@ def training_step(self, batch, batch_idx):
else:
opt_target = losses["MAE/train"]
return opt_target

def _log_forecast_plot(self, batch, y_hat, accum_batch_num, timesteps_to_plot, plot_suffix):
"""Log forecast plot to wandb"""
fig = plot_batch_forecasts(
Expand All @@ -497,27 +497,27 @@ def _log_forecast_plot(self, batch, y_hat, accum_batch_num, timesteps_to_plot, p
quantiles=self.output_quantiles,
key_to_plot=self._target_key_name,
)

plot_name = f"val_forecast_samples/batch_idx_{accum_batch_num}_{plot_suffix}"

self.logger.experiment.log({plot_name: wandb.Image(fig)})
plt.close(fig)

def validation_step(self, batch: dict, batch_idx):
"""Run validation step"""
y_hat = self(batch)
# Sensor seems to be in batch, station, time order
y = batch[self._target_key][:, -self.forecast_len :, 0]

# Expand persistence to be the same shape as y
losses = self._calculate_common_losses(y, y_hat)
losses.update(self._calculate_val_losses(y, y_hat))

# Store these to make horizon accuracy plot
self._horizon_maes.append(
{i:losses[f"MAE_horizon/step_{i:03}"] for i in range(self.forecast_len)}
{i: losses[f"MAE_horizon/step_{i:03}"] for i in range(self.forecast_len)}
)

logged_losses = {f"{k}/val": v for k, v in losses.items()}

self.log_dict(
Expand All @@ -537,46 +537,42 @@ def validation_step(self, batch: dict, batch_idx):

self._val_y_hats.append(y_hat)
self._val_batches.append(batch)

# if batch has accumulated
if (batch_idx + 1) % self.trainer.accumulate_grad_batches == 0:
y_hat = self._val_y_hats.flush()
batch = self._val_batches.flush()

self._log_forecast_plot(
batch,
y_hat,
accum_batch_num,
timesteps_to_plot=None,
batch,
y_hat,
accum_batch_num,
timesteps_to_plot=None,
plot_suffix="all",
)

if self.time_step_intervals_to_plot is not None:
for interval in self.time_step_intervals_to_plot:

self._log_forecast_plot(
batch,
y_hat,
accum_batch_num,
timesteps_to_plot=interval,
plot_suffix=f"timestep_{interval}"
batch,
y_hat,
accum_batch_num,
timesteps_to_plot=interval,
plot_suffix=f"timestep_{interval}",
)

del self._val_y_hats
del self._val_batches



return logged_losses

def on_validation_epoch_end(self):
"""Run on epoch end"""

horizon_maes_dict = self._horizon_maes.flush()

# Create the horizon accuracy curve
if isinstance(self.logger, pl.loggers.WandbLogger):

per_step_losses = [[i, horizon_maes_dict[i]] for i in range(self.forecast_len)]

table = wandb.Table(data=per_step_losses, columns=["horizon_step", "MAE"])
Expand All @@ -587,7 +583,7 @@ def on_validation_epoch_end(self):
)
},
)

def test_step(self, batch, batch_idx):
"""Run test step"""
y_hat = self(batch)
Expand Down

0 comments on commit 41dd431

Please sign in to comment.