diff --git a/README.md b/README.md index 051e9a2e..69f54c64 100644 --- a/README.md +++ b/README.md @@ -145,20 +145,20 @@ This is also where you can update the train, val & test periods to cover the dat ### Running the batch creation script -Run the `save_batches.py` script to create batches with the parameters specified in the datamodule config (`streamed_batches.yaml` in this example): +Run the `save_samples.py` script to create batches with the parameters specified in the datamodule config (`streamed_batches.yaml` in this example): ```bash -python scripts/save_batches.py +python scripts/save_samples.py ``` PVNet uses [hydra](https://hydra.cc/) which enables us to pass variables via the command line that will override the configuration defined in the `./configs` directory, like this: ```bash -python scripts/save_batches.py datamodule=streamed_batches datamodule.batch_output_dir="./output" datamodule.num_train_batches=10 datamodule.num_val_batches=5 +python scripts/save_samples.py datamodule=streamed_batches datamodule.sample_output_dir="./output" datamodule.num_train_batches=10 datamodule.num_val_batches=5 ``` -`scripts/save_batches.py` needs a config under `PVNet/configs/datamodule`. You can adapt `streamed_batches.yaml` or create your own in the same folder. +`scripts/save_samples.py` needs a config under `PVNet/configs/datamodule`. You can adapt `streamed_batches.yaml` or create your own in the same folder. If downloading private data from a GCP bucket make sure to authenticate gcloud (the public satellite data does not need authentication): @@ -197,7 +197,7 @@ Make sure to update the following config files before training your model: 2. In `configs/model/local_multimodal.yaml`: - update the list of encoders to reflect the data sources you are using. If you are using different NWP sources, the encoders for these should follow the same structure with two important updates: - `in_channels`: number of variables your NWP source supplies - - `image_size_pixels`: spatial crop of your NWP data. It depends on the spatial resolution of your NWP; should match `nwp_image_size_pixels_height` and/or `nwp_image_size_pixels_width` in `datamodule/example_configs.yaml`, unless transformations such as coarsening was applied (e. g. as for ECMWF data) + - `image_size_pixels`: spatial crop of your NWP data. It depends on the spatial resolution of your NWP; should match `image_size_pixels_height` and/or `image_size_pixels_width` in `datamodule/configuration/site_example_configuration.yaml` for the NWP, unless transformations such as coarsening was applied (e. g. as for ECMWF data) 3. In `configs/local_trainer.yaml`: - set `accelerator: 0` if running on a system without a supported GPU @@ -216,7 +216,7 @@ defaults: - hydra: default.yaml ``` -Assuming you ran the `save_batches.py` script to generate some premade train and +Assuming you ran the `save_samples.py` script to generate some premade train and val data batches, you can now train PVNet by running: ``` diff --git a/pvnet/data/datamodule.py b/pvnet/data/datamodule.py index b502113a..9708fbdd 100644 --- a/pvnet/data/datamodule.py +++ b/pvnet/data/datamodule.py @@ -3,7 +3,7 @@ import torch from lightning.pytorch import LightningDataModule -from ocf_data_sampler.torch_datasets.pvnet_uk_regional import PVNetUKRegionalDataset +from ocf_data_sampler.torch_datasets import PVNetUKRegionalDataset, SitesDataset from ocf_datapipes.batch import ( NumpyBatch, TensorBatch, @@ -93,7 +93,17 @@ def __init__( ) def _get_streamed_samples_dataset(self, start_time, end_time) -> Dataset: - return PVNetUKRegionalDataset(self.configuration, start_time=start_time, end_time=end_time) + if self.configuration.renewable == "uk_pv": + return PVNetUKRegionalDataset( + self.configuration, start_time=start_time, end_time=end_time + ) + elif self.configuration.renewable == "site": + return SitesDataset(self.configuration, start_time=start_time, end_time=end_time) + else: + raise ValueError( + f"Unknown renewable: {self.configuration.renewable}, " + "renewable value should either be uk_pv or site" + ) def _get_premade_samples_dataset(self, subdir) -> Dataset: split_dir = f"{self.sample_dir}/{subdir}" diff --git a/pvnet/data/pv_site_datamodule.py b/pvnet/data/pv_site_datamodule.py deleted file mode 100644 index 4c45eaec..00000000 --- a/pvnet/data/pv_site_datamodule.py +++ /dev/null @@ -1,67 +0,0 @@ -""" Data module for pytorch lightning """ -import glob - -from ocf_datapipes.batch import BatchKey, batch_to_tensor, stack_np_examples_into_batch -from ocf_datapipes.training.pvnet_site import ( - pvnet_site_datapipe, - pvnet_site_netcdf_datapipe, - split_dataset_dict_dp, - uncombine_from_single_dataset, -) - -from pvnet.data.base import BaseDataModule - - -class PVSiteDataModule(BaseDataModule): - """Datamodule for training pvnet site and using pvnet site pipeline in `ocf_datapipes`.""" - - def _get_datapipe(self, start_time, end_time): - data_pipeline = pvnet_site_datapipe( - self.configuration, - start_time=start_time, - end_time=end_time, - ) - data_pipeline = data_pipeline.map(uncombine_from_single_dataset).map(split_dataset_dict_dp) - data_pipeline = data_pipeline.pvnet_site_convert_to_numpy_batch() - - data_pipeline = ( - data_pipeline.batch(self.batch_size) - .map(stack_np_examples_into_batch) - .map(batch_to_tensor) - ) - return data_pipeline - - def _get_premade_batches_datapipe(self, subdir, shuffle=False): - filenames = list(glob.glob(f"{self.batch_dir}/{subdir}/*.nc")) - data_pipeline = pvnet_site_netcdf_datapipe( - keys=["pv", "nwp"], # add other keys e.g. sat if used as input in site model - filenames=filenames, - ) - data_pipeline = ( - data_pipeline.batch(self.batch_size) - .map(stack_np_examples_into_batch) - .map(batch_to_tensor) - ) - if shuffle: - data_pipeline = ( - data_pipeline.shuffle(buffer_size=100) - .sharding_filter() - # Split the batches and reshuffle them to be combined into new batches - .split_batches(splitting_key=BatchKey.pv) - .shuffle(buffer_size=self.shuffle_factor * self.batch_size) - ) - else: - data_pipeline = ( - data_pipeline.sharding_filter() - # Split the batches so we can use any batch-size - .split_batches(splitting_key=BatchKey.pv) - ) - - data_pipeline = ( - data_pipeline.batch(self.batch_size) - .map(stack_np_examples_into_batch) - .map(batch_to_tensor) - .set_length(int(len(filenames) / self.batch_size)) - ) - - return data_pipeline diff --git a/pvnet/data/wind_datamodule.py b/pvnet/data/wind_datamodule.py deleted file mode 100644 index 0c11d31d..00000000 --- a/pvnet/data/wind_datamodule.py +++ /dev/null @@ -1,62 +0,0 @@ -""" Data module for pytorch lightning """ -import glob - -from ocf_datapipes.batch import BatchKey, batch_to_tensor, stack_np_examples_into_batch -from ocf_datapipes.training.windnet import windnet_netcdf_datapipe - -from pvnet.data.base import BaseDataModule - - -class WindDataModule(BaseDataModule): - """Datamodule for training windnet and using windnet pipeline in `ocf_datapipes`.""" - - def _get_datapipe(self, start_time, end_time): - # TODO is this is not right, need to load full windnet pipeline - data_pipeline = windnet_netcdf_datapipe( - self.configuration, - keys=["wind", "nwp", "sensor"], - ) - - data_pipeline = ( - data_pipeline.batch(self.batch_size) - .map(stack_np_examples_into_batch) - .map(batch_to_tensor) - ) - return data_pipeline - - def _get_premade_batches_datapipe(self, subdir, shuffle=False): - filenames = list(glob.glob(f"{self.batch_dir}/{subdir}/*.nc")) - data_pipeline = windnet_netcdf_datapipe( - keys=["wind", "nwp", "sensor"], - filenames=filenames, - nwp_channels=self.nwp_channels, - ) - - data_pipeline = ( - data_pipeline.batch(self.batch_size) - .map(stack_np_examples_into_batch) - .map(batch_to_tensor) - ) - if shuffle: - data_pipeline = ( - data_pipeline.shuffle(buffer_size=100) - .sharding_filter() - # Split the batches and reshuffle them to be combined into new batches - .split_batches(splitting_key=BatchKey.wind) - .shuffle(buffer_size=self.shuffle_factor * self.batch_size) - ) - else: - data_pipeline = ( - data_pipeline.sharding_filter() - # Split the batches so we can use any batch-size - .split_batches(splitting_key=BatchKey.wind) - ) - - data_pipeline = ( - data_pipeline.batch(self.batch_size) - .map(stack_np_examples_into_batch) - .map(batch_to_tensor) - .set_length(int(len(filenames) / self.batch_size)) - ) - - return data_pipeline diff --git a/pyproject.toml b/pyproject.toml index b931d605..6c33954f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ dynamic = ["version", "readme"] license={file="LICENCE"} dependencies = [ - "ocf_data_sampler==0.0.32", + "ocf_data_sampler==0.0.33", "ocf_datapipes>=3.3.34", "ocf_ml_metrics>=0.0.11", "numpy", diff --git a/scripts/save_samples.py b/scripts/save_samples.py index d38a45f9..2fdac0ee 100644 --- a/scripts/save_samples.py +++ b/scripts/save_samples.py @@ -44,7 +44,7 @@ import dask import hydra import torch -from ocf_data_sampler.torch_datasets.pvnet_uk_regional import PVNetUKRegionalDataset +from ocf_data_sampler.torch_datasets import PVNetUKRegionalDataset, SitesDataset from omegaconf import DictConfig, OmegaConf from sqlalchemy import exc as sa_exc from torch.utils.data import DataLoader, Dataset @@ -79,7 +79,7 @@ def __call__(self, sample, sample_num: int): if self.renewable == "pv": torch.save(sample, f"{self.save_dir}/{sample_num:08}.pt") elif self.renewable in ["wind", "pv_india", "pv_site"]: - raise NotImplementedError + sample.to_netcdf(f"{self.save_dir}/{sample_num:08}.nc", mode="w", engine="h5netcdf") else: raise ValueError(f"Unknown renewable: {self.renewable}") @@ -89,7 +89,7 @@ def get_dataset(config_path: str, start_time: str, end_time: str, renewable: str if renewable == "pv": dataset_cls = PVNetUKRegionalDataset elif renewable in ["wind", "pv_india", "pv_site"]: - raise NotImplementedError + dataset_cls = SitesDataset else: raise ValueError(f"Unknown renewable: {renewable}") diff --git a/tests/conftest.py b/tests/conftest.py index ba657af5..89dbff75 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,7 +14,6 @@ import pvnet from pvnet.data.datamodule import DataModule -from pvnet.data.wind_datamodule import WindDataModule import pvnet.models.multimodal.encoders.encoders3d import pvnet.models.multimodal.linear_networks.networks @@ -159,20 +158,21 @@ def sample_pv_batch(): return torch.load("tests/test_data/presaved_batches/train/000000.pt") -@pytest.fixture() -def sample_wind_batch(): - dm = WindDataModule( - configuration=None, - batch_size=2, - num_workers=0, - prefetch_factor=None, - train_period=[None, None], - val_period=[None, None], - test_period=[None, None], - batch_dir="tests/test_data/sample_wind_batches", - ) - batch = next(iter(dm.train_dataloader())) - return batch +# TODO update this test once we add the loading logic for the Site dataset +# @pytest.fixture() +# def sample_wind_batch(): +# dm = WindDataModule( +# configuration=None, +# batch_size=2, +# num_workers=0, +# prefetch_factor=None, +# train_period=[None, None], +# val_period=[None, None], +# test_period=[None, None], +# batch_dir="tests/test_data/sample_wind_batches", +# ) +# batch = next(iter(dm.train_dataloader())) +# return batch @pytest.fixture() diff --git a/tests/data/test_datamodule.py b/tests/data/test_datamodule.py index 00b1705d..cf93372c 100644 --- a/tests/data/test_datamodule.py +++ b/tests/data/test_datamodule.py @@ -1,9 +1,4 @@ -import pytest from pvnet.data.datamodule import DataModule -from pvnet.data.wind_datamodule import WindDataModule -from pvnet.data.pv_site_datamodule import PVSiteDataModule -import os -from ocf_datapipes.batch.batches import BatchKey, NWPBatchKey def test_init(): @@ -18,57 +13,6 @@ def test_init(): ) -@pytest.mark.skip(reason="Has not been updated for ocf-data-sampler yet") -def test_wind_init(): - dm = WindDataModule( - configuration=None, - batch_size=2, - num_workers=0, - prefetch_factor=None, - train_period=[None, None], - val_period=[None, None], - test_period=[None, None], - batch_dir="tests/data/sample_batches", - ) - - -@pytest.mark.skip(reason="Has not been updated for ocf-data-sampler yet") -def test_wind_init_with_nwp_filter(): - dm = WindDataModule( - configuration=None, - batch_size=2, - num_workers=0, - prefetch_factor=None, - train_period=[None, None], - val_period=[None, None], - test_period=[None, None], - batch_dir="tests/test_data/sample_wind_batches", - nwp_channels={"ecmwf": ["t2m", "v200"]}, - ) - dataloader = iter(dm.train_dataloader()) - - batch = next(dataloader) - batch_channels = batch[BatchKey.nwp]["ecmwf"][NWPBatchKey.nwp_channel_names] - print(batch_channels) - for v in ["t2m", "v200"]: - assert v in batch_channels - assert batch[BatchKey.nwp]["ecmwf"][NWPBatchKey.nwp].shape[2] == 2 - - -@pytest.mark.skip(reason="Has not been updated for ocf-data-sampler yet") -def test_pv_site_init(): - dm = PVSiteDataModule( - configuration=f"{os.path.dirname(os.path.abspath(__file__))}/test_data/sample_batches/data_configuration.yaml", - batch_size=2, - num_workers=0, - prefetch_factor=None, - train_period=[None, None], - val_period=[None, None], - test_period=[None, None], - batch_dir=None, - ) - - def test_iter(): dm = DataModule( configuration=None, @@ -104,3 +48,6 @@ def test_iter_multiprocessing(): # Make sure we've served 2 batches assert served_batches == 2 + + +# TODO add test cases with some netcdfs premade samples diff --git a/tests/models/multimodal/site_encoders/test_encoders.py b/tests/models/multimodal/site_encoders/test_encoders.py index 41969b22..48938bc3 100644 --- a/tests/models/multimodal/site_encoders/test_encoders.py +++ b/tests/models/multimodal/site_encoders/test_encoders.py @@ -42,13 +42,14 @@ def test_singleattentionnetwork_forward(sample_pv_batch, site_encoder_model_kwar ) -def test_singleattentionnetwork_forward_4d(sample_wind_batch, site_encoder_sensor_model_kwargs): - _test_model_forward( - sample_wind_batch, - SingleAttentionNetwork, - site_encoder_sensor_model_kwargs, - batch_size=2, - ) +# TODO once we have updated the sample batches for sites include this test +# def test_singleattentionnetwork_forward_4d(sample_wind_batch, site_encoder_sensor_model_kwargs): +# _test_model_forward( +# sample_wind_batch, +# SingleAttentionNetwork, +# site_encoder_sensor_model_kwargs, +# batch_size=2, +# ) # Test model backward on all models