Skip to content

Commit

Permalink
Update check_for_nans.py
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobbieker authored Nov 7, 2023
1 parent 84908d3 commit 7e9bcfa
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions ocf_datapipes/validation/check_for_nans.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class CheckNaNsIterDataPipe(IterDataPipe):
"""Checks, and optionally fills, NaNs in Xarray Dataset"""

def __init__(
self, source_datapipe: IterDataPipe, dataset_name: str = None, fill_nans: bool = False
self, source_datapipe: IterDataPipe, dataset_name: str = None, fill_nans: bool = False, fill_value: float = 0.0
):
"""
Checks and optionally fills NaNs in the data
Expand All @@ -26,6 +26,7 @@ def __init__(
self.dataset_name = dataset_name
self.fill_nans = fill_nans
self.source_datapipe_name = source_datapipe.__repr__()
self.fill_value = fill_value

def __iter__(self) -> Union[xr.DataArray, xr.Dataset]:
"""
Expand All @@ -37,10 +38,10 @@ def __iter__(self) -> Union[xr.DataArray, xr.Dataset]:
for xr_data in self.source_datapipe:
if self.fill_nans:
if self.dataset_name is None:
xr_data = self.check_nan_and_fill_warning(data=xr_data)
xr_data = self.fill_nan(data=xr_data, fill_value=self.fill_value)
else:
xr_data[self.dataset_name] = self.check_nan_and_fill_warning(
data=xr_data[self.dataset_name]
xr_data[self.dataset_name] = self.fill_nan(

Check warning on line 43 in ocf_datapipes/validation/check_for_nans.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/validation/check_for_nans.py#L43

Added line #L43 was not covered by tests
data=xr_data[self.dataset_name], fill_value=self.fill_value,
)
self.check_nan_and_inf(
data=xr_data if self.dataset_name is None else xr_data[self.dataset_name],
Expand Down Expand Up @@ -69,10 +70,10 @@ def check_nan_and_inf(self, data: xr.Dataset) -> None:
message = f"Some data values are Infinite in datapipe {self.datapipe_name}."
raise Warning(message)

Check warning on line 71 in ocf_datapipes/validation/check_for_nans.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/validation/check_for_nans.py#L70-L71

Added lines #L70 - L71 were not covered by tests

def check_nan_and_fill_warning(self, data: xr.Dataset) -> xr.Dataset:
def fill_nan(self, data: xr.Dataset, fill_value: float = 0.0) -> xr.Dataset:
"""Check that all values are non NaNs and not infinite"""

if np.isnan(data).any():
data = data.fillna(0)
data = data.fillna(fill_value)

return data

0 comments on commit 7e9bcfa

Please sign in to comment.