Skip to content

Commit

Permalink
add rough amplitude cutoff metric
Browse files Browse the repository at this point in the history
  • Loading branch information
zm711 committed Sep 12, 2023
1 parent d901e65 commit 149c9e8
Showing 1 changed file with 43 additions and 3 deletions.
46 changes: 43 additions & 3 deletions src/spikeanalysis/spike_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,35 @@ def get_waveforms(self, wf_window: tuple = (-40, 41), n_wfs: int = 500):

self._return_to_dir(current_dir)

def qc_preprocessing(self, idthres: float, rpv: float, sil: float, recurated: bool = False):
def get_amplitudes(self, std: float = 2):
"""
function for assessing amplitude distribution.
Parameters
----------
std: float, default: 2
The number of standard deviations to use when assessing the desired spread of the data
Returns
-------
None.
"""

waveforms = self.waveforms
n_waveforms = waveforms.shape[1]
amplitudes = waveforms.max(axis=3) - waveforms.min(axis=3)
max_amplitudes = amplitudes.max(axis=2)
mean_amplitudes = max_amplitudes.mean(axis=1)
std_amplitudes = max_amplitudes.std(axis=1)
z_index = np.zeros((max_amplitudes.shape))
for row_index in range(max_amplitudes.shape[0]):
z_index[row_index] = (max_amplitudes[row_index] - mean_amplitudes[row_index]) / std_amplitudes[row_index]

amplitude_index = np.where(np.logical_and(z_index < std, z_index > -std), 1, 0).sum(axis=1) / n_waveforms

self.amplitude_index = amplitude_index

def qc_preprocessing(self, idthres: float, rpv: float, sil: float, amp_cutoff: float, recurated: bool = False):
"""
function for curating data based on qc metrics and refractory periods
Expand Down Expand Up @@ -484,24 +512,34 @@ def qc_preprocessing(self, idthres: float, rpv: float, sil: float, recurated: bo
self.refractory_period_violations = np.load("refractory_period_violations.npy")
except FileNotFoundError:
raise Exception("refractory period violations not calculated")
try:
_ = self.amplitude_index
except AttributeError:
try:
self.amplitude_index = np.load("amplitude_distribution.npy")
except FileNotFoundError:
raise Exception("ampltiude scores not calculated")

assert len(self.silhouette_scores) == len(self.isolation_distances), "Qc metrics should be same length"
assert len(self.silhouette_scores) == len(
self.refractory_period_violations
), "Refractory period violations should be same length as qc"
assert len(self.amplitude_index) == len(self.silhouette_scores), "amplitudes scores must be the same length"

iso_d_thres = np.where(self.isolation_distances > idthres, True, False)
sil_thres = np.where(self.silhouette_scores > sil, True, False)
rpv_thres = np.where(self.refractory_period_violations < rpv, True, False)

amp_thres = np.where(self.amplitude_index > amp_cutoff, True, False)
threshold = np.logical_and(iso_d_thres, sil_thres)
threshold = np.logical_and(threshold, rpv_thres)
threshold = np.logical_and(threshold, amp_thres)

self._qc_threshold = threshold

self._isolation_threshold = idthres
self._sil_threshold = sil
self._rpv = rpv
self._amp_cutoff = amp_cutoff

if self.CACHING:
np.save("qc_threshold.npy", threshold)
Expand All @@ -510,6 +548,7 @@ def qc_preprocessing(self, idthres: float, rpv: float, sil: float, recurated: bo
print("Current qc_preprocessing values led to 0 units.")
print(f"Iso: {np.sum(iso_d_thres)}, sil: {np.sum(sil)}")
print(f"RPV: {np.sum(rpv)}")
print(f"amp cutoff: {np.sum(amp_thres)}")

self._return_to_dir(current_dir)

Expand All @@ -524,7 +563,8 @@ def set_qc(self):
threshold = self._qc_threshold
except AttributeError:
raise Exception(
f"Must run qc functions first ('generate_pcs', 'generate_qcmetrics', 'refractory_violation')"
f"Must run qc functions first ('generate_pcs', 'generate_qcmetrics', 'refractory_violation'"
f"'get_amplitudes') "
)

self._cids = self._cids[threshold]
Expand Down

0 comments on commit 149c9e8

Please sign in to comment.