diff --git a/HISTORY.rst b/HISTORY.rst index 4bab7fbf..56c6c41b 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -4,7 +4,7 @@ History v0.8.0 (unreleased) ------------------- -Contributors to this version: Gabriel Rondeau-Genesse (:user:`RondeauG`). +Contributors to this version: Gabriel Rondeau-Genesse (:user:`RondeauG`), Pascal Bourgault (:user:`aulemahal`). Announcements ^^^^^^^^^^^^^ @@ -13,6 +13,8 @@ Announcements New features and enhancements ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ * Added the ability to search for simulations that reach a given warming level. (:pull:`251`). +* ``xs.spatial_mean`` now accepts the ``region="global"`` keyword to perform a global average (:issue:`94`, :pull:`260`). +* ``xs.spatial_mean`` with ``method='xESMF'`` will also automatically segmentize polygons (down to a 1° resolution) to ensure a correct average (:pull:`260`). Breaking changes ^^^^^^^^^^^^^^^^ @@ -34,6 +36,7 @@ Internal changes * Reduced the size of the files in /docs/notebooks/samples and changed the Notebooks and tests accordingly. (:issue:`247`, :pull:`248`). * Added a new `xscen.testing` module with the `datablock_3d` function previously located in `/tests/conftest.py`. (:pull:`248`). * New function `xscen.testing.fake_data` to generate fake data for testing. (:pull:`248`). +* xESMF 0.8 Regridder and SpatialAverager argument ``out_chunks`` is now accepted by ``xs.regrid_dataset`` and ``xs.spatial_mean``. (:pull:`260`). v0.7.1 (2023-08-23) ------------------- diff --git a/tests/test_aggregate.py b/tests/test_aggregate.py index 72afcbc8..879f346c 100644 --- a/tests/test_aggregate.py +++ b/tests/test_aggregate.py @@ -6,6 +6,7 @@ from xclim.testing.helpers import test_timeseries as timeseries import xscen as xs +from xscen.testing import datablock_3d class TestClimatologicalMean: @@ -523,3 +524,27 @@ def test_errors(self): xs.produce_horizon( self.ds, indicators=self.yaml_file, periods=[["1950", "1990"]] ) + + +class TestSpatialMean: + # We test different longitude flavors : all < 0, crossing 0, all > 0 + # the default global bbox changes because of subtleties in clisops + @pytest.mark.parametrize( + "method,exp", (["xesmf", 1.62032976], ["cos-lat", 1.63397460]) + ) + @pytest.mark.parametrize("lonstart", [-70, -30, 0]) + def test_global(self, lonstart, method, exp): + ds = datablock_3d( + np.array([[[0, 1, 2], [1, 2, 3], [2, 3, 4]]] * 3, "float"), + "tas", + "lon", + lonstart, + "lat", + 15, + 30, + 30, + as_dataset=True, + ) + + avg = xs.aggregate.spatial_mean(ds, method=method, region="global") + np.testing.assert_allclose(avg.tas, exp) diff --git a/xscen/aggregate.py b/xscen/aggregate.py index 3ca8ade9..552531bb 100644 --- a/xscen/aggregate.py +++ b/xscen/aggregate.py @@ -374,7 +374,7 @@ def spatial_mean( *, spatial_subset: bool = None, call_clisops: bool = False, - region: dict = None, + region: Union[dict, str] = None, kwargs: dict = None, simplify_tolerance: float = None, to_domain: str = None, @@ -394,13 +394,14 @@ def spatial_mean( spatial_subset : bool If True, xscen.spatial.subset will be called prior to the other operations. This requires the 'region' argument. If None, this will automatically become True if 'region' is provided and the subsetting method is either 'cos-lat' or 'mean'. - region : dict + region : dict or str Description of the region and the subsetting method (required fields listed in the Notes). If method=='interp_centroid', this is used to find the region's centroid. If method=='xesmf', the bounding box or shapefile is given to SpatialAverager. + Can also be "global", for global averages. This is simply a shortcut for `{'name': 'global', 'method': 'bbox', 'lon_bnds' [-180, 180], 'lat_bnds': [-90, 90]}`. kwargs : dict Arguments to send to either mean(), interp() or SpatialAverager(). - For SpatialAverager, one can give `skipna` here, to be passed to the averager call itself. + For SpatialAverager, one can give `skipna` or `out_chunks` here, to be passed to the averager call itself. simplify_tolerance : float Precision (in degree) used to simplify a shapefile before sending it to SpatialAverager(). The simpler the polygons, the faster the averaging, but it will lose some precision. @@ -451,6 +452,18 @@ def spatial_mean( ) spatial_subset = call_clisops + if region == "global": + region = { + "name": "global", + "method": "bbox", + "lat_bnds": [-90 + 1e-5, 90 - 1e-5], + } + # `spatial_subset` won't wrap coords on the bbox, we need to fit the system used on ds. + if ds.cf["longitude"].min() >= 0: + region["lon_bnds"] = [0, 360] + else: + region["lon_bnds"] = [-180, 180] + if ( (region is not None) and (region["method"] in region) @@ -587,26 +600,14 @@ def spatial_mean( # Uses xesmf.SpatialAverager elif method == "xesmf": - logger.warning( - "A bug has been found with xesmf.SpatialAverager that appears to impact big regions. " - "Until this is fixed, make sure that the computation is right or use multiple smaller regions." - ) # If the region is a bounding box, call shapely and geopandas to transform it into an input compatible with xesmf if region["method"] == "bbox": - lon_point_list = [ + polygon_geom = shapely.box( region["lon_bnds"][0], - region["lon_bnds"][0], - region["lon_bnds"][1], - region["lon_bnds"][1], - ] - lat_point_list = [ region["lat_bnds"][0], + region["lon_bnds"][1], region["lat_bnds"][1], - region["lat_bnds"][1], - region["lat_bnds"][0], - ] - - polygon_geom = Polygon(zip(lon_point_list, lat_point_list)) + ) polygon = gpd.GeoDataFrame(index=[0], geometry=[polygon_geom]) # Prepare the History field @@ -619,8 +620,10 @@ def spatial_mean( elif region["method"] == "shape": if not isinstance(region["shape"], gpd.GeoDataFrame): polygon = gpd.read_file(region["shape"]) + name = Path(region["shape"]).name else: polygon = region["shape"] + name = f"{len(polygon)} polygons" # Simplify the geometries to a given tolerance, if needed. # The simpler the polygons, the faster the averaging, but it will lose some precision. @@ -632,14 +635,19 @@ def spatial_mean( # Prepare the History field new_history = ( f"[{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " - f"xesmf.SpatialAverager over {Path(region['shape']).name} - xESMF v{xe.__version__}" + f"xesmf.SpatialAverager over {name} - xESMF v{xe.__version__}" ) else: raise ValueError("'method' should be one of [bbox, shape].") kwargs_copy = deepcopy(kwargs) - skipna = kwargs_copy.pop("skipna", False) + call_kwargs = {"skipna": kwargs_copy.pop("skipna", False)} + if "out_chunks" in kwargs: + call_kwargs["out_chunks"] = kwargs_copy.pop("out_chunks") + + # Pre-emptive segmentization. Same threshold as xESMF, but there's not strong analysis behind this choice + geoms = shapely.segmentize(polygon.geometry, 1) if ( ds.cf["longitude"].ndim == 2 @@ -650,8 +658,8 @@ def spatial_mean( ds = ds.update(create_bounds_rotated_pole(ds)) - savg = xe.SpatialAverager(ds, polygon.geometry, **kwargs_copy) - ds_agg = savg(ds, keep_attrs=True, skipna=skipna) + savg = xe.SpatialAverager(ds, geoms, **kwargs_copy) + ds_agg = savg(ds, keep_attrs=True, **call_kwargs) extra_coords = { col: xr.DataArray(polygon[col], dims=("geom",)) for col in polygon.columns diff --git a/xscen/regrid.py b/xscen/regrid.py index ace0f201..396c56ed 100644 --- a/xscen/regrid.py +++ b/xscen/regrid.py @@ -49,8 +49,8 @@ def regrid_dataset( Destination grid. The Dataset needs to have lat/lon coordinates. Supports a 'mask' variable compatible with ESMF standards. regridder_kwargs : dict - Arguments to send xe.Regridder(). If it contains `skipna`, that - one is passed to the regridder call directly. + Arguments to send xe.Regridder(). If it contains `skipna` or `out_chunks`, those + are passed to the regridder call directly. intermediate_grids : dict This argument is used to do a regridding in many steps, regridding to regular grids before regridding to the final ds_grid. @@ -123,7 +123,13 @@ def regrid_dataset( ): kwargs["weights"] = weights_filename kwargs["reuse_weights"] = True - skipna = regridder_kwargs.pop("skipna", False) + + # Extract args that are to be given at call time. + # out_chunks is only valid for xesmf >= 0.8, so don't add it be default to the call_kwargs + call_kwargs = {"skipna": regridder_kwargs.pop("skipna", False)} + if "out_chunks" in regridder_kwargs: + call_kwargs["out_chunks"] = regridder_kwargs.pop("out_chunks") + regridder = _regridder( ds_in=ds, ds_grid=ds_grid, filename=weights_filename, **regridder_kwargs ) @@ -131,7 +137,8 @@ def regrid_dataset( # The regridder (when fed Datasets) doesn't like if 'mask' is present. if "mask" in ds: ds = ds.drop_vars(["mask"]) - out = regridder(ds, keep_attrs=True, skipna=skipna) + + out = regridder(ds, keep_attrs=True, **call_kwargs) # double-check that grid_mapping information is transferred gridmap_out = any(