diff --git a/bento/geometry/_ops.py b/bento/geometry/_ops.py index 1579790..6eec55c 100644 --- a/bento/geometry/_ops.py +++ b/bento/geometry/_ops.py @@ -61,7 +61,6 @@ def overlay( ) - @singledispatch def labels_to_shapes(labels: np.ndarray, attrs: dict, bg_value: int = 0): """ @@ -80,7 +79,7 @@ def labels_to_shapes(labels: np.ndarray, attrs: dict, bg_value: int = 0): ------- GeoPandas DataFrame GeoPandas DataFrame containing the polygons extracted from the labeled image. - + """ import rasterio as rio import shapely.geometry @@ -91,8 +90,8 @@ def labels_to_shapes(labels: np.ndarray, attrs: dict, bg_value: int = 0): shapes = gpd.GeoDataFrame( polygons[:, 1], geometry=gpd.GeoSeries(polygons[:, 0]).T, columns=["id"] ) - shapes = shapes[shapes["id"] != bg_value] # Ignore background - + shapes = shapes[shapes["id"] != bg_value] # Ignore background + # Validate for SpatialData sd_shape = ShapesModel.parse(shapes) sd_shape.attrs = attrs @@ -112,13 +111,13 @@ def _(labels: SpatialImage, attrs: dict, bg_value: int = 0): Dictionary of attributes to set for the SpatialData object. bg_value : int, optional Value of the background pixels, by default 0 - + Returns ------- GeoPandas DataFrame GeoPandas DataFrame containing the polygons extracted from the labeled image. """ - + # Convert spatial_image.SpatialImage to np.ndarray labels = labels.values - return labels_to_shapes(labels, attrs, bg_value) \ No newline at end of file + return labels_to_shapes(labels, attrs, bg_value) diff --git a/bento/tools/_composition.py b/bento/tools/_composition.py index b5e5b91..116573f 100644 --- a/bento/tools/_composition.py +++ b/bento/tools/_composition.py @@ -73,7 +73,7 @@ def comp(sdata: SpatialData, points_key: str, shape_names: list): sdata : spatialdata.SpatialData Updates `sdata.table.uns` with average gene compositions for each shape. """ - points = get_points(sdata,points_key=points_key, astype="pandas") + points = get_points(sdata, points_key=points_key, astype="pandas") instance_key = get_instance_key(sdata) feature_key = get_feature_key(sdata) @@ -86,7 +86,9 @@ def comp(sdata: SpatialData, points_key: str, shape_names: list): sdata.table.uns["comp_stats"] = comp_stats -def comp_diff(sdata: SpatialData, points_key: str, shape_names: list, groupby: str, ref_group: str): +def comp_diff( + sdata: SpatialData, points_key: str, shape_names: list, groupby: str, ref_group: str +): """Calculate the average difference in gene composition for shapes across batches of cells. Uses the Wasserstein distance. Parameters @@ -109,7 +111,9 @@ def comp_diff(sdata: SpatialData, points_key: str, shape_names: list, groupby: s # Get average gene compositions for each batch comp_stats = dict() for group, pt_group in points.groupby(groupby): - comp_stats[group] = _get_compositions(pt_group, shape_names, instance_key=instance_key, feature_key=feature_key) + comp_stats[group] = _get_compositions( + pt_group, shape_names, instance_key=instance_key, feature_key=feature_key + ) ref_comp = comp_stats[ref_group] diff --git a/bento/tools/_point_features.py b/bento/tools/_point_features.py index 99ab170..c485653 100644 --- a/bento/tools/_point_features.py +++ b/bento/tools/_point_features.py @@ -133,12 +133,12 @@ def analyze_points( for shape in list( set(obs_attrs).intersection(set([x for x in shape_keys if x != instance_key])) ): - points_df = ( - points_df.join( - sdata.shapes[shape].set_index(instance_key), on=instance_key, lsuffix="", rsuffix=f"_{shape}" - ) - .rename(columns={shape: f"{shape}_index", f"geometry_{shape}": shape}) - ) + points_df = points_df.join( + sdata.shapes[shape].set_index(instance_key), + on=instance_key, + lsuffix="", + rsuffix=f"_{shape}", + ).rename(columns={shape: f"{shape}_index", f"geometry_{shape}": shape}) # Pull cell_boundaries shape features into the points dataframe points_df = ( diff --git a/pyproject.toml b/pyproject.toml index 3b7f14a..5b854f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,16 +27,10 @@ dependencies = [ "xgboost>=2.0.3", "statsmodels>=0.14.1", "scikit-learn>=1.4.2", - "pytest>=8.2.1", ] license = "BSD-2-Clause" readme = "README.md" requires-python = ">= 3.9" -include = [ - "bento/datasets/datasets.csv", - "bento/models/**/*", - "bento/tools/gene_sets/*", -] [project.optional-dependencies] docs = [ @@ -61,7 +55,19 @@ build-backend = "hatchling.build" [tool.rye] managed = true -dev-dependencies = [] +dev-dependencies = [ + "pytest>=8.2.2", + "pytest-cov>=5.0.0", + "pytest-watcher>=0.4.2", + "pytest-mock>=3.14.0", +] + +[tool.hatch.build] +include = [ + "bento/datasets/datasets.csv", + "bento/models/**/*", + "bento/tools/gene_sets/*", +] [tool.hatch.metadata] allow-direct-references = true diff --git a/tests/test_composition.py b/tests/test_composition.py new file mode 100644 index 0000000..cf50ff1 --- /dev/null +++ b/tests/test_composition.py @@ -0,0 +1,22 @@ +import bento as bt +import pandas as pd + + +def test_comp(small_data): + # Set up test data + shape_names = ["cell_boundaries", "nucleus_boundaries"] + points_key = "transcripts" + + # Call the comp function + bt.tl.comp(sdata=small_data, points_key=points_key, shape_names=shape_names) + + # Check if comp_stats is updated in sdata.table.uns + assert "comp_stats" in small_data.table.uns + + # Check the type of comp_stats + assert type(small_data.table.uns["comp_stats"]) == pd.DataFrame + + # Check if the shape_names are present in comp_stats + assert all( + shape_name in small_data.table.uns["comp_stats"] for shape_name in shape_names + ) diff --git a/tests/test_datasets.py b/tests/test_datasets.py new file mode 100644 index 0000000..2de3813 --- /dev/null +++ b/tests/test_datasets.py @@ -0,0 +1,31 @@ +import bento as bt +import pandas as pd +import geopandas as gpd +import spatialdata as sd +import dask.dataframe as dd + + +def test_sample_data(): + sdata = bt.ds.sample_data() + + # Check if the returned object is an instance of `bento.SpatialData` + assert isinstance(sdata, sd.SpatialData) + + # Check if the required keys are present in the `sdata` object + assert "transcripts" in sdata.points + assert "cell_boundaries" in sdata.shapes + assert "nucleus_boundaries" in sdata.shapes + + # Check if the types of the keys are correct + assert isinstance(sdata.points["transcripts"], dd.DataFrame) + assert isinstance(sdata.shapes["cell_boundaries"], gpd.GeoDataFrame) + assert isinstance(sdata.shapes["nucleus_boundaries"], gpd.GeoDataFrame) + + # Check if the `feature_name` column is present in the `transcripts` DataFrame + assert "feature_name" in sdata.points["transcripts"] + + # Check if the `cell_boundaries` and `nucleus_boundaries` shapes are present + assert "cell_boundaries" in sdata.shapes + assert "nucleus_boundaries" in sdata.shapes + assert isinstance(sdata.shapes["cell_boundaries"], gpd.GeoDataFrame) + assert isinstance(sdata.shapes["nucleus_boundaries"], gpd.GeoDataFrame) diff --git a/tests/test_geometry.py b/tests/test_geometry.py new file mode 100644 index 0000000..8f3eaf9 --- /dev/null +++ b/tests/test_geometry.py @@ -0,0 +1,24 @@ +import bento as bt +import geopandas as gpd + + +def test_overlay_intersection(small_data): + s1 = "nucleus_boundaries" + s2 = "cell_boundaries" + name = "overlay_result" + + # Perform overlay operation using GeoDataFrame.overlay() + shape1 = small_data.shapes[s1] + shape2 = small_data.shapes[s2] + expected_result = shape1.overlay(shape2, how="intersection", make_valid=True) + + # Perform overlay operation using bento.geo.overlay() + bt.geo.overlay(small_data, s1, s2, name, how="intersection") + + assert name in small_data.shapes + assert isinstance(small_data.shapes[name], gpd.GeoDataFrame) + assert ( + small_data[name] + .geom_almost_equals(expected_result, decimal=1, align=True) + .all() + )