From 430a1f76ee4d3d97cea620c5da1cc7ab9f6bbe57 Mon Sep 17 00:00:00 2001 From: peterdudfield Date: Thu, 5 Sep 2024 16:17:07 +0100 Subject: [PATCH 01/18] save validation batch results to wandb --- pvnet/models/base_model.py | 48 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 46 insertions(+), 2 deletions(-) diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index 9faff47d..26b59702 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -1,6 +1,7 @@ """Base model for all PVNet submodels""" import json import logging +import tempfile import os from pathlib import Path from typing import Dict, Optional, Union @@ -410,6 +411,9 @@ def __init__( else: self.num_output_features = self.forecast_len + # save all validation results to array, so we can save these to weights n biases + self.validation_epoch_results = [] + def _quantiles_to_prediction(self, y_quantiles): """ Convert network prediction into a point prediction. @@ -609,12 +613,41 @@ def _log_forecast_plot(self, batch, y_hat, accum_batch_num, timesteps_to_plot, p print(e) plt.close(fig) + def _log_validation_results(self, batch, y_hat, accum_batch_num): + """ Append validation results to self.validation_epoch_results """ + + y = batch[self._target_key][:, -self.forecast_len:, 0] + batch_size = y.shape[0] + + for i in range(batch_size): + y_i = y[i].detach().cpu().numpy() + y_hat_i = y_hat[i].detach().cpu().numpy() + + time_utc_key = BatchKey[f"{self._target_key}_time_utc"] + time_utc = batch[time_utc_key][i, -self.forecast_len:].detach().cpu().numpy() + + id_key = BatchKey[f"{self._target_key}_id"] + ids = batch[id_key][i].detach().cpu().numpy() + + self.validation_epoch_results.append({"y": y_i, + "y_hat": y_hat_i, + "time_utc": time_utc, + "id": ids, + "batch_idx": accum_batch_num, + "example_idx": i, + }) + def validation_step(self, batch: dict, batch_idx): """Run validation step""" + + accum_batch_num = batch_idx // self.trainer.accumulate_grad_batches + y_hat = self(batch) # Sensor seems to be in batch, station, time order y = batch[self._target_key][:, -self.forecast_len :, 0] + self._log_validation_results(batch, y_hat, accum_batch_num) + # Expand persistence to be the same shape as y losses = self._calculate_common_losses(y, y_hat) losses.update(self._calculate_val_losses(y, y_hat)) @@ -632,8 +665,6 @@ def validation_step(self, batch: dict, batch_idx): on_epoch=True, ) - accum_batch_num = batch_idx // self.trainer.accumulate_grad_batches - # Make plots only if using wandb logger if isinstance(self.logger, pl.loggers.WandbLogger) and accum_batch_num in [0, 1]: # Store these temporarily under self @@ -675,6 +706,19 @@ def validation_step(self, batch: dict, batch_idx): def on_validation_epoch_end(self): """Run on epoch end""" + try: + # join together validation results, and save to wandb + validation_results_df = pd.DataFrame(self.validation_epoch_results) + with tempfile.TemporaryDirectory() as tempdir: + filename = os.path.join(tempdir, f"validation_results.csv_{self.current_epoch}") + validation_results_df.to_csv(filename, index=False) + + validation_artifact = wandb.Artifact(f"validation_results_epoch={self.current_epoch}", type="dataset") + wandb.log_artifact(validation_artifact) + except Exception as e: + print("Failed to log validation results to wandb") + print(e) + horizon_maes_dict = self._horizon_maes.flush() # Create the horizon accuracy curve From a3f661be1acabaf03c44f09fd5efad7b53ca168d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Sep 2024 15:23:49 +0000 Subject: [PATCH 02/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pvnet/models/base_model.py | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index 26b59702..855dc625 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -1,8 +1,8 @@ """Base model for all PVNet submodels""" import json import logging -import tempfile import os +import tempfile from pathlib import Path from typing import Dict, Optional, Union @@ -614,9 +614,9 @@ def _log_forecast_plot(self, batch, y_hat, accum_batch_num, timesteps_to_plot, p plt.close(fig) def _log_validation_results(self, batch, y_hat, accum_batch_num): - """ Append validation results to self.validation_epoch_results """ + """Append validation results to self.validation_epoch_results""" - y = batch[self._target_key][:, -self.forecast_len:, 0] + y = batch[self._target_key][:, -self.forecast_len :, 0] batch_size = y.shape[0] for i in range(batch_size): @@ -624,18 +624,21 @@ def _log_validation_results(self, batch, y_hat, accum_batch_num): y_hat_i = y_hat[i].detach().cpu().numpy() time_utc_key = BatchKey[f"{self._target_key}_time_utc"] - time_utc = batch[time_utc_key][i, -self.forecast_len:].detach().cpu().numpy() + time_utc = batch[time_utc_key][i, -self.forecast_len :].detach().cpu().numpy() id_key = BatchKey[f"{self._target_key}_id"] ids = batch[id_key][i].detach().cpu().numpy() - self.validation_epoch_results.append({"y": y_i, - "y_hat": y_hat_i, - "time_utc": time_utc, - "id": ids, - "batch_idx": accum_batch_num, - "example_idx": i, - }) + self.validation_epoch_results.append( + { + "y": y_i, + "y_hat": y_hat_i, + "time_utc": time_utc, + "id": ids, + "batch_idx": accum_batch_num, + "example_idx": i, + } + ) def validation_step(self, batch: dict, batch_idx): """Run validation step""" @@ -713,7 +716,9 @@ def on_validation_epoch_end(self): filename = os.path.join(tempdir, f"validation_results.csv_{self.current_epoch}") validation_results_df.to_csv(filename, index=False) - validation_artifact = wandb.Artifact(f"validation_results_epoch={self.current_epoch}", type="dataset") + validation_artifact = wandb.Artifact( + f"validation_results_epoch={self.current_epoch}", type="dataset" + ) wandb.log_artifact(validation_artifact) except Exception as e: print("Failed to log validation results to wandb") From f25947b6317d7217c1a8999a3434a2c6ea5554ae Mon Sep 17 00:00:00 2001 From: peterdudfield Date: Thu, 5 Sep 2024 16:42:35 +0100 Subject: [PATCH 03/18] fix validation df --- pvnet/models/base_model.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index 855dc625..5b9999c6 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -623,22 +623,24 @@ def _log_validation_results(self, batch, y_hat, accum_batch_num): y_i = y[i].detach().cpu().numpy() y_hat_i = y_hat[i].detach().cpu().numpy() - time_utc_key = BatchKey[f"{self._target_key}_time_utc"] + time_utc_key = getattr(BatchKey, f"{self._target_key}_time_utc") time_utc = batch[time_utc_key][i, -self.forecast_len :].detach().cpu().numpy() - id_key = BatchKey[f"{self._target_key}_id"] - ids = batch[id_key][i].detach().cpu().numpy() + id_key = getattr(BatchKey, f"{self._target_key}_id") + target_id = batch[id_key][i].detach().cpu().numpy() - self.validation_epoch_results.append( + results_df = pd.DataFrame( { "y": y_i, "y_hat": y_hat_i, "time_utc": time_utc, - "id": ids, - "batch_idx": accum_batch_num, - "example_idx": i, } ) + results_df['id'] = target_id + results_df['batch_idx'] = accum_batch_num + results_df['example_idx'] = i + + self.validation_epoch_results.append(results_df) def validation_step(self, batch: dict, batch_idx): """Run validation step""" @@ -711,7 +713,7 @@ def on_validation_epoch_end(self): try: # join together validation results, and save to wandb - validation_results_df = pd.DataFrame(self.validation_epoch_results) + validation_results_df = pd.concat(self.validation_epoch_results) with tempfile.TemporaryDirectory() as tempdir: filename = os.path.join(tempdir, f"validation_results.csv_{self.current_epoch}") validation_results_df.to_csv(filename, index=False) From 202385a218a82c77d71167f38b5c6aa4332e6ba6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Sep 2024 15:43:32 +0000 Subject: [PATCH 04/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pvnet/models/base_model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index 5b9999c6..ed082bf6 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -636,9 +636,9 @@ def _log_validation_results(self, batch, y_hat, accum_batch_num): "time_utc": time_utc, } ) - results_df['id'] = target_id - results_df['batch_idx'] = accum_batch_num - results_df['example_idx'] = i + results_df["id"] = target_id + results_df["batch_idx"] = accum_batch_num + results_df["example_idx"] = i self.validation_epoch_results.append(results_df) From a17ad48fcc97ac25175e2509f37238847a1d92b9 Mon Sep 17 00:00:00 2001 From: peterdudfield Date: Thu, 5 Sep 2024 17:00:42 +0100 Subject: [PATCH 05/18] tidy up --- pvnet/models/base_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index ed082bf6..dc86a879 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -623,10 +623,10 @@ def _log_validation_results(self, batch, y_hat, accum_batch_num): y_i = y[i].detach().cpu().numpy() y_hat_i = y_hat[i].detach().cpu().numpy() - time_utc_key = getattr(BatchKey, f"{self._target_key}_time_utc") + time_utc_key = BatchKey[f"{self._target_key}_time_utc"] time_utc = batch[time_utc_key][i, -self.forecast_len :].detach().cpu().numpy() - id_key = getattr(BatchKey, f"{self._target_key}_id") + id_key = BatchKey[f"{self._target_key}_id"] target_id = batch[id_key][i].detach().cpu().numpy() results_df = pd.DataFrame( From e227b655564f96a35f12f2811489908d143364a1 Mon Sep 17 00:00:00 2001 From: peterdudfield Date: Thu, 5 Sep 2024 17:03:04 +0100 Subject: [PATCH 06/18] at print statment --- pvnet/models/base_model.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index dc86a879..c2354a2b 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -623,6 +623,8 @@ def _log_validation_results(self, batch, y_hat, accum_batch_num): y_i = y[i].detach().cpu().numpy() y_hat_i = y_hat[i].detach().cpu().numpy() + print(BatchKey._member_map_) + time_utc_key = BatchKey[f"{self._target_key}_time_utc"] time_utc = batch[time_utc_key][i, -self.forecast_len :].detach().cpu().numpy() From 42ed3e74b2d1ba1a4c8dc5a1abf6f272638498d6 Mon Sep 17 00:00:00 2001 From: peterdudfield Date: Thu, 5 Sep 2024 17:10:07 +0100 Subject: [PATCH 07/18] try and except around odd error --- pvnet/models/base_model.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index c2354a2b..f0b6e813 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -624,8 +624,11 @@ def _log_validation_results(self, batch, y_hat, accum_batch_num): y_hat_i = y_hat[i].detach().cpu().numpy() print(BatchKey._member_map_) + try: + time_utc_key = BatchKey[f"{self._target_key}_time_utc"] + except Exception as e: + raise Exception(f"Failed to find time_utc key for {self._target_key}, {BatchKey._member_map_}, {e}") - time_utc_key = BatchKey[f"{self._target_key}_time_utc"] time_utc = batch[time_utc_key][i, -self.forecast_len :].detach().cpu().numpy() id_key = BatchKey[f"{self._target_key}_id"] From 0444b700fa5126ecc223d63c441742e87c1977dd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Sep 2024 16:11:52 +0000 Subject: [PATCH 08/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pvnet/models/base_model.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index f0b6e813..7f9f54f0 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -627,7 +627,9 @@ def _log_validation_results(self, batch, y_hat, accum_batch_num): try: time_utc_key = BatchKey[f"{self._target_key}_time_utc"] except Exception as e: - raise Exception(f"Failed to find time_utc key for {self._target_key}, {BatchKey._member_map_}, {e}") + raise Exception( + f"Failed to find time_utc key for {self._target_key}, {BatchKey._member_map_}, {e}" + ) time_utc = batch[time_utc_key][i, -self.forecast_len :].detach().cpu().numpy() From 53024cdbdc21e7765c199f1280061c2bb364a380 Mon Sep 17 00:00:00 2001 From: peterdudfield Date: Thu, 5 Sep 2024 17:15:22 +0100 Subject: [PATCH 09/18] fix --- pvnet/models/base_model.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index 7f9f54f0..d9ebb482 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -623,18 +623,13 @@ def _log_validation_results(self, batch, y_hat, accum_batch_num): y_i = y[i].detach().cpu().numpy() y_hat_i = y_hat[i].detach().cpu().numpy() - print(BatchKey._member_map_) - try: - time_utc_key = BatchKey[f"{self._target_key}_time_utc"] - except Exception as e: - raise Exception( - f"Failed to find time_utc key for {self._target_key}, {BatchKey._member_map_}, {e}" - ) - + time_utc_key = BatchKey[f"{self._target_key_name}_time_utc"] time_utc = batch[time_utc_key][i, -self.forecast_len :].detach().cpu().numpy() - id_key = BatchKey[f"{self._target_key}_id"] + id_key = BatchKey[f"{self._target_key_name}_id"] target_id = batch[id_key][i].detach().cpu().numpy() + if target_id.ndim > 0: + target_id = target_id[0] results_df = pd.DataFrame( { From ffe9b1507f5a24d7ec448ddd2794a31b9b320f2a Mon Sep 17 00:00:00 2001 From: peterdudfield Date: Fri, 6 Sep 2024 08:34:48 +0100 Subject: [PATCH 10/18] PR comments --- pvnet/models/base_model.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index d9ebb482..d9bcdb3c 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -618,16 +618,18 @@ def _log_validation_results(self, batch, y_hat, accum_batch_num): y = batch[self._target_key][:, -self.forecast_len :, 0] batch_size = y.shape[0] + y = y.detach().cpu().numpy() + y_hat = y_hat.detach().cpu().numpy() + time_utc_key = BatchKey[f"{self._target_key_name}_time_utc"] + time_utc = batch[time_utc_key][:, -self.forecast_len:].detach().cpu().numpy() + id_key = BatchKey[f"{self._target_key_name}_id"] + target_id = batch[id_key].detach().cpu().numpy() for i in range(batch_size): - y_i = y[i].detach().cpu().numpy() - y_hat_i = y_hat[i].detach().cpu().numpy() - - time_utc_key = BatchKey[f"{self._target_key_name}_time_utc"] - time_utc = batch[time_utc_key][i, -self.forecast_len :].detach().cpu().numpy() - - id_key = BatchKey[f"{self._target_key_name}_id"] - target_id = batch[id_key][i].detach().cpu().numpy() + y_i = y[i] + y_hat_i = y_hat[i] + time_utc = time_utc[i] + target_id = target_id[i] if target_id.ndim > 0: target_id = target_id[0] @@ -717,12 +719,14 @@ def on_validation_epoch_end(self): # join together validation results, and save to wandb validation_results_df = pd.concat(self.validation_epoch_results) with tempfile.TemporaryDirectory() as tempdir: - filename = os.path.join(tempdir, f"validation_results.csv_{self.current_epoch}") + filename = os.path.join(tempdir, f"validation_results_{self.current_epoch}.csv") validation_results_df.to_csv(filename, index=False) + # make and log wand artifact validation_artifact = wandb.Artifact( f"validation_results_epoch={self.current_epoch}", type="dataset" ) + validation_artifact.add_file(filename) wandb.log_artifact(validation_artifact) except Exception as e: print("Failed to log validation results to wandb") From 79de1c4c3da17068f9d64a360f84edb35adffb93 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 6 Sep 2024 07:36:02 +0000 Subject: [PATCH 11/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pvnet/models/base_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index d9bcdb3c..a9fca650 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -621,7 +621,7 @@ def _log_validation_results(self, batch, y_hat, accum_batch_num): y = y.detach().cpu().numpy() y_hat = y_hat.detach().cpu().numpy() time_utc_key = BatchKey[f"{self._target_key_name}_time_utc"] - time_utc = batch[time_utc_key][:, -self.forecast_len:].detach().cpu().numpy() + time_utc = batch[time_utc_key][:, -self.forecast_len :].detach().cpu().numpy() id_key = BatchKey[f"{self._target_key_name}_id"] target_id = batch[id_key].detach().cpu().numpy() From d501107e10ecf6a339cd739162f673c9788a88e6 Mon Sep 17 00:00:00 2001 From: peterdudfield Date: Fri, 6 Sep 2024 08:45:39 +0100 Subject: [PATCH 12/18] fix and add comments --- pvnet/models/base_model.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index d9bcdb3c..5cece4fc 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -616,31 +616,37 @@ def _log_forecast_plot(self, batch, y_hat, accum_batch_num, timesteps_to_plot, p def _log_validation_results(self, batch, y_hat, accum_batch_num): """Append validation results to self.validation_epoch_results""" - y = batch[self._target_key][:, -self.forecast_len :, 0] - batch_size = y.shape[0] + # get truth values, shape (b, forecast_len) + y = batch[self._target_key][:, -self.forecast_len:, 0] y = y.detach().cpu().numpy() + batch_size = y.shape[0] + + # get truth values, shape (b, forecast_len) y_hat = y_hat.detach().cpu().numpy() + + # get time_utc, shape (b, forecast_len) time_utc_key = BatchKey[f"{self._target_key_name}_time_utc"] time_utc = batch[time_utc_key][:, -self.forecast_len:].detach().cpu().numpy() + + # get target id and change from (b,1) to (b,) id_key = BatchKey[f"{self._target_key_name}_id"] target_id = batch[id_key].detach().cpu().numpy() + target_id = target_id.squeeze() for i in range(batch_size): y_i = y[i] y_hat_i = y_hat[i] - time_utc = time_utc[i] - target_id = target_id[i] - if target_id.ndim > 0: - target_id = target_id[0] + time_utc_i = time_utc[i] + target_id_i = target_id[i] results_df = pd.DataFrame( { "y": y_i, "y_hat": y_hat_i, - "time_utc": time_utc, + "time_utc": time_utc_i, } ) - results_df["id"] = target_id + results_df["id"] = target_id_i results_df["batch_idx"] = accum_batch_num results_df["example_idx"] = i From 156fdfab301707bff129936acfddcf2f782d6c90 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 6 Sep 2024 07:46:37 +0000 Subject: [PATCH 13/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pvnet/models/base_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index 5cece4fc..29b909b2 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -617,7 +617,7 @@ def _log_validation_results(self, batch, y_hat, accum_batch_num): """Append validation results to self.validation_epoch_results""" # get truth values, shape (b, forecast_len) - y = batch[self._target_key][:, -self.forecast_len:, 0] + y = batch[self._target_key][:, -self.forecast_len :, 0] y = y.detach().cpu().numpy() batch_size = y.shape[0] @@ -626,7 +626,7 @@ def _log_validation_results(self, batch, y_hat, accum_batch_num): # get time_utc, shape (b, forecast_len) time_utc_key = BatchKey[f"{self._target_key_name}_time_utc"] - time_utc = batch[time_utc_key][:, -self.forecast_len:].detach().cpu().numpy() + time_utc = batch[time_utc_key][:, -self.forecast_len :].detach().cpu().numpy() # get target id and change from (b,1) to (b,) id_key = BatchKey[f"{self._target_key_name}_id"] From 89e4e75cfe469a06e0ff0f81bbd98ade69c9f232 Mon Sep 17 00:00:00 2001 From: peterdudfield Date: Thu, 12 Sep 2024 12:29:18 +0100 Subject: [PATCH 14/18] update for quantile loss --- pvnet/models/base_model.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index 29b909b2..5dfb553d 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -635,7 +635,11 @@ def _log_validation_results(self, batch, y_hat, accum_batch_num): for i in range(batch_size): y_i = y[i] - y_hat_i = y_hat[i] + if self.use_quantile_regression: + idx = self.output_quantiles.index(0.5) + y_hat_i = y_hat[i,idx] + else: + y_hat_i = y_hat[i] time_utc_i = time_utc[i] target_id_i = target_id[i] From 8276a07f1b1e2013d3b713fc954a5a87b2c57d89 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 12 Sep 2024 11:31:12 +0000 Subject: [PATCH 15/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pvnet/models/base_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index 5dfb553d..cc7aecae 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -637,7 +637,7 @@ def _log_validation_results(self, batch, y_hat, accum_batch_num): y_i = y[i] if self.use_quantile_regression: idx = self.output_quantiles.index(0.5) - y_hat_i = y_hat[i,idx] + y_hat_i = y_hat[i, idx] else: y_hat_i = y_hat[i] time_utc_i = time_utc[i] From 058a881773e2998ece40a91849755a4a104a1ed9 Mon Sep 17 00:00:00 2001 From: peterdudfield Date: Thu, 12 Sep 2024 12:38:14 +0100 Subject: [PATCH 16/18] save all quantile results --- pvnet/models/base_model.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index 5dfb553d..4e2f05c5 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -635,20 +635,27 @@ def _log_validation_results(self, batch, y_hat, accum_batch_num): for i in range(batch_size): y_i = y[i] - if self.use_quantile_regression: - idx = self.output_quantiles.index(0.5) - y_hat_i = y_hat[i,idx] - else: - y_hat_i = y_hat[i] + y_hat_i = y_hat[i] time_utc_i = time_utc[i] target_id_i = target_id[i] - results_df = pd.DataFrame( + results_dict = \ { "y": y_i, - "y_hat": y_hat_i, "time_utc": time_utc_i, } + if self.use_quantile_regression: + results_dict.update( + { + f"y_quantile_{q}": y_hat_i[:, i] + for i, q in enumerate(self.output_quantiles) + } + ) + else: + results_dict["y_hat"] = y_hat_i + + results_df = pd.DataFrame( + results_dict ) results_df["id"] = target_id_i results_df["batch_idx"] = accum_batch_num From a9032a046e436681b677b22cf762814836c9ae09 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 12 Sep 2024 11:38:54 +0000 Subject: [PATCH 17/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pvnet/models/base_model.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index 4e2f05c5..4bd95395 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -639,24 +639,18 @@ def _log_validation_results(self, batch, y_hat, accum_batch_num): time_utc_i = time_utc[i] target_id_i = target_id[i] - results_dict = \ - { - "y": y_i, - "time_utc": time_utc_i, - } + results_dict = { + "y": y_i, + "time_utc": time_utc_i, + } if self.use_quantile_regression: results_dict.update( - { - f"y_quantile_{q}": y_hat_i[:, i] - for i, q in enumerate(self.output_quantiles) - } + {f"y_quantile_{q}": y_hat_i[:, i] for i, q in enumerate(self.output_quantiles)} ) else: results_dict["y_hat"] = y_hat_i - results_df = pd.DataFrame( - results_dict - ) + results_df = pd.DataFrame(results_dict) results_df["id"] = target_id_i results_df["batch_idx"] = accum_batch_num results_df["example_idx"] = i From 74f4f6df748ce65a8d181b26b077180573b5ff96 Mon Sep 17 00:00:00 2001 From: peterdudfield Date: Thu, 12 Sep 2024 14:28:18 +0100 Subject: [PATCH 18/18] PR comment --- pvnet/models/base_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index 4e2f05c5..eca2e3ab 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -621,7 +621,7 @@ def _log_validation_results(self, batch, y_hat, accum_batch_num): y = y.detach().cpu().numpy() batch_size = y.shape[0] - # get truth values, shape (b, forecast_len) + # get prediction values, shape (b, forecast_len, quantiles?) y_hat = y_hat.detach().cpu().numpy() # get time_utc, shape (b, forecast_len)