Skip to content

Commit

Permalink
FIX: Fix rasters.sieve function with xr.apply_ufunc
Browse files Browse the repository at this point in the history
  • Loading branch information
remi-braun committed Dec 23, 2024
1 parent 41acbb1 commit d0168ff
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
- FIX: Use `np.tan` in `rasters.slope`
- FIX: Allow str as paths in `ci.assert_files_equal`
- FIX: Better alignement between `rasters.read` function and `rasters.any_raster_to_xr_ds` decorator
- FIX: Fix `rasters.sieve` function with `xr.apply_ufunc`
- OPTIM: Compute the spatial index by default in `vectors.read` (set `vectors.read(..., compute_sindex=False)` if you don't want to compute them)
- CI: Rename CI folder and remove unnecessary intermediate folder

Expand Down
6 changes: 5 additions & 1 deletion ci/test_rasters.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ def test_crop(tmp_path, xda, xds, xda_dask, mask):

@s3_env
@dask_env
def test_sieve(tmp_path, xda, xds, xda_dask):
def test_sieve(tmp_path, raster_path, xda, xds, xda_dask):
"""Test sieve function"""
# DataArray
xda_sieved = os.path.join(tmp_path, "test_sieved_xda.tif")
Expand Down Expand Up @@ -350,6 +350,10 @@ def test_sieve(tmp_path, xda, xds, xda_dask):
ci.assert_raster_equal(xda_sieved, raster_sieved_path)
ci.assert_raster_equal(xds_sieved, raster_sieved_path)

# From path
sieve_xda_path = rasters.sieve(raster_path, sieve_thresh=20, connectivity=4)
np.testing.assert_array_equal(sieve_xda, sieve_xda_path)


@s3_env
@dask_env
Expand Down
9 changes: 5 additions & 4 deletions sertit/rasters.py
Original file line number Diff line number Diff line change
Expand Up @@ -1380,14 +1380,15 @@ def sieve(

assert connectivity in [4, 8]

# Use this trick to make the sieve work
mask = np.where(np.isnan(xds.data), 0, 1).astype(np.uint8)
data = xds.data.astype(np.uint8)
mask = xr.where(np.isnan(xds), 0, 1).astype(np.uint8).data
data = xds.astype(np.uint8).data

# Sieve
try:
sieved_arr = xr.apply_ufunc(
features.sieve, data, sieve_thresh, connectivity, mask
features.sieve,
data,
kwargs={"size": sieve_thresh, "connectivity": connectivity, "mask": mask},
)
except ValueError:
sieved_arr = features.sieve(
Expand Down

0 comments on commit d0168ff

Please sign in to comment.