Skip to content

Commit

Permalink
Add error message when output has no chunks with parallel weights gen…
Browse files Browse the repository at this point in the history
…eration (#304)

* Add error msg when no chunks in par=T

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix for locstreamout

* add simple test

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: David Huard <[email protected]>
  • Loading branch information
3 people authored Sep 18, 2023
1 parent 198d0ee commit 35a8727
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ What's new

Bug fixes
~~~~~~~~~
* Raise a meaningful error messages when the output grid has no chunks with `parallel=True` (:issue:`299`, :pull:`304`). By `Pascal Bourgault <https://github.com/aulemahal>`_.
* Correct guess of output chunks for the :``SpatialAverager``.

0.8.1 (2023-09-05)
Expand Down
5 changes: 5 additions & 0 deletions xesmf/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,6 +963,11 @@ def _init_para_regrid(self, ds_in, ds_out, kwargs):
ds_out = ds_out.set_coords(['lon_b', 'lat_b'])
if 'lon_b' in ds_in.data_vars:
ds_in = ds_in.set_coords(['lon_b', 'lat_b'])
if not (set(self.out_horiz_dims) - {'dummy'}).issubset(ds_out.chunksizes.keys()):
raise ValueError(
'Using `parallel=True` requires the output grid to have chunks along all spatial dimensions. '
'If the dataset has no variables, consider adding an all-True spatial mask with appropriate chunks.'
)
# Drop everything in ds_out except mask or create mask if None. This is to prevent map_blocks loading unnecessary large data
if self.sequence_out:
ds_out_dims_drop = set(ds_out.variables).difference(ds_out.data_vars)
Expand Down
5 changes: 5 additions & 0 deletions xesmf/tests/test_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,11 @@ def test_para_weight_gen():
assert all(regridder_locs.w.data.data == para_regridder_locs.w.data.data)


def test_para_weight_gen_errors():
with pytest.raises(ValueError, match='requires the output grid to have chunks'):
xe.Regridder(ds_in, ds_out, 'conservative', parallel=True)


def test_regrid_dataset():
# xarray.Dataset containing in-memory numpy array
regridder = xe.Regridder(ds_in, ds_out, 'conservative')
Expand Down

0 comments on commit 35a8727

Please sign in to comment.