Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Raster.from_xarray() to create raster from a xr.DataArray #521

Merged
merged 10 commits into from
Mar 19, 2024
97 changes: 81 additions & 16 deletions geoutils/raster/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -969,13 +969,17 @@ def __setitem__(self, index: Mask | NDArrayBool | Any, assign: NDArrayNum | Numb
self._data[:, ind] = assign # type: ignore
return None

def raster_equal(self, other: RasterType) -> bool:
def raster_equal(self, other: RasterType, strict_masked: bool = True, warn_failure_reason: bool = False) -> bool:
"""
Check if two rasters are equal.

This means that are equal:
- The raster's masked array's data (including masked values), mask, fill_value and dtype,
- The raster's transform, crs and nodata values.

:param other: Other raster.
:param strict_masked: Whether to check if masked cells (in .data.mask) have the same value (in .data.data).
:param warn_failure_reason: Whether to warn for the reason of failure if the check does not pass.
"""

# If the mask is just "False", it is equivalent to being equal to an array of False
Expand All @@ -991,8 +995,10 @@ def raster_equal(self, other: RasterType) -> bool:

if not isinstance(other, Raster): # TODO: Possibly add equals to SatelliteImage?
raise NotImplementedError("Equality with other object than Raster not supported by raster_equal.")
return all(
[

if strict_masked:
names = ["data.data", "data.mask", "data.fill_value", "dtype", "transform", "crs", "nodata"]
equalities = [
np.array_equal(self.data.data, other.data.data, equal_nan=True),
np.array_equal(self_mask, other_mask),
self.data.fill_value == other.data.fill_value,
Expand All @@ -1001,7 +1007,26 @@ def raster_equal(self, other: RasterType) -> bool:
self.crs == other.crs,
self.nodata == other.nodata,
]
)
else:
names = ["data", "data.fill_value", "dtype", "transform", "crs", "nodata"]
equalities = [
np.ma.allequal(self.data, other.data),
self.data.fill_value == other.data.fill_value,
self.data.dtype == other.data.dtype,
self.transform == other.transform,
self.crs == other.crs,
self.nodata == other.nodata,
]

complete_equality = all(equalities)

if not complete_equality and warn_failure_reason:
where_fail = np.nonzero(~np.array(equalities))[0]
warnings.warn(
category=UserWarning, message=f"Equality failed for: {', '.join([names[w] for w in where_fail])}."
)

return complete_equality

def _overloading_check(
self: RasterType, other: RasterType | NDArrayNum | Number
Expand Down Expand Up @@ -1336,18 +1361,24 @@ def __ge__(self: RasterType, other: RasterType | NDArrayNum | Number) -> RasterT
return out_mask

@overload
def astype(self, dtype: DTypeLike, convert_nodata: bool = True, *, inplace: Literal[False] = False) -> Raster:
def astype(
self: RasterType, dtype: DTypeLike, convert_nodata: bool = True, *, inplace: Literal[False] = False
) -> RasterType:
...

@overload
def astype(self, dtype: DTypeLike, convert_nodata: bool = True, *, inplace: Literal[True]) -> None:
def astype(self: RasterType, dtype: DTypeLike, convert_nodata: bool = True, *, inplace: Literal[True]) -> None:
...

@overload
def astype(self, dtype: DTypeLike, convert_nodata: bool = True, *, inplace: bool = False) -> Raster | None:
def astype(
self: RasterType, dtype: DTypeLike, convert_nodata: bool = True, *, inplace: bool = False
) -> RasterType | None:
...

def astype(self, dtype: DTypeLike, convert_nodata: bool = True, inplace: bool = False) -> Raster | None:
def astype(
self: RasterType, dtype: DTypeLike, convert_nodata: bool = True, inplace: bool = False
) -> RasterType | None:
"""
Convert data type of the raster.

Expand Down Expand Up @@ -1523,6 +1554,7 @@ def set_nodata(

# Update the nodata value
self._nodata = new_nodata
self.data.fill_value = new_nodata
Copy link
Member Author

@rhugonnet rhugonnet Mar 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This little addition was necessary to make raster_equal (which checks fill_values are equal between raster's masked arrays) happy when set_nodata is called!


@property
def data(self) -> MArrayNum:
Expand Down Expand Up @@ -2629,22 +2661,55 @@ def save(

dst.gcps = (rio_gcps, gcps_crs)

@classmethod
def from_xarray(cls: type[RasterType], ds: xr.DataArray, dtype: DTypeLike | None = None) -> RasterType:
"""
Create raster from a xarray.DataArray.

This conversion loads the xarray.DataArray in memory. Use functions of the Xarray accessor directly
to avoid this behaviour.

:param ds: Data array.
:param dtype: Cast the array to a certain dtype.

:return: Raster.
"""

# Define main attributes
crs = ds.rio.crs
transform = ds.rio.transform(recalc=True)
nodata = ds.rio.nodata

# TODO: Add tags and area_or_point with PR #509
raster = cls.from_array(data=ds.data, transform=transform, crs=crs, nodata=nodata)

if dtype is not None:
raster = raster.astype(dtype)

return raster

def to_xarray(self, name: str | None = None) -> xr.DataArray:
"""
Convert raster to a xarray.DataArray.

This method uses rioxarray to generate a DataArray with associated
geo-referencing information.
This converts integer-type rasters into float32.

See the documentation of rioxarray and xarray for more information on
the methods and attributes of the resulting DataArray.
:param name: Name attribute for the data array.

:param name: Name attribute for the DataArray.

:returns: xarray DataArray
:returns: Data array.
"""

ds = rioxarray.open_rasterio(self.to_rio_dataset())
# If type was integer, cast to float to be able to save nodata values in the xarray data array
if np.issubdtype(self.dtypes[0], np.integer):
# Nodata conversion is not needed in this direction (integer towards float), we can maintain the original
updated_raster = self.astype(np.float32, convert_nodata=False)
else:
updated_raster = self

ds = rioxarray.open_rasterio(updated_raster.to_rio_dataset(), masked=True)
# When reading as masked, the nodata is not written to the dataset so we do it manually
ds.rio.set_nodata(self.nodata)

if name is not None:
ds.name = name

Expand Down
52 changes: 48 additions & 4 deletions tests/test_raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ def test_to_rio_dataset(self, example: str):
def test_to_xarray(self, example: str):
"""Test the export to a xarray dataset"""

# Open raster and export to rio dataset
# Open raster and export to xarray dataset
rst = gu.Raster(example)
ds = rst.to_xarray()

Expand Down Expand Up @@ -391,9 +391,32 @@ def test_to_xarray(self, example: str):

# Check that the arrays are equal in NaN type
if rst.count > 1:
assert np.array_equal(rst.data.data, ds.data)
assert np.array_equal(rst.get_nanarray(), ds.data.squeeze(), equal_nan=True)
else:
assert np.array_equal(rst.data.data, ds.data.squeeze())
assert np.array_equal(rst.get_nanarray(), ds.data.squeeze(), equal_nan=True)

@pytest.mark.parametrize("example", [landsat_b4_path, aster_dem_path, landsat_rgb_path]) # type: ignore
def test_from_xarray(self, example: str):
"""Test raster creation from a xarray dataset, not fully reversible with to_xarray due to float conversion"""

# Open raster and export to xarray, then import to xarray dataset
rst = gu.Raster(example)
ds = rst.to_xarray()
rst2 = gu.Raster.from_xarray(ds=ds)

# Exporting to a Xarray dataset results in loss of information to float32
# Check that the output equals the input converted to float32 (not fully reversible)
assert rst.astype("float32", convert_nodata=False).raster_equal(rst2, strict_masked=False)

# Test with the dtype argument to convert back to original raster even if integer-type
if np.issubdtype(rst.dtypes[0], np.integer):
# Set an existing nodata value, because all of our integer-type example datasets currently have "None"
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="New nodata value cells already exist.*")
rst.set_nodata(new_nodata=255)
ds = rst.to_xarray()
rst3 = gu.Raster.from_xarray(ds=ds, dtype=rst.dtypes[0])
assert rst3.raster_equal(rst, strict_masked=False)

@pytest.mark.parametrize("nodata_init", [None, "type_default"]) # type: ignore
@pytest.mark.parametrize(
Expand Down Expand Up @@ -2189,6 +2212,7 @@ def test_set_nodata(self, example: str) -> None:

# The nodata value should have been set in the metadata
assert r.nodata == new_nodata
assert r.data.fill_value == new_nodata

# By default, the array should have been updated
if old_nodata is not None:
Expand Down Expand Up @@ -2227,6 +2251,7 @@ def test_set_nodata(self, example: str) -> None:

# The nodata value should have been set in the metadata
assert r.nodata == new_nodata
assert r.data.fill_value == new_nodata

# By default, the array should have been updated similarly for the old nodata
if old_nodata is not None:
Expand Down Expand Up @@ -2269,6 +2294,7 @@ def test_set_nodata(self, example: str) -> None:

# The nodata value should have been set in the metadata
assert r.nodata == new_nodata
assert r.data.fill_value == new_nodata

# Now, the array should not have been updated, so the entire array should be unchanged except for the pixel
assert np.array_equal(r.data.data[~mask_pixel_artificially_set], r_copy.data.data[~mask_pixel_artificially_set])
Expand Down Expand Up @@ -2297,6 +2323,7 @@ def test_set_nodata(self, example: str) -> None:

# The nodata value should have been set in the metadata
assert r.nodata == new_nodata
assert r.data.fill_value == new_nodata

# The array should have been updated
if old_nodata is not None:
Expand All @@ -2323,6 +2350,7 @@ def test_set_nodata(self, example: str) -> None:

# The nodata value should have been set in the metadata
assert r.nodata == new_nodata
assert r.data.fill_value == new_nodata

# The array should not have been updated except for the pixel
assert np.array_equal(r.data.data[~mask_pixel_artificially_set], r_copy.data.data[~mask_pixel_artificially_set])
Expand Down Expand Up @@ -3204,7 +3232,7 @@ def test_reproject(self, mask: gu.Mask) -> None:
match="Reprojecting a mask with a resampling method other than 'nearest', "
"the boolean array will be converted to float during interpolation.",
):
mask.reproject(resampling="bilinear")
mask.reproject(res=50, resampling="bilinear", force_source_nodata=2)
rhugonnet marked this conversation as resolved.
Show resolved Hide resolved

@pytest.mark.parametrize("mask", [mask_landsat_b4, mask_aster_dem, mask_everest]) # type: ignore
def test_crop(self, mask: gu.Mask) -> None:
Expand Down Expand Up @@ -3437,6 +3465,22 @@ def test_raster_equal(self) -> None:
r2.set_nodata(34)
assert not r1.raster_equal(r2)

# Change value of a masked cell
r2 = r1.copy()
r2.data[0, 0] = np.ma.masked
r2.data.data[0, 0] = 0
r3 = r2.copy()
r3.data.data[0, 0] = 10
assert not r2.raster_equal(r3)
assert r2.raster_equal(r3, strict_masked=False)

# Check that a warning is raised with useful information without equality
with pytest.warns(UserWarning, match="Equality failed for: data.data."):
assert not r2.raster_equal(r3, warn_failure_reason=True)

# But no warning is raised for an equality
assert r2.raster_equal(r3, strict_masked=False, warn_failure_reason=True)

def test_equal_georeferenced_grid(self) -> None:
"""
Test that equal for shape, crs and transform work as expected
Expand Down
Loading