Skip to content

Commit

Permalink
Percentiles need 0-100 and ad duinit_ids to syncrhony metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
alejoe91 committed Sep 19, 2023
1 parent 92d458e commit 26cfd5d
Showing 1 changed file with 15 additions and 10 deletions.
25 changes: 15 additions & 10 deletions src/spikeinterface/qualitymetrics/misc_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ def compute_sliding_rp_violations(
)


def compute_synchrony_metrics(waveform_extractor, synchrony_sizes=(2, 4, 8), **kwargs):
def compute_synchrony_metrics(waveform_extractor, synchrony_sizes=(2, 4, 8), unit_ids=None, **kwargs):
"""Compute synchrony metrics. Synchrony metrics represent the rate of occurrences of
"synchrony_size" spikes at the exact same sample index.
Expand All @@ -509,6 +509,8 @@ def compute_synchrony_metrics(waveform_extractor, synchrony_sizes=(2, 4, 8), **k
The waveform extractor object.
synchrony_sizes : list or tuple, default: (2, 4, 8)
The synchrony sizes to compute.
unit_ids : list or None, default: None
List of unit ids to compute the synchrony metrics. If None, all units are used.
Returns
-------
Expand All @@ -526,6 +528,9 @@ def compute_synchrony_metrics(waveform_extractor, synchrony_sizes=(2, 4, 8), **k
sorting = waveform_extractor.sorting
spikes = sorting.to_spike_vector(concatenated=False)

if unit_ids is None:
unit_ids = sorting.unit_ids

# Pre-allocate synchrony counts
synchrony_counts = {}
for synchrony_size in synchrony_sizes:
Expand All @@ -538,20 +543,20 @@ def compute_synchrony_metrics(waveform_extractor, synchrony_sizes=(2, 4, 8), **k
unique_spike_index, complexity = np.unique(spikes_in_segment["sample_index"], return_counts=True)

# add counts for this segment
for unit_index in np.arange(len(sorting.unit_ids)):
for unit_id in unit_ids:
unit_index = sorting.unit_ids.index(unit_id)
spikes_per_unit = spikes_in_segment[spikes_in_segment["unit_index"] == unit_index]
# some segments/units might have no spikes
if len(spikes_per_unit) == 0:
continue
spike_complexity = complexity[np.in1d(unique_spike_index, spikes_per_unit["sample_index"])]
for synchrony_size in synchrony_sizes:
synchrony_counts[synchrony_size][unit_index] += np.count_nonzero(spike_complexity >= synchrony_size)
synchrony_counts[synchrony_size][unit_id] += np.count_nonzero(spike_complexity >= synchrony_size)

# add counts for this segment
synchrony_metrics_dict = {
f"sync_spike_{synchrony_size}": {
unit_id: synchrony_counts[synchrony_size][unit_index] / spike_counts[unit_id]
for unit_index, unit_id in enumerate(sorting.unit_ids)
unit_id: synchrony_counts[synchrony_size][unit_id] / spike_counts[unit_id] for unit_id in unit_ids
}
for synchrony_size in synchrony_sizes
}
Expand All @@ -565,7 +570,7 @@ def compute_synchrony_metrics(waveform_extractor, synchrony_sizes=(2, 4, 8), **k
_default_params["synchrony"] = dict(synchrony_sizes=(0, 2, 4))


def compute_firing_ranges(waveform_extractor, bin_size_s=5, percentiles=(0.05, 0.95), unit_ids=None):
def compute_firing_ranges(waveform_extractor, bin_size_s=5, percentiles=(5, 95), unit_ids=None, **kwargs):
"""Calculate firing range, the range between the 5th and 95th percentiles of the firing rates distribution
computed in non-overlapping time bins.
Expand All @@ -575,7 +580,7 @@ def compute_firing_ranges(waveform_extractor, bin_size_s=5, percentiles=(0.05, 0
The waveform extractor object.
bin_size_s : float, default: 5
The size of the bin in seconds.
percentiles : tuple, default: (0.05, 0.95)
percentiles : tuple, default: (5, 95)
The percentiles to compute.
unit_ids : list or None
List of unit ids to compute the firing range. If None, all units are used.
Expand Down Expand Up @@ -617,13 +622,13 @@ def compute_firing_ranges(waveform_extractor, bin_size_s=5, percentiles=(0.05, 0
return firing_ranges


_default_params["firing_range"] = dict(bin_size_s=5, percentiles=(0.05, 0.95))
_default_params["firing_range"] = dict(bin_size_s=5, percentiles=(5, 95))


def compute_amplitude_cv_metrics(
waveform_extractor,
average_num_spikes_per_bin=50,
percentiles=(0.05, 0.95),
percentiles=(5, 95),
min_num_bins=10,
amplitude_extension="spike_amplitudes",
unit_ids=None,
Expand Down Expand Up @@ -726,7 +731,7 @@ def compute_amplitude_cv_metrics(


_default_params["amplitude_cv"] = dict(
average_num_spikes_per_bin=50, percentiles=(0.05, 0.95), min_num_bins=10, amplitude_extension="spike_amplitudes"
average_num_spikes_per_bin=50, percentiles=(5, 95), min_num_bins=10, amplitude_extension="spike_amplitudes"
)


Expand Down

0 comments on commit 26cfd5d

Please sign in to comment.