Skip to content

Commit

Permalink
Rename spike_labels -> spike_unit_indices
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeZiminski committed Jul 2, 2024
1 parent 2741273 commit a96821c
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 19 deletions.
24 changes: 12 additions & 12 deletions src/spikeinterface/postprocessing/correlograms.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,16 +262,16 @@ def _compute_correlograms_numpy(sorting, window_size, bin_size):

for seg_index in range(num_seg):
spike_times = spikes[seg_index]["sample_index"]
spike_labels = spikes[seg_index]["unit_index"]
spike_unit_indices = spikes[seg_index]["unit_index"]

c0 = correlogram_for_one_segment(spike_times, spike_labels, window_size, bin_size)
c0 = correlogram_for_one_segment(spike_times, spike_unit_indices, window_size, bin_size)

correlograms += c0

return correlograms


def correlogram_for_one_segment(spike_times, spike_labels, window_size, bin_size):
def correlogram_for_one_segment(spike_times, spike_unit_indices, window_size, bin_size):
"""
A very well optimized algorithm for the cross-correlation of
spike trains, copied from the Phy package, written by Cyrille Rossant.
Expand All @@ -281,7 +281,7 @@ def correlogram_for_one_segment(spike_times, spike_labels, window_size, bin_size
spike_times : np.ndarray
An array of spike times (in samples, not seconds).
This contains spikes from all units.
spike_labels : np.ndarray
spike_unit_indices : np.ndarray
An array of labels indicating the unit of the corresponding
spike in `spike_times`.
window_size : int
Expand Down Expand Up @@ -315,7 +315,7 @@ def correlogram_for_one_segment(spike_times, spike_labels, window_size, bin_size
match within the window size.
"""
num_bins, num_half_bins = _compute_num_bins(window_size, bin_size)
num_units = len(np.unique(spike_labels))
num_units = len(np.unique(spike_unit_indices))

correlograms = np.zeros((num_units, num_units, num_bins), dtype="int64")

Expand Down Expand Up @@ -347,12 +347,12 @@ def correlogram_for_one_segment(spike_times, spike_labels, window_size, bin_size
# to be incremented, taking into account the spike unit labels.
if sign == 1:
indices = np.ravel_multi_index(
(spike_labels[+shift:][m], spike_labels[:-shift][m], spike_diff_b[m] + num_half_bins),
(spike_unit_indices[+shift:][m], spike_unit_indices[:-shift][m], spike_diff_b[m] + num_half_bins),
correlograms.shape,
)
else:
indices = np.ravel_multi_index(
(spike_labels[:-shift][m], spike_labels[+shift:][m], spike_diff_b[m] + num_half_bins),
(spike_unit_indices[:-shift][m], spike_unit_indices[+shift:][m], spike_diff_b[m] + num_half_bins),
correlograms.shape,
)

Expand Down Expand Up @@ -411,12 +411,12 @@ def _compute_correlograms_numba(sorting, window_size, bin_size):

for seg_index in range(sorting.get_num_segments()):
spike_times = spikes[seg_index]["sample_index"]
spike_labels = spikes[seg_index]["unit_index"]
spike_unit_indices = spikes[seg_index]["unit_index"]

_compute_correlograms_one_segment_numba(
correlograms,
spike_times.astype(np.int64, copy=False),
spike_labels.astype(np.int32, copy=False),
spike_unit_indices.astype(np.int32, copy=False),
window_size,
bin_size,
num_half_bins,
Expand All @@ -433,7 +433,7 @@ def _compute_correlograms_numba(sorting, window_size, bin_size):
cache=False,
)
def _compute_correlograms_one_segment_numba(
correlograms, spike_times, spike_labels, window_size, bin_size, num_half_bins
correlograms, spike_times, spike_unit_indices, window_size, bin_size, num_half_bins
):
"""
Compute the correlograms using `numba` for speed.
Expand All @@ -455,7 +455,7 @@ def _compute_correlograms_one_segment_numba(
spike_times : np.ndarray
An array of spike times (in samples, not seconds).
This contains spikes from all units.
spike_labels : np.ndarray
spike_unit_indices : np.ndarray
An array of labels indicating the unit of the corresponding
spike in `spike_times`.
window_size : int
Expand Down Expand Up @@ -494,4 +494,4 @@ def _compute_correlograms_one_segment_numba(

bin = diff // bin_size

correlograms[spike_labels[i], spike_labels[j], num_half_bins + bin] += 1
correlograms[spike_unit_indices[i], spike_unit_indices[j], num_half_bins + bin] += 1
14 changes: 7 additions & 7 deletions src/spikeinterface/postprocessing/tests/test_correlograms.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,18 +189,18 @@ def test_compute_correlograms(fill_all_bins, on_time_bin, multi_segment):
counts should double when two segments with identical spike times / labels are used.
"""
sampling_frequency = 30000
window_ms, bin_ms, spike_times, spike_labels, expected_bins, expected_result_auto, expected_result_corr = (
window_ms, bin_ms, spike_times, spike_unit_indices, expected_bins, expected_result_auto, expected_result_corr = (
generate_correlogram_test_dataset(sampling_frequency, fill_all_bins, on_time_bin)
)

if multi_segment:
sorting = NumpySorting.from_times_labels(
times_list=[spike_times], labels_list=[spike_labels], sampling_frequency=sampling_frequency
times_list=[spike_times], labels_list=[spike_unit_indices], sampling_frequency=sampling_frequency
)
else:
sorting = NumpySorting.from_times_labels(
times_list=[spike_times, spike_times],
labels_list=[spike_labels, spike_labels],
labels_list=[spike_unit_indices, spike_unit_indices],
sampling_frequency=sampling_frequency,
)
expected_result_auto *= 2
Expand Down Expand Up @@ -235,13 +235,13 @@ def test_compute_correlograms_different_units(method):
spike_times = np.array([0, 4, 8, 16]) / 1000 * sampling_frequency
spike_times.astype(int)

spike_labels = np.array([0, 1, 0, 1])
spike_unit_indices = np.array([0, 1, 0, 1])

window_ms = 40
bin_ms = 5

sorting = NumpySorting.from_times_labels(
times_list=[spike_times], labels_list=[spike_labels], sampling_frequency=sampling_frequency
times_list=[spike_times], labels_list=[spike_unit_indices], sampling_frequency=sampling_frequency
)

result, bins = compute_correlograms(sorting, window_ms=window_ms, bin_ms=bin_ms, method=method)
Expand Down Expand Up @@ -323,7 +323,7 @@ def generate_correlogram_test_dataset(sampling_frequency, fill_all_bins, hit_bin
# Now, make a set of times that increase by `base_diff_time` e.g.
# if base_diff_time=0.0051 then our spike times are [`0.0051, 0.0102, ...]`
spike_times = np.repeat(np.arange(num_filled_bins), num_units) * base_diff_time
spike_labels = np.tile(np.arange(num_units), int(spike_times.size / num_units))
spike_unit_indices = np.tile(np.arange(num_units), int(spike_times.size / num_units))

spike_times *= sampling_frequency
spike_times = spike_times.astype(int)
Expand Down Expand Up @@ -368,4 +368,4 @@ def generate_correlogram_test_dataset(sampling_frequency, fill_all_bins, hit_bin
expected_result_corr = expected_result_auto.copy()
expected_result_corr[int(num_bins / 2)] = num_filled_bins

return window_ms, bin_ms, spike_times, spike_labels, expected_bins, expected_result_auto, expected_result_corr
return window_ms, bin_ms, spike_times, spike_unit_indices, expected_bins, expected_result_auto, expected_result_corr

0 comments on commit a96821c

Please sign in to comment.