Skip to content

Commit

Permalink
Subscript-crop Raster or Vector by bracket call [], consistency…
Browse files Browse the repository at this point in the history
… and bug fixes (#335)

* Silence UserWarning for default shapefile bounds in create_mask

* Fix syntax with shapely deprecration

* Get consistent projected bounds functions, fix crop functionality, make projtools independent of georaster and geovector, and add tests

* Mirror Raster crop in Vector

* Fix syntax

* Linting

* Add bracket subscript Raster method based on crop

* Make polygonize robust to geopandas dtypes and improve tests

* Linting"

* Linting

* Silence warnings

* Try to trace source of error in align_bounds

* See what is now passed in intersection

* Fix package upgrade bug by widening the condition for void intersection

* Add prints to figure out issue

* Wtf is going on

* Testing in more details...

* Check for a tuple of nans (not numpy)

* Use math isnan

* Linting

* Incremental commit on comments

* Fix np.count_nonzero with Raster object

* Linting
  • Loading branch information
rhugonnet authored Jan 20, 2023
1 parent a43b57f commit 7e7ad86
Show file tree
Hide file tree
Showing 7 changed files with 269 additions and 88 deletions.
73 changes: 48 additions & 25 deletions geoutils/georaster/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""
from __future__ import annotations

import math
import os
import warnings
from collections import abc
Expand Down Expand Up @@ -32,6 +33,7 @@
import geoutils.geovector as gv
from geoutils._typing import AnyNumber, ArrayLike, DTypeLike
from geoutils.geovector import Vector
from geoutils.projtools import _get_bounds_projected

# If python38 or above, Literal is builtin. Otherwise, use typing_extensions
try:
Expand Down Expand Up @@ -525,6 +527,10 @@ def __str__(self) -> str:
"""Provide string of information about Raster."""
return self.info()

def __getitem__(self, value: Raster | Vector | list[float] | tuple[float, ...]) -> Raster:
"""Subset the Raster object: calls the crop method with default parameters"""
return self.crop(cropGeom=value, inplace=False)

def __eq__(self, other: object) -> bool:
"""Check if a Raster masked array's data (including masked values), mask, fill_value and dtype are equal,
as well as the Raster's nodata, and georeferencing."""
Expand Down Expand Up @@ -1314,7 +1320,9 @@ def crop(
inplace: bool = True,
) -> RasterType | None:
"""
Crop the Raster to a given extent.
Crop the Raster to a given extent, or bounds of a raster or vector.
Reprojection is done on the fly if georeferenced objects have different projections.
:param cropGeom: Geometry to crop raster to, as either a Raster object, a Vector object, or a list of
coordinates. If cropGeom is a Raster, crop() will crop to the boundary of the raster as returned by
Expand All @@ -1331,8 +1339,10 @@ def crop(
"match_extent",
"match_pixel",
], "mode must be one of 'match_pixel', 'match_extent'"

if isinstance(cropGeom, (Raster, Vector)):
xmin, ymin, xmax, ymax = cropGeom.bounds
# For another Vector or Raster, we reproject the bounding box in the same CRS as self
xmin, ymin, xmax, ymax = cropGeom.get_bounds_projected(out_crs=self.crs)
elif isinstance(cropGeom, (list, tuple)):
xmin, ymin, xmax, ymax = cropGeom
else:
Expand Down Expand Up @@ -1375,13 +1385,13 @@ def crop(
def reproject(
self: RasterType,
dst_ref: RasterType | rio.io.Dataset | str | None = None,
dst_crs: CRS | str | None = None,
dst_crs: CRS | str | int | None = None,
dst_size: tuple[int, int] | None = None,
dst_bounds: dict[str, float] | rio.coords.BoundingBox | None = None,
dst_res: float | abc.Iterable[float] | None = None,
dst_nodata: int | float | list[int] | list[float] | None = None,
src_nodata: int | float | list[int] | list[float] | None = None,
dtype: np.dtype | None = None,
dst_dtype: np.dtype | None = None,
resampling: Resampling | str = Resampling.bilinear,
silent: bool = False,
n_threads: int = 0,
Expand All @@ -1399,13 +1409,15 @@ def reproject(
:param dst_ref: a reference raster. If set will use the attributes of this
raster for the output grid. Can be provided as Raster/rasterio data set or as path to the file.
:param crs: Specify the Coordinate Reference System to reproject to. If dst_ref not set, defaults to self.crs.
:param dst_crs: Specify the Coordinate Reference System or EPSG to reproject to. If dst_ref not set,
defaults to self.crs.
:param dst_size: Raster size to write to (x, y). Do not use with dst_res.
:param dst_bounds: a BoundingBox object or a dictionary containing\
left, bottom, right, top bounds in the source CRS.
:param dst_bounds: a BoundingBox object or a dictionary containing left, bottom, right, top bounds in the
source CRS.
:param dst_res: Pixel size in units of target CRS. Either 1 value or (xres, yres). Do not use with dst_size.
:param dst_nodata: nodata value of the destination. If set to None, will use the same as source, \
:param dst_nodata: nodata value of the destination. If set to None, will use the same as source,
and if source is None, will use GDAL's default.
:param dst_dtype: Set data type of output.
:param src_nodata: nodata value of the source. If set to None, will read from the metadata.
:param resampling: A rasterio Resampling method
:param silent: If True, will not print warning statements
Expand Down Expand Up @@ -1452,9 +1464,9 @@ def reproject(
dst_crs = CRS.from_user_input(dst_crs)

# Set output dtype
if dtype is None:
if dst_dtype is None:
# Warning: this will not work for multiple bands with different dtypes
dtype = self.dtypes[0]
dst_dtype = self.dtypes[0]

# Set source nodata if provided
if src_nodata is None:
Expand All @@ -1465,7 +1477,7 @@ def reproject(
if dst_nodata is None:
dst_nodata = self.nodata
if dst_nodata is None:
dst_nodata = _default_nodata(dtype)
dst_nodata = _default_nodata(dst_dtype)
# if dst_nodata is already being used, raise a warning.
# TODO: for uint8, if all values are used, apply rio.warp to mask to identify invalid values
if not self.is_loaded:
Expand Down Expand Up @@ -1538,7 +1550,7 @@ def reproject(
# Set output shape (Note: dst_size is (ncol, nrow))
if dst_size is not None:
dst_shape = (self.count, dst_size[1], dst_size[0])
dst_data = np.ones(dst_shape, dtype=dtype)
dst_data = np.ones(dst_shape, dtype=dst_dtype)
reproj_kwargs.update({"destination": dst_data})
else:
dst_shape = (self.count, self.height, self.width)
Expand All @@ -1559,7 +1571,7 @@ def reproject(

# Specify the output bounds and shape, let rasterio handle the rest
reproj_kwargs.update({"dst_transform": dst_transform})
dst_data = np.ones((dst_size[1], dst_size[0]), dtype=dtype)
dst_data = np.ones((dst_size[1], dst_size[0]), dtype=dst_dtype)
reproj_kwargs.update({"destination": dst_data})

# Check that reprojection is actually needed
Expand Down Expand Up @@ -1617,7 +1629,7 @@ def reproject(
dst_data = np.array(dst_data)

# Enforce output type
dst_data = np.ma.masked_array(dst_data.astype(dtype), fill_value=dst_nodata)
dst_data = np.ma.masked_array(dst_data.astype(dst_dtype), fill_value=dst_nodata)

if dst_nodata is not None:
dst_data.mask = dst_data == dst_nodata
Expand Down Expand Up @@ -1775,25 +1787,22 @@ def to_xarray(self, name: str | None = None) -> rioxarray.DataArray:

return xr

def get_bounds_projected(self, out_crs: CRS, densify_pts_max: int = 5000) -> rio.coords.BoundingBox:
def get_bounds_projected(self, out_crs: CRS, densify_pts: int = 5000) -> rio.coords.BoundingBox:
"""
Return self's bounds in the given CRS.
:param out_crs: Output CRS
:param densify_pts_max: Maximum points to be added between image corners to account for non linear edges.
Reduce if time computation is really critical (ms) or increase if extent is \
not accurate enough.
:param densify_pts: Maximum points to be added between image corners to account for non linear edges.
Reduce if time computation is really critical (ms) or increase if extent is not accurate enough.
"""
# Max points to be added between image corners to account for non linear edges
# rasterio's default is a bit low for very large images
# instead, use image dimensions, with a maximum of 50000
densify_pts = min(max(self.width, self.height), densify_pts_max)
densify_pts = min(max(self.width, self.height), densify_pts)

# Calculate new bounds
left, bottom, right, top = self.bounds
new_bounds = rio.warp.transform_bounds(self.crs, out_crs, left, bottom, right, top, densify_pts)
new_bounds = rio.coords.BoundingBox(*new_bounds)
new_bounds = _get_bounds_projected(self.bounds, in_crs=self.crs, out_crs=out_crs, densify_pts=densify_pts)

return new_bounds

Expand Down Expand Up @@ -1822,8 +1831,8 @@ def intersection(self, rst: str | Raster, match_ref: bool = True) -> tuple[float
# Calculate intersection of bounding boxes
intersection = projtools.merge_bounds([self.bounds, rst_bounds_sameproj], merging_algorithm="intersection")

# check that intersection is not void, otherwise return 0 everywhere
if intersection == ():
# Check that intersection is not void (changed to NaN instead of empty tuple end 2022)
if intersection == () or all(math.isnan(i) for i in intersection):
warnings.warn("Intersection is void")
return (0.0, 0.0, 0.0, 0.0)

Expand Down Expand Up @@ -2476,9 +2485,23 @@ def polygonize(

raise ValueError("in_value must be a number, a tuple or a sequence")

# GeoPandas.from_features() only supports certain dtypes, we find the best common dtype to optimize memory usage
# TODO: this should be a function independent of polygonize, reused in several places
gpd_dtypes = ["uint8", "uint16", "int16", "int32", "float32"]
list_common_dtype_index = []
for gpd_type in gpd_dtypes:
polygonize_dtype = np.find_common_type([gpd_type, self.dtypes[0]], [])
if str(polygonize_dtype) in gpd_dtypes:
list_common_dtype_index.append(gpd_dtypes.index(gpd_type))
if len(list_common_dtype_index) == 0:
final_dtype = "float32"
else:
final_dtype_index = min(list_common_dtype_index)
final_dtype = gpd_dtypes[final_dtype_index]

results = (
{"properties": {"raster_value": v}, "geometry": s}
for i, (s, v) in enumerate(shapes(self.data, mask=bool_msk, transform=self.transform))
for i, (s, v) in enumerate(shapes(self.data.astype(final_dtype), mask=bool_msk, transform=self.transform))
)

gdf = gpd.GeoDataFrame.from_features(list(results))
Expand Down
Loading

0 comments on commit 7e7ad86

Please sign in to comment.