Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Dec 21, 2023
1 parent d5fb3ab commit 948e35e
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 10 deletions.
2 changes: 1 addition & 1 deletion pvnet_summation/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import torch
from lightning.pytorch import LightningDataModule
from ocf_datapipes.batch import BatchKey
from ocf_datapipes.load import OpenGSP
from ocf_datapipes.training.pvnet import normalize_gsp
from ocf_datapipes.batch import BatchKey
from torch.utils.data import DataLoader
from torch.utils.data.datapipes.datapipe import IterDataPipe
from torch.utils.data.datapipes.iter import FileLister, Zipper
Expand Down
17 changes: 8 additions & 9 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
def construct_batch_by_sample_duplication(og_batch, i):
"""From a batch of data, take the ith sample and repeat it 317 to create a new batch"""
new_batch = {}

# Need to loop through these keys and add to batch
ununsed_keys = list(og_batch.keys())

# NWP is nested so needs to be treated differently
if BatchKey.nwp in og_batch:
og_nwp_batch = og_batch[BatchKey.nwp]
Expand All @@ -38,18 +38,18 @@ def construct_batch_by_sample_duplication(og_batch, i):
else:
new_nwp_source_batch[key] = value
new_nwp_batch[nwp_source] = new_nwp_source_batch

new_batch[BatchKey.nwp] = new_nwp_batch
ununsed_keys.remove(BatchKey.nwp)

for key in ununsed_keys:
if isinstance(og_batch[key], torch.Tensor):
n_dims = len(og_batch[key].shape)
repeats = (317,) + tuple(1 for dim in range(n_dims - 1))
new_batch[key] = og_batch[key][i : i + 1].repeat(repeats)[:317]
else:
new_batch[key] = og_batch[key]

return new_batch


Expand All @@ -66,9 +66,8 @@ def sample_data():
file_n = 0
for file in glob.glob("tests/test_data/sample_batches/train/*.pt"):
og_batch = torch.load(file)

for i in range(og_batch[BatchKey.gsp_time_utc].shape[0]):

# Duplicate sample to fake 317 GSPs
new_batch = construct_batch_by_sample_duplication(og_batch, i)

Expand All @@ -79,7 +78,7 @@ def sample_data():
file_n += 1

times += [new_batch[BatchKey.gsp_time_utc][i].numpy().astype("datetime64[s]")]

times = np.unique(np.sort(np.concatenate(times)))

da_output = xr.DataArray(
Expand Down Expand Up @@ -109,7 +108,7 @@ def sample_data():
)

ds.to_zarr(f"{tmpdirname}/gsp.zarr")

yield tmpdirname, f"{tmpdirname}/gsp.zarr"


Expand Down

0 comments on commit 948e35e

Please sign in to comment.