diff --git a/CHANGES.rst b/CHANGES.rst index e1b2c8f6..a5df9dfa 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -6,13 +6,19 @@ v0.9.0 (unreleased) ------------------- Contributors to this version: Trevor James Smith (:user:`Zeitsperre`), Pascal Bourgault (:user:`aulemahal`), Gabriel Rondeau-Genesse (:user:`RondeauG`), Juliette Lavoie (:user: `juliettelavoie`). +Breaking changes +^^^^^^^^^^^^^^^^ +* Removed support for the old instances of the `region` argument in ``spatial_mean``, ``extract_dataset``, and ``subset``. (:pull:`367`). +* Removed ``xscen.extract.clisops_subset``. (:pull:`367`). + Internal changes ^^^^^^^^^^^^^^^^ * Updated the `cookiecutter` template to the latest version. (:pull:`358`): * Addresses a handful of misconfigurations in the GitHub Workflows. * Added a few free `grep`-based hooks for finding unwanted artifacts in the code base. * Updated `ruff` to v0.2.0 and `black` to v24.2.0. -* Added tests for biasadjust. (:pull:`366`). +* Added more tests. (:pull:`366`, :pull:`367`). +* Refactored ``xs.spatial.subset`` into smaller functions. (:pull:`367`). Bug fixes ^^^^^^^^^ @@ -21,7 +27,7 @@ Bug fixes * Loading of training in `adjust` is now done outside of the periods loop. (:pull:`366`). * Fixed bug for adding the preprocessing attributes inside the `adjust` function. (:pull:`366`). * Fixed a bug to accept `group = False` in `adjust` function. (:pull:`366`). - +* `creep_weights` now correctly handles the case where the grid is small, `n` is large, and `mode=wrap`. (:issue:`367`). v0.8.3 (2024-02-28) ------------------- diff --git a/tests/test_spatial.py b/tests/test_spatial.py new file mode 100644 index 00000000..db89704e --- /dev/null +++ b/tests/test_spatial.py @@ -0,0 +1,347 @@ +import dask.array +import geopandas as gpd +import numpy as np +import pytest +import xarray as xr +import xclim as xc +from shapely.geometry import Polygon + +import xscen as xs +from xscen.spatial import _estimate_grid_resolution, _load_lon_lat +from xscen.testing import datablock_3d + + +class TestCreepFill: + # Create a 3D datablock + ds = datablock_3d( + np.tile(np.arange(1, 37).reshape(6, 6), (3, 1, 1)), + "tas", + "lon", + -70, + "lat", + 45, + 1, + 1, + "2000-01-01", + as_dataset=True, + ) + ds["mask"] = ds["tas"].isel(time=0) > 0 + # Place a few False values in the mask + ds["mask"][0, 0] = False + ds["mask"][3, 3] = False + + @pytest.mark.parametrize( + ("n", "mode"), + [(1, "clip"), (2, "clip"), (3, "clip"), (1, "wrap"), (2, "wrap"), (3, "wrap")], + ) + def test_n(self, n, mode): + w = xs.spatial.creep_weights(self.ds["mask"], n=n, mode=mode) + out = xs.spatial.creep_fill(self.ds["tas"], w) + + if mode == "clip": + neighbours_0 = { + 1: [2, 7, 8], + 2: [2, 7, 8, 3, 9, 13, 14, 15], + 3: [ + 2, + 7, + 8, + 3, + 9, + 13, + 14, + 15, + 4, + 10, + 16, + 19, + 20, + 21, + ], # 22 is False, thus not included + } + neighbours_3 = { + # For these n, the average is the same as the original value + 1: [22], + 2: [22], + # Here all the values are included, except the False ones + 3: [ + (self.ds["tas"].isel(time=0).sum().values - 22 - 1) + / (self.ds["mask"].count().values - 2) + ], + } + else: + neighbours_0 = { + 1: [36, 31, 32, 6, 2, 12, 7, 8], + 2: [ + 29, + 30, + 25, + 26, + 27, + 35, + 36, + 31, + 32, + 33, + 5, + 6, + 2, + 3, + 11, + 12, + 7, + 8, + 9, + 17, + 18, + 13, + 14, + 15, + ], + 3: [ + ( + np.sum(np.arange(1, 37)) + - 22 + - 1 + + np.sum([19, 20, 21, 23, 24]) + + np.sum([4, 10, 16, 28, 34]) + ) + / (36 - 2 + 10) + ], + } + neighbours_3 = { + # For these n, the average is the same as the original value + 1: [22], + 2: [22], + # Here all the values are included, except the False ones + 3: [ + ( + np.sum(np.arange(1, 37)) + - 22 + - 1 + + np.sum(np.arange(2, 7)) + + np.sum([7, 13, 19, 25, 31]) + ) + / (36 - 2 + 10) + ], + } + + np.testing.assert_allclose( + out.isel(lat=0, lon=0), np.tile(np.mean(neighbours_0[n]), 3) + ) + np.testing.assert_allclose( + out.isel(lat=3, lon=3), np.tile(np.mean(neighbours_3[n]), 3) + ) + + def test_wrong_mode(self): + with pytest.raises(ValueError, match="mode must be either"): + xs.spatial.creep_weights(self.ds["mask"], n=1, mode="wrong") + + def test_n0(self): + w = xs.spatial.creep_weights(self.ds["mask"], n=0, mode="clip") + out = xs.spatial.creep_fill(self.ds["tas"], w) + np.testing.assert_equal(out.isel(lat=0, lon=0), np.tile(np.nan, 3)) + np.testing.assert_equal(out.isel(lat=3, lon=3), np.tile(np.nan, 3)) + + +class TestSubset: + ds = datablock_3d( + np.ones((3, 50, 50)), + "tas", + "lon", + -70, + "lat", + 45, + 1, + 1, + "2000-01-01", + as_dataset=True, + ) + + @pytest.mark.parametrize( + ("kwargs", "name"), + [ + ({"lon": -70, "lat": 45}, None), + ({"lon": [-53.3, -69.6], "lat": [49.3, 46.6]}, "foo"), + ], + ) + def test_subset_gridpoint(self, kwargs, name): + with pytest.warns(UserWarning, match="tile_buffer is not used"): + out = xs.spatial.subset( + self.ds, "gridpoint", name=name, tile_buffer=5, **kwargs + ) + + if isinstance(kwargs["lon"], list): + expected = { + "lon": [np.round(k) for k in kwargs["lon"]], + "lat": [np.round(k) for k in kwargs["lat"]], + } + else: + expected = { + "lon": [np.round(kwargs["lon"])], + "lat": [np.round(kwargs["lat"])], + } + + assert ( + f"gridpoint spatial subsetting on {len(expected['lon'])} coordinates" + in out.attrs["history"] + ) + np.testing.assert_array_equal(out["lon"], expected["lon"]) + np.testing.assert_array_equal(out["lat"], expected["lat"]) + if name: + assert out.attrs["cat:domain"] == name + else: + assert "cat:domain" not in out.attrs + + @pytest.mark.parametrize( + ("kwargs", "tile_buffer", "method"), + [ + ({"lon_bnds": [-63, -60], "lat_bnds": [47, 50]}, 0, "bbox"), + ({"lon_bnds": [-63, -60], "lat_bnds": [47, 50]}, 5, "bbox"), + ({}, 0, "shape"), + ({}, 5, "shape"), + ({"buffer": 3}, 5, "shape"), + ], + ) + def test_subset_bboxshape(self, kwargs, tile_buffer, method): + if method == "shape": + gdf = gpd.GeoDataFrame( + {"geometry": [Polygon([(-63, 47), (-63, 50), (-60, 50), (-60, 47)])]} + ) + kwargs["shape"] = gdf + + if "buffer" in kwargs: + with pytest.raises( + ValueError, match="Both tile_buffer and clisops' buffer were requested." + ): + xs.spatial.subset(self.ds, method, tile_buffer=tile_buffer, **kwargs) + else: + out = xs.spatial.subset(self.ds, method, tile_buffer=tile_buffer, **kwargs) + + assert f"{method} spatial subsetting" in out.attrs["history"] + if tile_buffer: + assert f"with buffer={tile_buffer}" in out.attrs["history"] + np.testing.assert_array_equal( + out["lon"], + np.arange( + np.max([-63 - tile_buffer, self.ds.lon.min()]), + np.min([-60 + tile_buffer + 1, self.ds.lon.max()]), + ), + ) + np.testing.assert_array_equal( + out["lat"], + np.arange( + np.max([47 - tile_buffer, self.ds.lat.min()]), + np.min([50 + tile_buffer + 1, self.ds.lat.max()]), + ), + ) + + else: + assert "with no buffer" in out.attrs["history"] + np.testing.assert_array_equal(out["lon"], np.arange(-63, -59)) + np.testing.assert_array_equal(out["lat"], np.arange(47, 51)) + + def test_subset_sel(self): + ds = datablock_3d( + np.ones((3, 50, 50)), + "tas", + "rlon", + -10, + "rlat", + 0, + 1, + 1, + "2000-01-01", + as_dataset=True, + ) + + with pytest.raises(KeyError): + xs.spatial.subset(ds, "sel", lon=[-75, -70], lat=[-5, 0]) + out = xs.spatial.subset(ds, "sel", rlon=[-5, 5], rlat=[0, 3]) + + assert "sel subsetting" in out.attrs["history"] + np.testing.assert_array_equal(out["rlon"], np.arange(-5, 6)) + np.testing.assert_array_equal(out["rlat"], np.arange(0, 4)) + + def test_history(self): + ds = self.ds.copy(deep=True) + ds.attrs["history"] = "this is previous history" + out = xs.spatial.subset(ds, "gridpoint", lon=-70, lat=45) + + assert "this is previous history" in out.attrs["history"].split("\n")[1] + assert "gridpoint spatial subsetting" in out.attrs["history"].split("\n")[0] + + def test_subset_wrong_method(self): + with pytest.raises(ValueError, match="Subsetting type not recognized"): + xs.spatial.subset(self.ds, "wrong", lon=-70, lat=45) + + +def test_dask_coords(): + ds = datablock_3d( + np.ones((3, 50, 50)), + "tas", + "rlon", + -10, + "rlat", + 0, + 1, + 1, + "2000-01-01", + as_dataset=True, + ) + # Transform the coordinates to dask arrays + lon_attrs = ds["lon"].attrs + ds["lon"] = xr.DataArray( + dask.array.from_array(ds["lon"].data, chunks=(1, 1)), + dims=ds["lon"].dims, + attrs=lon_attrs, + ) + lat_attrs = ds["lat"].attrs + ds["lat"] = xr.DataArray( + dask.array.from_array(ds["lat"].data, chunks=(1, 1)), + dims=ds["lat"].dims, + attrs=lat_attrs, + ) + assert xc.core.utils.uses_dask(ds.cf["longitude"]) + + ds = _load_lon_lat(ds) + assert not xc.core.utils.uses_dask(ds.cf["longitude"]) + assert not xc.core.utils.uses_dask(ds.cf["latitude"]) + + +@pytest.mark.parametrize(("lon_res", "lat_res"), [(0.5, 1), (1, 0.5)]) +def test_estimate_res_1d(lon_res, lat_res): + ds = datablock_3d( + np.ones((3, 5, 5)), + "tas", + "lon", + -10, + "lat", + 0, + lon_res, + lat_res, + "2000-01-01", + as_dataset=True, + ) + lon_res_est, lat_res_est = _estimate_grid_resolution(ds) + assert lon_res_est == lon_res + assert lat_res_est == lat_res + + +@pytest.mark.parametrize(("lon_res", "lat_res"), [(0.5, 1), (1, 0.5)]) +def test_estimate_res_2d(lon_res, lat_res): + ds = datablock_3d( + np.ones((3, 5, 5)), + "tas", + "rlon", + -10, + "rlat", + 0, + lon_res, + lat_res, + "2000-01-01", + as_dataset=True, + ) + lon_res_est, lat_res_est = _estimate_grid_resolution(ds) + np.testing.assert_allclose(lon_res_est, ds.lon.diff("rlon").max()) + np.testing.assert_allclose(lat_res_est, ds.lat.diff("rlat").max()) diff --git a/xscen/extract.py b/xscen/extract.py index 8f4fc4ca..f1341039 100644 --- a/xscen/extract.py +++ b/xscen/extract.py @@ -45,46 +45,6 @@ ] -def clisops_subset(ds: xr.Dataset, region: dict) -> xr.Dataset: - """Customize a call to clisops.subset() that allows for an automatic buffer around the region. - - Parameters - ---------- - ds : xr.Dataset - Dataset to be subsetted - region : dict - Description of the region and the subsetting method (required fields listed in the Notes) - - Notes - ----- - 'region' fields: - method: str - ['gridpoint', 'bbox', shape'] - : dict - Arguments specific to the method used. - buffer: float, optional - Multiplier to apply to the model resolution. - - Returns - ------- - xr.Dataset - Subsetted Dataset. - - See Also - -------- - clisops.core.subset.subset_gridpoint, clisops.core.subset.subset_bbox, clisops.core.subset.subset_shape - """ - warnings.warn( - "clisops_subset is deprecated and will not be available in future versions. " - "Use xscen.spatial.subset instead.", - category=FutureWarning, - ) - - ds_subset = subset(ds, region=region) - - return ds_subset - - @parse_config def extract_dataset( # noqa: C901 catalog: DataCatalog, @@ -326,19 +286,6 @@ def extract_dataset( # noqa: C901 # subset to the region if region is not None: - if (region["method"] in region) and ( - isinstance(region[region["method"]], dict) - ): - warnings.warn( - "You seem to be using a deprecated version of region. Please use the new formatting.", - category=FutureWarning, - ) - region = deepcopy(region) - if "buffer" in region: - region["tile_buffer"] = region.pop("buffer") - _kwargs = region.pop(region["method"]) - region.update(_kwargs) - ds = subset(ds, **region) # add relevant attrs diff --git a/xscen/spatial.py b/xscen/spatial.py index 48ffe9b3..0fa66666 100644 --- a/xscen/spatial.py +++ b/xscen/spatial.py @@ -2,21 +2,24 @@ import datetime import itertools +import logging import warnings -from copy import deepcopy +from collections.abc import Sequence from pathlib import Path -from typing import Optional +from typing import Optional, Union import clisops.core.subset import dask +import geopandas as gpd import numpy as np import sparse as sp import xarray as xr import xclim as xc -from xclim.core.utils import uses_dask from .config import parse_config +logger = logging.getLogger(__name__) + __all__ = [ "creep_fill", "creep_weights", @@ -61,7 +64,14 @@ def creep_weights(mask: xr.DataArray, n: int = 1, mode: str = "clip") -> xr.Data neigh_idx_1d = np.ravel_multi_index( neigh_idx_2d, mask.shape, order="C", mode=mode ) - neigh_idx = np.unravel_index(np.unique(neigh_idx_1d), mask.shape, order="C") + if mode == "clip": + neigh_idx = np.unravel_index( + np.unique(neigh_idx_1d), mask.shape, order="C" + ) + elif mode == "wrap": + neigh_idx = np.unravel_index(neigh_idx_1d, mask.shape, order="C") + else: + raise ValueError("mode must be either 'clip' or 'wrap'") neigh = mask[neigh_idx] N = (neigh).sum() if N > 0: @@ -125,18 +135,15 @@ def _dot(arr, wei): ) -def subset( # noqa: C901 +def subset( ds: xr.Dataset, - region: Optional[dict] = None, + method: str, *, name: Optional[str] = None, - method: Optional[ - str - ] = None, # FIXME: Once the region argument is removed, this should be made mandatory. tile_buffer: float = 0, **kwargs, ) -> xr.Dataset: - """ + r""" Subset the data to a region. Either creates a slice and uses the .sel() method, or customizes a call to @@ -146,19 +153,20 @@ def subset( # noqa: C901 ---------- ds : xr.Dataset Dataset to be subsetted. - region: dict - Deprecated argument that is there for legacy reasons and will be abandoned eventually. - name: str, optional - Used to rename the 'cat:domain' attribute. method : str - ['gridpoint', 'bbox', shape','sel'] + ['gridpoint', 'bbox', shape', 'sel'] If the method is `sel`, this is not a call to clisops but only a subsetting with the xarray .sel() fonction. + name: str, optional + Used to rename the 'cat:domain' attribute. tile_buffer : float For ['bbox', shape'], uses an approximation of the grid cell size to add a buffer around the requested region. This differs from clisops' 'buffer' argument in subset_shape(). - kwargs : dict - Arguments to be sent to clisops. - If the method is `sel`, the keys are the dimensions to subset and the values are turned into a slice. + \*\*kwargs : dict + Arguments to be sent to clisops. See relevant function for details. Depending on the method, required kwargs are: + - gridpoint: lon, lat + - bbox: lon_bnds, lat_bnds + - shape: shape + - sel: slices for each dimension Returns ------- @@ -169,97 +177,254 @@ def subset( # noqa: C901 -------- clisops.core.subset.subset_gridpoint, clisops.core.subset.subset_bbox, clisops.core.subset.subset_shape """ - if region is not None: + if tile_buffer > 0 and method in ["gridpoint", "sel"]: warnings.warn( - "The argument 'region' has been deprecated and will be abandoned in a future release.", - category=FutureWarning, + f"tile_buffer is not used for the '{method}' method. Ignoring the argument.", + UserWarning, ) - method = method or region.get("method") - if ("buffer" in region) and ("shape" in region): - warnings.warn( - "To avoid confusion with clisops' buffer argument, xscen's 'buffer' has been renamed 'tile_buffer'.", - category=FutureWarning, - ) - tile_buffer = tile_buffer or region.get("buffer", 0) - else: - tile_buffer = tile_buffer or region.get("tile_buffer", 0) - kwargs = deepcopy(region[region["method"]]) - if uses_dask(ds.lon) or uses_dask(ds.lat): - warnings.warn("Loading longitude and latitude for more efficient subsetting.") - ds["lon"], ds["lat"] = dask.compute(ds.lon, ds.lat) + if method == "gridpoint": + ds_subset = _subset_gridpoint(ds, name=name, **kwargs) + elif method == "bbox": + ds_subset = _subset_bbox(ds, name=name, tile_buffer=tile_buffer, **kwargs) + elif method == "shape": + ds_subset = _subset_shape(ds, name=name, tile_buffer=tile_buffer, **kwargs) + elif method == "sel": + ds_subset = _subset_sel(ds, name=name, **kwargs) + else: + raise ValueError( + "Subsetting type not recognized. Use 'gridpoint', 'bbox', 'shape' or 'sel'." + ) + + return ds_subset + + +def _subset_gridpoint( + ds: xr.Dataset, + lon: Union[float, Sequence[float], xr.DataArray], + lat: Union[float, Sequence[float], xr.DataArray], + *, + name: Optional[str] = None, + **kwargs, +) -> xr.Dataset: + r"""Subset the data to a gridpoint. + + Parameters + ---------- + ds : xr.Dataset + Dataset to be subsetted. + lon : float or Sequence[float] or xr.DataArray + Longitude coordinate(s). Must be of the same length as lat. + lat : float or Sequence[float] or xr.DataArray + Latitude coordinate(s). Must be of the same length as lon. + name: str, optional + Used to rename the 'cat:domain' attribute. + \*\*kwargs : dict + Other arguments to be sent to clisops. Possible kwargs are: + - start_date (str): Start date for the subset in the format 'YYYY-MM-DD'. + - end_date (str): End date for the subset in the format 'YYYY-MM-DD'. + - first_level (int or float): First level of the subset. + - last_level (int or float): Last level of the subset. + - tolerance (float): Masks values if the distance to the nearest gridpoint is larger than tolerance in meters. + - add_distance (bool): If True, adds a variable with the distance to the nearest gridpoint. + + Returns + ------- + xr.Dataset + Subsetted Dataset. + """ + ds = _load_lon_lat(ds) + if not hasattr(lon, "__iter__"): + lon = [lon] + if not hasattr(lat, "__iter__"): + lat = [lat] + + ds_subset = clisops.core.subset_gridpoint(ds, lon=lon, lat=lat, **kwargs) + new_history = ( + f"[{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " + f"gridpoint spatial subsetting on {len(lon)} coordinates - clisops v{clisops.__version__}" + ) + + return update_history_and_name(ds_subset, new_history, name) + + +def _subset_bbox( + ds: xr.Dataset, + lon_bnds: Union[tuple[float, float], list[float]], + lat_bnds: Union[tuple[float, float], list[float]], + *, + name: Optional[str] = None, + tile_buffer: float = 0, + **kwargs, +) -> xr.Dataset: + r"""Subset the data to a bounding box. + + Parameters + ---------- + ds : xr.Dataset + Dataset to be subsetted. + lon_bnds : tuple or list of two floats + Longitude boundaries of the bounding box. + lat_bnds : tuple or list of two floats + Latitude boundaries of the bounding box. + name: str, optional + Used to rename the 'cat:domain' attribute. + tile_buffer: float + Uses an approximation of the grid cell size to add a dynamic buffer around the requested region. + \*\*kwargs : dict + Other arguments to be sent to clisops. Possible kwargs are: + - start_date (str): Start date for the subset in the format 'YYYY-MM-DD'. + - end_date (str): End date for the subset in the format 'YYYY-MM-DD'. + - first_level (int or float): First level of the subset. + - last_level (int or float): Last level of the subset. + - time_values (Sequence[str]): A list of datetime strings to subset. + - level_values (Sequence[int or float]): A list of levels to subset. + + Returns + ------- + xr.Dataset + Subsetted Dataset. + """ + ds = _load_lon_lat(ds) + if tile_buffer > 0: - if method not in ["bbox", "shape"]: - warnings.warn( - "tile_buffer has been specified, but is not used for the requested subsetting method.", - ) - # estimate the model resolution - if len(ds.lon.dims) == 1: # 1D lat-lon - lon_res = np.abs(ds.lon.diff("lon")[0].values) - lat_res = np.abs(ds.lat.diff("lat")[0].values) - else: - lon_res = np.abs(ds.lon[0, 0].values - ds.lon[0, 1].values) - lat_res = np.abs(ds.lat[0, 0].values - ds.lat[1, 0].values) - - if method in ["gridpoint"]: - ds_subset = clisops.core.subset_gridpoint(ds, **kwargs) - new_history = ( - f"[{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " - f"{method} spatial subsetting on {len(kwargs['lon'])} coordinates - clisops v{clisops.__version__}" + lon_res, lat_res = _estimate_grid_resolution(ds) + lon_bnds = ( + lon_bnds[0] - lon_res * tile_buffer, + lon_bnds[1] + lon_res * tile_buffer, + ) + lat_bnds = ( + lat_bnds[0] - lat_res * tile_buffer, + lat_bnds[1] + lat_res * tile_buffer, ) - elif method in ["bbox"]: - if tile_buffer > 0: - # adjust the boundaries - kwargs["lon_bnds"] = ( - kwargs["lon_bnds"][0] - lon_res * tile_buffer, - kwargs["lon_bnds"][1] + lon_res * tile_buffer, - ) - kwargs["lat_bnds"] = ( - kwargs["lat_bnds"][0] - lat_res * tile_buffer, - kwargs["lat_bnds"][1] + lat_res * tile_buffer, + ds_subset = clisops.core.subset_bbox( + ds, lon_bnds=lon_bnds, lat_bnds=lat_bnds, **kwargs + ) + new_history = ( + f"[{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " + f"bbox spatial subsetting with {'buffer=' + str(tile_buffer) if tile_buffer > 0 else 'no buffer'}" + f", lon_bnds={np.array(lon_bnds)}, lat_bnds={np.array(lat_bnds)}" + f" - clisops v{clisops.__version__}" + ) + + return update_history_and_name(ds_subset, new_history, name) + + +def _subset_shape( + ds: xr.Dataset, + shape: Union[str, Path, gpd.GeoDataFrame], + *, + name: Optional[str] = None, + tile_buffer: float = 0, + **kwargs, +) -> xr.Dataset: + r"""Subset the data to a shape. + + Parameters + ---------- + ds : xr.Dataset + Dataset to be subsetted. + shape : str or gpd.GeoDataFrame + Path to the shapefile or GeoDataFrame. + name: str, optional + Used to rename the 'cat:domain' attribute. + tile_buffer: float + Uses an approximation of the grid cell size to add a buffer around the requested region. + \*\*kwargs : dict + Other arguments to be sent to clisops. Possible kwargs are: + - raster_crs (str or int): EPSG number or PROJ4 string. + - shape_crs (str or int): EPSG number or PROJ4 string. + - buffer (float): Buffer size to add around the shape. Units are based on the shape degrees/metres. + - start_date (str): Start date for the subset in the format 'YYYY-MM-DD'. + - end_date (str): End date for the subset in the format 'YYYY-MM-DD'. + - first_level (int or float): First level of the subset. + - last_level (int or float): Last level of the subset. + + Returns + ------- + xr.Dataset + Subsetted Dataset. + """ + ds = _load_lon_lat(ds) + + if tile_buffer > 0: + if kwargs.get("buffer") is not None: + raise ValueError( + "Both tile_buffer and clisops' buffer were requested. Use only one." ) + lon_res, lat_res = _estimate_grid_resolution(ds) + kwargs["buffer"] = np.max([lon_res, lat_res]) * tile_buffer + + ds_subset = clisops.core.subset_shape(ds, shape=shape, **kwargs) + new_history = ( + f"[{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " + f"shape spatial subsetting with {'buffer=' + str(tile_buffer) if tile_buffer > 0 else 'no buffer'}" + f", shape={Path(shape).name if isinstance(shape, (str, Path)) else 'gpd.GeoDataFrame'}" + f" - clisops v{clisops.__version__}" + ) - if xc.core.utils.uses_dask(ds.cf["longitude"]): - ds[ds.cf["longitude"].name].load() - if xc.core.utils.uses_dask(ds.cf["latitude"]): - ds[ds.cf["latitude"].name].load() - - ds_subset = clisops.core.subset_bbox(ds, **kwargs) - new_history = ( - f"[{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " - f"{method} spatial subsetting with {'buffer=' + str(tile_buffer) if tile_buffer > 0 else 'no buffer'}" - f", lon_bnds={np.array(kwargs['lon_bnds'])}, lat_bnds={np.array(kwargs['lat_bnds'])}" - f" - clisops v{clisops.__version__}" - ) + return update_history_and_name(ds_subset, new_history, name) - elif method in ["shape"]: - if tile_buffer > 0: - if kwargs.get("buffer") is not None: - raise NotImplementedError( - "Both tile_buffer and clisops' buffer were requested. Use only one." - ) - kwargs["buffer"] = np.max([lon_res, lat_res]) * tile_buffer - - ds_subset = clisops.core.subset_shape(ds, **kwargs) - new_history = ( - f"[{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " - f"{method} spatial subsetting with {'buffer=' + str(tile_buffer) if tile_buffer > 0 else 'no buffer'}" - f", shape={Path(kwargs['shape']).name if isinstance(kwargs['shape'], (str, Path)) else 'gpd.GeoDataFrame'}" - f" - clisops v{clisops.__version__}" - ) - elif method in ["sel"]: - arg_sel = {dim: slice(*map(float, bounds)) for dim, bounds in kwargs.items()} - ds_subset = ds.sel(**arg_sel) - new_history = ( - f"[{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " - f"{method} subsetting with arguments {arg_sel}" - ) +def _subset_sel(ds: xr.Dataset, *, name: Optional[str] = None, **kwargs) -> xr.Dataset: + r"""Subset the data using the .sel() method. + Parameters + ---------- + ds : xr.Dataset + Dataset to be subsetted. + name: str, optional + Used to rename the 'cat:domain' attribute. + \*\*kwargs : dict + The keys are the dimensions to subset and the values are turned into a slice. + + Returns + ------- + xr.Dataset + Subsetted Dataset. + """ + # Create a dictionary with slices for each dimension + arg_sel = {dim: slice(*map(float, bounds)) for dim, bounds in kwargs.items()} + + # Subset the dataset + ds_subset = ds.sel(**arg_sel) + + # Update the history attribute + new_history = ( + f"[{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " + f"sel subsetting with arguments {arg_sel}" + ) + + return update_history_and_name(ds_subset, new_history, name) + + +def _load_lon_lat(ds: xr.Dataset) -> xr.Dataset: + """Load longitude and latitude for more efficient subsetting.""" + if xc.core.utils.uses_dask(ds.cf["longitude"]): + logger.info("Loading longitude for more efficient subsetting.") + (ds[ds.cf["longitude"].name],) = dask.compute(ds[ds.cf["longitude"].name]) + if xc.core.utils.uses_dask(ds.cf["latitude"]): + logger.info("Loading latitude for more efficient subsetting.") + (ds[ds.cf["latitude"].name],) = dask.compute(ds[ds.cf["latitude"].name]) + + return ds + + +def _estimate_grid_resolution(ds: xr.Dataset) -> tuple[float, float]: + # Since this is to compute a buffer, we take the maximum difference as an approximation. + # Estimate the grid resolution + if len(ds.lon.dims) == 1: # 1D lat-lon + lon_res = np.abs(ds.lon.diff("lon").max().values) + lat_res = np.abs(ds.lat.diff("lat").max().values) else: - raise ValueError("Subsetting type not recognized") + lon_res = np.abs(ds.lon.diff(ds.cf["X"].name).max().values) + lat_res = np.abs(ds.lat.diff(ds.cf["Y"].name).max().values) + return lon_res, lat_res + + +def update_history_and_name(ds_subset, new_history, name): history = ( new_history + " \n " + ds_subset.attrs["history"] if "history" in ds_subset.attrs @@ -268,5 +433,4 @@ def subset( # noqa: C901 ds_subset.attrs["history"] = history if name is not None: ds_subset.attrs["cat:domain"] = name - return ds_subset