Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sample saving for Site Dataset #290

Open
wants to merge 15 commits into
base: dev-data-sampler
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,20 +145,20 @@ This is also where you can update the train, val & test periods to cover the dat

### Running the batch creation script

Run the `save_batches.py` script to create batches with the parameters specified in the datamodule config (`streamed_batches.yaml` in this example):
Run the `save_samples.py` script to create batches with the parameters specified in the datamodule config (`streamed_batches.yaml` in this example):

```bash
python scripts/save_batches.py
python scripts/save_samples.py
```
PVNet uses
[hydra](https://hydra.cc/) which enables us to pass variables via the command
line that will override the configuration defined in the `./configs` directory, like this:

```bash
python scripts/save_batches.py datamodule=streamed_batches datamodule.batch_output_dir="./output" datamodule.num_train_batches=10 datamodule.num_val_batches=5
python scripts/save_samples.py datamodule=streamed_batches datamodule.sample_output_dir="./output" datamodule.num_train_batches=10 datamodule.num_val_batches=5
```

`scripts/save_batches.py` needs a config under `PVNet/configs/datamodule`. You can adapt `streamed_batches.yaml` or create your own in the same folder.
`scripts/save_samples.py` needs a config under `PVNet/configs/datamodule`. You can adapt `streamed_batches.yaml` or create your own in the same folder.

If downloading private data from a GCP bucket make sure to authenticate gcloud (the public satellite data does not need authentication):

Expand Down Expand Up @@ -197,7 +197,7 @@ Make sure to update the following config files before training your model:
2. In `configs/model/local_multimodal.yaml`:
- update the list of encoders to reflect the data sources you are using. If you are using different NWP sources, the encoders for these should follow the same structure with two important updates:
- `in_channels`: number of variables your NWP source supplies
- `image_size_pixels`: spatial crop of your NWP data. It depends on the spatial resolution of your NWP; should match `nwp_image_size_pixels_height` and/or `nwp_image_size_pixels_width` in `datamodule/example_configs.yaml`, unless transformations such as coarsening was applied (e. g. as for ECMWF data)
- `image_size_pixels`: spatial crop of your NWP data. It depends on the spatial resolution of your NWP; should match `image_size_pixels_height` and/or `image_size_pixels_width` in `datamodule/configuration/site_example_configuration.yaml` for the NWP, unless transformations such as coarsening was applied (e. g. as for ECMWF data)
3. In `configs/local_trainer.yaml`:
- set `accelerator: 0` if running on a system without a supported GPU

Expand All @@ -216,7 +216,7 @@ defaults:
- hydra: default.yaml
```

Assuming you ran the `save_batches.py` script to generate some premade train and
Assuming you ran the `save_samples.py` script to generate some premade train and
val data batches, you can now train PVNet by running:

```
Expand Down
14 changes: 12 additions & 2 deletions pvnet/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch
from lightning.pytorch import LightningDataModule
from ocf_data_sampler.torch_datasets.pvnet_uk_regional import PVNetUKRegionalDataset
from ocf_data_sampler.torch_datasets import PVNetUKRegionalDataset, SitesDataset
from ocf_datapipes.batch import (
NumpyBatch,
TensorBatch,
Expand Down Expand Up @@ -93,7 +93,17 @@ def __init__(
)

def _get_streamed_samples_dataset(self, start_time, end_time) -> Dataset:
return PVNetUKRegionalDataset(self.configuration, start_time=start_time, end_time=end_time)
if self.configuration.renewable == "uk_pv":
return PVNetUKRegionalDataset(
self.configuration, start_time=start_time, end_time=end_time
)
elif self.configuration.renewable == "site":
return SitesDataset(self.configuration, start_time=start_time, end_time=end_time)
else:
raise ValueError(
f"Unknown renewable: {self.configuration.renewable}, "
"renewable value should either be uk_pv or site"
)

def _get_premade_samples_dataset(self, subdir) -> Dataset:
split_dir = f"{self.sample_dir}/{subdir}"
Expand Down
67 changes: 0 additions & 67 deletions pvnet/data/pv_site_datamodule.py

This file was deleted.

62 changes: 0 additions & 62 deletions pvnet/data/wind_datamodule.py

This file was deleted.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ dynamic = ["version", "readme"]
license={file="LICENCE"}

dependencies = [
"ocf_data_sampler==0.0.32",
"ocf_data_sampler==0.0.33",
"ocf_datapipes>=3.3.34",
"ocf_ml_metrics>=0.0.11",
"numpy",
Expand Down
6 changes: 3 additions & 3 deletions scripts/save_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
import dask
import hydra
import torch
from ocf_data_sampler.torch_datasets.pvnet_uk_regional import PVNetUKRegionalDataset
from ocf_data_sampler.torch_datasets import PVNetUKRegionalDataset, SitesDataset
from omegaconf import DictConfig, OmegaConf
from sqlalchemy import exc as sa_exc
from torch.utils.data import DataLoader, Dataset
Expand Down Expand Up @@ -79,7 +79,7 @@ def __call__(self, sample, sample_num: int):
if self.renewable == "pv":
torch.save(sample, f"{self.save_dir}/{sample_num:08}.pt")
elif self.renewable in ["wind", "pv_india", "pv_site"]:
raise NotImplementedError
sample.to_netcdf(f"{self.save_dir}/{sample_num:08}.nc", mode="w", engine="h5netcdf")
else:
raise ValueError(f"Unknown renewable: {self.renewable}")

Expand All @@ -89,7 +89,7 @@ def get_dataset(config_path: str, start_time: str, end_time: str, renewable: str
if renewable == "pv":
dataset_cls = PVNetUKRegionalDataset
elif renewable in ["wind", "pv_india", "pv_site"]:
raise NotImplementedError
dataset_cls = SitesDataset
else:
raise ValueError(f"Unknown renewable: {renewable}")

Expand Down
30 changes: 15 additions & 15 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import pvnet
from pvnet.data.datamodule import DataModule
from pvnet.data.wind_datamodule import WindDataModule

import pvnet.models.multimodal.encoders.encoders3d
import pvnet.models.multimodal.linear_networks.networks
Expand Down Expand Up @@ -159,20 +158,21 @@ def sample_pv_batch():
return torch.load("tests/test_data/presaved_batches/train/000000.pt")


@pytest.fixture()
def sample_wind_batch():
dm = WindDataModule(
configuration=None,
batch_size=2,
num_workers=0,
prefetch_factor=None,
train_period=[None, None],
val_period=[None, None],
test_period=[None, None],
batch_dir="tests/test_data/sample_wind_batches",
)
batch = next(iter(dm.train_dataloader()))
return batch
# TODO update this test once we add the loading logic for the Site dataset
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I cant work out if we should merge before doing this, im not sure there. Relates to openclimatefix/ocf-data-sampler#99

# @pytest.fixture()
# def sample_wind_batch():
# dm = WindDataModule(
# configuration=None,
# batch_size=2,
# num_workers=0,
# prefetch_factor=None,
# train_period=[None, None],
# val_period=[None, None],
# test_period=[None, None],
# batch_dir="tests/test_data/sample_wind_batches",
# )
# batch = next(iter(dm.train_dataloader()))
# return batch


@pytest.fixture()
Expand Down
59 changes: 3 additions & 56 deletions tests/data/test_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,4 @@
import pytest
from pvnet.data.datamodule import DataModule
from pvnet.data.wind_datamodule import WindDataModule
from pvnet.data.pv_site_datamodule import PVSiteDataModule
import os
from ocf_datapipes.batch.batches import BatchKey, NWPBatchKey


def test_init():
Expand All @@ -18,57 +13,6 @@ def test_init():
)


@pytest.mark.skip(reason="Has not been updated for ocf-data-sampler yet")
def test_wind_init():
dm = WindDataModule(
configuration=None,
batch_size=2,
num_workers=0,
prefetch_factor=None,
train_period=[None, None],
val_period=[None, None],
test_period=[None, None],
batch_dir="tests/data/sample_batches",
)


@pytest.mark.skip(reason="Has not been updated for ocf-data-sampler yet")
def test_wind_init_with_nwp_filter():
dm = WindDataModule(
configuration=None,
batch_size=2,
num_workers=0,
prefetch_factor=None,
train_period=[None, None],
val_period=[None, None],
test_period=[None, None],
batch_dir="tests/test_data/sample_wind_batches",
nwp_channels={"ecmwf": ["t2m", "v200"]},
)
dataloader = iter(dm.train_dataloader())

batch = next(dataloader)
batch_channels = batch[BatchKey.nwp]["ecmwf"][NWPBatchKey.nwp_channel_names]
print(batch_channels)
for v in ["t2m", "v200"]:
assert v in batch_channels
assert batch[BatchKey.nwp]["ecmwf"][NWPBatchKey.nwp].shape[2] == 2


@pytest.mark.skip(reason="Has not been updated for ocf-data-sampler yet")
def test_pv_site_init():
dm = PVSiteDataModule(
configuration=f"{os.path.dirname(os.path.abspath(__file__))}/test_data/sample_batches/data_configuration.yaml",
batch_size=2,
num_workers=0,
prefetch_factor=None,
train_period=[None, None],
val_period=[None, None],
test_period=[None, None],
batch_dir=None,
)


def test_iter():
dm = DataModule(
configuration=None,
Expand Down Expand Up @@ -104,3 +48,6 @@ def test_iter_multiprocessing():

# Make sure we've served 2 batches
assert served_batches == 2


# TODO add test cases with some netcdfs premade samples
15 changes: 8 additions & 7 deletions tests/models/multimodal/site_encoders/test_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,14 @@ def test_singleattentionnetwork_forward(sample_pv_batch, site_encoder_model_kwar
)


def test_singleattentionnetwork_forward_4d(sample_wind_batch, site_encoder_sensor_model_kwargs):
_test_model_forward(
sample_wind_batch,
SingleAttentionNetwork,
site_encoder_sensor_model_kwargs,
batch_size=2,
)
# TODO once we have updated the sample batches for sites include this test
# def test_singleattentionnetwork_forward_4d(sample_wind_batch, site_encoder_sensor_model_kwargs):
# _test_model_forward(
# sample_wind_batch,
# SingleAttentionNetwork,
# site_encoder_sensor_model_kwargs,
# batch_size=2,
# )


# Test model backward on all models
Expand Down
Loading