diff --git a/src/spikeanalysis/spike_data.py b/src/spikeanalysis/spike_data.py index 6290d5d..2412fea 100644 --- a/src/spikeanalysis/spike_data.py +++ b/src/spikeanalysis/spike_data.py @@ -432,7 +432,35 @@ def get_waveforms(self, wf_window: tuple = (-40, 41), n_wfs: int = 500): self._return_to_dir(current_dir) - def qc_preprocessing(self, idthres: float, rpv: float, sil: float, recurated: bool = False): + def get_amplitudes(self, std: float = 2): + """ + function for assessing amplitude distribution. + + Parameters + ---------- + std: float, default: 2 + The number of standard deviations to use when assessing the desired spread of the data + Returns + ------- + None. + + """ + + waveforms = self.waveforms + n_waveforms = waveforms.shape[1] + amplitudes = waveforms.max(axis=3) - waveforms.min(axis=3) + max_amplitudes = amplitudes.max(axis=2) + mean_amplitudes = max_amplitudes.mean(axis=1) + std_amplitudes = max_amplitudes.std(axis=1) + z_index = np.zeros((max_amplitudes.shape)) + for row_index in range(max_amplitudes.shape[0]): + z_index[row_index] = (max_amplitudes[row_index] - mean_amplitudes[row_index]) / std_amplitudes[row_index] + + amplitude_index = np.where(np.logical_and(z_index < std, z_index > -std), 1, 0).sum(axis=1) / n_waveforms + + self.amplitude_index = amplitude_index + + def qc_preprocessing(self, idthres: float, rpv: float, sil: float, amp_cutoff: float, recurated: bool = False): """ function for curating data based on qc metrics and refractory periods @@ -484,24 +512,34 @@ def qc_preprocessing(self, idthres: float, rpv: float, sil: float, recurated: bo self.refractory_period_violations = np.load("refractory_period_violations.npy") except FileNotFoundError: raise Exception("refractory period violations not calculated") + try: + _ = self.amplitude_index + except AttributeError: + try: + self.amplitude_index = np.load("amplitude_distribution.npy") + except FileNotFoundError: + raise Exception("ampltiude scores not calculated") assert len(self.silhouette_scores) == len(self.isolation_distances), "Qc metrics should be same length" assert len(self.silhouette_scores) == len( self.refractory_period_violations ), "Refractory period violations should be same length as qc" + assert len(self.amplitude_index) == len(self.silhouette_scores), "amplitudes scores must be the same length" iso_d_thres = np.where(self.isolation_distances > idthres, True, False) sil_thres = np.where(self.silhouette_scores > sil, True, False) rpv_thres = np.where(self.refractory_period_violations < rpv, True, False) - + amp_thres = np.where(self.amplitude_index > amp_cutoff, True, False) threshold = np.logical_and(iso_d_thres, sil_thres) threshold = np.logical_and(threshold, rpv_thres) + threshold = np.logical_and(threshold, amp_thres) self._qc_threshold = threshold self._isolation_threshold = idthres self._sil_threshold = sil self._rpv = rpv + self._amp_cutoff = amp_cutoff if self.CACHING: np.save("qc_threshold.npy", threshold) @@ -510,6 +548,7 @@ def qc_preprocessing(self, idthres: float, rpv: float, sil: float, recurated: bo print("Current qc_preprocessing values led to 0 units.") print(f"Iso: {np.sum(iso_d_thres)}, sil: {np.sum(sil)}") print(f"RPV: {np.sum(rpv)}") + print(f"amp cutoff: {np.sum(amp_thres)}") self._return_to_dir(current_dir) @@ -524,7 +563,8 @@ def set_qc(self): threshold = self._qc_threshold except AttributeError: raise Exception( - f"Must run qc functions first ('generate_pcs', 'generate_qcmetrics', 'refractory_violation')" + f"Must run qc functions first ('generate_pcs', 'generate_qcmetrics', 'refractory_violation'" + f"'get_amplitudes') " ) self._cids = self._cids[threshold]