Skip to content

Commit

Permalink
Add tests to aggregate.py (#486)
Browse files Browse the repository at this point in the history
<!-- Please ensure the PR fulfills the following requirements! -->
<!-- If this is your first PR, make sure to add your details to the
AUTHORS.rst! -->
### Pull Request Checklist:
- [x] This PR addresses an already opened issue (for bug fixes /
features)
    - This PR fixes #xyz
- [x] (If applicable) Documentation has been added / updated (for bug
fixes / features).
- [x] (If applicable) Tests have been added.
- [x] This PR does not seem to break the templates.
- [ ] CHANGELOG.rst has been updated (with summary of main changes).
- [ ] Link to issue (:issue:`number`) and pull request (:pull:`number`)
has been added.

### What kind of change does this PR introduce?

* Mainly, this adds tests to the lines in `aggregate.py` that had yet to
be tested.
* Bugfix: `climatological_op` now correctly handles `kwargs`.
* Bugfix: `climatological_op` with `linregress` now correctly handles
`min_periods`
* Removed obsolete code now that we pin `xarray >=2023.11`.
* Bugfix: `datablock_3d` did not use the same attributes for the rotated
pole when creating the fake dataset.

### Does this PR introduce a breaking change?
* `method=='interp_centroid'` in `spatial_mean` has been removed without
backwards compatibility or staged removal, since it could produce vastly
erroneous results.
* Not really breaking since the function was not public, but
`regrid._get_grid_mapping` was moved to `spatial.py` and made public.


### Other information:
  • Loading branch information
RondeauG authored Nov 4, 2024
2 parents ceeddf8 + b8adf59 commit c5a3261
Show file tree
Hide file tree
Showing 9 changed files with 490 additions and 235 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,16 @@ Bug fixes
* ``xs.io.save_to_table`` now correctly handles the case where the input is a `DataArray` or a `Dataset` with a single variable. (:pull:`473`).
* Fixed a bug in ``xs.utils.change_units`` where the original dataset was also getting modified. (:pull:`482`).
* Fixed a bug in ``xs.compute_indicators`` where the `cat:variable` attribute was not correctly set. (:pull:`483`).
* Fixed a bug in ``xs.climatological_op`` where kwargs were not passed to the operation function. (:pull:`486`).
* Fixed a bug in ``xs.climatological_op`` where `min_periods` was not passed when the operation was `linregress`. (:pull:`486`).

Internal changes
^^^^^^^^^^^^^^^^
* Include CF convention for temperature differences and on scale (:pull:`428`, :issue:`428`).
* Bumped the version of `xclim` to 0.53.2. (:pull:`482`).
* More tests added. (:pull:`486`).
* Fixed a bug in ``xs.testing.datablock_3d`` where some attributes of the rotated pole got reversed half-way through the creation of the dataset. (:pull:`486`).
* The function ``xs.regrid._get_grid_mapping`` was moved to ``xs.spatial.get_grid_mapping`` and is now a public function. (:pull:`486`).

v0.10.0 (2024-09-30)
--------------------
Expand Down
9 changes: 1 addition & 8 deletions docs/notebooks/2_getting_started.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1189,12 +1189,6 @@
"\n",
"- `method: cos-lat` will perform an average operation over the spatial dimensions, accounting for changes in grid cell area along the 'lat' coordinate.\n",
"\n",
"\n",
"- `method: interp_centroid` will perform an interpolation towards given coordinates or towards the centroid of a region.\n",
" - `kwargs` is used to sent arguments to `.interp()`, including `lon` and `lat`.\n",
" - `region` can alternatively be used to send a gridpoint, bbox, or shapefile and compute the centroid. This argument is a dictionary that follows the same requirements as the one for `xs.extract` described previously.\n",
"\n",
"\n",
"- `method: xesmf` will perform a call to *xESMF*'s [SpatialAverager](https://pangeo-xesmf.readthedocs.io/en/latest/notebooks/Spatial_Averaging.html). This method is the most precise, especially for irregular regions, but can be much slower.\n",
" - `kwargs` is used to sent arguments to `xesmf.SpatialAverager`.\n",
" - `region` is used to send a bbox or shapefile to the `SpatialAverager`. This argument is a dictionary that follows the same requirements as the one for `xs.extract` described previously.\n",
Expand All @@ -1215,8 +1209,7 @@
"for key, ds in ds_dict.items():\n",
" ds_savg = xs.spatial_mean(\n",
" ds=ds,\n",
" method=\"interp_centroid\",\n",
" kwargs={\"method\": \"linear\", \"lon\": -74.5, \"lat\": 47},\n",
" method=\"cos-lat\",\n",
" to_domain=\"aggregated\",\n",
" )\n",
"\n",
Expand Down
248 changes: 109 additions & 139 deletions src/xscen/aggregate.py

Large diffs are not rendered by default.

35 changes: 11 additions & 24 deletions src/xscen/regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

import cartopy.crs as ccrs
import cf_xarray as cfxr
import numpy as np
import xarray as xr
from xclim.core.units import convert_units_to

Expand All @@ -23,6 +22,7 @@
Regridder = "xesmf.Regridder"

from .config import parse_config
from .spatial import get_grid_mapping

__all__ = ["create_bounds_gridmapping", "create_mask", "regrid_dataset"]

Expand Down Expand Up @@ -164,7 +164,7 @@ def regrid_dataset( # noqa: C901
out = regridder(ds, keep_attrs=True, **call_kwargs)

# Double-check that grid_mapping information is transferred
gridmap_out = _get_grid_mapping(ds_grid)
gridmap_out = get_grid_mapping(ds_grid)
if gridmap_out:
# Regridder seems to seriously mess up the rotated dimensions
for d in out.lon.dims:
Expand All @@ -182,7 +182,7 @@ def regrid_dataset( # noqa: C901
if gridmap_out not in out:
out = out.assign_coords({gridmap_out: ds_grid[gridmap_out]})
else:
gridmap_in = _get_grid_mapping(ds)
gridmap_in = get_grid_mapping(ds)
# Remove the original grid_mapping attribute
for v in out.data_vars:
if "grid_mapping" in out[v].attrs:
Expand Down Expand Up @@ -349,8 +349,8 @@ def _regridder(
Regridder object
"""
if method.startswith("conservative"):
gridmap_in = _get_grid_mapping(ds_in)
gridmap_grid = _get_grid_mapping(ds_grid)
gridmap_in = get_grid_mapping(ds_in)
gridmap_grid = get_grid_mapping(ds_grid)

if (
ds_in.cf["longitude"].ndim == 2
Expand Down Expand Up @@ -386,8 +386,13 @@ def create_bounds_rotated_pole(ds: xr.Dataset) -> xr.Dataset:
return create_bounds_gridmapping(ds, "rotated_pole")


def create_bounds_gridmapping(ds: xr.Dataset, gridmap: str) -> xr.Dataset:
def create_bounds_gridmapping(ds: xr.Dataset, gridmap: str | None = None) -> xr.Dataset:
"""Create bounds for rotated pole datasets."""
if gridmap is None:
gridmap = get_grid_mapping(ds)
if gridmap == "":
raise ValueError("Grid mapping could not be inferred from the dataset.")

xname = ds.cf.axes["X"][0]
yname = ds.cf.axes["Y"][0]

Expand Down Expand Up @@ -452,24 +457,6 @@ def _get_opt_attr_as_float(da: xr.DataArray, attr: str) -> float | None:
return ds_bnds.transpose(*ds.lon.dims, "bounds")


def _get_grid_mapping(ds: xr.Dataset) -> str:
"""Get the grid_mapping attribute from the dataset."""
gridmap = [
ds[v].attrs["grid_mapping"]
for v in ds.data_vars
if "grid_mapping" in ds[v].attrs
]
gridmap += [c for c in ds.coords if ds[c].attrs.get("grid_mapping_name", None)]
gridmap = list(np.unique(gridmap))

if len(gridmap) > 1:
warnings.warn(
f"There are conflicting grid_mapping attributes in the dataset. Assuming {gridmap[0]}."
)

return gridmap[0] if gridmap else ""


def _generate_random_string(length: int):
characters = string.ascii_letters + string.digits

Expand Down
19 changes: 19 additions & 0 deletions src/xscen/spatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
__all__ = [
"creep_fill",
"creep_weights",
"get_grid_mapping",
"subset",
]

Expand Down Expand Up @@ -435,6 +436,24 @@ def _load_lon_lat(ds: xr.Dataset) -> xr.Dataset:
return ds


def get_grid_mapping(ds: xr.Dataset) -> str:
"""Get the grid_mapping attribute from the dataset."""
gridmap = [
ds[v].attrs["grid_mapping"]
for v in ds.data_vars
if "grid_mapping" in ds[v].attrs
]
gridmap += [c for c in ds.coords if ds[c].attrs.get("grid_mapping_name", None)]
gridmap = list(np.unique(gridmap))

if len(gridmap) > 1:
warnings.warn(
f"There are conflicting grid_mapping attributes in the dataset. Assuming {gridmap[0]}."
)

return gridmap[0] if gridmap else ""


def _estimate_grid_resolution(ds: xr.Dataset) -> tuple[float, float]:
# Since this is to compute a buffer, we take the maximum difference as an approximation.
# Estimate the grid resolution
Expand Down
2 changes: 1 addition & 1 deletion src/xscen/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def datablock_3d(
PC = ccrs.PlateCarree()
if x == "rlon": # rotated pole
GM = ccrs.RotatedPole(
pole_longitude=42.5, pole_latitude=83.0, central_rotated_longitude=0.0
pole_longitude=83.0, pole_latitude=42.5, central_rotated_longitude=0.0
)
da.attrs["grid_mapping"] = "rotated_pole"
else:
Expand Down
Loading

0 comments on commit c5a3261

Please sign in to comment.