diff --git a/src/spikeinterface/postprocessing/correlograms.py b/src/spikeinterface/postprocessing/correlograms.py index e96fcb570c..7c22260dbe 100644 --- a/src/spikeinterface/postprocessing/correlograms.py +++ b/src/spikeinterface/postprocessing/correlograms.py @@ -262,16 +262,16 @@ def _compute_correlograms_numpy(sorting, window_size, bin_size): for seg_index in range(num_seg): spike_times = spikes[seg_index]["sample_index"] - spike_labels = spikes[seg_index]["unit_index"] + spike_unit_indices = spikes[seg_index]["unit_index"] - c0 = correlogram_for_one_segment(spike_times, spike_labels, window_size, bin_size) + c0 = correlogram_for_one_segment(spike_times, spike_unit_indices, window_size, bin_size) correlograms += c0 return correlograms -def correlogram_for_one_segment(spike_times, spike_labels, window_size, bin_size): +def correlogram_for_one_segment(spike_times, spike_unit_indices, window_size, bin_size): """ A very well optimized algorithm for the cross-correlation of spike trains, copied from the Phy package, written by Cyrille Rossant. @@ -281,7 +281,7 @@ def correlogram_for_one_segment(spike_times, spike_labels, window_size, bin_size spike_times : np.ndarray An array of spike times (in samples, not seconds). This contains spikes from all units. - spike_labels : np.ndarray + spike_unit_indices : np.ndarray An array of labels indicating the unit of the corresponding spike in `spike_times`. window_size : int @@ -315,7 +315,7 @@ def correlogram_for_one_segment(spike_times, spike_labels, window_size, bin_size match within the window size. """ num_bins, num_half_bins = _compute_num_bins(window_size, bin_size) - num_units = len(np.unique(spike_labels)) + num_units = len(np.unique(spike_unit_indices)) correlograms = np.zeros((num_units, num_units, num_bins), dtype="int64") @@ -347,12 +347,12 @@ def correlogram_for_one_segment(spike_times, spike_labels, window_size, bin_size # to be incremented, taking into account the spike unit labels. if sign == 1: indices = np.ravel_multi_index( - (spike_labels[+shift:][m], spike_labels[:-shift][m], spike_diff_b[m] + num_half_bins), + (spike_unit_indices[+shift:][m], spike_unit_indices[:-shift][m], spike_diff_b[m] + num_half_bins), correlograms.shape, ) else: indices = np.ravel_multi_index( - (spike_labels[:-shift][m], spike_labels[+shift:][m], spike_diff_b[m] + num_half_bins), + (spike_unit_indices[:-shift][m], spike_unit_indices[+shift:][m], spike_diff_b[m] + num_half_bins), correlograms.shape, ) @@ -411,12 +411,12 @@ def _compute_correlograms_numba(sorting, window_size, bin_size): for seg_index in range(sorting.get_num_segments()): spike_times = spikes[seg_index]["sample_index"] - spike_labels = spikes[seg_index]["unit_index"] + spike_unit_indices = spikes[seg_index]["unit_index"] _compute_correlograms_one_segment_numba( correlograms, spike_times.astype(np.int64, copy=False), - spike_labels.astype(np.int32, copy=False), + spike_unit_indices.astype(np.int32, copy=False), window_size, bin_size, num_half_bins, @@ -433,7 +433,7 @@ def _compute_correlograms_numba(sorting, window_size, bin_size): cache=False, ) def _compute_correlograms_one_segment_numba( - correlograms, spike_times, spike_labels, window_size, bin_size, num_half_bins + correlograms, spike_times, spike_unit_indices, window_size, bin_size, num_half_bins ): """ Compute the correlograms using `numba` for speed. @@ -455,7 +455,7 @@ def _compute_correlograms_one_segment_numba( spike_times : np.ndarray An array of spike times (in samples, not seconds). This contains spikes from all units. - spike_labels : np.ndarray + spike_unit_indices : np.ndarray An array of labels indicating the unit of the corresponding spike in `spike_times`. window_size : int @@ -494,4 +494,4 @@ def _compute_correlograms_one_segment_numba( bin = diff // bin_size - correlograms[spike_labels[i], spike_labels[j], num_half_bins + bin] += 1 + correlograms[spike_unit_indices[i], spike_unit_indices[j], num_half_bins + bin] += 1 diff --git a/src/spikeinterface/postprocessing/tests/test_correlograms.py b/src/spikeinterface/postprocessing/tests/test_correlograms.py index 3b43921a0b..66d84c9565 100644 --- a/src/spikeinterface/postprocessing/tests/test_correlograms.py +++ b/src/spikeinterface/postprocessing/tests/test_correlograms.py @@ -189,18 +189,18 @@ def test_compute_correlograms(fill_all_bins, on_time_bin, multi_segment): counts should double when two segments with identical spike times / labels are used. """ sampling_frequency = 30000 - window_ms, bin_ms, spike_times, spike_labels, expected_bins, expected_result_auto, expected_result_corr = ( + window_ms, bin_ms, spike_times, spike_unit_indices, expected_bins, expected_result_auto, expected_result_corr = ( generate_correlogram_test_dataset(sampling_frequency, fill_all_bins, on_time_bin) ) if multi_segment: sorting = NumpySorting.from_times_labels( - times_list=[spike_times], labels_list=[spike_labels], sampling_frequency=sampling_frequency + times_list=[spike_times], labels_list=[spike_unit_indices], sampling_frequency=sampling_frequency ) else: sorting = NumpySorting.from_times_labels( times_list=[spike_times, spike_times], - labels_list=[spike_labels, spike_labels], + labels_list=[spike_unit_indices, spike_unit_indices], sampling_frequency=sampling_frequency, ) expected_result_auto *= 2 @@ -235,13 +235,13 @@ def test_compute_correlograms_different_units(method): spike_times = np.array([0, 4, 8, 16]) / 1000 * sampling_frequency spike_times.astype(int) - spike_labels = np.array([0, 1, 0, 1]) + spike_unit_indices = np.array([0, 1, 0, 1]) window_ms = 40 bin_ms = 5 sorting = NumpySorting.from_times_labels( - times_list=[spike_times], labels_list=[spike_labels], sampling_frequency=sampling_frequency + times_list=[spike_times], labels_list=[spike_unit_indices], sampling_frequency=sampling_frequency ) result, bins = compute_correlograms(sorting, window_ms=window_ms, bin_ms=bin_ms, method=method) @@ -323,7 +323,7 @@ def generate_correlogram_test_dataset(sampling_frequency, fill_all_bins, hit_bin # Now, make a set of times that increase by `base_diff_time` e.g. # if base_diff_time=0.0051 then our spike times are [`0.0051, 0.0102, ...]` spike_times = np.repeat(np.arange(num_filled_bins), num_units) * base_diff_time - spike_labels = np.tile(np.arange(num_units), int(spike_times.size / num_units)) + spike_unit_indices = np.tile(np.arange(num_units), int(spike_times.size / num_units)) spike_times *= sampling_frequency spike_times = spike_times.astype(int) @@ -368,4 +368,4 @@ def generate_correlogram_test_dataset(sampling_frequency, fill_all_bins, hit_bin expected_result_corr = expected_result_auto.copy() expected_result_corr[int(num_bins / 2)] = num_filled_bins - return window_ms, bin_ms, spike_times, spike_labels, expected_bins, expected_result_auto, expected_result_corr + return window_ms, bin_ms, spike_times, spike_unit_indices, expected_bins, expected_result_auto, expected_result_corr