diff --git a/src/pypromice/process/value_clipping.py b/src/pypromice/process/value_clipping.py index 6ff60c54..020e7317 100644 --- a/src/pypromice/process/value_clipping.py +++ b/src/pypromice/process/value_clipping.py @@ -1,8 +1,5 @@ -from typing import Dict, Set, Mapping - import numpy as np import pandas -import pandas as pd import xarray from pypromice.utilities.dependency_graph import DependencyGraph @@ -31,8 +28,6 @@ def clip_values( """ cols = ["lo", "hi", "OOL"] assert set(cols) <= set(var_configurations.columns) - # TODO: Check if this is necessary - # variable_limits = var_configurations[cols].dropna(how="all") variable_limits = var_configurations[cols].assign( dependents=lambda df: df.OOL.fillna("").str.split(), @@ -45,19 +40,16 @@ def clip_values( for var, row in variable_limits.iterrows(): if var not in list(ds.variables): continue - # TODO: Check if this is necessary - # I guess the nan flagging is already handled below - # What if rh_u_cor is nan? - # What if row.lo/hi is nan? + # This is a special case for rh_u_cor and rh_l_cor where values are clipped to 0 and 100. if var in ["rh_u_cor", "rh_l_cor"]: - ds[var] = ds[var].where(ds[var] >= row.lo, other=0) - ds[var] = ds[var].where(ds[var] <= row.hi, other=100) - - # Mask out invalid corrections based on uncorrected var - var_uncor = var.rstrip("_cor") - ds[var] = ds[var].where(~np.isnan(ds[var_uncor]), other=np.nan) - + # Nan inputs should stay nan + was_nan = ds[var].isnull() + if ~np.isnan(row.lo): + ds[var] = ds[var].where(ds[var] >= row.lo, other=0) + if ~np.isnan(row.hi): + ds[var] = ds[var].where( ds[var] <= row.hi, other=100) + ds[var] = ds[var].where(~was_nan) else: if ~np.isnan(row.lo): ds[var] = ds[var].where(ds[var] >= row.lo) diff --git a/tests/unit/test_value_clippping.py b/tests/unit/test_value_clippping.py index ceb8e2ef..063242fc 100644 --- a/tests/unit/test_value_clippping.py +++ b/tests/unit/test_value_clippping.py @@ -198,3 +198,78 @@ def test_circular_dependencies(self): check_dtype=True, ) + def test_rh_corrected_case(self): + """ + The rh corrected variables are treated differently in the clipping function. + """ + data_index = pd.RangeIndex(2) + rh_u = pd.Series(index=data_index, data=[0, 54], name="rh_u") + rh_u_cor = pd.Series(index=data_index, data=[0, np.nan], name="rh_u_cor") + rh_l = pd.Series(index=data_index, data=[-20, 54], name="rh_l") + rh_l_cor = pd.Series(index=data_index, data=[0, 254], name="rh_l_cor") + data = pd.concat([rh_u, rh_u_cor, rh_l, rh_l_cor], axis=1) + variable_config = pd.DataFrame( + columns=["field", "lo", "hi", "OOL"], + data=[ + ["rh_u", 0, 100, "rh_u_cor"], + ["rh_u_cor", 0, 150, ""], + ["rh_l_cor", np.nan, np.nan, ""], + ["rh_l", 0, 100, "rh_l_cor"], + ], + ).set_index("field") + data_set = xr.Dataset(data) + + data_set_out = clip_values(data_set, variable_config) + + # Convert to dataframe for easier comparison + data_frame_out = data_set_out.to_dataframe() + # The value of rh_u_cor should be nan since the input was nan + self.assertTrue(np.isnan(data_frame_out.iloc[1]["rh_u_cor"])) + # The value of rh_l_cor should not be changed since the hi threshold is not nan + self.assertEqual(data_frame_out.iloc[1]["rh_l_cor"], 254) + # The value of rh_l_cor should be nan since rh_l is below its threshold + self.assertTrue(np.isnan(data_frame_out.iloc[0]["rh_l_cor"])) + + def test_nan_input(self): + """ + Test that the function handles the case where nan input should cascade to child variables. + """ + fields = ["a", "b"] + variable_config = pd.DataFrame( + columns=["field", "lo", "hi", "OOL"], + data=[ + ["a", 0, 10, "b"], + ["b", 100, 110, ""], + ], + ).set_index("field") + data_index = pd.RangeIndex(2) + data = pd.DataFrame( + columns=fields, + data=[ + [0, 100], # All a withing range + [np.nan, 100], # a is nan + ], + dtype=float, + index=data_index, + ) + expected_output = pd.DataFrame( + columns=fields, + data=[ + [0, 100], + [np.nan, np.nan], # a is nan -> b + ], + dtype=float, + index=data.index, + ) + + data_set = xr.Dataset(data) + + data_set_out = clip_values(data_set, variable_config) + data_frame_out = data_set_out.to_dataframe() + + pd.testing.assert_frame_equal( + data_frame_out, + expected_output, + check_names=False, + check_dtype=True, + )