Skip to content

Commit

Permalink
Merge pull request #260 from Ouranosinc/better-spaaverage
Browse files Browse the repository at this point in the history
Better spatial mean
  • Loading branch information
aulemahal authored Sep 26, 2023
2 parents 9342bfc + 616bd4d commit fec1a0d
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 27 deletions.
5 changes: 4 additions & 1 deletion HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ History

v0.8.0 (unreleased)
-------------------
Contributors to this version: Gabriel Rondeau-Genesse (:user:`RondeauG`).
Contributors to this version: Gabriel Rondeau-Genesse (:user:`RondeauG`), Pascal Bourgault (:user:`aulemahal`).

Announcements
^^^^^^^^^^^^^
Expand All @@ -13,6 +13,8 @@ Announcements
New features and enhancements
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
* Added the ability to search for simulations that reach a given warming level. (:pull:`251`).
* ``xs.spatial_mean`` now accepts the ``region="global"`` keyword to perform a global average (:issue:`94`, :pull:`260`).
* ``xs.spatial_mean`` with ``method='xESMF'`` will also automatically segmentize polygons (down to a 1° resolution) to ensure a correct average (:pull:`260`).

Breaking changes
^^^^^^^^^^^^^^^^
Expand All @@ -34,6 +36,7 @@ Internal changes
* Reduced the size of the files in /docs/notebooks/samples and changed the Notebooks and tests accordingly. (:issue:`247`, :pull:`248`).
* Added a new `xscen.testing` module with the `datablock_3d` function previously located in `/tests/conftest.py`. (:pull:`248`).
* New function `xscen.testing.fake_data` to generate fake data for testing. (:pull:`248`).
* xESMF 0.8 Regridder and SpatialAverager argument ``out_chunks`` is now accepted by ``xs.regrid_dataset`` and ``xs.spatial_mean``. (:pull:`260`).

v0.7.1 (2023-08-23)
-------------------
Expand Down
25 changes: 25 additions & 0 deletions tests/test_aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from xclim.testing.helpers import test_timeseries as timeseries

import xscen as xs
from xscen.testing import datablock_3d


class TestClimatologicalMean:
Expand Down Expand Up @@ -523,3 +524,27 @@ def test_errors(self):
xs.produce_horizon(
self.ds, indicators=self.yaml_file, periods=[["1950", "1990"]]
)


class TestSpatialMean:
# We test different longitude flavors : all < 0, crossing 0, all > 0
# the default global bbox changes because of subtleties in clisops
@pytest.mark.parametrize(
"method,exp", (["xesmf", 1.62032976], ["cos-lat", 1.63397460])
)
@pytest.mark.parametrize("lonstart", [-70, -30, 0])
def test_global(self, lonstart, method, exp):
ds = datablock_3d(
np.array([[[0, 1, 2], [1, 2, 3], [2, 3, 4]]] * 3, "float"),
"tas",
"lon",
lonstart,
"lat",
15,
30,
30,
as_dataset=True,
)

avg = xs.aggregate.spatial_mean(ds, method=method, region="global")
np.testing.assert_allclose(avg.tas, exp)
52 changes: 30 additions & 22 deletions xscen/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ def spatial_mean(
*,
spatial_subset: bool = None,
call_clisops: bool = False,
region: dict = None,
region: Union[dict, str] = None,
kwargs: dict = None,
simplify_tolerance: float = None,
to_domain: str = None,
Expand All @@ -394,13 +394,14 @@ def spatial_mean(
spatial_subset : bool
If True, xscen.spatial.subset will be called prior to the other operations. This requires the 'region' argument.
If None, this will automatically become True if 'region' is provided and the subsetting method is either 'cos-lat' or 'mean'.
region : dict
region : dict or str
Description of the region and the subsetting method (required fields listed in the Notes).
If method=='interp_centroid', this is used to find the region's centroid.
If method=='xesmf', the bounding box or shapefile is given to SpatialAverager.
Can also be "global", for global averages. This is simply a shortcut for `{'name': 'global', 'method': 'bbox', 'lon_bnds' [-180, 180], 'lat_bnds': [-90, 90]}`.
kwargs : dict
Arguments to send to either mean(), interp() or SpatialAverager().
For SpatialAverager, one can give `skipna` here, to be passed to the averager call itself.
For SpatialAverager, one can give `skipna` or `out_chunks` here, to be passed to the averager call itself.
simplify_tolerance : float
Precision (in degree) used to simplify a shapefile before sending it to SpatialAverager().
The simpler the polygons, the faster the averaging, but it will lose some precision.
Expand Down Expand Up @@ -451,6 +452,18 @@ def spatial_mean(
)
spatial_subset = call_clisops

if region == "global":
region = {
"name": "global",
"method": "bbox",
"lat_bnds": [-90 + 1e-5, 90 - 1e-5],
}
# `spatial_subset` won't wrap coords on the bbox, we need to fit the system used on ds.
if ds.cf["longitude"].min() >= 0:
region["lon_bnds"] = [0, 360]
else:
region["lon_bnds"] = [-180, 180]

if (
(region is not None)
and (region["method"] in region)
Expand Down Expand Up @@ -587,26 +600,14 @@ def spatial_mean(

# Uses xesmf.SpatialAverager
elif method == "xesmf":
logger.warning(
"A bug has been found with xesmf.SpatialAverager that appears to impact big regions. "
"Until this is fixed, make sure that the computation is right or use multiple smaller regions."
)
# If the region is a bounding box, call shapely and geopandas to transform it into an input compatible with xesmf
if region["method"] == "bbox":
lon_point_list = [
polygon_geom = shapely.box(
region["lon_bnds"][0],
region["lon_bnds"][0],
region["lon_bnds"][1],
region["lon_bnds"][1],
]
lat_point_list = [
region["lat_bnds"][0],
region["lon_bnds"][1],
region["lat_bnds"][1],
region["lat_bnds"][1],
region["lat_bnds"][0],
]

polygon_geom = Polygon(zip(lon_point_list, lat_point_list))
)
polygon = gpd.GeoDataFrame(index=[0], geometry=[polygon_geom])

# Prepare the History field
Expand All @@ -619,8 +620,10 @@ def spatial_mean(
elif region["method"] == "shape":
if not isinstance(region["shape"], gpd.GeoDataFrame):
polygon = gpd.read_file(region["shape"])
name = Path(region["shape"]).name
else:
polygon = region["shape"]
name = f"{len(polygon)} polygons"

# Simplify the geometries to a given tolerance, if needed.
# The simpler the polygons, the faster the averaging, but it will lose some precision.
Expand All @@ -632,14 +635,19 @@ def spatial_mean(
# Prepare the History field
new_history = (
f"[{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] "
f"xesmf.SpatialAverager over {Path(region['shape']).name} - xESMF v{xe.__version__}"
f"xesmf.SpatialAverager over {name} - xESMF v{xe.__version__}"
)

else:
raise ValueError("'method' should be one of [bbox, shape].")

kwargs_copy = deepcopy(kwargs)
skipna = kwargs_copy.pop("skipna", False)
call_kwargs = {"skipna": kwargs_copy.pop("skipna", False)}
if "out_chunks" in kwargs:
call_kwargs["out_chunks"] = kwargs_copy.pop("out_chunks")

# Pre-emptive segmentization. Same threshold as xESMF, but there's not strong analysis behind this choice
geoms = shapely.segmentize(polygon.geometry, 1)

if (
ds.cf["longitude"].ndim == 2
Expand All @@ -650,8 +658,8 @@ def spatial_mean(

ds = ds.update(create_bounds_rotated_pole(ds))

savg = xe.SpatialAverager(ds, polygon.geometry, **kwargs_copy)
ds_agg = savg(ds, keep_attrs=True, skipna=skipna)
savg = xe.SpatialAverager(ds, geoms, **kwargs_copy)
ds_agg = savg(ds, keep_attrs=True, **call_kwargs)
extra_coords = {
col: xr.DataArray(polygon[col], dims=("geom",))
for col in polygon.columns
Expand Down
15 changes: 11 additions & 4 deletions xscen/regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ def regrid_dataset(
Destination grid. The Dataset needs to have lat/lon coordinates.
Supports a 'mask' variable compatible with ESMF standards.
regridder_kwargs : dict
Arguments to send xe.Regridder(). If it contains `skipna`, that
one is passed to the regridder call directly.
Arguments to send xe.Regridder(). If it contains `skipna` or `out_chunks`, those
are passed to the regridder call directly.
intermediate_grids : dict
This argument is used to do a regridding in many steps, regridding to regular
grids before regridding to the final ds_grid.
Expand Down Expand Up @@ -123,15 +123,22 @@ def regrid_dataset(
):
kwargs["weights"] = weights_filename
kwargs["reuse_weights"] = True
skipna = regridder_kwargs.pop("skipna", False)

# Extract args that are to be given at call time.
# out_chunks is only valid for xesmf >= 0.8, so don't add it be default to the call_kwargs
call_kwargs = {"skipna": regridder_kwargs.pop("skipna", False)}
if "out_chunks" in regridder_kwargs:
call_kwargs["out_chunks"] = regridder_kwargs.pop("out_chunks")

regridder = _regridder(
ds_in=ds, ds_grid=ds_grid, filename=weights_filename, **regridder_kwargs
)

# The regridder (when fed Datasets) doesn't like if 'mask' is present.
if "mask" in ds:
ds = ds.drop_vars(["mask"])
out = regridder(ds, keep_attrs=True, skipna=skipna)

out = regridder(ds, keep_attrs=True, **call_kwargs)

# double-check that grid_mapping information is transferred
gridmap_out = any(
Expand Down

0 comments on commit fec1a0d

Please sign in to comment.