diff --git a/README.md b/README.md index 051e9a2e..69f54c64 100644 --- a/README.md +++ b/README.md @@ -145,20 +145,20 @@ This is also where you can update the train, val & test periods to cover the dat ### Running the batch creation script -Run the `save_batches.py` script to create batches with the parameters specified in the datamodule config (`streamed_batches.yaml` in this example): +Run the `save_samples.py` script to create batches with the parameters specified in the datamodule config (`streamed_batches.yaml` in this example): ```bash -python scripts/save_batches.py +python scripts/save_samples.py ``` PVNet uses [hydra](https://hydra.cc/) which enables us to pass variables via the command line that will override the configuration defined in the `./configs` directory, like this: ```bash -python scripts/save_batches.py datamodule=streamed_batches datamodule.batch_output_dir="./output" datamodule.num_train_batches=10 datamodule.num_val_batches=5 +python scripts/save_samples.py datamodule=streamed_batches datamodule.sample_output_dir="./output" datamodule.num_train_batches=10 datamodule.num_val_batches=5 ``` -`scripts/save_batches.py` needs a config under `PVNet/configs/datamodule`. You can adapt `streamed_batches.yaml` or create your own in the same folder. +`scripts/save_samples.py` needs a config under `PVNet/configs/datamodule`. You can adapt `streamed_batches.yaml` or create your own in the same folder. If downloading private data from a GCP bucket make sure to authenticate gcloud (the public satellite data does not need authentication): @@ -197,7 +197,7 @@ Make sure to update the following config files before training your model: 2. In `configs/model/local_multimodal.yaml`: - update the list of encoders to reflect the data sources you are using. If you are using different NWP sources, the encoders for these should follow the same structure with two important updates: - `in_channels`: number of variables your NWP source supplies - - `image_size_pixels`: spatial crop of your NWP data. It depends on the spatial resolution of your NWP; should match `nwp_image_size_pixels_height` and/or `nwp_image_size_pixels_width` in `datamodule/example_configs.yaml`, unless transformations such as coarsening was applied (e. g. as for ECMWF data) + - `image_size_pixels`: spatial crop of your NWP data. It depends on the spatial resolution of your NWP; should match `image_size_pixels_height` and/or `image_size_pixels_width` in `datamodule/configuration/site_example_configuration.yaml` for the NWP, unless transformations such as coarsening was applied (e. g. as for ECMWF data) 3. In `configs/local_trainer.yaml`: - set `accelerator: 0` if running on a system without a supported GPU @@ -216,7 +216,7 @@ defaults: - hydra: default.yaml ``` -Assuming you ran the `save_batches.py` script to generate some premade train and +Assuming you ran the `save_samples.py` script to generate some premade train and val data batches, you can now train PVNet by running: ``` diff --git a/configs.example/datamodule/configuration/example_configuration.yaml b/configs.example/datamodule/configuration/example_configuration.yaml deleted file mode 100644 index 827b0e9d..00000000 --- a/configs.example/datamodule/configuration/example_configuration.yaml +++ /dev/null @@ -1,172 +0,0 @@ -general: - description: Example data config for creating PVNet batches - name: example_pvnet - -input_data: - default_history_minutes: 120 - default_forecast_minutes: 480 - - gsp: - # Path to the GSP data. This should be a zarr file - # e.g. gs://solar-pv-nowcasting-data/PV/GSP/v7/pv_gsp.zarr - gsp_zarr_path: PLACEHOLDER.zarr - history_minutes: 120 - forecast_minutes: 480 - time_resolution_minutes: 30 - # A random value from the list below will be chosen as the delay when dropout is used - # If set to null no dropout is applied. Only values before t0 are dropped out for GSP. - # Values after t0 are assumed as targets and cannot be dropped. - dropout_timedeltas_minutes: null - dropout_fraction: 0 # Fraction of samples with dropout - - pv: - pv_files_groups: - - label: solar_sheffield_passiv - # Path to the site-level PV data. This should be a netcdf - # e.g gs://solar-pv-nowcasting-data/PV/Passive/ocf_formatted/v0/passiv.netcdf - pv_filename: PLACEHOLDER.netcdf - # Path to the site-level PV metadata. This choudl be a csv - # e.g gs://solar-pv-nowcasting-data/PV/Passive/ocf_formatted/v0/system_metadata.csv - pv_metadata_filename: PLACEHOLDER.csv - # This is the list of pv_ml_ids to be sliced from the PV site level data - # The IDs below are 349 of the PV systems which have very little NaN data in the historic data - # and which are still reporting live (as of Oct 2023) - pv_ml_ids: - [ - 154, 155, 156, 158, 159, 160, 162, 164, 165, 166, 167, 168, 169, 171, 173, 177, 178, 179, - 181, 182, 185, 186, 187, 188, 189, 190, 191, 192, 193, 197, 198, 199, 200, 202, 204, 205, - 206, 208, 209, 211, 214, 215, 216, 217, 218, 219, 220, 221, 225, 229, 230, 232, 233, 234, - 236, 242, 243, 245, 252, 254, 255, 256, 257, 258, 260, 261, 262, 265, 267, 268, 272, 273, - 275, 276, 277, 280, 281, 282, 283, 287, 289, 291, 292, 293, 294, 295, 296, 297, 298, 301, - 302, 303, 304, 306, 307, 309, 310, 311, 317, 318, 319, 320, 321, 322, 323, 325, 326, 329, - 332, 333, 335, 336, 338, 340, 342, 344, 345, 346, 348, 349, 352, 354, 355, 356, 357, 360, - 362, 363, 368, 369, 370, 371, 372, 374, 375, 376, 378, 380, 382, 384, 385, 388, 390, 391, - 393, 396, 397, 398, 399, 400, 401, 403, 404, 405, 406, 407, 409, 411, 412, 413, 414, 415, - 416, 417, 418, 419, 420, 421, 422, 423, 424, 425, 426, 427, 429, 431, 435, 437, 438, 440, - 441, 444, 447, 450, 451, 453, 456, 457, 458, 459, 464, 465, 466, 467, 468, 470, 471, 473, - 474, 476, 477, 479, 480, 481, 482, 485, 486, 488, 490, 491, 492, 493, 496, 498, 501, 503, - 506, 507, 508, 509, 510, 511, 512, 513, 515, 516, 517, 519, 520, 521, 522, 524, 526, 527, - 528, 531, 532, 536, 537, 538, 540, 541, 542, 543, 544, 545, 549, 550, 551, 552, 553, 554, - 556, 557, 560, 561, 563, 566, 568, 571, 572, 575, 576, 577, 579, 580, 581, 582, 584, 585, - 588, 590, 594, 595, 597, 600, 602, 603, 604, 606, 611, 613, 614, 616, 618, 620, 622, 623, - 624, 625, 626, 628, 629, 630, 631, 636, 637, 638, 640, 641, 642, 644, 645, 646, 650, 651, - 652, 653, 654, 655, 657, 660, 661, 662, 663, 666, 667, 668, 670, 675, 676, 679, 681, 683, - 684, 685, 687, 696, 698, 701, 702, 703, 704, 706, 710, 722, 723, 724, 725, 727, 728, 729, - 730, 732, 733, 734, 735, 736, 737 - ] - history_minutes: 180 - forecast_minutes: 0 - time_resolution_minutes: 5 - # A random value from the list below will be chosen as the delay when dropout is used. - # If set to null no dropout is applied. All PV systems are dropped together with this setting. - dropout_timedeltas_minutes: null - dropout_fraction: 0 # Fraction of samples with dropout - # A random value from the list below will be chosen as the delay when system dropout is used. - # If set to null no dropout is applied. All PV systems are indpendently with this setting. - system_dropout_timedeltas_minutes: null - # For ech sample a differnt dropout probability is used which is uniformly sampled from the min - # and max below - system_dropout_fraction_min: 0 - system_dropout_fraction_max: 0 - - nwp: - ukv: - nwp_provider: ukv - nwp_zarr_path: - # Path(s) to UKV NWP data in zarr format - # e.g. gs://solar-pv-nowcasting-data/NWP/UK_Met_Office/UKV_intermediate_version_7.zarr - - PLACEHOLDER.zarr - history_minutes: 120 - forecast_minutes: 480 - time_resolution_minutes: 60 - nwp_channels: - # These variables exist in the CEDA training set and in the live MetOffice live service - - t # 2-metre temperature - - dswrf # downwards short-wave radiation flux - - dlwrf # downwards long-wave radiation flux - - hcc # high cloud cover - - mcc # medium cloud cover - - lcc # low cloud cover - - sde # snow depth water equivalent - - r # relative humidty - - vis # visibility - - si10 # 10-metre wind speed - - wdir10 # 10-metre wind direction - - prate # precipitation rate - # These variables exist in CEDA training data but not in the live MetOffice live service - - hcct # height of convective cloud top, meters above surface. NaN if no clouds - - cdcb # height of lowest cloud base > 3 oktas - - dpt # dew point temperature - - prmsl # mean sea level pressure - - h # geometrical? (maybe geopotential?) height - nwp_image_size_pixels_height: 24 - nwp_image_size_pixels_width: 24 - # A random value from the list below will be chosen as the delay when dropout is used - # If set to null no dropout is applied. Values must be negative. - dropout_timedeltas_minutes: [-180] - # Dropout applied with this probability - dropout_fraction: 1.0 - # How long after the NWP init-time are we still willing to use this forecast - # If null we use each init-time for all steps it covers - max_staleness_minutes: null - - ecmwf: - nwp_provider: ecmwf - # Path to ECMWF NWP data in zarr format - # n.b. It is not necessary to use multiple or any NWP data. These entries can be removed - nwp_zarr_path: PLACEHOLDER.zarr - history_minutes: 120 - forecast_minutes: 480 - time_resolution_minutes: 60 - nwp_channels: - - t2m # 2-metre temperature - - dswrf # downwards short-wave radiation flux - - dlwrf # downwards long-wave radiation flux - - hcc # high cloud cover - - mcc # medium cloud cover - - lcc # low cloud cover - - tcc # total cloud cover - - sde # snow depth water equivalent - - sr # direct solar radiation - - duvrs # downwards UV radiation at surface - - prate # precipitation rate - - u10 # 10-metre U component of wind speed - - u100 # 100-metre U component of wind speed - - u200 # 200-metre U component of wind speed - - v10 # 10-metre V component of wind speed - - v100 # 100-metre V component of wind speed - - v200 # 200-metre V component of wind speed - nwp_image_size_pixels_height: 12 # roughly equivalent to UKV 24 pixels - nwp_image_size_pixels_width: 12 - dropout_timedeltas_minutes: [-180] - dropout_fraction: 1.0 - max_staleness_minutes: null - - satellite: - satellite_zarr_path: - # Path(s) to non-HRV satellite data in zarr format - # e.g. gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/v4/2020_nonhrv.zarr - - PLACEHOLDER.zarr - history_minutes: 90 - forecast_minutes: 0 # Deprecated for most use cases - live_delay_minutes: 60 # Only data up to time t0-60minutes is inluced in slice - time_resolution_minutes: 5 - satellite_channels: - # Uses for each channel taken from https://resources.eumetrain.org/data/3/311/bsc_s4.pdf - - IR_016 # Surface, cloud phase - - IR_039 # Surface, clouds, wind fields - - IR_087 # Surface, clouds, atmospheric instability - - IR_097 # Ozone - - IR_108 # Surface, clouds, wind fields, atmospheric instability - - IR_120 # Surface, clouds, atmospheric instability - - IR_134 # Cirrus cloud height, atmospheric instability - - VIS006 # Surface, clouds, wind fields - - VIS008 # Surface, clouds, wind fields - - WV_062 # Water vapor, high level clouds, upper air analysis - - WV_073 # Water vapor, atmospheric instability, upper-level dynamics - satellite_image_size_pixels_height: 24 - satellite_image_size_pixels_width: 24 - # A random value from the list below will be chosen as the delay when dropout is used - # If set to null no dropout is applied. Values must be negative. - dropout_timedeltas_minutes: null - dropout_fraction: 0 # Fraction of samples with dropout diff --git a/configs.example/datamodule/configuration/site_example_configuration.yaml b/configs.example/datamodule/configuration/site_example_configuration.yaml new file mode 100644 index 00000000..157e7541 --- /dev/null +++ b/configs.example/datamodule/configuration/site_example_configuration.yaml @@ -0,0 +1,70 @@ +general: + description: Example config for producing PVNet samples for a reneweble generation site + name: site_example_config + +input_data: + + site: + time_resolution_minutes: 15 + interval_start_minutes: -60 + interval_end_minutes: 480 + file_path: PLACEHOLDER.nc + metadata_file_path: PLACEHOLDER.csv + dropout_timedeltas_minutes: null + dropout_fraction: 0 # Fraction of samples with dropout + + nwp: + ecmwf: + provider: ecmwf + # Path to ECMWF NWP data in zarr format + # n.b. It is not necessary to use multiple or any NWP data. These entries can be removed + zarr_path: PLACEHOLDER + interval_start_minutes: -60 + interval_end_minutes: 480 + time_resolution_minutes: 60 + channels: + - t2m # 2-metre temperature + - dswrf # downwards short-wave radiation flux + - dlwrf # downwards long-wave radiation flux + - hcc # high cloud cover + - mcc # medium cloud cover + - lcc # low cloud cover + - tcc # total cloud cover + - sde # snow depth water equivalent + - sr # direct solar radiation + - duvrs # downwards UV radiation at surface + - prate # precipitation rate + - u10 # 10-metre U component of wind speed + - u100 # 100-metre U component of wind speed + - u200 # 200-metre U component of wind speed + - v10 # 10-metre V component of wind speed + - v100 # 100-metre V component of wind speed + - v200 # 200-metre V component of wind speed + image_size_pixels_height: 24 + image_size_pixels_width: 24 + dropout_timedeltas_minutes: [-360] + dropout_fraction: 1.0 + max_staleness_minutes: null + + satellite: + zarr_path: PLACEHOLDER.zarr + interval_start_minutes: -30 + interval_end_minutes: 0 + time_resolution_minutes: 5 + channels: + # Uses for each channel taken from https://resources.eumetrain.org/data/3/311/bsc_s4.pdf + - IR_016 # Surface, cloud phase + - IR_039 # Surface, clouds, wind fields + - IR_087 # Surface, clouds, atmospheric instability + - IR_097 # Ozone + - IR_108 # Surface, clouds, wind fields, atmospheric instability + - IR_120 # Surface, clouds, atmospheric instability + - IR_134 # Cirrus cloud height, atmospheric instability + - VIS006 # Surface, clouds, wind fields + - VIS008 # Surface, clouds, wind fields + - WV_062 # Water vapor, high level clouds, upper air analysis + - WV_073 # Water vapor, atmospheric instability, upper-level dynamics + image_size_pixels_height: 24 + image_size_pixels_width: 24 + dropout_timedeltas_minutes: null + dropout_fraction: 0. diff --git a/configs.example/datamodule/premade_batches.yaml b/configs.example/datamodule/premade_batches.yaml index f08f5af2..350e7573 100644 --- a/configs.example/datamodule/premade_batches.yaml +++ b/configs.example/datamodule/premade_batches.yaml @@ -2,7 +2,7 @@ _target_: pvnet.data.datamodule.DataModule configuration: null # The batch_dir is the location batches were saved to using the save_batches.py script # The batch_dir should contain train and val subdirectories with batches -batch_dir: "PLACEHOLDER" +sample_dir: "PLACEHOLDER" num_workers: 10 prefetch_factor: 2 batch_size: 8 diff --git a/configs.example/datamodule/streamed_batches.yaml b/configs.example/datamodule/streamed_batches.yaml index 14f42bc5..8c2ef3cb 100644 --- a/configs.example/datamodule/streamed_batches.yaml +++ b/configs.example/datamodule/streamed_batches.yaml @@ -6,10 +6,9 @@ configuration: "PLACEHOLDER.yaml" num_workers: 20 prefetch_factor: 2 batch_size: 8 -batch_output_dir: "PLACEHOLDER" -num_train_batches: 2 -num_val_batches: 1 - +sample_output_dir: "PLACEHOLDER" +num_train_samples: 2 +num_val_samples: 1 train_period: - null @@ -17,6 +16,3 @@ train_period: val_period: - "2022-05-08" - "2023-05-08" -test_period: - - "2022-05-08" - - "2023-05-08" diff --git a/pvnet/data/datamodule.py b/pvnet/data/datamodule.py index b502113a..4df4dd81 100644 --- a/pvnet/data/datamodule.py +++ b/pvnet/data/datamodule.py @@ -3,7 +3,7 @@ import torch from lightning.pytorch import LightningDataModule -from ocf_data_sampler.torch_datasets.pvnet_uk_regional import PVNetUKRegionalDataset +from ocf_data_sampler.torch_datasets import PVNetUKRegionalDataset, SitesDataset from ocf_datapipes.batch import ( NumpyBatch, TensorBatch, @@ -93,7 +93,12 @@ def __init__( ) def _get_streamed_samples_dataset(self, start_time, end_time) -> Dataset: - return PVNetUKRegionalDataset(self.configuration, start_time=start_time, end_time=end_time) + if self.configuration.renewable == "pv": + return PVNetUKRegionalDataset(self.configuration, start_time=start_time, end_time=end_time) + elif self.configuration.renewable in ["wind", "pv_india", "pv_site"]: + return SitesDataset(self.configuration, start_time=start_time, end_time=end_time) + else: + raise ValueError(f"Unknown renewable: {self.configuration.renewable}") def _get_premade_samples_dataset(self, subdir) -> Dataset: split_dir = f"{self.sample_dir}/{subdir}" diff --git a/pvnet/data/pv_site_datamodule.py b/pvnet/data/pv_site_datamodule.py deleted file mode 100644 index 4c45eaec..00000000 --- a/pvnet/data/pv_site_datamodule.py +++ /dev/null @@ -1,67 +0,0 @@ -""" Data module for pytorch lightning """ -import glob - -from ocf_datapipes.batch import BatchKey, batch_to_tensor, stack_np_examples_into_batch -from ocf_datapipes.training.pvnet_site import ( - pvnet_site_datapipe, - pvnet_site_netcdf_datapipe, - split_dataset_dict_dp, - uncombine_from_single_dataset, -) - -from pvnet.data.base import BaseDataModule - - -class PVSiteDataModule(BaseDataModule): - """Datamodule for training pvnet site and using pvnet site pipeline in `ocf_datapipes`.""" - - def _get_datapipe(self, start_time, end_time): - data_pipeline = pvnet_site_datapipe( - self.configuration, - start_time=start_time, - end_time=end_time, - ) - data_pipeline = data_pipeline.map(uncombine_from_single_dataset).map(split_dataset_dict_dp) - data_pipeline = data_pipeline.pvnet_site_convert_to_numpy_batch() - - data_pipeline = ( - data_pipeline.batch(self.batch_size) - .map(stack_np_examples_into_batch) - .map(batch_to_tensor) - ) - return data_pipeline - - def _get_premade_batches_datapipe(self, subdir, shuffle=False): - filenames = list(glob.glob(f"{self.batch_dir}/{subdir}/*.nc")) - data_pipeline = pvnet_site_netcdf_datapipe( - keys=["pv", "nwp"], # add other keys e.g. sat if used as input in site model - filenames=filenames, - ) - data_pipeline = ( - data_pipeline.batch(self.batch_size) - .map(stack_np_examples_into_batch) - .map(batch_to_tensor) - ) - if shuffle: - data_pipeline = ( - data_pipeline.shuffle(buffer_size=100) - .sharding_filter() - # Split the batches and reshuffle them to be combined into new batches - .split_batches(splitting_key=BatchKey.pv) - .shuffle(buffer_size=self.shuffle_factor * self.batch_size) - ) - else: - data_pipeline = ( - data_pipeline.sharding_filter() - # Split the batches so we can use any batch-size - .split_batches(splitting_key=BatchKey.pv) - ) - - data_pipeline = ( - data_pipeline.batch(self.batch_size) - .map(stack_np_examples_into_batch) - .map(batch_to_tensor) - .set_length(int(len(filenames) / self.batch_size)) - ) - - return data_pipeline diff --git a/pvnet/data/wind_datamodule.py b/pvnet/data/wind_datamodule.py deleted file mode 100644 index 0c11d31d..00000000 --- a/pvnet/data/wind_datamodule.py +++ /dev/null @@ -1,62 +0,0 @@ -""" Data module for pytorch lightning """ -import glob - -from ocf_datapipes.batch import BatchKey, batch_to_tensor, stack_np_examples_into_batch -from ocf_datapipes.training.windnet import windnet_netcdf_datapipe - -from pvnet.data.base import BaseDataModule - - -class WindDataModule(BaseDataModule): - """Datamodule for training windnet and using windnet pipeline in `ocf_datapipes`.""" - - def _get_datapipe(self, start_time, end_time): - # TODO is this is not right, need to load full windnet pipeline - data_pipeline = windnet_netcdf_datapipe( - self.configuration, - keys=["wind", "nwp", "sensor"], - ) - - data_pipeline = ( - data_pipeline.batch(self.batch_size) - .map(stack_np_examples_into_batch) - .map(batch_to_tensor) - ) - return data_pipeline - - def _get_premade_batches_datapipe(self, subdir, shuffle=False): - filenames = list(glob.glob(f"{self.batch_dir}/{subdir}/*.nc")) - data_pipeline = windnet_netcdf_datapipe( - keys=["wind", "nwp", "sensor"], - filenames=filenames, - nwp_channels=self.nwp_channels, - ) - - data_pipeline = ( - data_pipeline.batch(self.batch_size) - .map(stack_np_examples_into_batch) - .map(batch_to_tensor) - ) - if shuffle: - data_pipeline = ( - data_pipeline.shuffle(buffer_size=100) - .sharding_filter() - # Split the batches and reshuffle them to be combined into new batches - .split_batches(splitting_key=BatchKey.wind) - .shuffle(buffer_size=self.shuffle_factor * self.batch_size) - ) - else: - data_pipeline = ( - data_pipeline.sharding_filter() - # Split the batches so we can use any batch-size - .split_batches(splitting_key=BatchKey.wind) - ) - - data_pipeline = ( - data_pipeline.batch(self.batch_size) - .map(stack_np_examples_into_batch) - .map(batch_to_tensor) - .set_length(int(len(filenames) / self.batch_size)) - ) - - return data_pipeline diff --git a/pyproject.toml b/pyproject.toml index cc3ea1cf..b931d605 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ dynamic = ["version", "readme"] license={file="LICENCE"} dependencies = [ - "ocf_data_sampler==0.0.26", + "ocf_data_sampler==0.0.32", "ocf_datapipes>=3.3.34", "ocf_ml_metrics>=0.0.11", "numpy", diff --git a/scripts/save_samples.py b/scripts/save_samples.py index d38a45f9..2fdac0ee 100644 --- a/scripts/save_samples.py +++ b/scripts/save_samples.py @@ -44,7 +44,7 @@ import dask import hydra import torch -from ocf_data_sampler.torch_datasets.pvnet_uk_regional import PVNetUKRegionalDataset +from ocf_data_sampler.torch_datasets import PVNetUKRegionalDataset, SitesDataset from omegaconf import DictConfig, OmegaConf from sqlalchemy import exc as sa_exc from torch.utils.data import DataLoader, Dataset @@ -79,7 +79,7 @@ def __call__(self, sample, sample_num: int): if self.renewable == "pv": torch.save(sample, f"{self.save_dir}/{sample_num:08}.pt") elif self.renewable in ["wind", "pv_india", "pv_site"]: - raise NotImplementedError + sample.to_netcdf(f"{self.save_dir}/{sample_num:08}.nc", mode="w", engine="h5netcdf") else: raise ValueError(f"Unknown renewable: {self.renewable}") @@ -89,7 +89,7 @@ def get_dataset(config_path: str, start_time: str, end_time: str, renewable: str if renewable == "pv": dataset_cls = PVNetUKRegionalDataset elif renewable in ["wind", "pv_india", "pv_site"]: - raise NotImplementedError + dataset_cls = SitesDataset else: raise ValueError(f"Unknown renewable: {renewable}")