From 7daa19b987228ae878227a136c869e35c19b85b8 Mon Sep 17 00:00:00 2001 From: Alexandra Udaltsova <43303448+AUdaltsova@users.noreply.github.com> Date: Mon, 5 Aug 2024 16:05:42 +0100 Subject: [PATCH 01/27] Fix history bug and add bells in analysis.py --- experiments/analysis.py | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/experiments/analysis.py b/experiments/analysis.py index 75b00cd9..9d7d35f9 100644 --- a/experiments/analysis.py +++ b/experiments/analysis.py @@ -16,10 +16,11 @@ def main(runs: list[str], run_names: list[str]) -> None: """ api = wandb.Api() dfs = [] + epoch_num = [] for run in runs: - run = api.run(f"openclimatefix/india/{run}") + run = api.run(f"openclimatefix/PROJECT/{run}") - df = run.history() + df = run.history(samples=run.lastHistoryStep+1) # Get the columns that are in the format 'MAE_horizon/step_/val` mae_cols = [col for col in df.columns if "MAE_horizon/step_" in col and "val" in col] # Sort them @@ -40,6 +41,7 @@ def main(runs: list[str], run_names: list[str]) -> None: # Get the step from the column name column_timesteps = [int(col.split("_")[-1].split("/")[0]) * 15 for col in mae_cols] dfs.append(df) + epoch_num.append(min_row_idx) # Get the timedelta for each group groupings = [ [0, 0], @@ -89,22 +91,22 @@ def main(runs: list[str], run_names: list[str]) -> None: # Plot the error on per timestep, and all timesteps plt.figure() for idx, df in enumerate(dfs): - plt.plot(column_timesteps, df, label=run_names[idx]) - plt.legend() - plt.xlabel("Timestep (minutes)") - plt.ylabel("MAE %") - plt.title("MAE % for each timestep") + plt.plot(column_timesteps, df, label=f"{run_names[idx]}, epoch: {epoch_num[idx]}") + plt.legend(fontsize=18) + plt.xlabel("Timestep (minutes)", fontsize=18) + plt.ylabel("MAE %", fontsize=18) + plt.title("MAE % for each timestep", fontsize=24) plt.savefig("mae_per_timestep.png") plt.show() # Plot the error on per timestep, and grouped timesteps plt.figure() - for run_name in run_names: - plt.plot(groups_df[run_name], label=run_name) - plt.legend() - plt.xlabel("Timestep (minutes)") - plt.ylabel("MAE %") - plt.title("MAE % for each timestep") + for idx, run_name in enumerate(run_names): + plt.plot(groups_df[run_name], label=f"{run_name}, epoch: {epoch_num[idx]}") + plt.legend(fontsize=18) + plt.xlabel("Timestep (minutes)", fontsize=18) + plt.ylabel("MAE %", fontsize=18) + plt.title("MAE % for each timestep", fontsize=24) plt.savefig("mae_per_timestep.png") plt.show() @@ -119,3 +121,4 @@ def main(runs: list[str], run_names: list[str]) -> None: parser.add_argument("--run_names", nargs="+") args = parser.parse_args() main(args.list_of_runs, args.run_names) + From e268a61fdc9c3d487038a96be7cbdeab90c0e1df Mon Sep 17 00:00:00 2001 From: Alexandra Udaltsova <43303448+AUdaltsova@users.noreply.github.com> Date: Mon, 5 Aug 2024 16:14:37 +0100 Subject: [PATCH 02/27] Update analysis.py --- experiments/analysis.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/experiments/analysis.py b/experiments/analysis.py index 9d7d35f9..33fc718d 100644 --- a/experiments/analysis.py +++ b/experiments/analysis.py @@ -92,10 +92,10 @@ def main(runs: list[str], run_names: list[str]) -> None: plt.figure() for idx, df in enumerate(dfs): plt.plot(column_timesteps, df, label=f"{run_names[idx]}, epoch: {epoch_num[idx]}") - plt.legend(fontsize=18) - plt.xlabel("Timestep (minutes)", fontsize=18) - plt.ylabel("MAE %", fontsize=18) - plt.title("MAE % for each timestep", fontsize=24) + plt.legend() + plt.xlabel("Timestep (minutes)") + plt.ylabel("MAE %") + plt.title("MAE % for each timestep") plt.savefig("mae_per_timestep.png") plt.show() @@ -103,10 +103,10 @@ def main(runs: list[str], run_names: list[str]) -> None: plt.figure() for idx, run_name in enumerate(run_names): plt.plot(groups_df[run_name], label=f"{run_name}, epoch: {epoch_num[idx]}") - plt.legend(fontsize=18) - plt.xlabel("Timestep (minutes)", fontsize=18) - plt.ylabel("MAE %", fontsize=18) - plt.title("MAE % for each timestep", fontsize=24) + plt.legend() + plt.xlabel("Timestep (minutes)") + plt.ylabel("MAE %") + plt.title("MAE % for each timestep") plt.savefig("mae_per_timestep.png") plt.show() From f928e801200329188d9567de23efd9b30ba0fd7f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 5 Aug 2024 15:16:57 +0000 Subject: [PATCH 03/27] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- experiments/analysis.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/experiments/analysis.py b/experiments/analysis.py index 33fc718d..bb119664 100644 --- a/experiments/analysis.py +++ b/experiments/analysis.py @@ -20,7 +20,7 @@ def main(runs: list[str], run_names: list[str]) -> None: for run in runs: run = api.run(f"openclimatefix/PROJECT/{run}") - df = run.history(samples=run.lastHistoryStep+1) + df = run.history(samples=run.lastHistoryStep + 1) # Get the columns that are in the format 'MAE_horizon/step_/val` mae_cols = [col for col in df.columns if "MAE_horizon/step_" in col and "val" in col] # Sort them @@ -121,4 +121,3 @@ def main(runs: list[str], run_names: list[str]) -> None: parser.add_argument("--run_names", nargs="+") args = parser.parse_args() main(args.list_of_runs, args.run_names) - From bdbe00a435081a700cc6ad2345a11b69179143ca Mon Sep 17 00:00:00 2001 From: AUdaltsova Date: Fri, 9 Aug 2024 11:35:03 +0100 Subject: [PATCH 04/27] remove unused configs --- tests/test_data/sample_wind_batches/data_configuration.yaml | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/test_data/sample_wind_batches/data_configuration.yaml b/tests/test_data/sample_wind_batches/data_configuration.yaml index 91b8a50a..cf603063 100644 --- a/tests/test_data/sample_wind_batches/data_configuration.yaml +++ b/tests/test_data/sample_wind_batches/data_configuration.yaml @@ -5,7 +5,6 @@ general: input_data: default_forecast_minutes: 2880 default_history_minutes: 60 - data_source_which_defines_geospatial_locations: "wind" nwp: ecmwf: # Path to ECMWF NWP data in zarr format @@ -64,7 +63,6 @@ input_data: - label: india wind_filename: /mnt/storage_ssd_4tb/india_wind_data.nc wind_metadata_filename: /mnt/storage_ssd_4tb/india_wind_metadata.csv - get_center: true n_wind_systems_per_example: 1 #start_datetime: "2021-01-01 00:00:00" #end_datetime: "2024-01-01 00:00:00" @@ -72,7 +70,6 @@ input_data: #sensor_files_groups: # - label: meteomatics sensor_filename: "/mnt/storage_b/nwp/meteomatics/nw_india/wind*.zarr.zip" - get_center: false history_minutes: 60 forecast_minutes: 2880 #n_sensor_systems_per_example: 26 From c8381560579647c32e88caba6215be1e2523b54b Mon Sep 17 00:00:00 2001 From: BumpVersion Action Date: Fri, 9 Aug 2024 13:20:58 +0000 Subject: [PATCH 05/27] =?UTF-8?q?Bump=20version:=203.0.53=20=E2=86=92=203.?= =?UTF-8?q?0.54=20[skip=20ci]?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .bumpversion.cfg | 2 +- pvnet/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.bumpversion.cfg b/.bumpversion.cfg index 67baea56..6671c1aa 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,7 +1,7 @@ [bumpversion] commit = True tag = True -current_version = 3.0.53 +current_version = 3.0.54 message = Bump version: {current_version} → {new_version} [skip ci] [bumpversion:file:pvnet/__init__.py] diff --git a/pvnet/__init__.py b/pvnet/__init__.py index 35f41407..dd0cbcb8 100644 --- a/pvnet/__init__.py +++ b/pvnet/__init__.py @@ -1,2 +1,2 @@ """PVNet""" -__version__ = "3.0.53" +__version__ = "3.0.54" From 28ed7d823207f9ca4c069e4f73f9b82e94a4c04f Mon Sep 17 00:00:00 2001 From: Sukhil Patel <42407101+Sukh-P@users.noreply.github.com> Date: Thu, 29 Aug 2024 09:40:26 +0100 Subject: [PATCH 06/27] Add torch module hub mixin back in (#250) --- pvnet/models/base_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index 081e9bc5..9faff47d 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -13,7 +13,7 @@ import torch.nn.functional as F import wandb import yaml -from huggingface_hub import ModelCard, ModelCardData +from huggingface_hub import ModelCard, ModelCardData, PyTorchModelHubMixin from huggingface_hub.constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME from huggingface_hub.file_download import hf_hub_download from huggingface_hub.hf_api import HfApi @@ -140,7 +140,7 @@ def minimize_data_config(input_path, output_path, model): yaml.dump(config, outfile, default_flow_style=False) -class PVNetModelHubMixin: +class PVNetModelHubMixin(PyTorchModelHubMixin): """ Implementation of [`PyTorchModelHubMixin`] to provide model Hub upload/download capabilities. """ From 83e7b4857289b4b24924ed7916d6164e73c0a4ca Mon Sep 17 00:00:00 2001 From: BumpVersion Action Date: Thu, 29 Aug 2024 08:40:58 +0000 Subject: [PATCH 07/27] =?UTF-8?q?Bump=20version:=203.0.54=20=E2=86=92=203.?= =?UTF-8?q?0.55=20[skip=20ci]?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .bumpversion.cfg | 2 +- pvnet/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.bumpversion.cfg b/.bumpversion.cfg index 6671c1aa..9779516d 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,7 +1,7 @@ [bumpversion] commit = True tag = True -current_version = 3.0.54 +current_version = 3.0.55 message = Bump version: {current_version} → {new_version} [skip ci] [bumpversion:file:pvnet/__init__.py] diff --git a/pvnet/__init__.py b/pvnet/__init__.py index dd0cbcb8..f84c422d 100644 --- a/pvnet/__init__.py +++ b/pvnet/__init__.py @@ -1,2 +1,2 @@ """PVNet""" -__version__ = "3.0.54" +__version__ = "3.0.55" From f358681273a667c59b0ed43f3bdd3e5064f0e968 Mon Sep 17 00:00:00 2001 From: Alexandra Udaltsova <43303448+AUdaltsova@users.noreply.github.com> Date: Fri, 6 Sep 2024 23:33:45 +0100 Subject: [PATCH 08/27] added hard to contribute batch and icon note --- README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 128a1a9e..26e8542f 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # PVNet 2.1 -[![Python Bump Version & release](https://github.com/openclimatefix/PVNet/actions/workflows/release.yml/badge.svg)](https://github.com/openclimatefix/PVNet/actions/workflows/release.yml) + [![Python Bump Version & release](https://github.com/openclimatefix/PVNet/actions/workflows/release.yml/badge.svg)](https://github.com/openclimatefix/PVNet/actions/workflows/release.yml) [![hard to contribute to](https://img.shields.io/badge/hard%20to%20contribute%20to-dd2e44)](https://github.com/openclimatefix/ocf-meta-repo?tab=readme-ov-file#overview-of-ocfs-nowcasting-repositories) This project is used for training PVNet and running PVNet on live data. @@ -85,6 +85,8 @@ OCF maintains a Zarr formatted version of the German Weather Service's (DWD) ICON-EU NWP model here: https://huggingface.co/datasets/openclimatefix/dwd-icon-eu which includes the UK +Please note that the current version of [ICON loader]([url](https://github.com/openclimatefix/ocf_datapipes/blob/9ec252eeee44937c12ab52699579bdcace76e72f/ocf_datapipes/load/nwp/providers/icon.py#L9-L30)) supports a different format. If you want to use our ICON-EU dataset or your own NWP source, you can create a loader for it using [the instructions here]([url](https://github.com/openclimatefix/ocf_datapipes/tree/main/ocf_datapipes/load#nwp)). + **PV**\ OCF maintains a dataset of PV generation from 1311 private PV installations here: https://huggingface.co/datasets/openclimatefix/uk_pv From 27204b91b433927aa8a79ba15bd7b08303cfbeaf Mon Sep 17 00:00:00 2001 From: Alexandra Udaltsova <43303448+AUdaltsova@users.noreply.github.com> Date: Tue, 10 Sep 2024 16:56:37 +0100 Subject: [PATCH 09/27] Update badge style --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 26e8542f..051e9a2e 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,7 @@ # PVNet 2.1 - [![Python Bump Version & release](https://github.com/openclimatefix/PVNet/actions/workflows/release.yml/badge.svg)](https://github.com/openclimatefix/PVNet/actions/workflows/release.yml) [![hard to contribute to](https://img.shields.io/badge/hard%20to%20contribute%20to-dd2e44)](https://github.com/openclimatefix/ocf-meta-repo?tab=readme-ov-file#overview-of-ocfs-nowcasting-repositories) + [![Python Bump Version & release](https://github.com/openclimatefix/PVNet/actions/workflows/release.yml/badge.svg)](https://github.com/openclimatefix/PVNet/actions/workflows/release.yml) [![ease of contribution: hard](https://img.shields.io/badge/ease%20of%20contribution:%20hard-bb2629)](https://github.com/openclimatefix/ocf-meta-repo?tab=readme-ov-file#overview-of-ocfs-nowcasting-repositories) + This project is used for training PVNet and running PVNet on live data. From 280dc2c7e50cdb7165e21903072b16795e28cf12 Mon Sep 17 00:00:00 2001 From: Peter Dudfield <34686298+peterdudfield@users.noreply.github.com> Date: Thu, 12 Sep 2024 18:54:08 +0100 Subject: [PATCH 10/27] save validation batch results to wandb (#252) * save validation batch results to wandb * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix validation df * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * tidy up * at print statment * try and except around odd error * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * PR comments * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix and add comments * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update for quantile loss * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * save all quantile results * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * PR comment --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- pvnet/models/base_model.py | 72 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 70 insertions(+), 2 deletions(-) diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index 9faff47d..e6e0e907 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -2,6 +2,7 @@ import json import logging import os +import tempfile from pathlib import Path from typing import Dict, Optional, Union @@ -410,6 +411,9 @@ def __init__( else: self.num_output_features = self.forecast_len + # save all validation results to array, so we can save these to weights n biases + self.validation_epoch_results = [] + def _quantiles_to_prediction(self, y_quantiles): """ Convert network prediction into a point prediction. @@ -609,12 +613,61 @@ def _log_forecast_plot(self, batch, y_hat, accum_batch_num, timesteps_to_plot, p print(e) plt.close(fig) + def _log_validation_results(self, batch, y_hat, accum_batch_num): + """Append validation results to self.validation_epoch_results""" + + # get truth values, shape (b, forecast_len) + y = batch[self._target_key][:, -self.forecast_len :, 0] + y = y.detach().cpu().numpy() + batch_size = y.shape[0] + + # get prediction values, shape (b, forecast_len, quantiles?) + y_hat = y_hat.detach().cpu().numpy() + + # get time_utc, shape (b, forecast_len) + time_utc_key = BatchKey[f"{self._target_key_name}_time_utc"] + time_utc = batch[time_utc_key][:, -self.forecast_len :].detach().cpu().numpy() + + # get target id and change from (b,1) to (b,) + id_key = BatchKey[f"{self._target_key_name}_id"] + target_id = batch[id_key].detach().cpu().numpy() + target_id = target_id.squeeze() + + for i in range(batch_size): + y_i = y[i] + y_hat_i = y_hat[i] + time_utc_i = time_utc[i] + target_id_i = target_id[i] + + results_dict = { + "y": y_i, + "time_utc": time_utc_i, + } + if self.use_quantile_regression: + results_dict.update( + {f"y_quantile_{q}": y_hat_i[:, i] for i, q in enumerate(self.output_quantiles)} + ) + else: + results_dict["y_hat"] = y_hat_i + + results_df = pd.DataFrame(results_dict) + results_df["id"] = target_id_i + results_df["batch_idx"] = accum_batch_num + results_df["example_idx"] = i + + self.validation_epoch_results.append(results_df) + def validation_step(self, batch: dict, batch_idx): """Run validation step""" + + accum_batch_num = batch_idx // self.trainer.accumulate_grad_batches + y_hat = self(batch) # Sensor seems to be in batch, station, time order y = batch[self._target_key][:, -self.forecast_len :, 0] + self._log_validation_results(batch, y_hat, accum_batch_num) + # Expand persistence to be the same shape as y losses = self._calculate_common_losses(y, y_hat) losses.update(self._calculate_val_losses(y, y_hat)) @@ -632,8 +685,6 @@ def validation_step(self, batch: dict, batch_idx): on_epoch=True, ) - accum_batch_num = batch_idx // self.trainer.accumulate_grad_batches - # Make plots only if using wandb logger if isinstance(self.logger, pl.loggers.WandbLogger) and accum_batch_num in [0, 1]: # Store these temporarily under self @@ -675,6 +726,23 @@ def validation_step(self, batch: dict, batch_idx): def on_validation_epoch_end(self): """Run on epoch end""" + try: + # join together validation results, and save to wandb + validation_results_df = pd.concat(self.validation_epoch_results) + with tempfile.TemporaryDirectory() as tempdir: + filename = os.path.join(tempdir, f"validation_results_{self.current_epoch}.csv") + validation_results_df.to_csv(filename, index=False) + + # make and log wand artifact + validation_artifact = wandb.Artifact( + f"validation_results_epoch={self.current_epoch}", type="dataset" + ) + validation_artifact.add_file(filename) + wandb.log_artifact(validation_artifact) + except Exception as e: + print("Failed to log validation results to wandb") + print(e) + horizon_maes_dict = self._horizon_maes.flush() # Create the horizon accuracy curve From 56a39a709ee640f0cabb49671ff8d189403a7bef Mon Sep 17 00:00:00 2001 From: BumpVersion Action Date: Thu, 12 Sep 2024 17:54:38 +0000 Subject: [PATCH 11/27] =?UTF-8?q?Bump=20version:=203.0.55=20=E2=86=92=203.?= =?UTF-8?q?0.56=20[skip=20ci]?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .bumpversion.cfg | 2 +- pvnet/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.bumpversion.cfg b/.bumpversion.cfg index 9779516d..1d0343e0 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,7 +1,7 @@ [bumpversion] commit = True tag = True -current_version = 3.0.55 +current_version = 3.0.56 message = Bump version: {current_version} → {new_version} [skip ci] [bumpversion:file:pvnet/__init__.py] diff --git a/pvnet/__init__.py b/pvnet/__init__.py index f84c422d..e1e6cfa0 100644 --- a/pvnet/__init__.py +++ b/pvnet/__init__.py @@ -1,2 +1,2 @@ """PVNet""" -__version__ = "3.0.55" +__version__ = "3.0.56" From 4795a9965b455503d7475bcbc2f2fcba91f14757 Mon Sep 17 00:00:00 2001 From: Alexandra Udaltsova <43303448+AUdaltsova@users.noreply.github.com> Date: Fri, 20 Sep 2024 16:36:50 +0100 Subject: [PATCH 12/27] Fix wandb artifact saving in base_model.py --- pvnet/models/base_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index e6e0e907..8665bf3c 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -735,7 +735,7 @@ def on_validation_epoch_end(self): # make and log wand artifact validation_artifact = wandb.Artifact( - f"validation_results_epoch={self.current_epoch}", type="dataset" + f"validation_results_epoch_{self.current_epoch}", type="dataset" ) validation_artifact.add_file(filename) wandb.log_artifact(validation_artifact) From b5fa2d95acb0f0655d72cd5a753bc18f6cc8333c Mon Sep 17 00:00:00 2001 From: BumpVersion Action Date: Mon, 23 Sep 2024 13:51:04 +0000 Subject: [PATCH 13/27] =?UTF-8?q?Bump=20version:=203.0.56=20=E2=86=92=203.?= =?UTF-8?q?0.57=20[skip=20ci]?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .bumpversion.cfg | 2 +- pvnet/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.bumpversion.cfg b/.bumpversion.cfg index 1d0343e0..19cd1fc3 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,7 +1,7 @@ [bumpversion] commit = True tag = True -current_version = 3.0.56 +current_version = 3.0.57 message = Bump version: {current_version} → {new_version} [skip ci] [bumpversion:file:pvnet/__init__.py] diff --git a/pvnet/__init__.py b/pvnet/__init__.py index e1e6cfa0..34aca23b 100644 --- a/pvnet/__init__.py +++ b/pvnet/__init__.py @@ -1,2 +1,2 @@ """PVNet""" -__version__ = "3.0.56" +__version__ = "3.0.57" From c93996cdf26d892b7186c6a17ae5493f1a5bc59c Mon Sep 17 00:00:00 2001 From: James Fulton <41546094+dfulu@users.noreply.github.com> Date: Mon, 7 Oct 2024 11:33:34 +0100 Subject: [PATCH 14/27] Use concurrent batch pipeline for ~30x speed up (#236) * bug fix * use concurrent datapipe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * clean up * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update comment * update comment * save as tensor --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- scripts/backtest_uk_gsp.py | 73 ++++++++---------------------- scripts/save_concurrent_batches.py | 71 ++++------------------------- 2 files changed, 28 insertions(+), 116 deletions(-) diff --git a/scripts/backtest_uk_gsp.py b/scripts/backtest_uk_gsp.py index 33d7eb62..e002af3f 100644 --- a/scripts/backtest_uk_gsp.py +++ b/scripts/backtest_uk_gsp.py @@ -38,15 +38,11 @@ NumpyBatch, batch_to_tensor, copy_batch_to_device, - stack_np_examples_into_batch, ) from ocf_datapipes.config.load import load_yaml_configuration from ocf_datapipes.load import OpenGSP -from ocf_datapipes.training.common import create_t0_and_loc_datapipes -from ocf_datapipes.training.pvnet import ( - _get_datapipes_dict, - construct_sliced_data_pipeline, -) +from ocf_datapipes.training.common import _get_datapipes_dict +from ocf_datapipes.training.pvnet_all_gsp import construct_sliced_data_pipeline, create_t0_datapipe from ocf_datapipes.utils.consts import ELEVATION_MEAN, ELEVATION_STD from omegaconf import DictConfig @@ -58,20 +54,19 @@ from tqdm import tqdm from pvnet.load_model import get_model_from_checkpoints -from pvnet.utils import GSPLocationLookup # ------------------------------------------------------------------ # USER CONFIGURED VARIABLES -output_dir = "/mnt/disks/backtest/test_backtest" +output_dir = "/mnt/disks/extra_batches/test_backtest" # Local directory to load the PVNet checkpoint from. By default this should pull the best performing # checkpoint on the val set -model_chckpoint_dir = "/home/jamesfulton/repos/PVNet/checkpoints/kqaknmuc" +model_chckpoint_dir = "/home/jamesfulton/repos/PVNet/checkpoints/q911tei5" # Local directory to load the summation model checkpoint from. By default this should pull the best # performing checkpoint on the val set. If set to None a simple sum is used instead summation_chckpoint_dir = ( - "/home/jamesfulton/repos/PVNet_summation/checkpoints/pvnet_summation/nw673nw2" + "/home/jamesfulton/repos/PVNet_summation/checkpoints/pvnet_summation/73oa4w9t" ) # Forecasts will be made for all available init times between these @@ -144,7 +139,7 @@ def get_available_t0_times(start_datetime, end_datetime, config_path): # Pop out the config file config = datapipes_dict.pop("config") - # We are going to abuse the `create_t0_and_loc_datapipes()` function to find the init-times in + # We are going to abuse the `create_t0_datapipe()` function to find the init-times in # potential_init_times which we have input data for. To do this, we will feed in some fake GSP # data which has the potential_init_times as timestamps. This is a bit hacky but works for now @@ -172,18 +167,15 @@ def get_available_t0_times(start_datetime, end_datetime, config_path): # Overwrite the GSP data which is already in the datapipes dict datapipes_dict["gsp"] = IterableWrapper([ds_fake_gsp]) - # Use create_t0_and_loc_datapipes to get datapipe of init-times - location_pipe, t0_datapipe = create_t0_and_loc_datapipes( + # Use create_t0_datapipe to get datapipe of init-times + t0_datapipe = create_t0_datapipe( datapipes_dict, configuration=config, - key_for_t0="gsp", shuffle=False, ) - # Create a full list of available init-times. Note that we need to loop over the t0s AND - # locations to avoid the torch datapipes buffer overflow but we don't actually use the location - available_init_times = [t0 for _, t0 in zip(location_pipe, t0_datapipe)] - available_init_times = pd.to_datetime(available_init_times) + # Create a full list of available init-times + available_init_times = pd.to_datetime([t0 for t0 in t0_datapipe]) logger.info( f"{len(available_init_times)} out of {len(potential_init_times)} " @@ -193,22 +185,16 @@ def get_available_t0_times(start_datetime, end_datetime, config_path): return available_init_times -def get_loctimes_datapipes(config_path): - """Create location and init-time datapipes +def get_times_datapipe(config_path): + """Create init-time datapipe Args: config_path: Path to data config file Returns: - tuple: A tuple of datapipes - - Datapipe yielding locations - - Datapipe yielding init-times + Datapipe: A Datapipe yielding init-times """ - # Set up ID location query object - ds_gsp = get_gsp_ds(config_path) - gsp_id_to_loc = GSPLocationLookup(ds_gsp.x_osgb, ds_gsp.y_osgb) - # Filter the init-times to times we have all input data for available_target_times = get_available_t0_times( start_datetime, @@ -222,25 +208,13 @@ def get_loctimes_datapipes(config_path): # the backtest will end up producing available_target_times.to_frame().to_csv(f"{output_dir}/t0_times.csv") - # Cycle the GSP locations - location_pipe = IterableWrapper([[gsp_id_to_loc(gsp_id) for gsp_id in ALL_GSP_IDS]]).repeat( - num_t0s - ) - - # Shard and then unbatch the locations so that each worker will generate all samples for all - # GSPs and for a single init-time - location_pipe = location_pipe.sharding_filter() - location_pipe = location_pipe.unbatch(unbatch_level=1) - # Create times datapipe so each worker receives 317 copies of the same datetime for its batch - t0_datapipe = IterableWrapper([[t0 for gsp_id in ALL_GSP_IDS] for t0 in available_target_times]) + t0_datapipe = IterableWrapper(available_target_times) t0_datapipe = t0_datapipe.sharding_filter() - t0_datapipe = t0_datapipe.unbatch(unbatch_level=1) - t0_datapipe = t0_datapipe.set_length(num_t0s * len(ALL_GSP_IDS)) - location_pipe = location_pipe.set_length(num_t0s * len(ALL_GSP_IDS)) + t0_datapipe = t0_datapipe.set_length(num_t0s) - return location_pipe, t0_datapipe + return t0_datapipe class ModelPipe: @@ -375,25 +349,16 @@ def get_datapipe(config_path: str) -> NumpyBatch: """ # Construct location and init-time datapipes - location_pipe, t0_datapipe = get_loctimes_datapipes(config_path) - - # Get the number of init-times - num_batches = len(t0_datapipe) // len(ALL_GSP_IDS) + t0_datapipe = get_times_datapipe(config_path) # Construct sample datapipes data_pipeline = construct_sliced_data_pipeline( config_path, - location_pipe, t0_datapipe, ) - # Batch so that each worker returns a batch of all locations for a single init-time - # Also convert to tensor for model - data_pipeline = ( - data_pipeline.batch(len(ALL_GSP_IDS)).map(stack_np_examples_into_batch).map(batch_to_tensor) - ) - - data_pipeline = data_pipeline.set_length(num_batches) + # Convert to tensor for model + data_pipeline = data_pipeline.map(batch_to_tensor).set_length(len(t0_datapipe)) return data_pipeline diff --git a/scripts/save_concurrent_batches.py b/scripts/save_concurrent_batches.py index f421887e..37833b9e 100644 --- a/scripts/save_concurrent_batches.py +++ b/scripts/save_concurrent_batches.py @@ -32,19 +32,17 @@ import hydra import numpy as np import torch -from ocf_datapipes.batch import BatchKey, batch_to_tensor, stack_np_examples_into_batch -from ocf_datapipes.training.common import ( - open_and_return_datapipes, +from ocf_datapipes.batch import BatchKey, batch_to_tensor +from ocf_datapipes.training.pvnet_all_gsp import ( + construct_sliced_data_pipeline, + construct_time_pipeline, ) -from ocf_datapipes.training.pvnet import construct_loctime_pipelines, construct_sliced_data_pipeline from omegaconf import DictConfig, OmegaConf from sqlalchemy import exc as sa_exc from torch.utils.data import DataLoader from torch.utils.data.datapipes.iter import IterableWrapper from tqdm import tqdm -from pvnet.utils import GSPLocationLookup - warnings.filterwarnings("ignore", category=sa_exc.SAWarning) logger = logging.getLogger(__name__) @@ -61,73 +59,22 @@ def __call__(self, input): torch.save(batch, f"{self.batch_dir}/{i:06}.pt") -def select_first(x): - """Select zeroth element from indexable object""" - return x[0] - - -def _get_loctimes_datapipes(config_path, start_time, end_time, n_batches): - # Set up ID location query object - ds_gsp = next( - iter( - open_and_return_datapipes( - config_path, - use_gsp=True, - use_nwp=False, - use_pv=False, - use_sat=False, - use_hrv=False, - use_topo=False, - )["gsp"] - ) - ) - gsp_id_to_loc = GSPLocationLookup(ds_gsp.x_osgb, ds_gsp.y_osgb) - - # Cycle the GSP locations - location_pipe = IterableWrapper([[gsp_id_to_loc(gsp_id) for gsp_id in range(1, 318)]]).repeat( - n_batches - ) - - # Shard and unbatch so each worker goes through GSP 1-317 for each batch - location_pipe = location_pipe.sharding_filter() - location_pipe = location_pipe.unbatch(unbatch_level=1) - - # These two datapipes come from an earlier fork and must be iterated through together - # despite the fact that we don't want these random locations here - random_location_datapipe, t0_datapipe = construct_loctime_pipelines( +def _get_datapipe(config_path, start_time, end_time, n_batches): + t0_datapipe = construct_time_pipeline( config_path, start_time, end_time, ) - # Iterate through both but select only time - t0_datapipe = t0_datapipe.zip(random_location_datapipe).map(select_first) - - # Create times datapipe so we'll get the same time over each batch t0_datapipe = t0_datapipe.header(n_batches) - t0_datapipe = IterableWrapper([[t0 for gsp_id in range(1, 318)] for t0 in t0_datapipe]) t0_datapipe = t0_datapipe.sharding_filter() - t0_datapipe = t0_datapipe.unbatch(unbatch_level=1) - return location_pipe, t0_datapipe - - -def _get_datapipe(config_path, start_time, end_time, n_batches): - # Open datasets from the config and filter to useable location-time pairs - - location_pipe, t0_datapipe = _get_loctimes_datapipes( - config_path, start_time, end_time, n_batches - ) - - data_pipeline = construct_sliced_data_pipeline( + datapipe = construct_sliced_data_pipeline( config_path, - location_pipe, t0_datapipe, - ) - - data_pipeline = data_pipeline.batch(317).map(stack_np_examples_into_batch).map(batch_to_tensor) + ).map(batch_to_tensor) - return data_pipeline + return datapipe def _save_batches_with_dataloader(batch_pipe, batch_dir, num_batches, dataloader_kwargs): From dfce50ac954af1e758edb0dfbdd67c50083ab100 Mon Sep 17 00:00:00 2001 From: BumpVersion Action Date: Mon, 7 Oct 2024 10:34:07 +0000 Subject: [PATCH 15/27] =?UTF-8?q?Bump=20version:=203.0.57=20=E2=86=92=203.?= =?UTF-8?q?0.58=20[skip=20ci]?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .bumpversion.cfg | 2 +- pvnet/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.bumpversion.cfg b/.bumpversion.cfg index 19cd1fc3..33417b49 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,7 +1,7 @@ [bumpversion] commit = True tag = True -current_version = 3.0.57 +current_version = 3.0.58 message = Bump version: {current_version} → {new_version} [skip ci] [bumpversion:file:pvnet/__init__.py] diff --git a/pvnet/__init__.py b/pvnet/__init__.py index 34aca23b..798ba127 100644 --- a/pvnet/__init__.py +++ b/pvnet/__init__.py @@ -1,2 +1,2 @@ """PVNet""" -__version__ = "3.0.57" +__version__ = "3.0.58" From 4b5bf480ceaeb0f159deaeee5283ac5f41020a56 Mon Sep 17 00:00:00 2001 From: Sukhil Patel <42407101+Sukh-P@users.noreply.github.com> Date: Mon, 7 Oct 2024 18:16:49 +0100 Subject: [PATCH 16/27] Clean up site backtest script (#258) --- scripts/backtest_sites.py | 32 +++++++++++++++----------------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/scripts/backtest_sites.py b/scripts/backtest_sites.py index 11f5b898..e764abf8 100644 --- a/scripts/backtest_sites.py +++ b/scripts/backtest_sites.py @@ -58,19 +58,15 @@ from pvnet.utils import SiteLocationLookup # ------------------------------------------------------------------ -# USER CONFIGURED VARIABLES +# USER CONFIGURED VARIABLES TO RUN THE SCRIPT + +# Directory path to save results output_dir = "PLACEHOLDER" # Local directory to load the PVNet checkpoint from. By default this should pull the best performing # checkpoint on the val set model_chckpoint_dir = "PLACEHOLDER" -# Local directory to load the summation model checkpoint from. By default this should pull the best -# performing checkpoint on the val set. If set to None a simple sum is used instead -# summation_chckpoint_dir = ( -# "/home/jamesfulton/repos/PVNet_summation/checkpoints/pvnet_summation/nw673nw2" -# ) - # Forecasts will be made for all available init times between these start_datetime = "2022-05-08 00:00" end_datetime = "2022-05-08 00:30" @@ -96,8 +92,10 @@ # When sun as elevation below this, the forecast is set to zero MIN_DAY_ELEVATION = 0 -# All pv system ids to produce forecasts for +# Add all pv site ids here that you wish to produce forecasts for ALL_SITE_IDS = [] +# Need to be in ascending order +ALL_SITE_IDS.sort() # ------------------------------------------------------------------ # FUNCTIONS @@ -255,7 +253,8 @@ def get_loctimes_datapipes(config_path): unbatch_level=1 ) # might not need this part since the site datapipe is creating examples - # Create times datapipe so each worker receives 317 copies of the same datetime for its batch + # Create times datapipe so each worker receives + # len(ALL_SITE_IDS) copies of the same datetime for its batch t0_datapipe = IterableWrapper( [[t0 for site_id in ALL_SITE_IDS] for t0 in available_target_times] ) @@ -305,7 +304,7 @@ def predict_batch(self, batch: NumpyBatch) -> xr.Dataset: ) # Get effective capacities for this forecast - # site_capacities = ds_site.nominal_capacity_wp.values + site_capacities = self.ds_site.nominal_capacity_wp.values # Get the solar elevations. We need to un-normalise these from the values in the batch elevation = batch[BatchKey.pv_solar_elevation] * ELEVATION_STD + ELEVATION_MEAN # We only need elevation mask for forecasted values, not history @@ -327,18 +326,17 @@ def predict_batch(self, batch: NumpyBatch) -> xr.Dataset: y_normed_site = model(device_batch).detach().cpu().numpy() da_normed_site = preds_to_dataarray(y_normed_site, model, valid_times, ALL_SITE_IDS) - # TODO fix this step: Multiply normalised forecasts by capacities and clip negatives - # For now output normalised by capacity outputs and unnormalise in post processing - # da_abs_site = da_normed_site.clip(0, None) * site_capacities[:, None, None] - da_normed_site = da_normed_site.clip(0, None) + # Multiply normalised forecasts by capacities and clip negatives + da_abs_site = da_normed_site.clip(0, None) * site_capacities[:, None, None] + # Apply sundown mask - da_normed_site = da_normed_site.where(~da_sundown_mask).fillna(0.0) + da_abs_site = da_abs_site.where(~da_sundown_mask).fillna(0.0) - da_normed_site = da_normed_site.expand_dims(dim="init_time_utc", axis=0).assign_coords( + da_abs_site = da_abs_site.expand_dims(dim="init_time_utc", axis=0).assign_coords( init_time_utc=[t0] ) - return da_normed_site + return da_abs_site def get_datapipe(config_path: str) -> NumpyBatch: From aa76d4a748fc94f429daa2c5cf0608c82c447038 Mon Sep 17 00:00:00 2001 From: BumpVersion Action Date: Mon, 7 Oct 2024 17:17:21 +0000 Subject: [PATCH 17/27] =?UTF-8?q?Bump=20version:=203.0.58=20=E2=86=92=203.?= =?UTF-8?q?0.59=20[skip=20ci]?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .bumpversion.cfg | 2 +- pvnet/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.bumpversion.cfg b/.bumpversion.cfg index 33417b49..bc58fbc9 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,7 +1,7 @@ [bumpversion] commit = True tag = True -current_version = 3.0.58 +current_version = 3.0.59 message = Bump version: {current_version} → {new_version} [skip ci] [bumpversion:file:pvnet/__init__.py] diff --git a/pvnet/__init__.py b/pvnet/__init__.py index 798ba127..f9b4efb2 100644 --- a/pvnet/__init__.py +++ b/pvnet/__init__.py @@ -1,2 +1,2 @@ """PVNet""" -__version__ = "3.0.58" +__version__ = "3.0.59" From eb8b445614015d33ff055db165625f6ada829a34 Mon Sep 17 00:00:00 2001 From: Alexandra Udaltsova <43303448+AUdaltsova@users.noreply.github.com> Date: Fri, 18 Oct 2024 14:12:17 +0100 Subject: [PATCH 18/27] add pv padding and hf model handling to backtest_sites.py (#262) * added pv padding and hf model handling to backtest_sites.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Apply suggestions from code review Co-authored-by: Sukhil Patel <42407101+Sukh-P@users.noreply.github.com> * Update scripts/backtest_sites.py * Update pyproject.toml * undo Update pyproject.toml * linting * docstring scripts/backtest_sites.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * docstring scripts/backtest_sites.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Sukhil Patel <42407101+Sukh-P@users.noreply.github.com> --- scripts/backtest_sites.py | 89 ++++++++++++++++++++++++++++++++++++--- 1 file changed, 83 insertions(+), 6 deletions(-) diff --git a/scripts/backtest_sites.py b/scripts/backtest_sites.py index e764abf8..3572daa3 100644 --- a/scripts/backtest_sites.py +++ b/scripts/backtest_sites.py @@ -23,6 +23,7 @@ except RuntimeError: pass +import json import logging import os import sys @@ -32,6 +33,8 @@ import pandas as pd import torch import xarray as xr +from huggingface_hub import hf_hub_download +from huggingface_hub.constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME from ocf_datapipes.batch import ( BatchKey, NumpyBatch, @@ -50,7 +53,7 @@ ) from ocf_datapipes.utils.consts import ELEVATION_MEAN, ELEVATION_STD from omegaconf import DictConfig -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, IterDataPipe, functional_datapipe from torch.utils.data.datapipes.iter import IterableWrapper from tqdm import tqdm @@ -67,6 +70,10 @@ # checkpoint on the val set model_chckpoint_dir = "PLACEHOLDER" +hf_revision = None +hf_token = None +hf_model_id = None + # Forecasts will be made for all available init times between these start_datetime = "2022-05-08 00:00" end_datetime = "2022-05-08 00:30" @@ -101,11 +108,70 @@ # FUNCTIONS +@functional_datapipe("pad_forward_pv") +class PadForwardPVIterDataPipe(IterDataPipe): + """ + Pads forecast pv. + + Sun position is calculated based off of pv time index + and for t0's close to end of pv data can have wrong shape as pv starts + to run out of data to slice for the forecast part. + """ + + def __init__(self, pv_dp: IterDataPipe, forecast_duration: np.timedelta64): + """Init""" + + super().__init__() + self.pv_dp = pv_dp + self.forecast_duration = forecast_duration + + def __iter__(self): + """Iter""" + + for xr_data in self.pv_dp: + t0 = xr_data.time_utc.data[int(xr_data.attrs["t0_idx"])] + pv_step = np.timedelta64(xr_data.attrs["sample_period_duration"]) + t_end = t0 + self.forecast_duration + pv_step + time_idx = np.arange(xr_data.time_utc.data[0], t_end, pv_step) + yield xr_data.reindex(time_utc=time_idx, fill_value=-1) + + +def load_model_from_hf(model_id: str, revision: str, token: str): + """ + Loads model from HuggingFace + """ + + model_file = hf_hub_download( + repo_id=model_id, + filename=PYTORCH_WEIGHTS_NAME, + revision=revision, + token=token, + ) + + # load config file + config_file = hf_hub_download( + repo_id=model_id, + filename=CONFIG_NAME, + revision=revision, + token=token, + ) + + with open(config_file, "r", encoding="utf-8") as f: + config = json.load(f) + + model = hydra.utils.instantiate(config) + + state_dict = torch.load(model_file, map_location=torch.device("cuda")) + model.load_state_dict(state_dict) # type: ignore + model.eval() # type: ignore + + return model + + def preds_to_dataarray(preds, model, valid_times, site_ids): """Put numpy array of predictions into a dataarray""" if model.use_quantile_regression: - output_labels = model.output_quantiles output_labels = [f"forecast_mw_plevel_{int(q*100):02}" for q in model.output_quantiles] output_labels[output_labels.index("forecast_mw_plevel_50")] = "forecast_mw" else: @@ -333,7 +399,7 @@ def predict_batch(self, batch: NumpyBatch) -> xr.Dataset: da_abs_site = da_abs_site.where(~da_sundown_mask).fillna(0.0) da_abs_site = da_abs_site.expand_dims(dim="init_time_utc", axis=0).assign_coords( - init_time_utc=[t0] + init_time_utc=np.array([t0], dtype="datetime64[ns]") ) return da_abs_site @@ -362,6 +428,11 @@ def get_datapipe(config_path: str) -> NumpyBatch: t0_datapipe, ) + config = load_yaml_configuration(config_path) + data_pipeline["pv"] = data_pipeline["pv"].pad_forward_pv( + forecast_duration=np.timedelta64(config.input_data.pv.forecast_minutes, "m") + ) + data_pipeline = DictDatasetIterDataPipe( {k: v for k, v in data_pipeline.items() if k != "config"}, ).map(split_dataset_dict_dp) @@ -412,7 +483,13 @@ def main(config: DictConfig): # Create a dataloader for the concurrent batches and use multiprocessing dataloader = DataLoader(batch_pipe, **dataloader_kwargs) # Load the PVNet model - model, *_ = get_model_from_checkpoints([model_chckpoint_dir], val_best=True) + if model_chckpoint_dir: + model, *_ = get_model_from_checkpoints([model_chckpoint_dir], val_best=True) + elif hf_model_id: + model = load_model_from_hf(hf_model_id, hf_revision, hf_token) + else: + raise ValueError("Provide a model checkpoint or a HuggingFace model") + model = model.eval().to(device) # Create object to make predictions for each input batch @@ -426,13 +503,13 @@ def main(config: DictConfig): t0 = ds_abs_all.init_time_utc.values[0] - # Save the predictioons + # Save the predictions filename = f"{output_dir}/{t0}.nc" ds_abs_all.to_netcdf(filename) pbar.update() except Exception as e: - print(f"Exception {e} at {i}") + print(f"Exception {e} at batch {i}") pass # Close down From ce664576b4eb912f176de312e489fac1c15d069d Mon Sep 17 00:00:00 2001 From: BumpVersion Action Date: Fri, 18 Oct 2024 13:12:45 +0000 Subject: [PATCH 19/27] =?UTF-8?q?Bump=20version:=203.0.59=20=E2=86=92=203.?= =?UTF-8?q?0.60=20[skip=20ci]?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .bumpversion.cfg | 2 +- pvnet/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.bumpversion.cfg b/.bumpversion.cfg index bc58fbc9..1749a8bd 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,7 +1,7 @@ [bumpversion] commit = True tag = True -current_version = 3.0.59 +current_version = 3.0.60 message = Bump version: {current_version} → {new_version} [skip ci] [bumpversion:file:pvnet/__init__.py] diff --git a/pvnet/__init__.py b/pvnet/__init__.py index f9b4efb2..370f2ded 100644 --- a/pvnet/__init__.py +++ b/pvnet/__init__.py @@ -1,2 +1,2 @@ """PVNet""" -__version__ = "3.0.59" +__version__ = "3.0.60" From fdd468befa981d15282a183488822215aaae8da8 Mon Sep 17 00:00:00 2001 From: Sukhil Patel <42407101+Sukh-P@users.noreply.github.com> Date: Tue, 22 Oct 2024 12:03:15 +0100 Subject: [PATCH 20/27] Remove run app locally script (#264) --- scripts/run_app_locally.py | 104 ------------------------------------- 1 file changed, 104 deletions(-) delete mode 100644 scripts/run_app_locally.py diff --git a/scripts/run_app_locally.py b/scripts/run_app_locally.py deleted file mode 100644 index a78efe5e..00000000 --- a/scripts/run_app_locally.py +++ /dev/null @@ -1,104 +0,0 @@ -"""Script to run the production app locally""" - -import logging -import os -import time -from datetime import timedelta - -import numpy as np -import pandas as pd -import xarray as xr -from ocf_datapipes.load import OpenGSPFromDatabase - -from pvnet.app import app - -formatter = logging.Formatter(fmt="%(levelname)s:%(name)s:%(message)s") -stream_handler = logging.StreamHandler() -stream_handler.setFormatter(formatter) - -logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) -logger.addHandler(stream_handler) - - -def sleep_until(wake_time): - """Sleep until the given time""" - now = pd.Timestamp.now() - sleep_duration = (wake_time - now).total_seconds() - if sleep_duration < 0: - logger.warning("Sleep for negative duration requested") - else: - logger.info(f"Sleeping for {sleep_duration} seconds") - time.sleep(sleep_duration) - - -if __name__ == "__main__": - # ---------------------------------------------------- - # USER SETTINGS - - # When to start and stop predictions - start_time = pd.Timestamp("2023-05-31 00:00") - end_time = pd.Timestamp("2023-06-05 21:00") - - output_dir = "/mnt/disks/batches/local_production_forecasts" - save_inputs = True - - # ---------------------------------------------------- - # RUN - - # Make output dirs - os.makedirs(f"{output_dir}/predictions", exist_ok=True) - os.makedirs(f"{output_dir}/logs", exist_ok=True) - if save_inputs: - os.makedirs(f"{output_dir}/inputs", exist_ok=True) - - # Wait until start time - if pd.Timestamp.now() < start_time: - sleep_until(start_time) - - while pd.Timestamp.now() < end_time: - # Next prediction time - t0 = pd.Timestamp.now().ceil(timedelta(minutes=30)) - - # Sleep until next prediction time - sleep_until(t0) - - try: - # Make predictions - df = app(write_predictions=False) - - # Save - df.to_csv(f"{output_dir}/predictions/{t0}.csv") - except Exception: - logger.exception(f"Predictions for {t0=} failed") - - try: - # Log delays of data sources - log = dict( - now=t0, - gsp_times=next(iter(OpenGSPFromDatabase())).time_utc.values, - sat_times=xr.open_zarr("latest.zarr.zip").time.values, - nwp_times=xr.open_zarr(os.environ["NWP_ZARR_PATH"]).init_time.values, - ) - np.save(f"{output_dir}/logs/{t0}.npy", log) - except Exception: - logger.exception(f"Logs for {t0=} failed") - - if save_inputs: - try: - # Set up directory to save inputs - input_dir = f"{output_dir}/inputs/{t0}" - os.makedirs(input_dir, exist_ok=True) - - # Save inputs - os.system(f"cp latest.zarr.zip '{input_dir}/sat.zarr.zip'") - - ds = xr.open_zarr(os.environ["NWP_ZARR_PATH"]) - for v in ds.variables: - ds[v].encoding.clear() - ds.to_zarr(f"{input_dir}/nwp.zarr") - - next(iter(OpenGSPFromDatabase())).to_dataset().to_zarr(f"{input_dir}/gsp.zarr") - - except Exception: - logger.exception(f"Saving inputs for {t0=} failed") From e5f34cab11cb8ab9e26dad2ac50933beb19bd8ef Mon Sep 17 00:00:00 2001 From: BumpVersion Action Date: Tue, 22 Oct 2024 11:03:45 +0000 Subject: [PATCH 21/27] =?UTF-8?q?Bump=20version:=203.0.60=20=E2=86=92=203.?= =?UTF-8?q?0.61=20[skip=20ci]?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .bumpversion.cfg | 2 +- pvnet/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.bumpversion.cfg b/.bumpversion.cfg index 1749a8bd..4e523b63 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,7 +1,7 @@ [bumpversion] commit = True tag = True -current_version = 3.0.60 +current_version = 3.0.61 message = Bump version: {current_version} → {new_version} [skip ci] [bumpversion:file:pvnet/__init__.py] diff --git a/pvnet/__init__.py b/pvnet/__init__.py index 370f2ded..a97ea0bb 100644 --- a/pvnet/__init__.py +++ b/pvnet/__init__.py @@ -1,2 +1,2 @@ """PVNet""" -__version__ = "3.0.60" +__version__ = "3.0.61" From 05abb8ba9d4bc652f77a271aebadc349241a8c4a Mon Sep 17 00:00:00 2001 From: Alexandra Udaltsova <43303448+AUdaltsova@users.noreply.github.com> Date: Wed, 23 Oct 2024 14:58:00 +0100 Subject: [PATCH 22/27] update backtest_sites.py (#266) * Update backtest_sites.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- scripts/backtest_sites.py | 31 +++++++++++++++++++++++++------ 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/scripts/backtest_sites.py b/scripts/backtest_sites.py index 3572daa3..65d63be2 100644 --- a/scripts/backtest_sites.py +++ b/scripts/backtest_sites.py @@ -118,21 +118,38 @@ class PadForwardPVIterDataPipe(IterDataPipe): to run out of data to slice for the forecast part. """ - def __init__(self, pv_dp: IterDataPipe, forecast_duration: np.timedelta64): + def __init__( + self, + pv_dp: IterDataPipe, + forecast_duration: np.timedelta64, + history_duration: np.timedelta64, + time_resolution_minutes: np.timedelta64, + ): """Init""" super().__init__() self.pv_dp = pv_dp self.forecast_duration = forecast_duration + self.history_duration = history_duration + self.time_resolution_minutes = time_resolution_minutes + + self.min_seq_length = history_duration // time_resolution_minutes def __iter__(self): """Iter""" for xr_data in self.pv_dp: - t0 = xr_data.time_utc.data[int(xr_data.attrs["t0_idx"])] - pv_step = np.timedelta64(xr_data.attrs["sample_period_duration"]) - t_end = t0 + self.forecast_duration + pv_step - time_idx = np.arange(xr_data.time_utc.data[0], t_end, pv_step) + t_end = ( + xr_data.time_utc.data[0] + + self.history_duration + + self.forecast_duration + + self.time_resolution_minutes + ) + time_idx = np.arange(xr_data.time_utc.data[0], t_end, self.time_resolution_minutes) + + if len(xr_data.time_utc.data) < self.min_seq_length: + raise ValueError("Not enough PV data to predict") + yield xr_data.reindex(time_utc=time_idx, fill_value=-1) @@ -430,7 +447,9 @@ def get_datapipe(config_path: str) -> NumpyBatch: config = load_yaml_configuration(config_path) data_pipeline["pv"] = data_pipeline["pv"].pad_forward_pv( - forecast_duration=np.timedelta64(config.input_data.pv.forecast_minutes, "m") + forecast_duration=np.timedelta64(config.input_data.pv.forecast_minutes, "m"), + history_duration=np.timedelta64(config.input_data.pv.history_minutes, "m"), + time_resolution_minutes=np.timedelta64(config.input_data.pv.time_resolution_minutes, "m"), ) data_pipeline = DictDatasetIterDataPipe( From c259a1cbb7c2eeb8570914f95d2cac5667024424 Mon Sep 17 00:00:00 2001 From: BumpVersion Action Date: Wed, 23 Oct 2024 13:58:39 +0000 Subject: [PATCH 23/27] =?UTF-8?q?Bump=20version:=203.0.61=20=E2=86=92=203.?= =?UTF-8?q?0.62=20[skip=20ci]?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .bumpversion.cfg | 2 +- pvnet/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.bumpversion.cfg b/.bumpversion.cfg index 4e523b63..0fc3026e 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,7 +1,7 @@ [bumpversion] commit = True tag = True -current_version = 3.0.61 +current_version = 3.0.62 message = Bump version: {current_version} → {new_version} [skip ci] [bumpversion:file:pvnet/__init__.py] diff --git a/pvnet/__init__.py b/pvnet/__init__.py index a97ea0bb..1d157cca 100644 --- a/pvnet/__init__.py +++ b/pvnet/__init__.py @@ -1,2 +1,2 @@ """PVNet""" -__version__ = "3.0.61" +__version__ = "3.0.62" From 5360394e3c46bea4e9cacc5d1c5e34d63db6408e Mon Sep 17 00:00:00 2001 From: Alexandra Udaltsova <43303448+AUdaltsova@users.noreply.github.com> Date: Fri, 25 Oct 2024 13:50:21 +0100 Subject: [PATCH 24/27] add flushing of val epoch resluts (#256) * add flushing of val epoch resluts * add docstring back base_model.py * log to csv only when end of accumbatch --- pvnet/models/base_model.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index 8665bf3c..83e67e7b 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -666,7 +666,8 @@ def validation_step(self, batch: dict, batch_idx): # Sensor seems to be in batch, station, time order y = batch[self._target_key][:, -self.forecast_len :, 0] - self._log_validation_results(batch, y_hat, accum_batch_num) + if (batch_idx + 1) % self.trainer.accumulate_grad_batches == 0: + self._log_validation_results(batch, y_hat, accum_batch_num) # Expand persistence to be the same shape as y losses = self._calculate_common_losses(y, y_hat) @@ -743,6 +744,7 @@ def on_validation_epoch_end(self): print("Failed to log validation results to wandb") print(e) + self.validation_epoch_results = [] horizon_maes_dict = self._horizon_maes.flush() # Create the horizon accuracy curve From 9497430a902045b4ad52db0fee5499fb31f0cf77 Mon Sep 17 00:00:00 2001 From: BumpVersion Action Date: Fri, 25 Oct 2024 12:50:53 +0000 Subject: [PATCH 25/27] =?UTF-8?q?Bump=20version:=203.0.62=20=E2=86=92=203.?= =?UTF-8?q?0.63=20[skip=20ci]?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .bumpversion.cfg | 2 +- pvnet/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.bumpversion.cfg b/.bumpversion.cfg index 0fc3026e..5d3bbe27 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,7 +1,7 @@ [bumpversion] commit = True tag = True -current_version = 3.0.62 +current_version = 3.0.63 message = Bump version: {current_version} → {new_version} [skip ci] [bumpversion:file:pvnet/__init__.py] diff --git a/pvnet/__init__.py b/pvnet/__init__.py index 1d157cca..3b41ad55 100644 --- a/pvnet/__init__.py +++ b/pvnet/__init__.py @@ -1,2 +1,2 @@ """PVNet""" -__version__ = "3.0.62" +__version__ = "3.0.63" From acb6a36736be9b1094c51ddc3aa7cbc4844ab9e9 Mon Sep 17 00:00:00 2001 From: Sukhil Patel <42407101+Sukh-P@users.noreply.github.com> Date: Fri, 8 Nov 2024 11:58:33 +0000 Subject: [PATCH 26/27] Update MAE analysis script (#274) Update script --- experiments/{analysis.py => mae_analysis.py} | 42 +++++++++++++------- 1 file changed, 28 insertions(+), 14 deletions(-) rename experiments/{analysis.py => mae_analysis.py} (74%) diff --git a/experiments/analysis.py b/experiments/mae_analysis.py similarity index 74% rename from experiments/analysis.py rename to experiments/mae_analysis.py index bb119664..ac01aed2 100644 --- a/experiments/analysis.py +++ b/experiments/mae_analysis.py @@ -1,5 +1,8 @@ """ -Script to generate a table comparing two run for MAE values for 48 hour 15 minute forecast +Script to generate analysis of MAE values for multiple model forecasts + +Does this for 48 hour horizon forecasts with 15 minute granularity + """ import argparse @@ -10,15 +13,21 @@ import wandb -def main(runs: list[str], run_names: list[str]) -> None: +def main(project: str, runs: list[str], run_names: list[str]) -> None: """ - Compare two runs for MAE values for 48 hour 15 minute forecast + Compare MAE values for multiple model forecasts for 48 hour horizon with 15 minute granularity + + Args: + project: name of W&B project + runs: W&B ids of runs + run_names: user specified names for runs + """ api = wandb.Api() dfs = [] epoch_num = [] for run in runs: - run = api.run(f"openclimatefix/PROJECT/{run}") + run = api.run(f"openclimatefix/{project}/{run}") df = run.history(samples=run.lastHistoryStep + 1) # Get the columns that are in the format 'MAE_horizon/step_/val` @@ -88,10 +97,12 @@ def main(runs: list[str], run_names: list[str]) -> None: for idx, df in enumerate(dfs): print(f"{run_names[idx]}: {df.mean()*100:0.3f}") - # Plot the error on per timestep, and all timesteps + # Plot the error per timestep plt.figure() for idx, df in enumerate(dfs): - plt.plot(column_timesteps, df, label=f"{run_names[idx]}, epoch: {epoch_num[idx]}") + plt.plot( + column_timesteps, df, label=f"{run_names[idx]}, epoch: {epoch_num[idx]}", linestyle="-" + ) plt.legend() plt.xlabel("Timestep (minutes)") plt.ylabel("MAE %") @@ -99,25 +110,28 @@ def main(runs: list[str], run_names: list[str]) -> None: plt.savefig("mae_per_timestep.png") plt.show() - # Plot the error on per timestep, and grouped timesteps + # Plot the error per grouped timestep plt.figure() for idx, run_name in enumerate(run_names): - plt.plot(groups_df[run_name], label=f"{run_name}, epoch: {epoch_num[idx]}") + plt.plot( + groups_df[run_name], + label=f"{run_name}, epoch: {epoch_num[idx]}", + marker="o", + linestyle="-", + ) plt.legend() plt.xlabel("Timestep (minutes)") plt.ylabel("MAE %") - plt.title("MAE % for each timestep") - plt.savefig("mae_per_timestep.png") + plt.title("MAE % for each grouped timestep") + plt.savefig("mae_per_grouped_timestep.png") plt.show() if __name__ == "__main__": parser = argparse.ArgumentParser() - "5llq8iw6" - parser.add_argument("--first_run", type=str, default="xdlew7ib") - parser.add_argument("--second_run", type=str, default="v3mja33d") + parser.add_argument("--project", type=str, default="") # Add arguments that is a list of strings parser.add_argument("--list_of_runs", nargs="+") parser.add_argument("--run_names", nargs="+") args = parser.parse_args() - main(args.list_of_runs, args.run_names) + main(args.project, args.list_of_runs, args.run_names) From db81147207e39fb6bfa7c8c0a06f53dacf9a2339 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 13 Nov 2024 10:11:12 +0000 Subject: [PATCH 27/27] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pvnet/data/datamodule.py | 38 +++++++------------ pvnet/models/base_model.py | 20 +++------- pvnet/models/utils.py | 2 - pvnet/utils.py | 4 +- scripts/save_samples.py | 78 +++++++++++++++++++------------------- 5 files changed, 59 insertions(+), 83 deletions(-) diff --git a/pvnet/data/datamodule.py b/pvnet/data/datamodule.py index b0df80ce..a834efbe 100644 --- a/pvnet/data/datamodule.py +++ b/pvnet/data/datamodule.py @@ -1,15 +1,11 @@ """ Data module for pytorch lightning """ -from datetime import datetime from glob import glob -from lightning.pytorch import LightningDataModule -from torch.utils.data import Dataset, DataLoader import torch - -from ocf_datapipes.batch import batch_to_tensor, stack_np_examples_into_batch, NumpyBatch -from ocf_data_sampler.torch_datasets.pvnet_uk_regional import ( - PVNetUKRegionalDataset -) +from lightning.pytorch import LightningDataModule +from ocf_data_sampler.torch_datasets.pvnet_uk_regional import PVNetUKRegionalDataset +from ocf_datapipes.batch import NumpyBatch, batch_to_tensor, stack_np_examples_into_batch +from torch.utils.data import DataLoader, Dataset def fill_nans_in_arrays(batch): @@ -29,30 +25,28 @@ def fill_nans_in_arrays(batch): return batch - class NumpybatchPremadeSamplesDataset(Dataset): """Dataset to load NumpyBatch samples""" - + def __init__(self, sample_dir): """Dataset to load NumpyBatch samples - + Args: sample_dir: Path to the directory of pre-saved samples. """ self.sample_paths = glob(f"{sample_dir}/*.pt") - - + def __len__(self): return len(self.sample_paths) - + def __getitem__(self, idx): return fill_nans_in_arrays(torch.load(self.sample_paths[idx])) - + def collate_fn(samples: list[NumpyBatch]): """Convert a list of NumpyBatch samples to a tensor batch""" return batch_to_tensor(stack_np_examples_into_batch(samples)) - + class DataModule(LightningDataModule): """Datamodule for training pvnet and using pvnet pipeline in `ocf_datapipes`.""" @@ -64,9 +58,8 @@ def __init__( batch_size: int = 16, num_workers: int = 0, prefetch_factor: int | None = None, - train_period: list[str|None] = [None, None], - val_period: list[str|None] = [None, None], - + train_period: list[str | None] = [None, None], + val_period: list[str | None] = [None, None], ): """Datamodule for training pvnet architecture. @@ -85,7 +78,6 @@ def __init__( """ super().__init__() - if not ((sample_dir is not None) ^ (configuration is not None)): raise ValueError("Exactly one of `sample_dir` or `configuration` must be set.") @@ -118,7 +110,7 @@ def _get_streamed_samples_dataset(self, start_time, end_time) -> Dataset: def _get_premade_samples_dataset(self, subdir) -> Dataset: split_dir = f"{self.sample_dir}/{subdir}" return NumpybatchPremadeSamplesDataset(split_dir) - + def train_dataloader(self) -> DataLoader: """Construct train dataloader""" if self.sample_dir is not None: @@ -126,7 +118,7 @@ def train_dataloader(self) -> DataLoader: else: dataset = self._get_streamed_samples_dataset(*self.train_period) return DataLoader(dataset, shuffle=True, **self._common_dataloader_kwargs) - + def val_dataloader(self) -> DataLoader: """Construct val dataloader""" if self.sample_dir is not None: @@ -134,5 +126,3 @@ def val_dataloader(self) -> DataLoader: else: dataset = self._get_streamed_samples_dataset(*self.val_period) return DataLoader(dataset, shuffle=False, **self._common_dataloader_kwargs) - - \ No newline at end of file diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index 4046cf41..57313265 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -18,11 +18,7 @@ from huggingface_hub.constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME from huggingface_hub.file_download import hf_hub_download from huggingface_hub.hf_api import HfApi - -from ocf_datapipes.batch import BatchKey -from ocf_datapipes.batch import copy_batch_to_device - -from ocf_ml_metrics.evaluation.evaluation import evaluation +from ocf_datapipes.batch import BatchKey, copy_batch_to_device from pvnet.models.utils import ( BatchAccumulator, @@ -32,8 +28,6 @@ from pvnet.optimizers import AbstractOptimizer from pvnet.utils import plot_batch_forecasts - - DATA_CONFIG_NAME = "data_config.yaml" @@ -239,13 +233,11 @@ def get_data_config( ) return data_config_file - - + def _save_pretrained(self, save_directory: Path) -> None: """Save weights from a Pytorch model to a local directory.""" model_to_save = self.module if hasattr(self, "module") else self # type: ignore torch.save(model_to_save.state_dict(), save_directory / PYTORCH_WEIGHTS_NAME) - def save_pretrained( self, @@ -416,14 +408,14 @@ def __init__( self.num_output_features = self.forecast_len * len(self.output_quantiles) else: self.num_output_features = self.forecast_len - + # save all validation results to array, so we can save these to weights n biases self.validation_epoch_results = [] def transfer_batch_to_device(self, batch, device, dataloader_idx): """Method to move custom batches to a given device""" return copy_batch_to_device(batch, device) - + def _quantiles_to_prediction(self, y_quantiles): """ Convert network prediction into a point prediction. @@ -465,7 +457,7 @@ def _calculate_quantile_loss(self, y_quantiles, y): errors = y - y_quantiles[..., i] losses.append(torch.max((q - 1) * errors, q * errors).unsqueeze(-1)) losses = 2 * torch.cat(losses, dim=2) - + return losses.mean() def _calculate_common_losses(self, y, y_hat): @@ -659,7 +651,7 @@ def validation_step(self, batch: dict, batch_idx): accum_batch_num = batch_idx // self.trainer.accumulate_grad_batches y_hat = self(batch) - + y = batch[self._target_key][:, -self.forecast_len :] if (batch_idx + 1) % self.trainer.accumulate_grad_batches == 0: diff --git a/pvnet/models/utils.py b/pvnet/models/utils.py index 20b80223..2bbe78f2 100644 --- a/pvnet/models/utils.py +++ b/pvnet/models/utils.py @@ -1,8 +1,6 @@ """Utility functions""" import logging -import math -from typing import Optional import numpy as np import torch diff --git a/pvnet/utils.py b/pvnet/utils.py index d515e819..2713558d 100644 --- a/pvnet/utils.py +++ b/pvnet/utils.py @@ -6,7 +6,6 @@ import lightning.pytorch as pl import matplotlib.pyplot as plt -import numpy as np import pandas as pd import pylab import rich.syntax @@ -16,7 +15,6 @@ from lightning.pytorch.utilities import rank_zero_only from ocf_datapipes.batch import BatchKey from ocf_datapipes.utils import Location -from ocf_datapipes.utils.geospatial import osgb_to_lon_lat from omegaconf import DictConfig, OmegaConf @@ -322,4 +320,4 @@ def _get_numpy(key): plt.suptitle(title) plt.tight_layout() - return fig \ No newline at end of file + return fig diff --git a/scripts/save_samples.py b/scripts/save_samples.py index b047c013..c7f306b7 100644 --- a/scripts/save_samples.py +++ b/scripts/save_samples.py @@ -20,41 +20,38 @@ ``` if wanting to override these values for example """ - + # Ensure this block of code runs only in the main process to avoid issues with worker processes. if __name__ == "__main__": import torch.multiprocessing as mp - - # Set the start method for torch multiprocessing. Choose either "forkserver" or "spawn" to be - # compatible with dask's multiprocessing. + + # Set the start method for torch multiprocessing. Choose either "forkserver" or "spawn" to be + # compatible with dask's multiprocessing. mp.set_start_method("forkserver") - - # Set the sharing strategy to 'file_system' to handle file descriptor limitations. This is - # important because libraries like Zarr may open many files, which can exhaust the file + + # Set the sharing strategy to 'file_system' to handle file descriptor limitations. This is + # important because libraries like Zarr may open many files, which can exhaust the file # descriptor limit if too many workers are used. - mp.set_sharing_strategy('file_system') + mp.set_sharing_strategy("file_system") +import logging import os -import sys import shutil -import logging +import sys import warnings +import dask import hydra +import torch +from ocf_data_sampler.torch_datasets.pvnet_uk_regional import PVNetUKRegionalDataset from omegaconf import DictConfig, OmegaConf from sqlalchemy import exc as sa_exc +from torch.utils.data import DataLoader, Dataset from tqdm import tqdm -import torch -from torch.utils.data import Dataset, DataLoader - -from ocf_data_sampler.torch_datasets.pvnet_uk_regional import PVNetUKRegionalDataset - from pvnet.utils import print_config -import dask - dask.config.set(scheduler="threads", num_workers=4) @@ -71,6 +68,7 @@ class SaveFuncFactory: """Factory for creating a function to save a sample to disk.""" + def __init__(self, save_dir: str, renewable: str = "pv"): self.save_dir = save_dir self.renewable = renewable @@ -86,22 +84,22 @@ def __call__(self, sample, sample_num: int): def get_dataset(config_path: str, start_time: str, end_time: str, renewable: str = "pv") -> Dataset: """Get the dataset for the given renewable type.""" - if renewable== "pv": - dataset_cls = PVNetUKRegionalDataset + if renewable == "pv": + dataset_cls = PVNetUKRegionalDataset elif renewable in ["wind", "pv_india", "pv_site"]: raise NotImplementedError else: raise ValueError(f"Unknown renewable: {renewable}") - + return dataset_cls(config_path, start_time=start_time, end_time=end_time) def save_samples_with_dataloader( - dataset: Dataset, - save_dir: str, - num_samples: int, - dataloader_kwargs: dict, - renewable: str = "pv" + dataset: Dataset, + save_dir: str, + num_samples: int, + dataloader_kwargs: dict, + renewable: str = "pv", ) -> None: """Save samples from a dataset using a dataloader.""" save_func = SaveFuncFactory(save_dir, renewable=renewable) @@ -124,7 +122,7 @@ def main(config: DictConfig) -> None: # Set up directory os.makedirs(config_dm.sample_output_dir, exist_ok=False) - + # Copy across configs which define the samples into the new sample directory with open(f"{config_dm.sample_output_dir}/datamodule.yaml", "w") as f: f.write(OmegaConf.to_yaml(config_dm)) @@ -141,29 +139,29 @@ def main(config: DictConfig) -> None: batch_sampler=None, num_workers=config_dm.num_workers, collate_fn=None, - pin_memory=False, # Only using CPU to prepare samples so pinning is not beneficial + pin_memory=False, # Only using CPU to prepare samples so pinning is not beneficial drop_last=False, timeout=0, worker_init_fn=None, prefetch_factor=config_dm.prefetch_factor, - persistent_workers=False, # Not needed since we only enter the dataloader loop once + persistent_workers=False, # Not needed since we only enter the dataloader loop once ) if config_dm.num_val_samples > 0: print("----- Saving val samples -----") - + val_output_dir = f"{config_dm.sample_output_dir}/val" - + # Make directory for val samples os.mkdir(val_output_dir) - - # Get the dataset + + # Get the dataset val_dataset = get_dataset( config_dm.configuration, *config_dm.val_period, renewable=config.renewable, ) - + # Save samples save_samples_with_dataloader( dataset=val_dataset, @@ -172,24 +170,24 @@ def main(config: DictConfig) -> None: dataloader_kwargs=dataloader_kwargs, renewable=config.renewable, ) - + del val_dataset if config_dm.num_train_samples > 0: print("----- Saving train samples -----") - + train_output_dir = f"{config_dm.sample_output_dir}/train" - + # Make directory for train samples os.mkdir(train_output_dir) - - # Get the dataset + + # Get the dataset train_dataset = get_dataset( config_dm.configuration, *config_dm.train_period, renewable=config.renewable, ) - + # Save samples save_samples_with_dataloader( dataset=train_dataset, @@ -198,7 +196,7 @@ def main(config: DictConfig) -> None: dataloader_kwargs=dataloader_kwargs, renewable=config.renewable, ) - + del train_dataset print("----- Saving complete -----")