From 95dc72fc903b1dfd33c65a53ccf85e110a5d8059 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Tue, 3 Oct 2023 14:16:44 -0400 Subject: [PATCH] update plotting functions --- docs/source/submodules/intrinsic_plotter.rst | 2 +- src/spikeanalysis/spike_plotter.py | 85 +++++++++++--------- 2 files changed, 46 insertions(+), 41 deletions(-) diff --git a/docs/source/submodules/intrinsic_plotter.rst b/docs/source/submodules/intrinsic_plotter.rst index a372280..b6d8d24 100644 --- a/docs/source/submodules/intrinsic_plotter.rst +++ b/docs/source/submodules/intrinsic_plotter.rst @@ -71,7 +71,7 @@ lamina of the spinal cord has most of the units found during sorting. .. code-block:: python - iplotter.plot_spike_dpeth_fr(sp=spikes) # spikes is SpikeData + iplotter.plot_spike_depth_fr(sp=spikes) # spikes is SpikeData CDF --- diff --git a/src/spikeanalysis/spike_plotter.py b/src/spikeanalysis/spike_plotter.py index db5d292..bbaff45 100644 --- a/src/spikeanalysis/spike_plotter.py +++ b/src/spikeanalysis/spike_plotter.py @@ -102,14 +102,14 @@ def plot_zscores( Returns ------- - ordered_cluster_ids: np.array + sorted_cluster_ids: np.array if indices is True, the function will return the cluster ids as displayed in the z bar graph """ - if self.cmap == "viridis": + if self.cmap is None: self.cmap = "vlag" - index_array = self._plot_scores( + sorted_cluster_ids = self._plot_scores( data="zscore", figsize=figsize, sorting_index=sorting_index, @@ -118,7 +118,7 @@ def plot_zscores( show_stim=show_stim, ) if indices: - return index_array + return sorted_cluster_ids def plot_raw_firing( self, @@ -150,19 +150,19 @@ def plot_raw_firing( Returns ------- - ordered_cluster_ids: Optional[np.array] + ordered_cluster_ids: Optional[dict] if indices is True, the function will return the cluster ids as displayed in the z bar graph """ - if self.cmap == "vlag": + if self.cmap is None: self.cmap = "viridis" - index_array = self._plot_scores( + sorted_cluster_ids = self._plot_scores( data="raw-data", figsize=figsize, sorting_index=sorting_index, bar=bar, indices=indices, show_stim=show_stim ) if indices: - return index_array + return sorted_cluster_ids def _plot_scores( self, @@ -187,13 +187,13 @@ def _plot_scores( bar: list[int] If given a list with min for the cbar at index 0 and the max at index 1. Overrides cbar generation indices: bool, default False - If true will return the cluster ids sorted in the order they appear in the graph + If true will return the cluster ids sorted in the order they appear in the graph as a dict of stimuli show_stim: bool, default True Show lines where stim onset and offset are Returns ------- - ordered_cluster_ids: Optional[np.array] + sorted_cluster_ids: Optional[dict] if indices is True, the function will return the cluster ids as displayed in the z bar graph """ @@ -208,10 +208,7 @@ def _plot_scores( if figsize is None: figsize = self.figsize - if self.cmap is None: - cmap = "vlag" - else: - cmap = self.cmap + cmap = self.cmap if self.y_axis is None: y_axis = "Units" @@ -222,7 +219,7 @@ def _plot_scores( assert len(bar) == 2, f"Please give z_bar as [min, max], you entered {bar}" stim_lengths = self._get_event_lengths() - + sorted_cluster_ids = {} for stimulus in z_scores.keys(): if len(np.shape(z_scores)) < 3: sub_zscores = np.expand_dims(z_scores[stimulus], axis=1) @@ -251,17 +248,18 @@ def _plot_scores( event_window = np.logical_and(bins >= 0, bins <= length) z_score_sorting_index = np.argsort(-np.sum(sub_zscores[:, sorting_index, event_window], axis=1)) - + sorted_cluster_ids[stimulus] = self.data.cluster_ids[z_score_sorting_index] sorted_z_scores = sub_zscores[z_score_sorting_index, :, :] - # nan_mask = np.all( - # np.isnan(sorted_z_scores) | np.equal(sorted_z_scores, 0) | np.isinf(sorted_z_scores), axis=2 - # ) - - # sorted_z_scores = sorted_z_scores[~nan_mask] if len(np.shape(sorted_z_scores)) == 2: sorted_z_scores = np.expand_dims(sorted_z_scores, axis=1) + nan_mask = np.all( + np.all(np.isnan(sorted_z_scores) | np.equal(sorted_z_scores, 0) | np.isinf(sorted_z_scores), axis=2), + axis=1, + ) + sorted_z_scores = sorted_z_scores[~nan_mask] + if bar is not None: vmax = bar[1] vmin = bar[0] @@ -320,7 +318,11 @@ def _plot_scores( ] ) cax.spines["bottom"].set_visible(False) - plt.colorbar(im, cax=cax, label="Z scores") # Similar to fig.colorbar(im, cax = cax) + if data == "zscore": + cbar_label = "Z scores" + else: + cbar_label = "Raw Firing" + plt.colorbar(im, cax=cax, label=cbar_label) # Similar to fig.colorbar(im, cax = cax) plt.title(f"{stimulus}") plt.figure(dpi=self.dpi) plt.show() @@ -329,7 +331,7 @@ def _plot_scores( sorting_index = None if indices: - return self.data.cluster_ids[z_score_sorting_index] + return sorted_cluster_ids def plot_raster(self, window: Union[list, list[list]], show_stim: bool = True): """ @@ -556,7 +558,7 @@ def plot_sm_fr( plt.figure(dpi=self.dpi) plt.show() - def plot_zscores_ind(self, z_bar: Optional[list[int]] = None): + def plot_zscores_ind(self, z_bar: Optional[list[int]] = None, show_stim: bool = True): """ Function for plotting z scored heatmaps by trial group rather than all trial groups on the same set of axes. In This function all data is ordered based on the most responsive unit/trial group. Rows can be different units @@ -567,6 +569,8 @@ def plot_zscores_ind(self, z_bar: Optional[list[int]] = None): ---------- z_bar: list[int] If given a list with min z score for the cbar at index 0 and the max at index 1. Overrides cbar generation + show_stim: bool, default: True + Whether to mark at the stim onset and offset """ try: z_scores = self.data.z_scores @@ -631,22 +635,23 @@ def plot_zscores_ind(self, z_bar: Optional[list[int]] = None): ax.set_xticks([i * bins_length for i in range(7)]) ax.set_xticklabels([round(bins[i * bins_length], 4) if i < 7 else z_window[1] for i in range(7)]) ax.set_ylabel(y_axis, fontsize="small") - ax.axvline( - zero_point, - 0, - np.shape(sorted_z_scores)[0], - color="black", - linestyle=":", - linewidth=0.5, - ) - ax.axvline( - end_point, - 0, - np.shape(sorted_z_scores)[0], - color="black", - linestyle=":", - linewidth=0.5, - ) + if show_stim: + ax.axvline( + zero_point, + 0, + np.shape(sorted_z_scores)[0], + color="black", + linestyle=":", + linewidth=0.5, + ) + ax.axvline( + end_point, + 0, + np.shape(sorted_z_scores)[0], + color="black", + linestyle=":", + linewidth=0.5, + ) self._despine(ax) ax.spines["bottom"].set_visible(False) ax.spines["left"].set_visible(False)