diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index cc0d27af..dd4dbb78 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -384,24 +384,24 @@ def _calculate_common_losses(self, y, y_hat): ) return losses - + def _step_mae_and_mse(self, y, y_hat, dict_key_root): """Calculate the MSE and MAE at each forecast step""" losses = {} - + mse_each_step = torch.mean((y_hat - y) ** 2, dim=0) mae_each_step = torch.mean(torch.abs(y_hat - y), dim=0) - + losses.update({f"MSE_{dict_key_root}/step_{i:03}": m for i, m in enumerate(mse_each_step)}) losses.update({f"MAE_{dict_key_root}/step_{i:03}": m for i, m in enumerate(mae_each_step)}) - + return losses def _calculate_val_losses(self, y, y_hat): """Calculate additional validation losses""" losses = {} - + if self.use_quantile_regression: # Add fraction below each quantile for calibration for i, quantile in enumerate(self.output_quantiles): @@ -412,15 +412,15 @@ def _calculate_val_losses(self, y, y_hat): # Take median value for remaining metric calculations y_hat = self._quantiles_to_prediction(y_hat) - - # Log the loss at each time horizon + + # Log the loss at each time horizon losses.update(self._step_mae_and_mse(y, y_hat, dict_key_root="horizon")) - + # Log the persistance losses y_persist = y[:, -1].unsqueeze(1).expand(-1, self.forecast_len) losses["MAE_persistence/val"] = F.l1_loss(y_persist, y) losses["MSE_persistence/val"] = F.mse_loss(y_persist, y) - + # Log persistance loss at each time horizon losses.update(self._step_mae_and_mse(y, y_persist, dict_key_root="persistence")) return losses @@ -488,7 +488,7 @@ def training_step(self, batch, batch_idx): else: opt_target = losses["MAE/train"] return opt_target - + def _log_forecast_plot(self, batch, y_hat, accum_batch_num, timesteps_to_plot, plot_suffix): """Log forecast plot to wandb""" fig = plot_batch_forecasts( @@ -497,27 +497,27 @@ def _log_forecast_plot(self, batch, y_hat, accum_batch_num, timesteps_to_plot, p quantiles=self.output_quantiles, key_to_plot=self._target_key_name, ) - + plot_name = f"val_forecast_samples/batch_idx_{accum_batch_num}_{plot_suffix}" - + self.logger.experiment.log({plot_name: wandb.Image(fig)}) plt.close(fig) - + def validation_step(self, batch: dict, batch_idx): """Run validation step""" y_hat = self(batch) # Sensor seems to be in batch, station, time order y = batch[self._target_key][:, -self.forecast_len :, 0] - + # Expand persistence to be the same shape as y losses = self._calculate_common_losses(y, y_hat) losses.update(self._calculate_val_losses(y, y_hat)) - + # Store these to make horizon accuracy plot self._horizon_maes.append( - {i:losses[f"MAE_horizon/step_{i:03}"] for i in range(self.forecast_len)} + {i: losses[f"MAE_horizon/step_{i:03}"] for i in range(self.forecast_len)} ) - + logged_losses = {f"{k}/val": v for k, v in losses.items()} self.log_dict( @@ -537,46 +537,42 @@ def validation_step(self, batch: dict, batch_idx): self._val_y_hats.append(y_hat) self._val_batches.append(batch) - + # if batch has accumulated if (batch_idx + 1) % self.trainer.accumulate_grad_batches == 0: y_hat = self._val_y_hats.flush() batch = self._val_batches.flush() - + self._log_forecast_plot( - batch, - y_hat, - accum_batch_num, - timesteps_to_plot=None, + batch, + y_hat, + accum_batch_num, + timesteps_to_plot=None, plot_suffix="all", ) if self.time_step_intervals_to_plot is not None: for interval in self.time_step_intervals_to_plot: - self._log_forecast_plot( - batch, - y_hat, - accum_batch_num, - timesteps_to_plot=interval, - plot_suffix=f"timestep_{interval}" + batch, + y_hat, + accum_batch_num, + timesteps_to_plot=interval, + plot_suffix=f"timestep_{interval}", ) del self._val_y_hats del self._val_batches - - return logged_losses - + def on_validation_epoch_end(self): """Run on epoch end""" - + horizon_maes_dict = self._horizon_maes.flush() - + # Create the horizon accuracy curve if isinstance(self.logger, pl.loggers.WandbLogger): - per_step_losses = [[i, horizon_maes_dict[i]] for i in range(self.forecast_len)] table = wandb.Table(data=per_step_losses, columns=["horizon_step", "MAE"]) @@ -587,7 +583,7 @@ def on_validation_epoch_end(self): ) }, ) - + def test_step(self, batch, batch_idx): """Run test step""" y_hat = self(batch)