Skip to content

Commit

Permalink
Merge pull request #149 from ckmah/v.2.1.3
Browse files Browse the repository at this point in the history
V.2.1.3
  • Loading branch information
ckmah authored Sep 9, 2024
2 parents 7fdb42c + 33e4c2c commit 909ff80
Show file tree
Hide file tree
Showing 18 changed files with 305 additions and 140 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,5 @@ requirements.lock
requirements-dev.lock
profile.html
profile.json

notebooks/
1 change: 1 addition & 0 deletions bento/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
PATTERN_COLORS = ["#17becf", "#1f77b4", "#7f7f7f", "#ff7f0e", "#d62728"]
PATTERN_NAMES = ["cell_edge", "cytoplasmic", "none", "nuclear", "nuclear_edge"]
PATTERN_PROBS = [f"{p}_p" for p in PATTERN_NAMES]
PATTERN_THRESHOLDS_CALIB = [0.45300, 0.43400, 0.37900, 0.43700, 0.50500]


class CosMx(Enum):
Expand Down
9 changes: 5 additions & 4 deletions bento/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def set_points_metadata(
metadata: Union[List, pd.Series, pd.DataFrame, np.ndarray],
columns: Union[List[str], str],
) -> None:
"""Write metadata in SpatialData points element as column(s). Aligns metadata index to shape index if present.
"""Write metadata in SpatialData points element as column(s).
Parameters
----------
Expand Down Expand Up @@ -314,9 +314,10 @@ def set_shape_metadata(
if "" not in metadata[col].cat.categories:
metadata[col] = metadata[col].cat.add_categories([""]).fillna("")

sdata.shapes[shape_key].loc[:, metadata.columns] = metadata.reindex(
shape_index
).fillna("")
sdata.shapes[shape_key] = sdata.shapes[shape_key].assign(
**metadata.reindex(shape_index).to_dict()
)
# sdata.shapes[shape_key].loc[:, metadata.columns] = metadata.reindex(shape_index)


def _sync_points(sdata, points_key):
Expand Down
56 changes: 39 additions & 17 deletions bento/io/_index.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import List

import pandas as pd
import geopandas as gpd
from spatialdata._core.spatialdata import SpatialData

Expand Down Expand Up @@ -42,19 +43,27 @@ def _sjoin_points(

# Grab points as GeoDataFrame
points = get_points(sdata, points_key, astype="geopandas", sync=False)
points.index.name = "pt_index"

# Index points to shapes
indexed_points = {}
for shape_key, shape in query_shapes.items():
shape = query_shapes[shape_key]
shape.index.name = None
shape.index.name = None # Forces sjoin to name index "index_right"
shape.index = shape.index.astype(str)

points = points.sjoin(shape, how="left", predicate="intersects")
points = points[~points.index.duplicated(keep="last")]
points["index_right"].fillna("", inplace=True)
points.rename(columns={"index_right": shape_key}, inplace=True)
indexed_points[shape_key] = (
points.sjoin(shape, how="left", predicate="intersects")
.reset_index()
.drop_duplicates(subset="pt_index")["index_right"]
.fillna("")
.values.flatten()
)

set_points_metadata(sdata, points_key, points[shape_key], columns=shape_key)
index_points = pd.DataFrame(indexed_points)
set_points_metadata(
sdata, points_key, index_points, columns=list(indexed_points.keys())
)

return sdata

Expand Down Expand Up @@ -97,23 +106,36 @@ def _sjoin_shapes(sdata: SpatialData, instance_key: str, shape_keys: List[str]):
# Hack for polygons that are 99% contained in parent shape or have shared boundaries
child_shape = gpd.GeoDataFrame(geometry=child_shape.buffer(-10e-6))

parent_shape = parent_shape.sjoin(child_shape, how="left", predicate="covers")
parent_shape = parent_shape[~parent_shape.index.duplicated(keep="last")]
parent_shape.loc[parent_shape["index_right"].isna(), "index_right"] = ""
parent_shape = parent_shape.astype({"index_right": "category"})
# Map child shape index to parent shape and process the result
parent_shape = (
parent_shape.sjoin(child_shape, how="left", predicate="covers")
.reset_index()
.drop_duplicates(subset="index", keep="last")
.set_index("index")
.assign(
index_right=lambda df: df.loc[
~df["index_right"].duplicated(keep="first"), "index_right"
]
.fillna("")
.astype("category")
)
.rename(columns={"index_right": shape_key})
)
parent_shape[shape_key] = parent_shape[shape_key].fillna("")

# save shape index as column in instance_key shape
parent_shape.rename(columns={"index_right": shape_key}, inplace=True)
# Save shape index as column in instance_key shape
set_shape_metadata(
sdata, shape_key=instance_key, metadata=parent_shape[shape_key]
)

# Add instance_key shape index to shape
parent_shape.index.name = "parent_index"
instance_index = parent_shape.reset_index().set_index(shape_key)["parent_index"]
instance_index.name = instance_key
instance_index.index.name = None
instance_index = instance_index[instance_index.index != ""]
instance_index = (
parent_shape.drop_duplicates(subset=shape_key)
.reset_index()
.set_index(shape_key)["index"]
.rename(instance_key)
.loc[lambda s: s.index != ""]
)

set_shape_metadata(sdata, shape_key=shape_key, metadata=instance_index)

Expand Down
4 changes: 3 additions & 1 deletion bento/io/_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def prep(
for shape_key, shape_gdf in sdata.shapes.items():
if shape_key == instance_key:
shape_gdf[shape_key] = shape_gdf["geometry"]
shape_gdf.index = make_index_unique(shape_gdf.index)
shape_gdf.index = make_index_unique(shape_gdf.index.astype(str))

# sindex points and sjoin shapes if they have not been indexed or joined
point_sjoin = []
Expand All @@ -67,6 +67,8 @@ def prep(
shape_sjoin.append(shape_key)

# Set instance key for points
if "spatialdata_attrs" not in sdata.points[points_key].attrs:
sdata.points[points_key].attrs["spatialdata_attrs"] = {}
sdata.points[points_key].attrs["spatialdata_attrs"]["instance_key"] = instance_key

pbar = tqdm(total=3)
Expand Down
9 changes: 6 additions & 3 deletions bento/plotting/_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,12 @@ def _polygons(sdata, shape, ax, hue=None, sync=True, **kwargs):
else:
shapes[hue] = df.reset_index()[hue].values
style_kwds["facecolor"] = sns.axes_style()["axes.edgecolor"]
style_kwds["edgecolor"] = "none" # let GeoDataFrame plot function handle facecolor
style_kwds["edgecolor"] = (
"none" # let GeoDataFrame plot function handle facecolor
)

style_kwds.update(kwargs)


patches = []
# Manually create patches for each polygon; GeoPandas plot function is slow
for poly in shapes["geometry"].values:
Expand All @@ -91,7 +92,9 @@ def _polygons(sdata, shape, ax, hue=None, sync=True, **kwargs):
ax.add_collection(patches)


def _raster(sdata, res, color, points_key, alpha, cbar=False, ax=None, **kwargs):
def _raster(
sdata, res, color, points_key, alpha, pthreshold=None, cbar=False, ax=None, **kwargs
):
"""Plot gradient."""

if ax is None:
Expand Down
2 changes: 1 addition & 1 deletion bento/plotting/_lp.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def lp_genes(
**kwargs
Options to pass to matplotlib plotting method.
"""
lp_stats(sdata, instance_key=instance_key)
lp_stats(sdata)

palette = dict(zip(PATTERN_NAMES, PATTERN_COLORS))

Expand Down
8 changes: 4 additions & 4 deletions bento/plotting/_multidimensional.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,10 @@ def shape_stats(
nucleus_gdf = pd.DataFrame(
sdata[nucleus_key].melt(value_vars=[c for c in cols if f"{nucleus_key}_" in c])
)
stats_long = cell_gdf.append(nucleus_gdf, ignore_index=True)
stats_long["quantile"] = stats_long.groupby("variable")["value"].transform(
lambda x: quantile_transform(x.values.reshape(-1, 1), n_quantiles=100).flatten()
)
stats_long = pd.concat([cell_gdf, nucleus_gdf], ignore_index=True)
# stats_long["quantile"] = stats_long.groupby("variable")["value"].transform(
# lambda x: quantile_transform(x.values.reshape(-1, 1), n_quantiles=100).flatten()
# )

stats_long["shape"] = stats_long["variable"].apply(lambda x: x.split("_")[0])
stats_long["var"] = stats_long["variable"].apply(
Expand Down
81 changes: 57 additions & 24 deletions bento/tools/_colocation.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
from typing import List

import dask
import emoji
import numpy as np
import pandas as pd
import seaborn as sns
import sparse
from spatialdata._core.spatialdata import SpatialData
from kneed import KneeLocator
from tqdm.auto import tqdm
from spatialdata._core.spatialdata import SpatialData
from tqdm.dask import TqdmCallback

from .._utils import get_points
from ._neighborhoods import _count_neighbors
from ._decomposition import decompose
from ._neighborhoods import _count_neighbors
import dask.bag as db


def colocation(
Expand Down Expand Up @@ -120,6 +122,7 @@ def coloc_quotient(
radius: int = 20,
min_points: int = 10,
min_cells: int = 0,
num_workers=1,
):
"""Calculate pairwise gene colocalization quotient in each cell.
Expand All @@ -141,6 +144,8 @@ def coloc_quotient(
Minimum number of points for sample to be considered for colocalization, default 10
min_cells : int
Minimum number of cells for gene to be considered for colocalization, default 0
num_workers : int
Number of workers to use for parallel processing
Returns
-------
Expand All @@ -160,41 +165,54 @@ def coloc_quotient(
# Keep genes expressed in at least min_cells cells
gene_counts = points.groupby(feature_key).size()
valid_genes = gene_counts[gene_counts >= min_cells].index

# Filter points by valid genes
points = points[points[feature_key].isin(valid_genes)]

# Partition so {chunksize} cells per partition
cells, group_loc = np.unique(
points[instance_key].astype(str),
return_index=True,
)
# Group points by cell
points_grouped = points.groupby(instance_key)
cells = list(points_grouped.groups.keys())
cells.sort()

end_loc = np.append(group_loc[1:], points.shape[0])
args = [
(
points_grouped.get_group(cell),
radius,
min_points,
feature_key,
instance_key,
)
for cell in cells
]

cell_clqs = []
for cell, start, end in tqdm(
zip(cells, group_loc, end_loc), desc=shape, total=len(cells)
):
cell_points = points.iloc[start:end]
cell_clq = _cell_clq(cell_points, radius, min_points, feature_key)
cell_clq[instance_key] = cell
bags = db.from_sequence(args).map(lambda x: _cell_clq(*x))

cell_clqs.append(cell_clq)
# Use dask.compute to execute the operations in parallel
with TqdmCallback(desc="Batches"), dask.config.set(num_workers=num_workers):
cell_clqs = bags.compute()

cell_clqs = pd.concat(cell_clqs)
cell_clqs[[instance_key, feature_key, "neighbor"]] = (
cell_clqs[[instance_key, feature_key, "neighbor"]]
.astype(str)
.astype("category")
)
cell_clqs["log_clq"] = cell_clqs["clq"].replace(0, np.nan).apply(np.log2)

# Compute log2 of clq and confidence intervals
cell_clqs["log_clq"] = cell_clqs["clq"].replace(0, np.nan).apply(np.log2)
cell_clqs["log_ci_lower"] = (
cell_clqs["ci_lower"].replace(0, np.nan).apply(np.log2)
)
cell_clqs["log_ci_upper"] = (
cell_clqs["ci_upper"].replace(0, np.nan).apply(np.log2)
)
# Save to uns['clq'] as adjacency list
all_clq[shape] = cell_clqs

sdata.tables["table"].uns["clq"] = all_clq


def _cell_clq(cell_points, radius, min_points, feature_key):
def _cell_clq(cell_points, radius, min_points, feature_key, instance_key):
# Count number of points for each gene
gene_counts = cell_points[feature_key].value_counts()

Expand All @@ -217,15 +235,21 @@ def _cell_clq(cell_points, radius, min_points, feature_key):
radius=radius,
agg="binary",
).toarray()

point_neighbors = pd.DataFrame(
point_neighbors, columns=valid_points[feature_key].cat.categories
)

# Get gene-level neighbor counts for each gene
neighbor_counts = (
pd.DataFrame(point_neighbors, columns=valid_points[feature_key].cat.categories)
.groupby(valid_points[feature_key].values)
point_neighbors.groupby(valid_points[feature_key].values)
.sum()
.reset_index()
.melt(id_vars="index")
.query("value > 0")
)
neighbor_counts.columns = [feature_key, "neighbor", "count"]
neighbor_counts[instance_key] = cell_points[instance_key].iloc[0]
clq_df = _clq_statistic(neighbor_counts, gene_counts, feature_key)

return clq_df
Expand All @@ -243,7 +267,16 @@ def _clq_statistic(neighbor_counts, counts, feature_key):
Series of raw gene counts.
"""
clq_df = neighbor_counts.copy()
clq_df["clq"] = (clq_df["count"] / counts.loc[clq_df[feature_key]].values) / (
counts.loc[clq_df["neighbor"]].values / counts.sum()
)
a = clq_df["count"]
b = counts.loc[clq_df[feature_key]].values
c = counts.loc[clq_df["neighbor"]].values
d = counts.sum()

clq_df["clq"] = (a / b) / (c / d)

# Calculate two-tailed 95% confidence interval
ci_lower = clq_df["clq"] - 1.96 * np.sqrt((1 / a) + (1 / b) + (1 / c) + (1 / d))
ci_upper = clq_df["clq"] + 1.96 * np.sqrt((1 / a) + (1 / b) + (1 / c) + (1 / d))
clq_df["ci_lower"] = ci_lower
clq_df["ci_upper"] = ci_upper
return clq_df.drop("count", axis=1)
Loading

0 comments on commit 909ff80

Please sign in to comment.