Skip to content

Commit

Permalink
Fix global avg - out_chunks in xe wrappers
Browse files Browse the repository at this point in the history
  • Loading branch information
aulemahal committed Sep 25, 2023
1 parent ac1f7be commit 1f5e173
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 9 deletions.
25 changes: 25 additions & 0 deletions tests/test_aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from xclim.testing.helpers import test_timeseries as timeseries

import xscen as xs
from xscen.testing import datablock_3d


class TestClimatologicalMean:
Expand Down Expand Up @@ -523,3 +524,27 @@ def test_errors(self):
xs.produce_horizon(
self.ds, indicators=self.yaml_file, periods=[["1950", "1990"]]
)


class TestSpatialMean:
# We test different longitude flavors : all < 0, crossing 0, all > 0
# the default global bbox changes because of subtleties in clisops
@pytest.mark.parametrize(
"method,exp", (["xesmf", 1.62032976], ["cos-lat", 1.63397460])
)
@pytest.mark.parametrize("lonstart", [-70, -30, 0])
def test_global(self, lonstart, method, exp):
ds = datablock_3d(
np.array([[[0, 1, 2], [1, 2, 3], [2, 3, 4]]] * 3, "float"),
"tas",
"lon",
lonstart,
"lat",
15,
30,
30,
as_dataset=True,
)

avg = xs.aggregate.spatial_mean(ds, method=method, region="global")
np.testing.assert_allclose(avg.tas, exp)
16 changes: 11 additions & 5 deletions xscen/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ def spatial_mean(
Can also be "global", for global averages. This is simply a shortcut for `{'name': 'global', 'method': 'bbox', 'lon_bnds' [-180, 180], 'lat_bnds': [-90, 90]}`.
kwargs : dict
Arguments to send to either mean(), interp() or SpatialAverager().
For SpatialAverager, one can give `skipna` here, to be passed to the averager call itself.
For SpatialAverager, one can give `skipna` or `out_chunks` here, to be passed to the averager call itself.
simplify_tolerance : float
Precision (in degree) used to simplify a shapefile before sending it to SpatialAverager().
The simpler the polygons, the faster the averaging, but it will lose some precision.
Expand Down Expand Up @@ -456,9 +456,13 @@ def spatial_mean(
region = {
"name": "global",
"method": "bbox",
"lon_bnds": [-180, 180],
"lat_bnds": [-90, 90],
"lat_bnds": [-90 + 1e-5, 90 - 1e-5],
}
# `spatial_subset` won't wrap coords on the bbox, we need to fit the system used on ds.
if ds.cf["longitude"].min() >= 0:
region["lon_bnds"] = [0, 360]
else:
region["lon_bnds"] = [-180, 180]

if (
(region is not None)
Expand Down Expand Up @@ -638,7 +642,9 @@ def spatial_mean(
raise ValueError("'method' should be one of [bbox, shape].")

kwargs_copy = deepcopy(kwargs)
skipna = kwargs_copy.pop("skipna", False)
call_kwargs = {"skipna": kwargs_copy.pop("skipna", False)}
if "out_chunks" in kwargs:
call_kwargs["out_chunks"] = kwargs_copy.pop("out_chunks")

# Pre-emptive segmentization. Same threshold as xESMF, but there's not strong analysis behind this choice
geoms = shapely.segmentize(polygon.geometry, 1)
Expand All @@ -653,7 +659,7 @@ def spatial_mean(
ds = ds.update(create_bounds_rotated_pole(ds))

savg = xe.SpatialAverager(ds, geoms, **kwargs_copy)
ds_agg = savg(ds, keep_attrs=True, skipna=skipna)
ds_agg = savg(ds, keep_attrs=True, **call_kwargs)
extra_coords = {
col: xr.DataArray(polygon[col], dims=("geom",))
for col in polygon.columns
Expand Down
15 changes: 11 additions & 4 deletions xscen/regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ def regrid_dataset(
Destination grid. The Dataset needs to have lat/lon coordinates.
Supports a 'mask' variable compatible with ESMF standards.
regridder_kwargs : dict
Arguments to send xe.Regridder(). If it contains `skipna`, that
one is passed to the regridder call directly.
Arguments to send xe.Regridder(). If it contains `skipna` or `out_chunks`, those
are passed to the regridder call directly.
intermediate_grids : dict
This argument is used to do a regridding in many steps, regridding to regular
grids before regridding to the final ds_grid.
Expand Down Expand Up @@ -123,15 +123,22 @@ def regrid_dataset(
):
kwargs["weights"] = weights_filename
kwargs["reuse_weights"] = True
skipna = regridder_kwargs.pop("skipna", False)

# Extract args that are to be given at call time.
# out_chunks is only valid for xesmf >= 0.8, so don't add it be default to the call_kwargs
call_kwargs = {"skipna": regridder_kwargs.pop("skipna", False)}
if "out_chunks" in regridder_kwargs:
call_kwargs["out_chunks"] = regridder_kwargs.pop("out_chunks")

regridder = _regridder(
ds_in=ds, ds_grid=ds_grid, filename=weights_filename, **regridder_kwargs
)

# The regridder (when fed Datasets) doesn't like if 'mask' is present.
if "mask" in ds:
ds = ds.drop_vars(["mask"])
out = regridder(ds, keep_attrs=True, skipna=skipna)

out = regridder(ds, keep_attrs=True, **call_kwargs)

# double-check that grid_mapping information is transferred
gridmap_out = any(
Expand Down

0 comments on commit 1f5e173

Please sign in to comment.