Skip to content

Commit

Permalink
working on keep list
Browse files Browse the repository at this point in the history
  • Loading branch information
zm711 committed Nov 6, 2023
1 parent 4954bbc commit 8edf467
Showing 1 changed file with 33 additions and 5 deletions.
38 changes: 33 additions & 5 deletions src/spikeanalysis/spike_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ def _plot_scores(
bar: Optional[list[int]] = None,
indices: bool = False,
show_stim: bool = True,

) -> Optional[np.array]:
"""
Function to plot heatmaps of firing rate data
Expand Down Expand Up @@ -342,7 +343,7 @@ def _plot_scores(
if indices:
return sorted_cluster_ids

def plot_raster(self, window: Union[list, list[list]], show_stim: bool = True):
def plot_raster(self, window: Union[list, list[list]], show_stim: bool = True, include_ids: list | np.nadarry | None = None):
"""
Function to plot rasters
Expand All @@ -353,6 +354,7 @@ def plot_raster(self, window: Union[list, list[list]], show_stim: bool = True):
of [start, stop] format
show_stim: bool, default True
Show lines where stim onset and offset are
include_ids: list | np.ndarray | None, default: None
"""
from .analysis_utils import histogram_functions as hf

Expand All @@ -372,6 +374,16 @@ def plot_raster(self, window: Union[list, list[list]], show_stim: bool = True):

event_lengths = self._get_event_lengths()


if include_ids is not None:
cluster_indices = self.data.cluster_ids
keep_list = []
for cid in include_ids:
keep_list.append(np.where(cluster_indices==cid)[0][0])
keep_list = np.array(keep_list)
else:
keep_list = np.arange(0, len(cluster_indices), 1)

for idx, stimulus in enumerate(psths.keys()):
bins = psths[stimulus]["bins"]
psth = psths[stimulus]["psth"]
Expand All @@ -384,8 +396,10 @@ def plot_raster(self, window: Union[list, list[list]], show_stim: bool = True):
psth = psth[:, :, np.logical_and(bins > sub_window[0], bins < sub_window[1])]
bins = bins[np.logical_and(bins >= sub_window[0], bins <= sub_window[1])]

for idx in range(np.shape(psth)[0]):
psth_sub = np.squeeze(psth[idx])
for idy in range(np.shape(psth)[0]):
if idy not in keep_list:
pass
psth_sub = np.squeeze(psth[idy])

if np.sum(psth_sub) == 0:
continue
Expand Down Expand Up @@ -426,7 +440,7 @@ def plot_raster(self, window: Union[list, list[list]], show_stim: bool = True):
sns.despine()
else:
self._despine(ax)
plt.title(f"{self.data.cluster_ids[idx]} stim: {stimulus}", size=7)
plt.title(f"{self.data.cluster_ids[idy]} stim: {stimulus}", size=7)
plt.figure(dpi=self.dpi)
plt.show()

Expand All @@ -436,6 +450,7 @@ def plot_sm_fr(
time_bin_ms: Union[float, list[float]],
sm_time_ms: Union[float, list[float]],
show_stim: bool = True,
include_ids: list | np.ndarray | None = None,
):
"""
Function to plot smoothed firing rates
Expand All @@ -452,7 +467,9 @@ def plot_sm_fr(
stimulus
show_stim: bool, default True
Show lines where stim onset and offset are
include_ids: list | np.ndarray | None
The ids to include for plotting
"""
import matplotlib as mpl
from .analysis_utils import histogram_functions as hf
Expand Down Expand Up @@ -488,6 +505,15 @@ def plot_sm_fr(
number of bins is{len(time_bin_ms)} and should be {NUM_STIM}"
time_bin_size = np.array(time_bin_ms) / 1000

if include_ids is not None:
cluster_indices = self.data.cluster_ids
keep_list = []
for cid in include_ids:
keep_list.append(np.where(cluster_indices==cid)[0][0])
keep_list = np.array(keep_list)
else:
keep_list = np.arange(0, len(cluster_indices), 1)

stim_trial_groups = self._get_trial_groups()
event_lengths = self._get_event_lengths_all()
for idx, stimulus in enumerate(psths.keys()):
Expand Down Expand Up @@ -519,6 +545,8 @@ def plot_sm_fr(
stderr = np.zeros((len(tg_set), len(bins)))
event_len = np.zeros((len(tg_set)))
for cluster_number in range(np.shape(psth)[0]):
if cluster_number not in keep_list:
pass
smoothed_psth = gaussian_smoothing(psth[cluster_number], bin_size, sm_std)

for trial_number, trial in enumerate(tg_set):
Expand Down

0 comments on commit 8edf467

Please sign in to comment.