Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding statistics to Raster #638

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions doc/source/raster_class.md
Original file line number Diff line number Diff line change
Expand Up @@ -346,3 +346,31 @@ rast_reproj.to_pointcloud()
# Export to xarray data array
rast_reproj.to_xarray()
```

## Obtain Statistics
The `get_stats()` method allows to extract key statistical information from a raster in a dictionary.
Supported statistics are : mean, median, max, mean, sum, sum of squares, 90th percentile, nmad, rmse, std.
Callable functions are supported as well.

### Usage Examples:
- Get all statistics in a dict:
```{code-cell} ipython3
rast.get_stats()
```

- Get a single statistic (e.g., 'mean') as a float:
```{code-cell} ipython3
rast.get_stats("mean")
```

- Get multiple statistics in a dict:
```{code-cell} ipython3
rast.get_stats(["mean", "max", "rmse"])
```

- Using a custom callable statistic:
```{code-cell} ipython3
def custom_stat(data):
return np.nansum(data > 100) # Count the number of pixels above 100
rast.get_stats(custom_stat)
```
177 changes: 167 additions & 10 deletions geoutils/raster/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from __future__ import annotations

import logging
import math
import pathlib
import warnings
Expand Down Expand Up @@ -69,6 +70,7 @@
decode_sensor_metadata,
parse_and_convert_metadata_from_filename,
)
from geoutils.stats import nmad
from geoutils.vector.vector import Vector

# If python38 or above, Literal is builtin. Otherwise, use typing_extensions
Expand Down Expand Up @@ -1870,6 +1872,157 @@ def set_mask(self, mask: NDArrayBool | Mask) -> None:
else:
self.data[mask_arr > 0] = np.ma.masked

def _statistics(self, band: int = 1) -> dict[str, np.floating[Any]]:
"""
Calculate common statistics for a specified band in the raster.

:param band: The index of the band for which to compute statistics. Default is 1.

:returns: A dictionary containing the calculated statistics for the selected band, including mean, median, max,
min, sum, sum of squares, 90th percentile, NMAD, RMSE, and standard deviation.
"""
if self.count == 1:
data = self.data
else:
data = self.data[band - 1]

# If data is a MaskedArray, use the compressed version (without masked values)
if isinstance(data, np.ma.MaskedArray):
data = data.compressed()

# Compute the statistics
stats_dict = {
"Mean": np.nanmean(data),
"Median": np.nanmedian(data),
"Max": np.nanmax(data),
"Min": np.nanmin(data),
"Sum": np.nansum(data),
"Sum of squares": np.nansum(np.square(data)),
"90th percentile": np.nanpercentile(data, 90),
"NMAD": nmad(data),
"RMSE": np.sqrt(np.nanmean(np.square(data - np.nanmean(data)))),
"Standard deviation": np.nanstd(data),
}
return stats_dict

@overload
def get_stats(
self,
stats_name: (
Literal["mean", "median", "max", "min", "sum", "sum of squares", "90th percentile", "nmad", "rmse", "std"]
| Callable[[NDArrayNum], np.floating[Any]]
),
band: int = 1,
) -> np.floating[Any]: ...

@overload
def get_stats(
self,
stats_name: (
list[
Literal[
"mean", "median", "max", "min", "sum", "sum of squares", "90th percentile", "nmad", "rmse", "std"
]
| Callable[[NDArrayNum], np.floating[Any]]
]
| None
) = None,
band: int = 1,
) -> dict[str, np.floating[Any]]: ...

def get_stats(
vschaffn marked this conversation as resolved.
Show resolved Hide resolved
self,
stats_name: (
Literal["mean", "median", "max", "min", "sum", "sum of squares", "90th percentile", "nmad", "rmse", "std"]
| Callable[[NDArrayNum], np.floating[Any]]
| list[
Literal[
"mean", "median", "max", "min", "sum", "sum of squares", "90th percentile", "nmad", "rmse", "std"
]
| Callable[[NDArrayNum], np.floating[Any]]
]
| None
) = None,
band: int = 1,
) -> np.floating[Any] | dict[str, np.floating[Any]]:
"""
Retrieve specified statistics or all available statistics for the raster data. Allows passing custom callables
to calculate custom stats.

:param stats_name: Name or list of names of the statistics to retrieve. If None, all statistics are returned.
Accepted names include:
- "mean", "median", "max", "min", "sum", "sum of squares", "90th percentile", "nmad", "rmse", "std"
You can also use common aliases for these names (e.g., "average", "maximum", "minimum", etc.).
Custom callables can also be provided.
:param band: The index of the band for which to compute statistics. Default is 1.

:returns: The requested statistic or a dictionary of statistics if multiple or all are requested.
"""
if not self.is_loaded:
self.load()
stats_dict = self._statistics(band=band)
if stats_name is None:
return stats_dict

# Define the metric aliases and their actual names
stats_aliases = {
"mean": "Mean",
"average": "Mean",
"median": "Median",
"max": "Max",
"maximum": "Max",
"min": "Min",
"minimum": "Min",
"sum": "Sum",
"sumofsquares": "Sum of squares",
"sum2": "Sum of squares",
"percentile": "90th percentile",

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's kind to consider the user, but it does make the documentation a bit more complex. We'll see if we leave that much flexibility.

"90thpercentile": "90th percentile",
"90percentile": "90th percentile",
"percentile90": "90th percentile",
"nmad": "NMAD",
"rmse": "RMSE",
"std": "Standard deviation",
"stddev": "Standard deviation",
"standarddev": "Standard deviation",
"standarddeviation": "Standard deviation",
}
if isinstance(stats_name, list):
result = {}
for name in stats_name:
if callable(name):
result[name.__name__] = name(self.data[band] if self.count > 1 else self.data)
else:
result[name] = self._get_single_stat(stats_dict, stats_aliases, name)
return result
else:
if callable(stats_name):
return stats_name(self.data[band] if self.count > 1 else self.data)
else:
return self._get_single_stat(stats_dict, stats_aliases, stats_name)

@staticmethod
def _get_single_stat(
stats_dict: dict[str, np.floating[Any]], stats_aliases: dict[str, str], stat_name: str
) -> np.floating[Any]:
"""
Retrieve a single statistic based on a flexible name or alias.

:param stats_dict: The dictionary of available statistics.
:param stats_aliases: The dictionary of alias mappings to the actual stat names.
:param stat_name: The name or alias of the statistic to retrieve.

:returns: The requested statistic value, or None if the stat name is not recognized.
"""

normalized_name = stat_name.lower().replace(" ", "").replace("_", "").replace("-", "")
if normalized_name in stats_aliases:
actual_name = stats_aliases[normalized_name]
return stats_dict[actual_name]
else:
logging.warning("Statistic name '%s' is not recognized", stat_name)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or raising an error here? Or is the reasoning for warning instead to accept potential typos in the statistic list?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not raised an error here in order to obtain results if stats_name is a list and only one stat is not recognized. But it could be passed as a raise if needed

return np.float32(np.nan)

@overload
def info(self, stats: bool = False, *, verbose: Literal[True] = ...) -> None: ...

Expand Down Expand Up @@ -1904,24 +2057,28 @@ def info(self, stats: bool = False, verbose: bool = True) -> None | str:
]

if stats:
as_str.append("\nStatistics:\n")
if not self.is_loaded:
self.load()

if self.count == 1:
as_str.append(f"[MAXIMUM]: {np.nanmax(self.data):.2f}\n")
as_str.append(f"[MINIMUM]: {np.nanmin(self.data):.2f}\n")
as_str.append(f"[MEDIAN]: {np.ma.median(self.data):.2f}\n")
as_str.append(f"[MEAN]: {np.nanmean(self.data):.2f}\n")
as_str.append(f"[STD DEV]: {np.nanstd(self.data):.2f}\n")
statistics = self.get_stats()

# Determine the maximum length of the stat names for alignment
max_len = max(len(name) for name in statistics.keys())

# Format the stats with aligned names
for name, value in statistics.items():
as_str.append(f"{name.ljust(max_len)}: {value:.2f}\n")
else:
for b in range(self.count):
# try to keep with rasterio convention.
as_str.append(f"Band {b + 1}:\n")
as_str.append(f"[MAXIMUM]: {np.nanmax(self.data[b, :, :]):.2f}\n")
as_str.append(f"[MINIMUM]: {np.nanmin(self.data[b, :, :]):.2f}\n")
as_str.append(f"[MEDIAN]: {np.ma.median(self.data[b, :, :]):.2f}\n")
as_str.append(f"[MEAN]: {np.nanmean(self.data[b, :, :]):.2f}\n")
as_str.append(f"[STD DEV]: {np.nanstd(self.data[b, :, :]):.2f}\n")
statistics = self.get_stats(band=b)
if isinstance(statistics, dict):
max_len = max(len(name) for name in statistics.keys())
for name, value in statistics.items():
as_str.append(f"{name.ljust(max_len)}: {value:.2f}\n")

if verbose:
print("".join(as_str))
Expand Down
26 changes: 26 additions & 0 deletions geoutils/stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
""" Statistical tools"""

from typing import Any

import numpy as np

from geoutils._typing import NDArrayNum


def nmad(data: NDArrayNum, nfact: float = 1.4826) -> np.floating[Any]:
"""
Calculate the normalized median absolute deviation (NMAD) of an array.
Default scaling factor is 1.4826 to scale the median absolute deviation (MAD) to the dispersion of a normal
distribution (see https://en.wikipedia.org/wiki/Median_absolute_deviation#Relation_to_standard_deviation, and
e.g. Höhle and Höhle (2009), http://dx.doi.org/10.1016/j.isprsjprs.2009.02.003)

:param data: Input array or raster
:param nfact: Normalization factor for the data

:returns nmad: (normalized) median absolute deviation of data.
"""
if isinstance(data, np.ma.masked_array):
data_arr = data.compressed()
else:
data_arr = np.asarray(data)
return nfact * np.nanmedian(np.abs(data_arr - np.nanmedian(data_arr)))
50 changes: 50 additions & 0 deletions tests/test_raster/test_raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@

from __future__ import annotations

import logging
import os
import pathlib
import re
import tempfile
import warnings
from cmath import isnan
from io import StringIO
from tempfile import TemporaryFile
from typing import Any

import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -1944,6 +1947,53 @@ def test_split_bands(self) -> None:
red_c.data.data.squeeze().astype("float32"), img.data.data[0, :, :].astype("float32"), equal_nan=True
)

@pytest.mark.parametrize("example", [landsat_b4_path, aster_dem_path, landsat_rgb_path]) # type: ignore
def test_stats(self, example: str, caplog) -> None:
raster = gu.Raster(example)

# Full stats
stats = raster.get_stats()
expected_stats = [
"Mean",
"Median",
"Max",
"Min",
"Sum",
"Sum of squares",
"90th percentile",
"NMAD",
"RMSE",
"Standard deviation",
]
for name in expected_stats:
assert name in stats
assert stats.get(name) is not None

# Single stat
stat = raster.get_stats(stats_name="Average")
assert isinstance(stat, np.floating)

def percentile_95(data: NDArrayNum) -> np.floating[Any]:
if isinstance(data, np.ma.MaskedArray):
data = data.compressed()
return np.nanpercentile(data, 95)

stat = raster.get_stats(stats_name=percentile_95)
assert isinstance(stat, np.floating)

# Selected stats and callable
stats_name = ["mean", "maximum", "std", "percentile_95"]
stats = raster.get_stats(stats_name=["mean", "maximum", "std", percentile_95])
for name in stats_name:
assert name in stats
assert stats.get(name) is not None

# non-existing stat
with caplog.at_level(logging.WARNING):
stat = raster.get_stats(stats_name="80 percentile")
assert isnan(stat)
assert "Statistic name '80 percentile' is not recognized" in caplog.text
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@adebardo There is already a test with an inconsistent stats name and log return checked here 😃



class TestMask:
# Paths to example data
Expand Down
30 changes: 30 additions & 0 deletions tests/test_stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""
Test functions for stats
"""

import scipy

from geoutils import Raster, examples
from geoutils.stats import nmad


class TestStats:
landsat_b4_path = examples.get_path("everest_landsat_b4")
landsat_raster = Raster(landsat_b4_path)

def test_nmad(self) -> None:
"""Test NMAD functionality runs on any type of input"""

# Check that the NMAD is computed the same with a masked array or NaN array, and is equal to scipy nmad
nmad_ma = nmad(self.landsat_raster.data)
nmad_array = nmad(self.landsat_raster.get_nanarray())
nmad_scipy = scipy.stats.median_abs_deviation(self.landsat_raster.data, axis=None, scale="normal")

assert nmad_ma == nmad_array
assert nmad_ma.round(2) == nmad_scipy.round(2)

# Check that the scaling factor works
nmad_1 = nmad(self.landsat_raster.data, nfact=1)
nmad_2 = nmad(self.landsat_raster.data, nfact=2)

assert nmad_1 * 2 == nmad_2