From 149c9e8a9449b2b729c837a2f6720cab4fcb74c8 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Tue, 12 Sep 2023 17:46:00 -0400 Subject: [PATCH 1/7] add rough amplitude cutoff metric --- src/spikeanalysis/spike_data.py | 46 ++++++++++++++++++++++++++++++--- 1 file changed, 43 insertions(+), 3 deletions(-) diff --git a/src/spikeanalysis/spike_data.py b/src/spikeanalysis/spike_data.py index 6290d5d..2412fea 100644 --- a/src/spikeanalysis/spike_data.py +++ b/src/spikeanalysis/spike_data.py @@ -432,7 +432,35 @@ def get_waveforms(self, wf_window: tuple = (-40, 41), n_wfs: int = 500): self._return_to_dir(current_dir) - def qc_preprocessing(self, idthres: float, rpv: float, sil: float, recurated: bool = False): + def get_amplitudes(self, std: float = 2): + """ + function for assessing amplitude distribution. + + Parameters + ---------- + std: float, default: 2 + The number of standard deviations to use when assessing the desired spread of the data + Returns + ------- + None. + + """ + + waveforms = self.waveforms + n_waveforms = waveforms.shape[1] + amplitudes = waveforms.max(axis=3) - waveforms.min(axis=3) + max_amplitudes = amplitudes.max(axis=2) + mean_amplitudes = max_amplitudes.mean(axis=1) + std_amplitudes = max_amplitudes.std(axis=1) + z_index = np.zeros((max_amplitudes.shape)) + for row_index in range(max_amplitudes.shape[0]): + z_index[row_index] = (max_amplitudes[row_index] - mean_amplitudes[row_index]) / std_amplitudes[row_index] + + amplitude_index = np.where(np.logical_and(z_index < std, z_index > -std), 1, 0).sum(axis=1) / n_waveforms + + self.amplitude_index = amplitude_index + + def qc_preprocessing(self, idthres: float, rpv: float, sil: float, amp_cutoff: float, recurated: bool = False): """ function for curating data based on qc metrics and refractory periods @@ -484,24 +512,34 @@ def qc_preprocessing(self, idthres: float, rpv: float, sil: float, recurated: bo self.refractory_period_violations = np.load("refractory_period_violations.npy") except FileNotFoundError: raise Exception("refractory period violations not calculated") + try: + _ = self.amplitude_index + except AttributeError: + try: + self.amplitude_index = np.load("amplitude_distribution.npy") + except FileNotFoundError: + raise Exception("ampltiude scores not calculated") assert len(self.silhouette_scores) == len(self.isolation_distances), "Qc metrics should be same length" assert len(self.silhouette_scores) == len( self.refractory_period_violations ), "Refractory period violations should be same length as qc" + assert len(self.amplitude_index) == len(self.silhouette_scores), "amplitudes scores must be the same length" iso_d_thres = np.where(self.isolation_distances > idthres, True, False) sil_thres = np.where(self.silhouette_scores > sil, True, False) rpv_thres = np.where(self.refractory_period_violations < rpv, True, False) - + amp_thres = np.where(self.amplitude_index > amp_cutoff, True, False) threshold = np.logical_and(iso_d_thres, sil_thres) threshold = np.logical_and(threshold, rpv_thres) + threshold = np.logical_and(threshold, amp_thres) self._qc_threshold = threshold self._isolation_threshold = idthres self._sil_threshold = sil self._rpv = rpv + self._amp_cutoff = amp_cutoff if self.CACHING: np.save("qc_threshold.npy", threshold) @@ -510,6 +548,7 @@ def qc_preprocessing(self, idthres: float, rpv: float, sil: float, recurated: bo print("Current qc_preprocessing values led to 0 units.") print(f"Iso: {np.sum(iso_d_thres)}, sil: {np.sum(sil)}") print(f"RPV: {np.sum(rpv)}") + print(f"amp cutoff: {np.sum(amp_thres)}") self._return_to_dir(current_dir) @@ -524,7 +563,8 @@ def set_qc(self): threshold = self._qc_threshold except AttributeError: raise Exception( - f"Must run qc functions first ('generate_pcs', 'generate_qcmetrics', 'refractory_violation')" + f"Must run qc functions first ('generate_pcs', 'generate_qcmetrics', 'refractory_violation'" + f"'get_amplitudes') " ) self._cids = self._cids[threshold] From 33acb4b1aebff047672d9d63bff079a1818aef86 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Tue, 12 Sep 2023 18:52:28 -0400 Subject: [PATCH 2/7] update test for new argument --- test/test_spike_data.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/test_spike_data.py b/test/test_spike_data.py index 980e6b0..c1413a3 100644 --- a/test/test_spike_data.py +++ b/test/test_spike_data.py @@ -155,6 +155,7 @@ def test_save_qc_parameters(spikes, tmp_path): spikes._isolation_threshold = 10 spikes._rpv = 0.02 spikes._sil_threshold = 0.4 + spikes._amp_cutoff = 0.98 spikes.save_qc_parameters() have_json = False @@ -280,12 +281,13 @@ def test_qc_preprocessing(spikes, tmp_path): id = np.array([10, 30, 20]) sil = np.array([0.1, 0.4, 0.5]) ref = np.array([0.3, 0.001, 0.1]) + amp = np.array([0.98, 0.98, 0.98]) np.save("isolation_distances.npy", id) np.save("silhouette_scores.npy", sil) np.save("refractory_period_violations.npy", ref) spikes.CACHING = True - spikes.qc_preprocessing(15, 0.02, 0.35) + spikes.qc_preprocessing(15, 0.02, 0.35, 0.97) assert isinstance(spikes._qc_threshold, np.ndarray) From 323eb4f4beb2a1c084f09c0d5701c5d358ef2f8e Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Tue, 12 Sep 2023 18:55:05 -0400 Subject: [PATCH 3/7] save test data --- test/test_spike_data.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_spike_data.py b/test/test_spike_data.py index c1413a3..d554d5e 100644 --- a/test/test_spike_data.py +++ b/test/test_spike_data.py @@ -286,6 +286,7 @@ def test_qc_preprocessing(spikes, tmp_path): np.save("isolation_distances.npy", id) np.save("silhouette_scores.npy", sil) np.save("refractory_period_violations.npy", ref) + np.save("ampltiude_distribution.npy", amp) spikes.CACHING = True spikes.qc_preprocessing(15, 0.02, 0.35, 0.97) From e443ba6c1954cfa91c4a23255367d5c481a5ac61 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Tue, 12 Sep 2023 18:57:31 -0400 Subject: [PATCH 4/7] fix typo --- test/test_spike_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_spike_data.py b/test/test_spike_data.py index d554d5e..e85a6c3 100644 --- a/test/test_spike_data.py +++ b/test/test_spike_data.py @@ -286,7 +286,7 @@ def test_qc_preprocessing(spikes, tmp_path): np.save("isolation_distances.npy", id) np.save("silhouette_scores.npy", sil) np.save("refractory_period_violations.npy", ref) - np.save("ampltiude_distribution.npy", amp) + np.save("amplitude_distribution.npy", amp) spikes.CACHING = True spikes.qc_preprocessing(15, 0.02, 0.35, 0.97) From 37ccb5664a16be127bfa88742b606b413fe64aba Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Wed, 13 Sep 2023 08:31:14 -0400 Subject: [PATCH 5/7] add test for get amplitudes --- test/test_spike_data.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/test/test_spike_data.py b/test/test_spike_data.py index e85a6c3..1ad4329 100644 --- a/test/test_spike_data.py +++ b/test/test_spike_data.py @@ -345,3 +345,21 @@ def test_load_waveforms(spikes, tmp_path): spikes._file_path = file_path os.chdir(spikes._file_path) + + +def test_get_amplitudes(spikes): + samples = np.random.normal(loc=1.0, scale=1, size=(82)) + samples2 = samples * 0.5 + + waveforms = np.random.rand(2, 10, 4, 82) + waveforms[1, :, 1, :] = samples + waveforms[1, ::2, 1, :] = samples2 + + spikes.waveforms = waveforms + spikes.get_amplitudes() + assert len(spikes.amplitude_index) == 2, "function failed" + print(spikes.amplitude_index) + + assert spikes.amplitude_index[1] == 1.0 + + assert spikes.amplitude_index[0] < spikes.amplitude_index[1] From 89cfb0038b934f20108f643eae0d91619e5c370b Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Wed, 13 Sep 2023 08:47:13 -0400 Subject: [PATCH 6/7] fix test conditions --- test/test_spike_data.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/test/test_spike_data.py b/test/test_spike_data.py index 1ad4329..944eed0 100644 --- a/test/test_spike_data.py +++ b/test/test_spike_data.py @@ -348,10 +348,14 @@ def test_load_waveforms(spikes, tmp_path): def test_get_amplitudes(spikes): - samples = np.random.normal(loc=1.0, scale=1, size=(82)) - samples2 = samples * 0.5 - - waveforms = np.random.rand(2, 10, 4, 82) + samples = np.random.normal(loc=5.0, scale=1, size=(82)) + samples2 = samples * 50 + + large_std = samples + large_std2 = samples * 1000 + waveforms = np.zeros((2, 1000, 4, 82)) + waveforms[0, :, 2, :] = large_std + waveforms[0, :40, 2, :] = large_std2 waveforms[1, :, 1, :] = samples waveforms[1, ::2, 1, :] = samples2 From 2b8471ec0966eb0f9c1ae33480026996cdca6f34 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Wed, 13 Sep 2023 09:34:35 -0400 Subject: [PATCH 7/7] make any specific qc optional, fix test for this --- src/spikeanalysis/spike_data.py | 70 ++++++++++++++++++++++----------- test/test_spike_data.py | 13 +++++- 2 files changed, 58 insertions(+), 25 deletions(-) diff --git a/src/spikeanalysis/spike_data.py b/src/spikeanalysis/spike_data.py index 2412fea..8039192 100644 --- a/src/spikeanalysis/spike_data.py +++ b/src/spikeanalysis/spike_data.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Union +from typing import Union, Optional import os import numpy as np @@ -460,19 +460,29 @@ def get_amplitudes(self, std: float = 2): self.amplitude_index = amplitude_index - def qc_preprocessing(self, idthres: float, rpv: float, sil: float, amp_cutoff: float, recurated: bool = False): + def qc_preprocessing( + self, + idthres: Optional[float] = None, + rpv: Optional[float] = None, + sil: Optional[float] = None, + amp_cutoff: Optional[float] = None, + recurated: bool = False, + ): """ function for curating data based on qc metrics and refractory periods Parameters ---------- - idthres : float + idthres : Optional[float], default: None The cutoff isolation distance, 0 means no curation. - rpv : float + rpv : Optional[float], default: None Fractional rate of refractory period violations, 0 is no violations and 1 would be all violations okay - sil : float + sil : Optional[float], default: None Minimum silhouette score, [-1, 1], where bigger is better. - recurated : bool, optional + amp_cutoff: Optional[float], default = None + The percentage of spikes allowed to be over the user specified standard deviations (default 2) given as the + desired percentage. E.g. 0.98 means 98% of spikes are within 2 stds. + recurated : bool, default: False If data has been recurated in phy since the last data run. The default is False. Raises @@ -504,35 +514,49 @@ def qc_preprocessing(self, idthres: float, rpv: float, sil: float, amp_cutoff: f self.silhouette_scores = np.load("silhouette_scores.npy") self.isolation_distances = np.load("isolation_distances.npy") except FileNotFoundError: - raise Exception("qc metrics has not been run") + if idthres is None and sil is None: + pass + else: + raise Exception("qc metrics has not been run") try: _ = self.refractory_period_violations except AttributeError: try: self.refractory_period_violations = np.load("refractory_period_violations.npy") except FileNotFoundError: - raise Exception("refractory period violations not calculated") + if rpv is None: + pass + else: + raise Exception("refractory period violations not calculated") try: _ = self.amplitude_index except AttributeError: try: self.amplitude_index = np.load("amplitude_distribution.npy") except FileNotFoundError: - raise Exception("ampltiude scores not calculated") - - assert len(self.silhouette_scores) == len(self.isolation_distances), "Qc metrics should be same length" - assert len(self.silhouette_scores) == len( - self.refractory_period_violations - ), "Refractory period violations should be same length as qc" - assert len(self.amplitude_index) == len(self.silhouette_scores), "amplitudes scores must be the same length" - - iso_d_thres = np.where(self.isolation_distances > idthres, True, False) - sil_thres = np.where(self.silhouette_scores > sil, True, False) - rpv_thres = np.where(self.refractory_period_violations < rpv, True, False) - amp_thres = np.where(self.amplitude_index > amp_cutoff, True, False) - threshold = np.logical_and(iso_d_thres, sil_thres) - threshold = np.logical_and(threshold, rpv_thres) - threshold = np.logical_and(threshold, amp_thres) + if amp_cutoff is None: + pass + else: + raise Exception("amplitude scores not calculated") + + if idthres is not None: + assert len(self.silhouette_scores) == len(self.isolation_distances), "Qc metrics should be same length" + iso_d_thres = np.where(self.isolation_distances > idthres, True, False) + sil_thres = np.where(self.silhouette_scores > sil, True, False) + threshold = np.logical_and(iso_d_thres, sil_thres) + else: + threshold = np.array([True] * len(self._cids)) + + if rpv is not None: + assert len(self.refractory_period_violations) == len( + self._cids + ), "mismatch between refactory period and cids" + rpv_thres = np.where(self.refractory_period_violations < rpv, True, False) + threshold = np.logical_and(threshold, rpv_thres) + if amp_cutoff is not None: + assert len(self.amplitude_index) == len(self._cids), "mismatch between amplitudes and cids" + amp_thres = np.where(self.amplitude_index > amp_cutoff, True, False) + threshold = np.logical_and(threshold, amp_thres) self._qc_threshold = threshold diff --git a/test/test_spike_data.py b/test/test_spike_data.py index 944eed0..025da04 100644 --- a/test/test_spike_data.py +++ b/test/test_spike_data.py @@ -278,16 +278,24 @@ def test_qc_preprocessing(spikes, tmp_path): file_path = spikes._file_path spikes._file_path = spikes._file_path / tmp_path os.chdir(spikes._file_path) - id = np.array([10, 30, 20]) + ids = np.array([10, 30, 20]) sil = np.array([0.1, 0.4, 0.5]) ref = np.array([0.3, 0.001, 0.1]) amp = np.array([0.98, 0.98, 0.98]) - np.save("isolation_distances.npy", id) + np.save("isolation_distances.npy", ids) np.save("silhouette_scores.npy", sil) np.save("refractory_period_violations.npy", ref) np.save("amplitude_distribution.npy", amp) spikes.CACHING = True + cids = spikes._cids + spikes._cids = np.array( + [ + 0, + 1, + 2, + ] + ) spikes.qc_preprocessing(15, 0.02, 0.35, 0.97) assert isinstance(spikes._qc_threshold, np.ndarray) @@ -295,6 +303,7 @@ def test_qc_preprocessing(spikes, tmp_path): assert spikes._qc_threshold[0] == False assert spikes._qc_threshold[1] == True assert spikes._qc_threshold[2] == False + spikes._cids = cids spikes._file_path = file_path os.chdir(file_path)