diff --git a/.github/workflows/mamba.yml b/.github/workflows/mamba.yml deleted file mode 100644 index dd81e64..0000000 --- a/.github/workflows/mamba.yml +++ /dev/null @@ -1,39 +0,0 @@ -name: Micromamba - -on: - push: - branches: - - main - pull_request: - branches: - - main - -jobs: - - build: - name: Micromamba test on (${{ matrix.python-version }}, ${{ matrix.os }}) - runs-on: ${{ matrix.os }} - strategy: - fail-fast: false - matrix: - os: ['ubuntu-latest'] - python-version: ['3.10'] - env: - MPLBACKEND: Agg # https://github.com/orgs/community/discussions/26434 - steps: - - uses: actions/checkout@v3 - - uses: mamba-org/setup-micromamba@v1 - with: - environment-file: environment.yml - cache-environment: true - init-shell: bash - - name: Micromamba info - shell: bash -el {0} - run: | - micromamba info - - name: Install dev dependencies - run: pip install .[dev] - shell: bash -el {0} - - name: Run pytest - run: pytest - shell: micromamba-shell {0} diff --git a/demo/cams_co2_dataset_demo.ipynb b/demo/cams_co2_dataset_demo.ipynb index b1394a6..238b3da 100644 --- a/demo/cams_co2_dataset_demo.ipynb +++ b/demo/cams_co2_dataset_demo.ipynb @@ -142,7 +142,6 @@ " spatial_bounds=bbox_demo,\n", " variable_names=[\"co2_concentration\"],\n", " resolution=1.0,\n", - " regrid_method=\"flox\",\n", ")" ] }, diff --git a/demo/era5-land_dataset_demo.ipynb b/demo/era5-land_dataset_demo.ipynb index 357b015..2091064 100644 --- a/demo/era5-land_dataset_demo.ipynb +++ b/demo/era5-land_dataset_demo.ipynb @@ -149,7 +149,6 @@ " spatial_bounds=bbox_demo,\n", " variable_names=[\"air_temperature\", \"dewpoint_temperature\"],\n", " resolution=1.0,\n", - " regrid_method=\"flox\",\n", ")" ] }, diff --git a/demo/era5_dataset_demo.ipynb b/demo/era5_dataset_demo.ipynb index a31f621..2d1ccc5 100644 --- a/demo/era5_dataset_demo.ipynb +++ b/demo/era5_dataset_demo.ipynb @@ -143,7 +143,6 @@ " spatial_bounds=bbox_demo,\n", " variable_names=[\"eastward_component_of_wind\"],\n", " resolution=1.0,\n", - " regrid_method=\"flox\",\n", ")" ] }, diff --git a/demo/eth_dataset_demo.ipynb b/demo/eth_dataset_demo.ipynb index 3865491..b6a6049 100644 --- a/demo/eth_dataset_demo.ipynb +++ b/demo/eth_dataset_demo.ipynb @@ -126,7 +126,6 @@ " spatial_bounds=bbox_demo,\n", " variable_names=[\"height_of_vegetation\"],\n", " resolution=0.05,\n", - " regrid_method=\"flox\",\n", ")" ] }, diff --git a/demo/land_cover_dataset_demo.ipynb b/demo/land_cover_dataset_demo.ipynb index 3862943..bc0090e 100644 --- a/demo/land_cover_dataset_demo.ipynb +++ b/demo/land_cover_dataset_demo.ipynb @@ -145,7 +145,6 @@ " spatial_bounds=bbox_demo,\n", " variable_names=[\"land_cover\"],\n", " resolution=1.0,\n", - " regrid_method=\"most_common\",\n", ")" ] }, @@ -1115,7 +1114,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.0" + "version": "3.10.12" }, "orig_nbformat": 4 }, diff --git a/demo/prism_dem_demo.ipynb b/demo/prism_dem_demo.ipynb index c2f1f44..eeb4f7b 100644 --- a/demo/prism_dem_demo.ipynb +++ b/demo/prism_dem_demo.ipynb @@ -139,7 +139,6 @@ " spatial_bounds=bbox_demo,\n", " variable_names=[\"elevation\"],\n", " resolution=0.01,\n", - " regrid_method=\"flox\",\n", ")" ] }, diff --git a/environment.yml b/environment.yml deleted file mode 100644 index 77c4945..0000000 --- a/environment.yml +++ /dev/null @@ -1,20 +0,0 @@ ---- -name: zampy -channels: - - conda-forge -dependencies: - - python==3.10 - - xESMF - - requests - - netcdf4 - - numpy - - pandas - - matplotlib - - xarray - - rioxarray # required for TIFF files - - tqdm - - dask[diagnostics] - - pint - - cf_xarray # required to auto-pint CF compliant datasets - - pint-xarray - - flox diff --git a/example_recipe.yml b/example_recipe.yml index 97c8dce..455aefa 100644 --- a/example_recipe.yml +++ b/example_recipe.yml @@ -15,7 +15,6 @@ convert: standard: ALMA-PLUMBER2 frequency: 1H # outputs at 1 hour frequency. Pandas-like freq-keyword. resolution: 0.01 # output resolution in degrees. - conversion-method: "flox" # Either flox or xesmf. xesmf requires conda + linux. additional_variables: # Possible future addition saturation_vapor_pressure: diff --git a/pyproject.toml b/pyproject.toml index 08bd5e2..cd3dd07 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,9 +62,8 @@ dependencies = [ "pint", "cf_xarray", # required to auto-pint CF compliant datasets. "pint-xarray", - "flox", # For better groupby methods. "cdsapi", - "xarray-regrid", # for land cover data regridding + "xarray-regrid", # for regridding ] dynamic = ["version"] @@ -120,16 +119,6 @@ features = ["docs"] build = ["mkdocs build"] serve = ["mkdocs serve"] -# [tool.hatch.envs.conda] -# type = "conda" -# python = "3.10" -# command = "micromamba" -# environment-file = "environment.yml" -# extra-dependencies = ["pytest", "pytest-cov"] - -# [tool.hatch.envs.conda.scripts] -# test = ["pytest ./tests/",] - [tool.pytest.ini_options] testpaths = ["tests"] diff --git a/src/zampy/__init__.py b/src/zampy/__init__.py index 98164ae..e5cc82f 100644 --- a/src/zampy/__init__.py +++ b/src/zampy/__init__.py @@ -1,6 +1,5 @@ """zampy.""" from zampy import datasets -from zampy import utils __author__ = "Bart Schilperoort" @@ -8,4 +7,4 @@ __version__ = "0.1.0" -__all__ = ["datasets", "utils"] +__all__ = ["datasets"] diff --git a/src/zampy/datasets/dataset_protocol.py b/src/zampy/datasets/dataset_protocol.py index 097abdc..ce3a0cc 100644 --- a/src/zampy/datasets/dataset_protocol.py +++ b/src/zampy/datasets/dataset_protocol.py @@ -130,7 +130,6 @@ def load( time_bounds: TimeBounds, spatial_bounds: SpatialBounds, resolution: float, - regrid_method: str, variable_names: list[str], ) -> xr.Dataset: """Get the dataset as an xarray Dataset. @@ -142,9 +141,6 @@ def load( loaded. resolution: The desired resolution of the loaded data. The ingested data will be regridded to match this resolution. - regrid_method: Which routines to use to resample. Either "flox" (default) or - "esmf". Of these two, esmf is the more robust and accurate regridding - method, however it can be difficult to install. variable_names: Which variables should be loaded. diff --git a/src/zampy/datasets/ecmwf_dataset.py b/src/zampy/datasets/ecmwf_dataset.py index 2254ae5..87f39bc 100644 --- a/src/zampy/datasets/ecmwf_dataset.py +++ b/src/zampy/datasets/ecmwf_dataset.py @@ -2,6 +2,7 @@ from pathlib import Path import xarray as xr +import xarray_regrid # noqa: F401 from zampy.datasets import cds_utils from zampy.datasets import converter from zampy.datasets import validation @@ -10,7 +11,7 @@ from zampy.datasets.dataset_protocol import Variable from zampy.datasets.dataset_protocol import copy_properties_file from zampy.datasets.dataset_protocol import write_properties_file -from zampy.utils import regrid +from zampy.datasets.utils import make_grid ## Ignore missing class/method docstrings: they are implemented in the Dataset class. @@ -111,7 +112,6 @@ def load( time_bounds: TimeBounds, spatial_bounds: SpatialBounds, resolution: float, - regrid_method: str, variable_names: list[str], ) -> xr.Dataset: files: list[Path] = [] @@ -121,7 +121,11 @@ def load( ds = xr.open_mfdataset(files, chunks={"latitude": 200, "longitude": 200}) ds = ds.sel(time=slice(time_bounds.start, time_bounds.end)) - ds = regrid.regrid_data(ds, spatial_bounds, resolution, regrid_method) + + grid = xarray_regrid.create_regridding_dataset( + make_grid(spatial_bounds, resolution) + ) + ds = ds.regrid.linear(grid) return ds diff --git a/src/zampy/datasets/eth_canopy_height.py b/src/zampy/datasets/eth_canopy_height.py index a1cb518..c186de3 100644 --- a/src/zampy/datasets/eth_canopy_height.py +++ b/src/zampy/datasets/eth_canopy_height.py @@ -3,6 +3,7 @@ from pathlib import Path import numpy as np import xarray as xr +import xarray_regrid from zampy.datasets import converter from zampy.datasets import utils from zampy.datasets import validation @@ -13,7 +14,6 @@ from zampy.datasets.dataset_protocol import write_properties_file from zampy.reference.variables import VARIABLE_REFERENCE_LOOKUP from zampy.reference.variables import unit_registry -from zampy.utils import regrid VALID_NAME_FILE = ( @@ -126,7 +126,6 @@ def load( time_bounds: TimeBounds, spatial_bounds: SpatialBounds, resolution: float, - regrid_method: str, variable_names: list[str], ) -> xr.Dataset: files: list[Path] = [] @@ -137,7 +136,11 @@ def load( ds = xr.open_mfdataset(files, chunks={"latitude": 2000, "longitude": 2000}) ds = ds.sel(time=slice(time_bounds.start, time_bounds.end)) - ds = regrid.regrid_data(ds, spatial_bounds, resolution, regrid_method) + + grid = xarray_regrid.create_regridding_dataset( + utils.make_grid(spatial_bounds, resolution) + ) + ds = ds.regrid.linear(grid) return ds def convert( diff --git a/src/zampy/datasets/fapar_lai.py b/src/zampy/datasets/fapar_lai.py index 0dbd629..545bfb6 100644 --- a/src/zampy/datasets/fapar_lai.py +++ b/src/zampy/datasets/fapar_lai.py @@ -13,6 +13,7 @@ from tqdm import tqdm from zampy.datasets import cds_utils from zampy.datasets import converter +from zampy.datasets import utils from zampy.datasets import validation from zampy.datasets.dataset_protocol import SpatialBounds from zampy.datasets.dataset_protocol import TimeBounds @@ -141,7 +142,6 @@ def load( time_bounds: TimeBounds, spatial_bounds: SpatialBounds, resolution: float, - regrid_method: str, # should be deprecated. variable_names: list[str], ) -> xr.Dataset: files = list((ingest_dir / self.name).glob("*.nc")) @@ -149,8 +149,10 @@ def load( ds = xr.open_mfdataset(files, parallel=True) ds = ds.sel(time=slice(time_bounds.start, time_bounds.end)) - target_dataset = create_regridding_ds(spatial_bounds, resolution) - ds = ds.regrid.linear(target_dataset) + grid = xarray_regrid.create_regridding_dataset( + utils.make_grid(spatial_bounds, resolution) + ) + ds = ds.regrid.linear(grid) return ds @@ -175,29 +177,6 @@ def convert( # Will be removed, see issue #43. return True -def create_regridding_ds( - spatial_bounds: SpatialBounds, resolution: float -) -> xr.Dataset: - """Create dataset to use with xarray-regrid regridding. - - Args: - spatial_bounds: Spatial bounds of the new dataset. - resolution: Latitude and longitude resolution of the new dataset. - - Returns: - The dataset ready to be used in regridding. - """ - new_grid = xarray_regrid.Grid( - north=spatial_bounds.north, - east=spatial_bounds.east, - south=spatial_bounds.south, - west=spatial_bounds.west, - resolution_lat=resolution, - resolution_lon=resolution, - ) - return xarray_regrid.create_regridding_dataset(new_grid) - - def get_year_month_pairs(time_bounds: TimeBounds) -> list[tuple[int, int]]: """Get the year and month pairs covering the input time bounds.""" start = pd.to_datetime(time_bounds.start) diff --git a/src/zampy/datasets/land_cover.py b/src/zampy/datasets/land_cover.py index 66e8f13..0ae9b78 100644 --- a/src/zampy/datasets/land_cover.py +++ b/src/zampy/datasets/land_cover.py @@ -8,6 +8,7 @@ import xarray_regrid from zampy.datasets import cds_utils from zampy.datasets import converter +from zampy.datasets import utils from zampy.datasets import validation from zampy.datasets.dataset_protocol import SpatialBounds from zampy.datasets.dataset_protocol import TimeBounds @@ -121,7 +122,6 @@ def load( time_bounds: TimeBounds, spatial_bounds: SpatialBounds, resolution: float, - regrid_method: str, # Unused in land-cover dataset variable_names: list[str], ) -> xr.Dataset: files: list[Path] = [] @@ -137,19 +137,13 @@ def load( ds = xr.open_mfdataset(files, chunks={"latitude": 200, "longitude": 200}) ds = ds.sel(time=slice(time_bounds.start, time_bounds.end)) - new_grid = xarray_regrid.Grid( - north=spatial_bounds.north, - east=spatial_bounds.east, - south=spatial_bounds.south, - west=spatial_bounds.west, - resolution_lat=resolution, - resolution_lon=resolution, - ) - target_dataset = xarray_regrid.create_regridding_dataset(new_grid) - ds_regrid = ds.regrid.most_common(target_dataset, time_dim="time", max_mem=1e9) + grid = xarray_regrid.create_regridding_dataset( + utils.make_grid(spatial_bounds, resolution) + ) + ds = ds.regrid.most_common(grid, time_dim="time", max_mem=1e9) - return ds_regrid + return ds def convert( self, diff --git a/src/zampy/datasets/prism_dem.py b/src/zampy/datasets/prism_dem.py index 55ca59d..73440bb 100644 --- a/src/zampy/datasets/prism_dem.py +++ b/src/zampy/datasets/prism_dem.py @@ -5,6 +5,7 @@ from typing import Literal import numpy as np import xarray as xr +import xarray_regrid from rasterio.io import MemoryFile from zampy.datasets import converter from zampy.datasets import utils @@ -16,7 +17,6 @@ from zampy.datasets.dataset_protocol import write_properties_file from zampy.reference.variables import VARIABLE_REFERENCE_LOOKUP from zampy.reference.variables import unit_registry -from zampy.utils import regrid VALID_NAME_FILES = [ @@ -127,7 +127,6 @@ def load( time_bounds: TimeBounds, # Unused in PrismDEM spatial_bounds: SpatialBounds, resolution: float, - regrid_method: str, variable_names: list[str], ) -> xr.Dataset: for var in variable_names: @@ -145,7 +144,11 @@ def preproc(ds: xr.Dataset) -> xr.Dataset: return ds.isel(latitude=slice(None, -1), longitude=slice(None, -1)) ds = xr.open_mfdataset(files, preprocess=preproc) - ds = regrid.regrid_data(ds, spatial_bounds, resolution, regrid_method) + + grid = xarray_regrid.create_regridding_dataset( + utils.make_grid(spatial_bounds, resolution) + ) + ds = ds.regrid.linear(grid) return ds diff --git a/src/zampy/datasets/utils.py b/src/zampy/datasets/utils.py index ee15680..69489a2 100644 --- a/src/zampy/datasets/utils.py +++ b/src/zampy/datasets/utils.py @@ -2,7 +2,9 @@ import urllib.request from pathlib import Path import requests +import xarray_regrid from tqdm import tqdm +from zampy.datasets.dataset_protocol import SpatialBounds class TqdmUpdate(tqdm): @@ -52,3 +54,15 @@ def get_file_size(fpath: Path) -> int: return 0 else: return fpath.stat().st_size + + +def make_grid(spatial_bounds: SpatialBounds, resolution: float) -> xarray_regrid.Grid: + """MAke a regridding grid for passing to xarray-regrid.""" + return xarray_regrid.Grid( + north=spatial_bounds.north, + east=spatial_bounds.east, + south=spatial_bounds.south, + west=spatial_bounds.west, + resolution_lat=resolution, + resolution_lon=resolution, + ) diff --git a/src/zampy/recipe.py b/src/zampy/recipe.py index 0a06c50..5cd26ac 100644 --- a/src/zampy/recipe.py +++ b/src/zampy/recipe.py @@ -113,7 +113,6 @@ def run(self) -> None: spatial_bounds=self.spatialbounds, variable_names=variables, resolution=self.resolution, - regrid_method="flox", ) ds = converter.convert(ds, dataset, convention=self.convention) diff --git a/src/zampy/utils/regrid.py b/src/zampy/utils/regrid.py deleted file mode 100644 index fdfaec0..0000000 --- a/src/zampy/utils/regrid.py +++ /dev/null @@ -1,222 +0,0 @@ -"""Zampy regridding functions.""" -import numpy as np -import pandas as pd -import xarray as xr -from flox import xarray as floxarray -from zampy.datasets.dataset_protocol import SpatialBounds - - -def assert_xesmf_available() -> None: - """Util that attempts to load the optional module xesmf.""" - try: - import xesmf as _ # noqa: F401 (unused import) - - except ImportError as e: - raise ImportError( - "Could not import the `xesmf` module.\nPlease install this" - " before continuing, with either `pip` or `conda`." - ) from e - - -def infer_resolution(dataset: xr.Dataset) -> tuple[float, float]: - """Infer the resolution of a dataset's latitude and longitude coordinates. - - Args: - dataset: dataset with latitude and longitude coordinates. - - Returns: - The latitude and longitude resolution. - """ - resolution_lat = np.median( - np.diff( - dataset["latitude"].to_numpy(), - n=1, - ) - ) - resolution_lon = np.median( - np.diff( - dataset["longitude"].to_numpy(), - n=1, - ) - ) - - return (resolution_lat, resolution_lon) - - -def _groupby_regrid( - data: xr.Dataset, - spatial_bounds: SpatialBounds, - resolution: float, -) -> xr.Dataset: - """Coarsen a dataset using xrarray's groupby method. - - Args: - data: Input dataset. - spatial_bounds: Spatial bounds of the new grid. - resolution: Resolution of the new grid. - - Returns: - Regridded input dataset - """ - # Determine the minumum number of datapoints per group. Simulates xesmf's na_thres. - na_thres = 0.10 - data_resolution = infer_resolution(data) - n_points = (resolution / data_resolution[0]) * (resolution / data_resolution[1]) - min_points = int(n_points * (1 - na_thres)) - - # Create bins to group by. Offset by 0.5*resolution, so bins are centered. - lat_bins = pd.interval_range( - start=spatial_bounds.south - 0.5 * resolution, - end=spatial_bounds.north + 0.5 * resolution, - periods=np.round((spatial_bounds.north - spatial_bounds.south) / resolution) - + 1, - closed="left", # Closed "both" is not implemented (yet) in Flox - ) - lon_bins = pd.interval_range( - spatial_bounds.west - 0.5 * resolution, - spatial_bounds.east + 0.5 * resolution, - periods=np.round((spatial_bounds.east - spatial_bounds.west) / resolution) + 1, - closed="left", # Closed "both" is not implemented (yet) in Flox - ) - - # Group by the bins and reduce w/ mean - ds_out = floxarray.xarray_reduce( - data, - "latitude", - "longitude", - func="mean", - expected_groups=(lat_bins, lon_bins), - isbin=True, - skipna=True, - min_count=min_points, - ) - - # Convert *_bins dimensions back to latitude and longitude - ds_out["latitude"] = ( - "latitude_bins", - [v.mid for v in ds_out["latitude_bins"].values], - ) - ds_out["longitude"] = ( - "longitude_bins", - [v.mid for v in ds_out["longitude_bins"].values], - ) - ds_out = ds_out.set_coords(["latitude", "longitude"]) - ds_out = ds_out.swap_dims( - {"latitude_bins": "latitude", "longitude_bins": "longitude"} - ) - ds_out = ds_out.drop_vars(["latitude_bins", "longitude_bins"]) - if "time" in ds_out.dims: - return ds_out.transpose("time", "latitude", "longitude", ...) - else: - return ds_out.transpose("latitude", "longitude", ...) - - -def _interp_regrid( - data: xr.Dataset, - spatial_bounds: SpatialBounds, - resolution: float, -) -> xr.Dataset: - """Refine a dataset using xrarray's interp method. - - Args: - data: Input dataset. - spatial_bounds: Spatial bounds of the new grid. - resolution: Resolution of the new grid. - - Returns: - Regridded input dataset - """ - lat_coords = np.arange( - spatial_bounds.south, spatial_bounds.north + resolution, resolution - ) - lon_coords = np.arange( - spatial_bounds.west, spatial_bounds.east + resolution, resolution - ) - - return data.interp( - coords={ - "latitude": lat_coords, - "longitude": lon_coords, - }, - method="linear", - ) - - -def flox_regrid( - data: xr.Dataset, - spatial_bounds: SpatialBounds, - resolution: float, -) -> xr.Dataset: - """Regrid a dataset to a new grid, using xarray + flox methods. - - Data will be regridded using groupby and/or linear interpolation, depending on the - ratio between the old and new resolution. - - This regridding method is a rough approximation that will work fine in most cases, - but can struggle in some areas: - - There is no weighted averaging performed: areas near the poles will - incorrectly contribute more to the total average. - - To achieve a conservative regrid-like method, data of a resolution close to - the new resolution is first interpolated (to a finer resolution), and then - regridded to the intended resolution. This will lead to small - inconsistencies with NaN thresholds. - - For a more robust method use the xesmf regridding option, however this does require - installation through conda/mamba, and is not available on Windows. - - Args: - data: Input dataset. - spatial_bounds: Spatial bounds of the new grid. - resolution: Resolution of the new grid. - - Returns: - Regridded input dataset - """ - data_resolution = infer_resolution(data) - old_resolution = min(data_resolution) - - # # Use Nyquist-like criterion to avoid aliasing. - # # At a 4x courser resolution, no issues: just groupy-regrid. - if resolution >= 4 * old_resolution: - return _groupby_regrid(data, spatial_bounds, resolution) - - # At a 4x finer resolution: no issues: just interpolate. - if resolution <= 0.25 * old_resolution: - return _interp_regrid(data, spatial_bounds, resolution) - - # Otherwise we first regrid to a finer grid, and then reduce: - else: - ds_in_interp = _interp_regrid(data, spatial_bounds, old_resolution / 4) - return _groupby_regrid(ds_in_interp, spatial_bounds, resolution) - - -def regrid_data( - data: xr.Dataset, - spatial_bounds: SpatialBounds, - resolution: float, - method: str = "flox", -) -> xr.Dataset: - """Regrid a dataset to a new grid. - - Args: - data: Input dataset. - spatial_bounds: Spatial bounds of the new grid. - resolution: Resolution of the new grid. - method: Which routines to use to resample. Either "flox" (default) or "esmf". - Of these two, esmf is the more robust and accurate regridding method, - however it can be difficult to install. - - Returns: - Regridded input dataset - """ - if method == "flox": - return flox_regrid(data, spatial_bounds, resolution) - - elif method == "esmf": - assert_xesmf_available() - from zampy.utils.xesmf_regrid import xesfm_regrid - - return xesfm_regrid(data, spatial_bounds, resolution) - - else: - raise ValueError(f"Unknown regridding method '{method}'") diff --git a/src/zampy/utils/xesmf_regrid.py b/src/zampy/utils/xesmf_regrid.py deleted file mode 100644 index 6daeffc..0000000 --- a/src/zampy/utils/xesmf_regrid.py +++ /dev/null @@ -1,62 +0,0 @@ -"""xesmf specific regridding implementation.""" -import numpy as np -import xarray as xr -import xesmf -from zampy.datasets.dataset_protocol import SpatialBounds -from zampy.utils import regrid - - -def create_new_grid(spatial_bounds: SpatialBounds, resolution: float) -> xr.Dataset: - """Create a dataset describing the new grid.""" - return xr.Dataset( - { - "latitude": ( - ["latitude"], - np.arange( - spatial_bounds.south, - spatial_bounds.north + 0.9 * resolution, - resolution, - ), - ), - "longitude": ( - ["longitude"], - np.arange( - spatial_bounds.west, - spatial_bounds.east + 0.9 * resolution, - resolution, - ), - ), - } - ) - - -def xesfm_regrid( - data: xr.Dataset, spatial_bounds: SpatialBounds, resolution: float -) -> xr.Dataset: - """Regrid a dataset to a new grid, using the xESMF library. - - Args: - data: Input dataset. - spatial_bounds: Spatial bounds of the new grid. - resolution: Resolution of the new grid. - - Returns: - Regridded input dataset - """ - data_resolution = regrid.infer_resolution(data) - old_resolution = min(data_resolution) - regrid_method = "bilinear" if resolution < old_resolution else "conservative" - - ds_grid = create_new_grid(spatial_bounds, resolution) - - regridder = xesmf.Regridder( - data, ds_grid, method=regrid_method, unmapped_to_nan=True - ) - return regridder( - # Chunks need to span all of the horizontal (lat/lon) dimensions for xesmf: - # https://xesmf.readthedocs.io/en/latest/notebooks/Dask.html#Invalid-chunk-sizes-to-avoid - data.chunk({"latitude": -1, "longitude": -1}), - keep_attrs=True, - skipna=True, - na_thres=0.1, # max 10% NaN values - ) diff --git a/tests/test_datasets/test_cams.py b/tests/test_datasets/test_cams.py index f53e010..5afaba7 100644 --- a/tests/test_datasets/test_cams.py +++ b/tests/test_datasets/test_cams.py @@ -113,7 +113,6 @@ def test_load(self): spatial_bounds=bbox, variable_names=variable, resolution=1.0, - regrid_method="flox", ) # we assert the regridded coordinates diff --git a/tests/test_datasets/test_era5.py b/tests/test_datasets/test_era5.py index 721c6cc..3b6cc1c 100644 --- a/tests/test_datasets/test_era5.py +++ b/tests/test_datasets/test_era5.py @@ -127,7 +127,6 @@ def test_load(self): spatial_bounds=bbox, variable_names=variable, resolution=1.0, - regrid_method="flox", ) # we assert the regridded coordinates diff --git a/tests/test_datasets/test_era5_land.py b/tests/test_datasets/test_era5_land.py index 6be14ad..8305697 100644 --- a/tests/test_datasets/test_era5_land.py +++ b/tests/test_datasets/test_era5_land.py @@ -127,7 +127,6 @@ def test_load(self): spatial_bounds=bbox, variable_names=variable, resolution=1.0, - regrid_method="flox", ) # we assert the regridded coordinates diff --git a/tests/test_datasets/test_eth_canopy_height.py b/tests/test_datasets/test_eth_canopy_height.py index 0beca68..70112a1 100644 --- a/tests/test_datasets/test_eth_canopy_height.py +++ b/tests/test_datasets/test_eth_canopy_height.py @@ -87,7 +87,6 @@ def test_load(self, dummy_dir): spatial_bounds=bbox, variable_names=variable, resolution=1.0, - regrid_method="flox", ) # we assert the regridded coordinates diff --git a/tests/test_datasets/test_fapar_lai.py b/tests/test_datasets/test_fapar_lai.py index bbb0020..6d9e568 100644 --- a/tests/test_datasets/test_fapar_lai.py +++ b/tests/test_datasets/test_fapar_lai.py @@ -108,7 +108,6 @@ def test_load(self): spatial_bounds=bbox, variable_names=variable, resolution=1.0, - regrid_method="flox", ) # we assert the regridded coordinates diff --git a/tests/test_datasets/test_land_cover.py b/tests/test_datasets/test_land_cover.py index 8b4e8c8..de986c1 100644 --- a/tests/test_datasets/test_land_cover.py +++ b/tests/test_datasets/test_land_cover.py @@ -106,7 +106,6 @@ def test_load(self, dummy_dir): spatial_bounds=bbox, variable_names=variable, resolution=1.0, - regrid_method="most_common", ) # we assert the regridded coordinates diff --git a/tests/test_datasets/test_prism_dem.py b/tests/test_datasets/test_prism_dem.py index 7fff305..5c7d1e1 100644 --- a/tests/test_datasets/test_prism_dem.py +++ b/tests/test_datasets/test_prism_dem.py @@ -85,7 +85,6 @@ def test_load(self, dummy_dir): spatial_bounds=bbox, variable_names=variable, resolution=0.25, - regrid_method="flox", ) # we assert the regridded coordinates diff --git a/tests/test_regrid.py b/tests/test_regrid.py deleted file mode 100644 index 0094347..0000000 --- a/tests/test_regrid.py +++ /dev/null @@ -1,133 +0,0 @@ -"""Unit test for regridding.""" - -import numpy as np -import pytest -from test_datasets import data_folder -from zampy.datasets.dataset_protocol import SpatialBounds -from zampy.datasets.eth_canopy_height import parse_tiff_file -from zampy.utils import regrid - - -path_dummy_data = data_folder / "eth-canopy-height" -XESMF_INSTALLED = True -# Since xesmf is only supported via conda, we need these checks to support -# tests cases running with and without conda environment in CD/CI -try: - import xesmf as _ # noqa: F401 (unused import) -except ImportError: - XESMF_INSTALLED = False - -# ruff: noqa: B018 - - -@pytest.mark.skipif(XESMF_INSTALLED, reason="xesmf is installed") -def assert_xesmf_available() -> None: - """Test assert_xesmf_available function with exception case.""" - with pytest.raises(ImportError, match="Could not import the `xesmf`"): - regrid.assert_xesmf_available() - - -@pytest.fixture -def dummy_dataset(): - ds = parse_tiff_file( - path_dummy_data / "ETH_GlobalCanopyHeight_10m_2020_N51E003_Map.tif", - ) - return ds - - -def test_infer_resolution(dummy_dataset): - """Test resolution inferring function.""" - (resolution_lat, resolution_lon) = regrid.infer_resolution(dummy_dataset) - expected_lat = 0.01 - expected_lon = 0.01 - - assert resolution_lat == pytest.approx(expected_lat, 0.001) - assert resolution_lon == pytest.approx(expected_lon, 0.001) - - -def test_groupby_regrid(dummy_dataset): - """Test groupby regrid function.""" - bbox = SpatialBounds(54, 6, 51, 3) - ds = regrid._groupby_regrid(data=dummy_dataset, spatial_bounds=bbox, resolution=1.0) - expected_lat = [51.0, 52.0, 53.0, 54.0] - expected_lon = [3.0, 4.0, 5.0, 6.0] - - np.testing.assert_allclose(ds.latitude.values, expected_lat) - np.testing.assert_allclose(ds.longitude.values, expected_lon) - - -def test_interp_regrid(dummy_dataset): - """Test interp regrid function.""" - bbox = SpatialBounds(51.1, 3.1, 51, 3) - ds = regrid._interp_regrid(data=dummy_dataset, spatial_bounds=bbox, resolution=0.05) - expected_lat = [51.0, 51.05, 51.1] - expected_lon = [3.0, 3.05, 3.1] - - np.testing.assert_allclose(ds.latitude.values, expected_lat) - np.testing.assert_allclose(ds.longitude.values, expected_lon) - - -def test_flox_regrid_coarser(dummy_dataset): - """Test flox regridding at a coarser resolution - - Note that the native resolution is about lat/lon=0.01. - """ - bbox = SpatialBounds(54, 6, 51, 3) - ds = regrid.flox_regrid(data=dummy_dataset, spatial_bounds=bbox, resolution=1.0) - expected_lat = [51.0, 52.0, 53.0, 54.0] - expected_lon = [3.0, 4.0, 5.0, 6.0] - - np.testing.assert_allclose(ds.latitude.values, expected_lat) - np.testing.assert_allclose(ds.longitude.values, expected_lon) - - -def test_flox_regrid_finer(dummy_dataset): - """Test flox regridding at a finer resolution - - Note that the native resolution is about lat/lon=0.01. - """ - bbox = SpatialBounds(51.02, 3.02, 51, 3) - ds = regrid.flox_regrid(data=dummy_dataset, spatial_bounds=bbox, resolution=0.002) - # only compare first 5 index - expected_lat = [51.0, 51.002, 51.004, 51.006, 51.008] - expected_lon = [3.0, 3.002, 3.004, 3.006, 3.008] - - np.testing.assert_allclose(ds.latitude.values[:5], expected_lat) - np.testing.assert_allclose(ds.longitude.values[:5], expected_lon) - - -def test_flox_regrid_close(dummy_dataset): - """Test flox regridding at a resolution close to native. - - Note that the native resolution is about lat/lon=0.01. - """ - bbox = SpatialBounds(51.1, 3.1, 51, 3) - ds = regrid.flox_regrid(data=dummy_dataset, spatial_bounds=bbox, resolution=0.02) - expected_lat = [51.0, 51.02, 51.04, 51.06, 51.08, 51.1] - expected_lon = [3.0, 3.02, 3.04, 3.06, 3.08, 3.1] - - np.testing.assert_allclose(ds.latitude.values, expected_lat) - np.testing.assert_allclose(ds.longitude.values, expected_lon) - - -def test_regrid_data_flox(dummy_dataset): - bbox = SpatialBounds(54, 6, 51, 3) - ds = regrid.regrid_data( - data=dummy_dataset, spatial_bounds=bbox, resolution=1.0, method="flox" - ) - expected_lat = [51.0, 52.0, 53.0, 54.0] - expected_lon = [3.0, 4.0, 5.0, 6.0] - - np.testing.assert_allclose(ds.latitude.values, expected_lat) - np.testing.assert_allclose(ds.longitude.values, expected_lon) - - -def test_regrid_data_unknown_method(dummy_dataset): - bbox = SpatialBounds(54, 6, 51, 3) - with pytest.raises(ValueError): - regrid.regrid_data( - data=dummy_dataset, - spatial_bounds=bbox, - resolution=1.0, - method="fake_method", - ) diff --git a/tests/test_xesmf_regrid.py b/tests/test_xesmf_regrid.py deleted file mode 100644 index 569340e..0000000 --- a/tests/test_xesmf_regrid.py +++ /dev/null @@ -1,48 +0,0 @@ -"""Unit test for regridding functions with xesmf.""" - -from pathlib import Path -import numpy as np -import pytest -from zampy.datasets.dataset_protocol import SpatialBounds -from zampy.datasets.eth_canopy_height import parse_tiff_file - - -XESMF_INSTALLED = True -try: - import xesmf as _ # noqa: F401 (unused import) - from zampy.utils import xesmf_regrid -except ImportError: - XESMF_INSTALLED = False - - -@pytest.mark.skipif(not XESMF_INSTALLED, reason="xesmf is not installed") -def test_create_new_grid(): - bbox = SpatialBounds(54, 6, 51, 3) - ds = xesmf_regrid.create_new_grid(spatial_bounds=bbox, resolution=1.0) - expected_lat = [51.0, 52.0, 53.0, 54.0] - expected_lon = [3.0, 4.0, 5.0, 6.0] - - np.testing.assert_allclose(ds.latitude.values, expected_lat) - np.testing.assert_allclose(ds.longitude.values, expected_lon) - - -@pytest.mark.skipif(not XESMF_INSTALLED, reason="xesmf is not installed") -def test_xesfm_regrid(): - # load dummy dataset - path_dummy_data = ( - Path(__file__).resolve().parent / "test_data" / "eth-canopy-height" - ) - dummy_dataset = parse_tiff_file( - path_dummy_data / "ETH_GlobalCanopyHeight_10m_2020_N51E003_Map.tif", - ) - - bbox = SpatialBounds(54, 6, 51, 3) - ds = xesmf_regrid.xesfm_regrid( - data=dummy_dataset, spatial_bounds=bbox, resolution=1.0 - ) - - expected_lat = [51.0, 52.0, 53.0, 54.0] - expected_lon = [3.0, 4.0, 5.0, 6.0] - - np.testing.assert_allclose(ds.latitude.values, expected_lat) - np.testing.assert_allclose(ds.longitude.values, expected_lon)