From cf5cb08c10629be8a4aca1f27b6c6756bf950067 Mon Sep 17 00:00:00 2001 From: Zach McKenzie <92116279+zm711@users.noreply.github.com> Date: Fri, 25 Aug 2023 09:48:25 -0400 Subject: [PATCH 1/2] calculate refractory period violations with samples rather than time --- src/spikeanalysis/spike_data.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/spikeanalysis/spike_data.py b/src/spikeanalysis/spike_data.py index 76071ed..38a1517 100644 --- a/src/spikeanalysis/spike_data.py +++ b/src/spikeanalysis/spike_data.py @@ -224,15 +224,13 @@ def refractory_violation(self, ref_dur_ms: float): """ print("calculating refractory period violation fraction") self._goto_file_path() - ref_dur = ref_dur_ms / 1000 + ref_dur_samples = ref_dur_ms / 1000 * self._sampling_rate spike_clusters = np.squeeze(np.load("spike_clusters.npy")) violations = np.zeros((len(set(spike_clusters)))) violations[:] = np.nan - try: - spike_times = self.spike_times - except AttributeError: - spike_times = self.raw_spike_times / self._sampling_rate + + spike_times = self.raw_spike_times / self._sampling_rate for idx, cluster in enumerate(tqdm(set(spike_clusters))): spikes = spike_times[self.spike_clusters == cluster] @@ -240,7 +238,7 @@ def refractory_violation(self, ref_dur_ms: float): if len(spikes) < 10: continue else: - num_violations = float(len(np.where(np.diff(spikes) <= ref_dur)[0])) + num_violations = float(len(np.where(np.diff(spikes) <= ref_dur_samples)[0])) total_spikes = len(spikes) violations[idx] = num_violations / total_spikes From 8c5b95d88d74dc4c357618c2435eab88b1ffa09a Mon Sep 17 00:00:00 2001 From: Zach McKenzie <92116279+zm711@users.noreply.github.com> Date: Fri, 25 Aug 2023 09:52:36 -0400 Subject: [PATCH 2/2] use spike times in samples --- src/spikeanalysis/spike_data.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/spikeanalysis/spike_data.py b/src/spikeanalysis/spike_data.py index 38a1517..6290d5d 100644 --- a/src/spikeanalysis/spike_data.py +++ b/src/spikeanalysis/spike_data.py @@ -228,9 +228,7 @@ def refractory_violation(self, ref_dur_ms: float): spike_clusters = np.squeeze(np.load("spike_clusters.npy")) violations = np.zeros((len(set(spike_clusters)))) violations[:] = np.nan - - - spike_times = self.raw_spike_times / self._sampling_rate + spike_times = self.raw_spike_times for idx, cluster in enumerate(tqdm(set(spike_clusters))): spikes = spike_times[self.spike_clusters == cluster]