diff --git a/src/spikeinterface/qualitymetrics/misc_metrics.py b/src/spikeinterface/qualitymetrics/misc_metrics.py index 6a42b12bb5..38add13c02 100644 --- a/src/spikeinterface/qualitymetrics/misc_metrics.py +++ b/src/spikeinterface/qualitymetrics/misc_metrics.py @@ -499,7 +499,7 @@ def compute_sliding_rp_violations( ) -def compute_synchrony_metrics(waveform_extractor, synchrony_sizes=(2, 4, 8), **kwargs): +def compute_synchrony_metrics(waveform_extractor, synchrony_sizes=(2, 4, 8), unit_ids=None, **kwargs): """Compute synchrony metrics. Synchrony metrics represent the rate of occurrences of "synchrony_size" spikes at the exact same sample index. @@ -509,6 +509,8 @@ def compute_synchrony_metrics(waveform_extractor, synchrony_sizes=(2, 4, 8), **k The waveform extractor object. synchrony_sizes : list or tuple, default: (2, 4, 8) The synchrony sizes to compute. + unit_ids : list or None, default: None + List of unit ids to compute the synchrony metrics. If None, all units are used. Returns ------- @@ -526,6 +528,9 @@ def compute_synchrony_metrics(waveform_extractor, synchrony_sizes=(2, 4, 8), **k sorting = waveform_extractor.sorting spikes = sorting.to_spike_vector(concatenated=False) + if unit_ids is None: + unit_ids = sorting.unit_ids + # Pre-allocate synchrony counts synchrony_counts = {} for synchrony_size in synchrony_sizes: @@ -538,20 +543,20 @@ def compute_synchrony_metrics(waveform_extractor, synchrony_sizes=(2, 4, 8), **k unique_spike_index, complexity = np.unique(spikes_in_segment["sample_index"], return_counts=True) # add counts for this segment - for unit_index in np.arange(len(sorting.unit_ids)): + for unit_id in unit_ids: + unit_index = sorting.unit_ids.index(unit_id) spikes_per_unit = spikes_in_segment[spikes_in_segment["unit_index"] == unit_index] # some segments/units might have no spikes if len(spikes_per_unit) == 0: continue spike_complexity = complexity[np.in1d(unique_spike_index, spikes_per_unit["sample_index"])] for synchrony_size in synchrony_sizes: - synchrony_counts[synchrony_size][unit_index] += np.count_nonzero(spike_complexity >= synchrony_size) + synchrony_counts[synchrony_size][unit_id] += np.count_nonzero(spike_complexity >= synchrony_size) # add counts for this segment synchrony_metrics_dict = { f"sync_spike_{synchrony_size}": { - unit_id: synchrony_counts[synchrony_size][unit_index] / spike_counts[unit_id] - for unit_index, unit_id in enumerate(sorting.unit_ids) + unit_id: synchrony_counts[synchrony_size][unit_id] / spike_counts[unit_id] for unit_id in unit_ids } for synchrony_size in synchrony_sizes } @@ -565,7 +570,7 @@ def compute_synchrony_metrics(waveform_extractor, synchrony_sizes=(2, 4, 8), **k _default_params["synchrony"] = dict(synchrony_sizes=(0, 2, 4)) -def compute_firing_ranges(waveform_extractor, bin_size_s=5, percentiles=(0.05, 0.95), unit_ids=None): +def compute_firing_ranges(waveform_extractor, bin_size_s=5, percentiles=(5, 95), unit_ids=None, **kwargs): """Calculate firing range, the range between the 5th and 95th percentiles of the firing rates distribution computed in non-overlapping time bins. @@ -575,7 +580,7 @@ def compute_firing_ranges(waveform_extractor, bin_size_s=5, percentiles=(0.05, 0 The waveform extractor object. bin_size_s : float, default: 5 The size of the bin in seconds. - percentiles : tuple, default: (0.05, 0.95) + percentiles : tuple, default: (5, 95) The percentiles to compute. unit_ids : list or None List of unit ids to compute the firing range. If None, all units are used. @@ -617,13 +622,13 @@ def compute_firing_ranges(waveform_extractor, bin_size_s=5, percentiles=(0.05, 0 return firing_ranges -_default_params["firing_range"] = dict(bin_size_s=5, percentiles=(0.05, 0.95)) +_default_params["firing_range"] = dict(bin_size_s=5, percentiles=(5, 95)) def compute_amplitude_cv_metrics( waveform_extractor, average_num_spikes_per_bin=50, - percentiles=(0.05, 0.95), + percentiles=(5, 95), min_num_bins=10, amplitude_extension="spike_amplitudes", unit_ids=None, @@ -726,7 +731,7 @@ def compute_amplitude_cv_metrics( _default_params["amplitude_cv"] = dict( - average_num_spikes_per_bin=50, percentiles=(0.05, 0.95), min_num_bins=10, amplitude_extension="spike_amplitudes" + average_num_spikes_per_bin=50, percentiles=(5, 95), min_num_bins=10, amplitude_extension="spike_amplitudes" )