diff --git a/dynamo/plot/__init__.py b/dynamo/plot/__init__.py index c2f28c70c..585de3bc0 100755 --- a/dynamo/plot/__init__.py +++ b/dynamo/plot/__init__.py @@ -3,7 +3,7 @@ from .cell_cycle import cell_cycle_scores from .clustering import infomap, leiden, louvain, streamline_clusters -from .connectivity import nneighbors +from .connectivity import nneighbors, plot_connectivity from .dimension_reduction import pca, trimap, tsne, umap from .dynamics import dynamics, phase_portraits from .ezplots import ( @@ -105,6 +105,7 @@ "umap", "trimap", "nneighbors", + "plot_connectivity", "cell_wise_vectors", "cell_wise_vectors_3d", "grid_vectors", diff --git a/dynamo/plot/connectivity.py b/dynamo/plot/connectivity.py index 021f70f34..ca7b59025 100755 --- a/dynamo/plot/connectivity.py +++ b/dynamo/plot/connectivity.py @@ -28,6 +28,7 @@ from anndata import AnnData from matplotlib.axes import Axes from matplotlib.figure import Figure +from scipy.sparse import csc_matrix, csr_matrix, issparse from ..configuration import _themes from ..docrep import DocstringProcessor @@ -564,6 +565,263 @@ def nneighbors( +def plot_connectivity( + adata: AnnData, + graph: Union[csr_matrix, csc_matrix, np.ndarray], + x: int = 0, + y: int = 1, + color: Union[str, List[str]] = ["ntr"], + basis: Union[str, List[str]] = ["umap"], + layer: Union[str, List[str]] = ["X"], + highlights: Optional[list] = None, + ncols: int = 1, + edge_bundling: Optional[Literal["hammer"]] = None, + edge_cmap: str = "gray_r", + show_points: bool = True, + labels: Optional[list] = None, + values: Optional[list] = None, + theme: Optional[ + Literal[ + "blue", + "red", + "green", + "inferno", + "fire", + "viridis", + "darkblue", + "darkgreen", + "darkred", + ] + ] = None, + cmap: str = "Blues", + color_key: Union[dict, list, None] = None, + color_key_cmap: str = "Spectral", + background: str = "black", + figsize: tuple = (6, 4), + ax: Optional[Axes] = None, + save_show_or_return: Literal["save", "show", "return"] = "return", + save_kwargs: dict = {}, +) -> Optional[Figure]: + """Plot the connectivity graph. + + A connectivity graph can be one of the followings: + 1. nneighbors: kNN graph constructed from umap/scKDTree/annoy, etc. + 2. mutual kNN shared between spliced or unspliced layer + 3. principal graph that learnt from DDRTree, L1graph or other principal graph algorithms + 4. regulatory network learnt from Scribe + 5. spatial kNN graph + 6. others + + Args: + adata: an Annodata object that include the umap embedding and simplicial graph. + graph: the matrix representing the connectivity relationship. For example `adata.obsp["connectivities"]` or + `adata.uns["neighbors"]["connectivities"]`. Notice that the matrix should have the same size as the data. + x: the first component of the embedding. Defaults to 0. + y: the second component of the embedding. Defaults to 1. + color: gene name(s) or cell annotation column(s) used for coloring the graph. Defaults to ["ntr"]. + basis: the low dimensional embedding to be used to visualize the cell. Defaults to ["umap"]. + layer: the layers of data representing the gene expression level. Defaults to ["X"]. + highlights: the list that cells will be restricted to. Defaults to None. + ncols: the number of columns to be plotted. Defaults to 1. + edge_bundling: the edge bundling method to use. Currently supported are None or 'hammer'. See the datashader + docs on graph visualization for more details. Defaults to None. + edge_cmap: the name of a matplotlib colormap to use for shading/coloring the edges of the connectivity graph. + Note that the `theme`, if specified, will override this. Defaults to "gray_r". + show_points: whether to display the points over top of the edge connectivity. Further options allow for + coloring/shading the points accordingly. Defaults to True. + labels: an array of labels (assumed integer or categorical), one for each data sample. This will be used for + coloring the points in the plot according to their label. Note that this option is mutually exclusive to the + `values` option. Defaults to None. + values: an array of values (assumed float or continuous), one for each sample. This will be used for coloring + the points in the plot according to a colorscale associated to the total range of values. Note that this + option is mutually exclusive to the `labels` option. Defaults to None. + theme: a color theme to use for plotting. A small set of predefined themes are provided which have relatively + good aesthetics. Available themes are: + * 'blue' + * 'red' + * 'green' + * 'inferno' + * 'fire' + * 'viridis' + * 'darkblue' + * 'darkred' + * 'darkgreen'. + Defaults to None. + cmap: the name of a matplotlib colormap to use for coloring or shading points. If no labels or values are passed + this will be used for shading points according to density (largely only of relevance for very large + datasets). If values are passed this will be used for shading according the value. Note that if theme is + passed then this value will be overridden by the corresponding option of the theme. Defaults to "Blues". + color_key: a way to assign colors to categoricals. This can either be an explicit dict mapping labels to colors + (as strings of form '#RRGGBB'), or an array like object providing one color for each distinct category being + provided in `labels`. Either way this mapping will be used to color points according to the label. Note that + if theme is passed then this value will be overridden by the corresponding option of the theme. Defaults to + None. + color_key_cmap: the name of a matplotlib colormap to use for categorical coloring. If an explicit `color_key` is + not given a color mapping for categories can be generated from the label list and selecting a matching list + of colors from the given colormap. Note that if theme is passed then this value will be overridden by the + corresponding option of the theme. Defaults to "Spectral". + background: the color of the background. Usually this will be either 'white' or 'black', but any color name will + work. Ideally one wants to match this appropriately to the colors being used for points etc. This is one of + the things that themes handle for you. Note that if theme is passed then this value will be overridden by + the corresponding option of the theme. Defaults to "black". + figsize: the desired size of the figure. Defaults to (6, 4). + ax: the axis on which the subplot would be shown. If set to be `None`, a new axis would be created. Defaults to + None. + save_show_or_return: whether to save, show or return the figure. Defaults to "return". + save_kwargs: a dictionary that will be passed to the save_show_ret function. By default, it is an empty dictionary + and the save_show_ret function will use the + { + "path": None, + "prefix": 'connectivity_base', + "dpi": None, + "ext": 'pdf', + "transparent": True, + "close": True, + "verbose": True + } + as its parameters. Otherwise, you can provide a dictionary that properly modify those keys according to your + needs. Defaults to {}. + + Raises: + TypeError: wrong type of `x` and `y`. + + Returns: + The matplotlib axis with the plotted connectivity graph by default. If `save_show_or_return` is set to be + `"show"` or `"save"`, nothing would be returned. + """ + + import matplotlib.pyplot as plt + import seaborn as sns + + if type(x) is not int or type(y) is not int: + raise TypeError("x, y have to be integers (components in the a particular embedding {}) ".format(basis)) + + basis = [basis] if isinstance(basis, str) else basis + color = [color] if isinstance(color, str) else color + layer = [layer] if isinstance(layer, str) else layer + + n_c, n_l, n_b = ( + 0 if color is None else len(color), + 0 if layer is None else len(layer), + 0 if basis is None else len(basis), + ) + + check_and_recompute_neighbors(adata, result_prefix="") + coo_graph = graph.tocoo() if issparse(graph) else csr_matrix(graph).tocoo() + + edge_df = pd.DataFrame( + np.vstack([coo_graph.row, coo_graph.col, coo_graph.data]).T, + columns=("source", "target", "weight"), + ) + edge_df["source"] = edge_df.source.astype(np.int32) + edge_df["target"] = edge_df.target.astype(np.int32) + + total_panels, ncols = n_c * n_l * n_b, min(n_c, ncols) + nrow, ncol = int(np.ceil(total_panels / ncols)), ncols + if figsize is None: + figsize = plt.rcParams["figsize"] + + font_color = _select_font_color(background) + if background == "black": + # https://github.com/matplotlib/matplotlib/blob/master/lib/matplotlib/mpl-data/stylelib/dark_background.mplstyle + sns.set( + rc={ + "axes.facecolor": background, + "axes.edgecolor": background, + "figure.facecolor": background, + "figure.edgecolor": background, + "axes.grid": False, + "ytick.color": font_color, + "xtick.color": font_color, + "axes.labelcolor": font_color, + "axes.edgecolor": font_color, + "savefig.facecolor": "k", + "savefig.edgecolor": "k", + "grid.color": font_color, + "text.color": font_color, + "lines.color": font_color, + "patch.edgecolor": font_color, + "figure.edgecolor": font_color, + } + ) + else: + sns.set( + rc={ + "axes.facecolor": background, + "figure.facecolor": background, + "axes.grid": False, + } + ) + + if total_panels > 1: + g = plt.figure(None, (figsize[0] * ncol, figsize[1] * nrow), facecolor=background) + gs = plt.GridSpec(nrow, ncol, wspace=0.12) + + i = 0 + for cur_b in basis: + for cur_l in layer: + prefix = cur_l + "_" + if prefix + cur_b in adata.obsm.keys(): + x_, y_ = ( + adata.obsm[prefix + cur_b][:, int(x)], + adata.obsm[prefix + cur_b][:, int(y)], + ) + else: + continue + for cur_c in color: + _color = adata.obs_vector(cur_c, layer=cur_l) + is_not_continous = _color.dtype.name == "category" + if is_not_continous: + labels = _color + if theme is None: + theme = "glasbey_dark" + else: + values = _color + if theme is None: + theme = "inferno" if cur_l != "velocity" else "div_blue_red" + + if total_panels > 1: + ax = plt.subplot(gs[i]) + i += 1 + + # if highligts is a list of lists - each list is relate to each color element + if highlights is not None: + if is_list_of_lists(highlights): + _highlights = highlights[color.index(cur_c)] + _highlights = _highlights if all([i in _color for i in _highlights]) else None + else: + _highlights = highlights if all([i in _color for i in highlights]) else None + else: + _highlights = None + + ax = connectivity_base( + x_, + y_, + edge_df, + _highlights, + edge_bundling, + edge_cmap, + show_points, + labels, + values, + theme, + cmap, + color_key, + color_key_cmap, + background, + figsize, + ax, + ) + + ax.set_xlabel( + cur_b + "_1", + ) + ax.set_ylabel(cur_b + "_2") + ax.set_title(cur_c) + + return save_show_ret("nneighbors", save_show_or_return, save_kwargs, plt.gcf()) + + def pgraph(): """Plot principal graph of cells that learnt from graph embedding algorithms.