Skip to content

Commit

Permalink
Merge pull request #119 from ckmah/plotting
Browse files Browse the repository at this point in the history
Plotting and metadata setters
  • Loading branch information
ckmah authored Apr 15, 2024
2 parents 34f7dce + af9afd0 commit ba81649
Show file tree
Hide file tree
Showing 22 changed files with 869 additions and 572 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ dmypy.json

# Pyre type checker
.pyre/

tests/img
docs/build.zip
tests/data/processed/data.pt
tests/data/processed/pre_filter.pt
Expand Down
49 changes: 49 additions & 0 deletions bento/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from spatialdata._core.spatialdata import SpatialData
from spatialdata.models import TableModel
from dask import dataframe as dd

from .geometry._geometry import get_points

def filter_by_gene(
sdata: SpatialData,
threshold: int = 10,
points_key: str = "transcripts",
feature_key: str = "feature_name"
):
"""
Filters out genes with low expression from the spatial data object.
Parameters
----------
sdata : SpatialData
Spatial formatted SpatialData object.
threshold : int
Minimum number of counts for a gene to be considered expressed.
Keep genes where at least {threshold} molecules are detected in at least one cell.
points_key : str
key for points element that holds transcript coordinates
feature_key : str
Key for gene instances
Returns
-------
sdata : SpatialData
.points[points_key] is updated to remove genes with low expression.
.table is updated to remove genes with low expression.
"""
gene_filter = (sdata.table.X >= threshold).sum(axis=0) > 0
filtered_table = sdata.table[:, gene_filter]

filtered_genes = list(sdata.table.var_names.difference(filtered_table.var_names))
points = get_points(sdata, points_key=points_key, astype="pandas", sync=False)
points = points[~points[feature_key].isin(filtered_genes)]
points[feature_key] = points[feature_key].cat.remove_unused_categories()

transform = sdata.points["transcripts"].attrs
sdata.points[points_key] = dd.from_pandas(points, npartitions=1)
sdata.points[points_key].attrs = transform

del sdata.table
sdata.table = TableModel.parse(filtered_table)

return sdata
107 changes: 59 additions & 48 deletions bento/geometry/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,12 @@ def sjoin_points(
points.loc[points["index_right"].isna(), "index_right"] = ""
points.rename(columns={"index_right": shape_key}, inplace=True)

set_points_metadata(sdata, points_key, points[shape_key])
set_points_metadata(sdata, points_key, points[shape_key], columns=shape_key)

return sdata


def sjoin_shapes(
sdata: SpatialData,
instance_key: str,
shape_keys: List[str]
):
def sjoin_shapes(sdata: SpatialData, instance_key: str, shape_keys: List[str]):
"""Adds polygon indexes to sdata.shapes[instance_key][shape_key] for point feature analysis
Parameters
Expand Down Expand Up @@ -92,19 +88,29 @@ def sjoin_shapes(
if len(shape_keys) == 0:
return sdata

parent_shape = sdata.shapes[instance_key]
parent_shape = gpd.GeoDataFrame(sdata.shapes[instance_key])

# sjoin shapes to instance_key shape
for shape_key in shape_keys:
child_shape = gpd.GeoDataFrame(geometry=sdata.shapes[shape_key]["geometry"])
child_shape = sdata.shapes[shape_key]["geometry"]

# Hack for polygons that are 99% contained in parent shape
child_shape = gpd.GeoDataFrame(
geometry=child_shape.buffer(
child_shape.minimum_bounding_radius().mean() * -0.05
)
)

parent_shape = parent_shape.sjoin(child_shape, how="left", predicate="contains")
parent_shape = parent_shape[~parent_shape.index.duplicated(keep="last")]
parent_shape.loc[parent_shape["index_right"].isna(), "index_right"] = ""
parent_shape = parent_shape.astype({"index_right": "category"})

# save shape index as column in instance_key shape
parent_shape.rename(columns={"index_right": shape_key}, inplace=True)
set_shape_metadata(sdata, shape_key=instance_key, metadata=parent_shape[shape_key])
set_shape_metadata(
sdata, shape_key=instance_key, metadata=parent_shape[shape_key]
)

# Add instance_key shape index to shape
parent_shape.index.name = "parent_index"
Expand All @@ -117,6 +123,7 @@ def sjoin_shapes(

return sdata


def get_points(
sdata: SpatialData,
points_key: str = "transcripts",
Expand Down Expand Up @@ -167,7 +174,8 @@ def get_points(
return gpd.GeoDataFrame(
points, geometry=gpd.points_from_xy(points.x, points.y), copy=True
)



def get_shape(sdata: SpatialData, shape_key: str, sync: bool = True) -> gpd.GeoSeries:
"""Get a GeoSeries of Polygon objects from an SpatialData object.
Expand Down Expand Up @@ -199,10 +207,11 @@ def get_shape(sdata: SpatialData, shape_key: str, sync: bool = True) -> gpd.GeoS

return sdata.shapes[shape_key].geometry


def get_points_metadata(
sdata: SpatialData,
metadata_keys: Union[str, List[str]],
points_key: str = "transcripts",
points_key: str,
astype="pandas",
):
"""Get points metadata.
Expand All @@ -226,26 +235,27 @@ def get_points_metadata(
if points_key not in sdata.points.keys():
raise ValueError(f"Points key {points_key} not found in sdata.points")
if astype not in ["pandas", "dask"]:
raise ValueError(
f"astype must be one of ['dask', 'pandas'], not {astype}"
)
raise ValueError(f"astype must be one of ['dask', 'pandas'], not {astype}")
if isinstance(metadata_keys, str):
metadata_keys = [metadata_keys]
for key in metadata_keys:
if key not in sdata.points[points_key].columns:
raise ValueError(f"Metadata key {key} not found in sdata.points[{points_key}]")
raise ValueError(
f"Metadata key {key} not found in sdata.points[{points_key}]"
)

metadata = sdata.points[points_key][metadata_keys]

if astype == "pandas":
return metadata.compute()
elif astype == "dask":
return metadata



def get_shape_metadata(
sdata: SpatialData,
metadata_keys: Union[str, List[str]],
shape_key: str = "transcripts",
shape_key: str,
):
"""Get shape metadata.
Expand All @@ -269,54 +279,52 @@ def get_shape_metadata(
metadata_keys = [metadata_keys]
for key in metadata_keys:
if key not in sdata.shapes[shape_key].columns:
raise ValueError(f"Metadata key {key} not found in sdata.shapes[{shape_key}]")
raise ValueError(
f"Metadata key {key} not found in sdata.shapes[{shape_key}]"
)

return sdata.shapes[shape_key][metadata_keys]


def set_points_metadata(
sdata: SpatialData,
points_key: str,
metadata: Union[List, pd.Series, pd.DataFrame],
column_names: Optional[Union[str, List[str]]] = None,
metadata: Union[List, pd.Series, pd.DataFrame, np.ndarray],
columns: Union[str, List[str]],
):
"""Write metadata in SpatialData points element as column(s). Aligns metadata index to shape index.
"""Write metadata in SpatialData points element as column(s). Aligns metadata index to shape index if present.
Parameters
----------
sdata : SpatialData
Spatial formatted SpatialData object
points_key : str
Name of element in sdata.points
metadata : pd.Series, pd.DataFrame
Metadata to set for points. Index must be a (sub)set of points index.
metadata : pd.Series, pd.DataFrame, np.ndarray
Metadata to set for points. Assumes input is already aligned to points index.
column_names : str or list of str, optional
Name of column(s) to set. If None, use metadata column name(s), by default None
"""
if points_key not in sdata.points.keys():
raise ValueError(f"{points_key} not found in sdata.points")

if isinstance(metadata, list):
metadata = pd.Series(metadata, index=sdata.points[points_key].index)

if isinstance(metadata, pd.Series):
metadata = pd.DataFrame(metadata)
columns = [columns] if isinstance(columns, str) else columns

metadata = pd.DataFrame(np.array(metadata), columns=columns)

if column_names is not None:
if isinstance(column_names, str):
column_names = [column_names]
for i in range(len(column_names)):
metadata = metadata.rename(columns={metadata.columns[i]: column_names[i]})

sdata.points[points_key] = sdata.points[points_key].reset_index(drop=True)
for name, series in metadata.iteritems():
series = series.fillna("")
metadata_series = dd.from_pandas(series, npartitions=sdata.points[points_key].npartitions).reset_index(drop=True)
sdata.points[points_key][name] = metadata_series
for name, series in metadata.items():
series = series.fillna("") if series.dtype == object else series
series = dd.from_pandas(
series, npartitions=sdata.points[points_key].npartitions
).reset_index(drop=True)
sdata.points[points_key] = sdata.points[points_key].assign(**{name: series})


def set_shape_metadata(
sdata: SpatialData,
shape_key: str,
metadata: Union[List, pd.Series, pd.DataFrame],
metadata: Union[List, pd.Series, pd.DataFrame, np.ndarray],
column_names: Optional[Union[str, List[str]]] = None,
):
"""Write metadata in SpatialData shapes element as column(s). Aligns metadata index to shape index.
Expand All @@ -334,23 +342,25 @@ def set_shape_metadata(
"""
if shape_key not in sdata.shapes.keys():
raise ValueError(f"Shape {shape_key} not found in sdata.shapes")


shape_index = sdata.shapes[shape_key].index

if isinstance(metadata, list):
metadata = pd.Series(metadata, index=sdata.shapes[shape_key].index)
metadata = pd.Series(metadata, index=shape_index)

if isinstance(metadata, pd.Series):
if isinstance(metadata, pd.Series) or isinstance(metadata, np.ndarray):
metadata = pd.DataFrame(metadata)

if column_names is not None:
if isinstance(column_names, str):
column_names = [column_names]
for i in range(len(column_names)):
metadata = metadata.rename(columns={metadata.columns[i]: column_names[i]})
metadata.columns = (
[column_names] if isinstance(column_names, str) else column_names
)

sdata.shapes[shape_key].loc[:, metadata.columns] = metadata.reindex(
sdata.shapes[shape_key].index
shape_index
).fillna("")


def _check_points_sync(sdata, points_key):
"""
Check if points are synced to instance_key shape in a SpatialData object.
Expand All @@ -373,6 +383,7 @@ def _check_points_sync(sdata, points_key):
f"Points {points_key} not synced to instance_key shape element. Run bento.io.format_sdata() to setup SpatialData object for bento-tools."
)


def _check_shape_sync(sdata, shape_key, instance_key):
"""
Check if a shape is synced to instance_key shape in a SpatialData object.
Expand Down
2 changes: 1 addition & 1 deletion bento/plotting/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from ._multidimensional import flux_summary, obs_stats

from ._lp import lp_diff_discrete, lp_dist, lp_gene_dist, lp_genes
from ._lp import lp_diff_discrete, lp_dist, lp_genes
from ._plotting import points, density, shapes, flux, fluxmap, fe
from ._signatures import colocation, factor
Loading

0 comments on commit ba81649

Please sign in to comment.