Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Add test for checking physical limits and zeroes in NWP data #… #340

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
3ee287c
chore: Add test for checking physical limits and zeroes in NWP data #…
glitch401 Jul 3, 2024
1e2df80
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 3, 2024
8105b91
changes to generate test data on the go. remove unnecessary zarr file…
glitch401 Jul 4, 2024
1eafe49
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 4, 2024
d5bc6cf
Fix ValueError message for NWP data containing zeros and outside phys…
glitch401 Jul 4, 2024
d8cfa9d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 4, 2024
5e68173
Fix ValueError message coding style
glitch401 Jul 4, 2024
466b710
update physical limits in according to pvnet_uk_region/data_config.yaml
glitch401 Jul 5, 2024
692500c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 5, 2024
0667bab
Update temperature physical limits in OpenNWPIterDataPipe
glitch401 Jul 5, 2024
246d898
Fix NaN check in stack_np_examples_into_batch function
glitch401 Jul 11, 2024
55627eb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 11, 2024
7ba254d
changes made to adapt for lazy loading
glitch401 Jul 16, 2024
c6ee33d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 16, 2024
d0c4f6f
moved limits to a constant file
glitch401 Jul 24, 2024
19050c7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 24, 2024
3fe89fc
Refactor test_merge_numpy_examples_to_batch.py and test_load_nwp.py t…
glitch401 Aug 15, 2024
ace0259
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ coverage.xml
.pytest_cache/
test.nc

#test data generator
tests/load/nwp/test_data_generator.py

glitch401 marked this conversation as resolved.
Show resolved Hide resolved
# Translations
*.mo
*.pot
Expand Down
12 changes: 11 additions & 1 deletion ocf_datapipes/batch/merge_numpy_examples_to_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def stack_np_examples_into_batch(dict_list: Sequence[NumpyBatch]) -> NumpyBatch:

nwp_batch[nwp_source] = nwp_source_batch

batch[BatchKey.nwp] = nwp_batch
batch[BatchKey.nwp] = check_for_nans(nwp_batch)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might be wrong, but I was under the impression that we allow for nans currently to be present in batches, which then get filled with zeroes during training? @peterdudfield is this a gsp thing?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we move this to when the NWP gets opened? And have an option to check it or not?
I think that would make it safer and clearer whats going on.

We could have a different issue that checks for nans in the batches, but we need to think how we turn that on and off .e.tc

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@glitch401 would you mind moving this to when the nwp is opened? with an option to do this or not.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mean, when data element NWP is opened?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like below, in the load stage

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gotcha, will append changes

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i was hoping this would be removed, and it would be mvoed to below

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is still open

else:
batch[batch_key] = stack_data_list(
Expand All @@ -101,6 +101,16 @@ def stack_np_examples_into_batch(dict_list: Sequence[NumpyBatch]) -> NumpyBatch:
return batch


def check_for_nans(batch: dict[str, NWPNumpyBatch]):
"""Check for NaNs in a batch"""
for keys in batch.keys():
for keys2 in batch[keys].keys():
if keys2 == NWPBatchKey.nwp:
if np.isnan(batch[keys][keys2]).any():
raise ValueError(f"NaNs found in {keys2}")
return batch


def unstack_np_batch_into_examples(batch: NumpyBatch):
"""Splits a single batch into samples.

Expand Down
97 changes: 96 additions & 1 deletion ocf_datapipes/load/nwp/nwp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pathlib import Path
from typing import Union

import dask
import xarray as xr
from ocf_blosc2 import Blosc2 # noqa: F401
from torch.utils.data import IterDataPipe, functional_datapipe
Expand All @@ -26,15 +27,54 @@ def __init__(
self,
zarr_path: Union[Path, str, list[Path], list[str]],
provider: str = "ukv",
check_for_zeros: bool = False,
check_physical_limits: bool = False,
):
"""
Opens NWP Zarr and yields it

Args:
zarr_path: Path to the Zarr file
provider: NWP provider
check_for_zeros: Check for zeros in the NWP data
check_physical_limits: Check the physical limits of nwp data (e.g. -100<temperature<100)
"""
self.zarr_path = zarr_path
self.check_for_zeros = check_for_zeros
self.check_physical_limits = check_physical_limits

# limits for NWP data in accordance with https://huggingface.co/openclimatefix/pvnet_uk_region/blob/main/data_config.yaml
self.limits = {
"t2m": (200, 350), # Temperature in Kelvin (-100°C to 60°C)
glitch401 marked this conversation as resolved.
Show resolved Hide resolved
"dswrf": (0, 1500), # Downward short-wave radiation flux, W/m^2
"dlwrf": (0, 750), # Downward long-wave radiation flux, W/m^2
"hcc": (0, 100), # High cloud cover, %
"mcc": (0, 100), # Medium cloud cover, %
"lcc": (0, 100), # Low cloud cover, %
"tcc": (0, 100), # Total cloud cover, %
"sde": (0, 1000), # Snowfall depth, meters
"duvrs": (0, 500), # Direct UV radiation at surface, W/m^2 (positive values only)
"u10": (-200, 200), # U component of 10m wind, m/s
"v10": (-200, 200), # V component of 10m wind, m/s
# UKV NWP channels (additional to ECMWF)
"prate": (0, 2000), # Precipitation rate, , kg/m^2/s (equivalent to 0-2000 mm/day)
"r": (0, 100), # Relative humidity, %
"si10": (0, 250), # Wind speed at 10m, m/s
"t": (200, 350), # Temperature in Kelvin (-100°C to 60°C)
"vis": (0, 100000), # Visibility, meters
# Satellite channels (no direct mapping to physical limits, using placeholder values)
"IR_016": (0, 1000), # Infrared channel
"IR_039": (0, 1000), # Infrared channel
"IR_087": (0, 1000), # Infrared channel
"IR_097": (0, 1000), # Infrared channel
"IR_108": (0, 1000), # Infrared channel
"IR_120": (0, 1000), # Infrared channel
"IR_134": (0, 1000), # Infrared channel
"VIS006": (0, 1000), # Visible channel
"VIS008": (0, 1000), # Visible channel
"WV_062": (0, 1000), # Water vapor channel
"WV_073": (0, 1000), # Water vapor channel
}
logger.info(f"Using {provider.lower()}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very much just a suggestion, but it would be nice to have some control over which variables receive the checks. Intuitively, that should probably be possible by just passing a list of keys to be checked instead of True to check_for_zeroes/check_physical_limits

if provider.lower() == "ukv":
self.open_nwp = open_ukv
Expand All @@ -53,9 +93,64 @@ def __init__(
else:
raise ValueError(f"Unknown provider: {provider}")

def __iter__(self) -> Union[xr.DataArray, xr.Dataset]:
def __iter__(self) -> Union[xr.DataArray, xr.Dataset]: # type: ignore
peterdudfield marked this conversation as resolved.
Show resolved Hide resolved
"""Opens the NWP data"""
logger.debug("Opening NWP data: %s", self.zarr_path)
nwp = self.open_nwp(self.zarr_path)
if self.check_for_zeros:
self.check_if_zeros(nwp)
if self.check_physical_limits:
self.check_if_physical_limits(nwp)
while True:
yield nwp

def check_if_zeros(self, nwp: Union[xr.DataArray, xr.Dataset]):
"""Checks if the NWP data contains zeros"""

def count_zeros(block):
return (block == 0).sum()

def check_zeros(result):
if result > 0:
raise ValueError(f"NWP data contains {result*100/nwp.size}% zeros")

if isinstance(nwp, xr.DataArray):
if dask.is_dask_collection(nwp.data):
zero_count = nwp.data.map_blocks(count_zeros, dtype=int).compute()
check_zeros(zero_count)
else:
if (nwp.values == 0).any():
raise ValueError(
f"NWP DataArray contains{(nwp.values == 0).sum()*100/nwp.values.size}% "
"zeros"
)
elif isinstance(nwp, xr.Dataset):
for var in nwp:
if dask.is_dask_collection(nwp[var].data):
zero_count = nwp[var].data.map_blocks(count_zeros, dtype=int).compute()
check_zeros(zero_count)
else:
if (nwp[var].values == 0).any():
raise ValueError(
f"NWP Dataset variable{var} "
f"contains {(nwp[var].values == 0).sum()*100/nwp[var].values.size}% "
"zeros"
)

def check_if_physical_limits(self, nwp: Union[xr.DataArray, xr.Dataset]):
"""Checks if the NWP data is within physical limits"""
if isinstance(nwp, xr.DataArray):
var_name = nwp.channel.values[0]
peterdudfield marked this conversation as resolved.
Show resolved Hide resolved
if var_name in self.limits:
lower, upper = self.limits[var_name]
if (nwp < lower).any() or (nwp > upper).any():
raise ValueError(
f"NWP data {var_name} is outside physical limits: ({lower},{upper})"
)
elif isinstance(nwp, xr.Dataset):
for var_name, (lower, upper) in self.limits.items():
if var_name in nwp.channel:
peterdudfield marked this conversation as resolved.
Show resolved Hide resolved
if not ((nwp[var_name] >= lower).all() and (nwp[var_name] <= upper).all()):
raise ValueError(
f"NWP data {var_name} is outside physical limits: ({lower},{upper})"
)
46 changes: 46 additions & 0 deletions tests/batch/test_merge_numpy_examples_to_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,49 @@ def _single_batch_sample(fill_value):
return sample


def _single_batch_sample_nan(fill_value):
"""This function allows us to create batches with different filled values"""

sample: NumpyBatch = {}
sample[BatchKey.satellite_actual] = np.full(
(12, 10, 24, 24), fill_value, dtype=np.float32
) # shape: (time, channel, x, y)
sample[BatchKey.gsp_id] = np.full((1,), fill_value) # shape: (1,)
sample[BatchKey.gsp_t0_idx] = 4 # scalar and constant across all samples

sample_nwp_ukv: NWPNumpyBatch = {}
sample_nwp_ukv[NWPBatchKey.nwp] = np.full(
(8, 2, 24, 24), fill_value, dtype=np.float32
) # shape: (time, variable, x, y)
sample_nwp_ukv[NWPBatchKey.nwp][0, 0, 0, 0] = np.nan

sample_nwp_ukv[NWPBatchKey.nwp_channel_names] = ["a", "b"] # shape: (variable,)

sample_nwp_ecmwf: NWPNumpyBatch = {}
sample_nwp_ecmwf[NWPBatchKey.nwp] = np.full(
(8, 4, 12, 12), fill_value
) # shape: (time, variable, x, y)

sample[BatchKey.nwp] = {
"ukv": sample_nwp_ukv,
"ecmwf": sample_nwp_ecmwf,
}
# print(sample[BatchKey.nwp]["ukv"])
return sample


@pytest.fixture
def numpy_sample_datapipe():
dp = IterableWrapper([_single_batch_sample(i) for i in range(8)])
return dp


@pytest.fixture
def numpy_nan_sample_datapipe():
dp = IterableWrapper([_single_batch_sample_nan(i) for i in range(8)])
return dp


def test_merge_numpy_batch(numpy_sample_datapipe):
dp = MergeNumpyBatchIterDataPipe(numpy_sample_datapipe.batch(4))
dp_iter = iter(dp)
Expand All @@ -62,6 +99,15 @@ def test_merge_numpy_batch(numpy_sample_datapipe):
assert nwp_batch[NWPBatchKey.nwp_channel_names] == ["a", "b"]


def test_merge_numpy_batch_for_nans(numpy_nan_sample_datapipe):
with pytest.raises(
ValueError
): # checks for Error raised if NWP/BatchKey DataArray contains Nans
dp = MergeNumpyBatchIterDataPipe(numpy_nan_sample_datapipe.batch(4))
dp_iter = iter(dp)
metadata = next(dp_iter)


def test_merge_numpy_examples_to_batch(numpy_sample_datapipe):
dp = MergeNumpyExamplesToBatchIterDataPipe(numpy_sample_datapipe, n_examples_per_batch=4)
dp_iter = iter(dp)
Expand Down
57 changes: 56 additions & 1 deletion tests/load/nwp/test_load_nwp.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import zarr
import shutil
import numpy as np
import pandas as pd
from xarray import DataArray

import pytest
from ocf_datapipes.load import OpenNWP


Expand Down Expand Up @@ -102,3 +105,55 @@ def test_load_excarta_local():
raise ValueError(
"The following dimensions are missing: %s" % (str(dim_keys - set(metadata.dims)))
)


def test_check_for_zeros():
# to generate data with zeros and limits:
original_store_path = "tests/data/nwp_data/test.zarr"
original_store = zarr.open(original_store_path, mode="r")
new_store_path = "tests/data/nwp_data/test_with_zeros_n_limits.zarr"
# Optionally, clear the destination store if it already exists
shutil.rmtree(new_store_path, ignore_errors=True)
with zarr.open(new_store_path, mode="w") as new_store:
for item in original_store:
zarr.copy(original_store[item], new_store, name=item)

new_store["UKV"][0, 0, 0, 0] = 0
new_store["UKV"][0, 0, 0, 1] = np.random.uniform(190, 360, size=(548,))
shutil.copy(
"tests/data/nwp_data/test.zarr/.zmetadata",
"tests/data/nwp_data/test_with_zeros_n_limits.zarr/.zmetadata",
)

# positive test case
nwp_datapipe1 = OpenNWP(
zarr_path=new_store_path,
check_for_zeros=True,
)
with pytest.raises(ValueError): # checks for Error raised if NWP DataArray contains zeros
metadata = next(iter(nwp_datapipe1))

# negative test case
nwp_datapipe2 = OpenNWP(zarr_path=original_store_path, check_for_zeros=True)
metadata = next(iter(nwp_datapipe2))
assert metadata is not None


def test_check_physical_limits():
# positive test case
nwp_datapipe1 = OpenNWP(
zarr_path="tests/data/nwp_data/test_with_zeros_n_limits.zarr", check_physical_limits=True
)
with pytest.raises(
ValueError
): # checks for Error raised if NWP data UKV is outside physical limits
metadata = next(iter(nwp_datapipe1))

# negative test case
nwp_datapipe2 = OpenNWP(zarr_path="tests/data/nwp_data/test.zarr", check_physical_limits=True)
metadata = next(iter(nwp_datapipe2))
assert metadata is not None

shutil.rmtree(
"tests/data/nwp_data/test_with_zeros_n_limits.zarr"
) # removes the zarr file created for testing