diff --git a/pvnet_summation/data/datamodule.py b/pvnet_summation/data/datamodule.py index 175e705..14d85ca 100644 --- a/pvnet_summation/data/datamodule.py +++ b/pvnet_summation/data/datamodule.py @@ -2,9 +2,9 @@ import torch from lightning.pytorch import LightningDataModule +from ocf_datapipes.batch import BatchKey from ocf_datapipes.load import OpenGSP from ocf_datapipes.training.pvnet import normalize_gsp -from ocf_datapipes.batch import BatchKey from torch.utils.data import DataLoader from torch.utils.data.datapipes.datapipe import IterDataPipe from torch.utils.data.datapipes.iter import FileLister, Zipper diff --git a/tests/conftest.py b/tests/conftest.py index 7d5e89a..e756706 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -20,10 +20,10 @@ def construct_batch_by_sample_duplication(og_batch, i): """From a batch of data, take the ith sample and repeat it 317 to create a new batch""" new_batch = {} - + # Need to loop through these keys and add to batch ununsed_keys = list(og_batch.keys()) - + # NWP is nested so needs to be treated differently if BatchKey.nwp in og_batch: og_nwp_batch = og_batch[BatchKey.nwp] @@ -38,10 +38,10 @@ def construct_batch_by_sample_duplication(og_batch, i): else: new_nwp_source_batch[key] = value new_nwp_batch[nwp_source] = new_nwp_source_batch - + new_batch[BatchKey.nwp] = new_nwp_batch ununsed_keys.remove(BatchKey.nwp) - + for key in ununsed_keys: if isinstance(og_batch[key], torch.Tensor): n_dims = len(og_batch[key].shape) @@ -49,7 +49,7 @@ def construct_batch_by_sample_duplication(og_batch, i): new_batch[key] = og_batch[key][i : i + 1].repeat(repeats)[:317] else: new_batch[key] = og_batch[key] - + return new_batch @@ -66,9 +66,8 @@ def sample_data(): file_n = 0 for file in glob.glob("tests/test_data/sample_batches/train/*.pt"): og_batch = torch.load(file) - + for i in range(og_batch[BatchKey.gsp_time_utc].shape[0]): - # Duplicate sample to fake 317 GSPs new_batch = construct_batch_by_sample_duplication(og_batch, i) @@ -79,7 +78,7 @@ def sample_data(): file_n += 1 times += [new_batch[BatchKey.gsp_time_utc][i].numpy().astype("datetime64[s]")] - + times = np.unique(np.sort(np.concatenate(times))) da_output = xr.DataArray( @@ -109,7 +108,7 @@ def sample_data(): ) ds.to_zarr(f"{tmpdirname}/gsp.zarr") - + yield tmpdirname, f"{tmpdirname}/gsp.zarr"