Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding statistics to Raster #638

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 145 additions & 10 deletions geoutils/raster/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from __future__ import annotations

import logging
import math
import pathlib
import warnings
Expand Down Expand Up @@ -1870,6 +1871,135 @@ def set_mask(self, mask: NDArrayBool | Mask) -> None:
else:
self.data[mask_arr > 0] = np.ma.masked

def _statistics(self, band: int = 1) -> dict[str, np.floating[Any]]:
"""
Calculate common statistics for a specified band in the raster.

:param band: The index of the band for which to compute statistics. Default is 1.

:returns: A dictionary containing the calculated statistics for the selected band, including mean, median, max,
min, sum, sum of squares, 90th percentile, NMAD, RMSE, and standard deviation.
"""
if self.count == 1:
data = self.data
else:
data = self.data[band - 1]

# If data is a MaskedArray, use the compressed version (without masked values)
if isinstance(data, np.ma.MaskedArray):
data = data.compressed()

# Compute the statistics
stats_dict = {
"Mean": np.nanmean(data),
"Median": np.nanmedian(data),
"Max": np.nanmax(data),
"Min": np.nanmin(data),
"Sum": np.nansum(data),
"Sum of squares": np.nansum(np.square(data)),
"90th percentile": np.nanpercentile(data, 90),
"NMAD": self._nmad(),
vschaffn marked this conversation as resolved.
Show resolved Hide resolved
"RMSE": np.sqrt(np.nanmean(np.square(data - np.nanmean(data)))),
"Standard deviation": np.nanstd(data),
}
return stats_dict

def get_stats(
vschaffn marked this conversation as resolved.
Show resolved Hide resolved
self,
stats_name: (
str | Callable[[NDArrayNum], np.floating[Any]] | list[str | Callable[[NDArrayNum], np.floating[Any]]] | None
vschaffn marked this conversation as resolved.
Show resolved Hide resolved
) = None,
band: int = 1,
) -> np.floating[Any] | dict[str, np.floating[Any]]:
"""
Retrieve specified statistics or all available statistics for the raster data. Allows passing custom callables
to calculate custom stats.

:param stats_name: Name or list of names of the statistics to retrieve. If None, all statistics are returned.
:param band: The index of the band for which to compute statistics. Default is 1.

:returns: The requested statistic or a dictionary of statistics if multiple or all are requested.
"""
if not self.is_loaded:
self.load()
stats_dict = self._statistics(band=band)
if stats_name is None:
return stats_dict

# Define the metric aliases and their actual names
stats_aliases = {
"mean": "Mean",
"average": "Mean",
"median": "Median",
"max": "Max",
"maximum": "Max",
"min": "Min",
"minimum": "Min",
"sum": "Sum",
"sumofsquares": "Sum of squares",
"sum2": "Sum of squares",
"percentile": "90th percentile",

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's kind to consider the user, but it does make the documentation a bit more complex. We'll see if we leave that much flexibility.

"90thpercentile": "90th percentile",
"90percentile": "90th percentile",
"percentile90": "90th percentile",
"nmad": "NMAD",
"rmse": "RMSE",
"std": "Standard deviation",
"stddev": "Standard deviation",
"standarddev": "Standard deviation",
"standarddeviation": "Standard deviation",
}
if isinstance(stats_name, list):
result = {}
for name in stats_name:
if callable(name):
result[name.__name__] = name(self.data[band] if self.count > 1 else self.data)
else:
result[name] = self._get_single_stat(stats_dict, stats_aliases, name)
return result
else:
if callable(stats_name):
return stats_name(self.data[band] if self.count > 1 else self.data)
else:
return self._get_single_stat(stats_dict, stats_aliases, stats_name)

@staticmethod
def _get_single_stat(
stats_dict: dict[str, np.floating[Any]], stats_aliases: dict[str, str], stat_name: str
) -> np.floating[Any]:
"""
Retrieve a single statistic based on a flexible name or alias.

:param stats_dict: The dictionary of available statistics.
:param stats_aliases: The dictionary of alias mappings to the actual stat names.
:param stat_name: The name or alias of the statistic to retrieve.

:returns: The requested statistic value, or None if the stat name is not recognized.
"""

normalized_name = stat_name.lower().replace(" ", "").replace("_", "").replace("-", "")
if normalized_name in stats_aliases:
actual_name = stats_aliases[normalized_name]
return stats_dict[actual_name]
else:
logging.warning("Statistic name '%s' is not recognized", stat_name)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or raising an error here? Or is the reasoning for warning instead to accept potential typos in the statistic list?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not raised an error here in order to obtain results if stats_name is a list and only one stat is not recognized. But it could be passed as a raise if needed

return np.float32(np.nan)

def _nmad(self, nfact: float = 1.4826, band: int = 0) -> np.floating[Any]:
vschaffn marked this conversation as resolved.
Show resolved Hide resolved
"""
Calculate the normalized median absolute deviation (NMAD) of an array.
Default scaling factor is 1.4826 to scale the median absolute deviation (MAD) to the dispersion of a normal
distribution (see https://en.wikipedia.org/wiki/Median_absolute_deviation#Relation_to_standard_deviation, and
e.g. Höhle and Höhle (2009), http://dx.doi.org/10.1016/j.isprsjprs.2009.02.003)
"""
if self.count == 1:
data = self.data
else:
data = self.data[band]
if isinstance(data, np.ma.MaskedArray):
data = data.compressed()
return nfact * np.nanmedian(np.abs(data - np.nanmedian(data)))

@overload
def info(self, stats: bool = False, *, verbose: Literal[True] = ...) -> None: ...

Expand Down Expand Up @@ -1904,24 +2034,29 @@ def info(self, stats: bool = False, verbose: bool = True) -> None | str:
]

if stats:
as_str.append("\nStatistics:\n")
if not self.is_loaded:
self.load()

if self.count == 1:
as_str.append(f"[MAXIMUM]: {np.nanmax(self.data):.2f}\n")
as_str.append(f"[MINIMUM]: {np.nanmin(self.data):.2f}\n")
as_str.append(f"[MEDIAN]: {np.ma.median(self.data):.2f}\n")
as_str.append(f"[MEAN]: {np.nanmean(self.data):.2f}\n")
as_str.append(f"[STD DEV]: {np.nanstd(self.data):.2f}\n")
statistics = self.get_stats()
if isinstance(statistics, dict):

# Determine the maximum length of the stat names for alignment
max_len = max(len(name) for name in statistics.keys())

# Format the stats with aligned names
for name, value in statistics.items():
as_str.append(f"{name.ljust(max_len)}: {value:.2f}\n")
else:
for b in range(self.count):
# try to keep with rasterio convention.
as_str.append(f"Band {b + 1}:\n")
as_str.append(f"[MAXIMUM]: {np.nanmax(self.data[b, :, :]):.2f}\n")
as_str.append(f"[MINIMUM]: {np.nanmin(self.data[b, :, :]):.2f}\n")
as_str.append(f"[MEDIAN]: {np.ma.median(self.data[b, :, :]):.2f}\n")
as_str.append(f"[MEAN]: {np.nanmean(self.data[b, :, :]):.2f}\n")
as_str.append(f"[STD DEV]: {np.nanstd(self.data[b, :, :]):.2f}\n")
statistics = self.get_stats(band=b)
if isinstance(statistics, dict):
max_len = max(len(name) for name in statistics.keys())
for name, value in statistics.items():
as_str.append(f"{name.ljust(max_len)}: {value:.2f}\n")

if verbose:
print("".join(as_str))
Expand Down
50 changes: 50 additions & 0 deletions tests/test_raster/test_raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@

from __future__ import annotations

import logging
import os
import pathlib
import re
import tempfile
import warnings
from cmath import isnan
from io import StringIO
from tempfile import TemporaryFile
from typing import Any

import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -1944,6 +1947,53 @@ def test_split_bands(self) -> None:
red_c.data.data.squeeze().astype("float32"), img.data.data[0, :, :].astype("float32"), equal_nan=True
)

@pytest.mark.parametrize("example", [landsat_b4_path, aster_dem_path, landsat_rgb_path]) # type: ignore
def test_stats(self, example: str, caplog) -> None:
raster = gu.Raster(example)

# Full stats
stats = raster.get_stats()
expected_stats = [
"Mean",
"Median",
"Max",
"Min",
"Sum",
"Sum of squares",
"90th percentile",
"NMAD",
"RMSE",
"Standard deviation",
]
for name in expected_stats:
assert name in stats
assert stats.get(name) is not None

# Single stat
stat = raster.get_stats(stats_name="Average")
assert isinstance(stat, np.floating)

def percentile_95(data: NDArrayNum) -> np.floating[Any]:
if isinstance(data, np.ma.MaskedArray):
data = data.compressed()
return np.nanpercentile(data, 95)

stat = raster.get_stats(stats_name=percentile_95)
assert isinstance(stat, np.floating)

# Selected stats and callable
stats_name = ["mean", "maximum", "std", "percentile_95"]
stats = raster.get_stats(stats_name=["mean", "maximum", "std", percentile_95])
for name in stats_name:
assert name in stats
assert stats.get(name) is not None

# non-existing stat
with caplog.at_level(logging.WARNING):
stat = raster.get_stats(stats_name="80 percentile")
assert isnan(stat)
assert "Statistic name '80 percentile' is not recognized" in caplog.text
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@adebardo There is already a test with an inconsistent stats name and log return checked here 😃



class TestMask:
# Paths to example data
Expand Down
Loading