From 948e35e95e6b2e65d48a4816be01dfba34f6ff70 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 21 Dec 2023 14:41:48 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pvnet_summation/data/datamodule.py | 2 +- tests/conftest.py | 17 ++++++++--------- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/pvnet_summation/data/datamodule.py b/pvnet_summation/data/datamodule.py index 175e705..14d85ca 100644 --- a/pvnet_summation/data/datamodule.py +++ b/pvnet_summation/data/datamodule.py @@ -2,9 +2,9 @@ import torch from lightning.pytorch import LightningDataModule +from ocf_datapipes.batch import BatchKey from ocf_datapipes.load import OpenGSP from ocf_datapipes.training.pvnet import normalize_gsp -from ocf_datapipes.batch import BatchKey from torch.utils.data import DataLoader from torch.utils.data.datapipes.datapipe import IterDataPipe from torch.utils.data.datapipes.iter import FileLister, Zipper diff --git a/tests/conftest.py b/tests/conftest.py index 7d5e89a..e756706 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -20,10 +20,10 @@ def construct_batch_by_sample_duplication(og_batch, i): """From a batch of data, take the ith sample and repeat it 317 to create a new batch""" new_batch = {} - + # Need to loop through these keys and add to batch ununsed_keys = list(og_batch.keys()) - + # NWP is nested so needs to be treated differently if BatchKey.nwp in og_batch: og_nwp_batch = og_batch[BatchKey.nwp] @@ -38,10 +38,10 @@ def construct_batch_by_sample_duplication(og_batch, i): else: new_nwp_source_batch[key] = value new_nwp_batch[nwp_source] = new_nwp_source_batch - + new_batch[BatchKey.nwp] = new_nwp_batch ununsed_keys.remove(BatchKey.nwp) - + for key in ununsed_keys: if isinstance(og_batch[key], torch.Tensor): n_dims = len(og_batch[key].shape) @@ -49,7 +49,7 @@ def construct_batch_by_sample_duplication(og_batch, i): new_batch[key] = og_batch[key][i : i + 1].repeat(repeats)[:317] else: new_batch[key] = og_batch[key] - + return new_batch @@ -66,9 +66,8 @@ def sample_data(): file_n = 0 for file in glob.glob("tests/test_data/sample_batches/train/*.pt"): og_batch = torch.load(file) - + for i in range(og_batch[BatchKey.gsp_time_utc].shape[0]): - # Duplicate sample to fake 317 GSPs new_batch = construct_batch_by_sample_duplication(og_batch, i) @@ -79,7 +78,7 @@ def sample_data(): file_n += 1 times += [new_batch[BatchKey.gsp_time_utc][i].numpy().astype("datetime64[s]")] - + times = np.unique(np.sort(np.concatenate(times))) da_output = xr.DataArray( @@ -109,7 +108,7 @@ def sample_data(): ) ds.to_zarr(f"{tmpdirname}/gsp.zarr") - + yield tmpdirname, f"{tmpdirname}/gsp.zarr"