Skip to content

Commit

Permalink
update test batches
Browse files Browse the repository at this point in the history
  • Loading branch information
dfulu committed Dec 21, 2023
1 parent d7cd455 commit d5fb3ab
Show file tree
Hide file tree
Showing 6 changed files with 473 additions and 100 deletions.
66 changes: 48 additions & 18 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,42 @@
from pvnet_summation.data.datamodule import DataModule


def construct_batch_by_sample_duplication(og_batch, i):
"""From a batch of data, take the ith sample and repeat it 317 to create a new batch"""
new_batch = {}

# Need to loop through these keys and add to batch
ununsed_keys = list(og_batch.keys())

# NWP is nested so needs to be treated differently
if BatchKey.nwp in og_batch:
og_nwp_batch = og_batch[BatchKey.nwp]
new_nwp_batch = {}
for nwp_source, og_nwp_source_batch in og_nwp_batch.items():
new_nwp_source_batch = {}
for key, value in og_nwp_source_batch.items():
if isinstance(value, torch.Tensor):
n_dims = len(value.shape)
repeats = (317,) + tuple(1 for dim in range(n_dims - 1))
new_nwp_source_batch[key] = value[i : i + 1].repeat(repeats)[:317]
else:
new_nwp_source_batch[key] = value
new_nwp_batch[nwp_source] = new_nwp_source_batch

new_batch[BatchKey.nwp] = new_nwp_batch
ununsed_keys.remove(BatchKey.nwp)

for key in ununsed_keys:
if isinstance(og_batch[key], torch.Tensor):
n_dims = len(og_batch[key].shape)
repeats = (317,) + tuple(1 for dim in range(n_dims - 1))
new_batch[key] = og_batch[key][i : i + 1].repeat(repeats)[:317]
else:
new_batch[key] = og_batch[key]

return new_batch


@pytest.fixture()
def sample_data():
# Copy small batches to fake 317 GSPs in each
Expand All @@ -28,28 +64,22 @@ def sample_data():
times = []

file_n = 0
for file in glob.glob("tests/data/sample_batches/train/*.pt"):
batch = torch.load(file)

this_batch = {}
for i in range(batch[BatchKey.gsp_time_utc].shape[0]):
for file in glob.glob("tests/test_data/sample_batches/train/*.pt"):
og_batch = torch.load(file)
for i in range(og_batch[BatchKey.gsp_time_utc].shape[0]):
# Duplicate sample to fake 317 GSPs
for key in batch.keys():
if isinstance(batch[key], torch.Tensor):
n_dims = len(batch[key].shape)
repeats = (317,) + tuple(1 for dim in range(n_dims - 1))
this_batch[key] = batch[key][i : i + 1].repeat(repeats)[:317]
else:
this_batch[key] = batch[key]
new_batch = construct_batch_by_sample_duplication(og_batch, i)

# Save fopr both train and val
torch.save(this_batch, f"{tmpdirname}/train/{file_n:06}.pt")
torch.save(this_batch, f"{tmpdirname}/val/{file_n:06}.pt")
torch.save(new_batch, f"{tmpdirname}/train/{file_n:06}.pt")
torch.save(new_batch, f"{tmpdirname}/val/{file_n:06}.pt")

file_n += 1

times += [batch[BatchKey.gsp_time_utc][i].numpy().astype("datetime64[s]")]

times += [new_batch[BatchKey.gsp_time_utc][i].numpy().astype("datetime64[s]")]
times = np.unique(np.sort(np.concatenate(times)))

da_output = xr.DataArray(
Expand Down Expand Up @@ -79,7 +109,7 @@ def sample_data():
)

ds.to_zarr(f"{tmpdirname}/gsp.zarr")

yield tmpdirname, f"{tmpdirname}/gsp.zarr"


Expand Down Expand Up @@ -109,7 +139,7 @@ def model_kwargs():
# These kwargs define the pvnet model which the summation model uses
kwargs = dict(
model_name="openclimatefix/pvnet_v2",
model_version="805ca9b2ee3120592b0b70b7c75a454e2b4e4bec",
model_version="22e577100d55787eb2547d701275b9bb48f7bfa0",
)
return kwargs

Expand Down
81 changes: 0 additions & 81 deletions tests/data/sample_batches/data_configuration.yaml

This file was deleted.

Binary file removed tests/data/sample_batches/train/000000.pt
Binary file not shown.
Loading

0 comments on commit d5fb3ab

Please sign in to comment.