Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an additional qc metric-Amplitude #40

Merged
merged 9 commits into from
Sep 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 85 additions & 21 deletions src/spikeanalysis/spike_data.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Union
from typing import Union, Optional
import os

import numpy as np
Expand Down Expand Up @@ -432,19 +432,57 @@ def get_waveforms(self, wf_window: tuple = (-40, 41), n_wfs: int = 500):

self._return_to_dir(current_dir)

def qc_preprocessing(self, idthres: float, rpv: float, sil: float, recurated: bool = False):
def get_amplitudes(self, std: float = 2):
"""
function for assessing amplitude distribution.

Parameters
----------
std: float, default: 2
The number of standard deviations to use when assessing the desired spread of the data
Returns
-------
None.

"""

waveforms = self.waveforms
n_waveforms = waveforms.shape[1]
amplitudes = waveforms.max(axis=3) - waveforms.min(axis=3)
max_amplitudes = amplitudes.max(axis=2)
mean_amplitudes = max_amplitudes.mean(axis=1)
std_amplitudes = max_amplitudes.std(axis=1)
z_index = np.zeros((max_amplitudes.shape))
for row_index in range(max_amplitudes.shape[0]):
z_index[row_index] = (max_amplitudes[row_index] - mean_amplitudes[row_index]) / std_amplitudes[row_index]

amplitude_index = np.where(np.logical_and(z_index < std, z_index > -std), 1, 0).sum(axis=1) / n_waveforms

self.amplitude_index = amplitude_index

def qc_preprocessing(
self,
idthres: Optional[float] = None,
rpv: Optional[float] = None,
sil: Optional[float] = None,
amp_cutoff: Optional[float] = None,
recurated: bool = False,
):
"""
function for curating data based on qc metrics and refractory periods

Parameters
----------
idthres : float
idthres : Optional[float], default: None
The cutoff isolation distance, 0 means no curation.
rpv : float
rpv : Optional[float], default: None
Fractional rate of refractory period violations, 0 is no violations and 1 would be all violations okay
sil : float
sil : Optional[float], default: None
Minimum silhouette score, [-1, 1], where bigger is better.
recurated : bool, optional
amp_cutoff: Optional[float], default = None
The percentage of spikes allowed to be over the user specified standard deviations (default 2) given as the
desired percentage. E.g. 0.98 means 98% of spikes are within 2 stds.
recurated : bool, default: False
If data has been recurated in phy since the last data run. The default is False.

Raises
Expand Down Expand Up @@ -476,32 +514,56 @@ def qc_preprocessing(self, idthres: float, rpv: float, sil: float, recurated: bo
self.silhouette_scores = np.load("silhouette_scores.npy")
self.isolation_distances = np.load("isolation_distances.npy")
except FileNotFoundError:
raise Exception("qc metrics has not been run")
if idthres is None and sil is None:
pass
else:
raise Exception("qc metrics has not been run")
try:
_ = self.refractory_period_violations
except AttributeError:
try:
self.refractory_period_violations = np.load("refractory_period_violations.npy")
except FileNotFoundError:
raise Exception("refractory period violations not calculated")

assert len(self.silhouette_scores) == len(self.isolation_distances), "Qc metrics should be same length"
assert len(self.silhouette_scores) == len(
self.refractory_period_violations
), "Refractory period violations should be same length as qc"

iso_d_thres = np.where(self.isolation_distances > idthres, True, False)
sil_thres = np.where(self.silhouette_scores > sil, True, False)
rpv_thres = np.where(self.refractory_period_violations < rpv, True, False)

threshold = np.logical_and(iso_d_thres, sil_thres)
threshold = np.logical_and(threshold, rpv_thres)
if rpv is None:
pass
else:
raise Exception("refractory period violations not calculated")
try:
_ = self.amplitude_index
except AttributeError:
try:
self.amplitude_index = np.load("amplitude_distribution.npy")
except FileNotFoundError:
if amp_cutoff is None:
pass
else:
raise Exception("amplitude scores not calculated")

if idthres is not None:
assert len(self.silhouette_scores) == len(self.isolation_distances), "Qc metrics should be same length"
iso_d_thres = np.where(self.isolation_distances > idthres, True, False)
sil_thres = np.where(self.silhouette_scores > sil, True, False)
threshold = np.logical_and(iso_d_thres, sil_thres)
else:
threshold = np.array([True] * len(self._cids))

if rpv is not None:
assert len(self.refractory_period_violations) == len(
self._cids
), "mismatch between refactory period and cids"
rpv_thres = np.where(self.refractory_period_violations < rpv, True, False)
threshold = np.logical_and(threshold, rpv_thres)
if amp_cutoff is not None:
assert len(self.amplitude_index) == len(self._cids), "mismatch between amplitudes and cids"
amp_thres = np.where(self.amplitude_index > amp_cutoff, True, False)
threshold = np.logical_and(threshold, amp_thres)

self._qc_threshold = threshold

self._isolation_threshold = idthres
self._sil_threshold = sil
self._rpv = rpv
self._amp_cutoff = amp_cutoff

if self.CACHING:
np.save("qc_threshold.npy", threshold)
Expand All @@ -510,6 +572,7 @@ def qc_preprocessing(self, idthres: float, rpv: float, sil: float, recurated: bo
print("Current qc_preprocessing values led to 0 units.")
print(f"Iso: {np.sum(iso_d_thres)}, sil: {np.sum(sil)}")
print(f"RPV: {np.sum(rpv)}")
print(f"amp cutoff: {np.sum(amp_thres)}")

self._return_to_dir(current_dir)

Expand All @@ -524,7 +587,8 @@ def set_qc(self):
threshold = self._qc_threshold
except AttributeError:
raise Exception(
f"Must run qc functions first ('generate_pcs', 'generate_qcmetrics', 'refractory_violation')"
f"Must run qc functions first ('generate_pcs', 'generate_qcmetrics', 'refractory_violation'"
f"'get_amplitudes') "
)

self._cids = self._cids[threshold]
Expand Down
40 changes: 37 additions & 3 deletions test/test_spike_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def test_save_qc_parameters(spikes, tmp_path):
spikes._isolation_threshold = 10
spikes._rpv = 0.02
spikes._sil_threshold = 0.4
spikes._amp_cutoff = 0.98
spikes.save_qc_parameters()
have_json = False

Expand Down Expand Up @@ -277,21 +278,32 @@ def test_qc_preprocessing(spikes, tmp_path):
file_path = spikes._file_path
spikes._file_path = spikes._file_path / tmp_path
os.chdir(spikes._file_path)
id = np.array([10, 30, 20])
ids = np.array([10, 30, 20])
sil = np.array([0.1, 0.4, 0.5])
ref = np.array([0.3, 0.001, 0.1])
amp = np.array([0.98, 0.98, 0.98])

np.save("isolation_distances.npy", id)
np.save("isolation_distances.npy", ids)
np.save("silhouette_scores.npy", sil)
np.save("refractory_period_violations.npy", ref)
np.save("amplitude_distribution.npy", amp)
spikes.CACHING = True
spikes.qc_preprocessing(15, 0.02, 0.35)
cids = spikes._cids
spikes._cids = np.array(
[
0,
1,
2,
]
)
spikes.qc_preprocessing(15, 0.02, 0.35, 0.97)

assert isinstance(spikes._qc_threshold, np.ndarray)

assert spikes._qc_threshold[0] == False
assert spikes._qc_threshold[1] == True
assert spikes._qc_threshold[2] == False
spikes._cids = cids
spikes._file_path = file_path
os.chdir(file_path)

Expand Down Expand Up @@ -342,3 +354,25 @@ def test_load_waveforms(spikes, tmp_path):

spikes._file_path = file_path
os.chdir(spikes._file_path)


def test_get_amplitudes(spikes):
samples = np.random.normal(loc=5.0, scale=1, size=(82))
samples2 = samples * 50

large_std = samples
large_std2 = samples * 1000
waveforms = np.zeros((2, 1000, 4, 82))
waveforms[0, :, 2, :] = large_std
waveforms[0, :40, 2, :] = large_std2
waveforms[1, :, 1, :] = samples
waveforms[1, ::2, 1, :] = samples2

spikes.waveforms = waveforms
spikes.get_amplitudes()
assert len(spikes.amplitude_index) == 2, "function failed"
print(spikes.amplitude_index)

assert spikes.amplitude_index[1] == 1.0

assert spikes.amplitude_index[0] < spikes.amplitude_index[1]
Loading