From 279ca725a6d0c2f6fd162f18a0771ad690bce2a9 Mon Sep 17 00:00:00 2001 From: James Fulton Date: Wed, 10 Apr 2024 15:49:37 +0000 Subject: [PATCH] fix horizon loss plot to use all val samples + tidy --- pvnet/models/base_model.py | 158 ++++++++++++++++++++----------------- 1 file changed, 86 insertions(+), 72 deletions(-) diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index 6b8faa36..cc0d27af 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -300,6 +300,7 @@ def __init__( self._accumulated_metrics = MetricAccumulator() self._accumulated_batches = BatchAccumulator(key_to_keep=self._target_key_name) self._accumulated_y_hat = PredAccumulator() + self._horizon_maes = MetricAccumulator() # Store whether the model should use quantile regression or simply predict the mean self.use_quantile_regression = self.output_quantiles is not None @@ -383,12 +384,24 @@ def _calculate_common_losses(self, y, y_hat): ) return losses + + def _step_mae_and_mse(self, y, y_hat, dict_key_root): + """Calculate the MSE and MAE at each forecast step""" + losses = {} + + mse_each_step = torch.mean((y_hat - y) ** 2, dim=0) + mae_each_step = torch.mean(torch.abs(y_hat - y), dim=0) + + losses.update({f"MSE_{dict_key_root}/step_{i:03}": m for i, m in enumerate(mse_each_step)}) + losses.update({f"MAE_{dict_key_root}/step_{i:03}": m for i, m in enumerate(mae_each_step)}) + + return losses def _calculate_val_losses(self, y, y_hat): """Calculate additional validation losses""" losses = {} - + if self.use_quantile_regression: # Add fraction below each quantile for calibration for i, quantile in enumerate(self.output_quantiles): @@ -399,12 +412,17 @@ def _calculate_val_losses(self, y, y_hat): # Take median value for remaining metric calculations y_hat = self._quantiles_to_prediction(y_hat) - mse_each_step = torch.mean((y_hat - y) ** 2, dim=0) - mae_each_step = torch.mean(torch.abs(y_hat - y), dim=0) - - losses.update({f"MSE_horizon/step_{i:03}": m for i, m in enumerate(mse_each_step)}) - losses.update({f"MAE_horizon/step_{i:03}": m for i, m in enumerate(mae_each_step)}) - + + # Log the loss at each time horizon + losses.update(self._step_mae_and_mse(y, y_hat, dict_key_root="horizon")) + + # Log the persistance losses + y_persist = y[:, -1].unsqueeze(1).expand(-1, self.forecast_len) + losses["MAE_persistence/val"] = F.l1_loss(y_persist, y) + losses["MSE_persistence/val"] = F.mse_loss(y_persist, y) + + # Log persistance loss at each time horizon + losses.update(self._step_mae_and_mse(y, y_persist, dict_key_root="persistence")) return losses def _calculate_test_losses(self, y, y_hat): @@ -470,52 +488,37 @@ def training_step(self, batch, batch_idx): else: opt_target = losses["MAE/train"] return opt_target - + + def _log_forecast_plot(self, batch, y_hat, accum_batch_num, timesteps_to_plot, plot_suffix): + """Log forecast plot to wandb""" + fig = plot_batch_forecasts( + batch, + y_hat, + quantiles=self.output_quantiles, + key_to_plot=self._target_key_name, + ) + + plot_name = f"val_forecast_samples/batch_idx_{accum_batch_num}_{plot_suffix}" + + self.logger.experiment.log({plot_name: wandb.Image(fig)}) + plt.close(fig) + def validation_step(self, batch: dict, batch_idx): """Run validation step""" y_hat = self(batch) # Sensor seems to be in batch, station, time order y = batch[self._target_key][:, -self.forecast_len :, 0] - persistence = batch[self._target_key][:, -self.forecast_len - 1, 0] + # Expand persistence to be the same shape as y - persistence = persistence.unsqueeze(1).expand(-1, self.forecast_len) losses = self._calculate_common_losses(y, y_hat) losses.update(self._calculate_val_losses(y, y_hat)) - + + # Store these to make horizon accuracy plot + self._horizon_maes.append( + {i:losses[f"MAE_horizon/step_{i:03}"] for i in range(self.forecast_len)} + ) + logged_losses = {f"{k}/val": v for k, v in losses.items()} - logged_losses["MAE_persistence/val"] = F.l1_loss(persistence, y) - logged_losses["MSE_persistence/val"] = F.mse_loss(persistence, y) - # Log for each timestep the persistence loss - for i in range(self.forecast_len): - logged_losses[f"MAE_persistence/step_{i:03}/val"] = F.l1_loss( - persistence[:, i], y[:, i] - ) - logged_losses[f"MSE_persistence/step_{i:03}/val"] = F.mse_loss( - persistence[:, i], y[:, i] - ) - # Get the losses in the format of {VAL>_horizon/step_000: 0.1, VAL>_horizon/step_001: 0.2} - # for each step in the forecast horizon - # This is needed for the custom plot - # And needs to be in order of step - x_values = [ - int(k.split("_")[-1].split("/")[0]) - for k in logged_losses.keys() - if "MAE_horizon/step" in k - ] - y_values = [] - for x in x_values: - y_values.append(logged_losses[f"MAE_horizon/step_{x:03}/val"]) - per_step_losses = [[x, y] for (x, y) in zip(x_values, y_values)] - # Check if WandBLogger is being used - if isinstance(self.logger, pl.loggers.WandbLogger): - table = wandb.Table(data=per_step_losses, columns=["timestep", "MAE"]) - wandb.log( - { - "mae_vs_timestep": wandb.plot.line( - table, "timestep", "MAE", title="MAE vs Timestep" - ) - } - ) self.log_dict( logged_losses, @@ -525,7 +528,8 @@ def validation_step(self, batch: dict, batch_idx): accum_batch_num = batch_idx // self.trainer.accumulate_grad_batches - if accum_batch_num in [0, 1]: + # Make plots only if using wandb logger + if isinstance(self.logger, pl.loggers.WandbLogger) and accum_batch_num in [0, 1]: # Store these temporarily under self if not hasattr(self, "_val_y_hats"): self._val_y_hats = PredAccumulator() @@ -533,47 +537,57 @@ def validation_step(self, batch: dict, batch_idx): self._val_y_hats.append(y_hat) self._val_batches.append(batch) - # if batch had accumulated + + # if batch has accumulated if (batch_idx + 1) % self.trainer.accumulate_grad_batches == 0: y_hat = self._val_y_hats.flush() batch = self._val_batches.flush() - - fig = plot_batch_forecasts( - batch, - y_hat, - quantiles=self.output_quantiles, - key_to_plot=self._target_key_name, - ) - - self.logger.experiment.log( - { - f"val_forecast_samples/batch_idx_{accum_batch_num}_all": wandb.Image(fig), - } + + self._log_forecast_plot( + batch, + y_hat, + accum_batch_num, + timesteps_to_plot=None, + plot_suffix="all", ) - plt.close(fig) if self.time_step_intervals_to_plot is not None: for interval in self.time_step_intervals_to_plot: - fig = plot_batch_forecasts( - batch, - y_hat, - quantiles=self.output_quantiles, - key_to_plot=self._target_key_name, - timesteps_to_plot=interval, - ) - self.logger.experiment.log( - { - f"val_forecast_samples/batch_idx_{accum_batch_num}_" - f"timestep_{interval}": wandb.Image(fig), - } + + self._log_forecast_plot( + batch, + y_hat, + accum_batch_num, + timesteps_to_plot=interval, + plot_suffix=f"timestep_{interval}" ) - plt.close(fig) del self._val_y_hats del self._val_batches + + return logged_losses + + def on_validation_epoch_end(self): + """Run on epoch end""" + + horizon_maes_dict = self._horizon_maes.flush() + + # Create the horizon accuracy curve + if isinstance(self.logger, pl.loggers.WandbLogger): + + per_step_losses = [[i, horizon_maes_dict[i]] for i in range(self.forecast_len)] + table = wandb.Table(data=per_step_losses, columns=["horizon_step", "MAE"]) + wandb.log( + { + "horizon_loss_curve": wandb.plot.line( + table, "horizon_step", "MAE", title="Horizon loss curve" + ) + }, + ) + def test_step(self, batch, batch_idx): """Run test step""" y_hat = self(batch)