Skip to content

Commit

Permalink
Fix bug for raster with nodata equal to 0 in stack_rasters (#609)
Browse files Browse the repository at this point in the history
  • Loading branch information
rhugonnet authored Sep 30, 2024
1 parent 1289d3f commit d80304d
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 10 deletions.
12 changes: 8 additions & 4 deletions geoutils/raster/multiraster.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,26 +165,30 @@ def stack_rasters(
return_rio_bbox=True,
)

# Make a data list and add all of the reprojected rasters into it.
# Make a data list and add all the reprojected rasters into it.
data: list[NDArrayNum] = []

for raster in tqdm(rasters, disable=not progress):
# Check that data is loaded, otherwise temporarily load it
if not raster.is_loaded:
raster.load()

nodata = reference_raster.nodata or gu.raster.raster._default_nodata(reference_raster.data.dtype)
nodata = reference_raster.nodata if not None else gu.raster.raster._default_nodata(reference_raster.data.dtype)
# Reproject to reference grid
reprojected_raster = raster.reproject(
bounds=dst_bounds,
res=reference_raster.res,
crs=reference_raster.crs,
dtype=reference_raster.data.dtype,
nodata=reference_raster.nodata,
nodata=nodata,
resampling=resampling_method,
silent=True,
)
reprojected_raster.set_nodata(nodata)
# If the georeferenced grid was the same, reproject() will have returned self with a warning (silenced here),
# and we want to copy the raster and just modify its nodata (or would modify raster inputs of this function)
if reprojected_raster.georeferenced_grid_equal(raster):
reprojected_raster = reprojected_raster.copy()
reprojected_raster.set_nodata(nodata)

# Optionally calculate difference
if diff:
Expand Down
96 changes: 90 additions & 6 deletions tests/test_raster/test_multiraster.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@
import geoutils as gu
from geoutils import examples
from geoutils.raster import RasterType
from geoutils.raster.raster import _default_nodata


class StackMergeImages:
class RealImageStack:
"""
Test cases for stacking and merging images
Real test cases for stacking and merging images
Split an image with some overlap, then stack/merge it, and validate bounds and shape.
Param `cls` is used to set the type of the output, e.g. gu.Raster (default).
"""
Expand Down Expand Up @@ -67,24 +68,93 @@ def __init__(
)


class SyntheticImageStack:
"""
Synthetic image stack for tests
Create a small synthetic example, where one can specify nodata value, values in second image (and potentially more
in the future).
"""

def __init__(self, nodata: int | float, img2_value: int | float):

shape = (10, 10)
data_int = np.ones(shape).astype(np.uint16)
data_mask = np.zeros(shape).astype(bool)
data_masked = np.ma.masked_array(data=data_int, mask=data_mask, fill_value=nodata)
img = gu.Raster.from_array(
data=data_masked,
transform=rio.transform.Affine(
1000.0,
0.0,
1_000_000.0,
0.0,
-1000.0,
1_000_000.0,
),
crs=pyproj.CRS.from_string("EPSG:3857"),
nodata=nodata,
)
self.img = img

# Find the easting midpoint of the img
x_midpoint = np.mean([self.img.bounds.right, self.img.bounds.left])
x_midpoint -= (x_midpoint - self.img.bounds.left) % self.img.res[0]

# Cut the img into two imgs that slightly overlap each other.
self.img1 = img.copy()
self.img1.crop(
rio.coords.BoundingBox(
right=x_midpoint + img.res[0] * 3, left=img.bounds.left, top=img.bounds.top, bottom=img.bounds.bottom
),
inplace=True,
)
self.img2 = img.copy()
self.img2.crop(
rio.coords.BoundingBox(
left=x_midpoint - img.res[0] * 3, right=img.bounds.right, top=img.bounds.top, bottom=img.bounds.bottom
),
inplace=True,
)

# Define a second raster with only 5s and the value defined above
self.img2[:5, :5] = img2_value

self.img3 = self.img1.copy()
self.img3.crop(
rio.coords.BoundingBox(
left=x_midpoint - self.img.res[0] * 3,
right=self.img.bounds.right - self.img.res[0] * 2,
top=self.img.bounds.top,
bottom=self.img.bounds.bottom,
),
inplace=True,
)


@pytest.fixture
def images_1d(): # type: ignore
return StackMergeImages("everest_landsat_b4")
return RealImageStack("everest_landsat_b4")


@pytest.fixture
def images_different_crs(): # type: ignore
return StackMergeImages("everest_landsat_b4", different_crs=4326)
return RealImageStack("everest_landsat_b4", different_crs=4326)


@pytest.fixture
def sat_images(): # type: ignore
return StackMergeImages("everest_landsat_b4", cls=gu.SatelliteImage)
return RealImageStack("everest_landsat_b4", cls=gu.SatelliteImage)


@pytest.fixture
def images_3d(): # type: ignore
return StackMergeImages("everest_landsat_rgb")
return RealImageStack("everest_landsat_rgb")


@pytest.fixture
def images_nodata_zero(): # type: ignore
return SyntheticImageStack(nodata=0, img2_value=65534)


class TestMultiRaster:
Expand All @@ -95,6 +165,7 @@ class TestMultiRaster:
pytest.lazy_fixture("sat_images"),
pytest.lazy_fixture("images_different_crs"),
pytest.lazy_fixture("images_3d"),
pytest.lazy_fixture("images_nodata_zero"),
],
) # type: ignore
def test_stack_rasters(self, rasters) -> None: # type: ignore
Expand All @@ -105,6 +176,7 @@ def test_stack_rasters(self, rasters) -> None: # type: ignore
"ignore", category=UserWarning, message="New nodata value cells already exist in the data array.*"
)
warnings.filterwarnings("ignore", category=UserWarning, message="For reprojection, nodata must be set.*")
warnings.filterwarnings("ignore", category=UserWarning, message="Unmasked values equal to*")

# Merge the two overlapping DEMs and check that output bounds and shape is correct
if rasters.img1.count > 1:
Expand Down Expand Up @@ -139,6 +211,7 @@ def test_stack_rasters(self, rasters) -> None: # type: ignore
)
assert merged_bounds == stacked_img.bounds

nodata_ref = rasters.img1.nodata
# Check that reference works with input Raster
stacked_img = gu.raster.stack_rasters([rasters.img1, rasters.img2], reference=rasters.img, use_ref_bounds=True)
assert rasters.img.bounds == stacked_img.bounds
Expand Down Expand Up @@ -170,6 +243,16 @@ def test_stack_rasters(self, rasters) -> None: # type: ignore
stacked_img = gu.raster.stack_rasters([rasters.img1, rasters.img2], resampling_method="bilinear")
assert not np.array_equal(np.unique(stacked_img.data.compressed()), np.array([1, 5]))

# Check input nodata is not modified inplace (issue 609)
new_nodata_ref = rasters.img1.nodata
assert nodata_ref == new_nodata_ref

# Check nodata value output is consistent with reference input
if nodata_ref is not None:
assert stacked_img.nodata == nodata_ref
else:
assert stacked_img.nodata == _default_nodata(rasters.img1.dtype)

@pytest.mark.parametrize(
"rasters",
[
Expand All @@ -187,6 +270,7 @@ def test_merge_rasters(self, rasters) -> None: # type: ignore
"ignore", category=UserWarning, message="New nodata value cells already exist in the data array.*"
)
warnings.filterwarnings("ignore", category=UserWarning, message="For reprojection, nodata must be set.*")
warnings.filterwarnings("ignore", category=UserWarning, message="Unmasked values equal to*")

# Ignore warning already checked in test_stack_rasters
if rasters.img1.count > 1:
Expand Down

0 comments on commit d80304d

Please sign in to comment.