From 4fbbf15a6a66db1e2132f91ac1f37994c627b45d Mon Sep 17 00:00:00 2001 From: Sukhil Patel Date: Tue, 17 Dec 2024 16:17:18 +0000 Subject: [PATCH 01/13] Add sample saving for Site Dataset --- README.md | 12 +- .../configuration/example_configuration.yaml | 172 ------------------ .../site_example_configuration.yaml | 70 +++++++ .../datamodule/premade_batches.yaml | 2 +- .../datamodule/streamed_batches.yaml | 10 +- pvnet/data/datamodule.py | 9 +- pvnet/data/pv_site_datamodule.py | 67 ------- pvnet/data/wind_datamodule.py | 62 ------- pyproject.toml | 2 +- scripts/save_samples.py | 6 +- 10 files changed, 91 insertions(+), 321 deletions(-) delete mode 100644 configs.example/datamodule/configuration/example_configuration.yaml create mode 100644 configs.example/datamodule/configuration/site_example_configuration.yaml delete mode 100644 pvnet/data/pv_site_datamodule.py delete mode 100644 pvnet/data/wind_datamodule.py diff --git a/README.md b/README.md index 051e9a2e..69f54c64 100644 --- a/README.md +++ b/README.md @@ -145,20 +145,20 @@ This is also where you can update the train, val & test periods to cover the dat ### Running the batch creation script -Run the `save_batches.py` script to create batches with the parameters specified in the datamodule config (`streamed_batches.yaml` in this example): +Run the `save_samples.py` script to create batches with the parameters specified in the datamodule config (`streamed_batches.yaml` in this example): ```bash -python scripts/save_batches.py +python scripts/save_samples.py ``` PVNet uses [hydra](https://hydra.cc/) which enables us to pass variables via the command line that will override the configuration defined in the `./configs` directory, like this: ```bash -python scripts/save_batches.py datamodule=streamed_batches datamodule.batch_output_dir="./output" datamodule.num_train_batches=10 datamodule.num_val_batches=5 +python scripts/save_samples.py datamodule=streamed_batches datamodule.sample_output_dir="./output" datamodule.num_train_batches=10 datamodule.num_val_batches=5 ``` -`scripts/save_batches.py` needs a config under `PVNet/configs/datamodule`. You can adapt `streamed_batches.yaml` or create your own in the same folder. +`scripts/save_samples.py` needs a config under `PVNet/configs/datamodule`. You can adapt `streamed_batches.yaml` or create your own in the same folder. If downloading private data from a GCP bucket make sure to authenticate gcloud (the public satellite data does not need authentication): @@ -197,7 +197,7 @@ Make sure to update the following config files before training your model: 2. In `configs/model/local_multimodal.yaml`: - update the list of encoders to reflect the data sources you are using. If you are using different NWP sources, the encoders for these should follow the same structure with two important updates: - `in_channels`: number of variables your NWP source supplies - - `image_size_pixels`: spatial crop of your NWP data. It depends on the spatial resolution of your NWP; should match `nwp_image_size_pixels_height` and/or `nwp_image_size_pixels_width` in `datamodule/example_configs.yaml`, unless transformations such as coarsening was applied (e. g. as for ECMWF data) + - `image_size_pixels`: spatial crop of your NWP data. It depends on the spatial resolution of your NWP; should match `image_size_pixels_height` and/or `image_size_pixels_width` in `datamodule/configuration/site_example_configuration.yaml` for the NWP, unless transformations such as coarsening was applied (e. g. as for ECMWF data) 3. In `configs/local_trainer.yaml`: - set `accelerator: 0` if running on a system without a supported GPU @@ -216,7 +216,7 @@ defaults: - hydra: default.yaml ``` -Assuming you ran the `save_batches.py` script to generate some premade train and +Assuming you ran the `save_samples.py` script to generate some premade train and val data batches, you can now train PVNet by running: ``` diff --git a/configs.example/datamodule/configuration/example_configuration.yaml b/configs.example/datamodule/configuration/example_configuration.yaml deleted file mode 100644 index 827b0e9d..00000000 --- a/configs.example/datamodule/configuration/example_configuration.yaml +++ /dev/null @@ -1,172 +0,0 @@ -general: - description: Example data config for creating PVNet batches - name: example_pvnet - -input_data: - default_history_minutes: 120 - default_forecast_minutes: 480 - - gsp: - # Path to the GSP data. This should be a zarr file - # e.g. gs://solar-pv-nowcasting-data/PV/GSP/v7/pv_gsp.zarr - gsp_zarr_path: PLACEHOLDER.zarr - history_minutes: 120 - forecast_minutes: 480 - time_resolution_minutes: 30 - # A random value from the list below will be chosen as the delay when dropout is used - # If set to null no dropout is applied. Only values before t0 are dropped out for GSP. - # Values after t0 are assumed as targets and cannot be dropped. - dropout_timedeltas_minutes: null - dropout_fraction: 0 # Fraction of samples with dropout - - pv: - pv_files_groups: - - label: solar_sheffield_passiv - # Path to the site-level PV data. This should be a netcdf - # e.g gs://solar-pv-nowcasting-data/PV/Passive/ocf_formatted/v0/passiv.netcdf - pv_filename: PLACEHOLDER.netcdf - # Path to the site-level PV metadata. This choudl be a csv - # e.g gs://solar-pv-nowcasting-data/PV/Passive/ocf_formatted/v0/system_metadata.csv - pv_metadata_filename: PLACEHOLDER.csv - # This is the list of pv_ml_ids to be sliced from the PV site level data - # The IDs below are 349 of the PV systems which have very little NaN data in the historic data - # and which are still reporting live (as of Oct 2023) - pv_ml_ids: - [ - 154, 155, 156, 158, 159, 160, 162, 164, 165, 166, 167, 168, 169, 171, 173, 177, 178, 179, - 181, 182, 185, 186, 187, 188, 189, 190, 191, 192, 193, 197, 198, 199, 200, 202, 204, 205, - 206, 208, 209, 211, 214, 215, 216, 217, 218, 219, 220, 221, 225, 229, 230, 232, 233, 234, - 236, 242, 243, 245, 252, 254, 255, 256, 257, 258, 260, 261, 262, 265, 267, 268, 272, 273, - 275, 276, 277, 280, 281, 282, 283, 287, 289, 291, 292, 293, 294, 295, 296, 297, 298, 301, - 302, 303, 304, 306, 307, 309, 310, 311, 317, 318, 319, 320, 321, 322, 323, 325, 326, 329, - 332, 333, 335, 336, 338, 340, 342, 344, 345, 346, 348, 349, 352, 354, 355, 356, 357, 360, - 362, 363, 368, 369, 370, 371, 372, 374, 375, 376, 378, 380, 382, 384, 385, 388, 390, 391, - 393, 396, 397, 398, 399, 400, 401, 403, 404, 405, 406, 407, 409, 411, 412, 413, 414, 415, - 416, 417, 418, 419, 420, 421, 422, 423, 424, 425, 426, 427, 429, 431, 435, 437, 438, 440, - 441, 444, 447, 450, 451, 453, 456, 457, 458, 459, 464, 465, 466, 467, 468, 470, 471, 473, - 474, 476, 477, 479, 480, 481, 482, 485, 486, 488, 490, 491, 492, 493, 496, 498, 501, 503, - 506, 507, 508, 509, 510, 511, 512, 513, 515, 516, 517, 519, 520, 521, 522, 524, 526, 527, - 528, 531, 532, 536, 537, 538, 540, 541, 542, 543, 544, 545, 549, 550, 551, 552, 553, 554, - 556, 557, 560, 561, 563, 566, 568, 571, 572, 575, 576, 577, 579, 580, 581, 582, 584, 585, - 588, 590, 594, 595, 597, 600, 602, 603, 604, 606, 611, 613, 614, 616, 618, 620, 622, 623, - 624, 625, 626, 628, 629, 630, 631, 636, 637, 638, 640, 641, 642, 644, 645, 646, 650, 651, - 652, 653, 654, 655, 657, 660, 661, 662, 663, 666, 667, 668, 670, 675, 676, 679, 681, 683, - 684, 685, 687, 696, 698, 701, 702, 703, 704, 706, 710, 722, 723, 724, 725, 727, 728, 729, - 730, 732, 733, 734, 735, 736, 737 - ] - history_minutes: 180 - forecast_minutes: 0 - time_resolution_minutes: 5 - # A random value from the list below will be chosen as the delay when dropout is used. - # If set to null no dropout is applied. All PV systems are dropped together with this setting. - dropout_timedeltas_minutes: null - dropout_fraction: 0 # Fraction of samples with dropout - # A random value from the list below will be chosen as the delay when system dropout is used. - # If set to null no dropout is applied. All PV systems are indpendently with this setting. - system_dropout_timedeltas_minutes: null - # For ech sample a differnt dropout probability is used which is uniformly sampled from the min - # and max below - system_dropout_fraction_min: 0 - system_dropout_fraction_max: 0 - - nwp: - ukv: - nwp_provider: ukv - nwp_zarr_path: - # Path(s) to UKV NWP data in zarr format - # e.g. gs://solar-pv-nowcasting-data/NWP/UK_Met_Office/UKV_intermediate_version_7.zarr - - PLACEHOLDER.zarr - history_minutes: 120 - forecast_minutes: 480 - time_resolution_minutes: 60 - nwp_channels: - # These variables exist in the CEDA training set and in the live MetOffice live service - - t # 2-metre temperature - - dswrf # downwards short-wave radiation flux - - dlwrf # downwards long-wave radiation flux - - hcc # high cloud cover - - mcc # medium cloud cover - - lcc # low cloud cover - - sde # snow depth water equivalent - - r # relative humidty - - vis # visibility - - si10 # 10-metre wind speed - - wdir10 # 10-metre wind direction - - prate # precipitation rate - # These variables exist in CEDA training data but not in the live MetOffice live service - - hcct # height of convective cloud top, meters above surface. NaN if no clouds - - cdcb # height of lowest cloud base > 3 oktas - - dpt # dew point temperature - - prmsl # mean sea level pressure - - h # geometrical? (maybe geopotential?) height - nwp_image_size_pixels_height: 24 - nwp_image_size_pixels_width: 24 - # A random value from the list below will be chosen as the delay when dropout is used - # If set to null no dropout is applied. Values must be negative. - dropout_timedeltas_minutes: [-180] - # Dropout applied with this probability - dropout_fraction: 1.0 - # How long after the NWP init-time are we still willing to use this forecast - # If null we use each init-time for all steps it covers - max_staleness_minutes: null - - ecmwf: - nwp_provider: ecmwf - # Path to ECMWF NWP data in zarr format - # n.b. It is not necessary to use multiple or any NWP data. These entries can be removed - nwp_zarr_path: PLACEHOLDER.zarr - history_minutes: 120 - forecast_minutes: 480 - time_resolution_minutes: 60 - nwp_channels: - - t2m # 2-metre temperature - - dswrf # downwards short-wave radiation flux - - dlwrf # downwards long-wave radiation flux - - hcc # high cloud cover - - mcc # medium cloud cover - - lcc # low cloud cover - - tcc # total cloud cover - - sde # snow depth water equivalent - - sr # direct solar radiation - - duvrs # downwards UV radiation at surface - - prate # precipitation rate - - u10 # 10-metre U component of wind speed - - u100 # 100-metre U component of wind speed - - u200 # 200-metre U component of wind speed - - v10 # 10-metre V component of wind speed - - v100 # 100-metre V component of wind speed - - v200 # 200-metre V component of wind speed - nwp_image_size_pixels_height: 12 # roughly equivalent to UKV 24 pixels - nwp_image_size_pixels_width: 12 - dropout_timedeltas_minutes: [-180] - dropout_fraction: 1.0 - max_staleness_minutes: null - - satellite: - satellite_zarr_path: - # Path(s) to non-HRV satellite data in zarr format - # e.g. gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/v4/2020_nonhrv.zarr - - PLACEHOLDER.zarr - history_minutes: 90 - forecast_minutes: 0 # Deprecated for most use cases - live_delay_minutes: 60 # Only data up to time t0-60minutes is inluced in slice - time_resolution_minutes: 5 - satellite_channels: - # Uses for each channel taken from https://resources.eumetrain.org/data/3/311/bsc_s4.pdf - - IR_016 # Surface, cloud phase - - IR_039 # Surface, clouds, wind fields - - IR_087 # Surface, clouds, atmospheric instability - - IR_097 # Ozone - - IR_108 # Surface, clouds, wind fields, atmospheric instability - - IR_120 # Surface, clouds, atmospheric instability - - IR_134 # Cirrus cloud height, atmospheric instability - - VIS006 # Surface, clouds, wind fields - - VIS008 # Surface, clouds, wind fields - - WV_062 # Water vapor, high level clouds, upper air analysis - - WV_073 # Water vapor, atmospheric instability, upper-level dynamics - satellite_image_size_pixels_height: 24 - satellite_image_size_pixels_width: 24 - # A random value from the list below will be chosen as the delay when dropout is used - # If set to null no dropout is applied. Values must be negative. - dropout_timedeltas_minutes: null - dropout_fraction: 0 # Fraction of samples with dropout diff --git a/configs.example/datamodule/configuration/site_example_configuration.yaml b/configs.example/datamodule/configuration/site_example_configuration.yaml new file mode 100644 index 00000000..157e7541 --- /dev/null +++ b/configs.example/datamodule/configuration/site_example_configuration.yaml @@ -0,0 +1,70 @@ +general: + description: Example config for producing PVNet samples for a reneweble generation site + name: site_example_config + +input_data: + + site: + time_resolution_minutes: 15 + interval_start_minutes: -60 + interval_end_minutes: 480 + file_path: PLACEHOLDER.nc + metadata_file_path: PLACEHOLDER.csv + dropout_timedeltas_minutes: null + dropout_fraction: 0 # Fraction of samples with dropout + + nwp: + ecmwf: + provider: ecmwf + # Path to ECMWF NWP data in zarr format + # n.b. It is not necessary to use multiple or any NWP data. These entries can be removed + zarr_path: PLACEHOLDER + interval_start_minutes: -60 + interval_end_minutes: 480 + time_resolution_minutes: 60 + channels: + - t2m # 2-metre temperature + - dswrf # downwards short-wave radiation flux + - dlwrf # downwards long-wave radiation flux + - hcc # high cloud cover + - mcc # medium cloud cover + - lcc # low cloud cover + - tcc # total cloud cover + - sde # snow depth water equivalent + - sr # direct solar radiation + - duvrs # downwards UV radiation at surface + - prate # precipitation rate + - u10 # 10-metre U component of wind speed + - u100 # 100-metre U component of wind speed + - u200 # 200-metre U component of wind speed + - v10 # 10-metre V component of wind speed + - v100 # 100-metre V component of wind speed + - v200 # 200-metre V component of wind speed + image_size_pixels_height: 24 + image_size_pixels_width: 24 + dropout_timedeltas_minutes: [-360] + dropout_fraction: 1.0 + max_staleness_minutes: null + + satellite: + zarr_path: PLACEHOLDER.zarr + interval_start_minutes: -30 + interval_end_minutes: 0 + time_resolution_minutes: 5 + channels: + # Uses for each channel taken from https://resources.eumetrain.org/data/3/311/bsc_s4.pdf + - IR_016 # Surface, cloud phase + - IR_039 # Surface, clouds, wind fields + - IR_087 # Surface, clouds, atmospheric instability + - IR_097 # Ozone + - IR_108 # Surface, clouds, wind fields, atmospheric instability + - IR_120 # Surface, clouds, atmospheric instability + - IR_134 # Cirrus cloud height, atmospheric instability + - VIS006 # Surface, clouds, wind fields + - VIS008 # Surface, clouds, wind fields + - WV_062 # Water vapor, high level clouds, upper air analysis + - WV_073 # Water vapor, atmospheric instability, upper-level dynamics + image_size_pixels_height: 24 + image_size_pixels_width: 24 + dropout_timedeltas_minutes: null + dropout_fraction: 0. diff --git a/configs.example/datamodule/premade_batches.yaml b/configs.example/datamodule/premade_batches.yaml index f08f5af2..350e7573 100644 --- a/configs.example/datamodule/premade_batches.yaml +++ b/configs.example/datamodule/premade_batches.yaml @@ -2,7 +2,7 @@ _target_: pvnet.data.datamodule.DataModule configuration: null # The batch_dir is the location batches were saved to using the save_batches.py script # The batch_dir should contain train and val subdirectories with batches -batch_dir: "PLACEHOLDER" +sample_dir: "PLACEHOLDER" num_workers: 10 prefetch_factor: 2 batch_size: 8 diff --git a/configs.example/datamodule/streamed_batches.yaml b/configs.example/datamodule/streamed_batches.yaml index 14f42bc5..8c2ef3cb 100644 --- a/configs.example/datamodule/streamed_batches.yaml +++ b/configs.example/datamodule/streamed_batches.yaml @@ -6,10 +6,9 @@ configuration: "PLACEHOLDER.yaml" num_workers: 20 prefetch_factor: 2 batch_size: 8 -batch_output_dir: "PLACEHOLDER" -num_train_batches: 2 -num_val_batches: 1 - +sample_output_dir: "PLACEHOLDER" +num_train_samples: 2 +num_val_samples: 1 train_period: - null @@ -17,6 +16,3 @@ train_period: val_period: - "2022-05-08" - "2023-05-08" -test_period: - - "2022-05-08" - - "2023-05-08" diff --git a/pvnet/data/datamodule.py b/pvnet/data/datamodule.py index b502113a..4df4dd81 100644 --- a/pvnet/data/datamodule.py +++ b/pvnet/data/datamodule.py @@ -3,7 +3,7 @@ import torch from lightning.pytorch import LightningDataModule -from ocf_data_sampler.torch_datasets.pvnet_uk_regional import PVNetUKRegionalDataset +from ocf_data_sampler.torch_datasets import PVNetUKRegionalDataset, SitesDataset from ocf_datapipes.batch import ( NumpyBatch, TensorBatch, @@ -93,7 +93,12 @@ def __init__( ) def _get_streamed_samples_dataset(self, start_time, end_time) -> Dataset: - return PVNetUKRegionalDataset(self.configuration, start_time=start_time, end_time=end_time) + if self.configuration.renewable == "pv": + return PVNetUKRegionalDataset(self.configuration, start_time=start_time, end_time=end_time) + elif self.configuration.renewable in ["wind", "pv_india", "pv_site"]: + return SitesDataset(self.configuration, start_time=start_time, end_time=end_time) + else: + raise ValueError(f"Unknown renewable: {self.configuration.renewable}") def _get_premade_samples_dataset(self, subdir) -> Dataset: split_dir = f"{self.sample_dir}/{subdir}" diff --git a/pvnet/data/pv_site_datamodule.py b/pvnet/data/pv_site_datamodule.py deleted file mode 100644 index 4c45eaec..00000000 --- a/pvnet/data/pv_site_datamodule.py +++ /dev/null @@ -1,67 +0,0 @@ -""" Data module for pytorch lightning """ -import glob - -from ocf_datapipes.batch import BatchKey, batch_to_tensor, stack_np_examples_into_batch -from ocf_datapipes.training.pvnet_site import ( - pvnet_site_datapipe, - pvnet_site_netcdf_datapipe, - split_dataset_dict_dp, - uncombine_from_single_dataset, -) - -from pvnet.data.base import BaseDataModule - - -class PVSiteDataModule(BaseDataModule): - """Datamodule for training pvnet site and using pvnet site pipeline in `ocf_datapipes`.""" - - def _get_datapipe(self, start_time, end_time): - data_pipeline = pvnet_site_datapipe( - self.configuration, - start_time=start_time, - end_time=end_time, - ) - data_pipeline = data_pipeline.map(uncombine_from_single_dataset).map(split_dataset_dict_dp) - data_pipeline = data_pipeline.pvnet_site_convert_to_numpy_batch() - - data_pipeline = ( - data_pipeline.batch(self.batch_size) - .map(stack_np_examples_into_batch) - .map(batch_to_tensor) - ) - return data_pipeline - - def _get_premade_batches_datapipe(self, subdir, shuffle=False): - filenames = list(glob.glob(f"{self.batch_dir}/{subdir}/*.nc")) - data_pipeline = pvnet_site_netcdf_datapipe( - keys=["pv", "nwp"], # add other keys e.g. sat if used as input in site model - filenames=filenames, - ) - data_pipeline = ( - data_pipeline.batch(self.batch_size) - .map(stack_np_examples_into_batch) - .map(batch_to_tensor) - ) - if shuffle: - data_pipeline = ( - data_pipeline.shuffle(buffer_size=100) - .sharding_filter() - # Split the batches and reshuffle them to be combined into new batches - .split_batches(splitting_key=BatchKey.pv) - .shuffle(buffer_size=self.shuffle_factor * self.batch_size) - ) - else: - data_pipeline = ( - data_pipeline.sharding_filter() - # Split the batches so we can use any batch-size - .split_batches(splitting_key=BatchKey.pv) - ) - - data_pipeline = ( - data_pipeline.batch(self.batch_size) - .map(stack_np_examples_into_batch) - .map(batch_to_tensor) - .set_length(int(len(filenames) / self.batch_size)) - ) - - return data_pipeline diff --git a/pvnet/data/wind_datamodule.py b/pvnet/data/wind_datamodule.py deleted file mode 100644 index 0c11d31d..00000000 --- a/pvnet/data/wind_datamodule.py +++ /dev/null @@ -1,62 +0,0 @@ -""" Data module for pytorch lightning """ -import glob - -from ocf_datapipes.batch import BatchKey, batch_to_tensor, stack_np_examples_into_batch -from ocf_datapipes.training.windnet import windnet_netcdf_datapipe - -from pvnet.data.base import BaseDataModule - - -class WindDataModule(BaseDataModule): - """Datamodule for training windnet and using windnet pipeline in `ocf_datapipes`.""" - - def _get_datapipe(self, start_time, end_time): - # TODO is this is not right, need to load full windnet pipeline - data_pipeline = windnet_netcdf_datapipe( - self.configuration, - keys=["wind", "nwp", "sensor"], - ) - - data_pipeline = ( - data_pipeline.batch(self.batch_size) - .map(stack_np_examples_into_batch) - .map(batch_to_tensor) - ) - return data_pipeline - - def _get_premade_batches_datapipe(self, subdir, shuffle=False): - filenames = list(glob.glob(f"{self.batch_dir}/{subdir}/*.nc")) - data_pipeline = windnet_netcdf_datapipe( - keys=["wind", "nwp", "sensor"], - filenames=filenames, - nwp_channels=self.nwp_channels, - ) - - data_pipeline = ( - data_pipeline.batch(self.batch_size) - .map(stack_np_examples_into_batch) - .map(batch_to_tensor) - ) - if shuffle: - data_pipeline = ( - data_pipeline.shuffle(buffer_size=100) - .sharding_filter() - # Split the batches and reshuffle them to be combined into new batches - .split_batches(splitting_key=BatchKey.wind) - .shuffle(buffer_size=self.shuffle_factor * self.batch_size) - ) - else: - data_pipeline = ( - data_pipeline.sharding_filter() - # Split the batches so we can use any batch-size - .split_batches(splitting_key=BatchKey.wind) - ) - - data_pipeline = ( - data_pipeline.batch(self.batch_size) - .map(stack_np_examples_into_batch) - .map(batch_to_tensor) - .set_length(int(len(filenames) / self.batch_size)) - ) - - return data_pipeline diff --git a/pyproject.toml b/pyproject.toml index cc3ea1cf..b931d605 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ dynamic = ["version", "readme"] license={file="LICENCE"} dependencies = [ - "ocf_data_sampler==0.0.26", + "ocf_data_sampler==0.0.32", "ocf_datapipes>=3.3.34", "ocf_ml_metrics>=0.0.11", "numpy", diff --git a/scripts/save_samples.py b/scripts/save_samples.py index d38a45f9..2fdac0ee 100644 --- a/scripts/save_samples.py +++ b/scripts/save_samples.py @@ -44,7 +44,7 @@ import dask import hydra import torch -from ocf_data_sampler.torch_datasets.pvnet_uk_regional import PVNetUKRegionalDataset +from ocf_data_sampler.torch_datasets import PVNetUKRegionalDataset, SitesDataset from omegaconf import DictConfig, OmegaConf from sqlalchemy import exc as sa_exc from torch.utils.data import DataLoader, Dataset @@ -79,7 +79,7 @@ def __call__(self, sample, sample_num: int): if self.renewable == "pv": torch.save(sample, f"{self.save_dir}/{sample_num:08}.pt") elif self.renewable in ["wind", "pv_india", "pv_site"]: - raise NotImplementedError + sample.to_netcdf(f"{self.save_dir}/{sample_num:08}.nc", mode="w", engine="h5netcdf") else: raise ValueError(f"Unknown renewable: {self.renewable}") @@ -89,7 +89,7 @@ def get_dataset(config_path: str, start_time: str, end_time: str, renewable: str if renewable == "pv": dataset_cls = PVNetUKRegionalDataset elif renewable in ["wind", "pv_india", "pv_site"]: - raise NotImplementedError + dataset_cls = SitesDataset else: raise ValueError(f"Unknown renewable: {renewable}") From 23a749dd349a9bd0081d4935e339c6a9ec6e9c73 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 17 Dec 2024 16:22:53 +0000 Subject: [PATCH 02/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../datamodule/configuration/site_example_configuration.yaml | 2 +- pvnet/data/datamodule.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/configs.example/datamodule/configuration/site_example_configuration.yaml b/configs.example/datamodule/configuration/site_example_configuration.yaml index 157e7541..a8cfadeb 100644 --- a/configs.example/datamodule/configuration/site_example_configuration.yaml +++ b/configs.example/datamodule/configuration/site_example_configuration.yaml @@ -40,7 +40,7 @@ input_data: - v10 # 10-metre V component of wind speed - v100 # 100-metre V component of wind speed - v200 # 200-metre V component of wind speed - image_size_pixels_height: 24 + image_size_pixels_height: 24 image_size_pixels_width: 24 dropout_timedeltas_minutes: [-360] dropout_fraction: 1.0 diff --git a/pvnet/data/datamodule.py b/pvnet/data/datamodule.py index 4df4dd81..8cd72411 100644 --- a/pvnet/data/datamodule.py +++ b/pvnet/data/datamodule.py @@ -94,7 +94,9 @@ def __init__( def _get_streamed_samples_dataset(self, start_time, end_time) -> Dataset: if self.configuration.renewable == "pv": - return PVNetUKRegionalDataset(self.configuration, start_time=start_time, end_time=end_time) + return PVNetUKRegionalDataset( + self.configuration, start_time=start_time, end_time=end_time + ) elif self.configuration.renewable in ["wind", "pv_india", "pv_site"]: return SitesDataset(self.configuration, start_time=start_time, end_time=end_time) else: From b05d8fa8f03be4458dc35945f219c863d4be8db1 Mon Sep 17 00:00:00 2001 From: Sukhil Patel Date: Wed, 18 Dec 2024 17:49:00 +0000 Subject: [PATCH 03/13] Up ocf-data-sampler version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index b931d605..6c33954f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ dynamic = ["version", "readme"] license={file="LICENCE"} dependencies = [ - "ocf_data_sampler==0.0.32", + "ocf_data_sampler==0.0.33", "ocf_datapipes>=3.3.34", "ocf_ml_metrics>=0.0.11", "numpy", From d69e1b4b9823afdbb4b82a84a038d3c37c06c1c4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 18 Dec 2024 17:59:47 +0000 Subject: [PATCH 04/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- configs.example/datamodule/streamed_batches.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/configs.example/datamodule/streamed_batches.yaml b/configs.example/datamodule/streamed_batches.yaml index e8573ad3..9e76ad8d 100644 --- a/configs.example/datamodule/streamed_batches.yaml +++ b/configs.example/datamodule/streamed_batches.yaml @@ -17,4 +17,3 @@ train_period: val_period: - "2022-05-08" - "2023-05-08" - From 0a0adbfbb9488b1c656c39d7a7cb5b1e3b3c28da Mon Sep 17 00:00:00 2001 From: Sukhil Patel Date: Wed, 18 Dec 2024 18:01:15 +0000 Subject: [PATCH 05/13] Undo space deletion --- configs.example/datamodule/streamed_batches.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/configs.example/datamodule/streamed_batches.yaml b/configs.example/datamodule/streamed_batches.yaml index e8573ad3..077b7876 100644 --- a/configs.example/datamodule/streamed_batches.yaml +++ b/configs.example/datamodule/streamed_batches.yaml @@ -7,6 +7,7 @@ configuration: "PLACEHOLDER.yaml" num_workers: 20 prefetch_factor: 2 batch_size: 8 + sample_output_dir: "PLACEHOLDER" num_train_samples: 2 num_val_samples: 1 From a75ee76902c4a3fe0b6b26cf6c01ee538a063cc8 Mon Sep 17 00:00:00 2001 From: Sukhil Patel Date: Thu, 19 Dec 2024 16:06:17 +0000 Subject: [PATCH 06/13] Remove datapipe tests --- tests/data/test_datamodule.py | 59 ++--------------------------------- 1 file changed, 2 insertions(+), 57 deletions(-) diff --git a/tests/data/test_datamodule.py b/tests/data/test_datamodule.py index 00b1705d..96be91e0 100644 --- a/tests/data/test_datamodule.py +++ b/tests/data/test_datamodule.py @@ -1,9 +1,4 @@ -import pytest from pvnet.data.datamodule import DataModule -from pvnet.data.wind_datamodule import WindDataModule -from pvnet.data.pv_site_datamodule import PVSiteDataModule -import os -from ocf_datapipes.batch.batches import BatchKey, NWPBatchKey def test_init(): @@ -17,58 +12,6 @@ def test_init(): val_period=[None, None], ) - -@pytest.mark.skip(reason="Has not been updated for ocf-data-sampler yet") -def test_wind_init(): - dm = WindDataModule( - configuration=None, - batch_size=2, - num_workers=0, - prefetch_factor=None, - train_period=[None, None], - val_period=[None, None], - test_period=[None, None], - batch_dir="tests/data/sample_batches", - ) - - -@pytest.mark.skip(reason="Has not been updated for ocf-data-sampler yet") -def test_wind_init_with_nwp_filter(): - dm = WindDataModule( - configuration=None, - batch_size=2, - num_workers=0, - prefetch_factor=None, - train_period=[None, None], - val_period=[None, None], - test_period=[None, None], - batch_dir="tests/test_data/sample_wind_batches", - nwp_channels={"ecmwf": ["t2m", "v200"]}, - ) - dataloader = iter(dm.train_dataloader()) - - batch = next(dataloader) - batch_channels = batch[BatchKey.nwp]["ecmwf"][NWPBatchKey.nwp_channel_names] - print(batch_channels) - for v in ["t2m", "v200"]: - assert v in batch_channels - assert batch[BatchKey.nwp]["ecmwf"][NWPBatchKey.nwp].shape[2] == 2 - - -@pytest.mark.skip(reason="Has not been updated for ocf-data-sampler yet") -def test_pv_site_init(): - dm = PVSiteDataModule( - configuration=f"{os.path.dirname(os.path.abspath(__file__))}/test_data/sample_batches/data_configuration.yaml", - batch_size=2, - num_workers=0, - prefetch_factor=None, - train_period=[None, None], - val_period=[None, None], - test_period=[None, None], - batch_dir=None, - ) - - def test_iter(): dm = DataModule( configuration=None, @@ -104,3 +47,5 @@ def test_iter_multiprocessing(): # Make sure we've served 2 batches assert served_batches == 2 + +# TODO add test cases with some netcdfs premade samples From 8e7e0d0be2b811bf685a4e385154305a86095ce3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 19 Dec 2024 16:07:12 +0000 Subject: [PATCH 07/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/data/test_datamodule.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/data/test_datamodule.py b/tests/data/test_datamodule.py index 96be91e0..cf93372c 100644 --- a/tests/data/test_datamodule.py +++ b/tests/data/test_datamodule.py @@ -12,6 +12,7 @@ def test_init(): val_period=[None, None], ) + def test_iter(): dm = DataModule( configuration=None, @@ -48,4 +49,5 @@ def test_iter_multiprocessing(): # Make sure we've served 2 batches assert served_batches == 2 -# TODO add test cases with some netcdfs premade samples + +# TODO add test cases with some netcdfs premade samples From ff5bbf0bdd34c4e985edf759402b8709c51ca8a0 Mon Sep 17 00:00:00 2001 From: Sukhil Patel Date: Thu, 19 Dec 2024 16:14:45 +0000 Subject: [PATCH 08/13] Remove old datapipe tests --- tests/conftest.py | 31 +++++++++---------- .../multimodal/site_encoders/test_encoders.py | 16 +++++----- 2 files changed, 23 insertions(+), 24 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index ba657af5..8e589616 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,7 +14,6 @@ import pvnet from pvnet.data.datamodule import DataModule -from pvnet.data.wind_datamodule import WindDataModule import pvnet.models.multimodal.encoders.encoders3d import pvnet.models.multimodal.linear_networks.networks @@ -158,21 +157,21 @@ def sample_pv_batch(): # old batches. For now we use the old batches to test the site encoder models return torch.load("tests/test_data/presaved_batches/train/000000.pt") - -@pytest.fixture() -def sample_wind_batch(): - dm = WindDataModule( - configuration=None, - batch_size=2, - num_workers=0, - prefetch_factor=None, - train_period=[None, None], - val_period=[None, None], - test_period=[None, None], - batch_dir="tests/test_data/sample_wind_batches", - ) - batch = next(iter(dm.train_dataloader())) - return batch +# TODO update this test once we add the loading logic for the Site dataset +# @pytest.fixture() +# def sample_wind_batch(): +# dm = WindDataModule( +# configuration=None, +# batch_size=2, +# num_workers=0, +# prefetch_factor=None, +# train_period=[None, None], +# val_period=[None, None], +# test_period=[None, None], +# batch_dir="tests/test_data/sample_wind_batches", +# ) +# batch = next(iter(dm.train_dataloader())) +# return batch @pytest.fixture() diff --git a/tests/models/multimodal/site_encoders/test_encoders.py b/tests/models/multimodal/site_encoders/test_encoders.py index 41969b22..08d3cac8 100644 --- a/tests/models/multimodal/site_encoders/test_encoders.py +++ b/tests/models/multimodal/site_encoders/test_encoders.py @@ -41,14 +41,14 @@ def test_singleattentionnetwork_forward(sample_pv_batch, site_encoder_model_kwar batch_size=8, ) - -def test_singleattentionnetwork_forward_4d(sample_wind_batch, site_encoder_sensor_model_kwargs): - _test_model_forward( - sample_wind_batch, - SingleAttentionNetwork, - site_encoder_sensor_model_kwargs, - batch_size=2, - ) +# TODO once we have updated the sample batches for sites include this test +# def test_singleattentionnetwork_forward_4d(sample_wind_batch, site_encoder_sensor_model_kwargs): +# _test_model_forward( +# sample_wind_batch, +# SingleAttentionNetwork, +# site_encoder_sensor_model_kwargs, +# batch_size=2, +# ) # Test model backward on all models From b7923f1d2327aba0e4a9a2375d3af4201a8df792 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 19 Dec 2024 16:15:21 +0000 Subject: [PATCH 09/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/conftest.py | 1 + tests/models/multimodal/site_encoders/test_encoders.py | 1 + 2 files changed, 2 insertions(+) diff --git a/tests/conftest.py b/tests/conftest.py index 8e589616..89dbff75 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -157,6 +157,7 @@ def sample_pv_batch(): # old batches. For now we use the old batches to test the site encoder models return torch.load("tests/test_data/presaved_batches/train/000000.pt") + # TODO update this test once we add the loading logic for the Site dataset # @pytest.fixture() # def sample_wind_batch(): diff --git a/tests/models/multimodal/site_encoders/test_encoders.py b/tests/models/multimodal/site_encoders/test_encoders.py index 08d3cac8..48938bc3 100644 --- a/tests/models/multimodal/site_encoders/test_encoders.py +++ b/tests/models/multimodal/site_encoders/test_encoders.py @@ -41,6 +41,7 @@ def test_singleattentionnetwork_forward(sample_pv_batch, site_encoder_model_kwar batch_size=8, ) + # TODO once we have updated the sample batches for sites include this test # def test_singleattentionnetwork_forward_4d(sample_wind_batch, site_encoder_sensor_model_kwargs): # _test_model_forward( From 8c62a5416b0f3a3a654a6c36993c0d7317fdddf4 Mon Sep 17 00:00:00 2001 From: Sukhil Patel Date: Thu, 19 Dec 2024 16:30:56 +0000 Subject: [PATCH 10/13] Update allowed renewable parameter values --- pvnet/data/datamodule.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pvnet/data/datamodule.py b/pvnet/data/datamodule.py index 8cd72411..1572a93f 100644 --- a/pvnet/data/datamodule.py +++ b/pvnet/data/datamodule.py @@ -93,14 +93,14 @@ def __init__( ) def _get_streamed_samples_dataset(self, start_time, end_time) -> Dataset: - if self.configuration.renewable == "pv": + if self.configuration.renewable == "uk_pv": return PVNetUKRegionalDataset( self.configuration, start_time=start_time, end_time=end_time ) - elif self.configuration.renewable in ["wind", "pv_india", "pv_site"]: + elif self.configuration.renewable == "site": return SitesDataset(self.configuration, start_time=start_time, end_time=end_time) else: - raise ValueError(f"Unknown renewable: {self.configuration.renewable}") + raise ValueError(f"Unknown renewable: {self.configuration.renewable}, renewable value should either be uk_pv or site") def _get_premade_samples_dataset(self, subdir) -> Dataset: split_dir = f"{self.sample_dir}/{subdir}" From 934e034855879cfca8f5038ceab7ae218ac0de80 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 19 Dec 2024 16:31:08 +0000 Subject: [PATCH 11/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pvnet/data/datamodule.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pvnet/data/datamodule.py b/pvnet/data/datamodule.py index 1572a93f..f2d747e5 100644 --- a/pvnet/data/datamodule.py +++ b/pvnet/data/datamodule.py @@ -100,7 +100,9 @@ def _get_streamed_samples_dataset(self, start_time, end_time) -> Dataset: elif self.configuration.renewable == "site": return SitesDataset(self.configuration, start_time=start_time, end_time=end_time) else: - raise ValueError(f"Unknown renewable: {self.configuration.renewable}, renewable value should either be uk_pv or site") + raise ValueError( + f"Unknown renewable: {self.configuration.renewable}, renewable value should either be uk_pv or site" + ) def _get_premade_samples_dataset(self, subdir) -> Dataset: split_dir = f"{self.sample_dir}/{subdir}" From 36181b3b8cafe790a0b2da971f8523ad6b21d31c Mon Sep 17 00:00:00 2001 From: Sukhil Patel Date: Thu, 19 Dec 2024 16:53:04 +0000 Subject: [PATCH 12/13] Linting --- pvnet/data/datamodule.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pvnet/data/datamodule.py b/pvnet/data/datamodule.py index f2d747e5..f1128155 100644 --- a/pvnet/data/datamodule.py +++ b/pvnet/data/datamodule.py @@ -101,7 +101,8 @@ def _get_streamed_samples_dataset(self, start_time, end_time) -> Dataset: return SitesDataset(self.configuration, start_time=start_time, end_time=end_time) else: raise ValueError( - f"Unknown renewable: {self.configuration.renewable}, renewable value should either be uk_pv or site" + f"Unknown renewable: {self.configuration.renewable}, + renewable value should either be uk_pv or site" ) def _get_premade_samples_dataset(self, subdir) -> Dataset: From 767bedc49f6ce0d085764197f73a2a16a56fab91 Mon Sep 17 00:00:00 2001 From: Sukhil Patel Date: Thu, 19 Dec 2024 16:55:38 +0000 Subject: [PATCH 13/13] More linting --- pvnet/data/datamodule.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pvnet/data/datamodule.py b/pvnet/data/datamodule.py index f1128155..9708fbdd 100644 --- a/pvnet/data/datamodule.py +++ b/pvnet/data/datamodule.py @@ -101,8 +101,8 @@ def _get_streamed_samples_dataset(self, start_time, end_time) -> Dataset: return SitesDataset(self.configuration, start_time=start_time, end_time=end_time) else: raise ValueError( - f"Unknown renewable: {self.configuration.renewable}, - renewable value should either be uk_pv or site" + f"Unknown renewable: {self.configuration.renewable}, " + "renewable value should either be uk_pv or site" ) def _get_premade_samples_dataset(self, subdir) -> Dataset: