From faeb20366c0b7d1af24c4797cf0da1652944dcaf Mon Sep 17 00:00:00 2001 From: Blampey Quentin Date: Mon, 4 Nov 2024 14:52:19 +0100 Subject: [PATCH] can take the min/max in channels aggregation (#147) --- docs/api/aggregation.md | 9 +++++++ sopa/aggregation/__init__.py | 2 +- sopa/aggregation/aggregation.py | 6 +++-- sopa/aggregation/channels.py | 47 ++++++++++++++++++++++++++++----- tests/test_aggregation.py | 22 +++++++++++---- 5 files changed, 71 insertions(+), 15 deletions(-) diff --git a/docs/api/aggregation.md b/docs/api/aggregation.md index 23ac5b8..48a1531 100644 --- a/docs/api/aggregation.md +++ b/docs/api/aggregation.md @@ -1,3 +1,12 @@ # Aggregation +!!! tips "Recommendation" + We recommend using the `sopa.aggregate` function below, which is a wrapper for all types of aggregation. Internally, it uses `aggregate_channels`, `count_transcripts`, and/or `aggregate_bins`, which are also documented below if needed. + ::: sopa.aggregate + +::: sopa.aggregation.aggregate_channels + +::: sopa.aggregation.count_transcripts + +::: sopa.aggregation.aggregate_bins diff --git a/sopa/aggregation/__init__.py b/sopa/aggregation/__init__.py index 0c22736..ecaa069 100644 --- a/sopa/aggregation/__init__.py +++ b/sopa/aggregation/__init__.py @@ -1,5 +1,5 @@ from .bins import aggregate_bins -from .channels import average_channels +from .channels import average_channels, aggregate_channels from .transcripts import count_transcripts from .aggregation import aggregate, Aggregator from .overlay import overlay_segmentation diff --git a/sopa/aggregation/aggregation.py b/sopa/aggregation/aggregation.py index acd9701..70485e0 100644 --- a/sopa/aggregation/aggregation.py +++ b/sopa/aggregation/aggregation.py @@ -17,7 +17,9 @@ get_spatial_element, get_spatial_image, ) -from . import aggregate_bins, average_channels, count_transcripts +from . import aggregate_bins +from . import aggregate_channels as _aggregate_channels +from . import count_transcripts log = logging.getLogger(__name__) @@ -157,7 +159,7 @@ def compute_table( self.filter_cells(self.table.X.sum(axis=1) < min_transcripts) if aggregate_channels: - mean_intensities = average_channels( + mean_intensities = _aggregate_channels( self.sdata, image_key=self.image_key, shapes_key=self.shapes_key, diff --git a/sopa/aggregation/channels.py b/sopa/aggregation/channels.py index eda6c25..f139046 100644 --- a/sopa/aggregation/channels.py +++ b/sopa/aggregation/channels.py @@ -5,6 +5,7 @@ import dask import geopandas as gpd import numpy as np +import numpy.ma as ma import shapely from dask.diagnostics import ProgressBar from shapely.geometry import Polygon, box @@ -16,35 +17,47 @@ log = logging.getLogger(__name__) +AVAILABLE_MODES = ["average", "min", "max"] + def average_channels( + sdata: SpatialData, image_key: str = None, shapes_key: str = None, expand_radius_ratio: float = 0 +) -> np.ndarray: + log.warning("average_channels is deprecated, use `aggregate_channels` instead") + return aggregate_channels(sdata, image_key, shapes_key, expand_radius_ratio, mode="average") + + +def aggregate_channels( sdata: SpatialData, image_key: str = None, shapes_key: str = None, expand_radius_ratio: float = 0, + mode: str = "average", ) -> np.ndarray: - """Average channel intensities per cell. + """Aggregate the channel intensities per cell (either `"average"`, or take the `"min"` / `"max"`). Args: sdata: A `SpatialData` object image_key: Key of `sdata` containing the image. If only one `images` element, this does not have to be provided. shapes_key: Key of `sdata` containing the cell boundaries. If only one `shapes` element, this does not have to be provided. expand_radius_ratio: Cells polygons will be expanded by `expand_radius_ratio * mean_radius`. This help better aggregate boundary stainings. + mode: Aggregation mode. One of `"average"`, `"min"`, `"max"`. By default, average the intensity inside the cell mask. Returns: A numpy `ndarray` of shape `(n_cells, n_channels)` """ + assert mode in AVAILABLE_MODES, f"Invalid {mode=}. Available modes are {AVAILABLE_MODES}" + image = get_spatial_image(sdata, image_key) geo_df = get_boundaries(sdata, key=shapes_key) geo_df = to_intrinsic(sdata, geo_df, image) geo_df = expand_radius(geo_df, expand_radius_ratio) - log.info(f"Averaging channels intensity over {len(geo_df)} cells with expansion {expand_radius_ratio=}") - return _average_channels_aligned(image, geo_df) + return _aggregate_channels_aligned(image, geo_df, mode) -def _average_channels_aligned(image: DataArray, geo_df: gpd.GeoDataFrame | list[Polygon]) -> np.ndarray: +def _aggregate_channels_aligned(image: DataArray, geo_df: gpd.GeoDataFrame | list[Polygon], mode: str) -> np.ndarray: """Average channel intensities per cell. The image and cells have to be aligned, i.e. be on the same coordinate system. Args: @@ -54,11 +67,17 @@ def _average_channels_aligned(image: DataArray, geo_df: gpd.GeoDataFrame | list[ Returns: A numpy `ndarray` of shape `(n_cells, n_channels)` """ + log.info(f"Aggregating channels intensity over {len(geo_df)} cells with {mode=}") + cells = geo_df if isinstance(geo_df, list) else list(geo_df.geometry) tree = shapely.STRtree(cells) - intensities = np.zeros((len(cells), len(image.coords["c"]))) + n_channels = len(image.coords["c"]) areas = np.zeros(len(cells)) + if mode == "min": + aggregation = np.full((len(cells), n_channels), fill_value=np.inf) + else: + aggregation = np.zeros((len(cells), n_channels)) chunk_sizes = image.data.chunks offsets_y = np.cumsum(np.pad(chunk_sizes[1], (1, 0), "constant")) @@ -86,9 +105,20 @@ def _average_chunk_inside_cells(chunk, iy, ix): mask = rasterize(cell, sub_image.shape[1:], bounds) - intensities[index] += np.sum(sub_image * mask, axis=(1, 2)) areas[index] += np.sum(mask) + if mode == "min": + masked_image = ma.masked_array(sub_image, 1 - np.repeat(mask[None], n_channels, axis=0)) + aggregation[index] = np.minimum(aggregation[index], masked_image.min(axis=(1, 2))) + elif mode in ["average", "max"]: + func = np.sum if mode == "average" else np.max + values = func(sub_image * mask, axis=(1, 2)) + + if mode == "average": + aggregation[index] += values + else: + aggregation[index] = np.maximum(aggregation[index], values) + with ProgressBar(): tasks = [ dask.delayed(_average_chunk_inside_cells)(chunk, iy, ix) @@ -97,4 +127,7 @@ def _average_chunk_inside_cells(chunk, iy, ix): ] dask.compute(tasks) - return intensities / areas[:, None].clip(1) + if mode == "average": + return aggregation / areas[:, None].clip(1) + else: + return aggregation diff --git a/tests/test_aggregation.py b/tests/test_aggregation.py index ac1ef00..0d4a6c5 100644 --- a/tests/test_aggregation.py +++ b/tests/test_aggregation.py @@ -6,14 +6,14 @@ import xarray as xr from shapely.geometry import Polygon, box -from sopa.aggregation.channels import _average_channels_aligned +from sopa.aggregation.channels import _aggregate_channels_aligned from sopa.aggregation.transcripts import _count_transcripts_aligned dask.config.set({"dataframe.query-planning": False}) import dask.dataframe as dd # noqa -def test_average_channels_aligned(): +def test_aggregate_channels_aligned(): image = np.random.randint(1, 10, size=(3, 8, 16)) arr = da.from_array(image, chunks=(1, 8, 8)) xarr = xr.DataArray(arr, dims=["c", "y", "x"]) @@ -24,11 +24,23 @@ def test_average_channels_aligned(): # One cell is on the first block, one is overlapping on both blocks, and one is on the last block cells = [box(x, y, x + cell_size - 1, y + cell_size - 1) for x, y in cell_start] - means = _average_channels_aligned(xarr, cells) + mean_intensities = _aggregate_channels_aligned(xarr, cells, "average") + min_intensities = _aggregate_channels_aligned(xarr, cells, "min") + max_intensities = _aggregate_channels_aligned(xarr, cells, "max") - true_means = np.stack([image[:, y : y + cell_size, x : x + cell_size].mean(axis=(1, 2)) for x, y in cell_start]) + true_mean_intensities = np.stack( + [image[:, y : y + cell_size, x : x + cell_size].mean(axis=(1, 2)) for x, y in cell_start] + ) + true_min_intensities = np.stack( + [image[:, y : y + cell_size, x : x + cell_size].min(axis=(1, 2)) for x, y in cell_start] + ) + true_max_intensities = np.stack( + [image[:, y : y + cell_size, x : x + cell_size].max(axis=(1, 2)) for x, y in cell_start] + ) - assert (means == true_means).all() + assert (mean_intensities == true_mean_intensities).all() + assert (min_intensities == true_min_intensities).all() + assert (max_intensities == true_max_intensities).all() def test_count_transcripts():