Skip to content

Commit

Permalink
feat: add a roi as a parameter in Raster
Browse files Browse the repository at this point in the history
  • Loading branch information
vschaffn committed Dec 17, 2024
1 parent 04b170a commit 63ac935
Showing 1 changed file with 63 additions and 7 deletions.
70 changes: 63 additions & 7 deletions geoutils/raster/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def _load_rio(
transform: Affine | None = None,
shape: tuple[int, int] | None = None,
out_count: int | None = None,
roi: dict[str, int] | None = None,
**kwargs: Any,
) -> MArrayNum:
r"""
Expand All @@ -185,6 +186,7 @@ def _load_rio(
:param transform: Create a window from the given transform (to read only parts of the raster)
:param shape: Expected shape of the read ndarray. Must be given together with the `transform` argument.
:param out_count: Specify the count for a downsampled version (to be used with kwargs out_shape).
:param roi: Optional pixel-based region of interest (dict with keys: left, bottom, right, top).
:raises ValueError: If only one of ``transform`` and ``shape`` are given.
Expand All @@ -197,9 +199,19 @@ def _load_rio(
* window : to load a cropped version
* resampling : to set the resampling algorithm
"""
window = None

# If a roi is passed, set up the corresponding window
if roi is not None:
window = rio.windows.Window(
col_off=roi["left"],
row_off=dataset.height - roi["top"],
width=roi["right"] - roi["left"],
height=roi["top"] - roi["bottom"],
)

# If out_shape is passed, no need to account for transform and shape
if kwargs["out_shape"] is not None:
window = None
elif kwargs.get("out_shape") is not None:
# If multi-band raster, the out_shape needs to contain the count
if out_count is not None and out_count > 1:
kwargs["out_shape"] = (out_count, *kwargs["out_shape"])
Expand All @@ -213,8 +225,6 @@ def _load_rio(
window = rio.windows.Window(col_off, row_off, *shape[::-1])
elif sum(param is None for param in [shape, transform]) == 1:
raise ValueError("If 'shape' or 'transform' is provided, BOTH must be given.")
else:
window = None

if indexes is None:
if only_mask:
Expand Down Expand Up @@ -356,6 +366,7 @@ def __init__(
silent: bool = True,
downsample: Number = 1,
nodata: int | float | None = None,
roi: dict[str, int] | None = None,
) -> None:
"""
Instantiate a raster from a filename or rasterio dataset.
Expand All @@ -367,6 +378,7 @@ def __init__(
:param silent: Whether to parse metadata silently or with console output.
:param downsample: Downsample the array once loaded by a round factor. Default is no downsampling.
:param nodata: Nodata value to be used (overwrites the metadata). Default reads from metadata.
:param roi: Optional pixel-based region of interest (dict with keys: left, bottom, right, top).
"""
self._driver: str | None = None
self._name: str | None = None
Expand All @@ -377,6 +389,7 @@ def __init__(
self._transform: affine.Affine | None = None
self._crs: CRS | None = None
self._nodata: int | float | None = nodata
self._roi: dict[str, int] | None = roi
self._bands = bands
self._bands_loaded: int | tuple[int, ...] | None = None
self._masked = True
Expand All @@ -391,6 +404,14 @@ def __init__(
self._downsample: int | float = 1
self._area_or_point: Literal["Area", "Point"] | None = None

# Check that roi contains the needed keys
if roi is not None:
required_roi_keys = {"left", "bottom", "right", "top"}
if not required_roi_keys.issubset(roi.keys()):
raise ValueError(
"The roi parameter must contain the following keys : 'left', 'right', 'top' and 'bottom'"
)

# This is for Raster.from_array to work.
if isinstance(filename_or_dataset, dict):

Expand All @@ -412,7 +433,7 @@ def __init__(
self.crs: rio.crs.CRS = filename_or_dataset["crs"]

for key in filename_or_dataset:
if key in ["data", "transform", "crs", "nodata", "area_or_point", "tags"]:
if key in ["data", "transform", "crs", "nodata", "area_or_point", "tags", "roi"]:
continue
setattr(self, key, filename_or_dataset[key])
return
Expand All @@ -421,7 +442,17 @@ def __init__(
if isinstance(filename_or_dataset, Raster):
for key in filename_or_dataset.__dict__:
setattr(self, key, filename_or_dataset.__dict__[key])
return

# if a roi is passed, crop the raster
if roi is not None:
crop_extent = [
self.transform.c + roi["left"] * self.transform.a, # xmin
self.transform.f + (self.height - roi["bottom"]) * self.transform.e, # ymin
self.transform.c + roi["right"] * self.transform.a, # xmax
self.transform.f + (self.height - roi["top"]) * self.transform.e, # ymax
]
self.crop(crop_extent, inplace=True)

# Image is a file on disk.
elif isinstance(filename_or_dataset, (str, pathlib.Path, rio.io.DatasetReader, rio.io.MemoryFile)):
# ExitStack is used instead of "with rio.open(filename_or_dataset) as ds:".
Expand Down Expand Up @@ -472,7 +503,16 @@ def __init__(
# Downsampled image size
if not isinstance(downsample, (int, float)):
raise TypeError("downsample must be of type int or float.")
if downsample == 1:
if self._roi is not None:
new_transform = rio.transform.from_origin(
self.transform.c + self._roi["left"] * self.transform.a,
self.transform.f + (self.height - self._roi["top"]) * self.transform.e,
self.transform.a,
-self.transform.e,
)
self.transform = new_transform
out_shape = (self._roi["top"] - self._roi["bottom"], self._roi["right"] - self._roi["left"])
elif downsample == 1:
out_shape = (self.height, self.width)
else:
down_width = int(np.ceil(self.width / downsample))
Expand All @@ -495,6 +535,7 @@ def __init__(
masked=self._masked,
out_shape=out_shape,
out_count=count,
roi=self._roi,
) # type: ignore

# Probably don't want to use set_nodata that can update array, setting self._nodata is sufficient
Expand Down Expand Up @@ -791,6 +832,7 @@ def _load_only_mask(self, bands: int | list[int] | None = None, **kwargs: Any) -
shape=self.shape,
out_shape=self._out_shape,
out_count=out_count,
roi=self._roi,
**kwargs,
)

Expand Down Expand Up @@ -847,6 +889,7 @@ def load(self, bands: int | list[int] | None = None, **kwargs: Any) -> None:
shape=self.shape,
out_shape=self._out_shape,
out_count=self._out_count,
roi=self._roi,
**kwargs,
)

Expand All @@ -869,6 +912,7 @@ def from_array(
area_or_point: Literal["Area", "Point"] | None = None,
tags: dict[str, Any] = None,
cast_nodata: bool = True,
roi: dict[str, int] = None,
) -> RasterType:
"""Create a raster from a numpy array and the georeferencing information.
Expand All @@ -883,6 +927,7 @@ def from_array(
:param tags: Metadata stored in a dictionary.
:param cast_nodata: Automatically cast nodata value to the default nodata for the new array type if not
compatible. If False, will raise an error when incompatible.
:param roi: Optional dictionary specifying the region of interest (ROI) with keys: left, bottom, right, top.
:returns: Raster created from the provided array and georeferencing.
Expand All @@ -903,6 +948,17 @@ def from_array(
if cast_nodata:
nodata = _cast_nodata(data.dtype, nodata)

if roi is not None:
height = data.shape[-2]
new_transform = rio.transform.from_origin(
transform[2] + roi["left"] * transform[0],
transform[5] + (height - roi["top"]) * transform[4],
transform[0],
-transform[4],
)
transform = new_transform
data = data[..., height - roi["top"] : height - roi["bottom"], roi["left"] : roi["right"]]

# If the data was transformed into boolean, re-initialize as a Mask subclass
# Typing: we can specify this behaviour in @overload once we add the NumPy plugin of MyPy
if data.dtype == bool:
Expand Down

0 comments on commit 63ac935

Please sign in to comment.