diff --git a/doc/source/raster_class.md b/doc/source/raster_class.md index ae8b6475..30001c28 100644 --- a/doc/source/raster_class.md +++ b/doc/source/raster_class.md @@ -346,3 +346,31 @@ rast_reproj.to_pointcloud() # Export to xarray data array rast_reproj.to_xarray() ``` + +## Obtain Statistics +The `get_stats()` method allows to extract key statistical information from a raster in a dictionary. +Supported statistics are : mean, median, max, mean, sum, sum of squares, 90th percentile, nmad, rmse, std. +Callable functions are supported as well. + +### Usage Examples: +- Get all statistics in a dict: +```{code-cell} ipython3 +rast.get_stats() +``` + +- Get a single statistic (e.g., 'mean') as a float: +```{code-cell} ipython3 +rast.get_stats("mean") +``` + +- Get multiple statistics in a dict: +```{code-cell} ipython3 +rast.get_stats(["mean", "max", "rmse"]) +``` + +- Using a custom callable statistic: +```{code-cell} ipython3 +def custom_stat(data): + return np.nansum(data > 100) # Count the number of pixels above 100 +rast.get_stats(custom_stat) +``` diff --git a/geoutils/raster/raster.py b/geoutils/raster/raster.py index aaedfbf0..7edc2e85 100644 --- a/geoutils/raster/raster.py +++ b/geoutils/raster/raster.py @@ -4,6 +4,7 @@ from __future__ import annotations +import logging import math import pathlib import warnings @@ -69,6 +70,7 @@ decode_sensor_metadata, parse_and_convert_metadata_from_filename, ) +from geoutils.stats import nmad from geoutils.vector.vector import Vector # If python38 or above, Literal is builtin. Otherwise, use typing_extensions @@ -1870,6 +1872,157 @@ def set_mask(self, mask: NDArrayBool | Mask) -> None: else: self.data[mask_arr > 0] = np.ma.masked + def _statistics(self, band: int = 1) -> dict[str, np.floating[Any]]: + """ + Calculate common statistics for a specified band in the raster. + + :param band: The index of the band for which to compute statistics. Default is 1. + + :returns: A dictionary containing the calculated statistics for the selected band, including mean, median, max, + min, sum, sum of squares, 90th percentile, NMAD, RMSE, and standard deviation. + """ + if self.count == 1: + data = self.data + else: + data = self.data[band - 1] + + # If data is a MaskedArray, use the compressed version (without masked values) + if isinstance(data, np.ma.MaskedArray): + data = data.compressed() + + # Compute the statistics + stats_dict = { + "Mean": np.nanmean(data), + "Median": np.nanmedian(data), + "Max": np.nanmax(data), + "Min": np.nanmin(data), + "Sum": np.nansum(data), + "Sum of squares": np.nansum(np.square(data)), + "90th percentile": np.nanpercentile(data, 90), + "NMAD": nmad(data), + "RMSE": np.sqrt(np.nanmean(np.square(data - np.nanmean(data)))), + "Standard deviation": np.nanstd(data), + } + return stats_dict + + @overload + def get_stats( + self, + stats_name: ( + Literal["mean", "median", "max", "min", "sum", "sum of squares", "90th percentile", "nmad", "rmse", "std"] + | Callable[[NDArrayNum], np.floating[Any]] + ), + band: int = 1, + ) -> np.floating[Any]: ... + + @overload + def get_stats( + self, + stats_name: ( + list[ + Literal[ + "mean", "median", "max", "min", "sum", "sum of squares", "90th percentile", "nmad", "rmse", "std" + ] + | Callable[[NDArrayNum], np.floating[Any]] + ] + | None + ) = None, + band: int = 1, + ) -> dict[str, np.floating[Any]]: ... + + def get_stats( + self, + stats_name: ( + Literal["mean", "median", "max", "min", "sum", "sum of squares", "90th percentile", "nmad", "rmse", "std"] + | Callable[[NDArrayNum], np.floating[Any]] + | list[ + Literal[ + "mean", "median", "max", "min", "sum", "sum of squares", "90th percentile", "nmad", "rmse", "std" + ] + | Callable[[NDArrayNum], np.floating[Any]] + ] + | None + ) = None, + band: int = 1, + ) -> np.floating[Any] | dict[str, np.floating[Any]]: + """ + Retrieve specified statistics or all available statistics for the raster data. Allows passing custom callables + to calculate custom stats. + + :param stats_name: Name or list of names of the statistics to retrieve. If None, all statistics are returned. + Accepted names include: + - "mean", "median", "max", "min", "sum", "sum of squares", "90th percentile", "nmad", "rmse", "std" + You can also use common aliases for these names (e.g., "average", "maximum", "minimum", etc.). + Custom callables can also be provided. + :param band: The index of the band for which to compute statistics. Default is 1. + + :returns: The requested statistic or a dictionary of statistics if multiple or all are requested. + """ + if not self.is_loaded: + self.load() + stats_dict = self._statistics(band=band) + if stats_name is None: + return stats_dict + + # Define the metric aliases and their actual names + stats_aliases = { + "mean": "Mean", + "average": "Mean", + "median": "Median", + "max": "Max", + "maximum": "Max", + "min": "Min", + "minimum": "Min", + "sum": "Sum", + "sumofsquares": "Sum of squares", + "sum2": "Sum of squares", + "percentile": "90th percentile", + "90thpercentile": "90th percentile", + "90percentile": "90th percentile", + "percentile90": "90th percentile", + "nmad": "NMAD", + "rmse": "RMSE", + "std": "Standard deviation", + "stddev": "Standard deviation", + "standarddev": "Standard deviation", + "standarddeviation": "Standard deviation", + } + if isinstance(stats_name, list): + result = {} + for name in stats_name: + if callable(name): + result[name.__name__] = name(self.data[band] if self.count > 1 else self.data) + else: + result[name] = self._get_single_stat(stats_dict, stats_aliases, name) + return result + else: + if callable(stats_name): + return stats_name(self.data[band] if self.count > 1 else self.data) + else: + return self._get_single_stat(stats_dict, stats_aliases, stats_name) + + @staticmethod + def _get_single_stat( + stats_dict: dict[str, np.floating[Any]], stats_aliases: dict[str, str], stat_name: str + ) -> np.floating[Any]: + """ + Retrieve a single statistic based on a flexible name or alias. + + :param stats_dict: The dictionary of available statistics. + :param stats_aliases: The dictionary of alias mappings to the actual stat names. + :param stat_name: The name or alias of the statistic to retrieve. + + :returns: The requested statistic value, or None if the stat name is not recognized. + """ + + normalized_name = stat_name.lower().replace(" ", "").replace("_", "").replace("-", "") + if normalized_name in stats_aliases: + actual_name = stats_aliases[normalized_name] + return stats_dict[actual_name] + else: + logging.warning("Statistic name '%s' is not recognized", stat_name) + return np.float32(np.nan) + @overload def info(self, stats: bool = False, *, verbose: Literal[True] = ...) -> None: ... @@ -1904,24 +2057,28 @@ def info(self, stats: bool = False, verbose: bool = True) -> None | str: ] if stats: + as_str.append("\nStatistics:\n") if not self.is_loaded: self.load() if self.count == 1: - as_str.append(f"[MAXIMUM]: {np.nanmax(self.data):.2f}\n") - as_str.append(f"[MINIMUM]: {np.nanmin(self.data):.2f}\n") - as_str.append(f"[MEDIAN]: {np.ma.median(self.data):.2f}\n") - as_str.append(f"[MEAN]: {np.nanmean(self.data):.2f}\n") - as_str.append(f"[STD DEV]: {np.nanstd(self.data):.2f}\n") + statistics = self.get_stats() + + # Determine the maximum length of the stat names for alignment + max_len = max(len(name) for name in statistics.keys()) + + # Format the stats with aligned names + for name, value in statistics.items(): + as_str.append(f"{name.ljust(max_len)}: {value:.2f}\n") else: for b in range(self.count): # try to keep with rasterio convention. as_str.append(f"Band {b + 1}:\n") - as_str.append(f"[MAXIMUM]: {np.nanmax(self.data[b, :, :]):.2f}\n") - as_str.append(f"[MINIMUM]: {np.nanmin(self.data[b, :, :]):.2f}\n") - as_str.append(f"[MEDIAN]: {np.ma.median(self.data[b, :, :]):.2f}\n") - as_str.append(f"[MEAN]: {np.nanmean(self.data[b, :, :]):.2f}\n") - as_str.append(f"[STD DEV]: {np.nanstd(self.data[b, :, :]):.2f}\n") + statistics = self.get_stats(band=b) + if isinstance(statistics, dict): + max_len = max(len(name) for name in statistics.keys()) + for name, value in statistics.items(): + as_str.append(f"{name.ljust(max_len)}: {value:.2f}\n") if verbose: print("".join(as_str)) diff --git a/geoutils/stats.py b/geoutils/stats.py new file mode 100644 index 00000000..838c701c --- /dev/null +++ b/geoutils/stats.py @@ -0,0 +1,26 @@ +""" Statistical tools""" + +from typing import Any + +import numpy as np + +from geoutils._typing import NDArrayNum + + +def nmad(data: NDArrayNum, nfact: float = 1.4826) -> np.floating[Any]: + """ + Calculate the normalized median absolute deviation (NMAD) of an array. + Default scaling factor is 1.4826 to scale the median absolute deviation (MAD) to the dispersion of a normal + distribution (see https://en.wikipedia.org/wiki/Median_absolute_deviation#Relation_to_standard_deviation, and + e.g. Höhle and Höhle (2009), http://dx.doi.org/10.1016/j.isprsjprs.2009.02.003) + + :param data: Input array or raster + :param nfact: Normalization factor for the data + + :returns nmad: (normalized) median absolute deviation of data. + """ + if isinstance(data, np.ma.masked_array): + data_arr = data.compressed() + else: + data_arr = np.asarray(data) + return nfact * np.nanmedian(np.abs(data_arr - np.nanmedian(data_arr))) diff --git a/tests/test_raster/test_raster.py b/tests/test_raster/test_raster.py index aee3894c..a128b5e7 100644 --- a/tests/test_raster/test_raster.py +++ b/tests/test_raster/test_raster.py @@ -4,13 +4,16 @@ from __future__ import annotations +import logging import os import pathlib import re import tempfile import warnings +from cmath import isnan from io import StringIO from tempfile import TemporaryFile +from typing import Any import matplotlib.pyplot as plt import numpy as np @@ -1944,6 +1947,53 @@ def test_split_bands(self) -> None: red_c.data.data.squeeze().astype("float32"), img.data.data[0, :, :].astype("float32"), equal_nan=True ) + @pytest.mark.parametrize("example", [landsat_b4_path, aster_dem_path, landsat_rgb_path]) # type: ignore + def test_stats(self, example: str, caplog) -> None: + raster = gu.Raster(example) + + # Full stats + stats = raster.get_stats() + expected_stats = [ + "Mean", + "Median", + "Max", + "Min", + "Sum", + "Sum of squares", + "90th percentile", + "NMAD", + "RMSE", + "Standard deviation", + ] + for name in expected_stats: + assert name in stats + assert stats.get(name) is not None + + # Single stat + stat = raster.get_stats(stats_name="Average") + assert isinstance(stat, np.floating) + + def percentile_95(data: NDArrayNum) -> np.floating[Any]: + if isinstance(data, np.ma.MaskedArray): + data = data.compressed() + return np.nanpercentile(data, 95) + + stat = raster.get_stats(stats_name=percentile_95) + assert isinstance(stat, np.floating) + + # Selected stats and callable + stats_name = ["mean", "maximum", "std", "percentile_95"] + stats = raster.get_stats(stats_name=["mean", "maximum", "std", percentile_95]) + for name in stats_name: + assert name in stats + assert stats.get(name) is not None + + # non-existing stat + with caplog.at_level(logging.WARNING): + stat = raster.get_stats(stats_name="80 percentile") + assert isnan(stat) + assert "Statistic name '80 percentile' is not recognized" in caplog.text + class TestMask: # Paths to example data diff --git a/tests/test_stats.py b/tests/test_stats.py new file mode 100644 index 00000000..dbfb5118 --- /dev/null +++ b/tests/test_stats.py @@ -0,0 +1,30 @@ +""" +Test functions for stats +""" + +import scipy + +from geoutils import Raster, examples +from geoutils.stats import nmad + + +class TestStats: + landsat_b4_path = examples.get_path("everest_landsat_b4") + landsat_raster = Raster(landsat_b4_path) + + def test_nmad(self) -> None: + """Test NMAD functionality runs on any type of input""" + + # Check that the NMAD is computed the same with a masked array or NaN array, and is equal to scipy nmad + nmad_ma = nmad(self.landsat_raster.data) + nmad_array = nmad(self.landsat_raster.get_nanarray()) + nmad_scipy = scipy.stats.median_abs_deviation(self.landsat_raster.data, axis=None, scale="normal") + + assert nmad_ma == nmad_array + assert nmad_ma.round(2) == nmad_scipy.round(2) + + # Check that the scaling factor works + nmad_1 = nmad(self.landsat_raster.data, nfact=1) + nmad_2 = nmad(self.landsat_raster.data, nfact=2) + + assert nmad_1 * 2 == nmad_2