diff --git a/dynamo/plot/__init__.py b/dynamo/plot/__init__.py index c2f28c70c..585de3bc0 100755 --- a/dynamo/plot/__init__.py +++ b/dynamo/plot/__init__.py @@ -3,7 +3,7 @@ from .cell_cycle import cell_cycle_scores from .clustering import infomap, leiden, louvain, streamline_clusters -from .connectivity import nneighbors +from .connectivity import nneighbors, plot_connectivity from .dimension_reduction import pca, trimap, tsne, umap from .dynamics import dynamics, phase_portraits from .ezplots import ( @@ -105,6 +105,7 @@ "umap", "trimap", "nneighbors", + "plot_connectivity", "cell_wise_vectors", "cell_wise_vectors_3d", "grid_vectors", diff --git a/dynamo/plot/connectivity.py b/dynamo/plot/connectivity.py index 242864340..ca7b59025 100755 --- a/dynamo/plot/connectivity.py +++ b/dynamo/plot/connectivity.py @@ -570,9 +570,9 @@ def plot_connectivity( graph: Union[csr_matrix, csc_matrix, np.ndarray], x: int = 0, y: int = 1, - color: List[str] = ["ntr"], - basis: List[str] = ["umap"], - layer: List[str] = ["X"], + color: Union[str, List[str]] = ["ntr"], + basis: Union[str, List[str]] = ["umap"], + layer: Union[str, List[str]] = ["X"], highlights: Optional[list] = None, ncols: int = 1, edge_bundling: Optional[Literal["hammer"]] = None, @@ -696,6 +696,10 @@ def plot_connectivity( if type(x) is not int or type(y) is not int: raise TypeError("x, y have to be integers (components in the a particular embedding {}) ".format(basis)) + basis = [basis] if isinstance(basis, str) else basis + color = [color] if isinstance(color, str) else color + layer = [layer] if isinstance(layer, str) else layer + n_c, n_l, n_b = ( 0 if color is None else len(color), 0 if layer is None else len(layer),