From 8edf467eab76653a49e3b20da7a072a1fb047a7b Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Mon, 6 Nov 2023 15:47:30 -0500 Subject: [PATCH] working on keep list --- src/spikeanalysis/spike_plotter.py | 38 ++++++++++++++++++++++++++---- 1 file changed, 33 insertions(+), 5 deletions(-) diff --git a/src/spikeanalysis/spike_plotter.py b/src/spikeanalysis/spike_plotter.py index a36fbbc..6a54eb8 100644 --- a/src/spikeanalysis/spike_plotter.py +++ b/src/spikeanalysis/spike_plotter.py @@ -181,6 +181,7 @@ def _plot_scores( bar: Optional[list[int]] = None, indices: bool = False, show_stim: bool = True, + ) -> Optional[np.array]: """ Function to plot heatmaps of firing rate data @@ -342,7 +343,7 @@ def _plot_scores( if indices: return sorted_cluster_ids - def plot_raster(self, window: Union[list, list[list]], show_stim: bool = True): + def plot_raster(self, window: Union[list, list[list]], show_stim: bool = True, include_ids: list | np.nadarry | None = None): """ Function to plot rasters @@ -353,6 +354,7 @@ def plot_raster(self, window: Union[list, list[list]], show_stim: bool = True): of [start, stop] format show_stim: bool, default True Show lines where stim onset and offset are + include_ids: list | np.ndarray | None, default: None """ from .analysis_utils import histogram_functions as hf @@ -372,6 +374,16 @@ def plot_raster(self, window: Union[list, list[list]], show_stim: bool = True): event_lengths = self._get_event_lengths() + + if include_ids is not None: + cluster_indices = self.data.cluster_ids + keep_list = [] + for cid in include_ids: + keep_list.append(np.where(cluster_indices==cid)[0][0]) + keep_list = np.array(keep_list) + else: + keep_list = np.arange(0, len(cluster_indices), 1) + for idx, stimulus in enumerate(psths.keys()): bins = psths[stimulus]["bins"] psth = psths[stimulus]["psth"] @@ -384,8 +396,10 @@ def plot_raster(self, window: Union[list, list[list]], show_stim: bool = True): psth = psth[:, :, np.logical_and(bins > sub_window[0], bins < sub_window[1])] bins = bins[np.logical_and(bins >= sub_window[0], bins <= sub_window[1])] - for idx in range(np.shape(psth)[0]): - psth_sub = np.squeeze(psth[idx]) + for idy in range(np.shape(psth)[0]): + if idy not in keep_list: + pass + psth_sub = np.squeeze(psth[idy]) if np.sum(psth_sub) == 0: continue @@ -426,7 +440,7 @@ def plot_raster(self, window: Union[list, list[list]], show_stim: bool = True): sns.despine() else: self._despine(ax) - plt.title(f"{self.data.cluster_ids[idx]} stim: {stimulus}", size=7) + plt.title(f"{self.data.cluster_ids[idy]} stim: {stimulus}", size=7) plt.figure(dpi=self.dpi) plt.show() @@ -436,6 +450,7 @@ def plot_sm_fr( time_bin_ms: Union[float, list[float]], sm_time_ms: Union[float, list[float]], show_stim: bool = True, + include_ids: list | np.ndarray | None = None, ): """ Function to plot smoothed firing rates @@ -452,7 +467,9 @@ def plot_sm_fr( stimulus show_stim: bool, default True Show lines where stim onset and offset are - + include_ids: list | np.ndarray | None + The ids to include for plotting + """ import matplotlib as mpl from .analysis_utils import histogram_functions as hf @@ -488,6 +505,15 @@ def plot_sm_fr( number of bins is{len(time_bin_ms)} and should be {NUM_STIM}" time_bin_size = np.array(time_bin_ms) / 1000 + if include_ids is not None: + cluster_indices = self.data.cluster_ids + keep_list = [] + for cid in include_ids: + keep_list.append(np.where(cluster_indices==cid)[0][0]) + keep_list = np.array(keep_list) + else: + keep_list = np.arange(0, len(cluster_indices), 1) + stim_trial_groups = self._get_trial_groups() event_lengths = self._get_event_lengths_all() for idx, stimulus in enumerate(psths.keys()): @@ -519,6 +545,8 @@ def plot_sm_fr( stderr = np.zeros((len(tg_set), len(bins))) event_len = np.zeros((len(tg_set))) for cluster_number in range(np.shape(psth)[0]): + if cluster_number not in keep_list: + pass smoothed_psth = gaussian_smoothing(psth[cluster_number], bin_size, sm_std) for trial_number, trial in enumerate(tg_set):