Skip to content

Commit

Permalink
save validation batch results to wandb
Browse files Browse the repository at this point in the history
  • Loading branch information
peterdudfield committed Sep 5, 2024
1 parent 83e7b48 commit 430a1f7
Showing 1 changed file with 46 additions and 2 deletions.
48 changes: 46 additions & 2 deletions pvnet/models/base_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Base model for all PVNet submodels"""
import json
import logging
import tempfile
import os
from pathlib import Path
from typing import Dict, Optional, Union
Expand Down Expand Up @@ -410,6 +411,9 @@ def __init__(
else:
self.num_output_features = self.forecast_len

# save all validation results to array, so we can save these to weights n biases
self.validation_epoch_results = []

def _quantiles_to_prediction(self, y_quantiles):
"""
Convert network prediction into a point prediction.
Expand Down Expand Up @@ -609,12 +613,41 @@ def _log_forecast_plot(self, batch, y_hat, accum_batch_num, timesteps_to_plot, p
print(e)
plt.close(fig)

def _log_validation_results(self, batch, y_hat, accum_batch_num):
""" Append validation results to self.validation_epoch_results """

y = batch[self._target_key][:, -self.forecast_len:, 0]
batch_size = y.shape[0]

for i in range(batch_size):
y_i = y[i].detach().cpu().numpy()
y_hat_i = y_hat[i].detach().cpu().numpy()

time_utc_key = BatchKey[f"{self._target_key}_time_utc"]
time_utc = batch[time_utc_key][i, -self.forecast_len:].detach().cpu().numpy()

id_key = BatchKey[f"{self._target_key}_id"]
ids = batch[id_key][i].detach().cpu().numpy()

self.validation_epoch_results.append({"y": y_i,
"y_hat": y_hat_i,
"time_utc": time_utc,
"id": ids,
"batch_idx": accum_batch_num,
"example_idx": i,
})

def validation_step(self, batch: dict, batch_idx):
"""Run validation step"""

accum_batch_num = batch_idx // self.trainer.accumulate_grad_batches

y_hat = self(batch)
# Sensor seems to be in batch, station, time order
y = batch[self._target_key][:, -self.forecast_len :, 0]

self._log_validation_results(batch, y_hat, accum_batch_num)

# Expand persistence to be the same shape as y
losses = self._calculate_common_losses(y, y_hat)
losses.update(self._calculate_val_losses(y, y_hat))
Expand All @@ -632,8 +665,6 @@ def validation_step(self, batch: dict, batch_idx):
on_epoch=True,
)

accum_batch_num = batch_idx // self.trainer.accumulate_grad_batches

# Make plots only if using wandb logger
if isinstance(self.logger, pl.loggers.WandbLogger) and accum_batch_num in [0, 1]:
# Store these temporarily under self
Expand Down Expand Up @@ -675,6 +706,19 @@ def validation_step(self, batch: dict, batch_idx):
def on_validation_epoch_end(self):
"""Run on epoch end"""

try:
# join together validation results, and save to wandb
validation_results_df = pd.DataFrame(self.validation_epoch_results)
with tempfile.TemporaryDirectory() as tempdir:
filename = os.path.join(tempdir, f"validation_results.csv_{self.current_epoch}")
validation_results_df.to_csv(filename, index=False)

validation_artifact = wandb.Artifact(f"validation_results_epoch={self.current_epoch}", type="dataset")
wandb.log_artifact(validation_artifact)
except Exception as e:
print("Failed to log validation results to wandb")
print(e)

horizon_maes_dict = self._horizon_maes.flush()

# Create the horizon accuracy curve
Expand Down

0 comments on commit 430a1f7

Please sign in to comment.