Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented sd_ratio as quality metric #2146

Merged
merged 23 commits into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
1cf9a97
Work on SD test
DradeAW Oct 31, 2023
9900bca
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 31, 2023
ed29d08
Work on SD test + circular import
DradeAW Oct 31, 2023
690d33d
Implemented SD test
DradeAW Nov 2, 2023
ded2c32
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 2, 2023
a1f9a3f
Merge branch 'main' into sd_test
DradeAW Nov 2, 2023
5ae3219
Added doc for SD test
DradeAW Nov 2, 2023
288029a
Multiple fixed for `sd_ratio` metric
DradeAW Nov 2, 2023
9d5c270
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 2, 2023
894ab2c
Fix `sd_ratio`
DradeAW Nov 2, 2023
edc0e91
Merge branch 'sd_test' of github.com:DradeAW/spikeinterface into sd_test
DradeAW Nov 2, 2023
cb477f9
Suggestions from Zach and Alessio for `sd_ratio`
DradeAW Nov 2, 2023
dcdbb1f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 2, 2023
0ed6796
oops
DradeAW Nov 2, 2023
660de12
Fix `sd_ratio` recordless
DradeAW Nov 2, 2023
47581f5
Fixed `sd_ratio`
DradeAW Nov 2, 2023
602f70f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 2, 2023
7cd707d
Merge branch 'main' into sd_test
DradeAW Nov 2, 2023
33232bf
Alphabetical order for quality metrics in docs
DradeAW Nov 3, 2023
0469792
Merge branch 'main' into sd_test
DradeAW Nov 14, 2023
f18eb9c
Merge branch 'main' into sd_test
DradeAW Nov 17, 2023
848ee93
Added option to correct for template itself in `sd_ratio`
DradeAW Nov 20, 2023
cd5eb3f
Merge branch 'main' into sd_test
alejoe91 Nov 22, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/modules/qualitymetrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ For more details about each metric and it's availability and use within SpikeInt
qualitymetrics/snr
qualitymetrics/noise_cutoff
qualitymetrics/silhouette_score
qualitymetrics/sd_ratio
qualitymetrics/synchrony


Expand Down
2 changes: 2 additions & 0 deletions doc/modules/qualitymetrics/references.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ References

.. [Llobet] Llobet Victor, Wyngaard Aurélien and Barbour Boris. “Automatic post-processing and merging of multiple spike-sorting analyses with Lussac“. BioRxiv (2022).

.. [Pouzat] Pouzat Christophe, Mazor Ofer and Laurent Gilles. “Using noise signature to optimize spike-sorting and to assess neuronal classification quality“. Journal of Neuroscience Methods (2002).

.. [Rousseeuw] Peter J Rousseeuw. Silhouettes: A graphical aid to the interpretation and validation of cluster analysis. Journal of computational and applied mathematics, 20(C):53–65, 1987.

.. [Schmitzer-Torbert] Schmitzer-Torbert, Neil, and A. David Redish. “Neuronal Activity in the Rodent Dorsal Striatum in Sequential Navigation: Separation of Spatial and Reward Responses on the Multiple T Task.” Journal of neurophysiology 91.5 (2004): 2259–2272. Web.
Expand Down
38 changes: 38 additions & 0 deletions doc/modules/qualitymetrics/sd_ratio.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
Standard Deviation (SD) ratio (:code:`sd_ratio`)
==============================================

Calculation
-----------

All spikes from the same neuron should have the same shape. This means that at the peak of the spike, the standard deviation of the voltage should be the same as that of noise. If spikes from multiple neurons are grouped in a single unit, chances are the standard deviation of spikes amplitude is going to increase.
DradeAW marked this conversation as resolved.
Show resolved Hide resolved

This metric, first described [Pouzat]_ then adapted by Wyngaard, Llobet & Barbour (paper in writing), returns the ratio between both standard deviations:
DradeAW marked this conversation as resolved.
Show resolved Hide resolved

.. math::
S = \frac{\sigma_{\mathrm{unit}}}{\sigma_{\mathrm{noise}}}

To remove the effect of drift on spikes amplitude, :math:`\sigma_{\mathrm{unit}}` is computed by subtracting each spike amplitude, and dividing the resulting standard deviation by :math:`\sqrt{2}`.
Also to remove the effect of bursts (which can have lower amplitudes), you can specify a censored period (by default 4.0 ms) where spikes happening less than this period after another spike will not be considered.


Expectation and use
-------------------

For a unit representing a single neuron, this metric should return a value close to one. However for unit that are contaminated, the value can be significantly higher.
DradeAW marked this conversation as resolved.
Show resolved Hide resolved


Example code
------------

.. code-block:: python

import spikeinterface.qualitymetrics as sqm
DradeAW marked this conversation as resolved.
Show resolved Hide resolved

sd_ratio = sqm.compute_sd_ratio(wvf_extractor, censored_period_ms=4.0)


Literature
----------

Introduced by [Pouzat]_ (2002).
Expanded by Wyngaard, Llobet and Barbour (in writing).
DradeAW marked this conversation as resolved.
Show resolved Hide resolved
69 changes: 67 additions & 2 deletions src/spikeinterface/qualitymetrics/misc_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
import numpy as np
import warnings

from ..postprocessing import correlogram_for_one_segment
from ..core import get_noise_levels
from ..postprocessing import compute_spike_amplitudes, correlogram_for_one_segment
from ..core import WaveformExtractor, get_noise_levels
from ..core.template_tools import (
get_template_extremum_channel,
get_template_extremum_amplitude,
Expand Down Expand Up @@ -1365,3 +1365,68 @@ def _compute_rp_violations_numba(nb_rp_violations, spike_trains, spike_clusters,
spike_train = spike_trains[spike_clusters == i]
n_v = _compute_nb_violations_numba(spike_train, t_r)
nb_rp_violations[i] += n_v


def compute_sd_ratio(
wvf_extractor: WaveformExtractor,
censored_period_ms: float = 4.0,
correct_for_drift: bool = True,
unit_ids=None,
**kwargs,
):
"""
Computes the SD (Standard Deviation) of each unit spikes amplitude, and compare it to that of noise.
DradeAW marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
waveform_extractor : WaveformExtractor
The waveform extractor object.
censored_period_ms : float
DradeAW marked this conversation as resolved.
Show resolved Hide resolved
The censored period in milliseconds. This is to remove any potential bursts that could affect the SD.
correct_for_drift: bool
DradeAW marked this conversation as resolved.
Show resolved Hide resolved
If True, will subtract the amplitudes sequentiially to significantly reduce the impact of drift.
unit_ids : list or None
DradeAW marked this conversation as resolved.
Show resolved Hide resolved
The list of unit ids to compute this metric. If None, all units are used.
**kwargs:
Keyword arguments for computing spike amplitudes and extremum channel.
TODO: Possibly remove spikes when computing noise?
TODO: Take jitter into account.

Returns
-------
num_spikes : dict
The number of spikes, across all segments, for each unit ID.
"""

from ..curation.curation_tools import _find_duplicated_spikes_keep_first_iterative

censored_period = int(round(censored_period_ms * 1e-3 * wvf_extractor.sampling_frequency))
if unit_ids is None:
unit_ids = wvf_extractor.unit_ids

spikes_amplitude = compute_spike_amplitudes(wvf_extractor, outputs="by_unit", return_scaled=True, **kwargs)
noise_levels = get_noise_levels(wvf_extractor.recording, return_scaled=True, method="std")
best_channels = get_template_extremum_channel(wvf_extractor, outputs="index", **kwargs)

sd_ratio = []
for segment_index in range(len(spikes_amplitude)):
sd_ratio.append({})

for unit_id, spk_amp in spikes_amplitude[segment_index].items():
if unit_id not in unit_ids:
continue

spike_train = wvf_extractor.sorting.get_unit_spike_train(unit_id, segment_index=segment_index).astype(
np.int64
)
censored_indices = _find_duplicated_spikes_keep_first_iterative(spike_train, censored_period)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure this is a good idea this cross dependency in this way between curation and qualitymetrics.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would you suggest as an alternative?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's ok for now, but we can discuss higher-level hierarchy later

spk_amp = np.delete(spk_amp, censored_indices)

if correct_for_drift:
unit_std = np.std(np.diff(spk_amp)) / np.sqrt(2)
else:
unit_std = np.std(spk_amp)

sd_ratio[segment_index][unit_id] = unit_std / noise_levels[best_channels[unit_id]]

return sd_ratio
2 changes: 2 additions & 0 deletions src/spikeinterface/qualitymetrics/quality_metric_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
compute_synchrony_metrics,
compute_firing_ranges,
compute_amplitude_cv_metrics,
compute_sd_ratio,
)

from .pca_metrics import (
Expand Down Expand Up @@ -46,4 +47,5 @@
"synchrony": compute_synchrony_metrics,
"firing_range": compute_firing_ranges,
"drift": compute_drift_metrics,
"sd_ratio": compute_sd_ratio,
}
14 changes: 12 additions & 2 deletions src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
compute_synchrony_metrics,
compute_firing_ranges,
compute_amplitude_cv_metrics,
compute_sd_ratio,
)


Expand Down Expand Up @@ -70,7 +71,7 @@ def _simulated_data():


def _waveform_extractor_simple():
recording, sorting = toy_example(duration=50, seed=10)
recording, sorting = toy_example(duration=80, seed=10)
recording = recording.save(folder=cache_folder / "rec1")
sorting = sorting.save(folder=cache_folder / "sort1")
folder = cache_folder / "waveform_folder1"
Expand Down Expand Up @@ -379,6 +380,14 @@ def test_calculate_drift_metrics(waveform_extractor_simple):
# assert np.allclose(list(drift_mads_gt.values()), list(drift_mads.values()), rtol=0.05)


def test_calculate_sd_ratio(waveform_extractor_simple):
sd_ratio = compute_sd_ratio(waveform_extractor_simple)

assert len(sd_ratio) == waveform_extractor_simple.get_num_segments()
assert np.all(list(sd_ratio[0].keys()) == waveform_extractor_simple.unit_ids)
assert np.allclose(np.array(list(sd_ratio[0].values())), 1, atol=0.5, rtol=0)


if __name__ == "__main__":
sim_data = _simulated_data()
we = _waveform_extractor_simple()
Expand All @@ -390,5 +399,6 @@ def test_calculate_drift_metrics(waveform_extractor_simple):
# test_calculate_sliding_rp_violations(we)
# test_calculate_drift_metrics(we)
# test_synchrony_metrics(we)
test_calculate_firing_range(we)
# test_calculate_firing_range(we)
# test_calculate_amplitude_cv_metrics(we)
test_calculate_sd_ratio(we)
Loading