From c5c4a848d7fedb50df908325e044885d5585569b Mon Sep 17 00:00:00 2001 From: RondeauG Date: Tue, 10 Sep 2024 14:29:07 -0400 Subject: [PATCH] unstack_fill_nan_coords --- src/xscen/utils.py | 105 +++++++++++++++++++++++++++++--------------- tests/test_utils.py | 52 ++++++++++++++++++++++ 2 files changed, 121 insertions(+), 36 deletions(-) diff --git a/src/xscen/utils.py b/src/xscen/utils.py index cae9a7c3..671dbeb2 100644 --- a/src/xscen/utils.py +++ b/src/xscen/utils.py @@ -9,6 +9,7 @@ import warnings from collections import defaultdict from collections.abc import Sequence +from copy import deepcopy from datetime import datetime from io import StringIO from itertools import chain @@ -446,7 +447,9 @@ def unstack_fill_nan( *, dim: str = "loc", coords: Optional[ - Union[str, os.PathLike, Sequence[Union[str, os.PathLike]], dict] + Union[ + str, os.PathLike, Sequence[Union[str, os.PathLike]], dict[str, xr.DataArray] + ] ] = None, ): """Unstack a Dataset that was stacked by :py:func:`stack_drop_nans`. @@ -454,38 +457,47 @@ def unstack_fill_nan( Parameters ---------- ds : xr.Dataset - A dataset with some dims stacked by `stack_drop_nans`. + A dataset with some dimensions stacked by `stack_drop_nans`. dim : str The dimension to unstack, same as `new_dim` in `stack_drop_nans`. - coords : Sequence of strings, Mapping of str to array, str, optional - If a sequence : if the dataset has coords along `dim` that are not original - dimensions, those original dimensions must be listed here. - If a dict : a mapping from the name to the array of the coords to unstack - If a str : a filename to a dataset containing only those coords (as coords). - If given a string with {shape} and {domain}, the formatting will fill them with - the original shape of the dataset (that should have been store in the - attributes of the stacked dimensions) by `stack_drop_nans` and the global attributes 'cat:domain'. - It is recommended to fill this argument in the config. It will be parsed automatically. - E.g.: - - utils: - stack_drop_nans: - to_file: /some_path/coords/coords_{domain}_{shape}.nc - unstack_fill_nan: - coords: /some_path/coords/coords_{domain}_{shape}.nc - - If None (default), all coords that have `dim` a single dimension are used as the - new dimensions/coords in the unstacked output. - Coordinates will be loaded within this function. + coords : string or os.PathLike or Sequence or dict, optional + Additional information used to reconstruct coordinates that might have been lost in the stacking (e.g., if a lat/lon grid was all NaNs). + If a string or os.PathLike : Path to a dataset containing only those coordinates, such as the output of `to_file` in `stack_drop_nans`. + This is the recommended option. + If a dictionary : A mapping from the name of the coordinate that was stacked to a DataArray. Better alternative if no file is available. + If a sequence : The names of the original dimensions that were stacked. Worst option. + If None (default), same as a sequence, but all coordinates that have `dim` as a single dimension are used as the new dimensions. + See Notes for more information. Returns ------- xr.Dataset Same as `ds`, but `dim` has been unstacked to coordinates in `coords`. Missing elements are filled according to the defaults of `fill_value` of :py:meth:`xarray.Dataset.unstack`. + + Notes + ----- + Some information might have been completely lost in the stacking process, for example, if a longitude is NaN across all latitudes. + It is impossible to recover that information when using `coords` as a list, which is why it is recommended to use a file or a dictionary instead. + + If a dictionary is used, the keys must be the names of the coordinates that were stacked and the values must be the DataArrays. + This method can recover both dimensions and additional coordinates that were not dimensions in the original dataset, but were stacked. + + If the original stacking was done with `stack_drop_nans` and the `to_file` argument was used, the `coords` argument should be a string with + the path to the file. Additionally, the file name can contain the formatting fields {shape} and {domain}, which will be automatically filled + with the original shape of the dataset and the global attribute 'cat:domain'. If using that dynamic path, it is recommended to fill the + argument in the xscen config. + E.g.: + + utils: + stack_drop_nans: + to_file: /some_path/coords/coords_{domain}_{shape}.nc + unstack_fill_nan: + coords: /some_path/coords/coords_{domain}_{shape}.nc """ if coords is None: logger.info("Dataset unstacked using no coords argument.") + coords = [d for d in ds.coords if ds[d].dims == (dim,)] if isinstance(coords, (str, os.PathLike)): # find original shape in the attrs of one of the dimension @@ -520,29 +532,50 @@ def unstack_fill_nan( # only reindex with the dims out = out.reindex(**coords_and_dims) - # add back the coords that arent dims + # add back the coords that aren't dims for c in coords_not_dims: out[c] = coords[c] else: - if isinstance(coords, (list, tuple)): - dims, crds = zip(*[(name, ds[name].load().values) for name in coords]) - else: - dims, crds = zip( - *[ - (name, crd.load().values) - for name, crd in ds.coords.items() - if crd.dims == (dim,) - ] - ) + coord_not_dim = {} + # Special case where the dictionary contains both dimensions and other coordinates + if isinstance(coords, dict): + coord_not_dim = { + k: v + for k, v in coords.items() + if len(set(v.dims).intersection(list(coords))) != 1 + } + coords = deepcopy(coords) + coords = { + k: v + for k, v in coords.items() + if k in set(coords).difference(coord_not_dim) + } + + dims, crds = zip( + *[ + (name, crd.load().values) + for name, crd in ds.coords.items() + if (crd.dims == (dim,) and name in set(coords)) + ] + ) - # explicitly get lat and lon + # Reconstruct the dimensions mindex_obj = pd.MultiIndex.from_arrays(crds, names=dims) mindex_coords = xr.Coordinates.from_pandas_multiindex(mindex_obj, dim) out = ds.drop_vars(dims).assign_coords(mindex_coords).unstack(dim) - if not isinstance(coords, (list, tuple)) and coords is not None: - out = out.reindex(**coords.coords) + if isinstance(coords, dict): + # Reindex with the coords that were dimensions + out = out.reindex(**coords) + # Add back the coordinates that aren't dimensions + for c in coord_not_dim: + out[c] = coord_not_dim[c] + + # Reorder the dimensions to match the CF conventions + order = [out.cf.axes.get(d, [""])[0] for d in ["T", "Z", "Y", "X"]] + order = [d for d in order if d] + [d for d in out.dims if d not in order] + out = out.transpose(*order) for dim in dims: out[dim].attrs.update(ds[dim].attrs) diff --git a/tests/test_utils.py b/tests/test_utils.py index 889d555d..b52e857a 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -296,6 +296,58 @@ def test_nan(self, tmp_path): ) assert ds_unstack.equals(ds) + @pytest.mark.parametrize("coords", ["file.nc", ["lon", "lat"], "dict", None]) + def test_fillnan_coords(self, tmpdir, coords): + data = np.zeros((20, 10, 10)) + data[:, 1, 0] = [np.nan] * 20 + data[:, 0, :] = [np.nan] * 10 + ds = datablock_3d( + data, + "tas", + "lon", + -5, + "lat", + 80.5, + 1, + 1, + "2000-01-01", + as_dataset=True, + ) + ds.attrs["cat:domain"] = "RegionEssai" + + mask = xr.where(ds.tas.isel(time=0).isnull(), False, True).drop_vars("time") + # Add mask as a coordinate + ds = ds.assign_coords(z=mask.astype(int)) + ds.z.attrs["foo"] = "bar" + + if coords == "dict": + coords = {"lon": ds.lon, "lat": ds.lat, "z": ds.z} + elif coords == "file.nc": + coords = str(tmpdir / "coords_{domain}_{shape}.nc") + + ds_stack = xs.utils.stack_drop_nans( + ds, mask=mask, to_file=coords if isinstance(coords, str) else None + ) + ds_unstack = xs.utils.unstack_fill_nan( + ds_stack, + coords=coords, + ) + + if isinstance(coords, list): + # Cannot fully recover the original dataset. + ds_unstack["z"] = ds_unstack["z"].fillna(0) + assert ds_unstack.equals(ds.isel(lat=slice(1, None))) + elif coords is None: + # 'z' gets completely assigned as a dimension. + assert "z" in ds_unstack.dims + assert ( + ds_unstack.isel(z=0) + .drop_vars("z") + .equals(ds.isel(lat=slice(1, None)).drop_vars("z")) + ) + else: + assert ds_unstack.equals(ds) + def test_maybe(self, tmp_path): data = np.zeros((20, 10, 10)) data[:, 0, 0] = [np.nan] * 20