Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade to datapipes v3 #14

Merged
merged 5 commits into from
Dec 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/workflows.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ jobs:
sudo_apt_install: "libgeos++-dev libproj-dev proj-data proj-bin"
# brew_install: "proj geos librttopo"
os_list: '["ubuntu-latest"]'
python-version: "['3.10', '3.11']"
2 changes: 1 addition & 1 deletion pvnet_summation/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import torch
from lightning.pytorch import LightningDataModule
from ocf_datapipes.batch import BatchKey
from ocf_datapipes.load import OpenGSP
from ocf_datapipes.training.pvnet import normalize_gsp
from ocf_datapipes.utils.consts import BatchKey
from torch.utils.data import DataLoader
from torch.utils.data.datapipes.datapipe import IterDataPipe
from torch.utils.data.datapipes.iter import FileLister, Zipper
Expand Down
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
nowcasting_utils
ocf_datapipes>=1.2.44
pvnet>=2.4.0
ocf_datapipes>=3.0.0
pvnet>=2.6.2
ocf_ml_metrics
numpy
pandas
matplotlib
xarray
ipykernel
h5netcdf
torch>=2.1.1
torch>=2.0.0
lightning>=2.0.1
torchdata
pytest
Expand Down
61 changes: 45 additions & 16 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,48 @@
from pvnet_summation.models.model import Model


from ocf_datapipes.utils.consts import BatchKey
from ocf_datapipes.batch import BatchKey
from datetime import timedelta

from pvnet_summation.data.datamodule import DataModule


def construct_batch_by_sample_duplication(og_batch, i):
"""From a batch of data, take the ith sample and repeat it 317 to create a new batch"""
new_batch = {}

# Need to loop through these keys and add to batch
ununsed_keys = list(og_batch.keys())

# NWP is nested so needs to be treated differently
if BatchKey.nwp in og_batch:
og_nwp_batch = og_batch[BatchKey.nwp]
new_nwp_batch = {}
for nwp_source, og_nwp_source_batch in og_nwp_batch.items():
new_nwp_source_batch = {}
for key, value in og_nwp_source_batch.items():
if isinstance(value, torch.Tensor):
n_dims = len(value.shape)
repeats = (317,) + tuple(1 for dim in range(n_dims - 1))
new_nwp_source_batch[key] = value[i : i + 1].repeat(repeats)[:317]
else:
new_nwp_source_batch[key] = value
new_nwp_batch[nwp_source] = new_nwp_source_batch

new_batch[BatchKey.nwp] = new_nwp_batch
ununsed_keys.remove(BatchKey.nwp)

for key in ununsed_keys:
if isinstance(og_batch[key], torch.Tensor):
n_dims = len(og_batch[key].shape)
repeats = (317,) + tuple(1 for dim in range(n_dims - 1))
new_batch[key] = og_batch[key][i : i + 1].repeat(repeats)[:317]
else:
new_batch[key] = og_batch[key]

return new_batch


@pytest.fixture()
def sample_data():
# Copy small batches to fake 317 GSPs in each
Expand All @@ -28,27 +64,20 @@ def sample_data():
times = []

file_n = 0
for file in glob.glob("tests/data/sample_batches/train/*.pt"):
batch = torch.load(file)
for file in glob.glob("tests/test_data/sample_batches/train/*.pt"):
og_batch = torch.load(file)

this_batch = {}
for i in range(batch[BatchKey.gsp_time_utc].shape[0]):
for i in range(og_batch[BatchKey.gsp_time_utc].shape[0]):
# Duplicate sample to fake 317 GSPs
for key in batch.keys():
if isinstance(batch[key], torch.Tensor):
n_dims = len(batch[key].shape)
repeats = (317,) + tuple(1 for dim in range(n_dims - 1))
this_batch[key] = batch[key][i : i + 1].repeat(repeats)[:317]
else:
this_batch[key] = batch[key]
new_batch = construct_batch_by_sample_duplication(og_batch, i)

# Save fopr both train and val
torch.save(this_batch, f"{tmpdirname}/train/{file_n:06}.pt")
torch.save(this_batch, f"{tmpdirname}/val/{file_n:06}.pt")
torch.save(new_batch, f"{tmpdirname}/train/{file_n:06}.pt")
torch.save(new_batch, f"{tmpdirname}/val/{file_n:06}.pt")

file_n += 1

times += [batch[BatchKey.gsp_time_utc][i].numpy().astype("datetime64[s]")]
times += [new_batch[BatchKey.gsp_time_utc][i].numpy().astype("datetime64[s]")]

times = np.unique(np.sort(np.concatenate(times)))

Expand Down Expand Up @@ -109,7 +138,7 @@ def model_kwargs():
# These kwargs define the pvnet model which the summation model uses
kwargs = dict(
model_name="openclimatefix/pvnet_v2",
model_version="805ca9b2ee3120592b0b70b7c75a454e2b4e4bec",
model_version="22e577100d55787eb2547d701275b9bb48f7bfa0",
)
return kwargs

Expand Down
81 changes: 0 additions & 81 deletions tests/data/sample_batches/data_configuration.yaml

This file was deleted.

Binary file removed tests/data/sample_batches/train/000000.pt
Binary file not shown.
2 changes: 1 addition & 1 deletion tests/data/test_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pvnet_summation.data.datamodule import DataModule
from ocf_datapipes.utils.consts import BatchKey
from ocf_datapipes.batch import BatchKey


def test_init(sample_data):
Expand Down
Loading
Loading