Skip to content

Commit

Permalink
add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ckmah committed Jun 6, 2024
1 parent 1e82c86 commit 786e414
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 23 deletions.
13 changes: 6 additions & 7 deletions bento/geometry/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def overlay(
)



@singledispatch
def labels_to_shapes(labels: np.ndarray, attrs: dict, bg_value: int = 0):
"""
Expand All @@ -80,7 +79,7 @@ def labels_to_shapes(labels: np.ndarray, attrs: dict, bg_value: int = 0):
-------
GeoPandas DataFrame
GeoPandas DataFrame containing the polygons extracted from the labeled image.
"""
import rasterio as rio
import shapely.geometry
Expand All @@ -91,8 +90,8 @@ def labels_to_shapes(labels: np.ndarray, attrs: dict, bg_value: int = 0):
shapes = gpd.GeoDataFrame(
polygons[:, 1], geometry=gpd.GeoSeries(polygons[:, 0]).T, columns=["id"]
)
shapes = shapes[shapes["id"] != bg_value] # Ignore background
shapes = shapes[shapes["id"] != bg_value] # Ignore background

# Validate for SpatialData
sd_shape = ShapesModel.parse(shapes)
sd_shape.attrs = attrs
Expand All @@ -112,13 +111,13 @@ def _(labels: SpatialImage, attrs: dict, bg_value: int = 0):
Dictionary of attributes to set for the SpatialData object.
bg_value : int, optional
Value of the background pixels, by default 0
Returns
-------
GeoPandas DataFrame
GeoPandas DataFrame containing the polygons extracted from the labeled image.
"""

# Convert spatial_image.SpatialImage to np.ndarray
labels = labels.values
return labels_to_shapes(labels, attrs, bg_value)
return labels_to_shapes(labels, attrs, bg_value)
10 changes: 7 additions & 3 deletions bento/tools/_composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def comp(sdata: SpatialData, points_key: str, shape_names: list):
sdata : spatialdata.SpatialData
Updates `sdata.table.uns` with average gene compositions for each shape.
"""
points = get_points(sdata,points_key=points_key, astype="pandas")
points = get_points(sdata, points_key=points_key, astype="pandas")

instance_key = get_instance_key(sdata)
feature_key = get_feature_key(sdata)
Expand All @@ -86,7 +86,9 @@ def comp(sdata: SpatialData, points_key: str, shape_names: list):
sdata.table.uns["comp_stats"] = comp_stats


def comp_diff(sdata: SpatialData, points_key: str, shape_names: list, groupby: str, ref_group: str):
def comp_diff(
sdata: SpatialData, points_key: str, shape_names: list, groupby: str, ref_group: str
):
"""Calculate the average difference in gene composition for shapes across batches of cells. Uses the Wasserstein distance.
Parameters
Expand All @@ -109,7 +111,9 @@ def comp_diff(sdata: SpatialData, points_key: str, shape_names: list, groupby: s
# Get average gene compositions for each batch
comp_stats = dict()
for group, pt_group in points.groupby(groupby):
comp_stats[group] = _get_compositions(pt_group, shape_names, instance_key=instance_key, feature_key=feature_key)
comp_stats[group] = _get_compositions(
pt_group, shape_names, instance_key=instance_key, feature_key=feature_key
)

ref_comp = comp_stats[ref_group]

Expand Down
12 changes: 6 additions & 6 deletions bento/tools/_point_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,12 +133,12 @@ def analyze_points(
for shape in list(
set(obs_attrs).intersection(set([x for x in shape_keys if x != instance_key]))
):
points_df = (
points_df.join(
sdata.shapes[shape].set_index(instance_key), on=instance_key, lsuffix="", rsuffix=f"_{shape}"
)
.rename(columns={shape: f"{shape}_index", f"geometry_{shape}": shape})
)
points_df = points_df.join(
sdata.shapes[shape].set_index(instance_key),
on=instance_key,
lsuffix="",
rsuffix=f"_{shape}",
).rename(columns={shape: f"{shape}_index", f"geometry_{shape}": shape})

# Pull cell_boundaries shape features into the points dataframe
points_df = (
Expand Down
20 changes: 13 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,10 @@ dependencies = [
"xgboost>=2.0.3",
"statsmodels>=0.14.1",
"scikit-learn>=1.4.2",
"pytest>=8.2.1",
]
license = "BSD-2-Clause"
readme = "README.md"
requires-python = ">= 3.9"
include = [
"bento/datasets/datasets.csv",
"bento/models/**/*",
"bento/tools/gene_sets/*",
]

[project.optional-dependencies]
docs = [
Expand All @@ -61,7 +55,19 @@ build-backend = "hatchling.build"

[tool.rye]
managed = true
dev-dependencies = []
dev-dependencies = [
"pytest>=8.2.2",
"pytest-cov>=5.0.0",
"pytest-watcher>=0.4.2",
"pytest-mock>=3.14.0",
]

[tool.hatch.build]
include = [
"bento/datasets/datasets.csv",
"bento/models/**/*",
"bento/tools/gene_sets/*",
]

[tool.hatch.metadata]
allow-direct-references = true
Expand Down
22 changes: 22 additions & 0 deletions tests/test_composition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import bento as bt
import pandas as pd


def test_comp(small_data):
# Set up test data
shape_names = ["cell_boundaries", "nucleus_boundaries"]
points_key = "transcripts"

# Call the comp function
bt.tl.comp(sdata=small_data, points_key=points_key, shape_names=shape_names)

# Check if comp_stats is updated in sdata.table.uns
assert "comp_stats" in small_data.table.uns

# Check the type of comp_stats
assert type(small_data.table.uns["comp_stats"]) == pd.DataFrame

# Check if the shape_names are present in comp_stats
assert all(
shape_name in small_data.table.uns["comp_stats"] for shape_name in shape_names
)
31 changes: 31 additions & 0 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import bento as bt
import pandas as pd
import geopandas as gpd
import spatialdata as sd
import dask.dataframe as dd


def test_sample_data():
sdata = bt.ds.sample_data()

# Check if the returned object is an instance of `bento.SpatialData`
assert isinstance(sdata, sd.SpatialData)

# Check if the required keys are present in the `sdata` object
assert "transcripts" in sdata.points
assert "cell_boundaries" in sdata.shapes
assert "nucleus_boundaries" in sdata.shapes

# Check if the types of the keys are correct
assert isinstance(sdata.points["transcripts"], dd.DataFrame)
assert isinstance(sdata.shapes["cell_boundaries"], gpd.GeoDataFrame)
assert isinstance(sdata.shapes["nucleus_boundaries"], gpd.GeoDataFrame)

# Check if the `feature_name` column is present in the `transcripts` DataFrame
assert "feature_name" in sdata.points["transcripts"]

# Check if the `cell_boundaries` and `nucleus_boundaries` shapes are present
assert "cell_boundaries" in sdata.shapes
assert "nucleus_boundaries" in sdata.shapes
assert isinstance(sdata.shapes["cell_boundaries"], gpd.GeoDataFrame)
assert isinstance(sdata.shapes["nucleus_boundaries"], gpd.GeoDataFrame)
24 changes: 24 additions & 0 deletions tests/test_geometry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import bento as bt
import geopandas as gpd


def test_overlay_intersection(small_data):
s1 = "nucleus_boundaries"
s2 = "cell_boundaries"
name = "overlay_result"

# Perform overlay operation using GeoDataFrame.overlay()
shape1 = small_data.shapes[s1]
shape2 = small_data.shapes[s2]
expected_result = shape1.overlay(shape2, how="intersection", make_valid=True)

# Perform overlay operation using bento.geo.overlay()
bt.geo.overlay(small_data, s1, s2, name, how="intersection")

assert name in small_data.shapes
assert isinstance(small_data.shapes[name], gpd.GeoDataFrame)
assert (
small_data[name]
.geom_almost_equals(expected_result, decimal=1, align=True)
.all()
)

0 comments on commit 786e414

Please sign in to comment.