diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index 9faff47d..e6e0e907 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -2,6 +2,7 @@ import json import logging import os +import tempfile from pathlib import Path from typing import Dict, Optional, Union @@ -410,6 +411,9 @@ def __init__( else: self.num_output_features = self.forecast_len + # save all validation results to array, so we can save these to weights n biases + self.validation_epoch_results = [] + def _quantiles_to_prediction(self, y_quantiles): """ Convert network prediction into a point prediction. @@ -609,12 +613,61 @@ def _log_forecast_plot(self, batch, y_hat, accum_batch_num, timesteps_to_plot, p print(e) plt.close(fig) + def _log_validation_results(self, batch, y_hat, accum_batch_num): + """Append validation results to self.validation_epoch_results""" + + # get truth values, shape (b, forecast_len) + y = batch[self._target_key][:, -self.forecast_len :, 0] + y = y.detach().cpu().numpy() + batch_size = y.shape[0] + + # get prediction values, shape (b, forecast_len, quantiles?) + y_hat = y_hat.detach().cpu().numpy() + + # get time_utc, shape (b, forecast_len) + time_utc_key = BatchKey[f"{self._target_key_name}_time_utc"] + time_utc = batch[time_utc_key][:, -self.forecast_len :].detach().cpu().numpy() + + # get target id and change from (b,1) to (b,) + id_key = BatchKey[f"{self._target_key_name}_id"] + target_id = batch[id_key].detach().cpu().numpy() + target_id = target_id.squeeze() + + for i in range(batch_size): + y_i = y[i] + y_hat_i = y_hat[i] + time_utc_i = time_utc[i] + target_id_i = target_id[i] + + results_dict = { + "y": y_i, + "time_utc": time_utc_i, + } + if self.use_quantile_regression: + results_dict.update( + {f"y_quantile_{q}": y_hat_i[:, i] for i, q in enumerate(self.output_quantiles)} + ) + else: + results_dict["y_hat"] = y_hat_i + + results_df = pd.DataFrame(results_dict) + results_df["id"] = target_id_i + results_df["batch_idx"] = accum_batch_num + results_df["example_idx"] = i + + self.validation_epoch_results.append(results_df) + def validation_step(self, batch: dict, batch_idx): """Run validation step""" + + accum_batch_num = batch_idx // self.trainer.accumulate_grad_batches + y_hat = self(batch) # Sensor seems to be in batch, station, time order y = batch[self._target_key][:, -self.forecast_len :, 0] + self._log_validation_results(batch, y_hat, accum_batch_num) + # Expand persistence to be the same shape as y losses = self._calculate_common_losses(y, y_hat) losses.update(self._calculate_val_losses(y, y_hat)) @@ -632,8 +685,6 @@ def validation_step(self, batch: dict, batch_idx): on_epoch=True, ) - accum_batch_num = batch_idx // self.trainer.accumulate_grad_batches - # Make plots only if using wandb logger if isinstance(self.logger, pl.loggers.WandbLogger) and accum_batch_num in [0, 1]: # Store these temporarily under self @@ -675,6 +726,23 @@ def validation_step(self, batch: dict, batch_idx): def on_validation_epoch_end(self): """Run on epoch end""" + try: + # join together validation results, and save to wandb + validation_results_df = pd.concat(self.validation_epoch_results) + with tempfile.TemporaryDirectory() as tempdir: + filename = os.path.join(tempdir, f"validation_results_{self.current_epoch}.csv") + validation_results_df.to_csv(filename, index=False) + + # make and log wand artifact + validation_artifact = wandb.Artifact( + f"validation_results_epoch={self.current_epoch}", type="dataset" + ) + validation_artifact.add_file(filename) + wandb.log_artifact(validation_artifact) + except Exception as e: + print("Failed to log validation results to wandb") + print(e) + horizon_maes_dict = self._horizon_maes.flush() # Create the horizon accuracy curve