From 2b8471ec0966eb0f9c1ae33480026996cdca6f34 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Wed, 13 Sep 2023 09:34:35 -0400 Subject: [PATCH] make any specific qc optional, fix test for this --- src/spikeanalysis/spike_data.py | 70 ++++++++++++++++++++++----------- test/test_spike_data.py | 13 +++++- 2 files changed, 58 insertions(+), 25 deletions(-) diff --git a/src/spikeanalysis/spike_data.py b/src/spikeanalysis/spike_data.py index 2412fea..8039192 100644 --- a/src/spikeanalysis/spike_data.py +++ b/src/spikeanalysis/spike_data.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Union +from typing import Union, Optional import os import numpy as np @@ -460,19 +460,29 @@ def get_amplitudes(self, std: float = 2): self.amplitude_index = amplitude_index - def qc_preprocessing(self, idthres: float, rpv: float, sil: float, amp_cutoff: float, recurated: bool = False): + def qc_preprocessing( + self, + idthres: Optional[float] = None, + rpv: Optional[float] = None, + sil: Optional[float] = None, + amp_cutoff: Optional[float] = None, + recurated: bool = False, + ): """ function for curating data based on qc metrics and refractory periods Parameters ---------- - idthres : float + idthres : Optional[float], default: None The cutoff isolation distance, 0 means no curation. - rpv : float + rpv : Optional[float], default: None Fractional rate of refractory period violations, 0 is no violations and 1 would be all violations okay - sil : float + sil : Optional[float], default: None Minimum silhouette score, [-1, 1], where bigger is better. - recurated : bool, optional + amp_cutoff: Optional[float], default = None + The percentage of spikes allowed to be over the user specified standard deviations (default 2) given as the + desired percentage. E.g. 0.98 means 98% of spikes are within 2 stds. + recurated : bool, default: False If data has been recurated in phy since the last data run. The default is False. Raises @@ -504,35 +514,49 @@ def qc_preprocessing(self, idthres: float, rpv: float, sil: float, amp_cutoff: f self.silhouette_scores = np.load("silhouette_scores.npy") self.isolation_distances = np.load("isolation_distances.npy") except FileNotFoundError: - raise Exception("qc metrics has not been run") + if idthres is None and sil is None: + pass + else: + raise Exception("qc metrics has not been run") try: _ = self.refractory_period_violations except AttributeError: try: self.refractory_period_violations = np.load("refractory_period_violations.npy") except FileNotFoundError: - raise Exception("refractory period violations not calculated") + if rpv is None: + pass + else: + raise Exception("refractory period violations not calculated") try: _ = self.amplitude_index except AttributeError: try: self.amplitude_index = np.load("amplitude_distribution.npy") except FileNotFoundError: - raise Exception("ampltiude scores not calculated") - - assert len(self.silhouette_scores) == len(self.isolation_distances), "Qc metrics should be same length" - assert len(self.silhouette_scores) == len( - self.refractory_period_violations - ), "Refractory period violations should be same length as qc" - assert len(self.amplitude_index) == len(self.silhouette_scores), "amplitudes scores must be the same length" - - iso_d_thres = np.where(self.isolation_distances > idthres, True, False) - sil_thres = np.where(self.silhouette_scores > sil, True, False) - rpv_thres = np.where(self.refractory_period_violations < rpv, True, False) - amp_thres = np.where(self.amplitude_index > amp_cutoff, True, False) - threshold = np.logical_and(iso_d_thres, sil_thres) - threshold = np.logical_and(threshold, rpv_thres) - threshold = np.logical_and(threshold, amp_thres) + if amp_cutoff is None: + pass + else: + raise Exception("amplitude scores not calculated") + + if idthres is not None: + assert len(self.silhouette_scores) == len(self.isolation_distances), "Qc metrics should be same length" + iso_d_thres = np.where(self.isolation_distances > idthres, True, False) + sil_thres = np.where(self.silhouette_scores > sil, True, False) + threshold = np.logical_and(iso_d_thres, sil_thres) + else: + threshold = np.array([True] * len(self._cids)) + + if rpv is not None: + assert len(self.refractory_period_violations) == len( + self._cids + ), "mismatch between refactory period and cids" + rpv_thres = np.where(self.refractory_period_violations < rpv, True, False) + threshold = np.logical_and(threshold, rpv_thres) + if amp_cutoff is not None: + assert len(self.amplitude_index) == len(self._cids), "mismatch between amplitudes and cids" + amp_thres = np.where(self.amplitude_index > amp_cutoff, True, False) + threshold = np.logical_and(threshold, amp_thres) self._qc_threshold = threshold diff --git a/test/test_spike_data.py b/test/test_spike_data.py index 944eed0..025da04 100644 --- a/test/test_spike_data.py +++ b/test/test_spike_data.py @@ -278,16 +278,24 @@ def test_qc_preprocessing(spikes, tmp_path): file_path = spikes._file_path spikes._file_path = spikes._file_path / tmp_path os.chdir(spikes._file_path) - id = np.array([10, 30, 20]) + ids = np.array([10, 30, 20]) sil = np.array([0.1, 0.4, 0.5]) ref = np.array([0.3, 0.001, 0.1]) amp = np.array([0.98, 0.98, 0.98]) - np.save("isolation_distances.npy", id) + np.save("isolation_distances.npy", ids) np.save("silhouette_scores.npy", sil) np.save("refractory_period_violations.npy", ref) np.save("amplitude_distribution.npy", amp) spikes.CACHING = True + cids = spikes._cids + spikes._cids = np.array( + [ + 0, + 1, + 2, + ] + ) spikes.qc_preprocessing(15, 0.02, 0.35, 0.97) assert isinstance(spikes._qc_threshold, np.ndarray) @@ -295,6 +303,7 @@ def test_qc_preprocessing(spikes, tmp_path): assert spikes._qc_threshold[0] == False assert spikes._qc_threshold[1] == True assert spikes._qc_threshold[2] == False + spikes._cids = cids spikes._file_path = file_path os.chdir(file_path)