Skip to content

Commit

Permalink
Merge pull request #14 from openclimatefix/datapipes_v3
Browse files Browse the repository at this point in the history
Upgrade to datapipes v3 #minor
  • Loading branch information
dfulu authored Dec 21, 2023
2 parents fcba032 + 0356c90 commit 3cdd785
Show file tree
Hide file tree
Showing 10 changed files with 476 additions and 103 deletions.
1 change: 1 addition & 0 deletions .github/workflows/workflows.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ jobs:
sudo_apt_install: "libgeos++-dev libproj-dev proj-data proj-bin"
# brew_install: "proj geos librttopo"
os_list: '["ubuntu-latest"]'
python-version: "['3.10', '3.11']"
2 changes: 1 addition & 1 deletion pvnet_summation/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import torch
from lightning.pytorch import LightningDataModule
from ocf_datapipes.batch import BatchKey
from ocf_datapipes.load import OpenGSP
from ocf_datapipes.training.pvnet import normalize_gsp
from ocf_datapipes.utils.consts import BatchKey
from torch.utils.data import DataLoader
from torch.utils.data.datapipes.datapipe import IterDataPipe
from torch.utils.data.datapipes.iter import FileLister, Zipper
Expand Down
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
nowcasting_utils
ocf_datapipes>=1.2.44
pvnet>=2.4.0
ocf_datapipes>=3.0.0
pvnet>=2.6.2
ocf_ml_metrics
numpy
pandas
matplotlib
xarray
ipykernel
h5netcdf
torch>=2.1.1
torch>=2.0.0
lightning>=2.0.1
torchdata
pytest
Expand Down
61 changes: 45 additions & 16 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,48 @@
from pvnet_summation.models.model import Model


from ocf_datapipes.utils.consts import BatchKey
from ocf_datapipes.batch import BatchKey
from datetime import timedelta

from pvnet_summation.data.datamodule import DataModule


def construct_batch_by_sample_duplication(og_batch, i):
"""From a batch of data, take the ith sample and repeat it 317 to create a new batch"""
new_batch = {}

# Need to loop through these keys and add to batch
ununsed_keys = list(og_batch.keys())

# NWP is nested so needs to be treated differently
if BatchKey.nwp in og_batch:
og_nwp_batch = og_batch[BatchKey.nwp]
new_nwp_batch = {}
for nwp_source, og_nwp_source_batch in og_nwp_batch.items():
new_nwp_source_batch = {}
for key, value in og_nwp_source_batch.items():
if isinstance(value, torch.Tensor):
n_dims = len(value.shape)
repeats = (317,) + tuple(1 for dim in range(n_dims - 1))
new_nwp_source_batch[key] = value[i : i + 1].repeat(repeats)[:317]
else:
new_nwp_source_batch[key] = value
new_nwp_batch[nwp_source] = new_nwp_source_batch

new_batch[BatchKey.nwp] = new_nwp_batch
ununsed_keys.remove(BatchKey.nwp)

for key in ununsed_keys:
if isinstance(og_batch[key], torch.Tensor):
n_dims = len(og_batch[key].shape)
repeats = (317,) + tuple(1 for dim in range(n_dims - 1))
new_batch[key] = og_batch[key][i : i + 1].repeat(repeats)[:317]
else:
new_batch[key] = og_batch[key]

return new_batch


@pytest.fixture()
def sample_data():
# Copy small batches to fake 317 GSPs in each
Expand All @@ -28,27 +64,20 @@ def sample_data():
times = []

file_n = 0
for file in glob.glob("tests/data/sample_batches/train/*.pt"):
batch = torch.load(file)
for file in glob.glob("tests/test_data/sample_batches/train/*.pt"):
og_batch = torch.load(file)

this_batch = {}
for i in range(batch[BatchKey.gsp_time_utc].shape[0]):
for i in range(og_batch[BatchKey.gsp_time_utc].shape[0]):
# Duplicate sample to fake 317 GSPs
for key in batch.keys():
if isinstance(batch[key], torch.Tensor):
n_dims = len(batch[key].shape)
repeats = (317,) + tuple(1 for dim in range(n_dims - 1))
this_batch[key] = batch[key][i : i + 1].repeat(repeats)[:317]
else:
this_batch[key] = batch[key]
new_batch = construct_batch_by_sample_duplication(og_batch, i)

# Save fopr both train and val
torch.save(this_batch, f"{tmpdirname}/train/{file_n:06}.pt")
torch.save(this_batch, f"{tmpdirname}/val/{file_n:06}.pt")
torch.save(new_batch, f"{tmpdirname}/train/{file_n:06}.pt")
torch.save(new_batch, f"{tmpdirname}/val/{file_n:06}.pt")

file_n += 1

times += [batch[BatchKey.gsp_time_utc][i].numpy().astype("datetime64[s]")]
times += [new_batch[BatchKey.gsp_time_utc][i].numpy().astype("datetime64[s]")]

times = np.unique(np.sort(np.concatenate(times)))

Expand Down Expand Up @@ -109,7 +138,7 @@ def model_kwargs():
# These kwargs define the pvnet model which the summation model uses
kwargs = dict(
model_name="openclimatefix/pvnet_v2",
model_version="805ca9b2ee3120592b0b70b7c75a454e2b4e4bec",
model_version="22e577100d55787eb2547d701275b9bb48f7bfa0",
)
return kwargs

Expand Down
81 changes: 0 additions & 81 deletions tests/data/sample_batches/data_configuration.yaml

This file was deleted.

Binary file removed tests/data/sample_batches/train/000000.pt
Binary file not shown.
2 changes: 1 addition & 1 deletion tests/data/test_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pvnet_summation.data.datamodule import DataModule
from ocf_datapipes.utils.consts import BatchKey
from ocf_datapipes.batch import BatchKey


def test_init(sample_data):
Expand Down
Loading

0 comments on commit 3cdd785

Please sign in to comment.