Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

save validation batch results to wandb #252

Merged
merged 21 commits into from
Sep 12, 2024
Merged
Changes from 9 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
430a1f7
save validation batch results to wandb
peterdudfield Sep 5, 2024
a3f661b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 5, 2024
f25947b
fix validation df
peterdudfield Sep 5, 2024
202385a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 5, 2024
a17ad48
tidy up
peterdudfield Sep 5, 2024
e227b65
at print statment
peterdudfield Sep 5, 2024
42ed3e7
try and except around odd error
peterdudfield Sep 5, 2024
0444b70
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 5, 2024
53024cd
fix
peterdudfield Sep 5, 2024
ffe9b15
PR comments
peterdudfield Sep 6, 2024
79de1c4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 6, 2024
d501107
fix and add comments
peterdudfield Sep 6, 2024
dd2f469
Merge commit '79de1c4c3da17068f9d64a360f84edb35adffb93' into issue.csv
peterdudfield Sep 6, 2024
156fdfa
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 6, 2024
89e4e75
update for quantile loss
peterdudfield Sep 12, 2024
8276a07
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 12, 2024
058a881
save all quantile results
peterdudfield Sep 12, 2024
6f8b946
Merge commit '8276a07f1b1e2013d3b713fc954a5a87b2c57d89' into issue.csv
peterdudfield Sep 12, 2024
a9032a0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 12, 2024
74f4f6d
PR comment
peterdudfield Sep 12, 2024
91a7e94
Merge commit 'a9032a046e436681b677b22cf762814836c9ae09' into issue.csv
peterdudfield Sep 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 55 additions & 2 deletions pvnet/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import logging
import os
import tempfile
from pathlib import Path
from typing import Dict, Optional, Union

Expand Down Expand Up @@ -410,6 +411,9 @@ def __init__(
else:
self.num_output_features = self.forecast_len

# save all validation results to array, so we can save these to weights n biases
self.validation_epoch_results = []

def _quantiles_to_prediction(self, y_quantiles):
"""
Convert network prediction into a point prediction.
Expand Down Expand Up @@ -609,12 +613,48 @@ def _log_forecast_plot(self, batch, y_hat, accum_batch_num, timesteps_to_plot, p
print(e)
plt.close(fig)

def _log_validation_results(self, batch, y_hat, accum_batch_num):
"""Append validation results to self.validation_epoch_results"""

y = batch[self._target_key][:, -self.forecast_len :, 0]
batch_size = y.shape[0]

for i in range(batch_size):
y_i = y[i].detach().cpu().numpy()
peterdudfield marked this conversation as resolved.
Show resolved Hide resolved
y_hat_i = y_hat[i].detach().cpu().numpy()

time_utc_key = BatchKey[f"{self._target_key_name}_time_utc"]
time_utc = batch[time_utc_key][i, -self.forecast_len :].detach().cpu().numpy()

id_key = BatchKey[f"{self._target_key_name}_id"]
target_id = batch[id_key][i].detach().cpu().numpy()
if target_id.ndim > 0:
target_id = target_id[0]

results_df = pd.DataFrame(
{
"y": y_i,
"y_hat": y_hat_i,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you done a training run with this? When predicting quantiles, I think y_i and y_hat_i will be different shapes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I havent done training, but the end2end test does go through here.

becasue i take y = batch[self._target_key][:, -self.forecast_len :, 0], it makes it the same length as y_hat.

Perhaps theres a better way to standardised that. something for the future

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeh but I think in this case, y_i is a vector with shape (horizon_step,). But y_hat_i can either be a vector with shape (horizon_step,) or (horizon_step, quantile,) depending on whether we are training to predict quantiles or only a central value. The end2end test only tests non-quantile training.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh good point. ill have a think about the qunalite things. Good catch on that

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I should have it now

"time_utc": time_utc,
}
)
results_df["id"] = target_id
results_df["batch_idx"] = accum_batch_num
results_df["example_idx"] = i

self.validation_epoch_results.append(results_df)

def validation_step(self, batch: dict, batch_idx):
"""Run validation step"""

accum_batch_num = batch_idx // self.trainer.accumulate_grad_batches

y_hat = self(batch)
# Sensor seems to be in batch, station, time order
y = batch[self._target_key][:, -self.forecast_len :, 0]

self._log_validation_results(batch, y_hat, accum_batch_num)

# Expand persistence to be the same shape as y
losses = self._calculate_common_losses(y, y_hat)
losses.update(self._calculate_val_losses(y, y_hat))
Expand All @@ -632,8 +672,6 @@ def validation_step(self, batch: dict, batch_idx):
on_epoch=True,
)

accum_batch_num = batch_idx // self.trainer.accumulate_grad_batches

# Make plots only if using wandb logger
if isinstance(self.logger, pl.loggers.WandbLogger) and accum_batch_num in [0, 1]:
# Store these temporarily under self
Expand Down Expand Up @@ -675,6 +713,21 @@ def validation_step(self, batch: dict, batch_idx):
def on_validation_epoch_end(self):
"""Run on epoch end"""

try:
# join together validation results, and save to wandb
validation_results_df = pd.concat(self.validation_epoch_results)
with tempfile.TemporaryDirectory() as tempdir:
filename = os.path.join(tempdir, f"validation_results.csv_{self.current_epoch}")
validation_results_df.to_csv(filename, index=False)

validation_artifact = wandb.Artifact(
f"validation_results_epoch={self.current_epoch}", type="dataset"
)
wandb.log_artifact(validation_artifact)
peterdudfield marked this conversation as resolved.
Show resolved Hide resolved
except Exception as e:
print("Failed to log validation results to wandb")
print(e)

horizon_maes_dict = self._horizon_maes.flush()

# Create the horizon accuracy curve
Expand Down
Loading