From 2311ee6113006429e4a196276a1a2e7aeb5ab0ad Mon Sep 17 00:00:00 2001 From: vschaffn Date: Tue, 3 Dec 2024 18:45:31 +0100 Subject: [PATCH] fix: add callable type in get_stats --- geoutils/raster/raster.py | 27 +++++++++++++++++++++------ tests/test_raster/test_raster.py | 7 ++++++- 2 files changed, 27 insertions(+), 7 deletions(-) diff --git a/geoutils/raster/raster.py b/geoutils/raster/raster.py index 4f799c86..4eb9c2e7 100644 --- a/geoutils/raster/raster.py +++ b/geoutils/raster/raster.py @@ -1871,7 +1871,7 @@ def set_mask(self, mask: NDArrayBool | Mask) -> None: else: self.data[mask_arr > 0] = np.ma.masked - def _statistics(self, band: int = 1) -> dict[str, float]: + def _statistics(self, band: int = 1) -> dict[str, np.floating[Any]]: """ Calculate common statistics for a specified band in the raster. @@ -1885,6 +1885,10 @@ def _statistics(self, band: int = 1) -> dict[str, float]: else: data = self.data[band - 1] + # If data is a MaskedArray, use the compressed version (without masked values) + if isinstance(data, np.ma.MaskedArray): + data = data.compressed() + # Compute the statistics stats_dict = { "Mean": np.nanmean(data), @@ -1901,8 +1905,12 @@ def _statistics(self, band: int = 1) -> dict[str, float]: return stats_dict def get_stats( - self, stats_name: str | list[str | Callable[[NDArrayNum], float]] | None = None, band: int = 1 - ) -> float | dict[str, float]: + self, + stats_name: ( + str | Callable[[NDArrayNum], np.floating[Any]] | list[str | Callable[[NDArrayNum], np.floating[Any]]] | None + ) = None, + band: int = 1, + ) -> np.floating[Any] | dict[str, np.floating[Any]]: """ Retrieve specified statistics or all available statistics for the raster data. Allows passing custom callables to calculate custom stats. @@ -1944,10 +1952,15 @@ def get_stats( result[name] = self._get_single_stat(stats_dict, stats_aliases, name) return result else: - return self._get_single_stat(stats_dict, stats_aliases, stats_name) + if callable(stats_name): + return stats_name(self.data[band] if self.count > 1 else self.data) + else: + return self._get_single_stat(stats_dict, stats_aliases, stats_name) @staticmethod - def _get_single_stat(stats_dict: dict[str, float], stats_aliases: dict[str, str], stat_name: str) -> float: + def _get_single_stat( + stats_dict: dict[str, np.floating[Any]], stats_aliases: dict[str, str], stat_name: str + ) -> np.floating[Any]: """ Retrieve a single statistic based on a flexible name or alias. @@ -1964,7 +1977,7 @@ def _get_single_stat(stats_dict: dict[str, float], stats_aliases: dict[str, str] return stats_dict[actual_name] else: logging.warning("Statistic name '%s' is not recognized", stat_name) - return np.nan + return np.floating(np.nan) def _nmad(self, nfact: float = 1.4826, band: int = 0) -> np.floating[Any]: """ @@ -1977,6 +1990,8 @@ def _nmad(self, nfact: float = 1.4826, band: int = 0) -> np.floating[Any]: data = self.data else: data = self.data[band] + if isinstance(data, np.ma.MaskedArray): + data = data.compressed() return nfact * np.nanmedian(np.abs(data - np.nanmedian(data))) @overload diff --git a/tests/test_raster/test_raster.py b/tests/test_raster/test_raster.py index 5c49e55f..1d4f96e9 100644 --- a/tests/test_raster/test_raster.py +++ b/tests/test_raster/test_raster.py @@ -1971,10 +1971,15 @@ def test_stats(self, example: str) -> None: stat = raster.get_stats(stats_name="Average") assert isinstance(stat, np.floating) - # Selected stats and callable def percentile_95(data: NDArrayNum) -> np.floating[Any]: + if isinstance(data, np.ma.MaskedArray): + data = data.compressed() return np.nanpercentile(data, 95) + stat = raster.get_stats(stats_name=percentile_95) + assert isinstance(stat, np.floating) + + # Selected stats and callable stats_name = ["mean", "maximum", "std", "percentile_95"] stats = raster.get_stats(stats_name=["mean", "maximum", "std", percentile_95]) for name in stats_name: