diff --git a/dynamo/plot/scatters.py b/dynamo/plot/scatters.py index b2c7532d0..fd2902908 100755 --- a/dynamo/plot/scatters.py +++ b/dynamo/plot/scatters.py @@ -18,6 +18,7 @@ from matplotlib.lines import Line2D from matplotlib.colors import rgb2hex, to_hex from pandas.api.types import is_categorical_dtype +from scipy.sparse import issparse from ..configuration import _themes from ..docrep import DocstringProcessor @@ -48,7 +49,7 @@ @docstrings.get_sectionsf("scatters") -def scatters( +def scatters_legacy( adata: AnnData, basis: str = "umap", x: int = 0, @@ -985,7 +986,7 @@ def _map_cur_axis(cur: str) -> Tuple[np.ndarray, str]: return points, cur_title -def scatters_interactive( +def scatters_interactive_legacy( adata: AnnData, basis: str = "umap", x: Union[int, str] = 0, @@ -1330,3 +1331,1510 @@ def _plot_basis_layer_pv(cur_b: str, cur_l: str) -> None: save_show_or_return=save_show_or_return, save_kwargs=save_kwargs, ) + + +def scatters_interactive( + adata: AnnData, + basis: str = "umap", + x: Union[int, str] = 0, + y: Union[int, str] = 1, + z: Union[int, str] = 2, + color: str = "ntr", + layer: str = "X", + plot_method: str = "pv", + labels: Optional[list] = None, + values: Optional[list] = None, + cmap: Optional[str] = None, + theme: Optional[str] = None, + background: Optional[str] = None, + color_key: Union[Dict[str, str], List[str], None] = None, + color_key_cmap: Optional[str] = None, + use_smoothed: bool = True, + sym_c: bool = False, + smooth: bool = False, + save_show_or_return: Literal["save", "show", "return", "both", "all"] = "show", + save_kwargs: Dict[str, Any] = {}, + **kwargs, +): + """Plot an embedding as points with Pyvista. Currently only 3D input is supported. For 2D data, `scatters` is a + better alternative. + + The function will use the colors from matplotlib to keep consistence with other plotting functions. + + Args: + adata: an AnnData object. + basis: the reduced dimension stored in adata.obsm. The specific basis key will be constructed in the following + priority if exits: 1) specific layer input + basis 2) X_ + basis 3) basis. E.g. if basis is PCA, `scatters` + is going to look for 1) if specific layer is spliced, `spliced_pca` 2) `X_pca` (dynamo convention) 3) `pca`. + Defaults to "umap". + x: the column index of the low dimensional embedding for the x-axis. Defaults to 0. + y: the column index of the low dimensional embedding for the y-axis. Defaults to 1. + z: the column index of the low dimensional embedding for the z-axis. Defaults to 2. + color: any column names or gene expression, etc. that will be used for coloring cells. Defaults to "ntr". + layer: the layer of data to use for the scatter plot. Defaults to "X". + labels: an array of labels (assumed integer or categorical), one for each data sample. This will be used for + coloring the points in the plot according to their label. Note that this option is mutually exclusive to the + `values` option. Defaults to None. + values: an array of values (assumed float or continuous), one for each sample. This will be used for coloring + the points in the plot according to a colorscale associated to the total range of values. Note that this + option is mutually exclusive to the `labels` option. Defaults to None. + theme: A color theme to use for plotting. A small set of predefined themes are provided which have relatively + good aesthetics. Defaults to None. + cmap: The name of a matplotlib colormap to use for coloring or shading points. If no labels or values are passed + this will be used for shading points according to density (largely only of relevance for very large + datasets). If values are passed this will be used for shading according the value. Note that if theme is + passed then this value will be overridden by the corresponding option of the theme. Defaults to None. + background: the color of the background. Usually this will be either 'white' or 'black', but any color name will + work. Ideally one wants to match this appropriately to the colors being used for points etc. This is one of + the things that themes handle for you. Note that if theme is passed then this value will be overridden by + the corresponding option of the theme. Defaults to None. + color_key: the method to assign colors to categoricals. This can either be an explicit dict mapping labels to + colors (as strings of form '#RRGGBB'), or an array like object providing one color for each distinct + category being provided in `labels`. Either way this mapping will be used to color points according to the + label. Note that if theme is passed then this value will be overridden by the corresponding option of the + theme. Defaults to None. + color_key_cmap: the name of a matplotlib colormap to use for categorical coloring. If an explicit `color_key` is + not given a color mapping for categories can be generated from the label list and selecting a matching list + of colors from the given colormap. Note that if theme is passed then this value will be overridden by the + corresponding option of the theme. Defaults to None. + use_smoothed: whether to use smoothed values (i.e. M_s / M_u instead of spliced / unspliced, etc.). Defaults to + True. + sym_c: whether do you want to make the limits of continuous color to be symmetric, normally this should be used + for plotting velocity, jacobian, curl, divergence or other types of data with both positive or negative + values. Defaults to False. + smooth: whether do you want to further smooth data and how much smoothing do you want. If it is `False`, no + smoothing will be applied. If `True`, smoothing based on one-step diffusion of connectivity matrix + (`.uns['moment_cnn']`) will be applied. If a number larger than 1, smoothing will be based on `smooth` steps + of diffusion. + save_show_or_return: whether to save, show or return the figure. If "both", it will save and plot the figure at + the same time. If "all", the figure will be saved, displayed and the associated axis and other object will + be return. Defaults to "show". + save_kwargs: A dictionary that will be passed to the saving function. By default, it is an empty dictionary + and the saving function will use the {"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', + "title": PyVista Export, "raster": True, "painter": True} as its parameters. Otherwise, you can provide a + dictionary that properly modify those keys according to your needs. Defaults to {}. + **kwargs: any other kwargs that would be passed to `Plotter.add_points()`. + + Returns: + If `save_show_or_return` is `save`, `show` or `both`, the function will return nothing but show or save the + figure. If `save_show_or_return` is `return`, the function will return the axis object(s) that contains the + figure. + """ + + if plot_method == "pv": + try: + import pyvista as pv + except ImportError: + raise ImportError("Please install pyvista first.") + elif plot_method == "plotly": + try: + import plotly.express as px + import plotly.graph_objects as go + from plotly.subplots import make_subplots + except ImportError: + raise ImportError("Please install plotly first.") + else: + raise NotImplementedError("Current plot method not supported.") + + if use_smoothed: + mapper = get_mapper() + + # check color, layer, basis -> convert to list + basis, background, color, layer, x, y, z, nrow, ncol, total_panels = _validate_parameters( + adata=adata, + basis=basis, + background=background, + color=color, + layer=layer, + x=x, + y=y, + z=z, + ) + + subplot_indices = [[i, j] for i in range(nrow) for j in range(ncol)] + cur_subplot = 0 + colors_list = [] + + if total_panels == 1: + pl = pv.Plotter() if plot_method == "pv" else make_subplots(rows=1, cols=1, specs=[[{"type": "scatter3d"}]]) + else: + pl = ( + pv.Plotter(shape=(nrow, ncol)) + if plot_method == "pv" + else + make_subplots(rows=nrow, cols=ncol, specs=[[{"type": "scatter3d"} for _ in range(ncol)] for _ in range(nrow)]) + ) + + def _plot_basis_layer_pv(cur_b: str, cur_l: str, cur_c: str, cur_x: str, cur_y: str, cur_z: str) -> None: + """A helper function for plotting a specific basis/layer data + + Args: + cur_b: current basis + cur_l: current layer + """ + nonlocal background, cur_subplot, cur_l_smoothed, cmap, sym_c, labels, values, theme, color_key, color_key_cmap + + cur_l_smoothed = cur_l_smoothed if cur_l_smoothed is not None else cur_l + main_debug("coloring scatter of cur_c: %s" % str(cur_c)) + + main_debug("handling coordinates, cur_x: %s, cur_y: %s, cur_z: %s" % (cur_x, cur_y, cur_z)) + + points, cur_title = _map_to_points( + adata, + axis_x=cur_x, + axis_y=cur_y, + axis_z=cur_z, + basis_key=basis_key, + cur_c=cur_c, + cur_b=cur_b, + cur_l_smoothed=cur_l_smoothed, + projection="3d", + ) + + ( + _color, + _labels, + _values, + _theme_, + _cmap, + _color_key_cmap, + background, + font_color, + is_not_continuous, + cur_title, + ) = _get_color_parameters( + adata=adata, + color=cur_c, + layer=cur_l, + labels=labels, + values=values, + theme=theme, + cmap=cmap, + color_key_cmap=color_key_cmap, + background=background, + title=cur_title, + ) + + + if smooth and not is_not_continuous: + main_debug("smooth and not continuous") + knn = adata.obsp["moments_con"] + _values = ( + calc_1nd_moment(_values, knn)[0] + if smooth in [1, True] + else calc_1nd_moment(_values, knn**smooth)[0] + ) + + colors, color_type, _ = calculate_colors( + points.values, + labels=_labels, + values=_values, + cmap=_cmap, + color_key=color_key, + color_key_cmap=_color_key_cmap, + background=background, + sym_c=sym_c, + ) + + colors_list.append(colors) + + if plot_method == "pv": + if total_panels > 1: + pl.subplot(subplot_indices[cur_subplot][0], subplot_indices[cur_subplot][1]) + + pvdataset = pv.PolyData(points.values) + pvdataset.point_data["colors"] = np.stack(colors) + pl.add_points(pvdataset, scalars="colors", preference='point', rgb=True, cmap=_cmap, **kwargs) + + if color_type == "labels": + type_color_dict = {cell_type: cell_color for cell_type, cell_color in zip(_labels, colors)} + type_color_pair = [[k, v] for k, v in type_color_dict.items()] + pl.add_legend(labels=type_color_pair) + else: + pl.add_scalar_bar() # TODO: fix the bug that scalar bar only works in the first plot + + pl.add_text(cur_title) + pl.add_axes(xlabel=points.columns[0], ylabel=points.columns[1], zlabel=points.columns[2]) + elif plot_method == "plotly": + + pl.add_trace( + go.Scatter3d( + x=points.iloc[:, 0], + y=points.iloc[:, 1], + z=points.iloc[:, 2], + mode="markers", + marker=dict( + color=colors, + ), + text=_labels if color_type == "labels" else _values, + **kwargs, + ), + row=subplot_indices[cur_subplot][0] + 1, col=subplot_indices[cur_subplot][1] + 1, + ) + + pl.update_layout( + scene=dict( + xaxis_title=points.columns[0], + yaxis_title=points.columns[1], + zaxis_title=points.columns[2] + ), + ) + + cur_subplot += 1 + + for cur_b in basis: + for cur_l in layer: + main_debug("Plotting basis:%s, layer: %s" % (str(basis), str(layer))) + main_debug("colors: %s" % (str(color))) + + if cur_l in ["acceleration", "curvature", "divergence", "velocity_S", "velocity_T"]: + cur_l_smoothed = cur_l + cmap, sym_c = "bwr", True # TODO maybe use other divergent color map in the future + else: + if use_smoothed: + cur_l_smoothed = cur_l if cur_l.startswith("M_") | cur_l.startswith("velocity") else mapper[cur_l] + if cur_l.startswith("velocity"): + cmap, sym_c = "bwr", True + + if cur_l + "_" + cur_b in adata.obsm.keys(): + prefix = cur_l + "_" + elif ("X_" + cur_b) in adata.obsm.keys(): + prefix = "X_" + elif cur_b in adata.obsm.keys(): + # special case for spatial for compatibility with other packages + prefix = "" + else: + raise ValueError("Please check if basis=%s exists in adata.obsm" % basis) + + basis_key = prefix + cur_b + main_info("plotting with basis key=%s" % basis_key, indent_level=2) + + for cur_c in color: + x, y, z = _standardize_dimensions(x, y, z, adata.n_obs, projection="3d") + for cur_x, cur_y, cur_z in zip(x, y, z): + _plot_basis_layer_pv(cur_b, cur_l, cur_c, cur_x, cur_y, cur_z) + + + return save_pyvista_plotter( + pl=pl, + colors_list=colors_list, + save_show_or_return=save_show_or_return, + save_kwargs=save_kwargs, + ) if plot_method == "pv" else save_plotly_figure( + pl=pl, + colors_list=colors_list, + save_show_or_return=save_show_or_return, + save_kwargs=save_kwargs, + ) + + +def scatters( + adata: AnnData, + basis: str = "umap", + x: int = 0, + y: int = 1, + z: int = 2, + color: str = "ntr", + layer: str = "X", + highlights: Optional[list] = None, + labels: Optional[list] = None, + values: Optional[list] = None, + theme: Optional[ + Literal[ + "blue", + "red", + "green", + "inferno", + "fire", + "viridis", + "darkblue", + "darkred", + "darkgreen", + ] + ] = None, + cmap: Optional[str] = None, + color_key: Union[Dict[str, str], List[str], None] = None, + color_key_cmap: Optional[str] = None, + background: Optional[str] = None, + ncols: int = 4, + pointsize: Optional[float] = None, + figsize: Tuple[float, float] = (6, 4), + show_legend: str = "on data", + use_smoothed: bool = True, + aggregate: Optional[str] = None, + show_arrowed_spines: bool = False, + ax: Optional[Axes] = None, + sort: Literal["raw", "abs", "neg"] = "raw", + save_show_or_return: Literal["save", "show", "return", "both", "all"] = "show", + save_kwargs: Dict[str, Any] = {}, + return_all: bool = False, + add_gamma_fit: bool = False, + frontier: bool = False, + contour: bool = False, + ccmap: Optional[str] = None, + alpha: float = 0.1, + calpha: float = 0.4, + sym_c: bool = False, + smooth: bool = False, + dpi: int = 100, + inset_dict: Dict[str, Any] = {}, + marker: Optional[str] = None, + group: Optional[str] = None, + add_group_gamma_fit: bool = False, + affine_transform_degree: Optional[int] = None, + affine_transform_A: Optional[float] = None, + affine_transform_b: Optional[float] = None, + stack_colors: bool = False, + stack_colors_threshold: float = 0.001, + stack_colors_title: str = "stacked colors", + stack_colors_legend_size: float = 2, + stack_colors_cmaps: Optional[List[str]] = None, + despline: bool = True, + deaxis: bool = True, + despline_sides: Optional[List[str]] = None, + projection: str = "2d", + **kwargs, +) -> Union[ + Axes, + List[Axes], + Tuple[Axes, List[str], Literal["white", "black"]], + Tuple[List[Axes], List[str], Literal["white", "black"]], + None, +]: + """Plot an embedding as points. + + Support both 2D and 3D embeddings. While there are many optional parameters to further control and tailor the + plotting, you need only pass in the trained/fit umap model to get results. This plot utility will attempt to do the + hard work of avoiding overplotting issues, and make it easy to automatically color points by a categorical labelling + or numeric values. This method is intended to be used within a Jupyter notebook with `%matplotlib inline`. + + Args: + adata: an AnnData object. + basis: the reduced dimension stored in adata.obsm. The specific basis key will be constructed in the following + priority if exits: 1) specific layer input + basis 2) X_ + basis 3) basis. E.g. if basis is PCA, `scatters` + is going to look for 1) if specific layer is spliced, `spliced_pca` 2) `X_pca` (dynamo convention) 3) `pca`. + Defaults to "umap". + x: the column index of the low dimensional embedding for the x-axis. Defaults to 0. + y: the column index of the low dimensional embedding for the y-axis. Defaults to 1. + z: the column index of the low dimensional embedding for the z-axis. Defaults to 2. + color: any column names or gene expression, etc. that will be used for coloring cells. Defaults to "ntr". + layer: the layer of data to use for the scatter plot. Defaults to "X". + highlights: the color group that will be highlighted. If highlights is a list of lists, each list is relate to + each color element. Defaults to None. + labels: an array of labels (assumed integer or categorical), one for each data sample. This will be used for + coloring the points in the plot according to their label. Note that this option is mutually exclusive to the + `values` option. Defaults to None. + values: an array of values (assumed float or continuous), one for each sample. This will be used for coloring + the points in the plot according to a colorscale associated to the total range of values. Note that this + option is mutually exclusive to the `labels` option. Defaults to None. + theme: A color theme to use for plotting. A small set of predefined themes are provided which have relatively + good aesthetics. Available themes are: {'blue', 'red', 'green', 'inferno', 'fire', 'viridis', 'darkblue', + 'darkred', 'darkgreen'}. Defaults to None. + cmap: The name of a matplotlib colormap to use for coloring or shading points. If no labels or values are passed + this will be used for shading points according to density (largely only of relevance for very large + datasets). If values are passed this will be used for shading according the value. Note that if theme is + passed then this value will be overridden by the corresponding option of the theme. Defaults to None. + color_key: the method to assign colors to categoricals. This can either be an explicit dict mapping labels to + colors (as strings of form '#RRGGBB'), or an array like object providing one color for each distinct + category being provided in `labels`. Either way this mapping will be used to color points according to the + label. Note that if theme is passed then this value will be overridden by the corresponding option of the + theme. Defaults to None. + color_key_cmap: the name of a matplotlib colormap to use for categorical coloring. If an explicit `color_key` is + not given a color mapping for categories can be generated from the label list and selecting a matching list + of colors from the given colormap. Note that if theme is passed then this value will be overridden by the + corresponding option of the theme. Defaults to None. + background: the color of the background. Usually this will be either 'white' or 'black', but any color name will + work. Ideally one wants to match this appropriately to the colors being used for points etc. This is one of + the things that themes handle for you. Note that if theme is passed then this value will be overridden by + the corresponding option of the theme. Defaults to None. + ncols: the number of columns for the figure. Defaults to 4. + pointsize: the scale of the point size. Actual point cell size is calculated as + `500.0 / np.sqrt(adata.shape[0]) * pointsize`. Defaults to None. + figsize: the width and height of a figure. Defaults to (6, 4). + show_legend: whether to display a legend of the labels. Defaults to "on data". + use_smoothed: whether to use smoothed values (i.e. M_s / M_u instead of spliced / unspliced, etc.). Defaults to + True. + aggregate: the column in adata.obs that will be used to aggregate data points. Defaults to None. + show_arrowed_spines: whether to show a pair of arrowed spines representing the basis of the scatter is currently + using. Defaults to False. + ax: the matplotlib axes object where new plots will be added to. Only applicable to drawing a single component. + Defaults to None. + sort: the method to reorder data so that high values points will be on top of background points. Can be one of + {'raw', 'abs', 'neg'}, i.e. sorted by raw data, sort by absolute values or sort by negative values. Defaults + to "raw". + save_show_or_return: whether to save, show or return the figure. If "both", it will save and plot the figure at + the same time. If "all", the figure will be saved, displayed and the associated axis and other object will + be return. Defaults to "show". + save_kwargs: A dictionary that will passed to the save_show_ret function. By default it is an empty dictionary and + the save_show_ret function will use the {"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', + "transparent": True, "close": True, "verbose": True} as its parameters. Otherwise you can provide a + dictionary that properly modify those keys according to your needs. Defaults to {}. + return_all: whether to return all the scatter related variables. Defaults to False. + add_gamma_fit: whether to add the line of the gamma fitting. This will automatically turn on if `basis` points + to gene names and those genes have went through gamma fitting. Defaults to False. + frontier: whether to add the frontier. Scatter plots can be enhanced by using transparency (alpha) in order to + show area of high density and multiple scatter plots can be used to delineate a frontier. See matplotlib + tips & tricks cheatsheet (https://github.com/matplotlib/cheatsheets). Originally inspired by figures from + scEU-seq paper: https://science.sciencemag.org/content/367/6482/1151. If `contour` is set to be True, + `frontier` will be ignored as `contour` also add an outlier for data points. Defaults to False. + contour: whether to add an countor on top of scatter plots. We use tricontourf to plot contour for non-gridded + data. The shapely package was used to create a polygon of the concave hull of the scatters. With the polygon + we then check if the mean of the triangulated points are within the polygon and use this as our condition to + form the mask to create the contour. We also add the polygon shape as a frontier of the data point (similar + to when setting `frontier = True`). When the color of the data points is continuous, we will use the same + cmap as for the scatter points by default, when color is categorical, no contour will be drawn but just the + polygon. cmap can be set with `ccmap` argument. See below. This has recently changed to use seaborn's + kdeplot. Defaults to False. + ccmap: the name of a matplotlib colormap to use for coloring or shading points the contour. See above. + Defaults to None. + alpha: the point's alpha (transparency) value. Defaults to 0.1. + calpha: contour alpha value passed into sns.kdeplot. The value should be inbetween [0, 1]. Defaults to 0.4. + sym_c: whether do you want to make the limits of continuous color to be symmetric, normally this should be used + for plotting velocity, jacobian, curl, divergence or other types of data with both positive or negative + values. Defaults to False. + smooth: whether do you want to further smooth data and how much smoothing do you want. If it is `False`, no + smoothing will be applied. If `True`, smoothing based on one step diffusion of connectivity matrix + (`.uns['moment_cnn']`) will be applied. If a number larger than 1, smoothing will based on `smooth` steps of + diffusion. + dpi: the resolution of the figure in dots-per-inch. Dots per inches (dpi) determines how many pixels the figure + comprises. dpi is different from ppi or points per inches. Note that most elements like lines, markers, + texts have a size given in points so you can convert the points to inches. Matplotlib figures use Points per + inch (ppi) of 72. A line with thickness 1 point will be 1./72. inch wide. A text with fontsize 12 points + will be 12./72. inch heigh. Of course if you change the figure size in inches, points will not change, so a + larger figure in inches still has the same size of the elements.Changing the figure size is thus like taking + a piece of paper of a different size. Doing so, would of course not change the width of the line drawn with + the same pen. On the other hand, changing the dpi scales those elements. At 72 dpi, a line of 1 point size + is one pixel strong. At 144 dpi, this line is 2 pixels strong. A larger dpi will therefore act like a + magnifying glass. All elements are scaled by the magnifying power of the lens. see more details at answer 2 + by @ImportanceOfBeingErnest: + https://stackoverflow.com/questions/47633546/relationship-between-dpi-and-figure-size. Defaults to 100. + inset_dict: a dictionary of parameters in inset_ax. Example, something like {"width": "5%", "height": "50%", "loc": + 'lower left', "bbox_to_anchor": (0.85, 0.90, 0.145, 0.145), "bbox_transform": ax.transAxes, "borderpad": 0} + See more details at https://matplotlib.org/api/_as_gen/mpl_toolkits.axes_grid1.inset_locator.inset_axes.html + or https://stackoverflow.com/questions/39803385/what-does-a-4-element-tuple-argument-for-bbox-to-anchor-mean-in-matplotlib. + Defaults to {}. + marker: the marker style. marker can be either an instance of the class or the text shorthand for a particular + marker. See matplotlib.markers for more information about marker styles. Defaults to None. + group: the key in `adata.obs` corresponding to the cell group data. Defaults to None. + add_group_gamma_fit: whether to plot the cell group's gamma fit results. Defaults to False. + affine_transform_degree: transform coordinates of points according to some degree. Defaults to None. + affine_transform_A: coefficients in affine transformation Ax + b. 2D for now. Defaults to None. + affine_transform_b: bias in affine transformation Ax + b. Defaults to None. + stack_colors: whether to stack all color on the same ax passed above. Currently only support 18 sequential + matplotlib default cmaps assigning to different color groups. (#colors should be smaller than 18, reuse if + #colors > 18. TODO generate cmaps according to #colors). Defaults to False. + stack_colors_threshold: a threshold for filtering out points values < threshold when drawing each color. E.g. if + you do not want points with values < 1 showing up on axis, set threshold to be 1. Defaults to 0.001. + stack_colors_title: the title for the stack_color plot. Defaults to "stacked colors". + stack_colors_legend_size: the legend size in stack color plot. Defaults to 2. + stack_colors_cmaps: a list of cmaps that will be used to map values to color when stacking colors on the same + subplot. The order corresponds to the order of color. Defaults to None. + despline: whether to remove splines of the figure. Defaults to True. + deaxis: whether to remove axis ticks of the figure. Defaults to True. + despline_sides: which side of splines should be removed. Can be any combination of `["bottom", "right", "top", "left"]`. Defaults to None. + projection: the projection property of the matplotlib.Axes. Defaults to "2d". + **kwargs: any other kwargs that would be passed to `pyplot.scatters`. + + Raises: + ValueError: invalid adata object: lacking of required layers. + ValueError: `basis` not found in `adata.obsm`. + ValueError: invalid `x` or `y`. + ValueError: `labels` and `values` conflicted. + ValueError: invalid velocity estimation in `adata`. + ValueError: invalid velocity estimation in `adata`. + + Returns: + None would be returned by default. If `save_show_or_return` is set to be 'return' or 'all', the matplotlib axes + object of the generated plots would be returned. If `return_all` is set to be true, the list of colors used and + the font color would also be returned. + """ + import matplotlib.pyplot as plt + + if projection == "2d": + projection = None + + if calpha < 0 or calpha > 1: + main_warning( + "calpha=%f is invalid (smaller than 0 or larger than 1) and may cause potential issues. Please check." + % calpha + ) + + if stack_colors and stack_colors_cmaps is None: + main_info("using default stack colors") + stack_colors_cmaps = [ + "Greys", + "Purples", + "Blues", + "Greens", + "Oranges", + "Reds", + "YlOrBr", + "YlOrRd", + "OrRd", + "PuRd", + "RdPu", + "BuPu", + "GnBu", + "PuBu", + "YlGnBu", + "PuBuGn", + "BuGn", + "YlGn", + ] + stack_legend_handles = [] + if stack_colors: + color_key = None + + if not (affine_transform_degree is None): + affine_transform_A = gen_rotation_2d(affine_transform_degree) + affine_transform_b = 0 + + if contour: + frontier = False + + if pointsize is None: + point_size = 16000.0 / np.sqrt(adata.shape[0]) + else: + point_size = 16000.0 / np.sqrt(adata.shape[0]) * pointsize + + scatter_kwargs = dict( + alpha=alpha, + s=point_size, + edgecolor=None, + linewidth=0, + rasterized=True, + marker=marker, + ) # (0, 0, 0, 1) + if kwargs is not None: + scatter_kwargs.update(kwargs) + + if figsize is None: + figsize = plt.rcParams["figsize"] + + basis, background, color, layer, x, y, z, nrow, ncol, total_panels = _validate_parameters( + adata=adata, + basis=basis, + background=background, + color=color, + layer=layer, + x=x, + y=y, + z=z, + ncols=ncols, + ) + + if total_panels > 1 and ax is None: + figure = plt.figure( + None, + (figsize[0] * ncol, figsize[1] * nrow), + facecolor=background, + dpi=dpi, + ) + gs = plt.GridSpec(nrow, ncol, wspace=0.12) + + ax_index = 0 + axes_list, color_list = [], [] + color_out = None + + if stack_colors: + ax = None + + for cur_b in basis: + for cur_l in layer: + + basis_key, cur_l_smoothed, cmap, sym_c = _get_basis_key( + adata=adata, + basis=cur_b, + layer=cur_l, + use_smoothed=use_smoothed, + cmap=cmap, + sym_c=sym_c, + ) + + for cur_c in color: + x, y, z = _standardize_dimensions(x, y, z, adata.n_obs, projection) + + for cur_x, cur_y, cur_z in zip(x, y, z): # here x / y are arrays + + if total_panels > 1 and not stack_colors: + ax = plt.subplot(gs[ax_index], projection=projection) + ax_index += 1 + + ax, color_out, font_color, stack_legend_handles = scatters_single_input( + adata=adata, + basis=cur_b, + x=cur_x, + y=cur_y, + z=cur_z, + color=cur_c, + layer=cur_l, + highlights=highlights, + labels=labels, + values=values, + theme=theme, + cmap=cmap, + color_key=color_key, + color_key_cmap=color_key_cmap, + background=background, + figsize=figsize, + show_legend=show_legend, + use_smoothed=use_smoothed, + aggregate=aggregate, + show_arrowed_spines=show_arrowed_spines, + ax=ax, + sort=sort, + add_gamma_fit=add_gamma_fit, + frontier=frontier, + contour=contour, + ccmap=ccmap, + calpha=calpha, + sym_c=sym_c, + smooth=smooth, + inset_dict=inset_dict, + group=group, + add_group_gamma_fit=add_group_gamma_fit, + affine_transform_A=affine_transform_A, + affine_transform_b=affine_transform_b, + stack_colors=stack_colors, + stack_colors_threshold=stack_colors_threshold, + stack_colors_title=stack_colors_title, + stack_colors_legend_size=stack_colors_legend_size, + stack_colors_cmaps=stack_colors_cmaps, + stack_legend_handles=stack_legend_handles, + despline=despline, + deaxis=deaxis, + despline_sides=despline_sides, + projection=projection, + scatter_kwargs=scatter_kwargs, + ax_index=ax_index, + cur_l_smoothed=cur_l_smoothed, + basis_key=basis_key, + ) + + axes_list.append(ax) + color_list.append(color_out) + + main_debug("show, return or save...") + if return_all: + return_value = (axes_list, color_list, font_color) if total_panels > 1 else (ax, color_out, font_color) + else: + return_value = axes_list if total_panels > 1 else ax + return save_show_ret("scatters", save_show_or_return, save_kwargs, return_value, adjust=show_legend, + background=background) + + +def scatters_single_input( + adata: AnnData, + basis: str = "umap", + x: int = 0, + y: int = 1, + z: int = 2, + color: str = "ntr", + layer: str = "X", + highlights: Optional[list] = None, + labels: Optional[list] = None, + values: Optional[list] = None, + theme: Optional[ + Literal[ + "blue", + "red", + "green", + "inferno", + "fire", + "viridis", + "darkblue", + "darkred", + "darkgreen", + ] + ] = None, + cmap: Optional[str] = None, + color_key: Union[Dict[str, str], List[str], None] = None, + color_key_cmap: Optional[str] = None, + background: Optional[str] = None, + figsize: Tuple[float, float] = (6, 4), + show_legend: str = "on data", + use_smoothed: bool = True, + aggregate: Optional[str] = None, + show_arrowed_spines: bool = False, + ax: Optional[Axes] = None, + sort: Literal["raw", "abs", "neg"] = "raw", + add_gamma_fit: bool = False, + frontier: bool = False, + contour: bool = False, + ccmap: Optional[str] = None, + calpha: float = 0.4, + sym_c: bool = False, + smooth: bool = False, + inset_dict: Dict[str, Any] = {}, + group: Optional[str] = None, + add_group_gamma_fit: bool = False, + affine_transform_A: Optional[float] = None, + affine_transform_b: Optional[float] = None, + stack_colors: bool = False, + stack_colors_threshold: float = 0.001, + stack_colors_title: str = "stacked colors", + stack_colors_legend_size: float = 2, + stack_colors_cmaps: Optional[List[str]] = None, + stack_legend_handles: Optional[List[Line2D]] = None, + despline: bool = True, + deaxis: bool = True, + despline_sides: Optional[List[str]] = None, + projection: str = "2d", + scatter_kwargs: Optional[Dict[str, Any]] = None, + ax_index: Optional[int] = None, + cur_l_smoothed: Optional[str] = None, + basis_key: Optional[str] = None, +) -> Union[ + Axes, + List[Axes], + Tuple[Axes, List[str], Literal["white", "black"]], + Tuple[List[Axes], List[str], Literal["white", "black"]], + None, +]: + """Generate a single scatter plot from given embedding. + + Args: + adata: an AnnData object. + basis: the reduced dimension stored in adata.obsm. The specific basis key will be constructed in the following + priority if exits: 1) specific layer input + basis 2) X_ + basis 3) basis. E.g. if basis is PCA, `scatters` + is going to look for 1) if specific layer is spliced, `spliced_pca` 2) `X_pca` (dynamo convention) 3) `pca`. + Defaults to "umap". + x: the column index of the low dimensional embedding for the x-axis. Defaults to 0. + y: the column index of the low dimensional embedding for the y-axis. Defaults to 1. + z: the column index of the low dimensional embedding for the z-axis. Defaults to 2. + color: any column names or gene expression, etc. that will be used for coloring cells. Defaults to "ntr". + layer: the layer of data to use for the scatter plot. Defaults to "X". + highlights: the color group that will be highlighted. If highlights is a list of lists, each list is relate to + each color element. Defaults to None. + labels: an array of labels (assumed integer or categorical), one for each data sample. This will be used for + coloring the points in the plot according to their label. Note that this option is mutually exclusive to the + `values` option. Defaults to None. + values: an array of values (assumed float or continuous), one for each sample. This will be used for coloring + the points in the plot according to a colorscale associated to the total range of values. Note that this + option is mutually exclusive to the `labels` option. Defaults to None. + theme: A color theme to use for plotting. A small set of predefined themes are provided which have relatively + good aesthetics. Available themes are: {'blue', 'red', 'green', 'inferno', 'fire', 'viridis', 'darkblue', + 'darkred', 'darkgreen'}. Defaults to None. + cmap: The name of a matplotlib colormap to use for coloring or shading points. If no labels or values are passed + this will be used for shading points according to density (largely only of relevance for very large + datasets). If values are passed this will be used for shading according the value. Note that if theme is + passed then this value will be overridden by the corresponding option of the theme. Defaults to None. + color_key: the method to assign colors to categoricals. This can either be an explicit dict mapping labels to + colors (as strings of form '#RRGGBB'), or an array like object providing one color for each distinct + category being provided in `labels`. Either way this mapping will be used to color points according to the + label. Note that if theme is passed then this value will be overridden by the corresponding option of the + theme. Defaults to None. + color_key_cmap: the name of a matplotlib colormap to use for categorical coloring. If an explicit `color_key` is + not given a color mapping for categories can be generated from the label list and selecting a matching list + of colors from the given colormap. Note that if theme is passed then this value will be overridden by the + corresponding option of the theme. Defaults to None. + background: the color of the background. Usually this will be either 'white' or 'black', but any color name will + work. Ideally one wants to match this appropriately to the colors being used for points etc. This is one of + the things that themes handle for you. Note that if theme is passed then this value will be overridden by + the corresponding option of the theme. Defaults to None. + figsize: the width and height of a figure. Defaults to (6, 4). + show_legend: whether to display a legend of the labels. Defaults to "on data". + use_smoothed: whether to use smoothed values (i.e. M_s / M_u instead of spliced / unspliced, etc.). Defaults to + True. + aggregate: the column in adata.obs that will be used to aggregate data points. Defaults to None. + show_arrowed_spines: whether to show a pair of arrowed spines representing the basis of the scatter is currently + using. Defaults to False. + ax: the matplotlib axes object where new plots will be added to. Only applicable to drawing a single component. + Defaults to None. + sort: the method to reorder data so that high values points will be on top of background points. Can be one of + {'raw', 'abs', 'neg'}, i.e. sorted by raw data, sort by absolute values or sort by negative values. Defaults + to "raw". + add_gamma_fit: whether to add the line of the gamma fitting. This will automatically turn on if `basis` points + to gene names and those genes have went through gamma fitting. Defaults to False. + frontier: whether to add the frontier. Scatter plots can be enhanced by using transparency (alpha) in order to + show area of high density and multiple scatter plots can be used to delineate a frontier. See matplotlib + tips & tricks cheatsheet (https://github.com/matplotlib/cheatsheets). Originally inspired by figures from + scEU-seq paper: https://science.sciencemag.org/content/367/6482/1151. If `contour` is set to be True, + `frontier` will be ignored as `contour` also add an outlier for data points. Defaults to False. + contour: whether to add an countor on top of scatter plots. We use tricontourf to plot contour for non-gridded + data. The shapely package was used to create a polygon of the concave hull of the scatters. With the polygon + we then check if the mean of the triangulated points are within the polygon and use this as our condition to + form the mask to create the contour. We also add the polygon shape as a frontier of the data point (similar + to when setting `frontier = True`). When the color of the data points is continuous, we will use the same + cmap as for the scatter points by default, when color is categorical, no contour will be drawn but just the + polygon. cmap can be set with `ccmap` argument. See below. This has recently changed to use seaborn's + kdeplot. Defaults to False. + ccmap: the name of a matplotlib colormap to use for coloring or shading points the contour. See above. + Defaults to None. + calpha: contour alpha value passed into sns.kdeplot. The value should be inbetween [0, 1]. Defaults to 0.4. + sym_c: whether do you want to make the limits of continuous color to be symmetric, normally this should be used + for plotting velocity, jacobian, curl, divergence or other types of data with both positive or negative + values. Defaults to False. + smooth: whether do you want to further smooth data and how much smoothing do you want. If it is `False`, no + smoothing will be applied. If `True`, smoothing based on one step diffusion of connectivity matrix + (`.uns['moment_cnn']`) will be applied. If a number larger than 1, smoothing will based on `smooth` steps of + diffusion. + inset_dict: a dictionary of parameters in inset_ax. Example, something like {"width": "5%", "height": "50%", "loc": + 'lower left', "bbox_to_anchor": (0.85, 0.90, 0.145, 0.145), "bbox_transform": ax.transAxes, "borderpad": 0} + See more details at https://matplotlib.org/api/_as_gen/mpl_toolkits.axes_grid1.inset_locator.inset_axes.html + or https://stackoverflow.com/questions/39803385/what-does-a-4-element-tuple-argument-for-bbox-to-anchor-mean-in-matplotlib. + Defaults to {}. + group: the key in `adata.obs` corresponding to the cell group data. Defaults to None. + add_group_gamma_fit: whether to plot the cell group's gamma fit results. Defaults to False. + affine_transform_A: coefficients in affine transformation Ax + b. 2D for now. Defaults to None. + affine_transform_b: bias in affine transformation Ax + b. Defaults to None. + stack_colors: whether to stack all color on the same ax passed above. Currently only support 18 sequential + matplotlib default cmaps assigning to different color groups. (#colors should be smaller than 18, reuse if + #colors > 18. TODO generate cmaps according to #colors). Defaults to False. + stack_colors_threshold: a threshold for filtering out points values < threshold when drawing each color. E.g. if + you do not want points with values < 1 showing up on axis, set threshold to be 1. Defaults to 0.001. + stack_colors_title: the title for the stack_color plot. Defaults to "stacked colors". + stack_colors_legend_size: the legend size in stack color plot. Defaults to 2. + stack_colors_cmaps: a list of cmaps that will be used to map values to color when stacking colors on the same + subplot. The order corresponds to the order of color. Defaults to None. + despline: whether to remove splines of the figure. Defaults to True. + deaxis: whether to remove axis ticks of the figure. Defaults to True. + despline_sides: which side of splines should be removed. Can be any combination of `["bottom", "right", "top", "left"]`. Defaults to None. + projection: the projection property of the matplotlib.Axes. Defaults to "2d". + + Returns: + The matplotlib axes object of the generated plot. + """ + import matplotlib.pyplot as plt + + # if #total_panel is 1, `_matplotlib_points` will create a figure. No need to create a figure here and generate a blank figure. + if ax is None: + figure, ax = plt.subplots() + + color_out = None + + basis_key = basis_key if basis_key is not None else basis + cur_l_smoothed = cur_l_smoothed if cur_l_smoothed is not None else layer + + main_info("plotting with basis key=%s" % basis_key, indent_level=2) + main_debug("coloring scatter of cur_c: %s" % str(color)) + + points, cur_title = _map_to_points( + adata, + axis_x=x, + axis_y=y, + axis_z=z, + basis_key=basis_key, + cur_c=color, + cur_b=basis, + cur_l_smoothed=cur_l_smoothed, + projection=projection, + ) + + ( + _color, + labels, + values, + _theme_, + cmap, + color_key_cmap, + background, + font_color, + is_not_continuous, + cur_title, + ) = _get_color_parameters( + adata=adata, + color=color, + layer=layer, + labels=labels, + values=values, + theme=theme, + cmap=cmap, + color_key_cmap=color_key_cmap, + background=background, + title=cur_title, + stack_colors=stack_colors, + stack_colors_threshold=stack_colors_threshold, + ) + + if stack_colors: + main_debug("stack colors: changing cmap") + cur_title = stack_colors_title + cmap = stack_colors_cmaps[(ax_index - 1) % len(stack_colors_cmaps)] + max_color = matplotlib.cm.get_cmap(cmap)(float("inf")) + # TODO: consider remove the legend because it is not helpful + legend_circle = Line2D( + [0], + [0], + marker="o", + color="w", + markerfacecolor=max_color, + label=color, + markersize=stack_colors_legend_size, + ) + stack_legend_handles.append(legend_circle) + ax.legend(handles=stack_legend_handles, loc="upper right", prop={"size": stack_colors_legend_size}) + + if aggregate is not None: + groups, uniq_grp = ( + adata.obs[aggregate], + list(adata.obs[aggregate].unique()), + ) + group_color, group_median = ( + np.zeros((1, len(uniq_grp))).flatten() + if isinstance(_color[0], Number) + else np.zeros((1, len(uniq_grp))).astype("str").flatten(), + np.zeros((len(uniq_grp), 2)), + ) + + grp_size = adata.obs[aggregate].value_counts()[uniq_grp].values + scatter_kwargs = ( + {"s": grp_size} if scatter_kwargs is None else update_dict(scatter_kwargs, {"s": grp_size}) + ) + + for ind, cur_grp in enumerate(uniq_grp): + group_median[ind, :] = np.nanmedian( + points.iloc[np.where(groups == cur_grp)[0], :2], + 0, + ) + if isinstance(_color[0], Number): + group_color[ind] = np.nanmedian(np.array(_color)[np.where(groups == cur_grp)[0]]) + else: + group_color[ind] = pd.Series(_color)[np.where(groups == cur_grp)[0]].value_counts().index[0] + + points, _color = ( + pd.DataFrame( + group_median, + index=uniq_grp, + columns=points.columns, + ), + group_color, + ) + + if is_not_continuous: + labels = _color + else: + values = _color + + if highlights is not None: + _highlights = highlights if all([i in _color for i in highlights]) else None + + if smooth and not is_not_continuous: + main_debug("smooth and not continuous") + knn = adata.obsp["moments_con"] + values = ( + calc_1nd_moment(values, knn)[0] + if smooth in [1, True] + else calc_1nd_moment(values, knn ** smooth)[0] + ) + + if affine_transform_A is None or affine_transform_b is None: + point_coords = points.values + else: + point_coords = affine_transform(points.values, affine_transform_A, affine_transform_b) + + if points.shape[0] <= figsize[0] * figsize[1] * 100000: + main_debug("drawing with _matplotlib_points function") + ax, color_out = _matplotlib_points( + # points.values, + point_coords, + ax, + labels, + values, + highlights, + cmap, + color_key, + color_key_cmap, + background, + figsize[0], + figsize[1], + show_legend, + sort=sort, + frontier=frontier, + contour=contour, + ccmap=ccmap, + calpha=calpha, + sym_c=sym_c, + inset_dict=inset_dict, + projection=projection, + **scatter_kwargs, + ) + if labels is not None: + color_dict = {} + colors = [rgb2hex(i) for i in color_out] + for i, j in zip(labels, colors): + color_dict[i] = j + + adata.uns[cur_title + "_colors"] = color_dict + else: + main_debug("drawing with _datashade_points function") + ax = _datashade_points( + # points.values, + point_coords, + ax, + labels, + values, + highlights, + cmap, + color_key, + color_key_cmap, + background, + figsize[0], + figsize[1], + show_legend, + sort=sort, + frontier=frontier, + contour=contour, + ccmap=ccmap, + calpha=calpha, + sym_c=sym_c, + **scatter_kwargs, + ) + + if show_arrowed_spines: + arrowed_spines(ax, points.columns[:2], background) + else: + if despline: + despline_all(ax, despline_sides) + if deaxis: + deaxis_all(ax) + + ax.set_title(cur_title) + + if add_gamma_fit and basis in adata.var_names[adata.var.use_for_dynamics]: + xnew = np.linspace( + points.iloc[:, 0].min(), + points.iloc[:, 0].max() * 0.80, + ) + k_name = "gamma_k" if adata.uns["dynamics"]["experiment_type"] == "one-shot" else "gamma" + vel_params_df = get_vel_params(adata) + if k_name in vel_params_df.columns: + if not ("gamma_b" in vel_params_df.columns) or all(vel_params_df.gamma_b.isna()): + vel_params_df.loc[:, "gamma_b"] = 0 + update_vel_params(adata, params_df=vel_params_df) + ax.plot( + xnew, + xnew * adata[:, basis].var.loc[:, k_name].unique() + + adata[:, basis].var.loc[:, "gamma_b"].unique(), + dashes=[6, 2], + c=font_color, + ) + else: + raise ValueError( + "_adata does not seem to have %s column. Velocity estimation is required " + "before running this function." % k_name + ) + if group is not None and add_group_gamma_fit and basis in adata.var_names[adata.var.use_for_dynamics]: + group_colors = ["b", "g", "r", "c", "m", "y", "k", "w"] + cell_groups = adata.obs[group] + unique_groups = np.unique(cell_groups) + k_suffix = "gamma_k" if adata.uns["dynamics"]["experiment_type"] == "one-shot" else "gamma" + for group_idx, cur_group in enumerate(unique_groups): + group_k_name = group + "_" + cur_group + "_" + k_suffix + group_adata = adata[adata.obs[group] == cur_group] + group_points = points.iloc[np.array(adata.obs[group] == cur_group)] + group_b_key = group + "_" + cur_group + "_" + "gamma_b" + group_xnew = np.linspace( + group_points.iloc[:, 0].min(), + group_points.iloc[:, 0].max() * 0.90, + ) + group_ynew = ( + group_xnew * group_adata[:, basis].var.loc[:, group_k_name].unique() + + group_adata[:, basis].var.loc[:, group_b_key].unique() + ) + ax.annotate(group + "_" + cur_group, xy=(group_xnew[-1], group_ynew[-1])) + vel_params_df = get_vel_params(group_adata) + if group_k_name in vel_params_df.columns: + if not (group_b_key in vel_params_df.columns) or all(vel_params_df[group_b_key].isna()): + vel_params_df.loc[:, group_b_key] = 0 + update_vel_params(group_adata, params_df=vel_params_df) + main_info("No %s found, setting all bias terms to zero" % group_b_key) + ax.plot( + group_xnew, + group_ynew, + dashes=[6, 2], + c=group_colors[group_idx % len(group_colors)], + ) + else: + raise Exception( + "_adata does not seem to have %s column. Velocity estimation is required " + "before running this function." % group_k_name + ) + return ax, color_out, font_color, stack_legend_handles + + +def _validate_parameters( + adata: AnnData, + basis: Union[str, List[str]], + background: Optional[str], + color: Union[str, List[str]], + layer: Union[str, List[str]], + x: Union[str, List[str], int, List[int], None], + y: Union[str, List[str], int, List[int], None], + z: Union[str, List[str], int, List[int], None], + ncols: int = 4, +) -> Tuple: + """A helper function to validate parameters for scatters function.""" + if all([is_gene_name(adata, i) for i in basis]): + if x[0] not in ["M_s", "X_spliced", "M_t", "X_total", "spliced", "total"] and y[0] not in [ + "M_u", + "X_unspliced", + "M_n", + "X_new", + "unspliced", + "new", + ]: + if "M_t" in adata.layers.keys() and "M_n" in adata.layers.keys(): + x, y = ["M_t"], ["M_n"] + elif "X_total" in adata.layers.keys() and "X_new" in adata.layers.keys(): + x, y = ["X_total"], ["X_new"] + elif "M_s" in adata.layers.keys() and "M_u" in adata.layers.keys(): + x, y = ["M_s"], ["M_u"] + elif "X_spliced" in adata.layers.keys() and "X_unspliced" in adata.layers.keys(): + x, y = ["X_spliced"], ["X_unspliced"] + elif "spliced" in adata.layers.keys() and "unspliced" in adata.layers.keys(): + x, y = ["spliced"], ["unspliced"] + elif "total" in adata.layers.keys() and "new" in adata.layers.keys(): + x, y = ["total"], ["new"] + else: + raise ValueError( + "your adata object is corrupted. Please make sure it has at least one of the following " + "pair of layers:" + "'M_s', 'X_spliced', 'M_t', 'X_total', 'spliced', 'total' and " + "'M_u', 'X_unspliced', 'M_n', 'X_new', 'unspliced', 'new'. " + ) + + if background is None: + background = rcParams.get("figure.facecolor") + background = to_hex(background) if type(background) is tuple else background + # if save_show_or_return != 'save': set_figure_params('dynamo', background=_background) + + if type(x) in [int, str]: + x = [x] + if type(y) in [int, str]: + y = [y] + if type(z) in [int, str]: + z = [z] + + if type(color) is str: + color = [color] + if type(layer) is str: + layer = [layer] + if type(basis) is str: + basis = [basis] + + n_c, n_l, n_b, n_x, n_y, n_z = ( + 1 if color is None else len(color), + 1 if layer is None else len(layer), + 1 if basis is None else len(basis), + 1 if x is None else 1 if type(x) in [anndata._core.views.ArrayView, np.ndarray] else len(x), + # check whether it is an array + 1 if y is None else 1 if type(y) in [anndata._core.views.ArrayView, np.ndarray] else len(y), + # check whether it is an array + 1 if z is None else 1 if type(z) in [anndata._core.views.ArrayView, np.ndarray] else len(z), + ) + + total_panels, ncols = ( + n_c * n_l * n_b * n_x, + min(max([n_c, n_l, n_b, n_x, n_y, n_z]), ncols), + ) + nrow, ncol = int(np.ceil(total_panels / ncols)), ncols + + return basis, background, color, layer, x, y, z, nrow, ncol, total_panels + +def _get_basis_key( + adata: AnnData, + basis: str, + layer: str, + use_smoothed: bool, + cmap: Optional[str] = None, + sym_c: Optional[bool] = None, +) -> str: + """A helper function to get basis key for scatters function.""" + if use_smoothed: + mapper = get_mapper() + + cur_l_smoothed = layer + + if layer in ["acceleration", "curvature", "divergence", "velocity_S", "velocity_T"]: + cur_l_smoothed = layer + cmap, sym_c = "bwr", True # TODO maybe use other divergent color map in the future + else: + if use_smoothed: + cur_l_smoothed = layer if layer.startswith("M_") | layer.startswith("velocity") else mapper[layer] + if layer.startswith("velocity"): + cmap, sym_c = "bwr", True + + if layer + "_" + basis in adata.obsm.keys(): + prefix = layer + "_" + elif ("X_" + basis) in adata.obsm.keys(): + prefix = "X_" + elif basis in adata.obsm.keys(): + # special case for spatial for compatibility with other packages + prefix = "" + elif is_gene_name(adata, basis): + prefix = "" + else: + raise ValueError("Please check if basis=%s exists in adata.obsm" % basis) + + basis_key = prefix + basis + + return basis_key, cur_l_smoothed, cmap, sym_c + +def _get_color_parameters( + adata: AnnData, + color: str, + layer: str, + labels: Optional[list] = None, + values: Optional[list] = None, + theme: Optional[str] = None, + cmap: Optional[str] = None, + color_key_cmap: Optional[str] = None, + background: Optional[str] = None, + title: str = "", + stack_colors: bool = False, + stack_colors_threshold: float = 0.001, +) -> Tuple: + """A helper function to get color parameters for scatters function.""" + + font_color = _select_font_color(background) + + def _get_theme_for_labels(): + if theme is None: + if background in ["#ffffff", "black"]: + _theme_ = "glasbey_dark" + else: + _theme_ = "glasbey_white" + else: + _theme_ = theme + + return _theme_ + + def _get_theme_for_values(): + if theme is None: + if background in ["#ffffff", "black"]: + _theme_ = "inferno" if layer != "velocity" else "div_blue_black_red" + else: + _theme_ = "viridis" if not layer.startswith("velocity") else "div_blue_red" + else: + _theme_ = theme + + return _theme_ + + if labels is not None and values is not None: + raise ValueError("Conflicting options; only one of labels or values should be set") + + if labels is not None: + main_info("labels input will disable the color by %s" % color) + _color = labels + title = "customized labels" if title == color else title + elif values is not None: + main_info("values input will disable the color by %s" % color) + _color = values + title = "customized values" if title == color else title + else: + _color = _get_adata_color_vec(adata, layer, color) + + is_numeric_color = np.issubdtype(_color.dtype, np.number) + if not is_numeric_color: + main_info( + "skip filtering %s by stack threshold when stacking color because it is not a numeric type" + % color, + indent_level=2, + ) + + if stack_colors and np.issubdtype(_color.dtype, np.number): + main_debug("Subsetting adata by stack_colors") + _stack_background_adata_indices = np.ones(len(adata), dtype=bool) + _stack_background_adata_indices = np.logical_and( + _stack_background_adata_indices, (_color < stack_colors_threshold) + ) + if values is not None: + values = values[_color > stack_colors_threshold] + _color = _color[_color > stack_colors_threshold] + if len(_color) == 0: + main_info("skipping filtering %s because no point of %s is above threshold" % (color, color)) + + is_not_continuous = not isinstance(_color[0], Number) or _color.dtype.name == "category" + + if is_not_continuous: + labels = np.asarray(_color) if is_categorical_dtype(_color) else _color + _theme_ = _get_theme_for_labels() + else: + values = _color + _theme_ = _get_theme_for_values() + + cmap = _themes[_theme_]["cmap"] if cmap is None else cmap + + color_key_cmap = _themes[_theme_]["color_key_cmap"] if color_key_cmap is None else color_key_cmap + background = _themes[_theme_]["background"] if background is None else background + + return _color, labels, values, _theme_, cmap, color_key_cmap, background, font_color, is_not_continuous, title + + +def _standardize_dimensions( + x: Union[int, str, np.ndarray, anndata._core.views.ArrayView], + y: Union[int, str, np.ndarray, anndata._core.views.ArrayView], + z: Union[int, str, np.ndarray, anndata._core.views.ArrayView], + n_obs: int, + projection: str, +) -> Tuple: + """A helper function to standardize the dimensions of x, y, z.""" + if projection == "3d": + x, y, z = ( + _standardize_dimension(x, n_obs), + _standardize_dimension(y, n_obs), + _standardize_dimension(z, n_obs), + ) + else: + x, y = _standardize_dimension(x, n_obs), _standardize_dimension(y, n_obs) + z = [np.nan] * len(x) + + assert len(x) == len(y) and len(x) == len(z), "bug: x, y, z does not have the same shape." + + return x, y, z + + +def _standardize_dimension(dim: Union[int, str, np.ndarray, anndata._core.views.ArrayView], n_obs: int) -> List: + """A helper function to standardize the dimension of a given axis.""" + if type(dim) in [anndata._core.views.ArrayView, np.ndarray] and len(dim) == n_obs: + return [dim] + elif hasattr(dim, "__len__"): + return list(dim) + + +def _map_to_points( + _adata: AnnData, + axis_x: str, + axis_y: str, + axis_z: str, + basis_key: str, + cur_c: str, + cur_b: str, + cur_l_smoothed: str, + projection: str, +) -> Tuple[pd.DataFrame, str]: + """A helper function to map the given axis to corresponding coordinates in current embedding space. + + Args: + _adata: an AnnData object. + axis_x: the column index of the low dimensional embedding for the x-axis in current space. + axis_y: the column index of the low dimensional embedding for the y-axis in current space. + axis_z: the column index of the low dimensional embedding for the z-axis in current space. + basis_key: the basis key constructed by current basis and layer. + cur_c: the current key to color the data. + cur_b: the current basis key representing the reduced dimension. + cur_l_smoothed: the smoothed layer of data to use. + + Returns: + The 3D DataFrame with coordinates of each sample and the title of the plot. + """ + gene_title = [] + anno_title = [] + + def _map_cur_axis_to_title( + cur, + _adata: AnnData, + cur_b: str, + cur_l_smoothed: str, + ) -> Tuple[np.ndarray, str]: + """A helper function to map an axis. + + Args: + cur: the current axis to map. + + Returns: + The coordinates and the column names. + """ + nonlocal gene_title, anno_title + + if is_gene_name(_adata, cur): + points_df_data = (_adata.obs_vector(k=cur, layer=None) + if cur_l_smoothed == "X" + else _adata.obs_vector(k=cur, layer=cur_l_smoothed)) + points_column = cur + " (" + cur_l_smoothed + ")" + gene_title.append(cur) + elif is_cell_anno_column(_adata, cur): + points_df_data = _adata.obs_vector(cur) + points_column = cur + anno_title.append(cur) + elif is_layer_keys(_adata, cur): + points_df_data = _adata[:, cur_b].layers[cur] + points_column = flatten(points_df_data) + else: + raise ValueError("Make sure your `x`, `y` and `z` are integers, gene names, column names in .obs, etc.") + + return points_df_data, points_column + + if projection == "3d": + if type(axis_x) is int and type(axis_y) is int and type(axis_z): + x_col_name = cur_b + "_0" + y_col_name = cur_b + "_1" + z_col_name = cur_b + "_2" + + points = pd.DataFrame( + { + x_col_name: _adata.obsm[basis_key][:, axis_x], + y_col_name: _adata.obsm[basis_key][:, axis_y], + z_col_name: _adata.obsm[basis_key][:, axis_z], + } + ) + points.columns = [x_col_name, y_col_name, z_col_name] + + cur_title = cur_c + + return points, cur_title + elif type(axis_x) in [anndata._core.views.ArrayView, np.ndarray] and type(axis_y) in [ + anndata._core.views.ArrayView, + np.ndarray, + ]: + points = pd.DataFrame({"x": flatten(axis_x), "y": flatten(axis_y), "z": flatten(axis_z)}) + points.columns = ["x", "y", "z"] + else: + x_points_df_data, x_points_column = _map_cur_axis_to_title(axis_x, _adata, cur_b, cur_l_smoothed) + y_points_df_data, y_points_column = _map_cur_axis_to_title(axis_y, _adata, cur_b, cur_l_smoothed) + z_points_df_data, z_points_column = _map_cur_axis_to_title(axis_z, _adata, cur_b, cur_l_smoothed) + points = pd.DataFrame({ + axis_x: x_points_df_data, + axis_y: y_points_df_data, + axis_z: z_points_df_data, + }) + points.columns = [x_points_column, y_points_column, z_points_column] + + if len(gene_title) != 0: + cur_title = " VS ".join(gene_title) + elif len(anno_title) == 3: + cur_title = " VS ".join(anno_title) + else: + cur_title = cur_b + else: + if type(axis_x) is int and type(axis_y) is int: + x_col_name = cur_b + "_0" + y_col_name = cur_b + "_1" + + points = pd.DataFrame( + { + x_col_name: _adata.obsm[basis_key][:, axis_x], + y_col_name: _adata.obsm[basis_key][:, axis_y], + } + ) + points.columns = [x_col_name, y_col_name] + + cur_title = cur_c + + return points, cur_title + elif type(axis_x) in [anndata._core.views.ArrayView, np.ndarray] and type(axis_y) in [ + anndata._core.views.ArrayView, + np.ndarray, + ]: + points = pd.DataFrame({"x": flatten(axis_x), "y": flatten(axis_y)}) + points.columns = ["x", "y"] + else: + x_points_df_data, x_points_column = _map_cur_axis_to_title(axis_x, _adata, cur_b, cur_l_smoothed) + y_points_df_data, y_points_column = _map_cur_axis_to_title(axis_y, _adata, cur_b, cur_l_smoothed) + x_points_df_data = x_points_df_data.A.flatten() if issparse(x_points_df_data) else x_points_df_data + y_points_df_data = y_points_df_data.A.flatten() if issparse(y_points_df_data) else y_points_df_data + points = pd.DataFrame({ + axis_x: x_points_df_data, + axis_y: y_points_df_data, + }, index=_adata.obs_names) + points.columns = [axis_x, axis_y] + + if len(gene_title) != 0: + cur_title = " VS ".join(gene_title) + elif len(anno_title) == 2: + cur_title = " VS ".join(anno_title) + else: + cur_title = cur_b + + return points, cur_title diff --git a/dynamo/vectorfield/topography.py b/dynamo/vectorfield/topography.py index a2537d685..1404c4bdb 100755 --- a/dynamo/vectorfield/topography.py +++ b/dynamo/vectorfield/topography.py @@ -882,8 +882,8 @@ def func(x): "Xss": Xss, "ftype": ftype, "confidence": confidence, - "NCx": {str(index): array for index, array in enumerate(NCx)}, - "NCy": {str(index): array for index, array in enumerate(NCy)}, + "NCx": {str(index): array for index, array in enumerate(NCx)} if NCx is not None else None, + "NCy": {str(index): array for index, array in enumerate(NCy)} if NCy is not None else None, "separatrices": None, "fp_ind": fp_ind, } @@ -897,8 +897,8 @@ def func(x): "Xss": Xss, "ftype": ftype, "confidence": confidence, - "NCx": {str(index): array for index, array in enumerate(NCx)}, - "NCy": {str(index): array for index, array in enumerate(NCy)}, + "NCx": {str(index): array for index, array in enumerate(NCx)} if NCx is not None else None, + "NCy": {str(index): array for index, array in enumerate(NCy)} if NCy is not None else None, "separatrices": None, "fp_ind": fp_ind, }