diff --git a/.github/workflows/publish_pypi.yml b/.github/workflows/publish_pypi.yml old mode 100755 new mode 100644 diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml old mode 100755 new mode 100644 diff --git a/.gitignore b/.gitignore old mode 100755 new mode 100644 diff --git a/.readthedocs.yml b/.readthedocs.yml old mode 100755 new mode 100644 diff --git a/MANIFEST.in b/MANIFEST.in old mode 100755 new mode 100644 diff --git a/README.md b/README.md old mode 100755 new mode 100644 index 9fa86ff..cfd3f54 --- a/README.md +++ b/README.md @@ -2,6 +2,7 @@ [![PyPI version](https://badge.fury.io/py/bento-tools.svg)](https://badge.fury.io/py/bento-tools) [![codecov](https://codecov.io/gh/ckmah/bento-tools/branch/master/graph/badge.svg?token=XVHDKNDCDT)](https://codecov.io/gh/ckmah/bento-tools) [![Documentation Status](https://readthedocs.org/projects/bento-tools/badge/?version=latest)](https://bento-tools.readthedocs.io/en/latest/?badge=latest) +[![Downloads](https://static.pepy.tech/badge/bento-tools)](https://pepy.tech/project/bento-tools) ![PyPI - Downloads](https://img.shields.io/pypi/dm/bento-tools) [![GitHub stars](https://badgen.net/github/stars/ckmah/bento-tools)](https://GitHub.com/Naereen/ckmah/bento-tools) diff --git a/bento/__init__.py b/bento/__init__.py old mode 100755 new mode 100644 index 9e134f6..7d9cd4f --- a/bento/__init__.py +++ b/bento/__init__.py @@ -1,10 +1,7 @@ -from ._settings import settings - -from . import datasets as ds from . import io from . import plotting as pl from . import tools as tl from . import _utils as ut from . import geometry as geo +from . import query as qy from .plotting import _colors as colors -from ._utils import sync diff --git a/bento/_constants.py b/bento/_constants.py index cdc5186..1be7869 100644 --- a/bento/_constants.py +++ b/bento/_constants.py @@ -1,18 +1,3 @@ PATTERN_COLORS = ["#17becf", "#1f77b4", "#7f7f7f", "#ff7f0e", "#d62728"] PATTERN_NAMES = ["cell_edge", "cytoplasmic", "none", "nuclear", "nuclear_edge"] PATTERN_PROBS = [f"{p}_p" for p in PATTERN_NAMES] -PATTERN_FEATURES = [ - "cell_inner_proximity", - "nucleus_inner_proximity", - "nucleus_outer_proximity", - "cell_inner_asymmetry", - "nucleus_inner_asymmetry", - "nucleus_outer_asymmetry", - "l_max", - "l_max_gradient", - "l_min_gradient", - "l_monotony", - "l_half_radius", - "point_dispersion_norm", - "nucleus_dispersion_norm", -] diff --git a/bento/_settings.py b/bento/_settings.py deleted file mode 100644 index 3aee303..0000000 --- a/bento/_settings.py +++ /dev/null @@ -1,101 +0,0 @@ -import logging -from rich.logging import RichHandler - - -class Settings: - """ - Settings class for Bento. - - Parameters - ---------- - verbosity : int - Verbosity level for logging. Default is 0. See the logging module for more information (https://docs.python.org/3/howto/logging.html#logging-levels). - log : logging.Logger - Logger object for Bento. - """ - - def __init__(self, verbosity): - self._verbosity = verbosity - self._log = Logger(verbosity) - - @property - def verbosity(self): - return self._verbosity - - @verbosity.setter - def verbosity(self, value): - self._verbosity = value - self._log.setLevel(value) - - @property - def log(self): - return self._log - - -class Logger: - def __init__(self, verbosity): - - FORMAT = "%(message)s" - logging.basicConfig( - level=verbosity, format=FORMAT, datefmt="[%X]", handlers=[RichHandler(markup=True)] - ) - self._logger = logging.getLogger("rich") - - def debug(self, text): - self._logger.debug(text) - - def info(self, text): - self._logger.info(text) - - def warn(self, text): - self._logger.warning(text) - - - def start(self, text): - """ - Alias for self.info(). Start logging a method. - - Parameters - ---------- - text : str - Text to log. - """ - self.info(f"[bold]{text}[/]") - - - def step(self, text): - """ - Alias for self.info(). Step logging. - - Parameters - ---------- - text : str - Text to log. - """ - self._logger.info(text) - - - def end(self, text): - """ - End logging. - - Parameters - ---------- - text : str - Text to log. - """ - self._logger.info(f"[bold]{text}[/]") - - def setLevel(self, value): - """ - Set the verbosity level of the logger. - - Parameters - ---------- - value : int - Verbosity level for logging. - """ - self._logger.setLevel(value) - - -settings = Settings(verbosity="WARNING") diff --git a/bento/_utils.py b/bento/_utils.py index 4f14e97..e69de29 100644 --- a/bento/_utils.py +++ b/bento/_utils.py @@ -1,329 +0,0 @@ -import inspect -import warnings -import geopandas as gpd -import pandas as pd -import seaborn as sns -from anndata import AnnData -from functools import wraps -from typing import Iterable -from shapely import wkt - -from ._settings import settings - - -def get_default_args(func): - signature = inspect.signature(func) - return { - k: v.default - for k, v in signature.parameters.items() - if v.default is not inspect.Parameter.empty - } - - -def track(func): - """ - Track changes in AnnData object after applying function. - - 1. First remembers a shallow list of AnnData attributes by listing keys from obs, var, etc. - 2. Perform arbitrary task - 3. List attributes again, perform simple diff between list of old and new attributes - 4. Print to user added and removed keys - - Parameters - ---------- - func : function - """ - - @wraps(func) - def wrapper(*args, **kwds): - kwargs = get_default_args(func) - kwargs.update(kwds) - - if type(args[0]) == AnnData: - adata = args[0] - else: - adata = args[1] - - old_attr = list_attributes(adata) - - if kwargs["copy"]: - out_adata = func(*args, **kwds) - new_attr = list_attributes(out_adata) - else: - func(*args, **kwds) - new_attr = list_attributes(adata) - - # Print differences between new and old adata - out = "" - out += "AnnData object modified:" - - if old_attr["n_obs"] != new_attr["n_obs"]: - out += f"\nn_obs: {old_attr['n_obs']} -> {new_attr['n_obs']}" - - if old_attr["n_vars"] != new_attr["n_vars"]: - out += f"\nn_vars: {old_attr['n_vars']} -> {new_attr['n_vars']}" - - modified = False - for attr in old_attr.keys(): - if attr == "n_obs" or attr == "n_vars": - continue - - removed = list(old_attr[attr] - new_attr[attr]) - added = list(new_attr[attr] - old_attr[attr]) - - if len(removed) > 0 or len(added) > 0: - modified = True - out += f"\n {attr}:" - if len(removed) > 0: - out += f"\n - {', '.join(removed)}" - if len(added) > 0: - out += f"\n + {', '.join(added)}" - - if modified: - settings.log.info(out) - - return out_adata if kwargs["copy"] else None - - return wrapper - - -def list_attributes(adata): - """Traverse AnnData object attributes and list keys. - - Parameters - ---------- - adata : AnnData - AnnData object - - Returns - ------- - dict - Dictionary of keys for each AnnData attribute. - """ - found_attr = dict(n_obs=adata.n_obs, n_vars=adata.n_vars) - for attr in [ - "obs", - "var", - "uns", - "obsm", - "varm", - "layers", - "obsp", - "varp", - ]: - keys = set(getattr(adata, attr).keys()) - found_attr[attr] = keys - - return found_attr - - -def pheno_to_color(pheno, palette): - """ - Maps list of categorical labels to a color palette. - Input values are first sorted alphanumerically least to greatest before mapping to colors. - This ensures consistent colors regardless of input value order. - - Parameters - ---------- - pheno : pd.Series - Categorical labels to map - palette: None, string, or sequence, optional - Name of palette or None to return current palette. - If a sequence, input colors are used but possibly cycled and desaturated. - Taken from sns.color_palette() documentation. - - Returns - ------- - dict - Mapping of label to color in RGBA - tuples - List of converted colors for each sample, formatted as RGBA tuples. - - """ - if isinstance(palette, str): - palette = sns.color_palette(palette) - - values = list(set(pheno)) - values.sort() - palette = sns.color_palette(palette, n_colors=len(values)) - study2color = dict(zip(values, palette)) - sample_colors = [study2color[v] for v in pheno] - return study2color, sample_colors - - -def sync(data, copy=False): - """ - Sync existing point sets and associated metadata with data.obs_names and data.var_names - - Parameters - ---------- - data : AnnData - Spatial formatted AnnData object - copy : bool, optional - """ - adata = data.copy() if copy else data - - if "point_sets" not in adata.uns.keys(): - adata.uns["point_sets"] = dict(points=[]) - - # Iterate over point sets - for point_key in adata.uns["point_sets"]: - points = adata.uns[point_key] - - # Subset for cells - cells = adata.obs_names.tolist() - in_cells = points["cell"].isin(cells) - - # Subset for genes - in_genes = [True] * points.shape[0] - if "gene" in points.columns: - genes = adata.var_names.tolist() - in_genes = points["gene"].isin(genes) - - # Combine boolean masks - valid_mask = (in_cells & in_genes).values - - # Sync points using mask - points = points.loc[valid_mask] - - # Remove unused categories for categorical columns - for col in points.columns: - if points[col].dtype == "category": - points[col].cat.remove_unused_categories(inplace=True) - - adata.uns[point_key] = points - - # Sync point metadata using mask - for metadata_key in adata.uns["point_sets"][point_key]: - if metadata_key not in adata.uns: - warnings.warn( - f"Skipping: metadata {metadata_key} not found in adata.uns" - ) - continue - - metadata = adata.uns[metadata_key] - # Slice DataFrame if not empty - if isinstance(metadata, pd.DataFrame) and not metadata.empty: - adata.uns[metadata_key] = metadata.loc[valid_mask, :] - - # Slice Iterable if not empty - elif isinstance(metadata, list) and any(metadata): - adata.uns[metadata_key] = [ - m for i, m in enumerate(metadata) if valid_mask[i] - ] - elif isinstance(metadata, Iterable) and metadata.shape[0] > 0: - adata.uns[metadata_key] = adata.uns[metadata_key][valid_mask] - else: - warnings.warn(f"Metadata {metadata_key} is not a DataFrame or Iterable") - - return adata if copy else None - - -def _register_points(data, point_key, metadata_keys): - required_cols = ["x", "y", "cell"] - - if point_key not in data.uns.keys(): - raise ValueError(f"Key {point_key} not found in data.uns") - - points = data.uns[point_key] - - if not all([col in points.columns for col in required_cols]): - raise ValueError( - f"Point DataFrame must have columns {', '.join(required_cols)}" - ) - - # Check for valid cells - cells = data.obs_names.tolist() - if not points["cell"].isin(cells).all(): - raise ValueError("Invalid cells in point DataFrame") - - # Initialize/add to point registry - if "point_sets" not in data.uns.keys(): - data.uns["point_sets"] = dict() - - if point_key not in data.uns["point_sets"].keys(): - data.uns["point_sets"][point_key] = [] - - if len(metadata_keys) < 0: - return - - # Register metadata - for key in metadata_keys: - # Check for valid metadata - if key not in data.uns.keys(): - raise ValueError(f"Key {key} not found in data.uns") - - n_points = data.uns[point_key].shape[0] - metadata_len = data.uns[key].shape[0] - if metadata_len != n_points: - raise ValueError( - f"Metadata {key} must have same length as points {point_key}" - ) - - # Add metadata key to registry - if key not in data.uns["point_sets"][point_key]: - data.uns["point_sets"][point_key].append(key) - - -def register_points(point_key: str, metadata_keys: list): - """Decorator function to register points to the current `AnnData` object. - This keeps track of point sets and keeps them in sync with `AnnData` object. - - Parameters - ---------- - point_key : str - Key where points are stored in `data.uns` - metadata_keys : list - Keys where point metadata are stored in `data.uns` - """ - - def decorator(func): - @wraps(func) - def wrapper(*args, **kwds): - kwargs = get_default_args(func) - kwargs.update(kwds) - - if kwargs["copy"]: - data = func(*args, **kwds) - else: - func(*args, **kwds) - data = args[0] - - # Check for required columns - _register_points(data, point_key, metadata_keys) - return data - - return wrapper - - return decorator - - -def sc_format(data, copy=False): - """ - Convert data.obs GeoPandas columns to string for compatibility with scanpy. - """ - adata = data.copy() if copy else data - - shape_names = data.obs.columns.str.endswith("_shape") - - for col in data.obs.columns[shape_names]: - adata.obs[col] = adata.obs[col].astype(str) - - return adata if copy else None - - -def geo_format(data, copy=False): - """ - Convert data.obs scanpy columns to GeoPandas compatible types. - """ - adata = data.copy() if copy else data - - shape_names = adata.obs.columns[adata.obs.columns.str.endswith("_shape")] - - adata.obs[shape_names] = adata.obs[shape_names].apply( - lambda col: gpd.GeoSeries( - col.astype(str).apply(lambda val: wkt.loads(val) if val != "None" else None) - ) - ) - - return adata if copy else None diff --git a/bento/datasets/__init__.py b/bento/datasets/__init__.py deleted file mode 100755 index 55082a3..0000000 --- a/bento/datasets/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from ._datasets import ( - get_dataset_info, - load_dataset, - sample_data, -) diff --git a/bento/datasets/_datasets.py b/bento/datasets/_datasets.py deleted file mode 100755 index b098a60..0000000 --- a/bento/datasets/_datasets.py +++ /dev/null @@ -1,115 +0,0 @@ -import os - -import pandas as pd - -pkg_resources = None - -from ..io import read_h5ad - - -def get_dataset_info(): - """Return DataFrame with info about builtin datasets. - - Returns - ------- - DataFrame - Info about builtin datasets indexed by dataset name. - """ - global pkg_resources - if pkg_resources is None: - import pkg_resources - - stream = pkg_resources.resource_stream(__name__, "datasets.csv") - return pd.read_csv(stream, index_col=0) - - -def load_dataset(name, cache=True, data_home="~/bento-data", **kws): - """Load a builtin dataset. - - Parameters - ---------- - name : str - Name of the dataset to load. - cache : bool (default: True) - If True, try to load from local cache first, download as needed. - data_home : str (default: "~/bento-data") - Path to directory where datasets are stored. - **kws - Keyword arguments passed to :func:`bento.io.read_h5ad`. - """ - datainfo = get_dataset_info() - - # Check if dataset name exists - if name not in datainfo.index: - raise KeyError( - f"No builtin dataset named '{name}'. Use :func:`bento.ds.get_dataset_info` to list info about available datasets." - ) - - # Sanitize user path - data_home = os.path.expanduser(data_home) - - # Make data folder if it doesn't exist - if not os.path.exists(data_home): - os.makedirs(data_home) - - # Try to load from local cache first, download as needed - url = datainfo.loc[name, "url"] - cache_path = os.path.join(data_home, os.path.basename(url)) - if cache: - if not os.path.exists(cache_path): - _download(url, cache_path) - else: - _download(url, cache_path) - - adata = read_h5ad(cache_path, **kws) - - return adata - - -# Taken from https://github.com/theislab/scanpy/blob/master/scanpy/readwrite.py -def _download(url, path): - try: - import ipywidgets - from tqdm.auto import tqdm - except ImportError: - from tqdm import tqdm - - from urllib.request import urlopen, Request - from pathlib import Path - - blocksize = 1024 * 8 - blocknum = 0 - - path = Path(path) - - try: - with urlopen(Request(url, headers={"User-agent": "bento"})) as resp: - total = resp.info().get("content-length", None) - with tqdm( - unit="B", - unit_scale=True, - miniters=1, - unit_divisor=1024, - total=total if total is None else int(total), - ) as t, path.open("wb") as f: - block = resp.read(blocksize) - while block: - f.write(block) - blocknum += 1 - t.update(len(block)) - block = resp.read(blocksize) - - except (KeyboardInterrupt, Exception): - # Make sure file doesn’t exist half-downloaded - if path.is_file(): - path.unlink() - raise - - -def sample_data(): - global pkg_resources - if pkg_resources is None: - import pkg_resources - - stream = pkg_resources.resource_stream(__name__, "merfish_sample.h5ad") - return read_h5ad(stream) diff --git a/bento/geometry/__init__.py b/bento/geometry/__init__.py index c0bc43a..e85c9ca 100644 --- a/bento/geometry/__init__.py +++ b/bento/geometry/__init__.py @@ -1,9 +1,10 @@ from ._geometry import ( - count_points, - crop, + sjoin_points, + sjoin_shapes, get_points, - get_points_metadata, get_shape, - rename_shapes, - sindex_points, -) + get_points_metadata, + get_shape_metadata, + set_points_metadata, + set_shape_metadata, +) \ No newline at end of file diff --git a/bento/geometry/_geometry.py b/bento/geometry/_geometry.py index 783d000..faad45e 100644 --- a/bento/geometry/_geometry.py +++ b/bento/geometry/_geometry.py @@ -1,310 +1,400 @@ -import re -from typing import Dict, List, Literal, Optional, Tuple, Union +from typing import Dict, Iterable, List, Literal, Optional, Tuple, Union import geopandas as gpd +import numpy as np import pandas as pd -from anndata import AnnData -from scipy.sparse import coo_matrix -from shapely import wkt +import dask.dataframe as dd +from spatialdata._core.spatialdata import SpatialData +from spatialdata.models import PointsModel, ShapesModel from shapely.geometry import Polygon -from tqdm.auto import tqdm -from .._utils import sync - -def count_points( - data: AnnData, shape_names: List[str], copy: bool = False -) -> Optional[AnnData]: - """Count points in shapes and add as layers to `data`. Expects points to already be indexed to shapes. +def sjoin_points( + sdata: SpatialData, + points_key: str, + shape_keys: List[str], +): + """Index points to shapes and add as columns to `data.points[points_key]`. Only supports 2D points for now. Parameters ---------- - data : AnnData - Spatial formatted AnnData object - shape_names : str, list + sdata : SpatialData + Spatial formatted SpatialData object + points_key : str + Key for points DataFrame in `sdata.points` + shape_keys : str, list List of shape names to index points to - copy : bool, optional - Whether to return a copy the AnnData object. Default False. Returns ------- - AnnData - .layers: Updated layers with count of points in each shape + sdata : SpatialData + .points[points_key]: Updated points DataFrame with string index for each shape """ - adata = data.copy() if copy else data - if isinstance(shape_names, str): - shape_names = [shape_names] + if isinstance(shape_keys, str): + shape_keys = [shape_keys] - points = get_points(data, asgeo=True) + # Grab all shape GeoDataFrames to index points to + query_shapes = {} + for shape in shape_keys: + query_shapes[shape] = gpd.GeoDataFrame(geometry=sdata.shapes[shape].geometry) - if shape_names[0].endswith("_shape"): - shape_prefixes = [ - "_".join(shp_name.split("_shape")[:-1]) for shp_name in shape_names - ] - else: - shape_prefixes = shape_names + # Grab points as GeoDataFrame + points = get_points(sdata, points_key, astype="geopandas", sync=False) - shape_counts = points.groupby(["cell", "gene"], observed=True)[shape_prefixes].sum() + # Index points to shapes + for shape_key, shape in query_shapes.items(): + shape = query_shapes[shape_key] + shape.index.name = None + shape.index = shape.index.astype(str) - for shape in shape_counts.columns: - pos_counts = shape_counts[shape] - pos_counts = pos_counts[pos_counts > 0] - values = pos_counts + points = points.sjoin(shape, how="left", predicate="intersects") + points = points[~points.index.duplicated(keep="last")] + points.loc[points["index_right"].isna(), "index_right"] = "" + points.rename(columns={"index_right": shape_key}, inplace=True) - row = adata.obs_names.get_indexer(pos_counts.index.get_level_values("cell")) - col = adata.var_names.get_indexer(pos_counts.index.get_level_values("gene")) - adata.layers[f"{shape}"] = coo_matrix((values, (row, col))) + set_points_metadata(sdata, points_key, points[shape_key]) + + return sdata - return adata if copy else None - -def sindex_points( - data: AnnData, points_key: str, shape_names: List[str], copy: bool = False -) -> Optional[AnnData]: - """Index points to shapes and add as columns to `data.uns[points_key]`. +def sjoin_shapes( + sdata: SpatialData, + instance_key: str, + shape_keys: List[str] +): + """Adds polygon indexes to sdata.shapes[instance_key][shape_key] for point feature analysis Parameters ---------- - data : AnnData - Spatial formatted AnnData object - points_key : str - Key for points DataFrame in `data.uns` - shape_names : str, list - List of shape names to index points to - copy : bool, optional - Whether to return a copy the AnnData object. Default False. + sdata : SpatialData + Spatially formatted SpatialData + instance_key : str + Key for the shape that will be used as the instance for all indexing. Usually the cell shape. + shape_keys : str or list of str + Names of the shapes to add. + Returns ------- - AnnData - .uns[points_key]: Updated points DataFrame with boolean column for each shape + sdata : SpatialData + .shapes[cell_shape_key][shape_key] """ - adata = data.copy() if copy else data - - if isinstance(shape_names, str): - shape_names = [shape_names] - - points = get_points(data, points_key, asgeo=True).sort_values("cell") - points = points.drop( - columns=shape_names, errors="ignore" - ) # Drop columns to overwrite - points_grouped = points.groupby("cell", observed=True) - cells = list(points_grouped.groups.keys()) - point_sindex = [] - - # Iterate over cells and index points to shapes - for cell in tqdm(cells, leave=False): - pt_group = points_grouped.get_group(cell) - - # Get shapes to index in current cell - cur_shapes = gpd.GeoDataFrame(geometry=data.obs.loc[cell, shape_names].T) - cur_sindex = ( - pt_group.reset_index() - .sjoin(cur_shapes, how="left", op="intersects") - .drop_duplicates(subset="index", keep="first") - .sort_index() - .reset_index()["index_right"] - .astype(str) - ) - point_sindex.append(cur_sindex) - # TODO: concat is hella slow - point_sindex = ( - pd.concat(point_sindex, ignore_index=True).str.get_dummies() == 1 - ).fillna(False) - point_sindex.columns = [col.replace("_shape", "") for col in point_sindex.columns] + # Cast to list if not already + if isinstance(shape_keys, str): + shape_keys = [shape_keys] - # Add new columns to points - points[point_sindex.columns] = point_sindex.values - adata.uns[points_key] = points + # Check if shapes are already indexed to instance_key shape + shape_keys = ( + set(shape_keys) - set(sdata.shapes[instance_key].columns) - set(instance_key) + ) - return adata if copy else None + if len(shape_keys) == 0: + return sdata + parent_shape = sdata.shapes[instance_key] -def crop( - data: AnnData, - xlims: Tuple[int], - ylims: Tuple[int], - copy: bool = True, -) -> Optional[AnnData]: - """Returns a view of data within specified coordinates. + # sjoin shapes to instance_key shape + for shape_key in shape_keys: + child_shape = gpd.GeoDataFrame(geometry=sdata.shapes[shape_key]["geometry"]) + parent_shape = parent_shape.sjoin(child_shape, how="left", predicate="contains") + parent_shape = parent_shape[~parent_shape.index.duplicated(keep="last")] + parent_shape.loc[parent_shape["index_right"].isna(), "index_right"] = "" + parent_shape = parent_shape.astype({"index_right": "category"}) - Parameters - ---------- - data : AnnData - Spatial formatted AnnData object - xlims : list, optional - Upper and lower x limits, by default None - ylims : list, optional - Upper and lower y limits, by default None - copy : bool, optional - Whether to return a copy the AnnData object. Default True. + # save shape index as column in instance_key shape + parent_shape.rename(columns={"index_right": shape_key}, inplace=True) + set_shape_metadata(sdata, shape_key=instance_key, metadata=parent_shape[shape_key]) - Returns - ------- - AnnData - AnnData object with data cropped to specified coordinates - """ - adata = data.copy() if copy else data + # Add instance_key shape index to shape + parent_shape.index.name = "parent_index" + instance_index = parent_shape.reset_index().set_index(shape_key)["parent_index"] + instance_index.name = instance_key + instance_index.index.name = None + instance_index = instance_index[instance_index.index != ""] - if len(xlims) < 1 and len(xlims) > 2: - return ValueError("Invalid xlims") + set_shape_metadata(sdata, shape_key=shape_key, metadata=instance_index) - if len(ylims) < 1 and len(ylims) > 2: - return ValueError("Invalid ylims") + return sdata - xmin, xmax = xlims[0], xlims[1] - ymin, ymax = ylims[0], ylims[1] - box = Polygon([[xmin, ymin], [xmin, ymax], [xmax, ymax], [xmax, ymin]]) - in_crop = get_shape(data, "cell_shape").within(box) +def get_points( + sdata: SpatialData, + points_key: str = "transcripts", + astype: str = "pandas", + sync: bool = True, +) -> Union[pd.DataFrame, dd.DataFrame, gpd.GeoDataFrame]: + """Get points DataFrame synced to AnnData object. - adata = data[in_crop, :] - sync(adata, copy=False) + Parameters + ---------- + data : SpatialData + Spatial formatted SpatialData object + key : str, optional + Key for `data.points` to use, by default "transcripts" + astype : str, optional + Whether to return a 'pandas' DataFrame, 'dask' DataFrame, or 'geopandas' GeoDataFrame, by default "pandas" - return adata if copy else None + Returns + ------- + DataFrame or GeoDataFrame + Returns `data.points[key]` as a `[Geo]DataFrame` or 'Dask DataFrame' + """ + if points_key not in sdata.points.keys(): + raise ValueError(f"Points key {points_key} not found in sdata.points") + if astype not in ["pandas", "dask", "geopandas"]: + raise ValueError( + f"astype must be one of ['dask', 'pandas', 'geopandas'], not {astype}" + ) -def get_shape(adata: AnnData, shape_name: str) -> gpd.GeoSeries: - """Get a GeoSeries of Polygon objects from an AnnData object. + points = sdata.points[points_key] + + # Sync points to instance_key + if sync: + _check_points_sync(sdata, points_key) + instance_key = points.attrs["spatialdata_attrs"]["instance_key"] + + point_index = sdata.points[points_key][instance_key] + valid_points = point_index != "" + points = points[valid_points] + + if astype == "pandas": + return points.compute() + elif astype == "dask": + return points + elif astype == "geopandas": + points = points.compute() + return gpd.GeoDataFrame( + points, geometry=gpd.points_from_xy(points.x, points.y), copy=True + ) + +def get_shape(sdata: SpatialData, shape_key: str, sync: bool = True) -> gpd.GeoSeries: + """Get a GeoSeries of Polygon objects from an SpatialData object. Parameters ---------- - adata : AnnData - Spatial formatted AnnData object - shape_name : str - Name of shape column in adata.obs + sdata : SpatialData + Spatial formatted SpatialData object + shape_key : str + Name of shape column in sdata.shapes + sync : bool + Whether to retrieve shapes synced to cell shape. Default True. Returns ------- GeoSeries GeoSeries of Polygon objects """ - if shape_name not in adata.obs.columns: - raise ValueError(f"Shape {shape_name} not found in adata.obs.") - - if adata.obs[shape_name].astype(str).str.startswith("POLYGON").any(): - return gpd.GeoSeries( - adata.obs[shape_name] - .astype(str) - .apply(lambda val: wkt.loads(val) if val != "None" else None) - ) + instance_key = sdata.table.uns["spatialdata_attrs"]["instance_key"] - else: - return gpd.GeoSeries(adata.obs[shape_name]) + # Make sure shape exists in sdata.shapes + if shape_key not in sdata.shapes.keys(): + raise ValueError(f"Shape {shape_key} not found in sdata.shapes") + if sync and shape_key != instance_key: + _check_shape_sync(sdata, shape_key, instance_key) + shape_index = sdata.shapes[shape_key][instance_key] + valid_shapes = shape_index != "" + return sdata.shapes[shape_key][valid_shapes].geometry -def rename_shapes( - data: AnnData, - mapping: Dict[str, str], - points_key: Optional[Union[List[str], None]] = None, - points_encoding: Union[List[Literal["label", "onehot"]], None] = None, - copy: bool = False, -) -> Optional[AnnData]: - """Rename shape columns in adata.obs and points columns in adata.uns. + return sdata.shapes[shape_key].geometry + +def get_points_metadata( + sdata: SpatialData, + metadata_keys: Union[str, List[str]], + points_key: str = "transcripts", + astype="pandas", +): + """Get points metadata. Parameters ---------- - adata : AnnData - Spatial formatted AnnData object - mapping : Dict[str, str] - Mapping of old shape names to new shape names - points_key : list of str, optional - List of keys for points DataFrame in `adata.uns`, by default None - points_encoding : {"label", "onehot"}, optional - Encoding type for each specified points - copy : bool, optional - Whether to return a copy of the AnnData object. Default False. + sdata : SpatialData + Spatial formatted SpatialData object + metadata_keys : str or list of str + Key(s) for `sdata.points[points_key][key]` to use + points_key : str, optional + Key for `sdata.points` to use, by default "transcripts" + astype : str, optional + Whether to return a 'pandas' Series or 'dask' DataFrame, by default "pandas" Returns ------- - AnnData - .obs: Updated shape column names - .uns[points_key]: Updated points shape(s) columns according to encoding type + pd.DataFrame or dd.DataFrame + Returns `sdata.points[points_key][metadata_keys]` as a `pd.DataFrame` or `dd.DataFrame` """ - adata = data.copy() if copy else data - adata.obs.rename(columns=mapping, inplace=True) - - # Map point columns - if points_key: - # Get mapping for points column names - prefix_map = { - _get_shape_prefix(shape_name): _get_shape_prefix(new_name) - for shape_name, new_name in mapping.items() - } - # Get mapping for label encoding - label_map = { - re.sub(r"\D", "", shape_name): re.sub(r"\D", "", new_name) - for shape_name, new_name in prefix_map.items() - } - - for p_key, p_encoding in zip(points_key, points_encoding): - if p_encoding == "label": - # Point column name with label encoding - col = re.sub(r"\d", "", list(prefix_map.keys())[0]) - adata.uns[p_key][col] = adata.uns[p_key][col].astype(str).map(label_map) - - elif p_encoding == "onehot": - # Remap column names - adata.uns[p_key].rename(columns=prefix_map, inplace=True) - - return adata if copy else None - - -def _get_shape_prefix(shape_name): - """Get prefix of shape name.""" - return "_".join(shape_name.split("_")[:-1]) - - -def get_points( - data: AnnData, key: str = "points", asgeo: bool = False -) -> Union[pd.DataFrame, gpd.GeoDataFrame]: - """Get points DataFrame synced to AnnData object. + if points_key not in sdata.points.keys(): + raise ValueError(f"Points key {points_key} not found in sdata.points") + if astype not in ["pandas", "dask"]: + raise ValueError( + f"astype must be one of ['dask', 'pandas'], not {astype}" + ) + if isinstance(metadata_keys, str): + metadata_keys = [metadata_keys] + for key in metadata_keys: + if key not in sdata.points[points_key].columns: + raise ValueError(f"Metadata key {key} not found in sdata.points[{points_key}]") + + metadata = sdata.points[points_key][metadata_keys] + + if astype == "pandas": + return metadata.compute() + elif astype == "dask": + return metadata + +def get_shape_metadata( + sdata: SpatialData, + metadata_keys: Union[str, List[str]], + shape_key: str = "transcripts", +): + """Get shape metadata. Parameters ---------- - data : AnnData - Spatial formatted AnnData object - key : str, optional - Key for `data.uns` to use, by default "points" - asgeo : bool, optional - Cast as GeoDataFrame using columns x and y for geometry, by default False + sdata : SpatialData + Spatial formatted SpatialData object + metadata_keys : str or list of str + Key(s) for `sdata.shapes[shape_key][key]` to use + shape_key : str + Key for `sdata.shapes` to use, by default "transcripts" Returns ------- - DataFrame or GeoDataFrame - Returns `data.uns[key]` as a `[Geo]DataFrame` + pd.Dataframe + Returns `sdata.shapes[shape_key][metadata_keys]` as a `pd.DataFrame` """ - points = sync(data, copy=True).uns[key] + if shape_key not in sdata.shapes.keys(): + raise ValueError(f"Shape key {shape_key} not found in sdata.shapes") + if isinstance(metadata_keys, str): + metadata_keys = [metadata_keys] + for key in metadata_keys: + if key not in sdata.shapes[shape_key].columns: + raise ValueError(f"Metadata key {key} not found in sdata.shapes[{shape_key}]") + + return sdata.shapes[shape_key][metadata_keys] + +def set_points_metadata( + sdata: SpatialData, + points_key: str, + metadata: Union[List, pd.Series, pd.DataFrame], + column_names: Optional[Union[str, List[str]]] = None, +): + """Write metadata in SpatialData points element as column(s). Aligns metadata index to shape index. - # Cast to GeoDataFrame - if asgeo: - points = gpd.GeoDataFrame( - points, geometry=gpd.points_from_xy(points.x, points.y) - ) + Parameters + ---------- + sdata : SpatialData + Spatial formatted SpatialData object + points_key : str + Name of element in sdata.points + metadata : pd.Series, pd.DataFrame + Metadata to set for points. Index must be a (sub)set of points index. + column_names : str or list of str, optional + Name of column(s) to set. If None, use metadata column name(s), by default None + """ + if points_key not in sdata.points.keys(): + raise ValueError(f"{points_key} not found in sdata.points") + + if isinstance(metadata, list): + metadata = pd.Series(metadata, index=sdata.points[points_key].index) + + if isinstance(metadata, pd.Series): + metadata = pd.DataFrame(metadata) + + if column_names is not None: + if isinstance(column_names, str): + column_names = [column_names] + for i in range(len(column_names)): + metadata = metadata.rename(columns={metadata.columns[i]: column_names[i]}) + + sdata.points[points_key] = sdata.points[points_key].reset_index(drop=True) + for name, series in metadata.iteritems(): + series = series.fillna("") + metadata_series = dd.from_pandas(series, npartitions=sdata.points[points_key].npartitions).reset_index(drop=True) + sdata.points[points_key][name] = metadata_series + +def set_shape_metadata( + sdata: SpatialData, + shape_key: str, + metadata: Union[List, pd.Series, pd.DataFrame], + column_names: Optional[Union[str, List[str]]] = None, +): + """Write metadata in SpatialData shapes element as column(s). Aligns metadata index to shape index. + + Parameters + ---------- + sdata : SpatialData + Spatial formatted SpatialData object + shape_key : str + Name of element in sdata.shapes + metadata : pd.Series, pd.DataFrame + Metadata to set for shape. Index must be a (sub)set of shape index. + column_names : str or list of str, optional + Name of column(s) to set. If None, use metadata column name(s), by default None + """ + if shape_key not in sdata.shapes.keys(): + raise ValueError(f"Shape {shape_key} not found in sdata.shapes") + + if isinstance(metadata, list): + metadata = pd.Series(metadata, index=sdata.shapes[shape_key].index) + + if isinstance(metadata, pd.Series): + metadata = pd.DataFrame(metadata) + + if column_names is not None: + if isinstance(column_names, str): + column_names = [column_names] + for i in range(len(column_names)): + metadata = metadata.rename(columns={metadata.columns[i]: column_names[i]}) + + sdata.shapes[shape_key].loc[:, metadata.columns] = metadata.reindex( + sdata.shapes[shape_key].index + ).fillna("") + +def _check_points_sync(sdata, points_key): + """ + Check if points are synced to instance_key shape in a SpatialData object. - return points + Parameters + ---------- + sdata : SpatialData + The SpatialData object to check. + points_key : str + The name of the points to check. + Raises + ------ + ValueError + If the points are not synced to instance_key shape. + """ + points = sdata.points[points_key] + if points.attrs["spatialdata_attrs"]["instance_key"] not in points.columns: + raise ValueError( + f"Points {points_key} not synced to instance_key shape element. Run bento.io.format_sdata() to setup SpatialData object for bento-tools." + ) -def get_points_metadata( - data: AnnData, - metadata_key: str, - points_key: str = "points", -): - """Get points metadata synced to AnnData object. +def _check_shape_sync(sdata, shape_key, instance_key): + """ + Check if a shape is synced to instance_key shape in a SpatialData object. Parameters ---------- - data : AnnData - Spatial formatted AnnData object - metadata_key : str - Key for `data.uns[key]` to use - key : str, optional - Key for `data.uns` to use, by default "points" - - Returns - ------- - Series - Returns `data.uns[key][metadata_key]` as a `Series` + sdata : SpatialData + The SpatialData object to check. + shape_key : str + The name of the shape to check. + instance_key : str + The instance key of the shape to check. + + Raises + ------ + ValueError + If the shape is not synced to instance_key shape. """ - metadata = sync(data, copy=True).uns[metadata_key] - return metadata + if ( + shape_key != instance_key + and shape_key not in sdata.shapes[instance_key].columns + ): + raise ValueError( + f"Shape {shape_key} not synced to instance_key shape element. Run bento.io.format_sdata() to setup SpatialData object for bento-tools." + ) diff --git a/bento/io/__init__.py b/bento/io/__init__.py old mode 100755 new mode 100644 index 411bba4..ab1950d --- a/bento/io/__init__.py +++ b/bento/io/__init__.py @@ -1 +1 @@ -from ._io import read_h5ad, write_h5ad, concatenate, prepare \ No newline at end of file +from ._io import format_sdata \ No newline at end of file diff --git a/bento/io/_io.py b/bento/io/_io.py old mode 100755 new mode 100644 index bad724a..ae9e141 --- a/bento/io/_io.py +++ b/bento/io/_io.py @@ -1,453 +1,87 @@ import warnings +from typing import List warnings.filterwarnings("ignore") -import geopandas as gpd -import numpy as np -import pandas as pd -from shapely import geometry, wkt -from shapely.geometry import Polygon -from tqdm.auto import tqdm -import anndata -import rasterio -import rasterio.features -import emoji +from spatialdata._core.spatialdata import SpatialData +from spatialdata.models import ShapesModel, TableModel -from .._utils import sc_format +from ..geometry import sjoin_points, sjoin_shapes -def read_h5ad(filename, backed=None): - """Load bento processed AnnData object from h5ad. +def format_sdata( + sdata: SpatialData, + points_key: str = "transcripts", + feature_key: str = "feature_name", + instance_key: str = "cell_boundaries", + shape_keys: List[str] = ["cell_boundaries", "nucleus_boundaries"], +) -> SpatialData: + """Converts shape indices to strings and indexes points to shapes and add as columns to `data.points[point_key]`. Parameters ---------- - filename : str - File name to load data file. - backed : 'r', 'r+', True, False, None - If 'r', load AnnData in backed mode instead of fully loading it into memory (memory mode). - If you want to modify backed attributes of the AnnData object, you need to choose 'r+'. - By default None. - Returns - ------- - AnnData - AnnData data object. - """ - - adata = anndata.read_h5ad(filename, backed=backed) - # Load obs columns that are shapely geometries - adata.obs = adata.obs.apply( - lambda col: gpd.GeoSeries( - col.astype(str).apply(lambda val: wkt.loads(val) if val != "None" else None) - ) - if col.astype(str).str.startswith("POLYGON").any() - else pd.Series(col) - ) - - adata.obs.index.name = "cell" - adata.var.index.name = "gene" - - return adata - - -def write_h5ad(data, filename): - """Write AnnData to h5ad. - - Parameters - ---------- - adata : AnnData - bento loaded AnnData - filename : str - File name to write data file. - """ - # Convert geometry from GeoSeries to list for h5ad serialization compatibility - adata = data.copy() - - sc_format(adata) - - adata.uns["points"] = adata.uns["points"].drop("geometry", axis=1, errors="ignore") - - # Write to h5ad - adata.write(filename) - - -def prepare( - molecules, - cell_seg, - x="x", - y="y", - gene="gene", - other_seg=dict(), -): - """Prepare AnnData with molecule-level spatial data. - - Parameters - ---------- - molecules : DataFrame - Molecule coordinates and annotations. - cell_seg : np.array - Cell segmentation masks represented as 2D numpy array where 1st and 2nd - dimensions correspond to x and y respectively. Connected regions must - have same value to be considered a valid shape. Data type must be one - of rasterio.int16, rasterio.int32, rasterio.uint8, rasterio.uint16, or - rasterio.float32. See rasterio.features.shapes for more details. - x : str - Column name for x coordinates, by default 'x'. - y : str - Column name for x coordinates, by default 'y'. - gene : str - Column name for gene name, by default 'gene'. - other_seg - Additional keyword arguments are interpreted as additional segmentation - masks. The user specified parameter name is used to store these masks as - {name}_shape in adata.obs. - Returns - ------- - AnnData object - """ - for var in [x, y, gene]: - if var not in molecules.columns: - return - - pbar = tqdm(total=6) - pbar.set_description(emoji.emojize(":test_tube: Loading inputs")) - points = molecules[[x, y, gene]] - points.columns = ["x", "y", "gene"] - points = gpd.GeoDataFrame( - points, geometry=gpd.points_from_xy(x=points.x, y=points.y) - ) - points["gene"] = points["gene"].astype("category") # Save memory - pbar.update() - - # Load each set of masks as GeoDataFrame - # shapes = Series where index = segs.keys() and values = GeoDataFrames - segs_dict = {"cell": cell_seg, **other_seg} - # Already formatted, select geometry column already - if isinstance(cell_seg, gpd.GeoDataFrame): - shapes_dict = { - shape_name: shape_seg[["geometry"]] - for shape_name, shape_seg in segs_dict.items() - } - # Load shapes from numpy array image - elif isinstance(cell_seg, np.ndarray): - shapes_dict = { - shape_name: _load_shapes_np(shape_seg) - for shape_name, shape_seg in segs_dict.items() - } - else: - print("Segmentation mask format not recognized.") - pbar.close() - return - pbar.update() - - # Index shapes to cell - pbar.set_description(emoji.emojize(":open_book: Indexing")) - obs_shapes = _index_shapes(shapes_dict, "cell") - pbar.update() - - # Index points for all shapes - # TODO: refactor to use geometry.sindex_points - point_index = dict() - for col in obs_shapes.columns: - shp_gdf = gpd.GeoDataFrame(geometry=obs_shapes[col]) - shp_name = "_".join(str(col).split("_")[:-1]) - point_index[shp_name] = _index_points(points, shp_gdf) - point_index = pd.DataFrame.from_dict(point_index) - pbar.update() - - # Main long dataframe for reformatting - pbar.set_description(emoji.emojize(":computer_disk: Formatting")) - uns_points = pd.concat( - [ - points[["x", "y", "gene"]].reset_index(drop=True), - point_index.reset_index(drop=True), - ], - axis=1, - ) - - # Remove extracellular points - uns_points = uns_points.loc[uns_points["cell"] != "-1"] - if len(uns_points) == 0: - print("No molecules found within cells. Data not processed.") - pbar.close() - return - uns_points[["cell", "gene"]] = uns_points[["cell", "gene"]].astype("category") - - # Aggregate points to counts - expression = ( - uns_points[["cell", "gene"]] - .groupby(["cell", "gene"]) - .apply(lambda x: x.shape[0]) - .reset_index() - ) - - # Create cell x gene matrix - cellxgene = expression.pivot_table( - index="cell", columns="gene", aggfunc="sum" - ).fillna(0) - cellxgene.columns = cellxgene.columns.get_level_values("gene") - pbar.update() - - # Create scanpy anndata object - pbar.set_description(emoji.emojize(":package: Create AnnData")) - adata = anndata.AnnData(X=cellxgene) - obs_shapes = obs_shapes.reindex(index=adata.obs.index) - adata.obs = pd.concat([adata.obs, obs_shapes], axis=1) - adata.obs.index = adata.obs.index.astype(str) - - # Save cell, gene, batch, and other shapes as categorical type to save memory - uns_points["cell"] = uns_points["cell"].astype("category") - uns_points["gene"] = uns_points["gene"].astype("category") - for shape_name in list(other_seg.keys()): - uns_points[shape_name] = uns_points[shape_name].astype("category") - - adata.uns = {"points": uns_points} - - pbar.set_description(emoji.emojize(":bento_box: Finished!")) - pbar.update() - pbar.close() - return adata - - -def _load_shapes_np(seg_img): - """Extract shapes from segmentation image. - - Parameters - ---------- - seg_img : np.array - Segmentation masks represented as 2D numpy array where 1st and 2nd dimensions correspond to x and y respectively. - - Returns - ------- - GeoDataFrame - Single column GeoDataFrame where each row is a single Polygon. - """ - seg_img = seg_img.astype("uint16") - contours = rasterio.features.shapes(seg_img) # rasterio to generate contours - # Convert to shapely Polygons - polygons = [Polygon(p["coordinates"][0]) for p, v in contours] - shapes = gpd.GeoDataFrame(geometry=gpd.GeoSeries(polygons)) # Cast to GeoDataFrame - shapes.drop( - shapes.area.sort_values().tail(1).index, inplace=True - ) # Remove extraneous shape - shapes = shapes[shapes.geom_type != "MultiPolygon"] - - shapes.index = shapes.index.astype(str) - - # Cleanup polygons - # mask.geometry = mask.geometry.buffer(2).buffer(-2) - # mask.geometry = mask.geometry.apply(unary_union) - - return shapes - - -def _load_shapes_json(seg_json): - """Extract shapes from python object loaded with json. - - Parameters - ---------- - seg_json : list - list loaded by json.load(file) - - Returns - ------- - GeoDataFrame - Each row represents a single shape, - """ - polys = [] - for i in range(len(seg_json)): - polys.append(Polygon(seg_json[i]["coordinates"][0])) - - shapes = gpd.GeoDataFrame(geometry=gpd.GeoSeries(polys)) - shapes = shapes[shapes.geom_type != "MultiPolygon"] - - shapes.index = shapes.index.astype(str) - - # Cleanup polygons - # mask.geometry = mask.geometry.buffer(2).buffer(-2) - # mask.geometry = mask.geometry.apply(unary_union) - - return shapes - - -def _index_shapes(shapes, cell_key): - """Spatially index other masks to cell mask. - - Parameters - ---------- - shapes : dict - Dictionary of GeoDataFrames. - - Returns - ------- - indexed_shapes : GeoDataFrame - Each column is - """ - cell_shapes = shapes[cell_key] + sdata : SpatialData + Spatial formatted SpatialData object + points_key : str + Key for points DataFrame in `sdata.points` + feature_key : str + Key for the feature name in the points DataFrame + instance_key : str + Key for the shape that will be used as the instance for all indexing. Usually the cell shape. + shape_keys : str, list + List of shape names to index points to - indexed_shapes = cell_shapes.copy() - for shape_name, shape in shapes.items(): - - # Don't index cell to itself - if shape_name == "cell": - continue - - # For each cell, get all overlapping shapes - geometry = gpd.sjoin( - shape, cell_shapes, how="left", op="intersects", rsuffix="cell" - ).dropna() - - # Calculate fraction overlap for each pair SLOW - geometry["fraction_overlap"] = ( - geometry.intersection( - cell_shapes.loc[geometry["index_cell"]], align=False - ).area - / geometry.area - ) - - # Keep shape that overlaps with cell_shapes the most - geometry = ( - geometry.sort_values("fraction_overlap", ascending=False) - .drop_duplicates("index_cell") - .set_index("index_cell") - .reindex(cell_shapes.index)["geometry"] - ) - geometry.name = f"{shape_name}_shape" - - # Add indexed shapes as new column in GeoDataFrame - indexed_shapes[f"{shape_name}_shape"] = geometry - - # Cells are rows, intersecting shape sets are columns - indexed_shapes = indexed_shapes.rename(columns={"geometry": "cell_shape"}) - indexed_shapes.index = indexed_shapes.index.astype(str) - return indexed_shapes - - -def _index_points(points, shapes): - """Index points to each set of shapes item and save. Assumes non-overlapping shapes. - - Parameters - ---------- - points : GeoDataFrame - Point coordinates. - shapes : GeoDataFrame - Single column of Polygons. Returns ------- - Series - Return list of mask indices corresponding to each point. + SpatialData + .shapes[shape_key]: Updated shapes GeoDataFrame with string index + .points[points_key]: Updated points DataFrame with string index for each shape """ - index = gpd.sjoin(points.reset_index(), shapes, how="left", op="intersects") - - # remove multiple cells assigned to same point - index = ( - index.drop_duplicates(subset="index", keep="first") - .sort_index() - .reset_index()["index_right"] - .fillna(-1) - .astype(str) - ) - return pd.Series(index) - - -def concatenate(adatas): - # Read point registry to identify point sets to concatenate - # TODO - - uns_points = [] - for i, adata in enumerate(adatas): - points = adata.uns["points"].copy() - - if "batch" not in points.columns: - points["batch"] = i - - points["cell"] = points["cell"].astype(str) + "-" + str(i) - - uns_points.append(points) - - new_adata = adatas[0].concatenate(adatas[1:]) - - uns_points = pd.concat(uns_points) - uns_points["cell"] = uns_points["cell"].astype("category") - uns_points["gene"] = uns_points["gene"].astype("category") - uns_points["batch"] = uns_points["batch"].astype("category") - - new_adata.uns["points"] = uns_points - - return new_adata - - -def _to_spliced_expression(expression): - cell_nucleus = expression.pivot(index=["cell", "nucleus"], columns="gene") - unspliced = [] - spliced = [] - idx = pd.IndexSlice - spliced_index = "-1" - - def to_splice_layers(cell_df): - unspliced_index = ( - cell_df.index.get_level_values("nucleus") - .drop(spliced_index, errors="ignore") - .tolist() - ) - - unspliced.append( - cell_df.loc[idx[:, unspliced_index], :] - .sum() - .to_frame() - .T.reset_index(drop=True) + # Renames geometry column of shape element to match shape name + # Changes indices to strings + for shape_key, shape_gdf in sdata.shapes.items(): + if shape_key == instance_key: + shape_gdf[shape_key] = shape_gdf["geometry"] + shape_gdf.index = shape_gdf.index.astype(str, copy=False) + + # sindex points and sjoin shapes if they have not been indexed or joined + point_sjoin = [] + shape_sjoin = [] + + for shape_key in shape_keys: + # Compile list of shapes that need to be indexed to points + if shape_key not in sdata.points[points_key].columns: + point_sjoin.append(shape_key) + # Compile list of shapes that need to be joined to instance shape + if ( + shape_key != instance_key + and shape_key not in sdata.shapes[instance_key].columns + ): + shape_sjoin.append(shape_key) + + if len(point_sjoin) > 0: + sdata = sjoin_points( + sdata=sdata, points_key=points_key, shape_keys=point_sjoin ) - - # Extract spliced counts for this gene if there are any. - if spliced_index in cell_df.index.get_level_values("nucleus"): - spliced.append( - cell_df.xs(spliced_index, level="nucleus").reset_index(drop=True) - ) - else: - # Initialize empty zeros - spliced.append( - pd.DataFrame(np.zeros((1, cell_df.shape[1])), columns=cell_df.columns) - ) - - cell_nucleus.groupby("cell").apply(to_splice_layers) - - cells = cell_nucleus.index.get_level_values("cell").unique() - - unspliced = pd.concat(unspliced) - unspliced.index = cells - spliced = pd.concat(spliced) - spliced.index = cells - - spliced = spliced.fillna(0) - unspliced = unspliced.fillna(0) - - return spliced, unspliced - - -def to_scanpy(data): - # Extract points - expression = pd.DataFrame( - data.X, index=pd.MultiIndex.from_frame(data.obs[["cell", "gene"]]) - ) - - # Aggregate points to counts - expression = ( - data.obs[["cell", "gene"]] - .groupby(["cell", "gene"]) - .apply(lambda x: x.shape[0]) - .to_frame() - ) - expression = expression.reset_index() - - # Remove extracellular points - expression = expression.loc[expression["cell"] != "-1"] - - # Format as dense cell x gene counts matrix - expression = expression.pivot(index="cell", columns="gene").fillna(0) - expression.columns = expression.columns.droplevel(0) - expression.columns = expression.columns.str.upper() - - # Create anndata object - sc_data = anndata.AnnData(expression) - - return sc_data + if len(shape_sjoin) > 0: + sdata = sjoin_shapes( + sdata=sdata, instance_key=instance_key, shape_keys=shape_sjoin + ) + + # Recompute count table + table = TableModel.parse(sdata.aggregate( + values=points_key, + instance_key=instance_key, + by=instance_key, + value_key=feature_key, + aggfunc="count", + ).table) + + del sdata.table + sdata.table = table + # Set instance key to cell_shape_key for all points and table + sdata.points[points_key].attrs["spatialdata_attrs"]["instance_key"] = instance_key + sdata.points[points_key].attrs["spatialdata_attrs"]["feature_key"] = feature_key + + return sdata diff --git a/bento/plotting/__init__.py b/bento/plotting/__init__.py old mode 100755 new mode 100644 index 235a4a8..4c5ef3b --- a/bento/plotting/__init__.py +++ b/bento/plotting/__init__.py @@ -1,5 +1,5 @@ from ._multidimensional import flux_summary, obs_stats -from ._lp import lp_diff, lp_dist, lp_gene_dist, lp_genes +from ._lp import lp_diff_discrete, lp_dist, lp_gene_dist, lp_genes from ._plotting import points, density, shapes, flux, fluxmap, fe -from ._signatures import colocation, factor, signatures, signatures_error +from ._signatures import colocation, factor \ No newline at end of file diff --git a/bento/plotting/_layers.py b/bento/plotting/_layers.py index e42cf45..c07fb5e 100644 --- a/bento/plotting/_layers.py +++ b/bento/plotting/_layers.py @@ -56,27 +56,29 @@ def _kde(points, ax, hue=None, **kwargs): sns.kdeplot(data=points, x="x", y="y", hue=hue, ax=ax, **kde_kws) -def _polygons(adata, shape, ax, hue=None, hide_outside=False, **kwargs): +def _polygons(sdata, shape, ax, hue=None, hide_outside=False, sync_shapes=True, **kwargs): """Plot shapes with GeoSeries plot function.""" - - shapes = gpd.GeoDataFrame(geometry=get_shape(adata, shape)) - - edge_color = sns.axes_style()["axes.edgecolor"] + shapes = gpd.GeoDataFrame(geometry=get_shape(sdata, shape, sync=sync_shapes)) + edge_color = "none" face_color = "none" # If hue is specified, use it to color faces if hue: - shapes[hue] = adata.obs.reset_index()[hue].values + df = shapes.reset_index().merge(sdata.shapes[shape], how='left', left_on="geometry", right_on="geometry").set_index('index') + if hue == "cell": + shapes[hue] = df.index + else: + shapes[hue] = df.reset_index()[hue].values edge_color = sns.axes_style()["axes.edgecolor"] face_color = "none" # let GeoDataFrame plot function handle facecolor - + style_kwds = dict( linewidth=0.5, edgecolor=edge_color, facecolor=face_color, zorder=2 ) style_kwds.update(kwargs) shapes.plot(ax=ax, column=hue, **style_kwds) - - if hide_outside: + + if hide_outside and ax is not None: # get axes limits xmin, xmax = ax.get_xlim() ymin, ymax = ax.get_ylim() @@ -105,15 +107,15 @@ def _polygons(adata, shape, ax, hue=None, hide_outside=False, **kwargs): ) -def _raster(adata, res, color, alpha, points_key="cell_raster", cbar=False, ax=None, **kwargs): +def _raster(sdata, res, color, points_key="cell_raster", cbar=False, ax=None, **kwargs): """Plot gradient.""" if ax is None: ax = plt.gca() - points = get_points(adata, key=points_key) + points = get_points(sdata, points_key=points_key, astype="pandas") step = 1 / res - color_values = get_points_metadata(adata, metadata_key=color, points_key=points_key) + color_values = np.array(get_points_metadata(sdata, metadata_key=color, points_key=points_key)) # Infer value format and convert values to rgb # Handle color names and (r, g, b) tuples with matplotlib v1 = color_values[0] @@ -159,4 +161,4 @@ def _raster(adata, res, color, alpha, points_key="cell_raster", cbar=False, ax=N if cbar: cax = inset_axes(ax, width="20%", height="4%", loc="upper right", borderpad=1.5) cbar = plt.colorbar(im, orientation="horizontal", cax=cax) - # cbar.ax.tick_params(axis="x", direction="in", pad=-12) + # cbar.ax.tick_params(axis="x", direction="in", pad=-12) \ No newline at end of file diff --git a/bento/plotting/_lp.py b/bento/plotting/_lp.py old mode 100755 new mode 100644 index ac5689f..d0bbeea --- a/bento/plotting/_lp.py +++ b/bento/plotting/_lp.py @@ -6,23 +6,23 @@ import matplotlib.pyplot as plt import numpy as np import seaborn as sns -from anndata import AnnData +from spatialdata._core.spatialdata import SpatialData from upsetplot import UpSet, from_indicators from .._constants import PATTERN_COLORS, PATTERN_NAMES from ..tools import lp_stats +from ..geometry import get_points from ._utils import savefig from ._multidimensional import _radviz - @savefig -def lp_dist(data, percentage=False, scale=1, fname=None): +def lp_dist(sdata, percentage=False, scale=1, fname=None): """Plot pattern combination frequencies as an UpSet plot. Parameters ---------- - data : AnnData - Spatial formatted AnnData + sdata : SpatialData + Spatial formatted SpatialData percentage : bool, optional If True, label each bar as a percentage else label as a count, by default False scale : int, optional @@ -30,7 +30,7 @@ def lp_dist(data, percentage=False, scale=1, fname=None): fname : str, optional Save the figure to specified filename, by default None """ - sample_labels = data.uns["lp"] + sample_labels = sdata.table.uns["lp"] sample_labels = sample_labels == 1 # Sort by degree, then pattern name @@ -58,22 +58,21 @@ def lp_dist(data, percentage=False, scale=1, fname=None): upset.plot() plt.suptitle(f"Localization Patterns\n{sample_labels.shape[0]} samples") - @savefig -def lp_gene_dist(data, fname=None): +def lp_gene_dist(sdata, fname=None): """Plot the cell fraction distribution of each pattern as a density plot. Parameters ---------- - data : AnnData - Spatial formatted AnnData + sdata : SpatialData + Spatial formatted SpatialData fname : str, optional Save the figure to specified filename, by default None """ - lp_stats(data) + lp_stats(sdata) col_names = [f"{p}_fraction" for p in PATTERN_NAMES] - gene_frac = data.var[col_names] + gene_frac = sdata.table.var[col_names] gene_frac.columns = PATTERN_NAMES # Plot frequency distributions sns.displot( @@ -86,11 +85,11 @@ def lp_gene_dist(data, fname=None): plt.xlim(0, 1) sns.despine() - @savefig def lp_genes( - data: AnnData, + sdata: SpatialData, groupby: str = "gene", + points_key = "transcripts", annotate: Union[int, List[str], None] = None, sizes: Tuple[int] = (2, 100), size_norm: Tuple[int] = (0, 100), @@ -105,8 +104,8 @@ def lp_genes( Parameters ---------- - data : AnnData - Spatial formatted AnnData + sdata : SpatialData + Spatial formatted SpatialData groupby : str Grouping variable, default "gene" annotate : int, list of str, optional @@ -122,20 +121,21 @@ def lp_genes( **kwargs Options to pass to matplotlib plotting method. """ - lp_stats(data, groupby) + lp_stats(sdata) palette = dict(zip(PATTERN_NAMES, PATTERN_COLORS)) - n_cells = data.n_obs - gene_frac = data.uns["lp_stats"][PATTERN_NAMES] / n_cells - - gene_logcount = data.X.mean(axis=0, where=data.X > 0) + n_cells = sdata.table.n_obs + gene_frac = sdata.table.uns["lp_stats"][PATTERN_NAMES] / n_cells + genes = gene_frac.index + gene_expression_array = sdata.table[:,genes].X.toarray() + gene_logcount = gene_expression_array.mean(axis=0, where=gene_expression_array > 0) gene_logcount = np.log2(gene_logcount + 1) gene_frac["logcounts"] = gene_logcount - + cell_fraction = ( 100 - * data.uns["points"].groupby("gene", observed=True)["cell"].nunique() + * get_points(sdata, points_key, astype="pandas").groupby("gene", observed=True)["cell"].nunique() / n_cells ) gene_frac["cell_fraction"] = cell_fraction @@ -144,22 +144,21 @@ def lp_genes( scatter_kws.update(kwargs) _radviz(gene_frac, annotate=annotate, ax=ax, **scatter_kws) - @savefig -def lp_diff(data: AnnData, phenotype: str, fname: str = None): +def lp_diff_discrete(sdata: SpatialData, phenotype: str, fname: str = None): """Visualize gene pattern frequencies between groups of cells by plotting log2 fold change and -log10p, similar to volcano plot. Run after :func:`bento.tl.lp_diff()` Parameters ---------- - data : AnnData - Spatial formatted AnnData + sdata : SpatialData + Spatial formatted SpatialData phenotype : str Variable used to group cells when calling :func:`bento.tl.lp_diff()`. fname : str, optional Save the figure to specified filename, by default None """ - diff_stats = data.uns[f"diff_{phenotype}"] + diff_stats = sdata.table.uns[f"diff_{phenotype}"] palette = dict(zip(PATTERN_NAMES, PATTERN_COLORS)) g = sns.relplot( @@ -187,4 +186,4 @@ def lp_diff(data: AnnData, phenotype: str, fname: str = None): ) # line where FDR = 0.05 sns.despine() - return g + return g \ No newline at end of file diff --git a/bento/plotting/_plotting.py b/bento/plotting/_plotting.py old mode 100755 new mode 100644 index 4cb9d37..fd74f97 --- a/bento/plotting/_plotting.py +++ b/bento/plotting/_plotting.py @@ -8,16 +8,14 @@ from ..geometry import get_points from ._layers import _raster, _scatter, _hist, _kde, _polygons -from ._utils import savefig, vec2color -from .._utils import sync +from ._utils import savefig from ._colors import red2blue, red2blue_dark - -def _prepare_points_df(adata, semantic_vars=None): +def _prepare_points_df(sdata, semantic_vars=None, hue=None, hue_order=None): """ Prepare points DataFrame for plotting. This function will concatenate the appropriate semantic variables as columns to points data. """ - points = get_points(adata, key="points") + points = get_points(sdata, astype="pandas") cols = list(set(["x", "y", "cell"])) if semantic_vars is None or len(semantic_vars) == 0: @@ -26,22 +24,20 @@ def _prepare_points_df(adata, semantic_vars=None): vars = [v for v in semantic_vars if v is not None] cols.extend(vars) + if hue_order is not None: + points = points[points[hue].isin(hue_order)] + # Add semantic variables to points; priority: points, obs, points metadata for var in vars: if var in points.columns: continue - elif var in adata.obs.columns: - points[var] = adata.obs.reindex(points["cell"].values)[var].values - elif var in adata.uns["point_sets"]["points"]: - if len(adata.uns[var].shape) > 1: - raise ValueError(f"Variable {var} is not 1-dimensional") - points[var] = adata.uns[var] + elif var in sdata.shapes: + points[var] = sdata.shapes[var].reindex(points["cell"].values)[var].values else: raise ValueError(f"Variable {var} not found in points or obs") - + return points[cols] - def _setup_ax( ax=None, dx=0.1, @@ -94,12 +90,11 @@ def _setup_ax( return ax - @savefig def points( - data, - batch=None, + sdata, hue=None, + hue_order=None, size=None, style=None, shapes=None, @@ -111,17 +106,12 @@ def points( axis_visible=False, frame_visible=True, ax=None, + sync_shapes=True, shapes_kws=dict(), fname=None, **kwargs, ): - # Default use first obs batch - if batch is None: - batch = data.obs["batch"].iloc[0] - adata = data[data.obs["batch"] == batch] - sync(adata) - title = f"batch {batch}" if not title else title - + ax = _setup_ax( ax=ax, dx=dx, @@ -131,18 +121,16 @@ def points( frame_visible=frame_visible, title=title, ) - - points = _prepare_points_df(adata, semantic_vars=[hue, size, style]) + points = _prepare_points_df(sdata, semantic_vars=[hue, size, style], hue=hue, hue_order=hue_order) _scatter(points, hue=hue, size=size, style=style, ax=ax, **kwargs) - _shapes(adata, shapes=shapes, hide_outside=hide_outside, ax=ax, **shapes_kws) - + _shapes(sdata, shapes=shapes, hide_outside=hide_outside, ax=ax, sync_shapes=sync_shapes, **shapes_kws) @savefig def density( - data, - batch=None, + sdata, kind="hist", hue=None, + hue_order=None, shapes=None, hide_outside=True, axis_visible=False, @@ -152,16 +140,11 @@ def density( units="um", square=False, ax=None, + sync_shapes=True, shape_kws=dict(), fname=None, **kwargs, ): - # Default use first obs batch - if batch is None: - batch = data.obs["batch"].iloc[0] - adata = data[data.obs["batch"] == batch] - sync(adata) - title = f"batch {batch}" if title is None else title ax = _setup_ax( ax=ax, @@ -173,115 +156,17 @@ def density( title=title, ) - points = _prepare_points_df(adata, semantic_vars=[hue]) + points = _prepare_points_df(sdata, semantic_vars=[hue], hue=hue, hue_order=hue_order) if kind == "hist": _hist(points, hue=hue, ax=ax, **kwargs) elif kind == "kde": _kde(points, hue=hue, ax=ax, **kwargs) - _shapes(adata, shapes=shapes, hide_outside=hide_outside, ax=ax, **shape_kws) - - -@savefig -def flux( - data, - dims=[0, 1, 2], - alpha=True, - batch=None, - res=1, - shapes=None, - hide_outside=True, - axis_visible=False, - frame_visible=True, - title=None, - dx=0.1, - units="um", - square=False, - ax=None, - shape_kws=dict(), - fname=None, - **kwargs, -): - # Default use first obs batch - if batch is None: - batch = data.obs["batch"].iloc[0] - adata = data[data.obs["batch"] == batch] - sync(adata) - title = f"batch {batch}" if not title else title - - ax = _setup_ax( - ax=ax, - dx=dx, - units=units, - square=square, - axis_visible=axis_visible, - frame_visible=frame_visible, - title=title, - ) - - adata.uns["flux_color"] = vec2color( - adata.uns["flux_embed"][:, dims], alpha_vec=adata.uns["flux_counts"] - ) - - _raster(adata, res=res, color="flux_color", alpha=alpha, ax=ax, **kwargs) - _shapes(adata, shapes=shapes, hide_outside=hide_outside, ax=ax, **shape_kws) - - -@savefig -def fe( - data, - gs, - batch=None, - res=1, - alpha=True, - shapes=None, - cmap=None, - cbar=True, - hide_outside=True, - axis_visible=False, - frame_visible=True, - title=None, - dx=0.1, - units="um", - square=False, - ax=None, - shape_kws=dict(), - fname=None, - **kwargs, -): - # Default use first obs batch - if batch is None: - batch = data.obs["batch"].iloc[0] - adata = data[data.obs["batch"] == batch] - sync(adata) - title = f"batch {batch}" if not title else title - - ax = _setup_ax( - ax=ax, - dx=dx, - units=units, - square=square, - axis_visible=axis_visible, - frame_visible=frame_visible, - title=title, - ) - - if cmap is None: - if sns.axes_style()["axes.facecolor"] == "white": - cmap = red2blue - elif sns.axes_style()["axes.facecolor"] == "black": - cmap = red2blue_dark - - _raster( - adata, res=res, color=gs, alpha=alpha, cmap=cmap, cbar=cbar, ax=ax, **kwargs - ) - _shapes(adata, shapes=shapes, hide_outside=hide_outside, ax=ax, **shape_kws) - + _shapes(sdata, shapes=shapes, hide_outside=hide_outside, ax=ax, sync_shapes=sync_shapes, **shape_kws) @savefig def shapes( - data, - batch=None, + sdata, shapes=None, color=None, color_style="outline", @@ -293,15 +178,10 @@ def shapes( title=None, square=False, ax=None, + sync_shapes=True, fname=None, **kwargs, ): - # Default use first obs batch - if batch is None: - batch = data.obs["batch"].iloc[0] - adata = data[data.obs["batch"] == batch] - sync(adata) - title = f"batch {batch}" if not title else title ax = _setup_ax( ax=ax, @@ -317,31 +197,33 @@ def shapes( shapes = [shapes] _shapes( - adata, + sdata, shapes=shapes, color=color, color_style=color_style, hide_outside=hide_outside, ax=ax, + sync_shapes=sync_shapes, **kwargs, ) def _shapes( - data, + sdata, shapes=None, color=None, color_style="outline", hide_outside=True, ax=None, + sync_shapes=True, **kwargs, ): """Plot layer(s) of shapes. Parameters ---------- - data : AnnData - Spatial formatted AnnData + data : SpatialData + Spatial formatted SpatialData shapes : list, optional List of shapes to plot, by default None. If None, will plot cell and nucleus shapes by default. color : str, optional @@ -358,14 +240,14 @@ def _shapes( shape_names = [] for s in shapes: - if str(s).endswith("_shape"): + if str(s).endswith("_boundaries"): shape_names.append(s) else: - shape_names.append(f"{s}_shape") + shape_names.append(f"{s}_boundaries") # Save list of names to remove if not in data.obs - shape_names = [name for name in shape_names if name in data.obs.columns] - missing_names = [name for name in shape_names if name not in data.obs.columns] + shape_names = [name for name in shape_names if name in sdata.shapes.keys()] + missing_names = [name for name in shape_names if name not in sdata.shapes.keys()] if len(missing_names) > 0: warnings.warn("Shapes not found in data: " + ", ".join(missing_names)) @@ -381,21 +263,94 @@ def _shapes( for name in shape_names: hide = False - if name == "cell_shape" and hide_outside: + if name == "cell_boundaries" and hide_outside: hide = True _polygons( - data, + sdata, name, hide_outside=hide, ax=ax, + sync_shapes=sync_shapes, **geo_kws, ) +@savefig +def flux( + sdata, + res=0.05, + shapes=None, + hide_outside=True, + axis_visible=False, + frame_visible=True, + title=None, + dx=0.1, + units="um", + square=False, + ax=None, + sync_shapes=True, + shape_kws=dict(), + fname=None, + **kwargs, +): + + ax = _setup_ax( + ax=ax, + dx=dx, + units=units, + square=square, + axis_visible=axis_visible, + frame_visible=frame_visible, + title=title, + ) + + _raster(sdata, res=res, color="flux_color", ax=ax, **kwargs) + _shapes(sdata, shapes=shapes, hide_outside=hide_outside, ax=ax, sync_shapes=sync_shapes, **shape_kws) + +@savefig +def fe( + sdata, + gs, + res=0.05, + shapes=None, + cmap=None, + cbar=True, + hide_outside=True, + axis_visible=False, + frame_visible=True, + title=None, + dx=0.1, + units="um", + square=False, + ax=None, + sync_shapes=True, + shape_kws=dict(), + fname=None, + **kwargs, +): + + ax = _setup_ax( + ax=ax, + dx=dx, + units=units, + square=square, + axis_visible=axis_visible, + frame_visible=frame_visible, + title=title, + ) + + if cmap is None: + if sns.axes_style()["axes.facecolor"] == "white": + cmap = red2blue + elif sns.axes_style()["axes.facecolor"] == "black": + cmap = red2blue_dark + + _raster(sdata, res=res, color=gs, cmap=cmap, cbar=cbar, ax=ax, **kwargs) + _shapes(sdata, shapes=shapes, hide_outside=hide_outside, ax=ax, sync_shapes=sync_shapes, **shape_kws) +@savefig def fluxmap( - data, - batch=None, + sdata, palette="tab10", hide_outside=True, axis_visible=False, @@ -403,7 +358,7 @@ def fluxmap( title=None, dx=0.1, ax=None, - legend=True, + sync_shapes=False, fname=None, **kwargs, ): @@ -413,8 +368,6 @@ def fluxmap( ---------- data : AnnData Spatial formatted AnnData - batch : str, optional - Batch to plot, by default None. If None, will use first batch. palette : str or dict, optional Color palette, by default "tab10". If dict, will use dict to map shape names to colors. ax : matplotlib.axes.Axes, optional @@ -427,7 +380,7 @@ def fluxmap( if isinstance(palette, dict): colormap = palette else: - fluxmap_shapes = [s for s in data.obs.columns if s.startswith("fluxmap")] + fluxmap_shapes = [s for s in sdata.shapes.keys() if s.startswith("fluxmap")] fluxmap_shapes.sort() colors = sns.color_palette(palette, n_colors=len(fluxmap_shapes)) colormap = dict(zip(fluxmap_shapes, colors)) @@ -437,8 +390,7 @@ def fluxmap( for s, c in colormap.items(): shapes( - data, - batch=batch, + sdata, shapes=s, color=c, hide_outside=hide_outside, @@ -447,32 +399,9 @@ def fluxmap( title=title, dx=dx, ax=ax, + sync_shapes=sync_shapes, **shape_kws, ) - - # Add to legend - if legend: - plt.plot( - [], - [], - color=c, - label="_".join(s.split("_")[:-1]), - marker="s", - linestyle="none", - ) - - if legend: - plt.legend() - + # Plot base cell and nucleus shapes - shapes( - data, - batch=batch, - ax=ax, - hide_outside=hide_outside, - axis_visible=axis_visible, - frame_visible=frame_visible, - title=title, - dx=dx, - fname=fname, - ) + shapes(sdata, ax=ax, fname=fname) \ No newline at end of file diff --git a/bento/plotting/_signatures.py b/bento/plotting/_signatures.py index c96296a..35402f1 100644 --- a/bento/plotting/_signatures.py +++ b/bento/plotting/_signatures.py @@ -4,89 +4,11 @@ import seaborn as sns from scipy.stats import zscore -from .._constants import PATTERN_COLORS, PATTERN_PROBS from ._colors import red2blue, red_light from ._utils import savefig - -@savefig -def signatures(adata, rank, fname=None): - """Plot signatures for specified rank across each dimension. - - bento.tl.signatures() must be run first. - - Parameters - ---------- - adata : anndata.AnnData - Spatial formatted AnnData - rank : int - Rank of signatures to plot - fname : str, optional - Path to save figure, by default None - """ - sig_key = f"r{rank}_signatures" - layer_g = sns.clustermap( - np.log2(adata.uns[sig_key] + 1).T, - col_cluster=False, - row_cluster=False, - col_colors=pd.Series(PATTERN_COLORS, index=PATTERN_PROBS), - standard_scale=0, - cmap=red_light, - linewidth=1, - linecolor="black", - figsize=(adata.uns[sig_key].shape[0], adata.uns[sig_key].shape[1] + 1), - ) - sns.despine(ax=layer_g.ax_heatmap, top=False, right=False) - plt.suptitle("Layers") - - gs_shape = adata.varm[sig_key].shape - gene_g = sns.clustermap( - np.log2(adata.varm[sig_key] + 1).T, - row_cluster=False, - cmap=red_light, - standard_scale=0, - figsize=(gs_shape[0], gs_shape[1] + 1), - ) - sns.despine(ax=gene_g.ax_heatmap, top=False, right=False) - plt.suptitle("Genes") - - os_shape = adata.obsm[sig_key].shape - cell_g = sns.clustermap( - np.log2(adata.obsm[sig_key] + 1).T, - row_cluster=False, - col_cluster=True, - standard_scale=0, - xticklabels=False, - # col_colors=pheno_to_color(adata.obs["leiden"], palette="tab20")[1], - cmap=red_light, - figsize=(os_shape[0], os_shape[1] + 1), - ) - sns.despine(ax=cell_g.ax_heatmap, top=False, right=False) - plt.suptitle("Cells") - - -@savefig -def signatures_error(adata, fname=None): - """Plot error for each rank. - - bento.tl.signatures() must be run first. - - Parameters - ---------- - adata : anndata.AnnData - Spatial formatted AnnData - fname : str, optional - Path to save figure, by default None - """ - errors = adata.uns["signatures_error"] - sns.lineplot(data=errors, x="rank", y="rmse", ci=95, marker="o") - sns.despine() - - return errors - - def colocation( - adata, + sdata, rank, n_top=[None, None, 5], z_score=[False, True, True], @@ -101,8 +23,8 @@ def colocation( Parameters ---------- - adata : anndata.AnnData - Spatial formatted AnnData + sdata : spatialdata.SpatialData + Spatial formatted SpatialData rank : int Rank of signatures to plot n_top : int, optional @@ -120,9 +42,9 @@ def colocation( fname : str, optional Path to save figure, by default None """ - factors = adata.uns["factors"][rank].copy() - labels = adata.uns["tensor_labels"].copy() - names = adata.uns["tensor_names"].copy() + factors = sdata.table.uns["factors"][rank].copy() + labels = sdata.table.uns["tensor_labels"].copy() + names = sdata.table.uns["tensor_names"].copy() # Perform z-scaling upfront for i in range(len(factors)): @@ -326,4 +248,4 @@ def _plot_loading(df, name, n_top, cut, show_labels, cluster, ax, **kwargs): ax.set_yticklabels(ax.get_yticklabels(), rotation=0) ax.set_title(f"{name}: [{df.shape[0]} x {df.shape[1]}]") - sns.despine(ax=ax, right=False, top=False) + sns.despine(ax=ax, right=False, top=False) \ No newline at end of file diff --git a/bento/query/__init__.py b/bento/query/__init__.py new file mode 100644 index 0000000..91a9aa2 --- /dev/null +++ b/bento/query/__init__.py @@ -0,0 +1,3 @@ +from ._query import ( + bounding_box_query, +) \ No newline at end of file diff --git a/bento/query/_query.py b/bento/query/_query.py new file mode 100644 index 0000000..76a8ccf --- /dev/null +++ b/bento/query/_query.py @@ -0,0 +1,754 @@ +from __future__ import annotations + +from abc import abstractmethod +from dataclasses import dataclass +from functools import singledispatch +from typing import Any, Callable, Optional, Union +import warnings + +warnings.filterwarnings("ignore") + +import dask.array as da +import numpy as np +import pandas as pd +from dask.dataframe.core import DataFrame as DaskDataFrame +from datatree import DataTree +from geopandas import GeoDataFrame +from multiscale_spatial_image.multiscale_spatial_image import MultiscaleSpatialImage +from shapely.geometry import Polygon +from spatial_image import SpatialImage +from tqdm import tqdm +from xarray import DataArray + +from spatialdata._core.spatialdata import SpatialData +from spatialdata._logging import logger +from spatialdata._types import ArrayLike +from spatialdata._utils import Number, _parse_list_into_array +from spatialdata.models import ( + SpatialElement, + get_axes_names, +) +from spatialdata.models._utils import ValidAxis_t, get_spatial_axes +from spatialdata.transformations._utils import compute_coordinates +from spatialdata.transformations.transformations import ( + Affine, + BaseTransformation, + Sequence, + Translation, + _get_affine_for_element, +) + + +def get_bounding_box_corners( + axes: tuple[str, ...], + min_coordinate: Union[list[Number], ArrayLike], + max_coordinate: Union[list[Number], ArrayLike], +) -> DataArray: + """Get the coordinates of the corners of a bounding box from the min/max values. + Parameters + ---------- + axes + The axes that min_coordinate and max_coordinate refer to. + min_coordinate + The upper left hand corner of the bounding box (i.e., minimum coordinates + along all dimensions). + max_coordinate + The lower right hand corner of the bounding box (i.e., the maximum coordinates + along all dimensions + Returns + ------- + (N, D) array of coordinates of the corners. N = 4 for 2D and 8 for 3D. + """ + min_coordinate = _parse_list_into_array(min_coordinate) + max_coordinate = _parse_list_into_array(max_coordinate) + + if len(min_coordinate) not in (2, 3): + raise ValueError("bounding box must be 2D or 3D") + + if len(min_coordinate) == 2: + # 2D bounding box + assert len(axes) == 2 + return DataArray( + [ + [min_coordinate[0], min_coordinate[1]], + [min_coordinate[0], max_coordinate[1]], + [max_coordinate[0], max_coordinate[1]], + [max_coordinate[0], min_coordinate[1]], + ], + coords={"corner": range(4), "axis": list(axes)}, + ) + + # 3D bounding cube + assert len(axes) == 3 + return DataArray( + [ + [min_coordinate[0], min_coordinate[1], min_coordinate[2]], + [min_coordinate[0], min_coordinate[1], max_coordinate[2]], + [min_coordinate[0], max_coordinate[1], max_coordinate[2]], + [min_coordinate[0], max_coordinate[1], min_coordinate[2]], + [max_coordinate[0], min_coordinate[1], min_coordinate[2]], + [max_coordinate[0], min_coordinate[1], max_coordinate[2]], + [max_coordinate[0], max_coordinate[1], max_coordinate[2]], + [max_coordinate[0], max_coordinate[1], min_coordinate[2]], + ], + coords={"corner": range(8), "axis": list(axes)}, + ) + + +def _get_bounding_box_corners_in_intrinsic_coordinates( + element: SpatialElement, + axes: tuple[str, ...], + min_coordinate: Union[list[Number], ArrayLike], + max_coordinate: Union[list[Number], ArrayLike], + target_coordinate_system: str, +) -> tuple[ArrayLike, tuple[str, ...]]: + """Get all corners of a bounding box in the intrinsic coordinates of an element. + Parameters + ---------- + element + The SpatialElement to get the intrinsic coordinate system from. + axes + The axes that min_coordinate and max_coordinate refer to. + min_coordinate + The upper left hand corner of the bounding box (i.e., minimum coordinates + along all dimensions). + max_coordinate + The lower right hand corner of the bounding box (i.e., the maximum coordinates + along all dimensions + target_coordinate_system + The coordinate system the bounding box is defined in. + Returns ------- All the corners of the bounding box in the intrinsic coordinate system of the element. The shape + is (2, 4) when axes has 2 spatial dimensions, and (2, 8) when axes has 3 spatial dimensions. + The axes of the intrinsic coordinate system. + """ + from spatialdata.transformations import get_transformation + + min_coordinate = _parse_list_into_array(min_coordinate) + max_coordinate = _parse_list_into_array(max_coordinate) + # get the transformation from the element's intrinsic coordinate system + # to the query coordinate space + transform_to_query_space = get_transformation(element, to_coordinate_system=target_coordinate_system) + + # get the coordinates of the bounding box corners + bounding_box_corners = get_bounding_box_corners( + min_coordinate=min_coordinate, + max_coordinate=max_coordinate, + axes=axes, + ).data + + # transform the coordinates to the intrinsic coordinate system + intrinsic_axes = get_axes_names(element) + transform_to_intrinsic = transform_to_query_space.inverse().to_affine_matrix( # type: ignore[union-attr] + input_axes=axes, output_axes=intrinsic_axes + ) + rotation_matrix = transform_to_intrinsic[0:-1, 0:-1] + translation = transform_to_intrinsic[0:-1, -1] + + intrinsic_bounding_box_corners = bounding_box_corners @ rotation_matrix.T + translation + + return intrinsic_bounding_box_corners, intrinsic_axes + + +@dataclass(frozen=True) +class BaseSpatialRequest: + """Base class for spatial queries.""" + + target_coordinate_system: str + + @abstractmethod + def to_dict(self) -> dict[str, Any]: + pass + + +@dataclass(frozen=True) +class BoundingBoxRequest(BaseSpatialRequest): + """Query with an axis-aligned bounding box. + Attributes + ---------- + axes + The axes the coordinates are expressed in. + min_coordinate + The coordinate of the lower left hand corner (i.e., minimum values) + of the bounding box. + max_coordinate + The coordinate of the upper right hand corner (i.e., maximum values) + of the bounding box + """ + + min_coordinate: ArrayLike + max_coordinate: ArrayLike + axes: tuple[ValidAxis_t, ...] + + def __post_init__(self) -> None: + # validate the axes + spatial_axes = get_spatial_axes(self.axes) + non_spatial_axes = set(self.axes) - set(spatial_axes) + if len(non_spatial_axes) > 0: + raise ValueError(f"Non-spatial axes specified: {non_spatial_axes}") + + # validate the axes + if len(self.axes) != len(self.min_coordinate) or len(self.axes) != len(self.max_coordinate): + raise ValueError("The number of axes must match the number of coordinates.") + + # validate the coordinates + if np.any(self.min_coordinate > self.max_coordinate): + raise ValueError("The minimum coordinate must be less than the maximum coordinate.") + + def to_dict(self) -> dict[str, Any]: + return { + "target_coordinate_system": self.target_coordinate_system, + "axes": self.axes, + "min_coordinate": self.min_coordinate, + "max_coordinate": self.max_coordinate, + } + + +def _bounding_box_mask_points( + points: DaskDataFrame, + axes: tuple[str, ...], + min_coordinate: Union[list[Number], ArrayLike], + max_coordinate: Union[list[Number], ArrayLike], +) -> da.Array: + """Compute a mask that is true for the points inside of an axis-aligned bounding box.. + Parameters + ---------- + points + The points element to perform the query on. + axes + The axes that min_coordinate and max_coordinate refer to. + min_coordinate + The upper left hand corner of the bounding box (i.e., minimum coordinates + along all dimensions). + max_coordinate + The lower right hand corner of the bounding box (i.e., the maximum coordinates + along all dimensions + Returns + ------- + The mask for the points inside of the bounding box. + """ + min_coordinate = _parse_list_into_array(min_coordinate) + max_coordinate = _parse_list_into_array(max_coordinate) + in_bounding_box_masks = [] + for axis_index, axis_name in enumerate(axes): + min_value = min_coordinate[axis_index] + in_bounding_box_masks.append(points[axis_name].gt(min_value).to_dask_array(lengths=True)) + for axis_index, axis_name in enumerate(axes): + max_value = max_coordinate[axis_index] + in_bounding_box_masks.append(points[axis_name].lt(max_value).to_dask_array(lengths=True)) + in_bounding_box_masks = da.stack(in_bounding_box_masks, axis=-1) + return da.all(in_bounding_box_masks, axis=1) + + +def _dict_query_dispatcher( + elements: dict[str, SpatialElement], query_function: Callable[[SpatialElement], SpatialElement], **kwargs: Any +) -> dict[str, SpatialElement]: + from spatialdata.transformations import get_transformation + + queried_elements = {} + for key, element in elements.items(): + target_coordinate_system = kwargs["target_coordinate_system"] + d = get_transformation(element, get_all=True) + assert isinstance(d, dict) + if target_coordinate_system in d: + result = query_function(element, **kwargs) + if result is not None: + # query returns None if it is empty + queried_elements[key] = result + return queried_elements + + +@singledispatch +def bounding_box_query( + element: Union[SpatialElement, SpatialData], + axes: tuple[str, ...], + min_coordinate: Union[list[Number], ArrayLike], + max_coordinate: Union[list[Number], ArrayLike], + target_coordinate_system: str, + **kwargs: Any, +) -> Optional[Union[SpatialElement, SpatialData]]: + """ + Perform a bounding box query on the SpatialData object. + Parameters + ---------- + axes + The axes `min_coordinate` and `max_coordinate` refer to. + min_coordinate + The minimum coordinates of the bounding box. + max_coordinate + The maximum coordinates of the bounding box. + target_coordinate_system + The coordinate system the bounding box is defined in. + filter_table + If `True`, the table is filtered to only contain rows that are annotating regions + contained within the bounding box. + Returns + ------- + The SpatialData object containing the requested data. + Elements with no valid data are omitted. + """ + raise RuntimeError("Unsupported type for bounding_box_query: " + str(type(element)) + ".") + + +@bounding_box_query.register(SpatialData) +def _( + sdata: SpatialData, + axes: tuple[str, ...], + min_coordinate: Union[list[Number], ArrayLike], + max_coordinate: Union[list[Number], ArrayLike], + target_coordinate_system: str, + filter_table: bool = True, +) -> SpatialData: + from spatialdata import SpatialData + from spatialdata._core.query.relational_query import _filter_table_by_elements + + min_coordinate = _parse_list_into_array(min_coordinate) + max_coordinate = _parse_list_into_array(max_coordinate) + new_elements = {} + for element_type in ["points", "images", "labels", "shapes"]: + elements = getattr(sdata, element_type) + queried_elements = _dict_query_dispatcher( + elements, + bounding_box_query, + axes=axes, + min_coordinate=min_coordinate, + max_coordinate=max_coordinate, + target_coordinate_system=target_coordinate_system, + ) + new_elements[element_type] = queried_elements + + table = _filter_table_by_elements(sdata.table, new_elements) if filter_table else sdata.table + return SpatialData(**new_elements, table=table) + + +@bounding_box_query.register(SpatialImage) +@bounding_box_query.register(MultiscaleSpatialImage) +def _( + image: Union[SpatialImage, MultiscaleSpatialImage], + axes: tuple[str, ...], + min_coordinate: Union[list[Number], ArrayLike], + max_coordinate: Union[list[Number], ArrayLike], + target_coordinate_system: str, +) -> Optional[Union[SpatialImage, MultiscaleSpatialImage]]: + """Implement bounding box query for SpatialImage. + Notes + ----- + _____ + See https://github.com/scverse/spatialdata/pull/151 for a detailed overview of the logic of this code, + and for the cases the comments refer to. + """ + from spatialdata.transformations import get_transformation, set_transformation + + min_coordinate = _parse_list_into_array(min_coordinate) + max_coordinate = _parse_list_into_array(max_coordinate) + + # for triggering validation + _ = BoundingBoxRequest( + target_coordinate_system=target_coordinate_system, + axes=axes, + min_coordinate=min_coordinate, + max_coordinate=max_coordinate, + ) + + # get the transformation from the element's intrinsic coordinate system to the query coordinate space + transform_to_query_space = get_transformation(image, to_coordinate_system=target_coordinate_system) + assert isinstance(transform_to_query_space, BaseTransformation) + m = _get_affine_for_element(image, transform_to_query_space) + input_axes_without_c = tuple([ax for ax in m.input_axes if ax != "c"]) + output_axes_without_c = tuple([ax for ax in m.output_axes if ax != "c"]) + m_without_c = m.to_affine_matrix(input_axes=input_axes_without_c, output_axes=output_axes_without_c) + m_without_c_linear = m_without_c[:-1, :-1] + + transform_dimension = np.linalg.matrix_rank(m_without_c_linear) + transform_coordinate_length = len(output_axes_without_c) + data_dim = len(input_axes_without_c) + + assert data_dim in [2, 3] + assert transform_dimension in [2, 3] + assert transform_coordinate_length in [2, 3] + assert not (data_dim == 2 and transform_dimension == 3) + assert not (transform_dimension == 3 and transform_coordinate_length == 2) + # see explanation in https://github.com/scverse/spatialdata/pull/151 + if data_dim == 2 and transform_dimension == 2 and transform_coordinate_length == 2: + case = 1 + elif data_dim == 2 and transform_dimension == 2 and transform_coordinate_length == 3: + case = 2 + elif data_dim == 3 and transform_dimension == 2 and transform_coordinate_length == 2: + case = 3 + elif data_dim == 3 and transform_dimension == 2 and transform_coordinate_length == 3: + case = 4 + elif data_dim == 3 and transform_dimension == 3 and transform_coordinate_length == 3: + case = 5 + else: + raise RuntimeError("This should not happen") + + if case in [3, 4]: + error_message = ( + f"This case is not supported (data with dimension" + f"{data_dim} but transformation with rank {transform_dimension}." + f"Please open a GitHub issue if you want to discuss a case." + ) + raise ValueError(error_message) + + if set(axes) != set(output_axes_without_c): + if set(axes).issubset(output_axes_without_c): + logger.warning( + f"The element has axes {output_axes_without_c}, but the query has axes {axes}. Excluding the element " + f"from the query result. In the future we can add support for this case. If you are interested, " + f"please open a GitHub issue." + ) + return None + error_messeage = ( + f"Invalid case. The bounding box axes are {axes}," + f"the spatial axes in {target_coordinate_system} are" + f"{output_axes_without_c}" + ) + raise ValueError(error_messeage) + + spatial_transform = Affine(m_without_c, input_axes=input_axes_without_c, output_axes=output_axes_without_c) + spatial_transform_bb_axes = Affine( + spatial_transform.to_affine_matrix(input_axes=input_axes_without_c, output_axes=axes), + input_axes=input_axes_without_c, + output_axes=axes, + ) + assert case in [1, 2, 5] + if case in [1, 5]: + bounding_box_corners = get_bounding_box_corners( + min_coordinate=min_coordinate, + max_coordinate=max_coordinate, + axes=axes, + ) + else: + assert case == 2 + # TODO: we need to intersect the plane in the extrinsic coordiante system with the 3D bounding box. The + # vertices of this polygons needs to be transformed to the intrinsic coordinate system + raise NotImplementedError( + "Case 2 (the transformation is embedding 2D data in the 3D space, is not " + "implemented yet. Please open a Github issue about this and we will prioritize the " + "development." + ) + inverse = spatial_transform_bb_axes.inverse() + assert isinstance(inverse, Affine) + rotation_matrix = inverse.matrix[0:-1, 0:-1] + translation = inverse.matrix[0:-1, -1] + + intrinsic_bounding_box_corners = DataArray( + bounding_box_corners.data @ rotation_matrix.T + translation, + coords={"corner": range(len(bounding_box_corners)), "axis": list(inverse.output_axes)}, + ) + + # build the request + selection = {} + translation_vector = [] + for axis_name in axes: + # get the min value along the axis + min_value = intrinsic_bounding_box_corners.sel(axis=axis_name).min().item() + + # get max value, slices are open half interval + max_value = intrinsic_bounding_box_corners.sel(axis=axis_name).max().item() + + # add the + selection[axis_name] = slice(min_value, max_value) + + if min_value > 0: + translation_vector.append(np.ceil(min_value).item()) + else: + translation_vector.append(0) + + query_result = image.sel(selection) + if isinstance(image, SpatialImage): + if 0 in query_result.shape: + return None + assert isinstance(query_result, SpatialImage) + else: + assert isinstance(image, MultiscaleSpatialImage) + assert isinstance(query_result, DataTree) + # we need to convert query_result it to MultiscaleSpatialImage, dropping eventual collapses scales (or even + # the whole object if the first scale is collapsed) + d = {} + for k, data_tree in query_result.items(): + v = data_tree.values() + assert len(v) == 1 + xdata = v.__iter__().__next__() + if 0 in xdata.shape: + if k == "scale0": + return None + else: + d[k] = xdata + query_result = MultiscaleSpatialImage.from_dict(d) + query_result = compute_coordinates(query_result) + + # the bounding box, mapped back to the intrinsic coordinate system is a set of points. The bounding box of these + # points is likely starting away from the origin (this is described by translation_vector), so we need to prepend + # this translation to every transformation in the new queries elements (unless the translation_vector is zero, + # in that case the translation is not needed) + if not np.allclose(np.array(translation_vector), 0): + translation_transform = Translation(translation=translation_vector, axes=axes) + + transformations = get_transformation(query_result, get_all=True) + assert isinstance(transformations, dict) + + new_transformations = {} + for coordinate_system, initial_transform in transformations.items(): + new_transformation: BaseTransformation = Sequence( + [translation_transform, initial_transform], + ) + new_transformations[coordinate_system] = new_transformation + set_transformation(query_result, new_transformations, set_all=True) + return query_result + + +@bounding_box_query.register(DaskDataFrame) +def _( + points: DaskDataFrame, + axes: tuple[str, ...], + min_coordinate: Union[list[Number], ArrayLike], + max_coordinate: Union[list[Number], ArrayLike], + target_coordinate_system: str, +) -> Optional[DaskDataFrame]: + from spatialdata import transform + from spatialdata.transformations import BaseTransformation, get_transformation + + min_coordinate = _parse_list_into_array(min_coordinate) + max_coordinate = _parse_list_into_array(max_coordinate) + # for triggering validation + _ = BoundingBoxRequest( + target_coordinate_system=target_coordinate_system, + axes=axes, + min_coordinate=min_coordinate, + max_coordinate=max_coordinate, + ) + + # get the four corners of the bounding box (2D case), or the 8 corners of the "3D bounding box" (3D case) + (intrinsic_bounding_box_corners, intrinsic_axes) = _get_bounding_box_corners_in_intrinsic_coordinates( + element=points, + axes=axes, + min_coordinate=min_coordinate, + max_coordinate=max_coordinate, + target_coordinate_system=target_coordinate_system, + ) + min_coordinate_intrinsic = intrinsic_bounding_box_corners.min(axis=0) + max_coordinate_intrinsic = intrinsic_bounding_box_corners.max(axis=0) + # get the points in the intrinsic coordinate bounding box + in_intrinsic_bounding_box = _bounding_box_mask_points( + points=points, + axes=axes, + min_coordinate=min_coordinate_intrinsic, + max_coordinate=max_coordinate_intrinsic, + ) + points_in_intrinsic_bounding_box = points.loc[in_intrinsic_bounding_box] + + if in_intrinsic_bounding_box.sum() == 0: + # if there aren't any points, just return + return None + + # we have to reset the index since we have subset + # https://stackoverflow.com/questions/61395351/how-to-reset-index-on-concatenated-dataframe-in-dask + points_in_intrinsic_bounding_box = points_in_intrinsic_bounding_box.assign(idx=1) + points_in_intrinsic_bounding_box = points_in_intrinsic_bounding_box.set_index( + points_in_intrinsic_bounding_box.idx.cumsum() - 1 + ) + points_in_intrinsic_bounding_box = points_in_intrinsic_bounding_box.map_partitions( + lambda df: df.rename(index={"idx": None}) + ) + points_in_intrinsic_bounding_box = points_in_intrinsic_bounding_box.drop(columns=["idx"]) + + # transform the element to the query coordinate system + transform_to_query_space = get_transformation(points, to_coordinate_system=target_coordinate_system) + assert isinstance(transform_to_query_space, BaseTransformation) + points_query_coordinate_system = transform( + points_in_intrinsic_bounding_box, transform_to_query_space, maintain_positioning=False + ) # type: ignore[union-attr] + + # get a mask for the points in the bounding box + bounding_box_mask = _bounding_box_mask_points( + points=points_query_coordinate_system, + axes=axes, + min_coordinate=min_coordinate, + max_coordinate=max_coordinate, + ) + if bounding_box_mask.sum() == 0: + return None + return points_in_intrinsic_bounding_box.loc[bounding_box_mask] + + +@bounding_box_query.register(GeoDataFrame) +def _( + polygons: GeoDataFrame, + axes: tuple[str, ...], + min_coordinate: Union[list[Number], ArrayLike], + max_coordinate: Union[list[Number], ArrayLike], + target_coordinate_system: str, +) -> Optional[GeoDataFrame]: + min_coordinate = _parse_list_into_array(min_coordinate) + max_coordinate = _parse_list_into_array(max_coordinate) + # for triggering validation + _ = BoundingBoxRequest( + target_coordinate_system=target_coordinate_system, + axes=axes, + min_coordinate=min_coordinate, + max_coordinate=max_coordinate, + ) + + # get the four corners of the bounding box + (intrinsic_bounding_box_corners, intrinsic_axes) = _get_bounding_box_corners_in_intrinsic_coordinates( + element=polygons, + axes=axes, + min_coordinate=min_coordinate, + max_coordinate=max_coordinate, + target_coordinate_system=target_coordinate_system, + ) + + bounding_box_non_axes_aligned = Polygon(intrinsic_bounding_box_corners) + queried = polygons[polygons.geometry.within(bounding_box_non_axes_aligned)] + if len(queried) == 0: + return None + return queried + + +def _polygon_query( + sdata: SpatialData, polygon: Polygon, target_coordinate_system: str, filter_table: bool, shapes: bool, points: bool +) -> SpatialData: + from spatialdata._core.query._utils import circles_to_polygons + from spatialdata._core.query.relational_query import _filter_table_by_elements + from spatialdata.models import ( + PointsModel, + ShapesModel, + points_dask_dataframe_to_geopandas, + points_geopandas_to_dask_dataframe, + ) + from spatialdata.transformations import get_transformation, set_transformation + + new_shapes = {} + if shapes: + for shapes_name, s in sdata.shapes.items(): + buffered = circles_to_polygons(s) if ShapesModel.RADIUS_KEY in s.columns else s + + if "__old_index" in buffered.columns: + assert np.all(s["__old_index"] == buffered.index) + else: + buffered["__old_index"] = buffered.index + indices = buffered.geometry.apply(lambda x: x.intersects(polygon)) + if np.sum(indices) == 0: + raise ValueError("we expect at least one shape") + queried_shapes = s[indices] + queried_shapes.index = buffered[indices]["__old_index"] + queried_shapes.index.name = None + del buffered["__old_index"] + if "__old_index" in queried_shapes.columns: + del queried_shapes["__old_index"] + transformation = get_transformation(buffered, target_coordinate_system) + queried_shapes = ShapesModel.parse(queried_shapes) + set_transformation(queried_shapes, transformation, target_coordinate_system) + new_shapes[shapes_name] = queried_shapes + + new_points = {} + if points: + for points_name, p in sdata.points.items(): + points_gdf = points_dask_dataframe_to_geopandas(p, suppress_z_warning=True) + indices = points_gdf.geometry.intersects(polygon) + if np.sum(indices) == 0: + raise ValueError("we expect at least one point") + queried_points = points_gdf[indices] + ddf = points_geopandas_to_dask_dataframe(queried_points, suppress_z_warning=True) + transformation = get_transformation(p, target_coordinate_system) + if "z" in ddf.columns: + ddf = PointsModel.parse(ddf, coordinates={"x": "x", "y": "y", "z": "z"}) + else: + ddf = PointsModel.parse(ddf, coordinates={"x": "x", "y": "y"}) + set_transformation(ddf, transformation, target_coordinate_system) + new_points[points_name] = ddf + + if filter_table: + table = _filter_table_by_elements(sdata.table, {"shapes": new_shapes, "points": new_points}) + else: + table = sdata.table + return SpatialData(shapes=new_shapes, points=new_points, table=table) + + +# this function is currently excluded from the API documentation. TODO: add it after the refactoring +def polygon_query( + sdata: SpatialData, + polygons: Union[Polygon, list[Polygon]], + target_coordinate_system: str, + filter_table: bool = True, + shapes: bool = True, + points: bool = True, +) -> SpatialData: + """ + Query a spatial data object by a polygon, filtering shapes and points. + Parameters + ---------- + sdata + The SpatialData object to query + polygon + The polygon (or list of polygons) to query by + target_coordinate_system + The coordinate system of the polygon + shapes + Whether to filter shapes + points + Whether to filter points + Returns + ------- + The queried SpatialData object with filtered shapes and points. + Notes + ----- + This function will be refactored to be more general. + The table is not filtered by this function, but is passed as is, this will also changed during the refactoring + making this function more general and ergonomic. + """ + from spatialdata._core.query.relational_query import _filter_table_by_elements + + if isinstance(polygons, Polygon): + polygons = [polygons] + + if len(polygons) == 1: + return _polygon_query( + sdata=sdata, + polygon=polygons[0], + target_coordinate_system=target_coordinate_system, + filter_table=filter_table, + shapes=shapes, + points=points, + ) + # TODO: the performance for this case can be greatly improved by using the geopandas queries only once, and not + # in a loop as done preliminarily here + if points: + raise NotImplementedError( + "points=True is not implemented when querying by multiple polygons. If you encounter this error, please" + " open an issue on GitHub and we will prioritize the implementation." + ) + sdatas = [] + for polygon in tqdm(polygons): + try: + # not filtering now, we filter below + queried_sdata = _polygon_query( + sdata=sdata, + polygon=polygon, + target_coordinate_system=target_coordinate_system, + filter_table=False, + shapes=shapes, + points=points, + ) + sdatas.append(queried_sdata) + except ValueError as e: + if str(e) != "we expect at least one shape": + raise e + # print("skipping", end="") + geodataframe_pieces: dict[str, list[GeoDataFrame]] = {} + + for sdata in sdatas: + for shapes_name, shapes in sdata.shapes.items(): + if shapes_name not in geodataframe_pieces: + geodataframe_pieces[shapes_name] = [] + geodataframe_pieces[shapes_name].append(shapes) + + geodataframes = {} + for k, v in geodataframe_pieces.items(): + vv = pd.concat(v) + vv = vv[~vv.index.duplicated(keep="first")] + geodataframes[k] = vv + + table = _filter_table_by_elements(sdata.table, {"shapes": geodataframes}) if filter_table else sdata.table + + return SpatialData(shapes=geodataframes, table=table) \ No newline at end of file diff --git a/bento/tools/__init__.py b/bento/tools/__init__.py old mode 100755 new mode 100644 index b43d86d..8133f7d --- a/bento/tools/__init__.py +++ b/bento/tools/__init__.py @@ -1,8 +1,8 @@ from ._colocation import coloc_quotient, colocation from ._composition import comp_diff from ._flux import flux, fluxmap -from ._flux_enrichment import fe, fe_kegg, fe_xia2019, fe_fazal2019, gene_sets, load_gene_sets -from ._lp import lp, lp_diff, lp_stats +from ._flux_enrichment import fe, fe_fazal2019, fe_xia2019, gene_sets, load_gene_sets +from ._lp import lp, lp_stats, lp_diff_discrete, lp_diff_continuous from ._point_features import analyze_points, list_point_features, register_point_feature from ._shape_features import ( analyze_shapes, @@ -10,4 +10,4 @@ register_shape_feature, list_shape_features, ) -from ._decomposition import decompose, to_tensor +from ._decomposition import decompose \ No newline at end of file diff --git a/bento/tools/_colocation.py b/bento/tools/_colocation.py index 142a952..b28628e 100644 --- a/bento/tools/_colocation.py +++ b/bento/tools/_colocation.py @@ -5,50 +5,52 @@ import pandas as pd import seaborn as sns import sparse -from anndata import AnnData +from spatialdata._core.spatialdata import SpatialData from kneed import KneeLocator from tqdm.auto import tqdm -from .._utils import track -from ..geometry import get_points +#from .._utils import track +from bento.geometry import get_points from ._neighborhoods import _count_neighbors from ._decomposition import decompose -@track def colocation( - data: AnnData, + sdata: SpatialData, ranks: List[int], + instance_key: str = "cell_boundaries", + feature_key: str = "feature_name", iterations: int = 3, plot_error: bool = True, - copy: bool = False, ): """Decompose a tensor of pairwise colocalization quotients into signatures. Parameters ---------- - adata : AnnData - Spatial formatted AnnData object. + sdata : SpatialData + Spatial formatted SpatialData object. ranks : list List of ranks to decompose the tensor. + instance_key : str + Key that specifies cell_boundaries instance in sdata. + feature_key : str + Key that specifies genes in sdata. iterations : int Number of iterations to run the decomposition. plot_error : bool Whether to plot the error of the decomposition. - copy : bool - Whether to return a copy of the AnnData object. Default False. + Returns ------- - adata : AnnData - .uns['factors']: Decomposed tensor factors. - .uns['factors_error']: Decomposition error. + sdata : SpatialData + .table.uns['factors']: Decomposed tensor factors. + .table.uns['factors_error']: Decomposition error. """ - adata = data.copy() if copy else data print("Preparing tensor...") - _colocation_tensor(adata, copy=copy) + _colocation_tensor(sdata, instance_key, feature_key) - tensor = adata.uns["tensor"] + tensor = sdata.table.uns["tensor"] print(emoji.emojize(":running: Decomposing tensor...")) factors, errors = decompose(tensor, ranks, iterations=iterations) @@ -60,27 +62,27 @@ def colocation( kl.plot_knee() sns.lineplot(data=errors, x="rank", y="rmse", ci=95, marker="o") - adata.uns["factors"] = factors - adata.uns["factors_error"] = errors + sdata.table.uns["factors"] = factors + sdata.table.uns["factors_error"] = errors print(emoji.emojize(":heavy_check_mark: Done.")) - return adata if copy else None -def _colocation_tensor(data: AnnData, copy: bool = False): +def _colocation_tensor(sdata: SpatialData, instance_key: str, feature_key: str): """ Convert a dictionary of colocation quotient values in long format to a dense tensor. Parameters ---------- - data : AnnData - Spatial formatted AnnData object. - copy : bool - Whether to return a copy of the AnnData object. Default False. + sdata : SpatialData + Spatial formatted SpatialData object. + instance_key : str + Key that specifies cell_boundaries instance in sdata. + feature_key : str + Key that specifies genes in sdata. """ - adata = data.copy() if copy else data - clqs = adata.uns["clq"] + clqs = sdata.table.uns["clq"] clq_long = [] for shape, clq in clqs.items(): @@ -89,10 +91,10 @@ def _colocation_tensor(data: AnnData, copy: bool = False): clq_long = pd.concat(clq_long, axis=0) clq_long["pair"] = ( - clq_long["gene"].astype(str) + "_" + clq_long["neighbor"].astype(str) + clq_long[feature_key].astype(str) + "_" + clq_long["neighbor"].astype(str) ) - label_names = ["compartment", "cell", "pair"] + label_names = ["compartment", instance_key, "pair"] labels = dict() label_orders = [] for name in label_names: @@ -106,28 +108,32 @@ def _colocation_tensor(data: AnnData, copy: bool = False): tensor = s.todense() print(tensor.shape) - adata.uns["tensor"] = tensor - adata.uns["tensor_labels"] = labels - adata.uns["tensor_names"] = label_names - - return adata + sdata.table.uns["tensor"] = tensor + sdata.table.uns["tensor_labels"] = labels + sdata.table.uns["tensor_names"] = label_names - -@track def coloc_quotient( - data: AnnData, - shapes: List[str] = ["cell_shape"], + sdata: SpatialData, + points_key: str = "transcripts", + instance_key: str = "cell_boundaries", + feature_key: str = "feature_name", + shapes: List[str] = ["cell_boundaries"], radius: int = 20, min_points: int = 10, min_cells: int = 0, - copy: bool = False, ): """Calculate pairwise gene colocalization quotient in each cell. Parameters ---------- - adata : AnnData - Spatial formatted AnnData object. + sdata : SpatialData + Spatial formatted SpatialData object. + points_key: str + Key that specifies transcript points in sdata. + instance_key : str + Key that specifies cell_boundaries instance in sdata. + feature_key : str + Key that specifies genes in sdata. shapes : list Specify which shapes to compute colocalization separately. radius : int @@ -136,35 +142,30 @@ def coloc_quotient( Minimum number of points for sample to be considered for colocalization, default 10 min_cells : int Minimum number of cells for gene to be considered for colocalization, default 0 - copy : bool - Whether to return a copy of the AnnData object. Default False. + Returns ------- - adata : AnnData - .uns['clq']: Pairwise gene colocalization similarity within each cell formatted as a long dataframe. + sdata : SpatialData + .table.uns['clq']: Pairwise gene colocalization similarity within each cell formatted as a long dataframe. """ - adata = data.copy() if copy else data - all_clq = dict() for shape in shapes: - shape_col = "_".join(str(shape).split("_")[:-1]) - points = get_points(adata, asgeo=False) - points[shape_col] = points[shape_col].astype(str) + points = get_points(sdata, points_key=points_key, astype="pandas", sync=True) points = ( - points.query(f"{shape_col} != '-1'") - .sort_values("cell")[["cell", "gene", "x", "y"]] + points.query(f"{instance_key} != ''") + .sort_values(instance_key)[[instance_key, feature_key, "x", "y"]] .reset_index(drop=True) ) # Keep genes expressed in at least min_cells cells - gene_counts = points.groupby("gene").size() + gene_counts = points.groupby(feature_key).size() valid_genes = gene_counts[gene_counts >= min_cells].index - points = points[points["gene"].isin(valid_genes)] + points = points[points[feature_key].isin(valid_genes)] # Partition so {chunksize} cells per partition cells, group_loc = np.unique( - points["cell"].astype(str), + points[instance_key].astype(str), return_index=True, ) @@ -175,29 +176,26 @@ def coloc_quotient( zip(cells, group_loc, end_loc), desc=shape, total=len(cells) ): cell_points = points.iloc[start:end] - cell_clq = _cell_clq(cell_points, adata.n_vars, radius, min_points) - cell_clq["cell"] = cell + cell_clq = _cell_clq(cell_points, radius, min_points, feature_key) + cell_clq[instance_key] = cell cell_clqs.append(cell_clq) cell_clqs = pd.concat(cell_clqs) - cell_clqs[["cell", "gene", "neighbor"]] = ( - cell_clqs[["cell", "gene", "neighbor"]].astype(str).astype("category") + cell_clqs[[instance_key, feature_key, "neighbor"]] = ( + cell_clqs[[instance_key, feature_key, "neighbor"]].astype(str).astype("category") ) cell_clqs["log_clq"] = cell_clqs["clq"].replace(0, np.nan).apply(np.log2) # Save to uns['clq'] as adjacency list all_clq[shape] = cell_clqs - adata.uns["clq"] = all_clq - - return adata if copy else None - + sdata.table.uns["clq"] = all_clq -def _cell_clq(cell_points, n_genes, radius, min_points): +def _cell_clq(cell_points, radius, min_points, feature_key): # Count number of points for each gene - gene_counts = cell_points["gene"].value_counts() + gene_counts = cell_points[feature_key].value_counts() # Keep genes with at least min_count gene_counts = gene_counts[gene_counts >= min_points] @@ -206,30 +204,30 @@ def _cell_clq(cell_points, n_genes, radius, min_points): return pd.DataFrame() # Get points - valid_points = cell_points[cell_points["gene"].isin(gene_counts.index)] + valid_points = cell_points[cell_points[feature_key].isin(gene_counts.index)] # Cleanup gene categories # valid_points["gene"] = valid_points["gene"].cat.remove_unused_categories() # Count number of source points that have neighbor gene point_neighbors = _count_neighbors( - valid_points, n_genes, radius=radius, agg="binary" + valid_points, len(valid_points[feature_key].cat.categories), radius=radius, agg="binary" ).toarray() neighbor_counts = ( - pd.DataFrame(point_neighbors, columns=valid_points["gene"].cat.categories) - .groupby(valid_points["gene"].values) + pd.DataFrame(point_neighbors, columns=valid_points[feature_key].cat.categories) + .groupby(valid_points[feature_key].values) .sum() .reset_index() .melt(id_vars="index") .query("value > 0") ) - neighbor_counts.columns = ["gene", "neighbor", "count"] - clq_df = _clq_statistic(neighbor_counts, gene_counts) + neighbor_counts.columns = [feature_key, "neighbor", "count"] + clq_df = _clq_statistic(neighbor_counts, gene_counts, feature_key) return clq_df -def _clq_statistic(neighbor_counts, counts): +def _clq_statistic(neighbor_counts, counts, feature_key): """ Compute the colocation quotient for each gene pair. @@ -241,7 +239,7 @@ def _clq_statistic(neighbor_counts, counts): Series of raw gene counts. """ clq_df = neighbor_counts.copy() - clq_df["clq"] = (clq_df["count"] / counts.loc[clq_df["gene"]].values) / ( + clq_df["clq"] = (clq_df["count"] / counts.loc[clq_df[feature_key]].values) / ( counts.loc[clq_df["neighbor"]].values / counts.sum() ) - return clq_df.drop("count", axis=1) + return clq_df.drop("count", axis=1) \ No newline at end of file diff --git a/bento/tools/_composition.py b/bento/tools/_composition.py index df9a364..b079556 100644 --- a/bento/tools/_composition.py +++ b/bento/tools/_composition.py @@ -5,9 +5,9 @@ import numpy as np from ..geometry import get_points -from .._utils import track +#from .._utils import track -from anndata import AnnData +from spatialdata._core.spatialdata import SpatialData def _get_compositions(points: pd.DataFrame, shape_names: list) -> pd.DataFrame: @@ -57,35 +57,24 @@ def _get_compositions(points: pd.DataFrame, shape_names: list) -> pd.DataFrame: return comp_stats -@track def comp_diff( - data: AnnData, shape_names: list, groupby: str, ref_group: str, copy: bool = False + sdata: SpatialData, shape_names: list, groupby: str, ref_group: str ): """Calculate the average difference in gene composition for shapes across batches of cells. Uses the Wasserstein distance. Parameters ---------- - data : anndata.AnnData - Spatial formatted AnnData object. + sdata : spatialdata.SpatialData + Spatial formatted SpatialData object. shape_names : list of str Names of shapes to calculate compositions for. groupby : str - Key in `adata.obs` to group cells by. + Key in `sdata.points['transcripts]` to group cells by. ref_group : str Reference group to compare other groups to. - copy : bool - Return a copy of `data` instead of writing to data, by default False. - - Returns - ------- - adata : anndata.AnnData - Returns `adata` if `copy=True`, otherwise adds fields to `data`: """ - - adata = data.copy() if copy else data - - points = get_points(data) + points = get_points(sdata, astype="pandas") # Get average gene compositions for each batch comp_stats = dict() @@ -94,7 +83,7 @@ def comp_diff( ref_comp = comp_stats[ref_group] - dims = [s.replace("_shape", "") for s in shape_names] + dims = [s.replace("_boundaries", "") for s in shape_names] for group in comp_stats.keys(): if group == ref_group: continue @@ -109,4 +98,4 @@ def comp_diff( index=ref_comp.index, ) - adata.uns[f"{groupby}_comp_stats"] = comp_stats + sdata.table.uns[f"{groupby}_comp_stats"] = comp_stats diff --git a/bento/tools/_decomposition.py b/bento/tools/_decomposition.py index 1909fa1..fa0a6b5 100644 --- a/bento/tools/_decomposition.py +++ b/bento/tools/_decomposition.py @@ -3,12 +3,12 @@ import numpy as np import pandas as pd import tensorly as tl -from anndata import AnnData +from spatialdata._core.spatialdata import SpatialData from scipy.stats import zscore from tensorly.decomposition import non_negative_parafac from tqdm.auto import tqdm -from .._utils import track +#from .._utils import track def decompose( @@ -111,47 +111,3 @@ def decompose( def rmse(tensor, tensor_mu): return np.sqrt((tensor[tensor != 0] - tensor_mu[tensor != 0]) ** 2).mean() - - -@track -def to_tensor( - data: AnnData, layers: List[str], scale: bool = False, copy: bool = False -): - """ - Generate tensor from data where dimensions are (layers, cells, genes). - - Parameters - ---------- - data : AnnData - Spatial formatted AnnData - layers : list of str - Keys in data.layers to build tensor. - scale : bool - Z scale across cells for each layer, by default False. - copy : bool - Return a copy of `data` instead of writing to data, by default False. - - Returns - ------- - adata : anndata.AnnData - `uns['tensor']` : np.ndarray - 3D numpy array of shape (len(layers), adata.n_obs, adata.n_vars) - """ - adata = data.copy() if copy else data - - # Build tensor from specified layers - tensor = [] - for l in layers: - tensor.append(adata.to_df(l).values) - - # Save tensor values - tensor = np.array(tensor) - - # Z scale across cells for each layer - if scale: - for i, layer in enumerate(tensor): - tensor[i] = zscore(layer, axis=1, nan_policy="omit") - - adata.uns["tensor"] = np.array(tensor) - - return adata diff --git a/bento/tools/_flux.py b/bento/tools/_flux.py index e132c8b..3bd79c7 100644 --- a/bento/tools/_flux.py +++ b/bento/tools/_flux.py @@ -7,10 +7,13 @@ import numpy as np import pandas as pd import rasterio +import rasterio.features import shapely -from anndata import AnnData +from spatialdata._core.spatialdata import SpatialData +from spatialdata.models import PointsModel, ShapesModel from kneed import KneeLocator from minisom import MiniSom +from shapely import Polygon from scipy.sparse import csr_matrix, vstack from sklearn.decomposition import TruncatedSVD, IncrementalPCA from sklearn.preprocessing import StandardScaler, minmax_scale, quantile_transform @@ -18,17 +21,15 @@ from tqdm.auto import tqdm from rich.progress import Progress -from bento._settings import settings -from bento._utils import register_points, track -from bento.geometry import get_points, sindex_points -from bento.tools._neighborhoods import _count_neighbors -from bento.tools._shape_features import analyze_shapes +from ..geometry import get_points, sjoin_points, set_points_metadata +from ..tools._neighborhoods import _count_neighbors +from ..tools._shape_features import analyze_shapes - -@track -@register_points("cell_raster", ["flux", "flux_embed", "flux_counts"]) def flux( - data: AnnData, + sdata: SpatialData, + points_key: str = "transcripts", + instance_key: str = "cell_boundaries", + feature_key: str = "feature_name", method: Literal["knn", "radius"] = "radius", n_neighbors: Optional[int] = None, radius: Optional[int] = 0.5, @@ -37,7 +38,7 @@ def flux( train_size: float = 1, use_highly_variable: bool = False, random_state: int = 11, - copy: bool = False, + recompute: bool = False ): """ RNAflux: Embedding each pixel as normalized local composition normalized by cell composition. @@ -46,8 +47,14 @@ def flux( Parameters ---------- - data : AnnData - Spatial formatted AnnData object. + sdata : SpatialData + Spatial formatted SpatialData object. + points_key : str + key for points element that holds transcript coordinates + instance_key : str + Key for cell_boundaries instances + feature_key : str + Key for gene instances method: str Method to use for local neighborhood. Either 'knn' or 'radius'. n_neighbors : int @@ -56,31 +63,29 @@ def flux( Radius to use for local neighborhood. Uses cell radius / 2 if None. res : float Resolution to use for rendering embedding. Default 0.05 samples at 5% original resolution (5 units between pixels) - copy : bool - Whether to return a copy the AnnData object. Default False. Returns ------- - adata : AnnData - .uns["flux"] : scipy.csr_matrix - [pixels x genes] sparse matrix of normalized local composition. - .uns["flux_embed"] : np.ndarray - [pixels x components] array of embedded flux values. - .uns["flux_color"] : np.ndarray - [pixels x 3] array of RGB values for visualization. - .uns["flux_genes"] : list + sdata : SpatialData + .points["{instance_key}_raster"]: pd.DataFrame + Length pixels DataFrame containing all computed flux values, embeddings, and colors as columns in a single DataFrame. + flux values: for each gene used in embedding. + embeddings: flux_embed_ for each component of the embedding. + colors: hex color codes for each pixel. + .table.uns["flux_genes"] : list List of genes used for embedding. - .uns["flux_variance_ratio"] : np.ndarray + .table.uns["flux_variance_ratio"] : np.ndarray [components] array of explained variance ratio for each component. """ - adata = data.copy() if copy else data - - settings.log.start("Running flux().") + if f"{instance_key}_raster" in sdata.points and len(sdata.points[f"{instance_key}_raster"].columns) > 3 and not recompute: + return + + if n_neighbors is None and radius is None: + radius = 50 - # Get points - adata.uns["points"] = get_points(adata).sort_values("cell") - points = get_points(adata)[["cell", "gene", "x", "y"]] + points = get_points(sdata, points_key=points_key, astype="pandas", sync=True) + points = points[[instance_key, feature_key, "x", "y"]].sort_values(instance_key) # Only use highly variable genes if use_highly_variable: @@ -113,23 +118,26 @@ def flux( step = 1 / res # Get grid rasters analyze_shapes( - adata, - "cell_shape", + sdata, + instance_key, "raster", progress=False, feature_kws=dict(raster={"step": step}), ) - # Long dataframe of raster points - adata.uns["cell_raster"] = adata.uns["cell_raster"].sort_values("cell") - raster_points = adata.uns["cell_raster"] + raster_points = get_points(sdata, points_key=f"{instance_key}_raster", astype="pandas", sync=True).sort_values(instance_key) + # Extract gene names and codes + gene_names = points[feature_key].cat.categories.tolist() + n_genes = len(gene_names) - points_grouped = points.groupby("cell") - rpoints_grouped = raster_points.groupby("cell") + points_grouped = points.groupby(instance_key) + rpoints_grouped = raster_points.groupby(instance_key) cells = list(points_grouped.groups.keys()) - # Compute cell composition for each cell - cell_composition = adata[cells, gene_names].X.toarray() + + cell_composition = sdata.table[cells, gene_names].X.toarray() + + # Compute cell composition cell_composition = cell_composition / (cell_composition.sum(axis=1).reshape(-1, 1)) cell_composition = np.nan_to_num(cell_composition) @@ -202,32 +210,47 @@ def flux( flux_sv = svd_model.components_ variance_ratio = svd_model.explained_variance_ratio_ - # Use the elbow method to determine the number of components to keep - kl = KneeLocator( - range(len(variance_ratio)), variance_ratio, curve="convex", direction="decreasing" - ) - if kl.elbow is not None: - n_components = kl.elbow - else: - n_components = len(variance_ratio) + # For color visualization of flux embeddings + flux_color = vec2color(flux_embed, fmt="hex", vmin=0.1, vmax=0.9) + pbar.update() + pbar.set_description(emoji.emojize("Saving")) + + flux_df = pd.DataFrame(cell_fluxs.todense().tolist(), columns=gene_names) + flux_embed_df = pd.DataFrame(flux_embed.tolist(), columns=[f'flux_embed_{i}' for i in range(len(flux_embed.tolist()[0]))]) + raster_points = pd.concat([raster_points, flux_df, flux_embed_df], axis=1, join='outer', ignore_index=False) - settings.log.step("Saving results") - adata.uns["flux"] = cell_fluxs # sparse gene embedding - adata.uns["flux_genes"] = gene_names # gene names - adata.uns["flux_embed"] = flux_embed - adata.uns["flux_sv"] = flux_sv - adata.uns["flux_n_components"] = n_components - adata.uns["flux_counts"] = rpoints_counts - adata.uns["flux_variance_ratio"] = variance_ratio + raster_points["flux_color"] = flux_color + flux_df = raster_points.drop(columns=["x", "y", instance_key]) + set_points_metadata(sdata, points_key=f"{instance_key}_raster", metadata=flux_df) + + sdata.table.uns["flux_variance_ratio"] = variance_ratio + sdata.table.uns["flux_genes"] = gene_names # gene names settings.log.end("Done.") - return adata if copy else None +def vec2color( + vec: np.ndarray, + fmt: Literal[ + "rgb", + "hex", + ] = "hex", + vmin: float = 0, + vmax: float = 1, +): + """Convert vector to color.""" + color = quantile_transform(vec[:, :3]) + color = minmax_scale(color, feature_range=(vmin, vmax)) + if fmt == "rgb": + pass + elif fmt == "hex": + color = np.apply_along_axis(mpl.colors.to_hex, 1, color, keep_alpha=True) + return color -@track def fluxmap( - data: AnnData, + sdata: SpatialData, + points_key: str = "transcripts", + instance_key: str = "cell_boundaries", n_clusters: Union[Iterable[int], int] = range(2, 9), n_components: Optional[int] = None, num_iterations: int = 1000, @@ -236,14 +259,17 @@ def fluxmap( res: float = 0.1, random_state: int = 11, plot_error: bool = True, - copy: bool = False, ): """Cluster flux embeddings using self-organizing maps (SOMs) and vectorize clusters as Polygon shapes. Parameters ---------- - data : AnnData - Spatial formatted AnnData object. + sdata : SpatialData + Spatial formatted SpatialData object. + points_key : str + key for points element that holds transcript coordinates + instance_key : str + Key for cell_boundaries instances n_clusters : int or list Number of clusters to use. If list, will pick best number of clusters using the elbow heuristic evaluated on the quantization error. @@ -257,35 +283,27 @@ def fluxmap( Random state to use for SOM training. Default 11. plot_error : bool Whether to plot quantization error. Default True. - copy : bool - Whether to return a copy the AnnData object. Default False. Returns ------- - adata : AnnData - .uns["cell_raster"] : DataFrame + sdata : SpatialData + .points["points"] : DataFrame Adds "fluxmap" column denoting cluster membership. - .uns["points"] : DataFrame - Adds "fluxmap#" columns for each cluster. - .obs : GeoSeries + .shapes["fluxmap#_shape"] : GeoSeries Adds "fluxmap#_shape" columns for each cluster rendered as (Multi)Polygon shapes. """ - adata = data.copy() if copy else data + + raster_points = get_points(sdata, points_key=f"{instance_key}_raster", astype="pandas", sync=True) # Check if flux embedding has been computed - if "flux_embed" not in adata.uns: + if "flux_embed_0" not in raster_points.columns: raise ValueError( "Flux embedding has not been computed. Run `bento.tl.flux()` first." ) - - if n_components is None: - n_components = adata.uns["flux_n_components"] - flux_embed = adata.uns["flux_embed"][:, :n_components] - raster_points = adata.uns["cell_raster"] - flux_counts = adata.uns["flux_counts"] - - # Exclude points with low counts - valid_points = flux_counts > min_points + + flux_embed = raster_points.filter(like='flux_embed_') + sorted_column_names = sorted(flux_embed.columns.tolist(), key=lambda x: int(x.split('_')[-1])) + flux_embed = flux_embed[sorted_column_names].to_numpy() if isinstance(n_clusters, int): n_clusters = [n_clusters] @@ -346,22 +364,18 @@ def fluxmap( qnt_index = np.ravel_multi_index(winner_coordinates, (1, best_k)) + 1 qnt_index[~valid_points] = 0 raster_points["fluxmap"] = qnt_index - adata.uns["cell_raster"] = raster_points.copy() + set_points_metadata(sdata, points_key=f"{instance_key}_raster", metadata=list(qnt_index), column_names="fluxmap") pbar.update() # Vectorize polygons in each cell pbar.set_description(emoji.emojize("Vectorizing domains")) - cells = raster_points["cell"].unique().tolist() - # Scale down to render resolution - # raster_points[["x", "y"]] = raster_points[["x", "y"]] * res + cells = raster_points[instance_key].unique().tolist() # Cast to int - raster_points[["x", "y", "fluxmap"]] = raster_points[["x", "y", "fluxmap"]].astype( - int - ) + raster_points[["x", "y", "fluxmap"]] = raster_points[["x", "y", "fluxmap"]].astype(int) - rpoints_grouped = raster_points.groupby("cell") + rpoints_grouped = raster_points.groupby(instance_key) fluxmap_df = dict() for cell in tqdm(cells, leave=False): rpoints = rpoints_grouped.get_group(cell) @@ -396,11 +410,10 @@ def fluxmap( # Group same fields as MultiPolygons shapes = shapes.dissolve("fluxmap")["geometry"] - fluxmap_df[cell] = shapes fluxmap_df = pd.DataFrame.from_dict(fluxmap_df).T - fluxmap_df.columns = "fluxmap" + fluxmap_df.columns.astype(str) + "_shape" + fluxmap_df.columns = "fluxmap" + fluxmap_df.columns.astype(str) + "_boundaries" # Upscale to match original resolution fluxmap_df = fluxmap_df.apply( @@ -411,20 +424,26 @@ def fluxmap( pbar.update() pbar.set_description("Saving") - old_cols = adata.obs.columns[adata.obs.columns.str.startswith("fluxmap")] - adata.obs = adata.obs.drop(old_cols, axis=1, errors="ignore") - - adata.obs[fluxmap_df.columns] = fluxmap_df.reindex(adata.obs_names) - old_cols = adata.uns["points"].columns[ - adata.uns["points"].columns.str.startswith("fluxmap") + old_shapes = [k for k in sdata.shapes.keys() if k.startswith("fluxmap")] + for key in old_shapes: + del sdata.shapes[key] + + transform = sdata.shapes[instance_key].attrs + fluxmap_df = fluxmap_df.reindex(sdata.table.obs_names).where(fluxmap_df.notna(), other=Polygon()) + for fluxmap in fluxmap_df.columns: + sdata.shapes[fluxmap] = ShapesModel.parse(gpd.GeoDataFrame(geometry=fluxmap_df[fluxmap])) + sdata.shapes[fluxmap].attrs = transform + + old_cols = sdata.points[points_key].columns[ + sdata.points[points_key].columns.str.startswith("fluxmap") ] - adata.uns["points"] = adata.uns["points"].drop(old_cols, axis=1) + sdata.points[points_key] = sdata.points[points_key].drop(old_cols, axis=1) # TODO SLOW - sindex_points(adata, "points", fluxmap_df.columns.tolist()) + sjoin_points(sdata=sdata, shape_keys=fluxmap_df.columns.tolist(), points_key=points_key) pbar.update() pbar.set_description("Done") pbar.close() - return adata if copy else None + \ No newline at end of file diff --git a/bento/tools/_flux_enrichment.py b/bento/tools/_flux_enrichment.py index f57c57b..7abeb95 100644 --- a/bento/tools/_flux_enrichment.py +++ b/bento/tools/_flux_enrichment.py @@ -2,104 +2,72 @@ import decoupler as dc +import numpy as np import pandas as pd +import dask.dataframe as dd import pkg_resources -from anndata import AnnData +from scipy import sparse +from spatialdata._core.spatialdata import SpatialData +from spatialdata.models import PointsModel -from bento._utils import track, _register_points +from ..geometry import get_points, set_points_metadata -def fe_kegg(data: AnnData, copy: bool = False, **kwargs) -> Optional[AnnData]: - """Compute enrichment scores from KEGG gene sets. - See `bento.tl.fe` docs for parameter details. - - Parameters - ---------- - data : AnnData - Spatial formatted AnnData object. - copy : bool - Return a copy instead of writing to `adata`. Default False. - Returns - ------- - DataFrame - Enrichment scores for each gene set. - """ - adata = data.copy() if copy else data - - msigdb = dc.get_resource('MSigDB') - msigdb = msigdb[msigdb['collection']=='KEGG'] - msigdb = msigdb[~msigdb.duplicated(['geneset', 'genesymbol'])] - - fe(adata, net=msigdb, weight=None, **kwargs) - - return adata if copy else None - - - -def fe_fazal2019(data: AnnData, copy: bool = False, **kwargs) -> Optional[AnnData]: +def fe_fazal2019(sdata: SpatialData, **kwargs): """Compute enrichment scores from subcellular compartment gene sets from Fazal et al. 2019 (APEX-seq). See `bento.tl.fe` docs for parameter details. Parameters ---------- - data : AnnData - Spatial formatted AnnData object. - copy : bool - Return a copy instead of writing to `adata`. Default False. + data : SpatialData + Spatial formatted SpatialData object. + Returns ------- DataFrame Enrichment scores for each gene set. """ - adata = data.copy() if copy else data gene_sets = load_gene_sets("fazal2019") - fe(adata, net=gene_sets, **kwargs) + fe(sdata, net=gene_sets, **kwargs) - return adata if copy else None - -def fe_xia2019(data: AnnData, copy: bool = False, **kwargs) -> Optional[AnnData]: +def fe_xia2019(sdata: SpatialData, **kwargs): """Compute enrichment scores from subcellular compartment gene sets from Xia et al. 2019 (MERFISH 10k U2-OS). See `bento.tl.fe` docs for parameters details. Parameters ---------- - data : AnnData - Spatial formatted AnnData object. - copy : bool - Return a copy instead of writing to `adata`. Default False. + data : SpatialData + Spatial formatted SpatialData object. + Returns ------- DataFrame Enrichment scores for each gene set. """ - adata = data.copy() if copy else data gene_sets = load_gene_sets("xia2019") - fe(adata, gene_sets, **kwargs) + fe(sdata, gene_sets, **kwargs) - return adata if copy else None - -@track def fe( - data: AnnData, + sdata: SpatialData, net: pd.DataFrame, + instance_key: Optional[str] = "cell_boundaries", source: Optional[str] = "source", target: Optional[str] = "target", weight: Optional[str] = "weight", batch_size: int = 10000, min_n: int = 0, - copy: bool = False, -) -> Optional[AnnData]: +): """ Perform functional enrichment on point embeddings. Wrapper for decoupler wsum function. Parameters ---------- - data : AnnData - Spatial formatted AnnData object. + sdata : SpatialData + Spatial formatted SpatialData object. net : DataFrame DataFrame with columns "source", "target", and "weight". See decoupler API for more details. source : str, optional @@ -112,79 +80,59 @@ def fe( Number of points to process in each batch. Default 10000. min_n : int Minimum number of targets per source. If less, sources are removed. - copy : bool - Return a copy instead of writing to `adata`. Default False. Returns ------- - adata : AnnData - uns["flux_fe"] : DataFrame + sdata : SpatialData + .points["cell_boundaries_raster"]["flux_fe"] : DataFrame Enrichment scores for each gene set. """ - - adata = data.copy() if copy else data - # Make sure embedding is run first - if "flux" not in data.uns: + if "flux_genes" in sdata.table.uns: + flux_genes = set(sdata.table.uns["flux_genes"]) + cell_raster_columns = set(sdata.points[f"{instance_key}_raster"].columns) + if len(flux_genes.intersection(cell_raster_columns)) != len(flux_genes): + print("Recompute bento.tl.flux first.") + return + else: print("Run bento.tl.flux first.") return + + flux_genes = sdata.table.uns["flux_genes"] + cell_raster_points = get_points(sdata, points_key=f"{instance_key}_raster", astype="pandas", sync=True)[flux_genes].values + cell_raster_matrix = np.mat(cell_raster_points) + mat = sparse.csr_matrix(cell_raster_matrix) # sparse matrix in csr format - mat = adata.uns["flux"] # sparse matrix in csr format - zero_rows = mat.getnnz(1) == 0 - - samples = adata.uns["cell_raster"].index.astype(str) - features = adata.uns["flux_genes"] - - if weight: - enrichment = dc.run_wsum( - mat=[mat, samples, features], - net=net, - source=source, - target=target, - weight=weight, - batch_size=batch_size, - min_n=min_n, - verbose=True, - ) - else: - enrichment = dc.run_gsea( - mat=[mat, samples, features], - net=net, - source=source, - target=target, - times=1000, - batch_size=batch_size, - min_n=min_n, - verbose=True, - ) - - scores = enrichment[1].reindex(index=samples) + samples = sdata.points[f"{instance_key}_raster"].index.astype(str) + features = sdata.table.uns["flux_genes"] - for col in scores.columns: - score_key = f"flux_{col}" - adata.uns[score_key] = scores[col].values + enrichment = dc.run_wsum( + mat=[mat, samples, features], + net=net, + source=source, + target=target, + weight=weight, + batch_size=batch_size, + min_n=min_n, + verbose=True, + ) - # Manually call register_points since it is dynamic - _register_points(adata, "cell_raster", [score_key]) + scores = enrichment[1].reindex(index=samples).add_prefix('flux_') + set_points_metadata(sdata, points_key=f"{instance_key}_raster", metadata=scores) - _fe_stats(adata, net, source=source, target=target, copy=copy) - - return adata if copy else None + _fe_stats(sdata, net, source=source, target=target) def _fe_stats( - data: AnnData, + sdata: SpatialData, net: pd.DataFrame, source: str = "source", target: str = "target", - copy: bool = False, ): - adata = data.copy() if copy else data - # rows = cells, columns = pathways, values = count of genes in pathway - expr_binary = adata.to_df() >= 5 + expr_binary = sdata.table.to_df() >= 5 # {cell : present gene list} - expr_genes = expr_binary.apply(lambda row: adata.var_names[row], axis=1) + expr_genes = expr_binary.apply(lambda row: sdata.table.var_names[row], axis=1) # Count number of genes present in each pathway net_ngenes = net.groupby(source).size().to_frame().T.rename(index={0: "n_genes"}) @@ -195,17 +143,13 @@ def _fe_stats( for source, group in net.groupby(source): sources.append(source) common = expr_genes.apply(lambda genes: set(genes).intersection(group[target])) - # common_genes[source] = np.array(common) common_ngenes.append(common.apply(len)) fe_stats = pd.concat(common_ngenes, axis=1) fe_stats.columns = sources - adata.uns["fe_stats"] = fe_stats - # adata.uns["fe_genes"] = common_genes - adata.uns["fe_ngenes"] = net_ngenes - - return adata if copy else None + sdata.table.uns["fe_stats"] = fe_stats + sdata.table.uns["fe_ngenes"] = net_ngenes gene_sets = dict( @@ -235,4 +179,4 @@ def load_gene_sets(name): stream = pkg_resources.resource_stream(__name__, f"gene_sets/{fname}") gs = pd.read_csv(stream) - return gs + return gs \ No newline at end of file diff --git a/bento/tools/_lp.py b/bento/tools/_lp.py old mode 100755 new mode 100644 index c2ff668..82af1ba --- a/bento/tools/_lp.py +++ b/bento/tools/_lp.py @@ -1,71 +1,94 @@ -import os +from typing import List, Optional, Union import pickle +import warnings + +warnings.filterwarnings("ignore") import bento import numpy as np import pandas as pd import statsmodels.formula.api as sfm +from pandas.api.types import is_numeric_dtype from patsy import PatsyError from statsmodels.tools.sm_exceptions import PerfectSeparationError from tqdm.auto import tqdm -from anndata import AnnData -from .._utils import track -from .._constants import PATTERN_NAMES, PATTERN_FEATURES +from spatialdata._core.spatialdata import SpatialData -tqdm.pandas() +from .._constants import PATTERN_NAMES +tqdm.pandas() -@track -def lp(data: AnnData, groupby: str = "gene", copy: bool = False): +def lp( + sdata: SpatialData, + instance_key: str = "cell_boundaries", + nucleus_key: str = "nucleus_boundaries", + groupby: Optional[Union[str, List[str]]] = "gene" +): """Predict transcript subcellular localization patterns. Patterns include: cell edge, cytoplasmic, nuclear edge, nuclear, none Parameters ---------- - data : AnnData - Spatial formatted AnnData object + sdata : SpatialData + Spatial formatted SpatialData object + groupby : str or list of str, optional (default: None) - Key in `data.uns['points'] to groupby, by default None. Always treats each cell separately - copy : bool - Return a copy of `data` instead of writing to data, by default False. + Key in `sdata.points[points_key] to groupby, by default None. Always treats each cell separately Returns ------- - adata : AnnData - .uns['lp']: DataFrame + sdata : SpatialData + .table.uns['lp']: DataFrame Localization pattern indicator matrix. - .uns['lpp']: DataFrame + .table.uns['lpp']: DataFrame Localization pattern probabilities. """ - adata = data.copy() if copy else data if isinstance(groupby, str): groupby = [groupby] + pattern_features = [ # Do not change order of features! + f"{instance_key}_inner_proximity", + f"{nucleus_key}_inner_proximity", + f"{nucleus_key}_outer_proximity", + f"{instance_key}_inner_asymmetry", + f"{nucleus_key}_inner_asymmetry", + f"{nucleus_key}_outer_asymmetry", + "l_max", + "l_max_gradient", + "l_min_gradient", + "l_monotony", + "l_half_radius", + "point_dispersion_norm", + f"{nucleus_key}_dispersion_norm", + ] + + # Compute features - feature_key = f"cell_{'_'.join(groupby)}_features" - if feature_key not in adata.uns.keys() or not all( - f in adata.uns[feature_key].columns for f in PATTERN_FEATURES + feature_key = f"{instance_key}_{'_'.join(groupby)}_features" + if feature_key not in sdata.table.uns.keys() or not all( + f in sdata.table.uns[feature_key].columns for f in pattern_features ): bento.tl.analyze_points( - adata, - "cell_shape", + sdata, + instance_key, ["proximity", "asymmetry", "ripley", "point_dispersion_norm"], groupby=groupby, recompute=True, ) bento.tl.analyze_points( - adata, - "nucleus_shape", + sdata, + nucleus_key, ["proximity", "asymmetry", "shape_dispersion_norm"], groupby=groupby, recompute=True, ) - X_df = adata.uns[feature_key][PATTERN_FEATURES] + + X_df = sdata.table.uns[feature_key][pattern_features] # Load trained model - model_path = os.path.join(os.path.split(bento.__file__)[0], "models", "rf_calib_20220514.pkl") - model = pickle.load(open(model_path, "rb")) + model_dir = "/".join(bento.__file__.split("/")[:-1]) + "/models" + model = pickle.load(open(f"{model_dir}/rf_calib_20220514.pkl", "rb")) # Compatibility with newer versions of scikit-learn for cls in model.calibrated_classifiers_: @@ -78,79 +101,71 @@ def lp(data: AnnData, groupby: str = "gene", copy: bool = False): ) # Add cell and groupby identifiers - pattern_prob.index = adata.uns[feature_key].set_index(["cell", *groupby]).index + pattern_prob.index = sdata.table.uns[feature_key].set_index([instance_key, *groupby]).index # Threshold probabilities to get indicator matrix thresholds = [0.45300, 0.43400, 0.37900, 0.43700, 0.50500] indicator_df = (pattern_prob >= thresholds).replace({True: 1, False: 0}) - adata.uns["lp"] = indicator_df.reset_index() - adata.uns["lpp"] = pattern_prob.reset_index() - return adata if copy else None - + sdata.table.uns["lp"] = indicator_df.reset_index() + sdata.table.uns["lpp"] = pattern_prob.reset_index() -@track -def lp_stats(data: AnnData, copy: bool = False): +def lp_stats(sdata: SpatialData, instance_key: str = "cell_boundaries"): """Computes frequencies of localization patterns across cells and genes. Parameters ---------- - data : AnnData - Spatial formatted AnnData object. - copy : bool - Whether to return a copy of the AnnData object. Default False. + sdata : SpatialData + Spatial formatted SpatialData object. + instance_key : str + cell boundaries instance key + Returns ------- - adata : AnnData - .uns['lp_stats']: DataFrame of localization pattern frequencies. + sdata : SpatialData + .table.uns['lp_stats']: DataFrame of localization pattern frequencies. """ - adata = data.copy() if copy else data - - lp = adata.uns["lp"] + lp = sdata.table.uns["lp"] cols = lp.columns groupby = list(cols[~cols.isin(PATTERN_NAMES)]) - groupby.remove("cell") - - g_pattern_counts = lp.groupby(groupby).apply(lambda df: df[PATTERN_NAMES].sum()) - adata.uns["lp_stats"] = g_pattern_counts + groupby.remove(instance_key) - return adata if copy else None + g_pattern_counts = lp.groupby(groupby).apply(lambda df: df[PATTERN_NAMES].sum().astype(int)) + sdata.table.uns["lp_stats"] = g_pattern_counts - -def _lp_logfc(data, phenotype=None): +def _lp_logfc(sdata, instance_key, phenotype=None): """Compute pairwise log2 fold change of patterns between groups in phenotype. Parameters ---------- - data : AnnData - Spatial formatted AnnData object. + data : SpatialData + Spatial formatted SpatialData object. + instance_key: str + cell boundaries instance key phenotype : str - Variable grouping cells for differential analysis. Must be in data.obs.columns. + Variable grouping cells for differential analysis. Must be in sdata.shapes["cell_boundaries"].columns. Returns ------- gene_fc_stats : DataFrame log2 fold change of patterns between groups in phenotype. """ - stats = data.uns["lp_stats"] + stats = sdata.table.uns["lp_stats"] - if phenotype not in data.obs.columns: + if phenotype not in sdata.shapes[instance_key].columns: raise ValueError("Phenotype is invalid.") - phenotype_vector = data.obs[phenotype] + phenotype_vector = sdata.shapes[instance_key][phenotype] - pattern_df = data.uns["lp"].copy() + pattern_df = sdata.table.uns["lp"].copy() groups_name = stats.index.name - pattern_df[["cell", groups_name]] = data.uns[f"cell_{groups_name}_features"][ - ["cell", groups_name] - ] gene_fc_stats = [] for c in PATTERN_NAMES: # save pattern frequency to new column, one for each group group_freq = ( - pattern_df.pivot(index="cell", columns=groups_name, values=c) + pattern_df.pivot(index=instance_key, columns=groups_name, values=c) .replace("none", np.nan) .astype(float) .groupby(phenotype_vector) @@ -193,16 +208,17 @@ def log2fc(group_col): return gene_fc_stats - -def _lp_diff_gene(cell_by_pattern, phenotype_vector): +def _lp_diff_gene(cell_by_pattern, phenotype_series, instance_key): """Perform pairwise comparison between groupby and every class. Parameters ---------- cell_by_pattern : DataFrame Cell by pattern matrix. - phenotype_vector : Series + phenotype_series : Series Series of cell groupings. + instance_key : str + cell boundaries instance key Returns ------- @@ -212,10 +228,9 @@ def _lp_diff_gene(cell_by_pattern, phenotype_vector): cell_by_pattern = cell_by_pattern.dropna().reset_index(drop=True) # One hot encode categories - group_dummies = pd.get_dummies(pd.Series(phenotype_vector)) - # group_dummies.columns = [f"{phenotype}_{g}" for g in group_dummies.columns] + group_dummies = pd.get_dummies(phenotype_series) group_names = group_dummies.columns.tolist() - group_data = pd.concat([cell_by_pattern, group_dummies], axis=1) + group_data = cell_by_pattern.set_index(instance_key).join(group_dummies, how='inner') group_data.columns = group_data.columns.astype(str) # Perform one group vs rest logistic regression @@ -255,66 +270,45 @@ def _lp_diff_gene(cell_by_pattern, phenotype_vector): return results if len(results) > 0 else None - -@track -def lp_diff( - data: AnnData, phenotype: str = None, continuous: bool = False, copy: bool = False +def lp_diff_discrete( + sdata: SpatialData, + instance_key: str = "cell_boundaries", + phenotype: str = None ): """Gene-wise test for differential localization across phenotype of interest. Parameters ---------- - data : AnnData - Spatial formatted AnnData object. + sdata : SpatialData + Spatial formatted SpatialData object. + instance_key : str + cell boundaries instance key phenotype : str - Variable grouping cells for differential analysis. Must be in data.obs.columns. - continuous : bool - Whether the phenotype is continuous or categorical. By default False. - copy : bool - Return a copy of `data` instead of writing to data, by default False. + Variable grouping cells for differential analysis. Must be in sdata.shape["cell_boundaries].columns. Returns ------- - adata : AnnData - Spatial formatted AnnData object. - .uns['diff_{phenotype}'] : DataFrame + sdata : SpatialData + Spatial formatted SpatialData object. + .table.uns['diff_{phenotype}'] : DataFrame Long DataFrame with differential localization test results across phenotype groups. """ - adata = data.copy() if copy else data - - stats = adata.uns["lp_stats"] + stats = sdata.table.uns["lp_stats"] # Retrieve cell phenotype - phenotype_vector = adata.obs[phenotype].tolist() - - # TODO untested/incomplete - if continuous: - pattern_dfs = {} - - # Compute correlation for each point group along cells - for p in PATTERN_NAMES: - p_labels = adata.uns["lp"][p] - groups_name = stats.index.name - p_labels[["cell", groups_name]] = adata.uns[f"cell_{groups_name}_features"][ - ["cell", groups_name] - ] - p_labels = p_labels.pivot(index="cell", columns="gene", values=p) - p_corr = p_labels.corrwith(phenotype_vector, drop=True) - pattern_dfs[p] = p_labels + phenotype_series = sdata.shapes[instance_key][phenotype] + if is_numeric_dtype(phenotype_series): + raise KeyError(f"Phenotype dtype must not be numeric | dtype: {phenotype_series.dtype}") - else: - # [Sample by patterns] where sample id = [cell, group] pair - pattern_df = adata.uns["lp"].copy() - groups_name = stats.index.name - pattern_df[["cell", groups_name]] = adata.uns[f"cell_{groups_name}_features"][ - ["cell", groups_name] - ] - - diff_output = ( - pattern_df.groupby(groups_name) - .progress_apply(lambda gp: _lp_diff_gene(gp, phenotype_vector)) - .reset_index() - ) + # [Sample by patterns] where sample id = [cell, group] pair + pattern_df = sdata.table.uns["lp"].copy() + groups_name = stats.index.name + + diff_output = ( + pattern_df.groupby(groups_name) + .progress_apply(lambda gp: _lp_diff_gene(gp, phenotype_series, instance_key)) + .reset_index() + ) # FDR correction diff_output["padj"] = diff_output["pvalue"] * diff_output[groups_name].nunique() @@ -326,11 +320,11 @@ def lp_diff( results["-log10padj"] = -np.log10(results["padj"].astype(np.float32)) # Cap significance values - results.loc[results["-log10p"] > 20, "-log10p"] = 20 - results.loc[results["-log10padj"] > 12, "-log10padj"] = 12 + results.loc[results["-log10p"] == np.inf, "-log10p"] = results.loc[results["-log10p"] != np.inf]["-log10p"].max() + results.loc[results["-log10padj"] == np.inf, "-log10padj"] = results.loc[results["-log10padj"] != np.inf]["-log10padj"].max() # Group-wise log2 fold change values - log2fc_stats = _lp_logfc(adata, phenotype) + log2fc_stats = _lp_logfc(sdata, instance_key, phenotype) # Join log2fc results to p value df results = ( @@ -340,9 +334,58 @@ def lp_diff( ) # Sort results - results = results.sort_values("pvalue") + results = results.sort_values("pvalue").reset_index(drop=True) + del results["level_1"] + # Save back to SpatialData + sdata.table.uns[f"diff_{phenotype}"] = results + +def lp_diff_continuous( + sdata: SpatialData, + instance_key: str = "cell_boundaries", + phenotype: str = None +): + """Gene-wise test for differential localization across phenotype of interest. - # Save back to AnnData - adata.uns[f"diff_{phenotype}"] = results + Parameters + ---------- + sdata : SpatialData + Spatial formatted SpatialData object. + instance_key : str + cell boundaries instance key + phenotype : str + Variable grouping cells for differential analysis. Must be in sdata.shape["cell_boundaries].columns. - return adata if copy else None + Returns + ------- + sdata : SpatialData + Spatial formatted SpatialData object. + .table.uns['diff_{phenotype}'] : DataFrame + Long DataFrame with differential localization test results across phenotype groups. + """ + stats = sdata.table.uns["lp_stats"] + lpp = sdata.table.uns["lpp"] + # Retrieve cell phenotype + phenotype_series = sdata.shapes[instance_key][phenotype] + + + pattern_dfs = {} + # Compute correlation for each point group along cells + for p in PATTERN_NAMES: + groups_name = stats.index.name + p_labels = lpp.pivot(index=instance_key, columns=groups_name, values=p) + p_corr = p_labels.corrwith(phenotype_series, axis=0, drop=True) + + pattern_df = pd.DataFrame(p_corr).reset_index(drop = False) + pattern_df.insert(loc = 1, column = 'pattern', value = p) + pattern_df = pattern_df.rename(columns = {0:'pearson_correlation'}) + pattern_dfs[p] = pattern_df + + # Concatenate all pattern_dfs into one + pattern_dfs = ( + pd.concat(pattern_dfs.values(), ignore_index=True) + .sort_values(by=['pearson_correlation'], ascending=False) + .reset_index(drop=True) + ) + + pattern_dfs = pattern_dfs.loc[~pattern_dfs['pearson_correlation'].isna()] + sdata.table.uns[f"diff_{phenotype}"] = pattern_dfs \ No newline at end of file diff --git a/bento/tools/_neighborhoods.py b/bento/tools/_neighborhoods.py index 40a9ed6..2f80dd0 100644 --- a/bento/tools/_neighborhoods.py +++ b/bento/tools/_neighborhoods.py @@ -5,7 +5,7 @@ def _count_neighbors( - points, n_genes, query_points=None, n_neighbors=None, radius=None, agg="gene" + points, n_genes, query_points=None, n_neighbors=None, radius=None, agg="feature_name" ): """Build nearest neighbor index for points. @@ -60,8 +60,8 @@ def _count_neighbors( print(points.shape, query_points.shape) # Get gene-level neighbor counts for each gene - if agg == "gene": - gene_code = points["gene"].values + if agg == "feature_name": + gene_code = points["feature_name"].values source_genes, source_indices = np.unique(gene_code, return_index=True) gene_index = [] @@ -78,13 +78,13 @@ def _count_neighbors( for neighbor, count in zip(neighbor_names, neighbor_counts): gene_index.append([g, neighbor, count]) - gene_index = pd.DataFrame(gene_index, columns=["gene", "neighbor", "count"]) + gene_index = pd.DataFrame(gene_index, columns=["feature_name", "neighbor", "count"]) return gene_index else: # Get gene-level neighbor counts for each point - gene_codes = points["gene"].cat.codes.values + gene_codes = points["feature_name"].cat.codes.values neighborhood_sizes = np.array([len(n) for n in neighbor_index]) flat_nindex = np.concatenate(neighbor_index) # Get gene name for each neighbor @@ -111,4 +111,4 @@ def _count_neighbors( point_ncounts = np.array(point_ncounts) point_ncounts = csr_matrix(point_ncounts) - return point_ncounts + return point_ncounts \ No newline at end of file diff --git a/bento/tools/_point_features.py b/bento/tools/_point_features.py index 21d29af..45f2118 100644 --- a/bento/tools/_point_features.py +++ b/bento/tools/_point_features.py @@ -3,83 +3,86 @@ warnings.filterwarnings("ignore", category=ShapelyDeprecationWarning) - from abc import ABCMeta, abstractmethod from typing import List, Optional, Union import numpy as np import pandas as pd -from anndata import AnnData +from spatialdata._core.spatialdata import SpatialData from astropy.stats.spatial import RipleysKEstimator from scipy.spatial import distance -from scipy.stats.stats import spearmanr +from scipy.stats import spearmanr from tqdm.auto import tqdm +from math import isnan import re from .. import tools as tl -from .._utils import track from ..geometry import get_points - -@track def analyze_points( - data: AnnData, - shape_names: List[str], + sdata: SpatialData, + shape_keys: List[str], feature_names: List[str], + points_key: str = "transcripts", + instance_key: str = "cell_boundaries", groupby: Optional[Union[str, List[str]]] = None, recompute=False, progress: bool = False, - copy: bool = False, ): """Calculate the set of specified `features` for each point group. Groups are within each cell. + + When creating the points_df, it first grabs sdata.points[points_key] and joins shape polygons from sdata.shapes[shape_keys]. + The second join is to sdata.shapes[instance_key] to pull in cell polygons and cell features. + The shape indices in the points object are renamed to have _index as a suffix to avoid conflicts. + The joined polygons are named with it's respective shape_key. Parameters ---------- - data : AnnData - Spatially formatted AnnData - shape_names : str or list of str + sdata : SpatialData + Spatially formatted SpatialData + shape_keys : str or list of str Names of the shapes to analyze. feature_names : str or list of str Names of the features to analyze. groupby : str or list of str, optional - Key(s) in `data.uns['points'] to groupby, by default None. Always treats each cell separately - copy : bool - Return a copy of `data` instead of writing to data, by default False. + Key(s) in `data.points['points'] to groupby, by default None. Always treats each cell separately Returns ------- - adata : anndata.AnnData - .uns["point_features"] + sdata : spatialdata.SpatialData + table.uns["point_feature"] See the output of each :class:`PointFeature` in `features` for keys added. - `.obsm[`cell_`]` - DataFrame with rows aligned to `adata.obs_names` and `features` as columns. """ - adata = data.copy() if copy else data # Cast to list if not already - if isinstance(shape_names, str): - shape_names = [shape_names] - + if isinstance(shape_keys, str): + shape_keys = [shape_keys] + # Cast to list if not already if isinstance(feature_names, str): feature_names = [feature_names] # Make sure groupby is a list if isinstance(groupby, str): - groupby = ["cell", groupby] + groupby = [f"{instance_key}_index", groupby] elif isinstance(groupby, list): - groupby = ["cell"] + groupby + groupby = [f"{instance_key}_index"] + groupby else: - groupby = ["cell"] + groupby = [f"{instance_key}_index"] + + # Make sure points are sjoined to all shapes in shape_keys + shapes_found = set(shape_keys).intersection(set(sdata.points[points_key].columns)) + if len(shapes_found) != len(shape_keys): + raise KeyError(f"sdata.points[{points_key}] does not have all columns: {shape_keys}. Please run sjoin_points first.") # Make sure all groupby keys are in point columns for g in groupby: - if g not in get_points(adata).columns: + if g != f"{instance_key}_index" and g not in get_points(sdata, points_key=points_key, astype="dask", sync=True).columns: raise ValueError(f"Groupby key {g} not found in point columns.") - + # Generate feature x shape combinations - feature_combos = [point_features[f](s) for f in feature_names for s in shape_names] + feature_combos = [point_features[f](instance_key, s) for f in feature_names for s in shape_keys] # Compile dependency set of features and attributes cell_features = set() @@ -87,39 +90,44 @@ def analyze_points( for f in feature_combos: cell_features.update(f.cell_features) obs_attrs.update(f.attributes) - + cell_features = list(cell_features) obs_attrs = list(obs_attrs) - + print("Crunching shape features...") tl.analyze_shapes( - adata, "cell_shape", cell_features, progress=progress, recompute=recompute + sdata=sdata, + shape_keys=instance_key, + feature_names=cell_features, + progress=progress, + recompute=recompute ) - # Make sure attributes are present - attrs_found = set(obs_attrs).intersection(set(adata.obs.columns.tolist())) - if len(attrs_found) != len(obs_attrs): + # Make sure points are sjoined to all shapes in shape_keys + attributes = [attr for attr in obs_attrs if attr not in shape_keys] + attributes.append(instance_key) + attrs_found = set(attributes).intersection(set(sdata.shapes[instance_key].columns.tolist())) + if len(attrs_found) != len(attributes): raise KeyError(f"df does not have all columns: {obs_attrs}.") - # extract cell attributes + points_df = get_points(sdata, points_key=points_key, astype="geopandas", sync=True).set_index(instance_key) + # Pull all shape polygons into the points dataframe + for shape in list(set(obs_attrs).intersection(set([x for x in shape_keys if x != instance_key]))): + points_df = ( + points_df.join(sdata.shapes[shape], on=instance_key, lsuffix="", rsuffix=f"_{shape}") + .drop("cell_boundaries", axis=1) + .rename(columns={shape: f"{shape}_index", f"geometry_{shape}": shape}) + ) + + # Pull cell_boundaries shape features into the points dataframe points_df = ( - get_points(adata, asgeo=True) - .set_index("cell") - .join(data.obs[obs_attrs]) + points_df.join(sdata.shapes[instance_key][attributes]) + .rename_axis(f"{instance_key}_index") .reset_index() ) for g in groupby: points_df[g] = points_df[g].astype("category") - # Handle categories as strings to avoid ambiguous cat types - # for col in points_df.loc[:, (points_df.dtypes == "category").values]: - # points_df[col] = points_df[col].astype(str) - - # Handle shape indexes as strings to avoid ambiguous types - for shape_name in adata.obs.columns[adata.obs.columns.str.endswith("_shape")]: - shape_prefix = "_".join(shape_name.split("_")[:-1]) - if shape_prefix in points_df.columns: - points_df[shape_prefix] = points_df[shape_prefix].astype(str) # Calculate features for a sample def process_sample(df): @@ -136,10 +144,10 @@ def process_partition(partition_df): # Process points of each cell separately cells, group_loc = np.unique( - points_df["cell"], + points_df[f"{instance_key}_index"], return_index=True, ) - + end_loc = np.append(group_loc[1:], points_df.shape[0]) output = [] @@ -151,19 +159,19 @@ def process_partition(partition_df): for start, end in group_locs: cell_points = points_df.iloc[start:end] output.append(process_partition(cell_points)) - output = pd.concat(output) + output = pd.concat(output) # Save and overwrite existing print("Saving results...") + groupby[groupby.index(f"{instance_key}_index")] = instance_key output_key = "_".join([*groupby, "features"]) - if output_key in adata.uns: - adata.uns[output_key][output.columns] = output.reset_index(drop=True) + if output_key in sdata.table.uns: + sdata.table.uns[output_key][output.columns] = output.reset_index(drop=True).rename(columns={f"{instance_key}_index": instance_key}) else: - adata.uns[output_key] = output.reset_index() + sdata.table.uns[output_key] = output.reset_index().rename(columns={f"{instance_key}_index": instance_key}) print("Done.") - return adata if copy else None - + return sdata class PointFeature(metaclass=ABCMeta): """Abstract class for calculating sample features. A sample is defined as the set of @@ -177,14 +185,14 @@ class PointFeature(metaclass=ABCMeta): Names (keys) used to store computed cell-level features. """ - def __init__(self, shape_name): + def __init__(self, instance_key, shape_key): self.cell_features = set() self.attributes = set() - - if shape_name: - self.attributes.add(shape_name) - self.shape_name = shape_name - self.shape_prefix = "_".join(shape_name.split("_")[:-1]) + self.instance_key = instance_key + + if shape_key: + self.attributes.add(shape_key) + self.shape_key = shape_key @abstractmethod def extract(self, df): @@ -199,11 +207,11 @@ def extract(self, df): class ShapeProximity(PointFeature): - """For a set of points, computes the proximity of points within `shape_name` - as well as the proximity of points outside `shape_name`. Proximity is defined as - the average absolute distance to the specified `shape_name` normalized by cell - radius. Values closer to 0 denote farther from the `shape_name`, values closer - to 1 denote closer to the `shape_name`. + """For a set of points, computes the proximity of points within `shape_key` + as well as the proximity of points outside `shape_key`. Proximity is defined as + the average absolute distance to the specified `shape_key` normalized by cell + radius. Values closer to 0 denote farther from the `shape_key`, values closer + to 1 denote closer to the `shape_key`. Attributes ---------- @@ -215,37 +223,45 @@ class ShapeProximity(PointFeature): Returns ------- dict - `"{shape_prefix}_inner_proximity"`: proximity of points inside `shape_name` - `"{shape_prefix}_outer_proximity"`: proximity of points outside `shape_name` + `"{shape_key}_inner_proximity"`: proximity of points inside `shape_key` + `"{shape_key}_outer_proximity"`: proximity of points outside `shape_key` """ - def __init__(self, shape_name): - super().__init__(shape_name) + def __init__(self, instance_key, shape_key): + super().__init__(instance_key, shape_key) self.cell_features.add("radius") - self.attributes.add("cell_radius") - self.shape_name = shape_name + self.attributes.add(f"{self.instance_key}_radius") def extract(self, df): df = super().extract(df) # Get shape polygon - shape = df[self.shape_name].values[0] + shape = df[self.shape_key].values[0] + + # Skip if no shape or if shape is nan + try: + if isnan(shape): + return { + f"{self.shape_key}_inner_proximity": 0, + f"{self.shape_key}_outer_proximity": 0, + } + except: + pass - # Skip if no shape if not shape: return { - f"{self.shape_prefix}_inner_proximity": 0, - f"{self.shape_prefix}_outer_proximity": 0, + f"{self.shape_key}_inner_proximity": 0, + f"{self.shape_key}_outer_proximity": 0, } # Get points points_geo = df["geometry"] # Check for points within shape, assume all are intracellular - if self.shape_prefix == "cell": + if self.shape_key == self.instance_key: inner = np.array([True] * len(df)) else: - inner = df[self.shape_prefix] != "-1" + inner = df[f"{self.shape_key}_index"] != "" outer = ~inner inner_dist = np.nan @@ -258,7 +274,7 @@ def extract(self, df): outer_dist = points_geo[outer].distance(shape.boundary).mean() # Scale from [0, 1], where 1 is close and 0 is far. - cell_radius = df["cell_radius"].values[0] + cell_radius = df[f"{self.instance_key}_radius"].values[0] inner_proximity = (cell_radius - inner_dist) / cell_radius outer_proximity = (cell_radius - outer_dist) / cell_radius @@ -269,16 +285,16 @@ def extract(self, df): outer_proximity = 0 return { - f"{self.shape_prefix}_inner_proximity": inner_proximity, - f"{self.shape_prefix}_outer_proximity": outer_proximity, + f"{self.shape_key}_inner_proximity": inner_proximity, + f"{self.shape_key}_outer_proximity": outer_proximity, } class ShapeAsymmetry(PointFeature): - """For a set of points, computes the asymmetry of points within `shape_name` - as well as the asymmetry of points outside `shape_name`. Asymmetry is defined as + """For a set of points, computes the asymmetry of points within `shape_key` + as well as the asymmetry of points outside `shape_key`. Asymmetry is defined as the offset between the centroid of points to the centroid of the specified - `shape_name`, normalized by cell radius. Values closer to 0 denote symmetry, + `shape_key`, normalized by cell radius. Values closer to 0 denote symmetry, values closer to 1 denote asymmetry. Attributes @@ -287,43 +303,51 @@ class ShapeAsymmetry(PointFeature): Set of cell-level features needed for computing sample-level features cell_attributes : int Names (keys) used to store computed cell-level features - shape_name : str + shape_key : str Name of shape to use, must be column name in input DataFrame Returns ------- dict - `"{shape_prefix}_inner_asymmetry"`: asymmetry of points inside `shape_name` - `"{shape_prefix}_outer_asymmetry"`: asymmetry of points outside `shape_name` + `"{shape_key}_inner_asymmetry"`: asymmetry of points inside `shape_key` + `"{shape_key}_outer_asymmetry"`: asymmetry of points outside `shape_key` """ - def __init__(self, shape_name): - super().__init__(shape_name) + def __init__(self, instance_key, shape_key): + super().__init__(instance_key, shape_key) self.cell_features.add("radius") - self.attributes.add("cell_radius") - self.shape_name = shape_name + self.attributes.add(f"{self.instance_key}_radius") def extract(self, df): df = super().extract(df) # Get shape polygon - shape = df[self.shape_name].values[0] + shape = df[self.shape_key].values[0] + + # Skip if no shape or shape is nan + try: + if isnan(shape): + return { + f"{self.shape_key}_inner_asymmetry": 0, + f"{self.shape_key}_outer_asymmetry": 0, + } + except: + pass - # Skip if no shape if shape is None: return { - f"{self.shape_prefix}_inner_asymmetry": 0, - f"{self.shape_prefix}_outer_asymmetry": 0, + f"{self.shape_key}_inner_asymmetry": 0, + f"{self.shape_key}_outer_asymmetry": 0, } # Get points points_geo = df["geometry"] # Check for points within shape, assume all are intracellular - if self.shape_prefix == "cell": + if self.shape_key == self.instance_key: inner = np.array([True] * len(df)) else: - inner = df[self.shape_prefix] != "-1" + inner = df[f"{self.shape_key}_index"] != "" outer = ~inner inner_to_centroid = np.nan @@ -336,7 +360,7 @@ def extract(self, df): outer_to_centroid = points_geo[outer].distance(shape.centroid).mean() # Values [0, 1], where 1 is asymmetrical and 0 is symmetrical. - cell_radius = df["cell_radius"].values[0] + cell_radius = df[f"{self.instance_key}_radius"].values[0] inner_asymmetry = inner_to_centroid / cell_radius outer_asymmetry = outer_to_centroid / cell_radius @@ -347,8 +371,8 @@ def extract(self, df): outer_asymmetry = 0 return { - f"{self.shape_prefix}_inner_asymmetry": inner_asymmetry, - f"{self.shape_prefix}_outer_asymmetry": outer_asymmetry, + f"{self.shape_key}_inner_asymmetry": inner_asymmetry, + f"{self.shape_key}_outer_asymmetry": outer_asymmetry, } @@ -370,19 +394,17 @@ class PointDispersionNorm(PointFeature): `"point_dispersion"`: measure of point dispersion """ - def __init__(self, shape_name): - super().__init__(shape_name) + def __init__(self, instance_key, shape_key): + super().__init__(instance_key, shape_key) self.cell_features.add("raster") - - attrs = ["cell_raster"] - self.attributes.update(attrs) + self.attributes.add(f"{self.instance_key}_raster") def extract(self, df): df = super().extract(df) # Get precomputed cell centroid and raster pt_centroid = df[["x", "y"]].values.mean(axis=0).reshape(1, 2) - cell_raster = df["cell_raster"].values[0] + cell_raster = df[f"{self.instance_key}_raster"].values[0] # Skip if no raster if not np.array(cell_raster).flatten().any(): @@ -400,7 +422,7 @@ def extract(self, df): class ShapeDispersionNorm(PointFeature): """For a set of points, calculates the second moment of all points in a cell relative to the - centroid of `shape_name`. This value is normalized by the second moment of a uniform + centroid of `shape_key`. This value is normalized by the second moment of a uniform distribution within the cell boundary. Attributes @@ -413,28 +435,33 @@ class ShapeDispersionNorm(PointFeature): Returns ------- dict - `"{shape_prefix}_dispersion"`: measure of point dispersion relative to `shape_name` + `"{shape_key}_dispersion"`: measure of point dispersion relative to `shape_key` """ - def __init__(self, shape_name): - super().__init__(shape_name) + def __init__(self, instance_key, shape_key): + super().__init__(instance_key, shape_key) self.cell_features.add("raster") - attrs = ["cell_raster"] - self.attributes.update(attrs) + self.attributes.add(f"{self.instance_key}_raster") def extract(self, df): df = super().extract(df) # Get shape polygon - shape = df[self.shape_name].values[0] + shape = df[self.shape_key].values[0] + + # Skip if no shape or if shape is nan + try: + if isnan(shape): + return {f"{self.shape_key}_dispersion_norm": np.nan} + except: + pass - # Skip if no shape if not shape: - return {f"{self.shape_prefix}_dispersion_norm": np.nan} + return {f"{self.shape_key}_dispersion_norm": np.nan} # Get precomputed shape centroid and raster - cell_raster = df["cell_raster"].values[0] + cell_raster = df[f"{self.instance_key}_raster"].values[0] # calculate points moment point_moment = _second_moment(shape.centroid, df[["x", "y"]].values) @@ -443,12 +470,12 @@ def extract(self, df): # Normalize by cell moment norm_moment = point_moment / cell_moment - return {f"{self.shape_prefix}_dispersion_norm": norm_moment} + return {f"{self.shape_key}_dispersion_norm": norm_moment} class ShapeDistance(PointFeature): - """For a set of points, computes the distance of points within `shape_name` - as well as the distance of points outside `shape_name`. + """For a set of points, computes the distance of points within `shape_key` + as well as the distance of points outside `shape_key`. Attributes ---------- @@ -460,35 +487,44 @@ class ShapeDistance(PointFeature): Returns ------- dict - `"{shape_prefix}_inner_distance"`: distance of points inside `shape_name` - `"{shape_prefix}_outer_distance"`: distance of points outside `shape_name` + `"{shape_key}_inner_distance"`: distance of points inside `shape_key` + `"{shape_key}_outer_distance"`: distance of points outside `shape_key` """ # Cell-level features needed for computing sample-level features - def __init__(self, shape_name): - super().__init__(shape_name) + def __init__(self, instance_key, shape_key): + super().__init__(instance_key, shape_key) def extract(self, df): df = super().extract(df) # Get shape polygon - shape = df[self.shape_name].values[0] + shape = df[self.shape_key].values[0] + + # Skip if no shape or if shape is nan + try: + if isnan(shape): + return { + f"{self.shape_key}_inner_distance": np.nan, + f"{self.shape_key}_outer_distance": np.nan, + } + except: + pass - # Skip if no shape if not shape: return { - f"{self.shape_prefix}_inner_distance": np.nan, - f"{self.shape_prefix}_outer_distance": np.nan, + f"{self.shape_key}_inner_distance": np.nan, + f"{self.shape_key}_outer_distance": np.nan, } # Get points points_geo = df["geometry"].values # Check for points within shape, assume all are intracellular - if self.shape_prefix == "cell": + if self.shape_key == self.instance_key: inner = np.array([True] * len(df)) else: - inner = df[self.shape_prefix] != "-1" + inner = df[f"{self.shape_key}_index"] != "" outer = ~inner if inner.sum() > 0: @@ -502,16 +538,16 @@ def extract(self, df): outer_dist = np.nan return { - f"{self.shape_prefix}_inner_distance": inner_dist, - f"{self.shape_prefix}_outer_distance": outer_dist, + f"{self.shape_key}_inner_distance": inner_dist, + f"{self.shape_key}_outer_distance": outer_dist, } class ShapeOffset(PointFeature): - """For a set of points, computes the offset of points within `shape_name` - as well as the offset of points outside `shape_name`. Offset is defined as + """For a set of points, computes the offset of points within `shape_key` + as well as the offset of points outside `shape_key`. Offset is defined as the offset between the centroid of points to the centroid of the specified - `shape_name`. + `shape_key`. Attributes ---------- @@ -519,40 +555,49 @@ class ShapeOffset(PointFeature): Set of cell-level features needed for computing sample-level features attributes : int Names (keys) used to store computed cell-level features - shape_name : str + shape_key : str Name of shape to use, must be column name in input DataFrame Returns ------- dict - `"{shape_prefix}_inner_offset"`: offset of points inside `shape_name` - `"{shape_prefix}_outer_offset"`: offset of points outside `shape_name` + `"{shape_key}_inner_offset"`: offset of points inside `shape_key` + `"{shape_key}_outer_offset"`: offset of points outside `shape_key` """ - def __init__(self, shape_name): - super().__init__(shape_name) + def __init__(self, instance_key, shape_key): + super().__init__(instance_key, shape_key) def extract(self, df): df = super().extract(df) # Get shape polygon - shape = df[self.shape_name].values[0] + shape = df[self.shape_key].values[0] # Skip if no shape + try: + if isnan(shape): + return { + f"{self.shape_key}_inner_offset": np.nan, + f"{self.shape_key}_outer_offset": np.nan, + } + except: + pass + if not shape: return { - f"{self.shape_prefix}_inner_offset": np.nan, - f"{self.shape_prefix}_outer_offset": np.nan, + f"{self.shape_key}_inner_offset": np.nan, + f"{self.shape_key}_outer_offset": np.nan, } # Get points points_geo = df["geometry"].values # Check for points within shape, assume all are intracellular - if self.shape_prefix == "cell": + if self.shape_key == self.instance_key: inner = np.array([True] * len(df)) else: - inner = df[self.shape_prefix] != "-1" + inner = df[f"{self.shape_key}_index"] != "" outer = ~inner if inner.sum() > 0: @@ -566,8 +611,8 @@ def extract(self, df): outer_to_centroid = np.nan return { - f"{self.shape_prefix}_inner_offset": inner_to_centroid, - f"{self.shape_prefix}_outer_offset": outer_to_centroid, + f"{self.shape_key}_inner_offset": inner_to_centroid, + f"{self.shape_key}_outer_offset": outer_to_centroid, } @@ -588,9 +633,9 @@ class PointDispersion(PointFeature): `"point_dispersion"`: measure of point dispersion """ - # shape_name set to None to follow the same convention as other shape features - def __init__(self, shape_name=None): - super().__init__(shape_name) + # shape_key set to None to follow the same convention as other shape features + def __init__(self, instance_key, shape_key=None): + super().__init__(instance_key, shape_key) def extract(self, df): df = super().extract(df) @@ -606,7 +651,7 @@ def extract(self, df): class ShapeDispersion(PointFeature): """For a set of points, calculates the second moment of all points in a cell relative to the - centroid of `shape_name`. + centroid of `shape_key`. Attributes ---------- @@ -618,26 +663,32 @@ class ShapeDispersion(PointFeature): Returns ------- dict - `"{shape_prefix}_dispersion"`: measure of point dispersion relative to `shape_name` + `"{shape_key}_dispersion"`: measure of point dispersion relative to `shape_key` """ - def __init__(self, shape_name): - super().__init__(shape_name) + def __init__(self, instance_key, shape_key): + super().__init__(instance_key, shape_key) def extract(self, df): df = super().extract(df) # Get shape polygon - shape = df[self.shape_name].values[0] + shape = df[self.shape_key].values[0] + + # Skip if no shape or if shape is nan + try: + if isnan(shape): + return {f"{self.shape_key}_dispersion": np.nan} + except: + pass - # Skip if no shape if not shape: - return {f"{self.shape_prefix}_dispersion": np.nan} + return {f"{self.shape_key}_dispersion": np.nan} # calculate points moment point_moment = _second_moment(shape.centroid, df[["x", "y"]].values) - return {f"{self.shape_prefix}_dispersion": point_moment} + return {f"{self.shape_key}_dispersion": point_moment} class RipleyStats(PointFeature): @@ -662,18 +713,18 @@ class RipleyStats(PointFeature): """ - def __init__(self, shape_name=None): - super().__init__(shape_name) + def __init__(self, instance_key, shape_key=None): + super().__init__(instance_key, shape_key) self.cell_features.update(["span", "bounds", "area"]) self.attributes.update( [ - "cell_span", - "cell_minx", - "cell_miny", - "cell_maxx", - "cell_maxy", - "cell_area", + f"{instance_key}_span", + f"{instance_key}_minx", + f"{instance_key}_miny", + f"{instance_key}_maxx", + f"{instance_key}_maxy", + f"{instance_key}_area", ] ) @@ -681,12 +732,12 @@ def extract(self, df): df = super().extract(df) # Get precomputed centroid and cell moment - cell_span = df["cell_span"].values[0] - cell_minx = df["cell_minx"].values[0] - cell_miny = df["cell_miny"].values[0] - cell_maxx = df["cell_maxx"].values[0] - cell_maxy = df["cell_maxy"].values[0] - cell_area = df["cell_area"].values[0] + cell_span = df[f"{self.instance_key}_span"].values[0] + cell_minx = df[f"{self.instance_key}_minx"].values[0] + cell_miny = df[f"{self.instance_key}_miny"].values[0] + cell_maxx = df[f"{self.instance_key}_maxx"].values[0] + cell_maxy = df[f"{self.instance_key}_maxy"].values[0] + cell_area = df[f"{self.instance_key}_area"].values[0] estimator = RipleysKEstimator( area=cell_area, @@ -741,7 +792,7 @@ def extract(self, df): class ShapeEnrichment(PointFeature): - """For a set of points, calculates the fraction of points within `shape_name` + """For a set of points, calculates the fraction of points within `shape_key` out of all points in the cell. Attributes @@ -750,17 +801,17 @@ class ShapeEnrichment(PointFeature): Set of cell-level features needed for computing sample-level features attributes : int Names (keys) used to store computed cell-level features - shape_name : str + shape_key : str Name of shape to use, must be column name in input DataFrame Returns ------- dict - `"{shape_prefix}_enrichment"`: enrichment fraction of points in `shape_name` + `"{shape_key}_enrichment"`: enrichment fraction of points in `shape_key` """ - def __init__(self, shape_name): - super().__init__(shape_name) + def __init__(self, instance_key, shape_key): + super().__init__(instance_key, shape_key) def extract(self, df): df = super().extract(df) @@ -769,13 +820,13 @@ def extract(self, df): points_geo = df["geometry"] # Check for points within shape, assume all are intracellular - if self.shape_prefix == "cell": + if self.shape_key == self.instance_key: enrichment = 1.0 else: - inner_count = (df[self.shape_prefix] != "-1").sum() + inner_count = (df[f"{self.shape_key}_index"] != "").sum() enrichment = inner_count / float(len(points_geo)) - return {f"{self.shape_prefix}_enrichment": enrichment} + return {f"{self.shape_key}_enrichment": enrichment} def _second_moment(centroid, pts): @@ -787,6 +838,8 @@ def _second_moment(centroid, pts): centroid : [1 x 2] float pts : [n x 2] float """ + if type(centroid) != np.ndarray: + centroid = centroid.coords centroid = np.array(centroid).reshape(1, 2) radii = distance.cdist(centroid, pts) second_moment = np.sum(radii * radii / len(pts)) @@ -839,4 +892,4 @@ def register_point_feature(name: str, FeatureClass: PointFeature): point_features[name] = FeatureClass - print(f"Registered point feature '{name}' to `bento.tl.shape_features`.") + print(f"Registered point feature '{name}' to `bento.tl.shape_features`.") \ No newline at end of file diff --git a/bento/tools/_shape_features.py b/bento/tools/_shape_features.py old mode 100755 new mode 100644 index a06f005..9eac1dc --- a/bento/tools/_shape_features.py +++ b/bento/tools/_shape_features.py @@ -1,7 +1,9 @@ import warnings + from shapely.errors import ShapelyDeprecationWarning warnings.filterwarnings("ignore", category=ShapelyDeprecationWarning) +warnings.filterwarnings("ignore") from typing import Callable, Dict, List, Union @@ -9,39 +11,42 @@ import matplotlib.path as mplPath import numpy as np import pandas as pd -from anndata import AnnData from scipy.spatial import distance, distance_matrix from shapely.geometry import MultiPolygon, Point +from spatialdata._core.spatialdata import SpatialData +from spatialdata.models import PointsModel, ShapesModel from tqdm.auto import tqdm -from .._utils import sync, track -from ..geometry import get_points, get_shape +from ..geometry import get_points, get_shape, set_shape_metadata -def _area(data: AnnData, shape_name: str, recompute: bool = False): - """Compute the area of each shape. +def _area(sdata: SpatialData, shape_key: str, recompute: bool = False): + """ + Compute the area of each shape. Parameters ---------- - data : AnnData - Spatial formatted AnnData - shape_name : str - Key in `data.obs` that contains the shape information. + sdata : SpatialData + Spatial formatted SpatialData + shape_key : str + Key in `sdata.shapes[shape_key]` that contains the shape information. + recompute : bool, optional + If True, forces the computation of the area even if it already exists in the shape metadata. + If False (default), the computation is skipped if the area already exists. Fields ------ - .obs['{shape}_area'] : float - Area of each polygon + .shapes[shape_key]['{shape}_area'] : float + Area of each polygon """ - feature_key = f"{shape_name.split('_')[0]}_area" - if feature_key in data.obs.keys() and not recompute: + feature_key = f"{shape_key}_area" + if feature_key in sdata.shapes[shape_key].columns and not recompute: return # Calculate pixel-wise area - area = get_shape(data, shape_name).area - - data.obs[feature_key] = area + area = get_shape(sdata=sdata, shape_key=shape_key, sync=False).area + set_shape_metadata(sdata=sdata, shape_key=shape_key, metadata=area, column_names=feature_key) def _poly_aspect_ratio(poly): @@ -66,123 +71,128 @@ def _poly_aspect_ratio(poly): return length / width -def _aspect_ratio(data: AnnData, shape_name: str, recompute: bool = False): +def _aspect_ratio(sdata: SpatialData, shape_key: str, recompute: bool = False): """Compute the aspect ratio of the minimum rotated rectangle that contains each shape. - Parameters - ---------- - data : AnnData - Spatial formatted AnnData - shape_name : str - Key in `data.obs` that contains the shape information. - - Fields - ------ - .obs['{shape}_aspect_ratio'] : float - Ratio of major to minor axis for each polygon + Parameters + ---------- + sdata : SpatialData + Spatial formatted SpatialData + shape_key : str + Key in `sdata.shapes[shape_key]` that contains the shape information. + 1 + Fields + ------ + .shapes[shape_key]['{shape}_aspect_ratio'] : float + Ratio of major to minor axis for each polygon """ - feature_key = f"{shape_name.split('_')[0]}_aspect_ratio" - if feature_key in data.obs.keys() and not recompute: + feature_key = f"{shape_key}_aspect_ratio" + if feature_key in sdata.shapes[shape_key].keys() and not recompute: return - ar = get_shape(data, shape_name).apply(_poly_aspect_ratio) - data.obs[feature_key] = ar + ar = get_shape(sdata, shape_key, sync=False).apply(_poly_aspect_ratio) + set_shape_metadata(sdata=sdata, shape_key=shape_key, metadata=ar, column_names=feature_key) -def _bounds(data: AnnData, shape_name: str, recompute: bool = False): +def _bounds(sdata: SpatialData, shape_key: str, recompute: bool = False): """Compute the minimum and maximum coordinate values that bound each shape. Parameters ---------- - data : AnnData - Spatial formatted AnnData - shape_name : str - Key in `data.obs` that contains the shape information. + sdata : SpatialData + Spatial formatted SpatialData + shape_key : str + Key in `sdata.shapes[shape_key]` that contains the shape information. Fields ------ - .obs['{shape}_minx'] : float + .shapes[shape_key]['{shape}_minx'] : float x-axis lower bound of each polygon - .obs['{shape}_miny'] : float + .shapes[shape_key]['{shape}_miny'] : float y-axis lower bound of each polygon - .obs['{shape}_maxx'] : float + .shapes[shape_key]['{shape}_maxx'] : float x-axis upper bound of each polygon - .obs['{shape}_maxy'] : float + .shapes[shape_key]['{shape}_maxy'] : float y-axis upper bound of each polygon """ + feat_names = ["minx", "miny", "maxx", "maxy"] feature_keys = [ - f"{shape_name.split('_')[0]}_{k}" for k in ["minx", "miny", "maxx", "maxy"] + f"{shape_key}_{k}" for k in feat_names ] - if all([k in data.obs.keys() for k in feature_keys]) and not recompute: + if ( + all([k in sdata.shapes[shape_key].keys() for k in feature_keys]) + and not recompute + ): return - bounds = get_shape(data, shape_name).bounds + bounds = get_shape(sdata, shape_key, sync=False).bounds - data.obs[feature_keys[0]] = bounds["minx"] - data.obs[feature_keys[1]] = bounds["miny"] - data.obs[feature_keys[2]] = bounds["maxx"] - data.obs[feature_keys[3]] = bounds["maxy"] + set_shape_metadata(sdata=sdata, shape_key=shape_key, metadata=bounds[feat_names], column_names=feature_keys) -# TODO move to point_features -def _density(data: AnnData, shape_name: str, recompute: bool = False): +def _density(sdata: SpatialData, shape_key: str, recompute: bool = False): """Compute the RNA density of each shape. Parameters ---------- - data : AnnData - Spatial formatted AnnData - shape_name : str - Key in `data.obs` that contains the shape information. + sdata : SpatialData + Spatial formatted SpatialData + shape_key : str + Key in `sdata.shapes[shape_key]` that contains the shape information. Fields ------ - .obs['{shape}_density'] : float + .shapes[shape_key]['{shape}_density'] : float Density (molecules / shape area) of each polygon """ - shape_prefix = shape_name.split("_")[0] - feature_key = f"{shape_prefix}_density" - if feature_key in data.obs.keys() and not recompute: + + feature_key = f"{shape_key}_density" + if feature_key in sdata.shapes[shape_key].keys() and not recompute: return - count = get_points(data).query(f"{shape_prefix} != '-1'")["cell"].value_counts() - _area(data, shape_name) + count = ( + get_points(sdata, astype="dask", sync=False) + .query(f"{shape_key} != 'None'")[shape_key] + .value_counts() + .compute() + ) + _area(sdata, shape_key) - data.obs[feature_key] = count / data.obs[f"{shape_prefix}_area"] + set_shape_metadata( + sdata=sdata, + shape_key=shape_key, + metadata=count / sdata.shapes[shape_key][f"{shape_key}_area"], + column_names=feature_key + ) -def _opening(data: AnnData, proportion: float, recompute: bool = False): +def _opening(sdata: SpatialData, shape_key: str, proportion: float, recompute: bool = False): """Compute the opening (morphological) of distance d for each cell. Parameters ---------- - data : AnnData - Spatial formatted AnnData + sdata : SpatialData + Spatial formatted SpatialData - Returns + Fields ------- - data : anndata.AnnData - Returns `data` if `copy=True`, otherwise adds fields to `data`: - - `obs['cell_open_{d}_shape']` : Polygons - Ratio of long / short axis for each polygon in `obs['cell_shape']` + .shapes[shape_key]['cell_open_{d}_shape'] : Polygons + Ratio of long / short axis for each polygon in `.shapes[shape_key]['cell_boundaries']` """ - shape_name = f"cell_open_{proportion}_shape" + feature_key = f"{shape_key}_open_{proportion}_shape" - if shape_name in data.obs.keys() and not recompute: + if feature_key in sdata.shapes[shape_key].keys() and not recompute: return - _radius(data, "cell_shape") - - cells = get_shape(data, "cell_shape") - d = proportion * data.obs["cell_radius"] + _radius(sdata, shape_key) - # Opening - data.obs[shape_name] = cells.buffer(-d).buffer(d) + shapes = get_shape(sdata, shape_key, sync=False) + d = proportion * sdata.shapes[shape_key][f"{shape_key}_radius"] + set_shape_metadata(sdata=sdata, shape_key=shape_key, metadata=shapes.buffer(-d).buffer(d), column_names=feature_key) def _second_moment_polygon(centroid, pts): @@ -194,48 +204,45 @@ def _second_moment_polygon(centroid, pts): centroid : 2D Point object pts : [n x 2] float """ - if not centroid or isinstance(pts, np.ndarray): + + if not centroid or not isinstance(pts, np.ndarray): return - centroid = np.array(centroid).reshape(1, 2) + centroid = np.array(centroid.coords).reshape(1, 2) radii = distance.cdist(centroid, pts) second_moment = np.sum(radii * radii / len(pts)) return second_moment -def _second_moment(data: AnnData, shape_name: str, recompute: bool = False): +def _second_moment(sdata: SpatialData, shape_key: str, recompute: bool = False): """Compute the second moment of each shape. Parameters ---------- - data : AnnData - Spatial formatted AnnData + sdata : SpatialData + Spatial formatted SpatialData - Returns + Fields ------- - data : anndata.AnnData - Returns `data` if `copy=True`, otherwise adds fields to `data`: - - `obs['{shape}_moment']` : float + .shapes[shape_key]['{shape}_moment'] : float The second moment for each polygon """ - shape_prefix = shape_name.split("_")[0] - feature_key = f"{shape_prefix}_moment" - if feature_key in data.obs.keys() and not recompute: + + feature_key = f"{shape_key}_moment" + if feature_key in sdata.shapes[shape_key].keys() and not recompute: return - _raster(data, shape_name, recompute=recompute) + _raster(sdata, shape_key, recompute=recompute) - rasters = data.obs[f"{shape_prefix}_raster"] - shape_centroids = get_shape(data, shape_name).centroid + rasters = sdata.shapes[shape_key][f"{shape_key}_raster"] + shape_centroids = get_shape(sdata, shape_key, sync=False).centroid moments = [ _second_moment_polygon(centroid, r) for centroid, r in zip(shape_centroids, rasters) ] - data.obs[f"{shape_prefix}_moment"] = moments - + set_shape_metadata(sdata=sdata, shape_key=shape_key, metadata=moments, column_names=feature_key) def _raster_polygon(poly, step=1): """ @@ -271,103 +278,110 @@ def _raster_polygon(poly, step=1): return xy -def _raster(data: AnnData, shape_name: str, step: int = 1, recompute: bool = False): +def _raster( + sdata: SpatialData, + shape_key: str, + points_key: str = "transcripts", + step: int = 1, + recompute: bool = False, +): """Generate a grid of points contained within each shape. The points lie on a 2D grid, with vertices spaced `step` distance apart. Parameters ---------- - data : AnnData - Spatial formatted AnnData + sdata : SpatialData + Spatial formatted SpatialData - Returns + Fields ------- - data : anndata.AnnData - Returns `data` if `copy=True`, otherwise adds fields to `data`: - - `uns['{shape}_raster']` : np.array - Long DataFrame of points annotated by shape from `obs['{shape_name}']` + .shapes[shape_key]['{shape}_raster'] : np.array + Long DataFrame of points annotated by shape from `.shapes[shape_key]['{shape_key}']` """ - shape_prefix = shape_name.split("_")[0] - feature_key = f"{shape_prefix}_raster" + + feature_key = f"{shape_key}_raster" - if feature_key in data.obs.keys() and not recompute: + if feature_key in sdata.shapes[shape_key].keys() and not recompute: return - raster = data.obs[f"{shape_name}"].apply( - lambda poly: _raster_polygon(poly, step=step) - ) + shapes = get_shape(sdata, shape_key, sync=False) + raster = shapes.apply(lambda poly: _raster_polygon(poly, step=step)) raster_all = [] for s, r in raster.items(): raster_df = pd.DataFrame(r, columns=["x", "y"]) - raster_df[shape_prefix] = s + raster_df[shape_key] = s raster_all.append(raster_df) - # Add raster to data.obs as 2d array per cell (for point_features compatibility) - data.obs[feature_key] = [df[["x", "y"]].values for df in raster_all] + # Add raster to sdata.shapes as 2d array per cell (for point_features compatibility) + set_shape_metadata( + sdata=sdata, + shape_key=shape_key, + metadata=[df[["x", "y"]].values for df in raster_all], + column_names=feature_key + ) - # Add raster to data.uns as long dataframe (for flux compatibility) + # Add raster to sdata.points as long dataframe (for flux compatibility) raster_all = pd.concat(raster_all).reset_index(drop=True) - data.uns[feature_key] = raster_all + transform = sdata.points[points_key].attrs + sdata.points[feature_key] = PointsModel.parse( + raster_all, coordinates={"x": "x", "y": "y"} + ) + sdata.points[feature_key].attrs = transform -def _perimeter(data: AnnData, shape_name: str, recompute: bool = False): +def _perimeter(sdata: SpatialData, shape_key: str, recompute: bool = False): """Compute the perimeter of each shape. Parameters ---------- - data : AnnData - Spatial formatted AnnData + sdata : SpatialData + Spatial formatted SpatialData - Returns + Fields ------- - data : anndata.AnnData - Returns `data` if `copy=True`, otherwise adds fields to `data`: - - `obs['{shape}_perimeter']` : np.array + `.shapes[shape_key]['{shape}_perimeter']` : np.array Perimeter of each polygon """ - shape_prefix = shape_name.split("_")[0] - feature_key = f"{shape_prefix}_perimeter" - - if feature_key in data.obs.keys() and not recompute: + + feature_key = f"{shape_key}_perimeter" + if feature_key in sdata.shapes[shape_key].keys() and not recompute: return - data.obs[feature_key] = get_shape(data, shape_name).length + set_shape_metadata( + sdata=sdata, + shape_key=shape_key, + metadata=get_shape(sdata, shape_key, sync=False).length, + column_names=feature_key + ) -def _radius(data: AnnData, shape_name: str, recompute: bool = False): +def _radius(sdata: SpatialData, shape_key: str, recompute: bool = False): """Compute the radius of each cell. Parameters ---------- - data : AnnData - Spatial formatted AnnData + sdata : SpatialData + Spatial formatted SpatialData - Returns + Fields ------- - data : anndata.AnnData - Returns `data` if `copy=True`, otherwise adds fields to `data`: - - `obs['{shape}_radius']` : np.array + .shapes[shape_key]['{shape}_radius'] : np.array Radius of each polygon in `obs['cell_shape']` """ - shape_prefix = shape_name.split("_")[0] - feature_key = f"{shape_prefix}_radius" - - if feature_key in data.obs.keys() and not recompute: + + feature_key = f"{shape_key}_radius" + if feature_key in sdata.shapes[shape_key].keys() and not recompute: return - shapes = get_shape(data, shape_name) + shapes = get_shape(sdata, shape_key, sync=False) # Get average distance from boundary to centroid shape_radius = shapes.apply(_shape_radius) - - data.obs[feature_key] = shape_radius + set_shape_metadata(sdata=sdata, shape_key=shape_key, metadata=shape_radius, column_names=feature_key) def _shape_radius(poly): @@ -375,31 +389,28 @@ def _shape_radius(poly): return np.nan return distance.cdist( - np.array(poly.centroid).reshape(1, 2), np.array(poly.exterior.xy).T + np.array(poly.centroid.coords).reshape(1, 2), np.array(poly.exterior.xy).T ).mean() -def _span(data: AnnData, shape_name: str, recompute: bool = False): +def _span(sdata: SpatialData, shape_key: str, recompute: bool = False): """Compute the length of the longest diagonal of each shape. Parameters ---------- - data : AnnData - Spatial formatted AnnData + sdata : SpatialData + Spatial formatted SpatialData - Returns + Fields ------- - data : anndata.AnnData - Returns `data` if `copy=True`, otherwise adds fields to `data`: - - `obs['{shape}_span']` : float + .shapes[shape_key]['{shape}_span'] : float Length of longest diagonal for each polygon """ - shape_prefix = shape_name.split("_")[0] - feature_key = f"{shape_prefix}_span" + + feature_key = f"{shape_key}_span" - if feature_key in data.obs.keys() and not recompute: + if feature_key in sdata.shapes[shape_key].keys() and not recompute: return def get_span(poly): @@ -409,18 +420,17 @@ def get_span(poly): shape_coo = np.array(poly.coords.xy).T return int(distance_matrix(shape_coo, shape_coo).max()) - span = get_shape(data, shape_name).exterior.apply(get_span) - - data.obs[feature_key] = span - + span = get_shape(sdata, shape_key, sync=False).exterior.apply(get_span) + set_shape_metadata(sdata=sdata, shape_key=shape_key, metadata=span, column_names=feature_key) + def list_shape_features(): - """Return a DataFrame of available shape features. Pulls descriptions from function docstrings. + """Return a dictionary of available shape features and their descriptions. Returns ------- - list - List of available shape features. + dict + A dictionary where keys are shape feature names and values are their corresponding descriptions. """ # Get shape feature descriptions from docstrings @@ -437,6 +447,7 @@ def list_shape_features(): aspect_ratio=_aspect_ratio, bounds=_bounds, density=_density, + opening=_opening, perimeter=_perimeter, radius=_radius, raster=_raster, @@ -446,86 +457,69 @@ def list_shape_features(): def obs_stats( - data: AnnData, + sdata: SpatialData, feature_names: List[str] = ["area", "aspect_ratio", "density"], - copy=False, ): """Compute features for each cell shape. Convenient wrapper for `bento.tl.shape_features`. See list of available features in `bento.tl.shape_features`. Parameters ---------- - data : AnnData - Spatial formatted AnnData + sdata : SpatialData + Spatial formatted SpatialData feature_names : list List of features to compute. See list of available features in `bento.tl.shape_features`. - copy : bool, optional - Return a copy of `data` instead of writing to data, by default False. - Returns + Fields ------- - data : anndata.AnnData - Returns `data` if `copy=True`, otherwise adds fields to `data`: - - `obs['{shape}_{feature}']` : np.array + .shapes[shape_key]['{shape}_{feature}'] : np.array Feature of each polygon """ - adata = data.copy() if copy else data # Compute features - analyze_shapes(adata, "cell_shape", feature_names, copy=copy) - if "nucleus_shape" in adata.obs.columns: - analyze_shapes(adata, "nucleus_shape", feature_names, copy=copy) + analyze_shapes(sdata, "cell_boundaries", feature_names) + if "nucleus_boundaries" in sdata.shapes.keys(): + analyze_shapes(sdata, "nucleus_boundaries", feature_names) - return adata if copy else None - -@track def analyze_shapes( - data: AnnData, - shape_names: Union[str, List[str]], + sdata: SpatialData, + shape_keys: Union[str, List[str]], feature_names: Union[str, List[str]], feature_kws: Dict[str, Dict] = None, recompute: bool = False, progress: bool = True, - copy: bool = False, ): """Analyze features of shapes. Parameters ---------- - data : AnnData - Spatial formatted AnnData - shape_names : list of str + sdata : SpatialData + Spatial formatted SpatialData + shape_keys : list of str List of shapes to analyze. feature_names : list of str List of features to analyze. feature_kws : dict, optional (default: None) Keyword arguments for each feature. - copy : bool, optional - Return a copy of `data` instead of writing to data, by default False. Returns ------- - adata : AnnData + sdata : SpatialData See specific feature function docs for fields added. """ - adata = data.copy() if copy else data # Cast to list if not already - if isinstance(shape_names, str): - shape_names = [shape_names] - - # Add _shape suffix if shape names don't have it - shape_names = [s if s.endswith("_shape") else f"{s}_shape" for s in shape_names] + if isinstance(shape_keys, str): + shape_keys = [shape_keys] # Cast to list if not already if isinstance(feature_names, str): feature_names = [feature_names] # Generate feature x shape combinations - combos = [(f, s) for f in feature_names for s in shape_names] + combos = [(f, s) for f in feature_names for s in shape_keys] # Set up progress bar if progress: @@ -537,21 +531,22 @@ def analyze_shapes( if feature_kws and feature in feature_kws: kws.update(feature_kws[feature]) - shape_features[feature](adata, shape, **kws) + shape_features[feature](sdata, shape, **kws) - return adata if copy else None + return sdata def register_shape_feature(name: str, func: Callable): - """Register a shape feature function. The function should take an AnnData object and a shape name as input. - The function should add the feature to the AnnData object as a column in AnnData.obs. This should be done in place and not return anything. + """Register a shape feature function. The function should take an SpatialData object and a shape name as input. + The function should add the feature to the SpatialData object as a column in SpatialData.table.obs. + This should be done in place and not return anything. Parameters ---------- name : str Name of the feature function. func : function - Function that takes an AnnData object and a shape name as arguments. + Function that takes a SpatialData object and a shape name as arguments. """ shape_features[name] = func diff --git a/docs/Makefile b/docs/Makefile old mode 100755 new mode 100644 diff --git a/docs/make.bat b/docs/make.bat old mode 100755 new mode 100644 diff --git a/docs/source/conf.py b/docs/source/conf.py old mode 100755 new mode 100644 diff --git a/docs/source/favicon.ico b/docs/source/favicon.ico old mode 100755 new mode 100644 diff --git a/docs/source/index.md b/docs/source/index.md old mode 100755 new mode 100644 diff --git a/images/anndata.png b/images/anndata.png old mode 100755 new mode 100644 diff --git a/pyproject.toml b/pyproject.toml old mode 100755 new mode 100644 index 165ac19..119189b --- a/pyproject.toml +++ b/pyproject.toml @@ -13,17 +13,16 @@ license = "BSD-2-Clause" readme = "README.md" [tool.poetry.dependencies] -python = ">=3.8, <3.10" -anndata = "^0.8" +python = ">=3.9, <3.10" +anndata = "^0.9.2" astropy = "^5.0" -geopandas = "^0.10.0" -matplotlib = "^3.2" +geopandas = "^0.14.0" +matplotlib = "~3.7" matplotlib-scalebar = "^0.8.1" -pygeos = "^0.12.0" scanpy = "^1.9.1" scipy = "^1.7.0" seaborn = "^0.12.1" -Shapely = "^1.8.2" +Shapely = "^2.0.1" Sphinx = { version = "^4.1.2", extras = ["docs"] } sphinx-autobuild = { version = "^2021.3.14", extras = ["docs"] } sphinx-book-theme = {version = "^1.0.0", extras = ["docs"] } @@ -35,16 +34,17 @@ emoji = "^1.7.0" tensorly = "^0.7.0" rasterio = "^1.3.0" ipywidgets = "^8.0" -decoupler = "^1.2.0" +decoupler = "1.4.0" MiniSom = "^2.3.0" kneed = "^0.8.1" adjustText = "^0.7.3" sparse = "^0.13.0" pandas = "^1.5.3" -xgboost = "1.4.0" +xgboost = "2.0.0" myst-nb = {version = "^0.17.1", extras = ["docs"]} sphinx_design = {version = "^0.3.0", extras = ["docs"]} -rich = "^13.5.2" +rtree = "^1.0.1" +spatialdata = "0.0.15" [tool.poetry.extras] docs = [ diff --git a/setup.py b/setup.py old mode 100755 new mode 100644 diff --git a/tests/__init__.py b/tests/__init__.py old mode 100755 new mode 100644 diff --git a/tests/test_colocation.py b/tests/test_colocation.py index 5244e85..d472883 100644 --- a/tests/test_colocation.py +++ b/tests/test_colocation.py @@ -1,21 +1,60 @@ import unittest import bento as bt +import spatialdata as sd -data = bt.ds.sample_data() -rank = 3 +class TestColocation(unittest.TestCase): + def setUp(self): + self.data = sd.read_zarr("/mnt/d/spatial_datasets/small_data.zarr") + self.data = bt.io.format_sdata( + self.data, + points_key="transcripts", + feature_key="feature_name", + instance_key="cell_boundaries", + shape_keys=["cell_boundaries", "nucleus_boundaries"], + ) - -class TestColocation(unittest.TestCase): + bt.tl.coloc_quotient(self.data, shapes=["cell_boundaries"]) + bt.tl.colocation(self.data, ranks=range(1, 6)) + + def test_coloc_quotient(self): - bt.tl.coloc_quotient(data) - self.assertTrue("clq" in data.uns) + # Check that clq is in self.data.table.uns + self.assertTrue("clq" in self.data.table.uns) + # Check that cell_boundaries is in self.data.table.uns["clq"] + self.assertTrue("cell_boundaries" in self.data.table.uns["clq"]) + + coloc_quotient_features = ['feature_name', 'neighbor', 'clq', 'cell_boundaries', 'log_clq', 'compartment'] + # Check columns are in clq["cell_boundaries"] + for feature in coloc_quotient_features: + self.assertTrue(feature in self.data.table.uns["clq"]["cell_boundaries"]) + def test_colocation(self): - bt.tl.coloc_quotient(data, radius=20, min_points=10, min_cells=0) - bt.tl.colocation(data, ranks=[rank], iterations=3) - self.assertTrue("clq" in data.uns) - def test_plot(self): - bt.pl.colocation(data, rank=rank) - self.assertTrue(True) + # Check that tensor is in self.data.table.uns + self.assertTrue("tensor" in self.data.table.uns) + + # Check that tensor_labels is in self.data.table.uns + self.assertTrue("tensor_labels" in self.data.table.uns) + + # Check that tensor_names is in self.data.table.uns + self.assertTrue("tensor_names" in self.data.table.uns) + + # Check keys are in tensor_labels + for feature in self.data.table.uns["tensor_names"]: + self.assertTrue(feature in self.data.table.uns["tensor_labels"]) + + # Check that factors is in self.data.table.uns + self.assertTrue("factors" in self.data.table.uns) + + # Check that keys are in factors + for i in range(1, 6): + self.assertTrue(i in self.data.table.uns["factors"]) + + # Check that factors_error is in self.data.table.uns + self.assertTrue("factors_error" in self.data.table.uns) + self.assertTrue("rmse" in self.data.table.uns["factors_error"]) + self.assertTrue("rank" in self.data.table.uns["factors_error"]) + + \ No newline at end of file diff --git a/tests/test_flux.py b/tests/test_flux.py index 590208b..d4a4dd1 100644 --- a/tests/test_flux.py +++ b/tests/test_flux.py @@ -1,46 +1,45 @@ import unittest import bento as bt +import spatialdata as sd -data = bt.ds.sample_data()[0] -bt.sync(data) -radius = None -n_neighbors = 20 -res = 0.2 +class TestFlux(unittest.TestCase): + def setUp(self): + self.data = sd.read_zarr("/mnt/d/spatial_datasets/small_data.zarr") + self.data = bt.io.format_sdata( + self.data, + points_key="transcripts", + feature_key="feature_name", + instance_key="cell_boundaries", + shape_keys=["cell_boundaries", "nucleus_boundaries"], + ) -class TestFlux(unittest.TestCase): - def test_flux_radius(self): - # Calculate flux using radius method - bt.tl.flux(data, method="radius", radius=radius, res=res) + bt.tl.flux(sdata=self.data, points_key="transcripts", instance_key="cell_boundaries", feature_key="feature_name") + bt.tl.fluxmap(sdata=self.data, points_key="transcripts", instance_key="cell_boundaries", n_clusters=3) + - # Check that the flux data is present and in the correct format - self.assertTrue(data.uns["flux"].shape[0] == data.uns["cell_raster"].shape[0]) - self.assertTrue( - data.uns["flux_embed"].shape[0] == data.uns["cell_raster"].shape[0] - ) + def test_flux(self): + # Check that cell_boundaries_raster is in self.data.points + self.assertTrue("cell_boundaries_raster" in self.data.points) - def test_flux_knn(self): - bt.tl.flux(data, method="knn", n_neighbors=n_neighbors, res=res) + # Check that flux_genes is in self.data.table.uns + self.assertTrue("flux_genes" in self.data.table.uns) + genes = self.data.table.uns["flux_genes"] - self.assertTrue(data.uns["flux"].shape[0] == data.uns["cell_raster"].shape[0]) - self.assertTrue( - data.uns["flux_embed"].shape[0] == data.uns["cell_raster"].shape[0] - ) + # Check that flux_variance_ratio is in self.data.table.uns + self.assertTrue("flux_variance_ratio" in self.data.table.uns) + + # Check columns are added in cell_boundaries_raster + for gene in genes: + self.assertTrue(gene in self.data.points["cell_boundaries_raster"].columns) + for i in range(10): + self.assertTrue(f"flux_embed_{i}" in self.data.points["cell_boundaries_raster"].columns) + + self.assertTrue("fluxmap" in self.data.points["cell_boundaries_raster"].columns) + def test_fluxmap(self): - bt.tl.flux(data, method="radius", radius=radius, res=res) - bt.tl.fluxmap(data, n_clusters=range(2, 4), train_size=0.2, res=res) - bt.tl.fluxmap(data, n_clusters=3, train_size=1, res=res) - self.assertTrue("fluxmap" in data.uns["cell_raster"]) - self.assertTrue( - [ - f in data.uns["points"].columns - for f in ["fluxmap0", "fluxmap1", "fluxmap2"] - ] - ) - self.assertTrue( - [ - f in data.obs.columns - for f in ["fluxmap0_shape", "fluxmap1_shape", "fluxmap2_shape"] - ] - ) + for i in range(1, 4): + self.assertTrue(f"fluxmap{i}_boundaries" in self.data.points["transcripts"].columns) + self.assertTrue(f"fluxmap{i}_boundaries" in self.data.shapes) + \ No newline at end of file diff --git a/tests/test_flux_enrichment.py b/tests/test_flux_enrichment.py new file mode 100644 index 0000000..3fcbcc4 --- /dev/null +++ b/tests/test_flux_enrichment.py @@ -0,0 +1,59 @@ +import unittest +import bento as bt +import spatialdata as sd + + +class TestFluxEnrichement(unittest.TestCase): + def setUp(self): + self.data = sd.read_zarr("/mnt/d/spatial_datasets/small_data.zarr") + self.data = bt.io.format_sdata( + self.data, + points_key="transcripts", + feature_key="feature_name", + instance_key="cell_boundaries", + shape_keys=["cell_boundaries", "nucleus_boundaries"], + ) + + bt.tl.flux(sdata=self.data, points_key="transcripts", instance_key="cell_boundaries", feature_key="feature_name") + bt.tl.fluxmap(sdata=self.data, points_key="transcripts", instance_key="cell_boundaries", n_clusters=3) + + self.fe_fazal2019_features = ['Cytosol', 'ER Lumen', 'ERM', 'Lamina', 'Nuclear Pore', 'Nucleolus', 'Nucleus', 'OMM'] + self.fe_xia2019_features = ['ER', 'Nucleus'] + + + def test_fe_fazal2019(self): + bt.tl.fe_fazal2019(self.data) + + # Check that cell_boundaries_raster is in self.data.points + self.assertTrue("cell_boundaries_raster" in self.data.points) + + # Check that fe_stats is in self.data.table.uns + self.assertTrue("fe_stats" in self.data.table.uns) + + # Check that fe_ngenes is in self.data.table.uns + self.assertTrue("fe_ngenes" in self.data.table.uns) + + # Check columns are in cell_boundaries_raster, fe_stats, abd fe_ngenes + for feature in self.fe_fazal2019_features: + self.assertTrue(f"flux_{feature}" in self.data.points["cell_boundaries_raster"]) + self.assertTrue(feature in self.data.table.uns["fe_stats"]) + self.assertTrue(feature in self.data.table.uns["fe_ngenes"]) + + def test_fe_xia2019(self): + bt.tl.fe_xia2019(self.data) + + # Check that cell_boundaries_raster is in self.data.points + self.assertTrue("cell_boundaries_raster" in self.data.points) + + # Check that fe_stats is in self.data.table.uns + self.assertTrue("fe_stats" in self.data.table.uns) + + # Check that fe_ngenes is in self.data.table.uns + self.assertTrue("fe_ngenes" in self.data.table.uns) + + # Check columns are in cell_boundaries_raster, fe_stats, abd fe_ngenes + for feature in self.fe_xia2019_features: + self.assertTrue(f"flux_{feature}" in self.data.points["cell_boundaries_raster"]) + self.assertTrue(feature in self.data.table.uns["fe_stats"]) + self.assertTrue(feature in self.data.table.uns["fe_ngenes"]) + \ No newline at end of file diff --git a/tests/test_geometry.py b/tests/test_geometry.py index ebbd13d..50fc695 100644 --- a/tests/test_geometry.py +++ b/tests/test_geometry.py @@ -1,40 +1,163 @@ import unittest import bento as bt +import spatialdata as sd +import pandas as pd +import geopandas as gpd +import dask as dd -data = bt.ds.sample_data() +class TestGeometry(unittest.TestCase): + def setUp(self): + self.data = sd.read_zarr("/mnt/d/spatial_datasets/small_data.zarr") -class TestGeometry(unittest.TestCase): - def test_crop(self): + def test_sjoin_points(self): + self.data = bt.geo.sjoin_points(sdata=self.data, points_key="transcripts", shape_keys=["cell_boundaries", "nucleus_boundaries"]) + self.assertTrue("cell_boundaries" in self.data.points["transcripts"].columns) + self.assertTrue("nucleus_boundaries" in self.data.points["transcripts"].columns) - # Get bounds of first cell - cell_shape = bt.geo.get_shape(data, "cell_shape") - xmin, ymin, xmax, ymax = cell_shape.bounds.iloc[0] + def test_sjoin_shapes(self): + self.data = bt.geo.sjoin_shapes(sdata=self.data, instance_key="cell_boundaries", shape_keys=["nucleus_boundaries"]) + self.assertTrue("cell_boundaries" in self.data.shapes["nucleus_boundaries"].columns) + self.assertTrue("nucleus_boundaries" in self.data.shapes["cell_boundaries"].columns) - adata_crop = bt.geo.crop(data, (xmin, xmax), (ymin, ymax), copy=True) + def test_get_points(self): + self.data = bt.io.format_sdata( + self.data, + points_key="transcripts", + feature_key="feature_name", + instance_key="cell_boundaries", + shape_keys=["cell_boundaries", "nucleus_boundaries"], + ) + + pd_sync = bt.geo.get_points(sdata=self.data, points_key="transcripts", astype="pandas", sync=True) + pd_no_sync = bt.geo.get_points(sdata=self.data, points_key="transcripts", astype="pandas", sync=False) + + self.assertTrue(type(pd_sync) == pd.DataFrame) + self.assertTrue(type(pd_no_sync) == pd.DataFrame) + self.assertTrue(len(pd_sync) != len(self.data.points["transcripts"])) + self.assertTrue(len(pd_no_sync) == len(self.data.points["transcripts"])) + + gdf_sync = bt.geo.get_points(sdata=self.data, points_key="transcripts", astype="geopandas", sync=True) + gdf_no_sync = bt.geo.get_points(sdata=self.data, points_key="transcripts", astype="geopandas", sync=False) + + self.assertTrue(type(gdf_sync) == gpd.GeoDataFrame) + self.assertTrue(type(gdf_no_sync) == gpd.GeoDataFrame) + self.assertTrue(len(gdf_sync) != len(self.data.points["transcripts"])) + self.assertTrue(len(gdf_no_sync) == len(self.data.points["transcripts"])) + + dd_sync = bt.geo.get_points(sdata=self.data, points_key="transcripts", astype="dask", sync=True) + dd_no_sync = bt.geo.get_points(sdata=self.data, points_key="transcripts", astype="dask", sync=False) + + self.assertTrue(type(dd_sync) == dd.dataframe.core.DataFrame) + self.assertTrue(type(dd_no_sync) == dd.dataframe.core.DataFrame) + self.assertTrue(len(dd_sync) != len(self.data.points["transcripts"])) + self.assertTrue(len(dd_no_sync) == len(self.data.points["transcripts"])) + + def test_get_shape(self): + self.data = bt.io.format_sdata( + self.data, + points_key="transcripts", + feature_key="feature_name", + instance_key="cell_boundaries", + shape_keys=["cell_boundaries", "nucleus_boundaries"], + ) + + sync = bt.geo.get_shape(sdata=self.data, shape_key="nucleus_boundaries", sync=True) + no_sync = bt.geo.get_shape(sdata=self.data, shape_key="nucleus_boundaries", sync=False) + + self.assertTrue(type(sync) == gpd.GeoSeries) + self.assertTrue(type(no_sync) == gpd.GeoSeries) + self.assertTrue(len(sync) != len(self.data.shapes["nucleus_boundaries"])) + self.assertTrue(len(no_sync) == len(self.data.shapes["nucleus_boundaries"])) + + def test_set_points_metadata(self): + list_metadata = [0] * len(self.data.points["transcripts"]) + series_metadata = pd.Series(list_metadata) + dataframe_metadata = pd.DataFrame({"0": list_metadata, "1": list_metadata, "2": list_metadata}) + column_names = ["list_metadata", "series_metadata", "dataframe_metadata0", "dataframe_metadata1", "dataframe_metadata2"] - # Check that cropped data only contains first cell - self.assertTrue(adata_crop.obs.shape[0] == 1) - self.assertTrue(adata_crop.obs.index[0] == data.obs.index[0]) + bt.geo.set_points_metadata(sdata=self.data, points_key="transcripts", metadata=list_metadata, column_names=column_names[0]) + bt.geo.set_points_metadata(sdata=self.data, points_key="transcripts", metadata=series_metadata, column_names=column_names[1]) + bt.geo.set_points_metadata( + sdata=self.data, + points_key="transcripts", + metadata=dataframe_metadata, + column_names=[column_names[2], column_names[3], column_names[4]] + ) + for column in column_names: + self.assertTrue(column in self.data.points["transcripts"]) + + def test_set_shape_metadata(self): + list_metadata = [0] * len(self.data.shapes["cell_boundaries"]) + series_metadata = pd.Series(list_metadata) + dataframe_metadata = pd.DataFrame({"0": list_metadata, "1": list_metadata, "2": list_metadata}) + column_names = ["list_metadata", "series_metadata", "dataframe_metadata0", "dataframe_metadata1", "dataframe_metadata2"] + + bt.geo.set_shape_metadata(sdata=self.data, shape_key="cell_boundaries", metadata=list_metadata, column_names=column_names[0]) + bt.geo.set_shape_metadata(sdata=self.data, shape_key="cell_boundaries", metadata=series_metadata, column_names=column_names[1]) + bt.geo.set_shape_metadata( + sdata=self.data, + shape_key="cell_boundaries", + metadata=dataframe_metadata, + column_names=[column_names[2], column_names[3], column_names[4]] + ) + for column in column_names: + self.assertTrue(column in self.data.shapes["cell_boundaries"]) + + def test_get_points_metadata(self): + list_metadata = [0] * len(self.data.points["transcripts"]) + series_metadata = pd.Series(list_metadata) + column_names = ["list_metadata", "series_metadata"] + + bt.geo.set_points_metadata(sdata=self.data, points_key="transcripts", metadata=list_metadata, column_names=column_names[0]) + bt.geo.set_points_metadata(sdata=self.data, points_key="transcripts", metadata=series_metadata, column_names=column_names[1]) - # Check that points are cropped - self.assertTrue( - adata_crop.uns["points"].shape[0] - == data.uns["points"].query("cell == @data.obs.index[0]").shape[0] + pd_metadata_single = bt.geo.get_points_metadata(sdata=self.data, points_key="transcripts", metadata_keys=column_names[0], astype="pandas") + dd_metadata_single = bt.geo.get_points_metadata(sdata=self.data, points_key="transcripts", metadata_keys=column_names[0], astype="dask") + pd_metadata = bt.geo.get_points_metadata( + sdata=self.data, + points_key="transcripts", + metadata_keys=[column_names[0], column_names[1]], + astype="pandas" ) + dd_metadata = bt.geo.get_points_metadata( + sdata=self.data, + points_key="transcripts", + metadata_keys=[column_names[0], column_names[1]], + astype="dask" + ) + + self.assertTrue(type(pd_metadata_single) == pd.DataFrame) + self.assertTrue(column_names[0] in pd_metadata_single) + + self.assertTrue(type(dd_metadata_single) == dd.dataframe.core.DataFrame) + self.assertTrue(column_names[0] in dd_metadata_single) - def test_rename_cells(self): - res=0.02 - bt.tl.flux(data, method="radius", radius=200, res=res) - bt.tl.fluxmap(data, 2, train_size=1, res=res) - bt.geo.rename_shapes( - data, - {"fluxmap1_shape": "fluxmap3_shape", "fluxmap2_shape": "fluxmap4_shape"}, - points_key=["points", "cell_raster"], - points_encoding=["onhot", "label"], + self.assertTrue(type(pd_metadata) == pd.DataFrame) + self.assertTrue(type(dd_metadata) == dd.dataframe.core.DataFrame) + self.assertTrue("list_metadata" in pd_metadata_single) + self.assertTrue("list_metadata" in dd_metadata_single) + for column in column_names: + self.assertTrue(column in pd_metadata) + self.assertTrue(column in dd_metadata) + + def test_get_shape_metadata(self): + list_metadata = [0] * len(self.data.shapes["cell_boundaries"]) + series_metadata = pd.Series(list_metadata) + column_names = ["list_metadata", "series_metadata"] + + bt.geo.set_shape_metadata(sdata=self.data, shape_key="cell_boundaries", metadata=list_metadata, column_names=column_names[0]) + bt.geo.set_shape_metadata(sdata=self.data, shape_key="cell_boundaries", metadata=series_metadata, column_names=column_names[1]) + metadata_single = bt.geo.get_shape_metadata(sdata=self.data, shape_key="cell_boundaries", metadata_keys=column_names[0]) + metadata = bt.geo.get_shape_metadata( + sdata=self.data, + shape_key="cell_boundaries", + metadata_keys=[column_names[0], column_names[1]], ) - new_names = ["fluxmap3_shape", "fluxmap4_shape"] - self.assertTrue([f in data.obs.columns for f in new_names]) - self.assertTrue([f in data.uns["points"].columns for f in new_names]) - self.assertTrue([f in data.uns["cell_raster"]["fluxmap"] for f in ["3", "4"]]) + self.assertTrue(type(metadata_single) == pd.DataFrame) + self.assertTrue(column_names[0] in metadata_single) + + self.assertTrue(type(metadata) == pd.DataFrame) + for column in column_names: + self.assertTrue(column in metadata) diff --git a/tests/test_io.py b/tests/test_io.py new file mode 100644 index 0000000..7e799ad --- /dev/null +++ b/tests/test_io.py @@ -0,0 +1,43 @@ +import unittest +import bento as bt +import spatialdata as sd + + +class TestIO(unittest.TestCase): + def setUp(self): + self.data = sd.read_zarr("/mnt/d/spatial_datasets/small_data.zarr") + self.data = bt.io.format_sdata( + self.data, + points_key="transcripts", + feature_key="feature_name", + instance_key="cell_boundaries", + shape_keys=["cell_boundaries", "nucleus_boundaries"], + ) + + def test_points_indexing(self): + # Check points indexing + self.assertTrue("cell_boundaries" in self.data.points["transcripts"].columns) + self.assertTrue("nucleus_boundaries" in self.data.points["transcripts"].columns) + + def test_shapes_indexing(self): + # Check shapes indexing + self.assertTrue("cell_boundaries" in self.data.shapes["cell_boundaries"].columns) + self.assertTrue("cell_boundaries" in self.data.shapes["nucleus_boundaries"].columns) + self.assertTrue("nucleus_boundaries" in self.data.shapes["cell_boundaries"].columns) + + def test_points_attrs(self): + # Check points attrs + self.assertTrue("transform" in self.data.points["transcripts"].attrs.keys()) + self.assertTrue(self.data.points["transcripts"].attrs["spatialdata_attrs"]["feature_key"] == "feature_name") + self.assertTrue(self.data.points["transcripts"].attrs["spatialdata_attrs"]["instance_key"] == "cell_boundaries") + + def test_shapes_attrs(self): + # Check shapes attrs + self.assertTrue("transform" in self.data.shapes["cell_boundaries"].attrs.keys()) + self.assertTrue("transform" in self.data.shapes["nucleus_boundaries"].attrs.keys()) + + def test_index_dtypes(self): + # Check index dtypes + self.assertTrue(self.data.shapes["cell_boundaries"].index.dtype == "object") + self.assertTrue(self.data.shapes["nucleus_boundaries"].index.dtype == "object") + self.assertTrue(self.data.points["transcripts"].index.dtype == "int64") \ No newline at end of file diff --git a/tests/test_lp.py b/tests/test_lp.py index 59e26b1..71b27df 100644 --- a/tests/test_lp.py +++ b/tests/test_lp.py @@ -1,18 +1,75 @@ import unittest import bento as bt +import spatialdata as sd +import random -data = bt.ds.sample_data() +class TestLp(unittest.TestCase): + def setUp(self): + self.data = sd.read_zarr("/mnt/d/spatial_datasets/small_data.zarr") + self.data = bt.io.format_sdata( + self.data, + points_key="transcripts", + feature_key="feature_name", + instance_key="cell_boundaries", + shape_keys=["cell_boundaries", "nucleus_boundaries"], + ) + + bt.tl.lp(sdata=self.data, instance_key="cell_boundaries", nucleus_key="nucleus_boundaries", groupby="feature_name") + bt.tl.lp_stats(sdata=self.data, instance_key="cell_boundaries") -class TestPatterns(unittest.TestCase): def test_lp(self): - bt.tl.lp(data) + lp_columns = ['cell_boundaries', 'feature_name', 'cell_edge', 'cytoplasmic', 'none', 'nuclear', 'nuclear_edge'] + + # Check lp and lpp dataframes in sdata.table.uns + for column in lp_columns: + self.assertTrue(column in self.data.table.uns["lp"].columns) + self.assertTrue(column in self.data.table.uns["lpp"].columns) + + def test_lp_stats(self): + lp_stats_columns = ['cell_edge', 'cytoplasmic', 'none', 'nuclear', 'nuclear_edge'] + + # Check lp_stats index in sdata.table.uns + self.assertTrue(self.data.table.uns["lp_stats"].index.name == "feature_name") + # Check lp_stats dataframe in sdata.table.uns + for column in lp_stats_columns: + self.assertTrue(column in self.data.table.uns["lp_stats"].columns) + + def test_lp_diff_discrete(self): + lp_diff_discrete_columns = ['feature_name', 'pattern', 'phenotype', 'dy/dx', 'std_err', 'z', 'pvalue', 'ci_low', 'ci_high', 'padj', '-log10p', '-log10padj', 'log2fc'] + + # Assign random cell stage to each cell + stages = ['G0', 'G1', 'S', 'G2', 'M'] + phenotype = [] + for i in range(len(self.data.shapes['cell_boundaries'])): + phenotype.append(random.choice(stages)) + self.data.shapes['cell_boundaries']['cell_stage'] = phenotype + + bt.tl.lp_diff_discrete(sdata=self.data, instance_key="cell_boundaries", phenotype="cell_stage") + + # Check lp_diff_discrete dataframe in sdata.table.uns + for column in lp_diff_discrete_columns: + self.assertTrue(column in self.data.table.uns["diff_cell_stage"].columns) + + def test_lp_diff_discrete_error(self): + error_test = [] + for i in range(len(self.data.shapes['cell_boundaries'])): + if self.data.shapes['cell_boundaries']['cell_boundaries_area'][i] > self.data.shapes['cell_boundaries']['cell_boundaries_area'].median(): + error_test.append(1) + else: + error_test.append(0) + self.data.shapes['cell_boundaries']['error_test'] = error_test + + # Check that KeyError is raised when phenotype is numeric + with self.assertRaises(KeyError): + bt.tl.lp_diff_discrete(sdata=self.data, instance_key="cell_boundaries", phenotype="error_test") + + def test_lp_diff_continuous(self): + lp_diff_continuous_columns = ['feature_name', 'pattern', 'pearson_correlation'] - # Check if "lp" and "lpp" are in data.obsm - self.assertTrue("lp" in data.uns.keys() and "lpp" in data.uns.keys()) + bt.tl.lp_diff_continuous(self.data, phenotype="cell_boundaries_area") - def test_lp_plots(self): - bt.pl.lp_dist(data, percentage=True) - bt.pl.lp_dist(data, percentage=False) - bt.tl.lp_stats(data) - bt.pl.lp_genes(data) + # Check lp_diff_continuous dataframe in sdata.table.uns + for column in lp_diff_continuous_columns: + self.assertTrue(column in self.data.table.uns["diff_cell_boundaries_area"].columns) + \ No newline at end of file diff --git a/tests/test_plotting.py b/tests/test_plotting.py index 8f605e2..5818e58 100644 --- a/tests/test_plotting.py +++ b/tests/test_plotting.py @@ -1,33 +1,105 @@ -import unittest -import bento as bt -import matplotlib as mpl -import matplotlib.pyplot as plt +# import unittest +# import bento as bt +# import matplotlib as mpl +# import matplotlib.pyplot as plt -adata = bt.ds.sample_data() +# adata = bt.ds.sample_data() -# Test if plotting functions run without error -class TestPlotting(unittest.TestCase): - def test_points(self): - bt.pl.points(adata) +# # Test if plotting functions run without error +# class TestPlotting(unittest.TestCase): +# def test_analyze(self): +# bt.pl.points(adata) - bt.pl.points(adata, hue="gene", legend=False) +# bt.pl.points(adata, hue="gene", legend=False) - genes = ["MALAT1", "TLN1", "SPTBN1"] - bt.pl.points(adata[:, genes], hue="gene", legend=False) +# genes = ["MALAT1", "TLN1", "SPTBN1"] +# bt.pl.points(adata[:, genes], hue="gene") - self.assertTrue(True) +# bt.pl.density(adata) - def test_density(self): - bt.pl.density(adata) +# bt.pl.density(adata, kind="kde") - self.assertTrue(True) +# bt.pl.shapes(adata) - def test_shapes(self): - bt.pl.shapes(adata) +# bt.pl.shapes(adata, color_style="fill") - bt.pl.shapes(adata, color_style="fill") +# bt.pl.shapes(adata, hue="cell", color_style="fill") - bt.pl.shapes(adata, hue="cell", color_style="fill") +# fig, ax = plt.subplots() +# bt.pl.shapes(adata, shapes="cell", linestyle="--", ax=ax) +# bt.pl.shapes( +# adata, +# shapes="nucleus", +# edgecolor="black", +# facecolor="lightseagreen", +# ax=ax, +# ) +# fig, axes = plt.subplots(1, 2, figsize=(8, 4)) - self.assertTrue(True) +# bt.pl.density(adata, ax=axes[0], title="default styling") + +# bt.pl.density( +# adata, +# ax=axes[1], +# axis_visible=True, +# frame_visible=True, +# square=True, +# title="square plot + axis", +# ) +# plt.tight_layout() +# with mpl.style.context("dark_background"): +# fig, ax = plt.subplots() +# bt.pl.shapes(adata, shapes="cell", linestyle="--", ax=ax) +# bt.pl.shapes( +# adata, +# shapes="nucleus", +# edgecolor="black", +# facecolor="lightseagreen", +# ax=ax, +# ) +# cells = adata.obs_names[:8] # get some cells +# ncells = len(cells) + +# ncols = 4 +# nrows = 2 +# ax_height = 1.5 +# fig, axes = plt.subplots( +# nrows, ncols, figsize=(ncols * ax_height, nrows * ax_height) +# ) # instantiate + +# for c, ax in zip(cells, axes.flat): +# bt.pl.density( +# adata[c], +# ax=ax, +# square=True, +# title="", +# ) + +# plt.subplots_adjust(wspace=0, hspace=0, bottom=0, top=1, left=0, right=1) +# batches = adata.obs["batch"].unique()[:6] # get 6 batches +# nbatches = len(batches) + +# ncols = 3 +# nrows = 2 +# ax_height = 3 +# fig, axes = plt.subplots( +# nrows, ncols, figsize=(ncols * ax_height, nrows * ax_height) +# ) # instantiate + +# for b, ax in zip(batches, axes.flat): +# bt.pl.density( +# adata, +# batch=b, +# ax=ax, +# square=True, +# title="", +# ) + +# # remove empty axes +# for ax in axes.flat[nbatches:]: +# ax.remove() + +# plt.subplots_adjust(wspace=0, hspace=0, bottom=0, top=1, left=0, right=1) + +# self.assertTrue(True) diff --git a/tests/test_point_features.py b/tests/test_point_features.py index ae39625..7f40243 100644 --- a/tests/test_point_features.py +++ b/tests/test_point_features.py @@ -1,39 +1,114 @@ import unittest import bento as bt +import spatialdata as sd -data = bt.ds.sample_data()[:5, :5] -bt.sync(data) -# Ad a missing shape for testing -nucleus_shapes = data.obs["nucleus_shape"] -nucleus_shapes[1] = None +class TestPointFeatures(unittest.TestCase): + def setUp(self): + self.data = sd.read_zarr("/mnt/d/spatial_datasets/small_data.zarr") + self.data = bt.io.format_sdata( + sdata=self.data, + points_key="transcripts", + feature_key="feature_name", + instance_key="cell_boundaries", + shape_keys=["cell_boundaries", "nucleus_boundaries"], + ) + self.point_features = bt.tl.list_point_features().keys() + self.instance_key = ["cell_boundaries"] + self.feature_key = ["feature_name"] + self.indpendent_features = [ + 'point_dispersion_norm', + 'point_dispersion', + 'l_max', + 'l_max_gradient', + 'l_min_gradient', + 'l_monotony', + 'l_half_radius' + ] -features = list(bt.tl.list_point_features().keys()) + feature_names = [ + 'inner_proximity', + 'outer_proximity', + 'inner_asymmetry', + 'outer_asymmetry', + 'dispersion_norm', + 'inner_distance', + 'outer_distance', + 'inner_offset', + 'outer_offset', + 'dispersion', + 'enrichment' + ] + self.cell_features = [f"cell_boundaries_{x}" for x in feature_names] + self.nucleus_features = [f"nucleus_boundaries_{x}" for x in feature_names] -class TestPointFeatures(unittest.TestCase): - def test_single_feature(self): - # Simplest case, single parameters - bt.tl.analyze_points(data, "cell_shape", features[0], groupby=None) - self.assertTrue("cell_features" in data.uns) - self.assertTrue(data.uns["cell_features"].shape[0] == data.n_obs) + # Test case to check if features are calculated for a single shape and a single group + def test_single_shape_single_group(self): + data = bt.tl.analyze_points( + sdata=self.data, + shape_keys=["cell_boundaries"], + feature_names=self.point_features, + groupby=None, + recompute=False, + progress=True, + ) + + point_features = self.instance_key + self.indpendent_features + self.cell_features + + # Check if cell_boundaries point features are calculated + for feature in point_features: + self.assertTrue(feature in data.table.uns["cell_boundaries_features"].columns) + - def test_multiple_shapes(self): - # Multiple shapes, and features - bt.tl.analyze_points( - data, ["cell_shape", "nucleus_shape"], features, groupby=None + # Test case to check if features are calculated for a single shape and multiple groups + def test_single_shape_multiple_groups(self): + data = bt.tl.analyze_points( + sdata=self.data, + shape_keys=["cell_boundaries"], + feature_names=self.point_features, + groupby=["feature_name"], + recompute=False, + progress=True, ) - self.assertTrue("cell_features" in data.uns) - self.assertTrue(data.uns["cell_features"].shape[0] == data.n_obs) + point_features = self.instance_key + self.feature_key+ self.indpendent_features + self.cell_features - def test_multiple_shapes_features_groupby(self): - # Multiple shapes, features, and gene groupby - bt.tl.analyze_points( - data, ["cell_shape", "nucleus_shape"], features, groupby="gene" + # Check if cell_boundaries and gene point features are calculated + for feature in point_features: + self.assertTrue(feature in data.table.uns["cell_boundaries_feature_name_features"].columns) + + # Test case to check if point features are calculated for multiple shapes and a single group + def test_multiple_shapes_single_group(self): + data = bt.tl.analyze_points( + sdata=self.data, + shape_keys=["cell_boundaries", "nucleus_boundaries"], + feature_names=self.point_features, + groupby=None, + recompute=False, + progress=True, + ) + + point_features = self.instance_key + self.indpendent_features + self.cell_features + self.nucleus_features + + # Check if cell_boundaries and nucleus_boundaries point features are calculated + for feature in point_features: + self.assertTrue(feature in data.table.uns["cell_boundaries_features"].columns) + + # Test case to check if multiple shape features are calculated for multiple shapes + def test_multiple_shapes_multiple_groups(self): + data = bt.tl.analyze_points( + sdata=self.data, + shape_keys=["cell_boundaries", "nucleus_boundaries"], + feature_names=self.point_features, + groupby=["feature_name"], + recompute=False, + progress=True, ) - output_key = "cell_gene_features" - n_groups = data.uns["points"].groupby(["cell", "gene"], observed=True).ngroups - self.assertTrue(data.uns[output_key].shape[0] == n_groups) + point_features = self.instance_key + self.feature_key + self.indpendent_features + self.cell_features + self.nucleus_features + + # Check if cell_boundaries and nucleus_boundaries point features are calculated + for feature in point_features: + self.assertTrue(feature in data.table.uns["cell_boundaries_feature_name_features"].columns) diff --git a/tests/test_shape_features.py b/tests/test_shape_features.py index a9a4a13..7c4c4b9 100644 --- a/tests/test_shape_features.py +++ b/tests/test_shape_features.py @@ -1,33 +1,134 @@ import unittest import bento as bt +import spatialdata as sd -data = bt.ds.sample_data()[:5, :5] -bt.sync(data) -# Ad a missing shape for testing -nucleus_shapes = data.obs["nucleus_shape"] -nucleus_shapes[1] = None +class TestShapeFeatures(unittest.TestCase): + def setUp(self): + self.data = sd.read_zarr("/mnt/d/spatial_datasets/small_data.zarr") + self.data = bt.io.format_sdata( + sdata=self.data, + points_key="transcripts", + feature_key="feature_name", + instance_key="cell_boundaries", + shape_keys=["cell_boundaries", "nucleus_boundaries"], + ) -features = list(bt.tl.list_shape_features().keys()) + self.shape_features = bt.tl.list_shape_features().keys() + feature_names = ['area', 'aspect_ratio', 'minx', 'miny', 'maxx', 'maxy', 'density', 'open_0.5_shape', 'perimeter', 'radius', 'raster', 'moment', 'span'] + self.cell_features = [f"cell_boundaries_{x}" for x in feature_names] + self.nucleus_features = [f"nucleus_boundaries_{x}" for x in feature_names] -class TestShapeFeatures(unittest.TestCase): - # Simplest case, single shape and feature + # Simplest test to check if a single shape feature is calculated for a single shape def test_single_shape_single_feature(self): - # Test shape name with/without suffix - bt.tl.analyze_shapes(data, "cell", "area") - bt.tl.analyze_shapes(data, "cell_shape", "area") - self.assertTrue("cell_area" in data.obs) - - def test_single_shape_multi_feature(self): - # Test all features - bt.tl.analyze_shapes(data, "cell", features) - feature_keys = [f"cell_{f}" for f in features] - self.assertTrue(f in data.obs for f in feature_keys) - - def test_missing_shape(self): - # Test missing nucleus shapes - bt.tl.analyze_shapes(data, "nucleus", features) - feature_keys = [f"nucleus_{f}" for f in features] - self.assertTrue(f in data.obs for f in feature_keys) - self.assertTrue(data.obs[f].isna()[1] for f in feature_keys) + self.data = bt.tl.analyze_shapes( + sdata=self.data, + shape_keys="cell_boundaries", + feature_names="area", + progress=True + ) + + # Check if cell_boundaries shape features are calculated + self.assertTrue("cell_boundaries_area" in self.data.shapes["cell_boundaries"].columns) + + # Check shapes attrs + self.assertTrue("transform" in self.data.shapes["cell_boundaries"].attrs.keys()) + + # Test case to check if multiple shape features are calculated for a single shape + def test_single_shape_multiple_features(self): + self.data = bt.tl.analyze_shapes( + sdata=self.data, + shape_keys="cell_boundaries", + feature_names=self.shape_features, + feature_kws={"opening": {"proportion": 0.5}}, + progress=True + ) + + # Check if cell_boundaries shape features are calculated + for feature in self.cell_features: + self.assertTrue(feature in self.data.shapes["cell_boundaries"].columns) + + # Check that raster is a points element + self.assertTrue("cell_boundaries_raster" in self.data.points.keys()) + + # Check points attrs + self.assertTrue("transform" in self.data.points["cell_boundaries_raster"].attrs.keys()) + self.assertTrue(self.data.points["cell_boundaries_raster"].attrs["spatialdata_attrs"]["feature_key"] == "feature_name") + self.assertTrue(self.data.points["cell_boundaries_raster"].attrs["spatialdata_attrs"]["instance_key"] == "cell_boundaries") + + # Check shapes attrs + self.assertTrue("transform" in self.data.shapes["cell_boundaries"].attrs.keys()) + + # Test case to check if a single shape feature is calculated for multiple shapes + def test_multiple_shapes_single_feature(self): + self.data = bt.tl.analyze_shapes( + sdata=self.data, + shape_keys=["cell_boundaries", "nucleus_boundaries"], + feature_names="area", + progress=True + ) + + # Check if cell_boundaries and nucleus_boundaries shape features are calculated + self.assertTrue("cell_boundaries_area" in self.data.shapes["cell_boundaries"].columns) + self.assertTrue("nucleus_boundaries_area" in self.data.shapes["nucleus_boundaries"].columns) + + # Check shapes attrs + self.assertTrue("transform" in self.data.shapes["cell_boundaries"].attrs.keys()) + self.assertTrue("transform" in self.data.shapes["nucleus_boundaries"].attrs.keys()) + + # Test case to check if multiple shape features are calculated for multiple shapes + def test_multiple_shapes_multiple_features(self): + self.data = bt.tl.analyze_shapes( + sdata=self.data, + shape_keys=["cell_boundaries", "nucleus_boundaries"], + feature_names=self.shape_features, + feature_kws={"opening": {"proportion": 0.5}}, + progress=True + ) + + # Check if cell_boundaries shape features are calculated + for feature in self.cell_features: + self.assertTrue(feature in self.data.shapes["cell_boundaries"].columns) + + # Check that raster is a points element + self.assertTrue("cell_boundaries_raster" in self.data.points.keys()) + + # Check points attrs + self.assertTrue("transform" in self.data.points["cell_boundaries_raster"].attrs.keys()) + self.assertTrue(self.data.points["cell_boundaries_raster"].attrs["spatialdata_attrs"]["feature_key"] == "feature_name") + self.assertTrue(self.data.points["cell_boundaries_raster"].attrs["spatialdata_attrs"]["instance_key"] == "cell_boundaries") + + # Check shapes attrs + self.assertTrue("transform" in self.data.shapes["cell_boundaries"].attrs.keys()) + + # Check if nucleus_boundaries shape features are calculated + for feature in self.nucleus_features: + self.assertTrue(feature in self.data.shapes["nucleus_boundaries"].columns) + + # Check that raster is a points element + self.assertTrue("nucleus_boundaries_raster" in self.data.points.keys()) + + # Check points attrs + self.assertTrue("transform" in self.data.points["nucleus_boundaries_raster"].attrs.keys()) + self.assertTrue(self.data.points["nucleus_boundaries_raster"].attrs["spatialdata_attrs"]["feature_key"] == "feature_name") + self.assertTrue(self.data.points["nucleus_boundaries_raster"].attrs["spatialdata_attrs"]["instance_key"] == "cell_boundaries") + + # Check shapes attrs + self.assertTrue("transform" in self.data.shapes["nucleus_boundaries"].attrs.keys()) + + # Test case to check if obs_stats function calculates area, aspect_ratio and density for both cell_boundaries and nucleus_boundaries + def test_obs_stats(self): + bt.tl.obs_stats(sdata=self.data) + + # Check if cell_boundaries and nucleus_boundaries shape features are calculated + self.assertTrue("cell_boundaries_area" in self.data.shapes["cell_boundaries"].columns) + self.assertTrue("cell_boundaries_aspect_ratio" in self.data.shapes["cell_boundaries"].columns) + self.assertTrue("cell_boundaries_density" in self.data.shapes["cell_boundaries"].columns) + self.assertTrue("nucleus_boundaries_area" in self.data.shapes["nucleus_boundaries"].columns) + self.assertTrue("nucleus_boundaries_aspect_ratio" in self.data.shapes["nucleus_boundaries"].columns) + self.assertTrue("nucleus_boundaries_density" in self.data.shapes["nucleus_boundaries"].columns) + + # Check shapes attrs + self.assertTrue("transform" in self.data.shapes["cell_boundaries"].attrs.keys()) + self.assertTrue("transform" in self.data.shapes["nucleus_boundaries"].attrs.keys()) diff --git a/tests/test_signatures.py b/tests/test_signatures.py index e8d929e..eaae67e 100644 --- a/tests/test_signatures.py +++ b/tests/test_signatures.py @@ -1,11 +1,11 @@ -import unittest -import bento +# import unittest +# import bento -data = bento.datasets.sample_data() +# data = bento.datasets.sample_data() -class TestSignatures(unittest.TestCase): - def test_to_tensor(self): - bento.tl.to_tensor(data, [None]) - tensor = data.uns["tensor"] - self.assertTrue(tensor.shape == (1, data.n_obs, data.n_vars)) +# class TestSignatures(unittest.TestCase): +# def test_to_tensor(self): +# bento.tl.to_tensor(data, [None]) +# tensor = data.uns["tensor"] +# self.assertTrue(tensor.shape == (1, data.n_obs, data.n_vars))