Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better spatial mean #260

Merged
merged 4 commits into from
Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ History

v0.8.0 (unreleased)
-------------------
Contributors to this version: Gabriel Rondeau-Genesse (:user:`RondeauG`).
Contributors to this version: Gabriel Rondeau-Genesse (:user:`RondeauG`), Pascal Bourgault (:user:`aulemahal`).

Announcements
^^^^^^^^^^^^^
Expand All @@ -13,6 +13,8 @@ Announcements
New features and enhancements
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
* Added the ability to search for simulations that reach a given warming level. (:pull:`251`).
* ``xs.spatial_mean`` now accepts the ``region="global"`` keyword to perform a global average (:issue:`94`, :pull:`260`).
* ``xs.spatial_mean`` with ``method='xESMF'`` will also automatically segmentize polygons (down to a 1° resolution) to ensure a correct average (:pull:`260`).

Breaking changes
^^^^^^^^^^^^^^^^
Expand All @@ -34,6 +36,7 @@ Internal changes
* Reduced the size of the files in /docs/notebooks/samples and changed the Notebooks and tests accordingly. (:issue:`247`, :pull:`248`).
* Added a new `xscen.testing` module with the `datablock_3d` function previously located in `/tests/conftest.py`. (:pull:`248`).
* New function `xscen.testing.fake_data` to generate fake data for testing. (:pull:`248`).
* xESMF 0.8 Regridder and SpatialAverager argument ``out_chunks`` is now accepted by ``xs.regrid_dataset`` and ``xs.spatial_mean``. (:pull:`260`).

v0.7.1 (2023-08-23)
-------------------
Expand Down
25 changes: 25 additions & 0 deletions tests/test_aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from xclim.testing.helpers import test_timeseries as timeseries

import xscen as xs
from xscen.testing import datablock_3d


class TestClimatologicalMean:
Expand Down Expand Up @@ -523,3 +524,27 @@ def test_errors(self):
xs.produce_horizon(
self.ds, indicators=self.yaml_file, periods=[["1950", "1990"]]
)


class TestSpatialMean:
# We test different longitude flavors : all < 0, crossing 0, all > 0
# the default global bbox changes because of subtleties in clisops
@pytest.mark.parametrize(
"method,exp", (["xesmf", 1.62032976], ["cos-lat", 1.63397460])
)
@pytest.mark.parametrize("lonstart", [-70, -30, 0])
def test_global(self, lonstart, method, exp):
ds = datablock_3d(
np.array([[[0, 1, 2], [1, 2, 3], [2, 3, 4]]] * 3, "float"),
"tas",
"lon",
lonstart,
"lat",
15,
30,
30,
as_dataset=True,
)

avg = xs.aggregate.spatial_mean(ds, method=method, region="global")
np.testing.assert_allclose(avg.tas, exp)
52 changes: 30 additions & 22 deletions xscen/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ def spatial_mean(
*,
spatial_subset: bool = None,
call_clisops: bool = False,
region: dict = None,
region: Union[dict, str] = None,
kwargs: dict = None,
simplify_tolerance: float = None,
to_domain: str = None,
Expand All @@ -394,13 +394,14 @@ def spatial_mean(
spatial_subset : bool
If True, xscen.spatial.subset will be called prior to the other operations. This requires the 'region' argument.
If None, this will automatically become True if 'region' is provided and the subsetting method is either 'cos-lat' or 'mean'.
region : dict
region : dict or str
Description of the region and the subsetting method (required fields listed in the Notes).
If method=='interp_centroid', this is used to find the region's centroid.
If method=='xesmf', the bounding box or shapefile is given to SpatialAverager.
Can also be "global", for global averages. This is simply a shortcut for `{'name': 'global', 'method': 'bbox', 'lon_bnds' [-180, 180], 'lat_bnds': [-90, 90]}`.
kwargs : dict
Arguments to send to either mean(), interp() or SpatialAverager().
For SpatialAverager, one can give `skipna` here, to be passed to the averager call itself.
For SpatialAverager, one can give `skipna` or `out_chunks` here, to be passed to the averager call itself.
simplify_tolerance : float
Precision (in degree) used to simplify a shapefile before sending it to SpatialAverager().
The simpler the polygons, the faster the averaging, but it will lose some precision.
Expand Down Expand Up @@ -451,6 +452,18 @@ def spatial_mean(
)
spatial_subset = call_clisops

if region == "global":
region = {
"name": "global",
"method": "bbox",
"lat_bnds": [-90 + 1e-5, 90 - 1e-5],
}
# `spatial_subset` won't wrap coords on the bbox, we need to fit the system used on ds.
if ds.cf["longitude"].min() >= 0:
region["lon_bnds"] = [0, 360]
else:
region["lon_bnds"] = [-180, 180]

if (
(region is not None)
and (region["method"] in region)
Expand Down Expand Up @@ -587,26 +600,14 @@ def spatial_mean(

# Uses xesmf.SpatialAverager
elif method == "xesmf":
logger.warning(
"A bug has been found with xesmf.SpatialAverager that appears to impact big regions. "
"Until this is fixed, make sure that the computation is right or use multiple smaller regions."
)
# If the region is a bounding box, call shapely and geopandas to transform it into an input compatible with xesmf
if region["method"] == "bbox":
lon_point_list = [
polygon_geom = shapely.box(
region["lon_bnds"][0],
region["lon_bnds"][0],
region["lon_bnds"][1],
region["lon_bnds"][1],
]
lat_point_list = [
region["lat_bnds"][0],
region["lon_bnds"][1],
region["lat_bnds"][1],
region["lat_bnds"][1],
region["lat_bnds"][0],
]

polygon_geom = Polygon(zip(lon_point_list, lat_point_list))
)
polygon = gpd.GeoDataFrame(index=[0], geometry=[polygon_geom])

# Prepare the History field
Expand All @@ -619,8 +620,10 @@ def spatial_mean(
elif region["method"] == "shape":
if not isinstance(region["shape"], gpd.GeoDataFrame):
polygon = gpd.read_file(region["shape"])
name = Path(region["shape"]).name
else:
polygon = region["shape"]
name = f"{len(polygon)} polygons"

# Simplify the geometries to a given tolerance, if needed.
# The simpler the polygons, the faster the averaging, but it will lose some precision.
Expand All @@ -632,14 +635,19 @@ def spatial_mean(
# Prepare the History field
new_history = (
f"[{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] "
f"xesmf.SpatialAverager over {Path(region['shape']).name} - xESMF v{xe.__version__}"
f"xesmf.SpatialAverager over {name} - xESMF v{xe.__version__}"
)

else:
raise ValueError("'method' should be one of [bbox, shape].")

kwargs_copy = deepcopy(kwargs)
skipna = kwargs_copy.pop("skipna", False)
call_kwargs = {"skipna": kwargs_copy.pop("skipna", False)}
if "out_chunks" in kwargs:
call_kwargs["out_chunks"] = kwargs_copy.pop("out_chunks")

# Pre-emptive segmentization. Same threshold as xESMF, but there's not strong analysis behind this choice
geoms = shapely.segmentize(polygon.geometry, 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Am I right in thinking that this will be skipped if we provided a shapefile with a resolution finer than 1°?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It won't do anything to the shape indeed.

geoms == shapely.segmentize(geoms, 1) if no geom has any vertex longer than 1.


if (
ds.cf["longitude"].ndim == 2
Expand All @@ -650,8 +658,8 @@ def spatial_mean(

ds = ds.update(create_bounds_rotated_pole(ds))

savg = xe.SpatialAverager(ds, polygon.geometry, **kwargs_copy)
ds_agg = savg(ds, keep_attrs=True, skipna=skipna)
savg = xe.SpatialAverager(ds, geoms, **kwargs_copy)
ds_agg = savg(ds, keep_attrs=True, **call_kwargs)
extra_coords = {
col: xr.DataArray(polygon[col], dims=("geom",))
for col in polygon.columns
Expand Down
15 changes: 11 additions & 4 deletions xscen/regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ def regrid_dataset(
Destination grid. The Dataset needs to have lat/lon coordinates.
Supports a 'mask' variable compatible with ESMF standards.
regridder_kwargs : dict
Arguments to send xe.Regridder(). If it contains `skipna`, that
one is passed to the regridder call directly.
Arguments to send xe.Regridder(). If it contains `skipna` or `out_chunks`, those
are passed to the regridder call directly.
intermediate_grids : dict
This argument is used to do a regridding in many steps, regridding to regular
grids before regridding to the final ds_grid.
Expand Down Expand Up @@ -123,15 +123,22 @@ def regrid_dataset(
):
kwargs["weights"] = weights_filename
kwargs["reuse_weights"] = True
skipna = regridder_kwargs.pop("skipna", False)

# Extract args that are to be given at call time.
# out_chunks is only valid for xesmf >= 0.8, so don't add it be default to the call_kwargs
call_kwargs = {"skipna": regridder_kwargs.pop("skipna", False)}
if "out_chunks" in regridder_kwargs:
call_kwargs["out_chunks"] = regridder_kwargs.pop("out_chunks")

regridder = _regridder(
ds_in=ds, ds_grid=ds_grid, filename=weights_filename, **regridder_kwargs
)

# The regridder (when fed Datasets) doesn't like if 'mask' is present.
if "mask" in ds:
ds = ds.drop_vars(["mask"])
out = regridder(ds, keep_attrs=True, skipna=skipna)

out = regridder(ds, keep_attrs=True, **call_kwargs)

# double-check that grid_mapping information is transferred
gridmap_out = any(
Expand Down