From 35a872743d93205239bd00e84e54e85f5c393f0e Mon Sep 17 00:00:00 2001 From: Pascal Bourgault Date: Mon, 18 Sep 2023 14:35:47 -0400 Subject: [PATCH] Add error message when output has no chunks with parallel weights generation (#304) * Add error msg when no chunks in par=T * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix for locstreamout * add simple test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: David Huard --- CHANGES.rst | 1 + xesmf/frontend.py | 5 +++++ xesmf/tests/test_frontend.py | 5 +++++ 3 files changed, 11 insertions(+) diff --git a/CHANGES.rst b/CHANGES.rst index a2ed8dba..fa1169a8 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -6,6 +6,7 @@ What's new Bug fixes ~~~~~~~~~ +* Raise a meaningful error messages when the output grid has no chunks with `parallel=True` (:issue:`299`, :pull:`304`). By `Pascal Bourgault `_. * Correct guess of output chunks for the :``SpatialAverager``. 0.8.1 (2023-09-05) diff --git a/xesmf/frontend.py b/xesmf/frontend.py index 2b79f59f..267ed3d9 100644 --- a/xesmf/frontend.py +++ b/xesmf/frontend.py @@ -963,6 +963,11 @@ def _init_para_regrid(self, ds_in, ds_out, kwargs): ds_out = ds_out.set_coords(['lon_b', 'lat_b']) if 'lon_b' in ds_in.data_vars: ds_in = ds_in.set_coords(['lon_b', 'lat_b']) + if not (set(self.out_horiz_dims) - {'dummy'}).issubset(ds_out.chunksizes.keys()): + raise ValueError( + 'Using `parallel=True` requires the output grid to have chunks along all spatial dimensions. ' + 'If the dataset has no variables, consider adding an all-True spatial mask with appropriate chunks.' + ) # Drop everything in ds_out except mask or create mask if None. This is to prevent map_blocks loading unnecessary large data if self.sequence_out: ds_out_dims_drop = set(ds_out.variables).difference(ds_out.data_vars) diff --git a/xesmf/tests/test_frontend.py b/xesmf/tests/test_frontend.py index 0246574b..0d0b9d81 100644 --- a/xesmf/tests/test_frontend.py +++ b/xesmf/tests/test_frontend.py @@ -649,6 +649,11 @@ def test_para_weight_gen(): assert all(regridder_locs.w.data.data == para_regridder_locs.w.data.data) +def test_para_weight_gen_errors(): + with pytest.raises(ValueError, match='requires the output grid to have chunks'): + xe.Regridder(ds_in, ds_out, 'conservative', parallel=True) + + def test_regrid_dataset(): # xarray.Dataset containing in-memory numpy array regridder = xe.Regridder(ds_in, ds_out, 'conservative')