From a84781538bb1ffb44bb3953b0f3dea33fe15a1f3 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Tue, 7 Nov 2023 16:03:02 -0500 Subject: [PATCH] add sorted --- src/spikeanalysis/spike_plotter.py | 35 ++++++++++++++++++++++++++---- 1 file changed, 31 insertions(+), 4 deletions(-) diff --git a/src/spikeanalysis/spike_plotter.py b/src/spikeanalysis/spike_plotter.py index 6a54eb8..94952eb 100644 --- a/src/spikeanalysis/spike_plotter.py +++ b/src/spikeanalysis/spike_plotter.py @@ -343,7 +343,7 @@ def _plot_scores( if indices: return sorted_cluster_ids - def plot_raster(self, window: Union[list, list[list]], show_stim: bool = True, include_ids: list | np.nadarry | None = None): + def plot_raster(self, window: Union[list, list[list]], show_stim: bool = True, include_ids: list | np.nadarry | None = None, sorted: bool = False): """ Function to plot rasters @@ -355,6 +355,8 @@ def plot_raster(self, window: Union[list, list[list]], show_stim: bool = True, i show_stim: bool, default True Show lines where stim onset and offset are include_ids: list | np.ndarray | None, default: None + sub ids to include + sorted: bool, default = False """ from .analysis_utils import histogram_functions as hf @@ -384,6 +386,12 @@ def plot_raster(self, window: Union[list, list[list]], show_stim: bool = True, i else: keep_list = np.arange(0, len(cluster_indices), 1) + if sorted: + sorted_indices = [] + for value_id in include_ids: + sorted_indices.append(np.nonzero(self.cluster_ids==value_id)[0][0]) + sorted_indices = np.array(sorted_indices) + for idx, stimulus in enumerate(psths.keys()): bins = psths[stimulus]["bins"] psth = psths[stimulus]["psth"] @@ -394,6 +402,10 @@ def plot_raster(self, window: Union[list, list[list]], show_stim: bool = True, i tg_set = np.unique(trial_groups) psth = psth[:, :, np.logical_and(bins > sub_window[0], bins < sub_window[1])] + + if sorted: + psth = psth[sorted_indices,...] + bins = bins[np.logical_and(bins >= sub_window[0], bins <= sub_window[1])] for idy in range(np.shape(psth)[0]): @@ -440,7 +452,10 @@ def plot_raster(self, window: Union[list, list[list]], show_stim: bool = True, i sns.despine() else: self._despine(ax) - plt.title(f"{self.data.cluster_ids[idy]} stim: {stimulus}", size=7) + if sorted: + plt.title(f"{stimulus}: {self.data.cluster_ids[sorted_indices][idy]}", fontsize=8) + else: + plt.title(f"{stimulus}: {self.data.cluster_ids[idy]}", fontsize=8) plt.figure(dpi=self.dpi) plt.show() @@ -451,6 +466,7 @@ def plot_sm_fr( sm_time_ms: Union[float, list[float]], show_stim: bool = True, include_ids: list | np.ndarray | None = None, + sorted: bool = False, ): """ Function to plot smoothed firing rates @@ -469,6 +485,7 @@ def plot_sm_fr( Show lines where stim onset and offset are include_ids: list | np.ndarray | None The ids to include for plotting + sorted: bool = False, """ import matplotlib as mpl @@ -514,6 +531,12 @@ def plot_sm_fr( else: keep_list = np.arange(0, len(cluster_indices), 1) + if sorted: + sorted_indices = [] + for value_id in include_ids: + sorted_indices.append(np.nonzero(self.cluster_ids==value_id)[0][0]) + sorted_indices = np.array(sorted_indices) + stim_trial_groups = self._get_trial_groups() event_lengths = self._get_event_lengths_all() for idx, stimulus in enumerate(psths.keys()): @@ -531,6 +554,8 @@ def plot_sm_fr( trial_groups = stim_trial_groups[stimulus] sub_window = windows[idx] psth = psth[:, :, np.logical_and(bins > sub_window[0], bins < sub_window[1])] + if sorted: + psth = psth[sorted_indices, ...] bins = bins[np.logical_and(bins > sub_window[0], bins < sub_window[1])] events = event_lengths[stimulus] tg_set = np.unique(trial_groups) @@ -590,8 +615,10 @@ def plot_sm_fr( sns.despine() else: self._despine(ax) - - plt.title(f"{stimulus}: {self.data.cluster_ids[cluster_number]}", fontsize=8) + if sorted: + plt.title(f"{stimulus}: {self.data.cluster_ids[sorted_indices][cluster_number]}", fontsize=8) + else: + plt.title(f"{stimulus}: {self.data.cluster_ids[cluster_number]}", fontsize=8) plt.figure(dpi=self.dpi) plt.show()