diff --git a/.gitignore b/.gitignore index acbd6fc..90d129f 100644 --- a/.gitignore +++ b/.gitignore @@ -23,7 +23,6 @@ lib/ lib64/ parts/ sdist/ -var/ wheels/ pip-wheel-metadata/ share/python-wheels/ diff --git a/bento/__init__.py b/bento/__init__.py index d3a09aa..e3baf20 100644 --- a/bento/__init__.py +++ b/bento/__init__.py @@ -1,3 +1,5 @@ +from ._version import __version__ + from . import _utils as ut from . import geometry as geo from . import plotting as pl diff --git a/bento/_utils.py b/bento/_utils.py index 720f116..3e8edde 100644 --- a/bento/_utils.py +++ b/bento/_utils.py @@ -17,27 +17,26 @@ def filter_by_gene( min_count: int = 10, points_key: str = "transcripts", feature_key: str = "feature_name", -): - """ - Filters out genes with low expression from the spatial data object. +) -> SpatialData: + """Filter out genes with low expression. Parameters ---------- sdata : SpatialData - Spatial formatted SpatialData object. - threshold : int - Minimum number of counts for a gene to be considered expressed. - Keep genes where at least {threshold} molecules are detected in at least one cell. - points_key : str - key for points element that holds transcript coordinates - feature_key : str - Key for gene instances + Input SpatialData object + min_count : int, default 10 + Minimum number of molecules required per gene + points_key : str, default "transcripts" + Key for points in sdata.points + feature_key : str, default "feature_name" + Column name containing gene identifiers Returns ------- - sdata : SpatialData - .points[points_key] is updated to remove genes with low expression. - .tables["table"] is updated to remove genes with low expression. + SpatialData + Updated object with filtered: + - points[points_key]: Only points from expressed genes + - tables["table"]: Only expressed genes """ gene_filter = (sdata.tables["table"].X >= min_count).sum(axis=0) > 0 filtered_table = sdata.tables["table"][:, gene_filter] @@ -71,23 +70,28 @@ def get_points( astype: str = "pandas", sync: bool = True, ) -> Union[pd.DataFrame, dd.DataFrame, gpd.GeoDataFrame]: - """Get points DataFrame synced to AnnData object. + """Get points data synchronized with cell boundaries. Parameters ---------- - data : SpatialData - Spatial formatted SpatialData object - key : str, optional - Key for `data.points` to use, by default "transcripts" - astype : str, optional - Whether to return a 'pandas' DataFrame, 'dask' DataFrame, or 'geopandas' GeoDataFrame, by default "pandas" - sync : bool, optional - Whether to set and retrieve points synced to instance_key shape. Default True. + sdata : SpatialData + Input SpatialData object + points_key : str, default "transcripts" + Key for points in sdata.points + astype : str, default "pandas" + Return type: 'pandas', 'dask', or 'geopandas' + sync : bool, default True + Whether to sync points with instance_key shapes Returns ------- - DataFrame or GeoDataFrame - Returns `data.points[key]` as a `[Geo]DataFrame` or 'Dask DataFrame' + Union[pd.DataFrame, dd.DataFrame, gpd.GeoDataFrame] + Points data in requested format + + Raises + ------ + ValueError + If points_key not found or invalid astype """ if points_key not in sdata.points.keys(): raise ValueError(f"Points key {points_key} not found in sdata.points") @@ -114,22 +118,31 @@ def get_points( ) -def get_shape(sdata: SpatialData, shape_key: str, sync: bool = True) -> gpd.GeoSeries: - """Get a GeoSeries of Polygon objects from an SpatialData object. +def get_shape( + sdata: SpatialData, + shape_key: str, + sync: bool = True +) -> gpd.GeoSeries: + """Get shape geometries synchronized with cell boundaries. Parameters ---------- sdata : SpatialData - Spatial formatted SpatialData object + Input SpatialData object shape_key : str - Name of shape column in sdata.shapes - sync : bool - Whether to set and retrieve shapes synced to cell shape. Default True. + Key for shapes in sdata.shapes + sync : bool, default True + Whether to sync shapes with instance_key shapes Returns ------- - GeoSeries - GeoSeries of Polygon objects + gpd.GeoSeries + Shape geometries + + Raises + ------ + ValueError + If shape_key not found in sdata.shapes """ instance_key = sdata.tables["table"].uns["spatialdata_attrs"]["instance_key"] @@ -152,23 +165,28 @@ def get_points_metadata( points_key: str, astype: str = "pandas", ) -> Union[pd.DataFrame, dd.DataFrame]: - """Get points metadata. + """Get metadata columns from points data. Parameters ---------- sdata : SpatialData - Spatial formatted SpatialData object + Input SpatialData object metadata_keys : str or list of str - Key(s) for `sdata.points[points_key][key]` to use - points_key : str, optional - Key for `sdata.points` to use, by default "transcripts" - astype : str, optional - Whether to return a 'pandas' Series or 'dask' DataFrame, by default "pandas" + Column name(s) to retrieve + points_key : str + Key for points in sdata.points + astype : str, default "pandas" + Return type: 'pandas' or 'dask' Returns ------- - pd.DataFrame or dd.DataFrame - Returns `sdata.points[points_key][metadata_keys]` as a `pd.DataFrame` or `dd.DataFrame` + Union[pd.DataFrame, dd.DataFrame] + Requested metadata columns + + Raises + ------ + ValueError + If points_key or metadata_keys not found """ if points_key not in sdata.points.keys(): raise ValueError(f"Points key {points_key} not found in sdata.points") @@ -195,21 +213,26 @@ def get_shape_metadata( metadata_keys: Union[List[str], str], shape_key: str, ) -> pd.DataFrame: - """Get shape metadata. + """Get metadata columns from shapes data. Parameters ---------- sdata : SpatialData - Spatial formatted SpatialData object + Input SpatialData object metadata_keys : str or list of str - Key(s) for `sdata.shapes[shape_key][key]` to use + Column name(s) to retrieve shape_key : str - Key for `sdata.shapes` to use, by default "transcripts" + Key for shapes in sdata.shapes Returns ------- pd.DataFrame - Returns `sdata.shapes[shape_key][metadata_keys]` as a `pd.DataFrame` + Requested metadata columns + + Raises + ------ + ValueError + If shape_key or metadata_keys not found """ if shape_key not in sdata.shapes.keys(): raise ValueError(f"Shape key {shape_key} not found in sdata.shapes") @@ -230,18 +253,23 @@ def set_points_metadata( metadata: Union[List, pd.Series, pd.DataFrame, np.ndarray], columns: Union[List[str], str], ) -> None: - """Write metadata in SpatialData points element as column(s). + """Add metadata columns to points data. Parameters ---------- sdata : SpatialData - Spatial formatted SpatialData object + Input SpatialData object points_key : str - Name of element in sdata.points - metadata : pd.Series, pd.DataFrame, np.ndarray - Metadata to set for points. Assumes input is already aligned to points index. - column_names : str or list of str, optional - Name of column(s) to set. If None, use metadata column name(s), by default None + Key for points in sdata.points + metadata : array-like + Data to add as new columns + columns : str or list of str + Names for new columns + + Raises + ------ + ValueError + If points_key not found """ if points_key not in sdata.points.keys(): raise ValueError(f"{points_key} not found in sdata.points") @@ -275,18 +303,23 @@ def set_shape_metadata( metadata: Union[List, pd.Series, pd.DataFrame, np.ndarray], column_names: Union[List[str], str] = None, ) -> None: - """Write metadata in SpatialData shapes element as column(s). Aligns metadata index to shape index. + """Add metadata columns to shapes data. Parameters ---------- sdata : SpatialData - Spatial formatted SpatialData object + Input SpatialData object shape_key : str - Name of element in sdata.shapes - metadata : pd.Series, pd.DataFrame - Metadata to set for shape. Index must be a (sub)set of shape index. + Key for shapes in sdata.shapes + metadata : array-like + Data to add as new columns column_names : str or list of str, optional - Name of column(s) to set. If None, use metadata column name(s), by default None + Names for new columns. If None, use metadata column names + + Raises + ------ + ValueError + If shape_key not found """ if shape_key not in sdata.shapes.keys(): raise ValueError(f"Shape {shape_key} not found in sdata.shapes") @@ -320,21 +353,18 @@ def set_shape_metadata( # sdata.shapes[shape_key].loc[:, metadata.columns] = metadata.reindex(shape_index) -def _sync_points(sdata, points_key): - """ - Check if points are synced to instance_key shape in a SpatialData object. +def _sync_points(sdata: SpatialData, points_key: str) -> None: + """Synchronize points with cell boundaries. + + Updates sdata.points[points_key] to only include points within cells. Parameters ---------- sdata : SpatialData - The SpatialData object to check. + Input SpatialData object points_key : str - The name of the points to check. + Key for points in sdata.points - Raises - ------ - ValueError - If the points are not synced to instance_key shape. """ points = sdata.points[points_key].compute() instance_key = get_instance_key(sdata) @@ -354,23 +384,20 @@ def _sync_points(sdata, points_key): sdata.points[points_key] = points_valid -def _sync_shapes(sdata, shape_key, instance_key): - """ - Check if a shape is synced to instance_key shape in a SpatialData object. +def _sync_shapes(sdata: SpatialData, shape_key: str, instance_key: str) -> None: + """Synchronize shapes with cell boundaries. + + Updates sdata.shapes[shape_key] to only include shapes within cells. Parameters ---------- sdata : SpatialData - The SpatialData object to check. + Input SpatialData object shape_key : str - The name of the shape to check. + Key for shapes to sync instance_key : str - The instance key of the shape to check. + Key for cell boundaries - Raises - ------ - ValueError - If the shape is not synced to instance_key shape. """ shapes = sdata.shapes[shape_key] instance_shapes = sdata.shapes[instance_key] @@ -388,19 +415,23 @@ def _sync_shapes(sdata, shape_key, instance_key): sdata.shapes[shape_key] = shapes_valid -def get_instance_key(sdata: SpatialData): - """ - Returns the instance key for the spatial data object. +def get_instance_key(sdata: SpatialData) -> str: + """Get key for cell boundaries. Parameters ---------- sdata : SpatialData - Spatial formatted SpatialData object. + Input SpatialData object Returns ------- - instance_key : str - Key for the shape that will be used as the instance for all indexing. Usually the cell shape. + str + Key for cell boundaries in sdata.shapes + + Raises + ------ + KeyError + If instance key attribute not found """ try: return sdata.points["transcripts"].attrs["spatialdata_attrs"]["instance_key"] @@ -410,19 +441,23 @@ def get_instance_key(sdata: SpatialData): ) -def get_feature_key(sdata: SpatialData): - """ - Returns the feature key for the spatial data object. +def get_feature_key(sdata: SpatialData) -> str: + """Get key for gene identifiers. Parameters ---------- sdata : SpatialData - Spatial formatted SpatialData object. + Input SpatialData object Returns ------- - feature_key : str - Key for the feature name in the points DataFrame + str + Column name containing gene identifiers + + Raises + ------ + KeyError + If feature key attribute not found """ try: return sdata.points["transcripts"].attrs["spatialdata_attrs"]["feature_key"] diff --git a/bento/_version.py b/bento/_version.py new file mode 100644 index 0000000..5baf2b2 --- /dev/null +++ b/bento/_version.py @@ -0,0 +1,15 @@ +import tomli + +def get_version(): + import os + import pathlib + + package_dir = pathlib.Path(__file__).parent.parent + pyproject_path = os.path.join(package_dir, "pyproject.toml") + + with open(pyproject_path, "rb") as f: + pyproject = tomli.load(f) + return pyproject["project"]["version"] + +__version__ = get_version() + diff --git a/bento/geometry/__init__.py b/bento/geometry/__init__.py index a9e6098..430053a 100644 --- a/bento/geometry/__init__.py +++ b/bento/geometry/__init__.py @@ -1 +1 @@ -from ._ops import overlay, labels_to_shapes \ No newline at end of file +from ._ops import overlay \ No newline at end of file diff --git a/bento/geometry/_ops.py b/bento/geometry/_ops.py index 6eec55c..247e0a1 100644 --- a/bento/geometry/_ops.py +++ b/bento/geometry/_ops.py @@ -20,33 +20,38 @@ def overlay( name: str, how: str = "intersection", make_valid: bool = True, + instance_map_type: str = "1to1", ): """Overlay two shape elements in a SpatialData object and store the result as a new shape element. Parameters ---------- sdata : SpatialData - Spatial formatted SpatialData object + SpatialData object containing the shape elements s1 : str Name of the first shape element s2 : str Name of the second shape element - how : str + name : str + Name of the new shape element to be created + how : str, optional Type of overlay operation to perform. Options are "intersection", "union", "difference", "symmetric_difference", by default "intersection" - make_valid : bool + make_valid : bool, optional If True, correct invalid geometries with GeoPandas, by default True - instance_key : str - Name of the shape element to use as the instance for indexing, by default "cell_boundaries". If None, no indexing is performed. + instance_map_type : str, optional + Type of instance mapping to use. Options are "1to1", "1tomany", by default "1to1". + Returns ------- - SpatialData - A new SpatialData object with the resulting shapes from the overlay operation. + None + The function modifies the input SpatialData object in-place """ shape1 = sdata[s1] shape2 = sdata[s2] - new_shape = shape1.overlay(shape2, how=how, make_valid=make_valid) + new_shape = shape1.overlay(shape2, how=how, make_valid=make_valid)[["geometry"]] + new_shape.index = new_shape.index.astype(str) new_shape.attrs = {} transform = shape1.attrs @@ -58,66 +63,5 @@ def overlay( shape_keys=[name], instance_key=get_instance_key(sdata), feature_key=get_feature_key(sdata), + instance_map_type=instance_map_type, ) - - -@singledispatch -def labels_to_shapes(labels: np.ndarray, attrs: dict, bg_value: int = 0): - """ - Given a labeled 2D image, convert encoded pixels as Polygons and return a SpatialData verified GeoPandas DataFrame. - - Parameters - ---------- - labels : np.ndarray - Labeled 2D image where each pixel is encoded with an integer value. - attrs : dict - Dictionary of attributes to set for the SpatialData object. - bg_value : int, optional - Value of the background pixels, by default 0 - - Returns - ------- - GeoPandas DataFrame - GeoPandas DataFrame containing the polygons extracted from the labeled image. - - """ - import rasterio as rio - import shapely.geometry - - # Extract polygons from labeled image - contours = rio.features.shapes(labels) - polygons = np.array([(shapely.geometry.shape(p), v) for p, v in contours]) - shapes = gpd.GeoDataFrame( - polygons[:, 1], geometry=gpd.GeoSeries(polygons[:, 0]).T, columns=["id"] - ) - shapes = shapes[shapes["id"] != bg_value] # Ignore background - - # Validate for SpatialData - sd_shape = ShapesModel.parse(shapes) - sd_shape.attrs = attrs - return sd_shape - - -@labels_to_shapes.register(SpatialImage) -def _(labels: SpatialImage, attrs: dict, bg_value: int = 0): - """ - Given a labeled 2D image, convert encoded pixels as Polygons and return a SpatialData verified GeoPandas DataFrame. - - Parameters - ---------- - labels : SpatialImage - Labeled 2D image where each pixel is encoded with an integer value. - attrs : dict - Dictionary of attributes to set for the SpatialData object. - bg_value : int, optional - Value of the background pixels, by default 0 - - Returns - ------- - GeoPandas DataFrame - GeoPandas DataFrame containing the polygons extracted from the labeled image. - """ - - # Convert spatial_image.SpatialImage to np.ndarray - labels = labels.values - return labels_to_shapes(labels, attrs, bg_value) diff --git a/bento/io/_index.py b/bento/io/_index.py index d7b78da..7044ae2 100644 --- a/bento/io/_index.py +++ b/bento/io/_index.py @@ -1,9 +1,9 @@ -from typing import List +from typing import List, Union import pandas as pd import geopandas as gpd from spatialdata._core.spatialdata import SpatialData - +from spatialdata.models import ShapesModel from .._utils import ( get_points, set_points_metadata, @@ -16,21 +16,21 @@ def _sjoin_points( points_key: str, shape_keys: List[str], ): - """Index points to shapes and add as columns to `data.points[points_key]`. Only supports 2D points for now. + """Index points to shapes and add as columns to `sdata.points[points_key]`. Only supports 2D points for now. Parameters ---------- sdata : SpatialData - Spatial formatted SpatialData object + SpatialData object points_key : str Key for points DataFrame in `sdata.points` - shape_keys : str, list + shape_keys : List[str] List of shape names to index points to Returns ------- - sdata : SpatialData - .points[points_key]: Updated points DataFrame with string index for each shape + SpatialData + Updated SpatialData object with `sdata.points[points_key]` containing new columns for each shape index """ if isinstance(shape_keys, str): @@ -68,22 +68,30 @@ def _sjoin_points( return sdata -def _sjoin_shapes(sdata: SpatialData, instance_key: str, shape_keys: List[str]): + +def _sjoin_shapes( + sdata: SpatialData, + instance_key: str, + shape_keys: List[str], + instance_map_type: Union[str, dict], +): """Adds polygon indexes to sdata.shapes[instance_key][shape_key] for point feature analysis. Parameters ---------- sdata : SpatialData - Spatially formatted SpatialData + SpatialData object instance_key : str Key for the shape that will be used as the instance for all indexing. Usually the cell shape. - shape_keys : str or list of str + shape_keys : List[str] Names of the shapes to add. + instance_map_type : str + Type of instance mapping to use. "1tomany" or "manyto1". Returns ------- - sdata : SpatialData - .shapes[cell_shape_key][shape_key] + SpatialData + Updated SpatialData object with `sdata.shapes[instance_key]` containing new columns for each shape index """ # Cast to list if not already @@ -98,21 +106,40 @@ def _sjoin_shapes(sdata: SpatialData, instance_key: str, shape_keys: List[str]): if len(shape_keys) == 0: return sdata - parent_shape = gpd.GeoDataFrame(sdata.shapes[instance_key]) + parent_shape = gpd.GeoDataFrame(geometry=sdata.shapes[instance_key].geometry) # sjoin shapes to instance_key shape for shape_key in shape_keys: - child_shape = sdata.shapes[shape_key] + child_shape = sdata.shapes[shape_key].copy() + child_attrs = child_shape.attrs # Hack for polygons that are 99% contained in parent shape or have shared boundaries child_shape = gpd.GeoDataFrame(geometry=child_shape.buffer(-10e-6)) # Map child shape index to parent shape and process the result + + if instance_map_type == "1tomany": + child_shape = ( + child_shape.sjoin( + parent_shape.reset_index(drop=True), + how="left", + predicate="covered_by", + ) + .dissolve(by="index_right", observed=True, dropna=False) + .reset_index(drop=True)[["geometry"]] + ) + child_shape.index = child_shape.index.astype(str) + child_shape = ShapesModel.parse(child_shape) + child_shape.attrs = child_attrs + sdata.shapes[shape_key] = child_shape + parent_shape = ( parent_shape.sjoin(child_shape, how="left", predicate="covers") - .reset_index() - .drop_duplicates(subset="index", keep="last") + .reset_index() # ignore any user defined index name + .drop_duplicates( + subset="index", keep="last" + ) # Remove multiple child shapes mapped to same parent shape .set_index("index") - .assign( + .assign( # can this just be fillna on index_right? index_right=lambda df: df.loc[ ~df["index_right"].duplicated(keep="first"), "index_right" ] @@ -121,6 +148,13 @@ def _sjoin_shapes(sdata: SpatialData, instance_key: str, shape_keys: List[str]): ) .rename(columns={"index_right": shape_key}) ) + + # Add empty category to shape_key if not already present + if ( + parent_shape[shape_key].dtype == "category" + and "" not in parent_shape[shape_key].cat.categories + ): + parent_shape[shape_key] = parent_shape[shape_key].cat.add_categories([""]) parent_shape[shape_key] = parent_shape[shape_key].fillna("") # Save shape index as column in instance_key shape @@ -128,7 +162,7 @@ def _sjoin_shapes(sdata: SpatialData, instance_key: str, shape_keys: List[str]): sdata, shape_key=instance_key, metadata=parent_shape[shape_key] ) - # Add instance_key shape index to shape + # Add instance_key shape index to child shape instance_index = ( parent_shape.drop_duplicates(subset=shape_key) .reset_index() diff --git a/bento/io/_io.py b/bento/io/_io.py index bc5cd1a..ede400e 100644 --- a/bento/io/_io.py +++ b/bento/io/_io.py @@ -1,7 +1,8 @@ import warnings -from typing import List +from typing import List, Union import emoji +import spatialdata as sd from anndata.utils import make_index_unique from spatialdata import SpatialData from spatialdata.models import TableModel @@ -19,29 +20,38 @@ def prep( feature_key: str = "feature_name", instance_key: str = "cell_boundaries", shape_keys: List[str] = ["cell_boundaries", "nucleus_boundaries"], + instance_map_type: Union[dict, str] = "1to1", ) -> SpatialData: """Computes spatial indices for elements in SpatialData to enable usage of bento-tools. - Specifically, this function indexes points to shapes and joins shapes to the instance shape. It also computes a count table for the points. + This function indexes points to shapes, joins shapes to the instance shape, and computes a count table for the points. Parameters ---------- sdata : SpatialData - Spatial formatted SpatialData object - points_key : str + SpatialData object + points_key : str, default "transcripts" Key for points DataFrame in `sdata.points` - feature_key : str + feature_key : str, default "feature_name" Key for the feature name in the points DataFrame - instance_key : str + instance_key : str, default "cell_boundaries" Key for the shape that will be used as the instance for all indexing. Usually the cell shape. - shape_keys : str, list + shape_keys : List[str], default ["cell_boundaries", "nucleus_boundaries"] List of shape names to index points to + instance_map_type : str, dict + Type of mapping to use for the instance shape. If "1to1", each instance shape will be mapped to a single shape at most. + If "1tomany", each instance shape will be mapped to one or more shapes; + multiple shapes mapped to the same instance shape will be merged into a single MultiPolygon. + Use a dict to specify different mapping types for each shape. Returns ------- SpatialData - .shapes[shape_key]: Updated shapes GeoDataFrame with string index - .points[points_key]: Updated points DataFrame with string index for each shape + Updated SpatialData object with: + - Updated shapes in `sdata.shapes[shape_key]` with string index + - Updated points in `sdata.points[points_key]` with string index for each shape + - New count table in `sdata.tables["table"]` + - Updated attributes for instance_key and feature_key """ # Renames geometry column of shape element to match shape name @@ -51,6 +61,24 @@ def prep( shape_gdf[shape_key] = shape_gdf["geometry"] shape_gdf.index = make_index_unique(shape_gdf.index.astype(str)) + transform = { + "global": sd.transformations.get_transformation(sdata.points[points_key]) + } + if "global" in sdata.points[points_key].attrs["transform"]: + # Force points to 2D for Xenium data + if isinstance(transform["global"], sd.transformations.Scale): + transform = { + "global": sd.transformations.Scale( + scale=transform.to_scale_vector(["x", "y"]), axes=["x", "y"] + ) + } + sdata.points[points_key] = sd.models.PointsModel.parse( + sdata.points[points_key].compute().reset_index(drop=True), + coordinates={"x": "x", "y": "y"}, + feature_key=feature_key, + transformations=transform, + ) + # sindex points and sjoin shapes if they have not been indexed or joined point_sjoin = [] shape_sjoin = [] @@ -72,6 +100,19 @@ def prep( sdata.points[points_key].attrs["spatialdata_attrs"]["instance_key"] = instance_key pbar = tqdm(total=3) + if len(shape_sjoin) > 0: + pbar.set_description( + "Mapping shapes" + ) # Map shapes must happen first; manyto1 mapping resets shape index + sdata = _sjoin_shapes( + sdata=sdata, + instance_key=instance_key, + shape_keys=shape_sjoin, + instance_map_type=instance_map_type, + ) + + pbar.update() + if len(point_sjoin) > 0: pbar.set_description("Mapping points") sdata = _sjoin_points( @@ -82,14 +123,6 @@ def prep( pbar.update() - if len(shape_sjoin) > 0: - pbar.set_description("Mapping shapes") - sdata = _sjoin_shapes( - sdata=sdata, instance_key=instance_key, shape_keys=shape_sjoin - ) - - pbar.update() - # Only keep points within instance_key shape _sync_points(sdata, points_key) diff --git a/bento/tools/_colocation.py b/bento/tools/_colocation.py index a614ca6..16b27fd 100644 --- a/bento/tools/_colocation.py +++ b/bento/tools/_colocation.py @@ -29,23 +29,24 @@ def colocation( Parameters ---------- sdata : SpatialData - Spatial formatted SpatialData object. - ranks : list + SpatialData object. + ranks : List[int] List of ranks to decompose the tensor. - instance_key : str + instance_key : str, default "cell_boundaries" Key that specifies cell_boundaries instance in sdata. - feature_key : str + feature_key : str, default "feature_name" Key that specifies genes in sdata. - iterations : int + iterations : int, default 3 Number of iterations to run the decomposition. - plot_error : bool + plot_error : bool, default True Whether to plot the error of the decomposition. Returns ------- - sdata : SpatialData - .tables["table"].uns['factors']: Decomposed tensor factors. - .tables["table"].uns['factors_error']: Decomposition error. + SpatialData + Updated SpatialData object with: + - .tables["table"].uns['factors']: Decomposed tensor factors. + - .tables["table"].uns['factors_error']: Decomposition error. """ print("Preparing tensor...") @@ -76,7 +77,7 @@ def _colocation_tensor(sdata: SpatialData, instance_key: str, feature_key: str): Parameters ---------- sdata : SpatialData - Spatial formatted SpatialData object. + SpatialData object. instance_key : str Key that specifies cell_boundaries instance in sdata. feature_key : str @@ -122,35 +123,36 @@ def coloc_quotient( radius: int = 20, min_points: int = 10, min_cells: int = 0, - num_workers=1, + num_workers: int = 1, ): """Calculate pairwise gene colocalization quotient in each cell. Parameters ---------- sdata : SpatialData - Spatial formatted SpatialData object. - points_key: str + SpatialData object. + points_key: str, default "transcripts" Key that specifies transcript points in sdata. - instance_key : str + instance_key : str, default "cell_boundaries" Key that specifies cell_boundaries instance in sdata. - feature_key : str + feature_key : str, default "feature_name" Key that specifies genes in sdata. - shapes : list + shapes : List[str], default ["cell_boundaries"] Specify which shapes to compute colocalization separately. - radius : int - Unit distance to count neighbors, default 20 - min_points : int - Minimum number of points for sample to be considered for colocalization, default 10 - min_cells : int - Minimum number of cells for gene to be considered for colocalization, default 0 - num_workers : int - Number of workers to use for parallel processing + radius : int, default 20 + Unit distance to count neighbors. + min_points : int, default 10 + Minimum number of points for sample to be considered for colocalization. + min_cells : int, default 0 + Minimum number of cells for gene to be considered for colocalization. + num_workers : int, default 1 + Number of workers to use for parallel processing. Returns ------- - sdata : SpatialData - .tables["table"].uns['clq']: Pairwise gene colocalization similarity within each cell formatted as a long dataframe. + SpatialData + Updated SpatialData object with: + - .tables["table"].uns['clq']: Pairwise gene colocalization similarity within each cell formatted as a long dataframe. """ all_clq = dict() diff --git a/bento/tools/_composition.py b/bento/tools/_composition.py index 9639cf8..2c25116 100644 --- a/bento/tools/_composition.py +++ b/bento/tools/_composition.py @@ -1,5 +1,6 @@ import numpy as np import pandas as pd +from typing import List from scipy.stats import wasserstein_distance from sklearn.metrics.pairwise import paired_distances from spatialdata._core.spatialdata import SpatialData @@ -8,26 +9,26 @@ def _get_compositions( - points: pd.DataFrame, shape_names: list, instance_key: str, feature_key: str + points: pd.DataFrame, shape_names: List[str], instance_key: str, feature_key: str ) -> pd.DataFrame: """Compute the mean composition of each gene across shapes. Parameters ---------- - points : pandas.DataFrame + points : pd.DataFrame Points indexed to shape_names denoted by boolean columns. - shape_names : list of str + shape_names : List[str] Names of shapes to calculate compositions for. instance_key : str - Key for + Key for identifying unique instances (e.g., cells). + feature_key : str + Key for identifying features (e.g., genes). Returns ------- - comp_data : DataFrame - For every gene return the composition of each shape, mean log(counts) and cell fraction. - + pd.DataFrame + For every gene, returns the composition of each shape, mean log(counts) and cell fraction. """ - n_cells = points[instance_key].nunique() points_grouped = points.groupby([instance_key, feature_key], observed=True) counts = points_grouped.apply(lambda x: (x[shape_names] != "").sum()) @@ -58,20 +59,23 @@ def _get_compositions( return comp_stats -def comp(sdata: SpatialData, points_key: str, shape_names: list): +def comp(sdata: SpatialData, points_key: str, shape_names: List[str]) -> SpatialData: """Calculate the average gene composition for shapes across all cells. Parameters ---------- - sdata : spatialdata.SpatialData - Spatial formatted SpatialData object. - shape_names : list of str + sdata : SpatialData + SpatialData object. + points_key : str + Key for points DataFrame in `sdata.points`. + shape_names : List[str] Names of shapes to calculate compositions for. Returns ------- - sdata : spatialdata.SpatialData - Updates `sdata.tables["table"].uns` with average gene compositions for each shape. + SpatialData + Updated SpatialData object with average gene compositions for each shape + stored in `sdata.tables["table"].uns["comp_stats"]`. """ points = get_points(sdata, points_key=points_key, astype="pandas") @@ -87,21 +91,30 @@ def comp(sdata: SpatialData, points_key: str, shape_names: list): def comp_diff( - sdata: SpatialData, points_key: str, shape_names: list, groupby: str, ref_group: str -): - """Calculate the average difference in gene composition for shapes across batches of cells. Uses the Wasserstein distance. + sdata: SpatialData, points_key: str, shape_names: List[str], groupby: str, ref_group: str +) -> SpatialData: + """Calculate the average difference in gene composition for shapes across batches of cells. + + Uses the Wasserstein distance to compare compositions. Parameters ---------- - sdata : spatialdata.SpatialData - Spatial formatted SpatialData object. - shape_names : list of str + sdata : SpatialData + SpatialData object. + points_key : str + Key for points DataFrame in `sdata.points`. + shape_names : List[str] Names of shapes to calculate compositions for. groupby : str - Key in `sdata.points['transcripts]` to group cells by. + Key in `sdata.points[points_key]` to group cells by. ref_group : str Reference group to compare other groups to. + Returns + ------- + SpatialData + Updated SpatialData object with composition statistics for each group + stored in `sdata.tables["table"].uns["{groupby}_comp_stats"]`. """ points = get_points(sdata, points_key=points_key, astype="pandas") diff --git a/bento/tools/_decomposition.py b/bento/tools/_decomposition.py index fa0a6b5..fac0fdf 100644 --- a/bento/tools/_decomposition.py +++ b/bento/tools/_decomposition.py @@ -19,27 +19,27 @@ def decompose( random_state: int = 11, ): """ - Perform tensor decomposition on an input tensor, optionally automatically selecting the best rank across a list of ranks. + Perform tensor decomposition on an input tensor using non-negative PARAFAC. Parameters ---------- tensor : np.ndarray - numpy array - ranks : int or list of int - Rank(s) to perform decomposition. - iterations : int, 3 by default - Number of times to run decomposition to compute confidence interval at each rank. Only the best iteration for each rank is saved. - device : str, optional - Device to use for decomposition. If "auto", will use GPU if available. By default "auto". - random_state : int, optional - Random state for decomposition. By default 11. + Input tensor for decomposition. + ranks : List[int] + List of ranks to perform decomposition for. + iterations : int, default 3 + Number of times to run decomposition for each rank to compute confidence interval. The best iteration for each rank is saved. + device : Literal["auto", "cpu", "cuda"], default "auto" + Device to use for decomposition. If "auto", will use GPU if available. + random_state : int, default 11 + Random state for reproducibility. Returns ------- - factors_per_rank : dict - Dictionary of factors for each rank. - errors : pd.DataFrame - Dataframe of errors for each rank. + Tuple[Dict[int, List[np.ndarray]], pd.DataFrame] + A tuple containing: + - factors_per_rank: Dictionary mapping each rank to a list of factor matrices. + - errors: DataFrame with columns 'rmse' and 'rank', containing the error for each rank. """ # Replace nans with 0 for decomposition tensor_mask = ~np.isnan(tensor) @@ -109,5 +109,20 @@ def decompose( return factors_per_rank, errors -def rmse(tensor, tensor_mu): +def rmse(tensor: np.ndarray, tensor_mu: np.ndarray) -> float: + """ + Calculate the Root Mean Square Error (RMSE) between two tensors, ignoring zero values. + + Parameters + ---------- + tensor : np.ndarray + Original tensor. + tensor_mu : np.ndarray + Reconstructed tensor. + + Returns + ------- + float + RMSE between the non-zero elements of the original and reconstructed tensors. + """ return np.sqrt((tensor[tensor != 0] - tensor_mu[tensor != 0]) ** 2).mean() diff --git a/bento/tools/_flux.py b/bento/tools/_flux.py index 74b8ab4..1c0408d 100644 --- a/bento/tools/_flux.py +++ b/bento/tools/_flux.py @@ -1,4 +1,4 @@ -from typing import Iterable, Literal, Optional, Union +from typing import Iterable, Literal, Optional, Union, List import dask import dask.delayed @@ -46,44 +46,50 @@ def flux( train_size: Optional[float] = 1, random_state: int = 11, recompute: bool = False, - num_workers=1, -): + num_workers: int = 1, +) -> SpatialData: """ Compute RNAflux embeddings of each pixel as local composition normalized by cell composition. - For k-nearest neighborhoods or "knn", method, specify n_neighbors. For radius neighborhoods, specify radius. - The default method is "radius" with radius = 1/3 of cell radius. RNAflux requires a minimum of 4 genes per cell to compute all embeddings properly. Parameters ---------- sdata : SpatialData - Spatial formatted SpatialData object. - points_key : str - key for points element that holds transcript coordinates - instance_key : str - Key for cell_boundaries instances - feature_key : str - Key for gene instances - method: str - Method to use for local neighborhood. Either 'knn' or 'radius'. - n_neighbors : int - Number of neighbors to use for local neighborhood. - radius : float - Fraction of mean cell radius to use for local neighborhood. - res : float + SpatialData object. + points_key : str, default "transcripts" + Key for points element that holds transcript coordinates. + instance_key : str, default "cell_boundaries" + Key for cell_boundaries instances. + feature_key : str, default "feature_name" + Key for gene instances. + method : Literal["knn", "radius"], default "radius" + Method to use for local neighborhood. + n_neighbors : Optional[int], default None + Number of neighbors to use for local neighborhood if method is "knn". + radius : Optional[float], default None + Fraction of mean cell radius to use for local neighborhood if method is "radius". + If None, defaults to 1/3 of average cell radius. + res : Optional[float], default 1 Resolution to use for rendering embedding. + train_size : Optional[float], default 1 + Fraction of data to use for training. + random_state : int, default 11 + Random state for reproducibility. + recompute : bool, default False + If True, recompute flux even if it already exists. + num_workers : int, default 1 + Number of workers to use for parallel processing. Returns ------- - sdata : SpatialData - .points["{instance_key}_raster"]: pd.DataFrame - Length pixels DataFrame containing all computed flux values, embeddings, and colors as columns in a single DataFrame. - flux values: for each gene used in embedding. - embeddings: flux_embed_ for each component of the embedding. - colors: hex color codes for each pixel. - .tables["table"].uns["flux_genes"] : list - List of genes used for embedding. - .tables["table"].uns["flux_variance_ratio"] : np.ndarray - [components] array of explained variance ratio for each component. + SpatialData + Updated SpatialData object with: + - .points["{instance_key}_raster"]: pd.DataFrame containing flux values, embeddings, and colors. + - .tables["table"].uns["flux_genes"]: List of genes used for embedding. + - .tables["table"].uns["flux_variance_ratio"]: Array of explained variance ratio for each component. + + Notes + ----- + RNAflux requires a minimum of 4 genes per cell to compute all embeddings properly. """ if ( @@ -285,8 +291,28 @@ def vec2color( ] = "hex", vmin: float = 0, vmax: float = 1, -): - """Convert vector to color.""" +) -> Union[np.ndarray, List[str]]: + """ + Convert vector to color. + + Parameters + ---------- + vec : np.ndarray + Input vector to convert to color. + alpha_vec : Optional[np.ndarray], default None + Vector of alpha values. + fmt : Literal["rgb", "hex"], default "hex" + Output format for colors. + vmin : float, default 0 + Minimum value for color scaling. + vmax : float, default 1 + Maximum value for color scaling. + + Returns + ------- + Union[np.ndarray, List[str]] + Array of RGB values or list of hex color codes. + """ # Grab the first 3 channels color = vec[:, :3] @@ -330,37 +356,39 @@ def fluxmap( random_state: int = 11, plot_error: bool = False, ): - """Cluster flux embeddings using self-organizing maps (SOMs) and vectorize clusters as Polygon shapes. + """ + Cluster flux embeddings using self-organizing maps (SOMs) and vectorize clusters as Polygon shapes. Parameters ---------- sdata : SpatialData - Spatial formatted SpatialData object. - points_key : str - key for points element that holds transcript coordinates - instance_key : str - Key for cell_boundaries instances - n_clusters : int or list - Number of clusters to use. If list, will pick best number of clusters + SpatialData object. + points_key : str, default "transcripts" + Key for points element that holds transcript coordinates. + instance_key : str, default "cell_boundaries" + Key for cell_boundaries instances. + n_clusters : Union[Iterable[int], int], default range(2, 9) + Number of clusters to use. If iterable, will pick best number of clusters using the elbow heuristic evaluated on the quantization error. - num_iterations : int + num_iterations : int, default 1000 Number of iterations to use for SOM training. - train_size : float - Fraction of cells to use for SOM training. Default 0.2. - res : float - Resolution used for rendering embedding. Default 0.05. - random_state : int - Random state to use for SOM training. Default 11. - plot_error : bool - Whether to plot quantization error. Default True. + min_count : int, default 50 + Minimum count for a point to be included in clustering. + train_size : float, default 1 + Fraction of cells to use for SOM training. + res : float, default 1 + Resolution used for rendering embedding. + random_state : int, default 11 + Random state to use for SOM training. + plot_error : bool, default False + Whether to plot quantization error. Returns ------- - sdata : SpatialData - .points["points"] : DataFrame - Adds "fluxmap" column denoting cluster membership. - .shapes["fluxmap#"] : GeoSeries - Adds "fluxmap#" columns for each cluster rendered as (Multi)Polygon shapes. + SpatialData + Updated SpatialData object with: + - .points[f"{instance_key}_raster"]: Added "fluxmap" column denoting cluster membership. + - .shapes["fluxmap#"]: Added "fluxmap#" columns for each cluster rendered as (Multi)Polygon shapes. """ raster_points = get_points( @@ -562,7 +590,7 @@ def fluxmap( sdata.points[points_key] = sdata.points[points_key].drop(old_cols, axis=1) _sjoin_points(sdata=sdata, shape_keys=fluxmap_names, points_key=points_key) - _sjoin_shapes(sdata=sdata, instance_key=instance_key, shape_keys=fluxmap_names) + _sjoin_shapes(sdata=sdata, instance_key=instance_key, shape_keys=fluxmap_names, instance_map_type="1to1") pbar.update() pbar.set_description("Done") diff --git a/bento/tools/_flux_enrichment.py b/bento/tools/_flux_enrichment.py index bd8ba6d..1c63e6e 100644 --- a/bento/tools/_flux_enrichment.py +++ b/bento/tools/_flux_enrichment.py @@ -12,17 +12,18 @@ def fe_fazal2019(sdata: SpatialData, **kwargs): """Compute enrichment scores from subcellular compartment gene sets from Fazal et al. 2019 (APEX-seq). - See `bento.tl.fe` docs for parameter details. Parameters ---------- - data : SpatialData - Spatial formatted SpatialData object. + sdata : SpatialData + SpatialData object. + **kwargs + Additional keyword arguments passed to `bento.tl.fe()` function. Returns ------- - DataFrame - Enrichment scores for each gene set. + SpatialData + Updated SpatialData object with enrichment scores added to points metadata. """ gene_sets = load_gene_sets("fazal2019") @@ -31,17 +32,18 @@ def fe_fazal2019(sdata: SpatialData, **kwargs): def fe_xia2019(sdata: SpatialData, **kwargs): """Compute enrichment scores from subcellular compartment gene sets from Xia et al. 2019 (MERFISH 10k U2-OS). - See `bento.tl.fe` docs for parameters details. Parameters ---------- - data : SpatialData - Spatial formatted SpatialData object. + sdata : SpatialData + SpatialData object. + **kwargs + Additional keyword arguments passed to `bento.tl.fe()` function. Returns ------- - DataFrame - Enrichment scores for each gene set. + SpatialData + Updated SpatialData object with enrichment scores added to points metadata. """ gene_sets = load_gene_sets("xia2019") @@ -51,38 +53,41 @@ def fe_xia2019(sdata: SpatialData, **kwargs): def fe( sdata: SpatialData, net: pd.DataFrame, - instance_key: Optional[str] = "cell_boundaries", - source: Optional[str] = "source", - target: Optional[str] = "target", - weight: Optional[str] = "weight", + instance_key: str = "cell_boundaries", + source: str = "source", + target: str = "target", + weight: str = "weight", batch_size: int = 10000, min_n: int = 0, -): +) -> SpatialData: """ - Perform functional enrichment of RNAflux embeddings. Uses decoupler wsum function. + Perform functional enrichment of RNAflux embeddings using decoupler's wsum function. Parameters ---------- sdata : SpatialData - Spatial formatted SpatialData object. - net : DataFrame - DataFrame with columns "source", "target", and "weight". See decoupler API for more details. - source : str, optional - Column name for source nodes in `net`. Default "source". - target : str, optional - Column name for target nodes in `net`. Default "target". - weight : str, optional - Column name for weights in `net`. Default "weight". - batch_size : int - Number of points to process in each batch. Default 10000. - min_n : int + SpatialData object. + net : pd.DataFrame + DataFrame with columns for source, target, and weight. See decoupler API for more details. + instance_key : str, default "cell_boundaries" + Key for the instance in sdata.points. + source : str, default "source" + Column name for source nodes in `net`. + target : str, default "target" + Column name for target nodes in `net`. + weight : str, default "weight" + Column name for weights in `net`. + batch_size : int, default 10000 + Number of points to process in each batch. + min_n : int, default 0 Minimum number of targets per source. If less, sources are removed. Returns ------- - sdata : SpatialData - .points["cell_boundaries_raster"]["flux_fe"] : DataFrame - Enrichment scores for each gene set. + SpatialData + Updated SpatialData object with: + - Enrichment scores added to `sdata.points[f"{instance_key}_raster"]` + - Enrichment statistics added to `sdata.tables["table"].uns["fe_stats"]` and `sdata.tables["table"].uns["fe_ngenes"]` """ # Make sure embedding is run first if "flux_genes" in sdata.tables["table"].uns: @@ -160,18 +165,23 @@ def _fe_stats( ) -def load_gene_sets(name): - """Load a gene set; list available ones with `bento.tl.gene_sets`. +def load_gene_sets(name: str) -> pd.DataFrame: + """Load a gene set from the predefined collection. Parameters ---------- name : str - Name of gene set to load. + Name of gene set to load. Available options can be listed with `bento.tl.gene_sets`. Returns ------- - DataFrame - Gene set. + pd.DataFrame + Gene set as a DataFrame. + + Raises + ------ + KeyError + If the specified gene set name is not found in the collection. """ from importlib.resources import files, as_file diff --git a/bento/tools/_lp.py b/bento/tools/_lp.py index f9d594f..d285199 100644 --- a/bento/tools/_lp.py +++ b/bento/tools/_lp.py @@ -1,4 +1,5 @@ -from typing import List, Optional, Union +from typing import List, Union +import os import pickle import warnings @@ -29,23 +30,32 @@ def lp( recompute=False, ): """Predict transcript subcellular localization patterns. - Patterns include: cell edge, cytoplasmic, nuclear edge, nuclear, none + + Predicts patterns including: cell edge, cytoplasmic, nuclear edge, nuclear, none. + Computes required features if they don't exist. Parameters ---------- sdata : SpatialData - Spatial formatted SpatialData object - - groupby : str or list of str - Key in `sdata.points[points_key] to groupby, by default None. Always treats each cell separately + Input SpatialData object + instance_key : str, default "cell_boundaries" + Key for cell boundaries in sdata.shapes + nucleus_key : str, default "nucleus_boundaries" + Key for nucleus boundaries in sdata.shapes + groupby : str or list of str, default "feature_name" + Column(s) in sdata.points to group transcripts by + num_workers : int, default 1 + Number of parallel workers for feature computation + recompute : bool, default False + Whether to recompute existing features Returns ------- - sdata : SpatialData - .tables["table"].uns['lp'] - Localization pattern indicator matrix. - .tables["table"].uns['lpp'] - Localization pattern probabilities. + None + Modifies sdata.tables["table"].uns with: + - 'lp': DataFrame of binary pattern indicators + - 'lpp': DataFrame of pattern probabilities + Also computes pattern statistics via lp_stats() """ if isinstance(groupby, str): @@ -100,7 +110,7 @@ def lp( invalid_samples = X_df.isna().any(axis=1) # Load trained model - model_dir = "/".join(bento.__file__.split("/")[:-1]) + "/models" + model_dir = os.path.join(os.path.dirname(bento.__file__), "models") model = pickle.load(open(f"{model_dir}/rf_calib_20220514.pkl", "rb")) # Compatibility with newer versions of scikit-learn @@ -148,23 +158,23 @@ def lp( def lp_stats(sdata: SpatialData): - """Computes frequencies of localization patterns across cells and genes. + """Compute frequencies of localization patterns across cells and genes. Parameters ---------- sdata : SpatialData - Spatial formatted SpatialData object. - instance_key : str - cell boundaries instance key + Input SpatialData object with localization pattern results Returns ------- - sdata : SpatialData - .tables["table"].uns['lp_stats']: DataFrame of localization pattern frequencies. + None + Modifies sdata with: + - tables["table"].uns['lp_stats']: Pattern frequencies per group + - points["transcripts"]: Adds 'pattern' column with top pattern """ instance_key = get_instance_key(sdata) feature_key = get_feature_key(sdata) - lp = sdata.table.uns["lp"] + lp = sdata["table"].uns["lp"] cols = lp.columns groupby = list(cols[~cols.isin(PATTERN_NAMES)]) @@ -173,7 +183,7 @@ def lp_stats(sdata: SpatialData): g_pattern_counts = lp.groupby(groupby, observed=True).apply( lambda df: df[PATTERN_NAMES].sum().astype(int) ) - sdata.table.uns["lp_stats"] = g_pattern_counts + sdata["table"].uns["lp_stats"] = g_pattern_counts lpp = sdata["table"].uns["lpp"] top_pattern = lpp[[instance_key, feature_key]] @@ -191,22 +201,26 @@ def lp_stats(sdata: SpatialData): set_points_metadata(sdata, "transcripts", top_pattern_long, "pattern") -def _lp_logfc(sdata, instance_key, phenotype=None): - """Compute pairwise log2 fold change of patterns between groups in phenotype. +def _lp_logfc(sdata, instance_key, phenotype): + """Compute pairwise log2 fold change of patterns between phenotype groups. Parameters ---------- - data : SpatialData - Spatial formatted SpatialData object. - instance_key: str - cell boundaries instance key + sdata : SpatialData + Input SpatialData object containing localization pattern results + instance_key : str + Key for cell boundaries in sdata.shapes phenotype : str - Variable grouping cells for differential analysis. Must be in sdata.shapes["cell_boundaries"].columns. + Column in sdata.shapes[instance_key] containing group labels Returns ------- - gene_fc_stats : DataFrame - log2 fold change of patterns between groups in phenotype. + pd.DataFrame + Log2 fold changes between groups with columns: + - feature_name: Feature identifier + - log2fc: Log2 fold change between groups + - phenotype: Group identifier + - pattern: Pattern name """ stats = sdata.tables["table"].uns["lp_stats"] @@ -231,13 +245,19 @@ def _lp_logfc(sdata, instance_key, phenotype=None): ) def log2fc(group_col): - """ - Return - ------ - log2fc : int - log2fc of group_count / rest, pseudocount of 1 - group_count : int - rest_mean_count : int + """Calculate log2 fold change between one group and mean of other groups. + + Parameters + ---------- + group_col : pd.Series + Pattern frequencies for one phenotype group + + Returns + ------- + pd.DataFrame + DataFrame with columns: + - log2fc: log2 fold change of group vs mean of other groups (with pseudocount of 1) + - phenotype: name of the group """ group_name = group_col.name rest_cols = group_freq.columns[group_freq.columns != group_name] @@ -267,21 +287,29 @@ def log2fc(group_col): def _lp_diff_gene(cell_by_pattern, phenotype_series, instance_key): - """Perform pairwise comparison between groupby and every class. + """Test differential pattern usage between phenotype groups using logistic regression. Parameters ---------- - cell_by_pattern : DataFrame - Cell by pattern matrix. - phenotype_series : Series - Series of cell groupings. + cell_by_pattern : pd.DataFrame + Binary matrix of cells x patterns (0/1 indicators) + phenotype_series : pd.Series + Cell phenotype labels indexed by instance_key instance_key : str - cell boundaries instance key + Key identifying cells in cell_by_pattern index Returns ------- - DataFrame - Differential localization test results. [# of patterns, ] + pd.DataFrame + Statistical test results with columns: + - pattern: Pattern name + - dy/dx: Marginal effect size + - std_err: Standard error + - z: Z-score statistic + - pvalue: Raw p-value + - ci_low: Lower confidence interval + - ci_high: Upper confidence interval + - phenotype: Group identifier """ cell_by_pattern = cell_by_pattern.dropna().reset_index(drop=True) @@ -349,29 +377,28 @@ def _lp_diff_gene(cell_by_pattern, phenotype_series, instance_key): def lp_diff_discrete( sdata: SpatialData, instance_key: str = "cell_boundaries", phenotype: str = None ): - """Gene-wise test for differential localization across phenotype of interest. + """Test for differential localization patterns between discrete phenotype groups. - Scenarios: - Missing patterns within phenotype groupings - Solution: - - Warn user about missing patterns - - Remove missing patterns from analysis - - Return results with missing patterns removed + Performs pairwise statistical testing between phenotype groups for each pattern + and gene combination. Missing patterns are excluded from analysis. Parameters ---------- sdata : SpatialData - Spatial formatted SpatialData object. - instance_key : str - cell boundaries instance key. + Input SpatialData object with localization pattern results + instance_key : str, default "cell_boundaries" + Key for cell boundaries in sdata.shapes phenotype : str - Variable grouping cells for differential analysis. Must be in sdata.shape["cell_boundaries].columns. + Column in sdata.shapes[instance_key] containing group labels Returns ------- - sdata : SpatialData - .tables["table"].uns['diff_{phenotype}'] - Long DataFrame with differential localization test results across phenotype groups. + None + Modifies sdata.tables["table"].uns[f'diff_{phenotype}'] with: + - Statistical results (p-values, z-scores) + - Effect sizes (dy/dx) + - Log2 fold changes between groups + - Multiple testing corrected p-values """ lp_df = sdata.tables["table"].uns["lp"] @@ -450,22 +477,24 @@ def lp_diff_discrete( def lp_diff_continuous( sdata: SpatialData, instance_key: str = "cell_boundaries", phenotype: str = None ): - """Gene-wise test for differential localization across phenotype of interest. + """Test correlation between localization patterns and continuous phenotype values. Parameters ---------- sdata : SpatialData - Spatial formatted SpatialData object. - instance_key : str - cell boundaries instance key. + Input SpatialData object with localization pattern results + instance_key : str, default "cell_boundaries" + Key for cell boundaries in sdata.shapes phenotype : str - Variable grouping cells for differential analysis. Must be in sdata.shape["cell_boundaries].columns. + Column in sdata.shapes[instance_key] containing continuous values Returns ------- - sdata : SpatialData - .tables["table"].uns['diff_{phenotype}'] - Long DataFrame with differential localization test results across phenotype groups. + None + Modifies sdata.tables["table"].uns[f'diff_{phenotype}'] with: + - feature_name: Feature identifier + - pattern: Pattern name + - pearson_correlation: Correlation coefficient with phenotype """ stats = sdata.tables["table"].uns["lp_stats"] lpp = sdata.tables["table"].uns["lpp"] diff --git a/bento/tools/_neighborhoods.py b/bento/tools/_neighborhoods.py index 08d69a0..d85d98c 100644 --- a/bento/tools/_neighborhoods.py +++ b/bento/tools/_neighborhoods.py @@ -1,3 +1,4 @@ +from typing import Optional, Union, List import numpy as np import pandas as pd from scipy.sparse import csr_matrix @@ -5,29 +6,45 @@ def _count_neighbors( - points, n_genes, query_points=None, n_neighbors=None, radius=None, agg="feature_name" -): - """Build nearest neighbor index for points. + points: pd.DataFrame, + n_genes: int, + query_points: Optional[pd.DataFrame] = None, + n_neighbors: Optional[int] = None, + radius: Optional[float] = None, + agg: Optional[str] = "feature_name" +) -> Union[pd.DataFrame, csr_matrix]: + """Build nearest neighbor index and count neighbors for points. Parameters ---------- points : pd.DataFrame - Points dataframe. Must have columns "x", "y", and "gene". + Points dataframe containing columns "x", "y", and "feature_name" n_genes : int - Number of genes in overall dataset. Used to initialize unique gene counts. + Total number of unique genes in dataset query_points : pd.DataFrame, optional - Points to query. If None, use points_df. Default None. - n_neighbors : int - Number of nearest neighbors to consider per gene. - agg : "gene", "binary", None - Whether to aggregate nearest neighbors counts. "Gene" aggregates counts by gene, whereas "binary" counts neighbors only once per point. If None, return neighbor counts for each point. - Default "gene". + Points to query. If None, uses points dataframe + n_neighbors : int, optional + Number of nearest neighbors to find per point + radius : float, optional + Radius within which to find neighbors + agg : str, optional + How to aggregate neighbor counts: + - "feature_name": aggregate by gene + - "binary": count neighbors once per point + - None: return raw neighbor counts per point + Returns ------- - DataFrame or dict of dicts - If agg is True, returns a DataFrame with columns "gene", "neighbor", and "count". - If agg is False, returns a list of dicts, one for each point. Dict keys are gene names, values are counts. - + Union[pd.DataFrame, csr_matrix] + If agg="feature_name": + DataFrame with columns ["feature_name", "neighbor", "count"] + If agg="binary" or None: + Sparse matrix of shape (n_points, n_genes) containing neighbor counts + + Raises + ------ + ValueError + If neither n_neighbors nor radius is specified, or if both are specified """ if n_neighbors and radius: raise ValueError("Only specify one of n_neighbors or radius, not both.") diff --git a/bento/tools/_point_features.py b/bento/tools/_point_features.py index cd28f7d..629f6d3 100644 --- a/bento/tools/_point_features.py +++ b/bento/tools/_point_features.py @@ -9,7 +9,7 @@ import re from abc import ABCMeta, abstractmethod from math import isnan -from typing import List, Optional, Union +from typing import List, Optional, Union, Type, Dict import numpy as np import pandas as pd @@ -25,39 +25,50 @@ def analyze_points( sdata: SpatialData, - shape_keys: List[str], - feature_names: List[str], + shape_keys: Union[str, List[str]], + feature_names: Union[str, List[str]], points_key: str = "transcripts", instance_key: str = "cell_boundaries", groupby: Optional[Union[str, List[str]]] = None, - recompute=False, - progress=False, + recompute: bool = False, + progress: bool = False, num_workers: int = 1, -): - """Calculate features for each point group. Groups are always within each cell. +) -> None: + """Calculate features for point groups within cells. - When creating the points_df, it first grabs sdata.points[points_key] and joins shape polygons from sdata.shapes[shape_keys]. - The second join is to sdata.shapes[instance_key] to pull in cell polygons and cell features. - The shape indices in the points object are renamed to have _index as a suffix to avoid conflicts. - The joined polygons are named with it's respective shape_key. + Efficiently avoids recomputing cell-level features by compiling and computing + once the set of required cell-level features and attributes for each feature. Parameters ---------- sdata : SpatialData - Spatially formatted SpatialData + Input SpatialData object shape_keys : str or list of str - Names of the shapes to analyze. + Names of shapes to analyze in sdata.shapes feature_names : str or list of str - Names of the features to analyze. + Names of features to compute; list available features with `bt.tl.list_point_features()` + points_key : str, default "transcripts" + Key for points in sdata.points + instance_key : str, default "cell_boundaries" + Key for cell boundaries in sdata.shapes groupby : str or list of str, optional - Key(s) in `data.points['points'] to groupby, by default None. Always treats each cell separately. + Column(s) in sdata.points[points_key] to group by + recompute : bool, default False + Whether to force recomputation of features + progress : bool, default False + Whether to show progress bars + num_workers : int, default 1 + Number of parallel workers Returns ------- - sdata : spatialdata.SpatialData - table.uns["{instance_key}_{groupby}_features"] - See the output of each :class:`PointFeature` in `features` for keys added. + None + Updates sdata.tables["table"].uns["{instance_key}_{groupby}_features"] with computed features + Raises + ------ + KeyError + If required shape keys or groupby columns are not found """ # Cast to list if not already @@ -194,60 +205,76 @@ def process_partition(bag): class PointFeature(metaclass=ABCMeta): - """Abstract class for calculating sample features. A sample is defined as the set of - molecules corresponding to a single cell-gene pair. + """Base class for point feature calculations. + + Parameters + ---------- + instance_key : str + Key for cell boundaries in sdata.shapes + shape_key : str, optional + Key for shape to analyze relative to Attributes ---------- - cell_features : int - Set of cell-level features needed for computing sample-level features. - attributes : int - Names (keys) used to store computed cell-level features. + cell_features : set + Required cell-level features + attributes : set + Required shape attributes """ - def __init__(self, instance_key, shape_key): + def __init__(self, instance_key: str, shape_key: Optional[str] = None): self.cell_features = set() self.attributes = set() self.instance_key = instance_key - + if shape_key: self.attributes.add(shape_key) self.shape_key = shape_key @abstractmethod - def extract(self, df): - """Calculates this feature for a given sample. + def extract(self, df: pd.DataFrame) -> Dict[str, float]: + """Calculate features for a group of points. Parameters ---------- - df : DataFrame - Assumes each row is a molecule and that columns `x`, `y`, `cell`, and `gene` are present. + df : pd.DataFrame + Points data with required columns + + Returns + ------- + Dict[str, float] + Computed feature values """ return df class ShapeProximity(PointFeature): - """For a set of points, computes the proximity of points within `shape_key` - as well as the proximity of points outside `shape_key`. Proximity is defined as - the average absolute distance to the specified `shape_key` normalized by cell - radius. Values closer to 0 denote farther from the `shape_key`, values closer - to 1 denote closer to the `shape_key`. + """Compute proximity of points relative to a shape boundary. + + Parameters + ---------- + instance_key : str + Key for cell boundaries in sdata.shapes + shape_key : str + Key for shape to analyze relative to Attributes ---------- - cell_features : int - Set of cell-level features needed for computing sample-level features - attributes : int - Names (keys) used to store computed cell-level features + cell_features : set + Required cell-level features + attributes : set + Required shape attributes Returns ------- - dict - `"{shape_key}_inner_proximity"`: proximity of points inside `shape_key` - `"{shape_key}_outer_proximity"`: proximity of points outside `shape_key` + Dict[str, float] + Dictionary containing: + - {shape_key}_inner_proximity: Proximity of points inside shape (0-1) + - {shape_key}_outer_proximity: Proximity of points outside shape (0-1) + Values closer to 0 indicate farther from boundary, closer to 1 indicate nearer """ - def __init__(self, instance_key, shape_key): + def __init__(self, instance_key: str, shape_key: str): super().__init__(instance_key, shape_key) self.cell_features.add("radius") self.attributes.add(f"{self.instance_key}_radius") @@ -311,29 +338,33 @@ def extract(self, df): class ShapeAsymmetry(PointFeature): - """For a set of points, computes the asymmetry of points within `shape_key` - as well as the asymmetry of points outside `shape_key`. Asymmetry is defined as - the offset between the centroid of points to the centroid of the specified - `shape_key`, normalized by cell radius. Values closer to 0 denote symmetry, - values closer to 1 denote asymmetry. + """Compute asymmetry of points relative to a shape centroid. - Attributes + Parameters ---------- - cell_features : int - Set of cell-level features needed for computing sample-level features - cell_attributes : int - Names (keys) used to store computed cell-level features + instance_key : str + Key for cell boundaries in sdata.shapes shape_key : str - Name of shape to use, must be column name in input DataFrame + Key for shape to analyze relative to + + + Attributes + ---------- + cell_features : set + Required cell-level features + attributes : set + Required shape attributes Returns ------- - dict - `"{shape_key}_inner_asymmetry"`: asymmetry of points inside `shape_key` - `"{shape_key}_outer_asymmetry"`: asymmetry of points outside `shape_key` + Dict[str, float] + Dictionary containing: + - {shape_key}_inner_asymmetry: Asymmetry of points inside shape (0-1) + - {shape_key}_outer_asymmetry: Asymmetry of points outside shape (0-1) + Values closer to 0 indicate symmetry, closer to 1 indicate asymmetry """ - def __init__(self, instance_key, shape_key): + def __init__(self, instance_key: str, shape_key: str): super().__init__(instance_key, shape_key) self.cell_features.add("radius") self.attributes.add(f"{self.instance_key}_radius") @@ -397,24 +428,31 @@ def extract(self, df): class PointDispersionNorm(PointFeature): - """For a set of points, calculates the second moment of all points in a cell - relative to the centroid of the total RNA signal. This value is normalized by - the second moment of a uniform distribution within the cell boundary. + """Compute normalized dispersion of points relative to RNA signal centroid. + + Parameters + ---------- + instance_key : str + Key for cell boundaries in sdata.shapes + shape_key : str + Key for shape to analyze relative to Attributes ---------- - cell_features : int - Set of cell-level features needed for computing sample-level features - cell_attributes : int - Names (keys) used to store computed cell-level features + cell_features : set + Required cell-level features + attributes : set + Required shape attributes Returns ------- - dict - `"point_dispersion"`: measure of point dispersion + Dict[str, float] + Dictionary containing: + - point_dispersion_norm: Second moment of points normalized by + second moment of uniform distribution """ - def __init__(self, instance_key, shape_key): + def __init__(self, instance_key: str, shape_key: str): super().__init__(instance_key, shape_key) self.cell_features.add("raster") self.attributes.add(f"{self.instance_key}_raster") @@ -441,26 +479,32 @@ def extract(self, df): class ShapeDispersionNorm(PointFeature): - """For a set of points, calculates the second moment of all points in a cell relative to the - centroid of `shape_key`. This value is normalized by the second moment of a uniform - distribution within the cell boundary. + """Compute normalized dispersion of points relative to a shape centroid. + + Parameters + ---------- + instance_key : str + Key for cell boundaries in sdata.shapes + shape_key : str + Key for shape to analyze relative to Attributes ---------- - cell_features : int - Set of cell-level features needed for computing sample-level features - cell_attributes : int - Names (keys) used to store computed cell-level features + cell_features : set + Required cell-level features + attributes : set + Required shape attributes Returns ------- - dict - `"{shape_key}_dispersion"`: measure of point dispersion relative to `shape_key` + Dict[str, float] + Dictionary containing: + - {shape_key}_dispersion_norm: Second moment of points normalized by + second moment of uniform distribution """ - def __init__(self, instance_key, shape_key): + def __init__(self, instance_key: str, shape_key: str): super().__init__(instance_key, shape_key) - self.cell_features.add("raster") self.attributes.add(f"{self.instance_key}_raster") @@ -494,25 +538,24 @@ def extract(self, df): class ShapeDistance(PointFeature): - """For a set of points, computes the distance of points within `shape_key` - as well as the distance of points outside `shape_key`. + """Compute absolute distances between points and a shape boundary. - Attributes + Parameters ---------- - cell_features : int - Set of cell-level features needed for computing sample-level features - attributes : int - Names (keys) used to store computed cell-level features + instance_key : str + Key for cell boundaries in sdata.shapes + shape_key : str + Key for shape to analyze relative to Returns ------- - dict - `"{shape_key}_inner_distance"`: distance of points inside `shape_key` - `"{shape_key}_outer_distance"`: distance of points outside `shape_key` + Dict[str, float] + Dictionary containing: + - {shape_key}_inner_distance: Mean distance of points inside shape to boundary + - {shape_key}_outer_distance: Mean distance of points outside shape to boundary """ - # Cell-level features needed for computing sample-level features - def __init__(self, instance_key, shape_key): + def __init__(self, instance_key: str, shape_key: str): super().__init__(instance_key, shape_key) def extract(self, df): @@ -564,28 +607,24 @@ def extract(self, df): class ShapeOffset(PointFeature): - """For a set of points, computes the offset of points within `shape_key` - as well as the offset of points outside `shape_key`. Offset is defined as - the offset between the centroid of points to the centroid of the specified - `shape_key`. + """Compute distances between point centroids and a shape centroid. - Attributes + Parameters ---------- - cell_features : int - Set of cell-level features needed for computing sample-level features - attributes : int - Names (keys) used to store computed cell-level features + instance_key : str + Key for cell boundaries in sdata.shapes shape_key : str - Name of shape to use, must be column name in input DataFrame + Key for shape to analyze relative to Returns ------- - dict - `"{shape_key}_inner_offset"`: offset of points inside `shape_key` - `"{shape_key}_outer_offset"`: offset of points outside `shape_key` + Dict[str, float] + Dictionary containing: + - {shape_key}_inner_offset: Mean distance from inner points to shape centroid + - {shape_key}_outer_offset: Mean distance from outer points to shape centroid """ - def __init__(self, instance_key, shape_key): + def __init__(self, instance_key: str, shape_key: str): super().__init__(instance_key, shape_key) def extract(self, df): @@ -637,24 +676,30 @@ def extract(self, df): class PointDispersion(PointFeature): - """For a set of points, calculates the second moment of all points in a cell - relative to the centroid of the total RNA signal. + """Compute second moment of points relative to RNA signal centroid. + + Parameters + ---------- + instance_key : str + Key for cell boundaries in sdata.shapes + shape_key : Optional[str] + Not used, included for API consistency Attributes ---------- - cell_features : int - Set of cell-level features needed for computing sample-level features - attributes : int - Names (keys) used to store computed cell-level features + cell_features : set + Required cell-level features + attributes : set + Required shape attributes Returns ------- - dict - `"point_dispersion"`: measure of point dispersion + Dict[str, float] + Dictionary containing: + - point_dispersion: Second moment of points relative to RNA centroid """ - # shape_key set to None to follow the same convention as other shape features - def __init__(self, instance_key, shape_key=None): + def __init__(self, instance_key: str, shape_key: Optional[str] = None): super().__init__(instance_key, shape_key) def extract(self, df): @@ -670,23 +715,30 @@ def extract(self, df): class ShapeDispersion(PointFeature): - """For a set of points, calculates the second moment of all points in a cell relative to the - centroid of `shape_key`. + """Compute second moment of points relative to a shape centroid. + + Parameters + ---------- + instance_key : str + Key for cell boundaries in sdata.shapes + shape_key : str + Key for shape to analyze relative to Attributes ---------- - cell_features : int - Set of cell-level features needed for computing sample-level features - attributes : int - Names (keys) used to store computed cell-level features + cell_features : set + Required cell-level features + attributes : set + Required shape attributes Returns ------- - dict - `"{shape_key}_dispersion"`: measure of point dispersion relative to `shape_key` + Dict[str, float] + Dictionary containing: + - {shape_key}_dispersion: Second moment of points relative to shape centroid """ - def __init__(self, instance_key, shape_key): + def __init__(self, instance_key: str, shape_key: str): super().__init__(instance_key, shape_key) def extract(self, df): @@ -712,41 +764,41 @@ def extract(self, df): class RipleyStats(PointFeature): - """For a set of points, calculates properties of the L-function. The L-function - measures spatial clustering of a point pattern over the area of the cell. + """Compute Ripley's L-function statistics for point patterns. + + The L-function is evaluated at r=[1,d], where d is half the cell's maximum diameter. + + Parameters + ---------- + instance_key : str + Key for cell boundaries in sdata.shapes + shape_key : Optional[str] + Not used, included for API consistency Attributes ---------- - cell_features : int - Set of cell-level features needed for computing sample-level features - attributes : int - Names (keys) used to store computed cell-level features + cell_features : set + Required cell-level features + attributes : set + Required shape attributes Returns ------- - dict - `"l_max": The max value of the L-function evaluated at r=[1,d], where d is half the cell’s maximum diameter. - `"l_max_gradient"`: The max value of the gradient of the above L-function. - `"l_min_gradient"`: The min value of the gradient of the above L-function. - `"l_monotony"`: The correlation of the L-function and r=[1,d]. - `"l_half_radius"`: The value of the L-function evaluated at 1/4 of the maximum cell diameter. - + Dict[str, float] + Dictionary containing: + - l_max: Maximum value of L-function + - l_max_gradient: Maximum gradient of L-function + - l_min_gradient: Minimum gradient of L-function + - l_monotony: Spearman correlation between L-function and radius + - l_half_radius: L-function value at quarter cell diameter """ - def __init__(self, instance_key, shape_key=None): + def __init__(self, instance_key: str, shape_key: Optional[str] = None): super().__init__(instance_key, shape_key) self.cell_features.update(["span", "bounds", "area"]) - - self.attributes.update( - [ - f"{instance_key}_span", - f"{instance_key}_minx", - f"{instance_key}_miny", - f"{instance_key}_maxx", - f"{instance_key}_maxy", - f"{instance_key}_area", - ] - ) + self.attributes.update([f"{instance_key}_span", f"{instance_key}_minx", + f"{instance_key}_miny", f"{instance_key}_maxx", + f"{instance_key}_maxy", f"{instance_key}_area"]) def extract(self, df): df = super().extract(df) @@ -812,25 +864,23 @@ def extract(self, df): class ShapeEnrichment(PointFeature): - """For a set of points, calculates the fraction of points within `shape_key` - out of all points in the cell. + """Compute fraction of points within a shape boundary. - Attributes + Parameters ---------- - cell_features : int - Set of cell-level features needed for computing sample-level features - attributes : int - Names (keys) used to store computed cell-level features + instance_key : str + Key for cell boundaries in sdata.shapes shape_key : str - Name of shape to use, must be column name in input DataFrame + Key for shape to analyze relative to Returns ------- - dict - `"{shape_key}_enrichment"`: enrichment fraction of points in `shape_key` + Dict[str, float] + Dictionary containing: + - {shape_key}_enrichment: Fraction of points inside shape (0-1) """ - def __init__(self, instance_key, shape_key): + def __init__(self, instance_key: str, shape_key: str): super().__init__(instance_key, shape_key) def extract(self, df): @@ -849,16 +899,22 @@ def extract(self, df): return {f"{self.shape_key}_enrichment": enrichment} -def _second_moment(centroid, pts): - """ - Calculate second moment of points with centroid as reference. +def _second_moment(centroid: np.ndarray, pts: np.ndarray) -> float: + """Calculate second moment of points relative to a centroid. Parameters ---------- - centroid : [1 x 2] float - pts : [n x 2] float + centroid : np.ndarray + Reference point coordinates, shape (1, 2) + pts : np.ndarray + Point coordinates, shape (n, 2) + + Returns + ------- + float + Second moment value """ - if type(centroid) != np.ndarray: + if type(centroid) is not np.ndarray: centroid = centroid.coords centroid = np.array(centroid).reshape(1, 2) radii = distance.cdist(centroid, pts) @@ -866,13 +922,13 @@ def _second_moment(centroid, pts): return second_moment -def list_point_features(): - """Return a DataFrame of available point features. Pulls descriptions from function docstrings. +def list_point_features() -> pd.DataFrame: + """List available point feature calculations. Returns ------- - list - List of available point features. + pd.DataFrame + DataFrame with feature names as index and descriptions from docstrings """ # Get point feature descriptions from docstrings @@ -899,15 +955,20 @@ def list_point_features(): ) -def register_point_feature(name: str, FeatureClass: PointFeature): - """Register a new point feature function. +def register_point_feature(name: str, FeatureClass: Type[PointFeature]) -> None: + """Register a new point feature calculation class. Parameters ---------- name : str - Name of feature function - func : class - Class that extends PointFeature. Needs to override abstract functions. + Name to register the feature as + FeatureClass : Type[PointFeature] + Class that extends PointFeature base class + + Returns + ------- + None + Updates global point_features dictionary """ point_features[name] = FeatureClass diff --git a/bento/tools/_shape_features.py b/bento/tools/_shape_features.py index aa115dc..f0f3fd4 100644 --- a/bento/tools/_shape_features.py +++ b/bento/tools/_shape_features.py @@ -6,7 +6,7 @@ warnings.filterwarnings("ignore") -from typing import Callable, Dict, List, Union +from typing import Callable, Dict, List, Union, Optional import matplotlib.path as mplPath import numpy as np @@ -21,24 +21,23 @@ from .._utils import get_points, get_shape, set_shape_metadata -def area(sdata: SpatialData, shape_key: str, recompute: bool = False): - """ - Compute the area of each shape. +def area(sdata: SpatialData, shape_key: str, recompute: bool = False) -> None: + """Compute the area of each shape. Parameters ---------- sdata : SpatialData - Spatial formatted SpatialData + Input SpatialData object shape_key : str - Key in `sdata.shapes[shape_key]` that contains the shape information. - recompute : bool, optional - If True, forces the computation of the area even if it already exists in the shape metadata. - If False (default), the computation is skipped if the area already exists. + Key in sdata.shapes containing shape geometries + recompute : bool, default False + Whether to force recomputation if feature exists Returns - ------ - .shapes[shape_key]['{shape}_area'] : float - Area of each polygon + ------- + None + Updates sdata.shapes[shape_key] with: + - '{shape_key}_area': Area of each polygon """ feature_key = f"{shape_key}_area" @@ -52,8 +51,20 @@ def area(sdata: SpatialData, shape_key: str, recompute: bool = False): ) -def _poly_aspect_ratio(poly): - """Compute the aspect ratio of the minimum rotated rectangle that contains a polygon.""" +def _poly_aspect_ratio(poly: Union[MultiPolygon, None]) -> float: + """Compute aspect ratio of minimum rotated rectangle containing a polygon. + + Parameters + ---------- + poly : MultiPolygon or None + Input polygon geometry + + Returns + ------- + float + Ratio of longest to shortest side of minimum bounding rectangle, + or np.nan if polygon is None + """ if not poly: return np.nan @@ -74,20 +85,26 @@ def _poly_aspect_ratio(poly): return length / width -def aspect_ratio(sdata: SpatialData, shape_key: str, recompute: bool = False): - """Compute the aspect ratio of the minimum rotated rectangle that contains each shape. +def aspect_ratio(sdata: SpatialData, shape_key: str, recompute: bool = False) -> None: + """Compute aspect ratio of minimum rotated rectangle containing each shape. + + The aspect ratio is defined as the ratio of the longest to shortest side + of the minimum rotated rectangle that contains the shape. Parameters ---------- sdata : SpatialData - Spatial formatted SpatialData + Input SpatialData object shape_key : str - Key in `sdata.shapes[shape_key]` that contains the shape information. + Key in sdata.shapes containing shape geometries + recompute : bool, default False + Whether to force recomputation if feature exists - Fields - ------ - .shapes[shape_key]['{shape}_aspect_ratio'] : float - Ratio of major to minor axis for each polygon + Returns + ------- + None + Updates sdata.shapes[shape_key] with: + - '{shape_key}_aspect_ratio': Ratio of major to minor axis """ feature_key = f"{shape_key}_aspect_ratio" @@ -100,26 +117,26 @@ def aspect_ratio(sdata: SpatialData, shape_key: str, recompute: bool = False): ) -def bounds(sdata: SpatialData, shape_key: str, recompute: bool = False): - """Compute the minimum and maximum coordinate values that bound each shape. +def bounds(sdata: SpatialData, shape_key: str, recompute: bool = False) -> None: + """Compute bounding box coordinates for each shape. Parameters ---------- sdata : SpatialData - Spatial formatted SpatialData + Input SpatialData object shape_key : str - Key in `sdata.shapes[shape_key]` that contains the shape information. + Key in sdata.shapes containing shape geometries + recompute : bool, default False + Whether to force recomputation if feature exists Returns - ------ - .shapes[shape_key]['{shape}_minx'] : float - x-axis lower bound of each polygon - .shapes[shape_key]['{shape}_miny'] : float - y-axis lower bound of each polygon - .shapes[shape_key]['{shape}_maxx'] : float - x-axis upper bound of each polygon - .shapes[shape_key]['{shape}_maxy'] : float - y-axis upper bound of each polygon + ------- + None + Updates sdata.shapes[shape_key] with: + - '{shape_key}_minx': x-axis lower bound + - '{shape_key}_miny': y-axis lower bound + - '{shape_key}_maxx': x-axis upper bound + - '{shape_key}_maxy': y-axis upper bound """ feat_names = ["minx", "miny", "maxx", "maxy"] @@ -140,20 +157,23 @@ def bounds(sdata: SpatialData, shape_key: str, recompute: bool = False): ) -def density(sdata: SpatialData, shape_key: str, recompute: bool = False): - """Compute the RNA density of each shape. +def density(sdata: SpatialData, shape_key: str, recompute: bool = False) -> None: + """Compute RNA density (molecules per area) for each shape. Parameters ---------- sdata : SpatialData - Spatial formatted SpatialData + Input SpatialData object shape_key : str - Key in `sdata.shapes[shape_key]` that contains the shape information. + Key in sdata.shapes containing shape geometries + recompute : bool, default False + Whether to force recomputation if feature exists Returns - ------ - .shapes[shape_key]['{shape}_density'] : float - Density (molecules / shape area) of each polygon + ------- + None + Updates sdata.shapes[shape_key] with: + - '{shape_key}_density': Number of molecules divided by shape area """ feature_key = f"{shape_key}_density" @@ -165,6 +185,7 @@ def density(sdata: SpatialData, shape_key: str, recompute: bool = False): .query(f"{shape_key} != 'None'")[shape_key] .value_counts() .compute() + .reindex_like(sdata.shapes[shape_key]) ) area(sdata, shape_key) @@ -178,18 +199,28 @@ def density(sdata: SpatialData, shape_key: str, recompute: bool = False): def opening( sdata: SpatialData, shape_key: str, proportion: float, recompute: bool = False -): - """Compute the opening (morphological) of distance d for each cell. +) -> None: + """Compute morphological opening of each shape. + + The opening operation erodes the shape by distance d and then dilates by d, + where d = proportion * shape radius. Parameters ---------- sdata : SpatialData - Spatial formatted SpatialData + Input SpatialData object + shape_key : str + Key in sdata.shapes containing shape geometries + proportion : float + Fraction of shape radius to use as opening distance + recompute : bool, default False + Whether to force recomputation if feature exists Returns ------- - .shapes[shape_key]['cell_open_{d}_shape'] : Polygons - Ratio of long / short axis for each polygon in `.shapes[shape_key]['cell_boundaries']` + None + Updates sdata.shapes[shape_key] with: + - '{shape_key}_open_{proportion}_shape': Opened shape geometries """ feature_key = f"{shape_key}_open_{proportion}_shape" @@ -209,14 +240,20 @@ def opening( ) -def _second_moment_polygon(centroid, pts): - """ - Calculate second moment of points with centroid as reference. +def _second_moment_polygon(centroid: Point, pts: np.ndarray) -> Optional[float]: + """Calculate second moment of points relative to a centroid. Parameters ---------- - centroid : 2D Point object - pts : [n x 2] float + centroid : Point + Reference point for moment calculation + pts : np.ndarray + Array of point coordinates, shape (n, 2) + + Returns + ------- + float or None + Second moment value, or None if inputs are invalid """ if not centroid or not isinstance(pts, np.ndarray): @@ -227,18 +264,25 @@ def _second_moment_polygon(centroid, pts): return second_moment -def second_moment(sdata: SpatialData, shape_key: str, recompute: bool = False): - """Compute the second moment of each shape. +def second_moment(sdata: SpatialData, shape_key: str, recompute: bool = False) -> None: + """Compute second moment of each shape relative to its centroid. + + The second moment measures the spread of points in the shape around its center. Parameters ---------- sdata : SpatialData - Spatial formatted SpatialData + Input SpatialData object + shape_key : str + Key in sdata.shapes containing shape geometries + recompute : bool, default False + Whether to force recomputation if feature exists Returns ------- - .shapes[shape_key]['{shape}_moment'] : float - The second moment for each polygon + None + Updates sdata.shapes[shape_key] with: + - '{shape_key}_moment': Second moment value for each shape """ feature_key = f"{shape_key}_moment" @@ -260,11 +304,23 @@ def second_moment(sdata: SpatialData, shape_key: str, recompute: bool = False): ) -def _raster_polygon(poly, step=1): - """ - Generate a grid of points contained within the poly. The points lie on - a 2D grid, with vertices spaced step distance apart. +def _raster_polygon(poly: Union[MultiPolygon, None], step: int = 1) -> Optional[np.ndarray]: + """Generate grid of points contained within a polygon. + + Parameters + ---------- + poly : MultiPolygon or None + Input polygon geometry + step : int, default 1 + Grid spacing between points + + Returns + ------- + np.ndarray or None + Array of grid point coordinates, shape (n, 2), + or None if polygon is invalid """ + if not poly: return minx, miny, maxx, maxy = [int(i) for i in poly.bounds] @@ -300,19 +356,31 @@ def raster( points_key: str = "transcripts", step: int = 1, recompute: bool = False, -): - """Generate a grid of points contained within each shape. The points lie on - a 2D grid, with vertices spaced `step` distance apart. +) -> None: + """Generate grid of points within each shape. + + Creates a regular grid of points with spacing 'step' that covers each shape. + Points outside the shape are excluded. Parameters ---------- sdata : SpatialData - Spatial formatted SpatialData + Input SpatialData object + shape_key : str + Key in sdata.shapes containing shape geometries + points_key : str, default "transcripts" + Key for points in sdata.points + step : int, default 1 + Grid spacing between points + recompute : bool, default False + Whether to force recomputation if feature exists Returns ------- - .shapes[shape_key]['{shape}_raster'] : np.array - Long DataFrame of points annotated by shape from `.shapes[shape_key]['{shape_key}']` + None + Updates: + - sdata.shapes[shape_key]['{shape_key}_raster']: Array of grid points per shape + - sdata.points['{shape_key}_raster']: All grid points as point cloud """ shape_feature_key = f"{shape_key}_raster" @@ -350,18 +418,23 @@ def raster( sdata.points[shape_feature_key].attrs = transform -def perimeter(sdata: SpatialData, shape_key: str, recompute: bool = False): - """Compute the perimeter of each shape. +def perimeter(sdata: SpatialData, shape_key: str, recompute: bool = False) -> None: + """Compute perimeter length of each shape. Parameters ---------- sdata : SpatialData - Spatial formatted SpatialData + Input SpatialData object + shape_key : str + Key in sdata.shapes containing shape geometries + recompute : bool, default False + Whether to force recomputation if feature exists Returns ------- - `.shapes[shape_key]['{shape}_perimeter']` : np.array - Perimeter of each polygon + None + Updates sdata.shapes[shape_key] with: + - '{shape_key}_perimeter': Perimeter length of each shape """ feature_key = f"{shape_key}_perimeter" @@ -376,18 +449,26 @@ def perimeter(sdata: SpatialData, shape_key: str, recompute: bool = False): ) -def radius(sdata: SpatialData, shape_key: str, recompute: bool = False): - """Compute the radius of each cell. +def radius(sdata: SpatialData, shape_key: str, recompute: bool = False) -> None: + """Compute average radius of each shape. + + The radius is calculated as the mean distance from the shape's centroid + to points on its boundary. Parameters ---------- sdata : SpatialData - Spatial formatted SpatialData + Input SpatialData object + shape_key : str + Key in sdata.shapes containing shape geometries + recompute : bool, default False + Whether to force recomputation if feature exists Returns ------- - .shapes[shape_key]['{shape}_radius'] : np.array - Radius of each polygon in `obs['cell_shape']` + None + Updates sdata.shapes[shape_key] with: + - '{shape_key}_radius': Average radius of each shape """ feature_key = f"{shape_key}_radius" @@ -406,7 +487,22 @@ def radius(sdata: SpatialData, shape_key: str, recompute: bool = False): ) -def _shape_radius(poly): +def _shape_radius(poly: Union[MultiPolygon, None]) -> float: + """Compute average radius of a polygon. + + Calculates mean distance from centroid to boundary points. + + Parameters + ---------- + poly : MultiPolygon or None + Input polygon geometry + + Returns + ------- + float + Average radius, or np.nan if polygon is None + """ + if not poly: return np.nan @@ -415,18 +511,25 @@ def _shape_radius(poly): ).mean() -def span(sdata: SpatialData, shape_key: str, recompute: bool = False): - """Compute the length of the longest diagonal of each shape. +def span(sdata: SpatialData, shape_key: str, recompute: bool = False) -> None: + """Compute maximum diameter of each shape. + + The span is the length of the longest line segment that fits within the shape. Parameters ---------- sdata : SpatialData - Spatial formatted SpatialData + Input SpatialData object + shape_key : str + Key in sdata.shapes containing shape geometries + recompute : bool, default False + Whether to force recomputation if feature exists Returns ------- - .shapes[shape_key]['{shape}_span'] : float - Length of longest diagonal for each polygon + None + Updates sdata.shapes[shape_key] with: + - '{shape_key}_span': Maximum diameter of each shape """ feature_key = f"{shape_key}_span" @@ -447,13 +550,13 @@ def get_span(poly): ) -def list_shape_features(): - """Return a dictionary of available shape features and their descriptions. +def list_shape_features() -> Dict[str, str]: + """List available shape feature calculations. Returns ------- - dict - A dictionary where keys are shape feature names and values are their corresponding descriptions. + Dict[str, str] + Dictionary mapping feature names to their descriptions """ # Get shape feature descriptions from docstrings @@ -481,22 +584,23 @@ def list_shape_features(): def shape_stats( sdata: SpatialData, - feature_names: List[str] = ["area", "aspect_ratio", "density"], -): - """Compute descriptive stats for cells. Convenient wrapper for `bento.tl.shape_features`. - See list of available features in `bento.tl.shape_features`. + feature_names: List[str] = ["area", "aspect_ratio", "density"] +) -> None: + """Compute common shape statistics. + + Wrapper around analyze_shapes() for frequently used features. Parameters ---------- sdata : SpatialData - Spatial formatted SpatialData - feature_names : list - List of features to compute. See list of available features in `bento.tl.shape_features`. + Input SpatialData object + feature_names : List[str], default ["area", "aspect_ratio", "density"] + Features to compute Returns ------- - .shapes['cell_boundaries']['cell_boundaries_{feature}'] : np.array - Feature of each polygon + None + Updates sdata.shapes with computed features """ # Compute features @@ -509,28 +613,31 @@ def analyze_shapes( sdata: SpatialData, shape_keys: Union[str, List[str]], feature_names: Union[str, List[str]], - feature_kws: Dict[str, Dict] = None, + feature_kws: Optional[Dict[str, Dict]] = None, recompute: bool = False, progress: bool = True, -): - """Analyze features of shapes. +) -> None: + """Compute multiple shape features. Parameters ---------- sdata : SpatialData - Spatial formatted SpatialData - shape_keys : list of str - List of shapes to analyze. - feature_names : list of str - List of features to analyze. - feature_kws : dict, optional (default: None) - Keyword arguments for each feature. + Input SpatialData object + shape_keys : str or list of str + Keys in sdata.shapes to analyze + feature_names : str or list of str + Names of features to compute + feature_kws : dict, optional + Additional keyword arguments for each feature + recompute : bool, default False + Whether to force recomputation if features exist + progress : bool, default True + Whether to show progress bar Returns ------- - sdata : SpatialData - See specific feature function docs for fields added. - + None + Updates sdata.shapes with computed features """ # Cast to list if not already @@ -557,17 +664,21 @@ def analyze_shapes( shape_features[feature](sdata, shape, **kws) -def register_shape_feature(name: str, func: Callable): - """Register a shape feature function. The function should take an SpatialData object and a shape name as input. - The function should add the feature to the SpatialData object as a column in SpatialData.tables["table"].obs. - This should be done in place and not return anything. +def register_shape_feature(name: str, func: Callable[[SpatialData, str], None]) -> None: + """Register a new shape feature calculation function. Parameters ---------- name : str - Name of the feature function. - func : function - Function that takes a SpatialData object and a shape name as arguments. + Name to register the feature as + func : Callable[[SpatialData, str], None] + Function that takes SpatialData and shape_key as arguments + and modifies SpatialData in-place + + Returns + ------- + None + Updates global shape_features dictionary """ shape_features[name] = func diff --git a/bento/tools/gene_sets/boyle2023.zip b/bento/tools/gene_sets/boyle2023.zip new file mode 100644 index 0000000..a0ebfe2 Binary files /dev/null and b/bento/tools/gene_sets/boyle2023.zip differ diff --git a/docs/source/api.md b/docs/source/api.md index 2f2e242..2c42c2d 100644 --- a/docs/source/api.md +++ b/docs/source/api.md @@ -83,7 +83,7 @@ TODO ### Shape Features -Compute spatial properties of shape features e.g. area, aspect ratio, etc. of the cell, nucleus, or other region of interest. The set of available shape features is described in the Shape Feature Catalog. Use the function `bt.analyze_points()` to compute features and add your own custom calculation. See the [tutorial](https://bento-tools.github.io/bento/tutorials/TBD.html) for more information. +Compute spatial properties of shape features e.g. area, aspect ratio, etc. of the cell, nucleus, or other region of interest. The set of available shape features is described in the Shape Feature Catalog. Use the function `bt.analyze_points()` to compute features and add your own custom calculation. See the [tutorial](tutorial_gallery/Spatial_Features.html) for more information. ```{eval-rst} .. currentmodule:: bento.tl diff --git a/docs/source/conf.py b/docs/source/conf.py index 474152e..1f8549f 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -15,7 +15,14 @@ from sphinxawesome_theme import LinkIcon, ThemeOptions -sys.path.insert(0, os.path.abspath("..")) # Source code dir relative to this file +sys.path.insert(0, os.path.abspath("../..")) # Go up two levels to reach project root + +# Debug imports +try: + import bento + print(f"Successfully imported bento from {bento.__file__}") +except ImportError as e: + print(f"Failed to import bento: {e}") # -- Project information ----------------------------------------------------- @@ -57,7 +64,7 @@ intersphinx_mapping = { "geopandas": ("https://geopandas.org/en/stable/", None), "shapely": ("https://shapely.readthedocs.io/en/stable/", None), - "spatialdata": ("https://spatialdata.readthedocs.io/en/stable/", None), + "spatialdata": ("https://spatialdata.scverse.org/en/stable/", None), "scanpy": ("https://scanpy.readthedocs.io/en/stable/", None), "anndata": ("https://anndata.readthedocs.io/en/stable/", None), } diff --git a/docs/source/howitworks.md b/docs/source/howitworks.md index 70daa85..57d321d 100644 --- a/docs/source/howitworks.md +++ b/docs/source/howitworks.md @@ -6,7 +6,7 @@ ## Data Format -Under the hood, we use the [SpatialData](https://spatialdata.scverse.org/en/latest/) framework to manage `SpatialData` objects in Python, allowing us to store and manipulate spatial data in a standardized format. Briefly, `SpatialData` objects are stored on-disk in the Zarr storage format. We aim to be fully compatible with SpatialData, so you can use the same objects in both Bento and SpatialData. +Under the hood, we use the {doc}`SpatialData ` framework to manage `SpatialData` objects in Python, allowing us to store and manipulate spatial data in a standardized format. Briefly, `SpatialData` objects are stored on-disk in the Zarr storage format. We aim to be fully compatible with SpatialData, so you can use the same objects in both Bento and SpatialData. To enable scalable and performant operation with Bento, we perform spatial indexing on the data upfront and store these indices as metadata. This allows us to quickly query points within shapes, and shapes that contain points. Bento adopts a cell-centric approach, where each cell is treated as an independent unit of analysis. This allows us to perform subcellular spatial analysis within individual cells, and aggregate results across cells. @@ -47,7 +47,7 @@ The `SpatialData` object is a container for the following elements: - `Shapes`: boundaries, circles, polygons - `Tables`: annotations, count matrices -See the [Data Prep Guide](tutorial_gallery/Data_Prep_Guide.html) for more information on how to prepare `SpatialData` objects for Bento and official [SpatialData documentation](https://spatialdata.scverse.org) for more info. +See the {doc}`Data Prep Guide ` for more information on how to prepare `SpatialData` objects for Bento and official {doc}`SpatialData docs ` for more info. ## RNAflux diff --git a/pyproject.toml b/pyproject.toml index f35b494..6dc48de 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,5 @@ [project] name = "bento-tools" -version = "2.1.3" description = "A toolkit for subcellular analysis of spatial transcriptomics data" authors = [{ name = "ckmah", email = "clarence.k.mah@gmail.com" }] dependencies = [ @@ -24,12 +23,15 @@ dependencies = [ "upsetplot>=0.9.0", "xgboost>=2.0.3", "statsmodels>=0.14.1", - "scikit-learn>=1.4.2", + "scikit-learn<1.6.0", "ipywidgets>=8.1.5", + "tomli>=2.2.1", ] license = "BSD-2-Clause" readme = "README.md" requires-python = ">= 3.10" +version = "2.1.4" + [project.optional-dependencies] docs = [ @@ -49,6 +51,7 @@ build-backend = "hatchling.build" [tool.rye] managed = true +version.source = "pyproject" dev-dependencies = [ "pytest>=8.2.2", "pytest-cov>=5.0.0", @@ -71,3 +74,6 @@ include = ["bento"] [tool.hatch.build.targets.wheel] packages = ["bento"] + +[tool.hatch.version] +path = "bento/_version.py" diff --git a/tests/__init__.py b/tests/__init__.py index e69de29..a73d2e8 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -0,0 +1,5 @@ +import os +import sys + +# Add the parent directory to the Python path +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) diff --git a/tests/test_flux.py b/tests/test_flux.py index b88e0fb..e30887f 100644 --- a/tests/test_flux.py +++ b/tests/test_flux.py @@ -2,7 +2,7 @@ import bento as bt -from . import conftest +from tests import conftest @pytest.fixture(scope="module") diff --git a/tests/test_io.py b/tests/test_io.py index 2ae99f3..b2d5f2b 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -1,5 +1,5 @@ -from . import conftest +from tests import conftest def test_points_indexing(small_data): diff --git a/tests/test_lp.py b/tests/test_lp.py index 1abcb17..5092ab0 100644 --- a/tests/test_lp.py +++ b/tests/test_lp.py @@ -3,7 +3,7 @@ import bento as bt -from . import conftest +from tests import conftest @pytest.fixture(scope="module") diff --git a/tests/test_point_features.py b/tests/test_point_features.py index da89830..fd45297 100644 --- a/tests/test_point_features.py +++ b/tests/test_point_features.py @@ -1,7 +1,7 @@ import pytest import bento as bt -from . import conftest +from tests import conftest @pytest.fixture(scope="module") diff --git a/tests/test_shape_features.py b/tests/test_shape_features.py index 11ccf69..9c049b5 100644 --- a/tests/test_shape_features.py +++ b/tests/test_shape_features.py @@ -1,7 +1,7 @@ import pytest import bento as bt -from . import conftest +from tests import conftest @pytest.fixture(scope="module")