Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented sd_ratio as quality metric #2146

Merged
merged 23 commits into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
1cf9a97
Work on SD test
DradeAW Oct 31, 2023
9900bca
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 31, 2023
ed29d08
Work on SD test + circular import
DradeAW Oct 31, 2023
690d33d
Implemented SD test
DradeAW Nov 2, 2023
ded2c32
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 2, 2023
a1f9a3f
Merge branch 'main' into sd_test
DradeAW Nov 2, 2023
5ae3219
Added doc for SD test
DradeAW Nov 2, 2023
288029a
Multiple fixed for `sd_ratio` metric
DradeAW Nov 2, 2023
9d5c270
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 2, 2023
894ab2c
Fix `sd_ratio`
DradeAW Nov 2, 2023
edc0e91
Merge branch 'sd_test' of github.com:DradeAW/spikeinterface into sd_test
DradeAW Nov 2, 2023
cb477f9
Suggestions from Zach and Alessio for `sd_ratio`
DradeAW Nov 2, 2023
dcdbb1f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 2, 2023
0ed6796
oops
DradeAW Nov 2, 2023
660de12
Fix `sd_ratio` recordless
DradeAW Nov 2, 2023
47581f5
Fixed `sd_ratio`
DradeAW Nov 2, 2023
602f70f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 2, 2023
7cd707d
Merge branch 'main' into sd_test
DradeAW Nov 2, 2023
33232bf
Alphabetical order for quality metrics in docs
DradeAW Nov 3, 2023
0469792
Merge branch 'main' into sd_test
DradeAW Nov 14, 2023
f18eb9c
Merge branch 'main' into sd_test
DradeAW Nov 17, 2023
848ee93
Added option to correct for template itself in `sd_ratio`
DradeAW Nov 20, 2023
cd5eb3f
Merge branch 'main' into sd_test
alejoe91 Nov 22, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions doc/modules/qualitymetrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,12 @@ For more details about each metric and it's availability and use within SpikeInt
qualitymetrics/isolation_distance
qualitymetrics/l_ratio
qualitymetrics/nearest_neighbor
qualitymetrics/noise_cutoff
qualitymetrics/presence_ratio
qualitymetrics/sd_ratio
qualitymetrics/silhouette_score
qualitymetrics/sliding_rp_violations
qualitymetrics/snr
qualitymetrics/noise_cutoff
qualitymetrics/silhouette_score
qualitymetrics/synchrony


Expand Down
2 changes: 2 additions & 0 deletions doc/modules/qualitymetrics/references.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ References

.. [Llobet] Llobet Victor, Wyngaard Aurélien and Barbour Boris. “Automatic post-processing and merging of multiple spike-sorting analyses with Lussac“. BioRxiv (2022).

.. [Pouzat] Pouzat Christophe, Mazor Ofer and Laurent Gilles. “Using noise signature to optimize spike-sorting and to assess neuronal classification quality“. Journal of Neuroscience Methods (2002).

.. [Rousseeuw] Peter J Rousseeuw. Silhouettes: A graphical aid to the interpretation and validation of cluster analysis. Journal of computational and applied mathematics, 20(C):53–65, 1987.

.. [Schmitzer-Torbert] Schmitzer-Torbert, Neil, and A. David Redish. “Neuronal Activity in the Rodent Dorsal Striatum in Sequential Navigation: Separation of Spatial and Reward Responses on the Multiple T Task.” Journal of neurophysiology 91.5 (2004): 2259–2272. Web.
Expand Down
38 changes: 38 additions & 0 deletions doc/modules/qualitymetrics/sd_ratio.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
Standard Deviation (SD) ratio (:code:`sd_ratio`)
==============================================

Calculation
-----------

All spikes from the same neuron should have the same shape. This means that at the peak of the spike, the standard deviation of the voltage should be the same as that of noise. If spikes from multiple neurons are grouped into a single unit, the standard deviation of spike amplitudes would likely be increased.

This metric, first described [Pouzat]_ then adapted by Wyngaard, Llobet & Barbour (in preparation), returns the ratio between both standard deviations:

.. math::
S = \frac{\sigma_{\mathrm{unit}}}{\sigma_{\mathrm{noise}}}

To remove the effect of drift on spikes amplitude, :math:`\sigma_{\mathrm{unit}}` is computed by subtracting each spike amplitude, and dividing the resulting standard deviation by :math:`\sqrt{2}`.
Also to remove the effect of bursts (which can have lower amplitudes), you can specify a censored period (by default 4.0 ms) where spikes happening less than this period after another spike will not be considered.


Expectation and use
-------------------

For a unit representing a single neuron, this metric should return a value close to one. However for units that are contaminated, the value can be significantly higher.


Example code
------------

.. code-block:: python

import spikeinterface.qualitymetrics as sqm
DradeAW marked this conversation as resolved.
Show resolved Hide resolved

sd_ratio = sqm.compute_sd_ratio(wvf_extractor, censored_period_ms=4.0)


Literature
----------

Introduced by [Pouzat]_ (2002).
Expanded by Wyngaard, Llobet and Barbour (in preparation).
105 changes: 103 additions & 2 deletions src/spikeinterface/qualitymetrics/misc_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
import numpy as np
import warnings

from ..postprocessing import correlogram_for_one_segment
from ..core import get_noise_levels
from ..postprocessing import compute_spike_amplitudes, correlogram_for_one_segment
from ..core import WaveformExtractor, get_noise_levels
from ..core.template_tools import (
get_template_extremum_channel,
get_template_extremum_amplitude,
Expand Down Expand Up @@ -1365,3 +1365,104 @@ def _compute_rp_violations_numba(nb_rp_violations, spike_trains, spike_clusters,
spike_train = spike_trains[spike_clusters == i]
n_v = _compute_nb_violations_numba(spike_train, t_r)
nb_rp_violations[i] += n_v


def compute_sd_ratio(
wvf_extractor: WaveformExtractor,
censored_period_ms: float = 4.0,
correct_for_drift: bool = True,
correct_for_template_itself: bool = True,
unit_ids=None,
**kwargs,
):
"""
Computes the SD (Standard Deviation) of each unit's spike amplitudes, and compare it to the SD of noise.
In this case, noise refers to the global voltage trace on the same channel as the best channel of the unit.
(ideally (not implemented yet), the noise would be computed outside of spikes from the unit itself).

Parameters
----------
waveform_extractor : WaveformExtractor
The waveform extractor object.
censored_period_ms : float, default: 4.0
The censored period in milliseconds. This is to remove any potential bursts that could affect the SD.
correct_for_drift: bool, default: True
If True, will subtract the amplitudes sequentiially to significantly reduce the impact of drift.
correct_for_template_itself: bool, default: True
If true, will take into account that the template itself impacts the standard deviation of the noise,
and will make a rough estimation of what that impact is (and remove it).
unit_ids : list or None, default: None
The list of unit ids to compute this metric. If None, all units are used.
**kwargs:
Keyword arguments for computing spike amplitudes and extremum channel.
TODO: Take jitter into account.

Returns
-------
num_spikes : dict
The number of spikes, across all segments, for each unit ID.
"""

from ..curation.curation_tools import _find_duplicated_spikes_keep_first_iterative

censored_period = int(round(censored_period_ms * 1e-3 * wvf_extractor.sampling_frequency))
if unit_ids is None:
unit_ids = wvf_extractor.unit_ids

if not wvf_extractor.has_recording():
warnings.warn(
"The `sd_ratio` metric cannot work with a recordless WaveformExtractor object"
"SD ratio metric will be set to NaN"
)
Comment on lines +1413 to +1416
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have to decide a global behavior for this. I would say lets raise no ?

return {unit_id: np.nan for unit_id in unit_ids}

if wvf_extractor.is_extension("spike_amplitudes"):
amplitudes_ext = wvf_extractor.load_extension("spike_amplitudes")
spike_amplitudes = amplitudes_ext.get_data(outputs="by_unit")
else:
warnings.warn(
"The `sd_ratio` metric require the `spike_amplitudes` waveform extension. "
"Use the `postprocessing.compute_spike_amplitudes()` functions. "
"SD ratio metric will be set to NaN"
)
return {unit_id: np.nan for unit_id in unit_ids}

noise_levels = get_noise_levels(
wvf_extractor.recording, return_scaled=amplitudes_ext._params["return_scaled"], method="std"
)
best_channels = get_template_extremum_channel(wvf_extractor, outputs="index", **kwargs)
n_spikes = wvf_extractor.sorting.count_num_spikes_per_unit()

sd_ratio = {}
for unit_id in unit_ids:
spk_amp = []

for segment_index in range(wvf_extractor.get_num_segments()):
spike_train = wvf_extractor.sorting.get_unit_spike_train(unit_id, segment_index=segment_index).astype(
np.int64
)
censored_indices = _find_duplicated_spikes_keep_first_iterative(spike_train, censored_period)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure this is a good idea this cross dependency in this way between curation and qualitymetrics.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would you suggest as an alternative?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's ok for now, but we can discuss higher-level hierarchy later

spk_amp.append(np.delete(spike_amplitudes[segment_index][unit_id], censored_indices))
spk_amp = np.concatenate([spk_amp[i] for i in range(len(spk_amp))])

if correct_for_drift:
unit_std = np.std(np.diff(spk_amp)) / np.sqrt(2)
else:
unit_std = np.std(spk_amp)

best_channel = best_channels[unit_id]
std_noise = noise_levels[best_channel]

if correct_for_template_itself:
template = wvf_extractor.get_template(unit_id, force_dense=True)[:, best_channel]

# Computing the variance of a trace that is all 0 and n_spikes non-overlapping template.
# TODO: Take into account that templates for different segments might differ.
p = wvf_extractor.nsamples * n_spikes[unit_id] / wvf_extractor.get_total_samples()
total_variance = p * np.mean(template**2) - p**2 * np.mean(template)

std_noise = np.sqrt(std_noise**2 - total_variance)

sd_ratio[unit_id] = unit_std / std_noise

return sd_ratio
2 changes: 2 additions & 0 deletions src/spikeinterface/qualitymetrics/quality_metric_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
compute_synchrony_metrics,
compute_firing_ranges,
compute_amplitude_cv_metrics,
compute_sd_ratio,
)

from .pca_metrics import (
Expand Down Expand Up @@ -46,4 +47,5 @@
"synchrony": compute_synchrony_metrics,
"firing_range": compute_firing_ranges,
"drift": compute_drift_metrics,
"sd_ratio": compute_sd_ratio,
}
19 changes: 14 additions & 5 deletions src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
compute_synchrony_metrics,
compute_firing_ranges,
compute_amplitude_cv_metrics,
compute_sd_ratio,
)


Expand Down Expand Up @@ -70,7 +71,7 @@ def _simulated_data():


def _waveform_extractor_simple():
recording, sorting = toy_example(duration=50, seed=10)
recording, sorting = toy_example(duration=80, seed=10, firing_rate=6.0)
recording = recording.save(folder=cache_folder / "rec1")
sorting = sorting.save(folder=cache_folder / "sort1")
folder = cache_folder / "waveform_folder1"
Expand All @@ -86,6 +87,7 @@ def _waveform_extractor_simple():
overwrite=True,
)
_ = compute_principal_components(we, n_components=5, mode="by_channel_local")
_ = compute_spike_amplitudes(we, return_scaled=True)
return we


Expand Down Expand Up @@ -227,7 +229,7 @@ def test_calculate_firing_range(waveform_extractor_simple):

def test_calculate_amplitude_cutoff(waveform_extractor_simple):
we = waveform_extractor_simple
spike_amps = compute_spike_amplitudes(we)
spike_amps = we.load_extension("spike_amplitudes").get_data()
amp_cuts = compute_amplitude_cutoffs(we, num_histogram_bins=10)
print(amp_cuts)

Expand All @@ -238,7 +240,7 @@ def test_calculate_amplitude_cutoff(waveform_extractor_simple):

def test_calculate_amplitude_median(waveform_extractor_simple):
we = waveform_extractor_simple
spike_amps = compute_spike_amplitudes(we)
spike_amps = we.load_extension("spike_amplitudes").get_data()
amp_medians = compute_amplitude_medians(we)
print(spike_amps, amp_medians)

Expand All @@ -249,7 +251,6 @@ def test_calculate_amplitude_median(waveform_extractor_simple):

def test_calculate_amplitude_cv_metrics(waveform_extractor_simple):
we = waveform_extractor_simple
spike_amps = compute_spike_amplitudes(we)
amp_cv_median, amp_cv_range = compute_amplitude_cv_metrics(we, average_num_spikes_per_bin=20)
print(amp_cv_median)
print(amp_cv_range)
Expand Down Expand Up @@ -379,6 +380,13 @@ def test_calculate_drift_metrics(waveform_extractor_simple):
# assert np.allclose(list(drift_mads_gt.values()), list(drift_mads.values()), rtol=0.05)


def test_calculate_sd_ratio(waveform_extractor_simple):
sd_ratio = compute_sd_ratio(waveform_extractor_simple)

assert np.all(list(sd_ratio.keys()) == waveform_extractor_simple.unit_ids)
assert np.allclose(list(sd_ratio.values()), 1, atol=0.2, rtol=0)


if __name__ == "__main__":
sim_data = _simulated_data()
we = _waveform_extractor_simple()
Expand All @@ -390,5 +398,6 @@ def test_calculate_drift_metrics(waveform_extractor_simple):
# test_calculate_sliding_rp_violations(we)
# test_calculate_drift_metrics(we)
# test_synchrony_metrics(we)
test_calculate_firing_range(we)
# test_calculate_firing_range(we)
# test_calculate_amplitude_cv_metrics(we)
test_calculate_sd_ratio(we)
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,9 @@ def test_recordingless(self):

# check metrics are the same
for metric_name in qm_rec.columns:
if metric_name == "sd_ratio":
continue

# rtol is addedd for sliding_rp_violation, for a reason I do not have to explore now. Sam.
assert np.allclose(qm_rec[metric_name].values, qm_no_rec[metric_name].values, rtol=1e-02)

Expand Down