Skip to content

Commit

Permalink
make any specific qc optional, fix test for this
Browse files Browse the repository at this point in the history
  • Loading branch information
zm711 committed Sep 13, 2023
1 parent 89cfb00 commit 2b8471e
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 25 deletions.
70 changes: 47 additions & 23 deletions src/spikeanalysis/spike_data.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Union
from typing import Union, Optional
import os

import numpy as np
Expand Down Expand Up @@ -460,19 +460,29 @@ def get_amplitudes(self, std: float = 2):

self.amplitude_index = amplitude_index

def qc_preprocessing(self, idthres: float, rpv: float, sil: float, amp_cutoff: float, recurated: bool = False):
def qc_preprocessing(
self,
idthres: Optional[float] = None,
rpv: Optional[float] = None,
sil: Optional[float] = None,
amp_cutoff: Optional[float] = None,
recurated: bool = False,
):
"""
function for curating data based on qc metrics and refractory periods
Parameters
----------
idthres : float
idthres : Optional[float], default: None
The cutoff isolation distance, 0 means no curation.
rpv : float
rpv : Optional[float], default: None
Fractional rate of refractory period violations, 0 is no violations and 1 would be all violations okay
sil : float
sil : Optional[float], default: None
Minimum silhouette score, [-1, 1], where bigger is better.
recurated : bool, optional
amp_cutoff: Optional[float], default = None
The percentage of spikes allowed to be over the user specified standard deviations (default 2) given as the
desired percentage. E.g. 0.98 means 98% of spikes are within 2 stds.
recurated : bool, default: False
If data has been recurated in phy since the last data run. The default is False.
Raises
Expand Down Expand Up @@ -504,35 +514,49 @@ def qc_preprocessing(self, idthres: float, rpv: float, sil: float, amp_cutoff: f
self.silhouette_scores = np.load("silhouette_scores.npy")
self.isolation_distances = np.load("isolation_distances.npy")
except FileNotFoundError:
raise Exception("qc metrics has not been run")
if idthres is None and sil is None:
pass
else:
raise Exception("qc metrics has not been run")
try:
_ = self.refractory_period_violations
except AttributeError:
try:
self.refractory_period_violations = np.load("refractory_period_violations.npy")
except FileNotFoundError:
raise Exception("refractory period violations not calculated")
if rpv is None:
pass
else:
raise Exception("refractory period violations not calculated")
try:
_ = self.amplitude_index
except AttributeError:
try:
self.amplitude_index = np.load("amplitude_distribution.npy")
except FileNotFoundError:
raise Exception("ampltiude scores not calculated")

assert len(self.silhouette_scores) == len(self.isolation_distances), "Qc metrics should be same length"
assert len(self.silhouette_scores) == len(
self.refractory_period_violations
), "Refractory period violations should be same length as qc"
assert len(self.amplitude_index) == len(self.silhouette_scores), "amplitudes scores must be the same length"

iso_d_thres = np.where(self.isolation_distances > idthres, True, False)
sil_thres = np.where(self.silhouette_scores > sil, True, False)
rpv_thres = np.where(self.refractory_period_violations < rpv, True, False)
amp_thres = np.where(self.amplitude_index > amp_cutoff, True, False)
threshold = np.logical_and(iso_d_thres, sil_thres)
threshold = np.logical_and(threshold, rpv_thres)
threshold = np.logical_and(threshold, amp_thres)
if amp_cutoff is None:
pass
else:
raise Exception("amplitude scores not calculated")

if idthres is not None:
assert len(self.silhouette_scores) == len(self.isolation_distances), "Qc metrics should be same length"
iso_d_thres = np.where(self.isolation_distances > idthres, True, False)
sil_thres = np.where(self.silhouette_scores > sil, True, False)
threshold = np.logical_and(iso_d_thres, sil_thres)
else:
threshold = np.array([True] * len(self._cids))

if rpv is not None:
assert len(self.refractory_period_violations) == len(
self._cids
), "mismatch between refactory period and cids"
rpv_thres = np.where(self.refractory_period_violations < rpv, True, False)
threshold = np.logical_and(threshold, rpv_thres)
if amp_cutoff is not None:
assert len(self.amplitude_index) == len(self._cids), "mismatch between amplitudes and cids"
amp_thres = np.where(self.amplitude_index > amp_cutoff, True, False)
threshold = np.logical_and(threshold, amp_thres)

self._qc_threshold = threshold

Expand Down
13 changes: 11 additions & 2 deletions test/test_spike_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,23 +278,32 @@ def test_qc_preprocessing(spikes, tmp_path):
file_path = spikes._file_path
spikes._file_path = spikes._file_path / tmp_path
os.chdir(spikes._file_path)
id = np.array([10, 30, 20])
ids = np.array([10, 30, 20])
sil = np.array([0.1, 0.4, 0.5])
ref = np.array([0.3, 0.001, 0.1])
amp = np.array([0.98, 0.98, 0.98])

np.save("isolation_distances.npy", id)
np.save("isolation_distances.npy", ids)
np.save("silhouette_scores.npy", sil)
np.save("refractory_period_violations.npy", ref)
np.save("amplitude_distribution.npy", amp)
spikes.CACHING = True
cids = spikes._cids
spikes._cids = np.array(
[
0,
1,
2,
]
)
spikes.qc_preprocessing(15, 0.02, 0.35, 0.97)

assert isinstance(spikes._qc_threshold, np.ndarray)

assert spikes._qc_threshold[0] == False
assert spikes._qc_threshold[1] == True
assert spikes._qc_threshold[2] == False
spikes._cids = cids
spikes._file_path = file_path
os.chdir(file_path)

Expand Down

0 comments on commit 2b8471e

Please sign in to comment.