Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

save validation batch results to wandb #252

Merged
merged 21 commits into from
Sep 12, 2024
Merged
Changes from 19 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
430a1f7
save validation batch results to wandb
peterdudfield Sep 5, 2024
a3f661b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 5, 2024
f25947b
fix validation df
peterdudfield Sep 5, 2024
202385a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 5, 2024
a17ad48
tidy up
peterdudfield Sep 5, 2024
e227b65
at print statment
peterdudfield Sep 5, 2024
42ed3e7
try and except around odd error
peterdudfield Sep 5, 2024
0444b70
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 5, 2024
53024cd
fix
peterdudfield Sep 5, 2024
ffe9b15
PR comments
peterdudfield Sep 6, 2024
79de1c4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 6, 2024
d501107
fix and add comments
peterdudfield Sep 6, 2024
dd2f469
Merge commit '79de1c4c3da17068f9d64a360f84edb35adffb93' into issue.csv
peterdudfield Sep 6, 2024
156fdfa
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 6, 2024
89e4e75
update for quantile loss
peterdudfield Sep 12, 2024
8276a07
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 12, 2024
058a881
save all quantile results
peterdudfield Sep 12, 2024
6f8b946
Merge commit '8276a07f1b1e2013d3b713fc954a5a87b2c57d89' into issue.csv
peterdudfield Sep 12, 2024
a9032a0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 12, 2024
74f4f6d
PR comment
peterdudfield Sep 12, 2024
91a7e94
Merge commit 'a9032a046e436681b677b22cf762814836c9ae09' into issue.csv
peterdudfield Sep 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 70 additions & 2 deletions pvnet/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import logging
import os
import tempfile
from pathlib import Path
from typing import Dict, Optional, Union

Expand Down Expand Up @@ -410,6 +411,9 @@ def __init__(
else:
self.num_output_features = self.forecast_len

# save all validation results to array, so we can save these to weights n biases
self.validation_epoch_results = []

def _quantiles_to_prediction(self, y_quantiles):
"""
Convert network prediction into a point prediction.
Expand Down Expand Up @@ -609,12 +613,61 @@ def _log_forecast_plot(self, batch, y_hat, accum_batch_num, timesteps_to_plot, p
print(e)
plt.close(fig)

def _log_validation_results(self, batch, y_hat, accum_batch_num):
"""Append validation results to self.validation_epoch_results"""

# get truth values, shape (b, forecast_len)
y = batch[self._target_key][:, -self.forecast_len :, 0]
y = y.detach().cpu().numpy()
batch_size = y.shape[0]

# get truth values, shape (b, forecast_len)
peterdudfield marked this conversation as resolved.
Show resolved Hide resolved
y_hat = y_hat.detach().cpu().numpy()

# get time_utc, shape (b, forecast_len)
time_utc_key = BatchKey[f"{self._target_key_name}_time_utc"]
time_utc = batch[time_utc_key][:, -self.forecast_len :].detach().cpu().numpy()

# get target id and change from (b,1) to (b,)
id_key = BatchKey[f"{self._target_key_name}_id"]
target_id = batch[id_key].detach().cpu().numpy()
target_id = target_id.squeeze()

for i in range(batch_size):
y_i = y[i]
y_hat_i = y_hat[i]
time_utc_i = time_utc[i]
target_id_i = target_id[i]

results_dict = {
"y": y_i,
"time_utc": time_utc_i,
}
if self.use_quantile_regression:
results_dict.update(
{f"y_quantile_{q}": y_hat_i[:, i] for i, q in enumerate(self.output_quantiles)}
)
else:
results_dict["y_hat"] = y_hat_i

results_df = pd.DataFrame(results_dict)
results_df["id"] = target_id_i
results_df["batch_idx"] = accum_batch_num
results_df["example_idx"] = i

self.validation_epoch_results.append(results_df)

def validation_step(self, batch: dict, batch_idx):
"""Run validation step"""

accum_batch_num = batch_idx // self.trainer.accumulate_grad_batches

y_hat = self(batch)
# Sensor seems to be in batch, station, time order
y = batch[self._target_key][:, -self.forecast_len :, 0]

self._log_validation_results(batch, y_hat, accum_batch_num)

# Expand persistence to be the same shape as y
losses = self._calculate_common_losses(y, y_hat)
losses.update(self._calculate_val_losses(y, y_hat))
Expand All @@ -632,8 +685,6 @@ def validation_step(self, batch: dict, batch_idx):
on_epoch=True,
)

accum_batch_num = batch_idx // self.trainer.accumulate_grad_batches

# Make plots only if using wandb logger
if isinstance(self.logger, pl.loggers.WandbLogger) and accum_batch_num in [0, 1]:
# Store these temporarily under self
Expand Down Expand Up @@ -675,6 +726,23 @@ def validation_step(self, batch: dict, batch_idx):
def on_validation_epoch_end(self):
"""Run on epoch end"""

try:
# join together validation results, and save to wandb
validation_results_df = pd.concat(self.validation_epoch_results)
with tempfile.TemporaryDirectory() as tempdir:
filename = os.path.join(tempdir, f"validation_results_{self.current_epoch}.csv")
validation_results_df.to_csv(filename, index=False)

# make and log wand artifact
validation_artifact = wandb.Artifact(
f"validation_results_epoch={self.current_epoch}", type="dataset"
)
validation_artifact.add_file(filename)
wandb.log_artifact(validation_artifact)
peterdudfield marked this conversation as resolved.
Show resolved Hide resolved
except Exception as e:
print("Failed to log validation results to wandb")
print(e)

horizon_maes_dict = self._horizon_maes.flush()

# Create the horizon accuracy curve
Expand Down
Loading