From d0168ffc0817289b53f039f4aa8dd87c7a7d2473 Mon Sep 17 00:00:00 2001 From: BRAUN REMI Date: Mon, 23 Dec 2024 14:49:30 +0100 Subject: [PATCH] FIX: Fix `rasters.sieve` function with `xr.apply_ufunc` --- CHANGES.md | 1 + ci/test_rasters.py | 6 +++++- sertit/rasters.py | 9 +++++---- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index b556ddc..df0c46b 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -10,6 +10,7 @@ - FIX: Use `np.tan` in `rasters.slope` - FIX: Allow str as paths in `ci.assert_files_equal` - FIX: Better alignement between `rasters.read` function and `rasters.any_raster_to_xr_ds` decorator +- FIX: Fix `rasters.sieve` function with `xr.apply_ufunc` - OPTIM: Compute the spatial index by default in `vectors.read` (set `vectors.read(..., compute_sindex=False)` if you don't want to compute them) - CI: Rename CI folder and remove unnecessary intermediate folder diff --git a/ci/test_rasters.py b/ci/test_rasters.py index af6d078..d09addd 100644 --- a/ci/test_rasters.py +++ b/ci/test_rasters.py @@ -316,7 +316,7 @@ def test_crop(tmp_path, xda, xds, xda_dask, mask): @s3_env @dask_env -def test_sieve(tmp_path, xda, xds, xda_dask): +def test_sieve(tmp_path, raster_path, xda, xds, xda_dask): """Test sieve function""" # DataArray xda_sieved = os.path.join(tmp_path, "test_sieved_xda.tif") @@ -350,6 +350,10 @@ def test_sieve(tmp_path, xda, xds, xda_dask): ci.assert_raster_equal(xda_sieved, raster_sieved_path) ci.assert_raster_equal(xds_sieved, raster_sieved_path) + # From path + sieve_xda_path = rasters.sieve(raster_path, sieve_thresh=20, connectivity=4) + np.testing.assert_array_equal(sieve_xda, sieve_xda_path) + @s3_env @dask_env diff --git a/sertit/rasters.py b/sertit/rasters.py index 8e1ec6b..09482e3 100644 --- a/sertit/rasters.py +++ b/sertit/rasters.py @@ -1380,14 +1380,15 @@ def sieve( assert connectivity in [4, 8] - # Use this trick to make the sieve work - mask = np.where(np.isnan(xds.data), 0, 1).astype(np.uint8) - data = xds.data.astype(np.uint8) + mask = xr.where(np.isnan(xds), 0, 1).astype(np.uint8).data + data = xds.astype(np.uint8).data # Sieve try: sieved_arr = xr.apply_ufunc( - features.sieve, data, sieve_thresh, connectivity, mask + features.sieve, + data, + kwargs={"size": sieve_thresh, "connectivity": connectivity, "mask": mask}, ) except ValueError: sieved_arr = features.sieve(