diff --git a/src/spikeinterface/postprocessing/__init__.py b/src/spikeinterface/postprocessing/__init__.py index ae071a55e0..68047a1ad5 100644 --- a/src/spikeinterface/postprocessing/__init__.py +++ b/src/spikeinterface/postprocessing/__init__.py @@ -21,11 +21,7 @@ from .correlograms import ( ComputeCorrelograms, compute_correlograms, - compute_autocorrelogram_from_spiketrain, - compute_crosscorrelogram_from_spiketrain, correlogram_for_one_segment, - compute_correlograms_numba, - compute_correlograms_numpy, ) from .isi import ( diff --git a/src/spikeinterface/postprocessing/correlograms.py b/src/spikeinterface/postprocessing/correlograms.py index bc7d2578fa..7c22260dbe 100644 --- a/src/spikeinterface/postprocessing/correlograms.py +++ b/src/spikeinterface/postprocessing/correlograms.py @@ -16,35 +16,55 @@ class ComputeCorrelograms(AnalyzerExtension): """ - Compute auto and cross correlograms. + Compute auto and cross correlograms of unit spike times. Parameters ---------- sorting_analyzer: SortingAnalyzer A SortingAnalyzer object window_ms : float, default: 50.0 - The window in ms + The window around the spike to compute the correlation in ms. For example, + if 50 ms, the correlations will be computed at lags -25 ms ... 25 ms. bin_ms : float, default: 1.0 - The bin size in ms + The bin size in ms. This determines the bin size over which to + combine lags. For example, with a window size of -25 ms to 25 ms, and + bin size 1 ms, the correlation will be binned as -25 ms, -24 ms, ... method : "auto" | "numpy" | "numba", default: "auto" - If "auto" and numba is installed, numba is used, otherwise numpy is used + If "auto" and numba is installed, numba is used, otherwise numpy is used. Returns ------- - ccgs : np.array + correlogram : np.array Correlograms with shape (num_units, num_units, num_bins) - The diagonal of ccgs is the auto correlogram. - ccgs[A, B, :] is the symetrie of ccgs[B, A, :] - ccgs[A, B, :] have to be read as the histogram of spiketimesA - spiketimesB + The diagonal of the correlogram (e.g. correlogram[A, A, :]) + holds the unit auto correlograms. The off-diagonal elements + are the cross-correlograms between units, where correlogram[A, B, :] + and correlogram[B, A, :] represent cross-correlation between + the same pair of units, applied in opposite directions, + correlogram[A, B, :] = correlogram[B, A, ::-1]. bins : np.array The bin edges in ms - Returns - ------- - isi_histograms : np.array - 2D array with ISI histograms (num_units, num_bins) - bins : np.array - 1D array with bins in ms + Notes + ----- + In the extracellular electrophysiology context, a correlogram + is a visualisation of the results of a cross-correlation + between two spike trains. The cross-correlation slides one spike train + along another sample-by-sample, taking the correlation at each 'lag'. This results + in a plot with 'lag' (i.e. time offset) on the x-axis and 'correlation' + (i.e. how similar to two spike trains are) on the y-axis. In this + implementation, the y-axis result is the 'counts' of spike matches per + time bin (rather than a computer correlation or covariance). + + In the present implementation, a 'window' around spikes is first + specified. For example, if a window of 100 ms is taken, we will + take the correlation at lags from -50 ms to +50 ms around the spike peak. + In theory, we can have as many lags as we have samples. Often, this + visualisation is too high resolution and instead the lags are binned + (e.g. -50 to -45 ms, ..., -5 to 0 ms, 0 to 5 ms, ...., 45 to 50 ms). + When using counts as output, binning the lags involves adding up all counts across + a range of lags. + """ @@ -71,7 +91,7 @@ def _select_extension_data(self, unit_ids): return new_data def _run(self, verbose=False): - ccgs, bins = compute_correlograms_on_sorting(self.sorting_analyzer.sorting, **self.params) + ccgs, bins = _compute_correlograms_on_sorting(self.sorting_analyzer.sorting, **self.params) self.data["ccgs"] = ccgs self.data["bins"] = bins @@ -89,7 +109,10 @@ def compute_correlograms( bin_ms: float = 1.0, method: str = "auto", ): - + """ + Compute correlograms using Numba or Numpy. + See ComputeCorrelograms() for details. + """ if isinstance(sorting_analyzer_or_sorting, MockWaveformExtractor): sorting_analyzer_or_sorting = sorting_analyzer_or_sorting.sorting @@ -98,7 +121,7 @@ def compute_correlograms( sorting_analyzer_or_sorting, window_ms=window_ms, bin_ms=bin_ms, method=method ) else: - return compute_correlograms_on_sorting( + return _compute_correlograms_on_sorting( sorting_analyzer_or_sorting, window_ms=window_ms, bin_ms=bin_ms, method=method ) @@ -107,6 +130,33 @@ def compute_correlograms( def _make_bins(sorting, window_ms, bin_ms): + """ + Create the bins for the correlogram, in samples. + + The autocorrelogram bins are centered around zero. Each bin + increases in a positive / negative direction starting at zero. + + For example, given a window_ms of 50 ms and a bin_ms of + 5 ms, the bins in unit ms will be: + [-25 to -20, ..., -5 to 0, 0 to 5, ..., 20 to 25]. + + The window size will be clipped if not divisible by the bin size. + + Parameters + ---------- + See ComputeCorrelograms() for parameters. + + Returns + ------- + + bins : np.ndarray + The bins edges in ms + window_size : int + The window size in samples + bin_size : int + The bin size in samples + + """ fs = sorting.sampling_frequency window_size = int(round(fs * window_ms / 2 * 1e-3)) @@ -120,62 +170,56 @@ def _make_bins(sorting, window_ms, bin_ms): return bins, window_size, bin_size -def compute_autocorrelogram_from_spiketrain(spike_times, window_size, bin_size): +def _compute_num_bins(window_size, bin_size): """ - Computes the auto-correlogram from a given spike train. - - This implementation only works if you have numba installed, to accelerate the - computation time. + Internal function to compute number of bins, expects + window_size and bin_size are already divisible. These are + typically generated in `_make_bins()`. - Parameters - ---------- - spike_times: np.ndarray - The ordered spike train to compute the auto-correlogram. - window_size: int - Compute the auto-correlogram between -window_size and +window_size (in sampling time). - bin_size: int - Size of a bin (in sampling time). Returns ------- - tuple (auto_corr, bins) - auto_corr: np.ndarray[int64] - The computed auto-correlogram. + num_bins : int + The total number of bins to span the window, in samples + half_num_bins : int + Half the number of bins. The bins are an equal number + of bins that look forward and backwards from zero, e.g. + [..., -10 to -5, -5 to 0, 0 to 5, 5 to 10, ...] + """ - assert HAVE_NUMBA - return _compute_autocorr_numba(spike_times.astype(np.int64), window_size, bin_size) + num_half_bins = int(window_size // bin_size) + num_bins = int(2 * num_half_bins) + return num_bins, num_half_bins -def compute_crosscorrelogram_from_spiketrain(spike_times1, spike_times2, window_size, bin_size): + +def _compute_correlograms_on_sorting(sorting, window_ms, bin_ms, method="auto"): """ - Computes the cros-correlogram between two given spike trains. + Computes cross-correlograms from multiple units. - This implementation only works if you have numba installed, to accelerate the - computation time. + Entry function to compute correlograms across all units in a `Sorting` + object (i.e. spike trains at all determined offsets will be computed + for each unit against every other unit). Parameters ---------- - spike_times1: np.ndarray - The ordered spike train to compare against the second one. - spike_times2: np.ndarray - The ordered spike train that serves as a reference for the cross-correlogram. - window_size: int - Compute the auto-correlogram between -window_size and +window_size (in sampling time). - bin_size: int - Size of a bin (in sampling time). + sorting : Sorting + A SpikeInterface Sorting object + window_ms : float + The window size over which to perform the cross-correlation, in ms + bin_ms : float + The size of which to bin lags, in ms. + method : str + To use "numpy" or "numba". "auto" will use numba if available, + otherwise numpy. Returns ------- - tuple (auto_corr, bins) - auto_corr: np.ndarray[int64] - The computed auto-correlogram. - """ - assert HAVE_NUMBA - return _compute_crosscorr_numba(spike_times1.astype(np.int64), spike_times2.astype(np.int64), window_size, bin_size) - - -def compute_correlograms_on_sorting(sorting, window_ms, bin_ms, method="auto"): - """ - Computes several cross-correlogram in one course from several clusters. + correlograms : np.array + A (num_units, num_units, num_bins) array where unit x unit correlation + matrices are stacked at all determined time bins. Note the true + correlation is not returned but instead the count of number of matches. + bins : np.array + The bins edges in ms """ assert method in ("auto", "numba", "numpy") @@ -185,23 +229,23 @@ def compute_correlograms_on_sorting(sorting, window_ms, bin_ms, method="auto"): bins, window_size, bin_size = _make_bins(sorting, window_ms, bin_ms) if method == "numpy": - correlograms = compute_correlograms_numpy(sorting, window_size, bin_size) + correlograms = _compute_correlograms_numpy(sorting, window_size, bin_size) if method == "numba": - correlograms = compute_correlograms_numba(sorting, window_size, bin_size) + correlograms = _compute_correlograms_numba(sorting, window_size, bin_size) return correlograms, bins # LOW-LEVEL IMPLEMENTATIONS -def compute_correlograms_numpy(sorting, window_size, bin_size): +def _compute_correlograms_numpy(sorting, window_size, bin_size): """ - Computes cross-correlograms for all units in a sorting object. + Computes correlograms for all units in a sorting object. This very elegant implementation is copied from phy package written by Cyrille Rossant. https://github.com/cortex-lab/phylib/blob/master/phylib/stats/ccg.py - The main modification is way the positive and negative are handled explicitly - for rounding reasons. + The main modification is the way positive and negative are handled + explicitly for rounding reasons. Other slight modifications have been made to fit the SpikeInterface data model (e.g. adding the ability to handle multiple segments). @@ -212,30 +256,66 @@ def compute_correlograms_numpy(sorting, window_size, bin_size): num_units = len(sorting.unit_ids) spikes = sorting.to_spike_vector(concatenated=False) - num_half_bins = int(window_size // bin_size) - num_bins = int(2 * num_half_bins) + num_bins, num_half_bins = _compute_num_bins(window_size, bin_size) correlograms = np.zeros((num_units, num_units, num_bins), dtype="int64") for seg_index in range(num_seg): spike_times = spikes[seg_index]["sample_index"] - spike_labels = spikes[seg_index]["unit_index"] + spike_unit_indices = spikes[seg_index]["unit_index"] - c0 = correlogram_for_one_segment(spike_times, spike_labels, window_size, bin_size) + c0 = correlogram_for_one_segment(spike_times, spike_unit_indices, window_size, bin_size) correlograms += c0 return correlograms -def correlogram_for_one_segment(spike_times, spike_labels, window_size, bin_size): - """ - Called by compute_correlograms_numpy +def correlogram_for_one_segment(spike_times, spike_unit_indices, window_size, bin_size): """ + A very well optimized algorithm for the cross-correlation of + spike trains, copied from the Phy package, written by Cyrille Rossant. - num_half_bins = int(window_size // bin_size) - num_bins = int(2 * num_half_bins) - num_units = len(np.unique(spike_labels)) + Parameters + ---------- + spike_times : np.ndarray + An array of spike times (in samples, not seconds). + This contains spikes from all units. + spike_unit_indices : np.ndarray + An array of labels indicating the unit of the corresponding + spike in `spike_times`. + window_size : int + The window size over which to perform the cross-correlation, in samples + bin_size : int + The size of which to bin lags, in samples. + + Returns + ------- + correlograms : np.array + A (num_units, num_units, num_bins) array of correlograms + between all units at each lag time bin. + + Notes + ----- + For all spikes, time difference between this spike and + every other spike within the window is directly computed + and stored as a count in the relevant lag time bin. + + Initially, the spike_times array is shifted by 1 position, and the difference + computed. This gives the time differences between the closest spikes + (skipping the zero-lag case). Next, the differences between + spikes times in samples are converted into units relative to + bin_size ('binarized'). Spikes in which the binarized difference to + their closest neighbouring spike is greater than half the bin-size are + masked. + + Finally, the indices of the (num_units, num_units, num_bins) correlogram + that need incrementing are done so with `ravel_multi_index()`. This repeats + for all shifts along the spike_train until no spikes have a corresponding + match within the window size. + """ + num_bins, num_half_bins = _compute_num_bins(window_size, bin_size) + num_units = len(np.unique(spike_unit_indices)) correlograms = np.zeros((num_units, num_units, num_bins), dtype="int64") @@ -243,8 +323,8 @@ def correlogram_for_one_segment(spike_times, spike_labels, window_size, bin_size # within the correlogram time window. mask = np.ones_like(spike_times, dtype="bool") - # The loop continues as long as there is at least one spike with - # a matching spike. + # The loop continues as long as there is at least one + # spike with a matching spike. shift = 1 while mask[:-shift].any(): # Number of time samples between spike i and spike i+shift. @@ -264,15 +344,15 @@ def correlogram_for_one_segment(spike_times, spike_labels, window_size, bin_size m = mask[:-shift] # Find the indices in the raveled correlograms array that need - # to be incremented, taking into account the spike clusters. + # to be incremented, taking into account the spike unit labels. if sign == 1: indices = np.ravel_multi_index( - (spike_labels[+shift:][m], spike_labels[:-shift][m], spike_diff_b[m] + num_half_bins), + (spike_unit_indices[+shift:][m], spike_unit_indices[:-shift][m], spike_diff_b[m] + num_half_bins), correlograms.shape, ) else: indices = np.ravel_multi_index( - (spike_labels[:-shift][m], spike_labels[+shift:][m], spike_diff_b[m] + num_half_bins), + (spike_unit_indices[:-shift][m], spike_unit_indices[+shift:][m], spike_diff_b[m] + num_half_bins), correlograms.shape, ) @@ -280,35 +360,66 @@ def correlogram_for_one_segment(spike_times, spike_labels, window_size, bin_size bbins = np.bincount(indices) correlograms.ravel()[: len(bbins)] += bbins + if sign == 1: + # For positive sign, the end bin is < num_half_bins (e.g. + # bin = 29, num_half_bins = 30, will go to index 59 (i.e. the + # last bin). For negative sign, the first bin is == num_half_bins + # e.g. bin = -30, with num_half_bins = 30 will go to bin 0. Therefore + # sign == 1 must mask spike_diff_b <= num_half_bins but sign == -1 + # must count all (possibly repeating across units) cases of + # spike_diff_b == num_half_bins. So we turn it back on here + # for the next loop that starts with the -1 case. + mask[:-shift][spike_diff_b == num_half_bins] = True + shift += 1 return correlograms -def compute_correlograms_numba(sorting, window_size, bin_size): +def _compute_correlograms_numba(sorting, window_size, bin_size): """ - Computes several cross-correlogram in one course - from several cluster. + Computes cross-correlograms between all units in `sorting`. This is a "brute force" method using compiled code (numba) - to accelerate the computation. + to accelerate the computation. See + `_compute_correlograms_one_segment_numba()` for details. + + Parameters + ---------- + sorting : Sorting + A SpikeInterface Sorting object + window_size : int + The window size over which to perform the cross-correlation, in samples + bin_size : int + The size of which to bin lags, in samples. + + Returns + ------- + correlograms: np.array + A (num_units, num_units, num_bins) array of correlograms + between all units at each lag time bin. Implementation: Aurélien Wyngaard """ - assert HAVE_NUMBA, "numba version of this function requires installation of numba" - num_bins = 2 * int(window_size / bin_size) + num_bins, num_half_bins = _compute_num_bins(window_size, bin_size) num_units = len(sorting.unit_ids) + spikes = sorting.to_spike_vector(concatenated=False) correlograms = np.zeros((num_units, num_units, num_bins), dtype=np.int64) for seg_index in range(sorting.get_num_segments()): spike_times = spikes[seg_index]["sample_index"] - spike_labels = spikes[seg_index]["unit_index"] - - _compute_correlograms_numba( - correlograms, spike_times.astype(np.int64), spike_labels.astype(np.int32), window_size, bin_size + spike_unit_indices = spikes[seg_index]["unit_index"] + + _compute_correlograms_one_segment_numba( + correlograms, + spike_times.astype(np.int64, copy=False), + spike_unit_indices.astype(np.int32, copy=False), + window_size, + bin_size, + num_half_bins, ) return correlograms @@ -316,75 +427,71 @@ def compute_correlograms_numba(sorting, window_size, bin_size): if HAVE_NUMBA: - @numba.jit(nopython=True, nogil=True, cache=False) - def _compute_autocorr_numba(spike_times, window_size, bin_size): - num_half_bins = window_size // bin_size - num_bins = 2 * num_half_bins - - auto_corr = np.zeros(num_bins, dtype=np.int64) - - for i in range(len(spike_times)): - for j in range(i + 1, len(spike_times)): - diff = spike_times[j] - spike_times[i] - - if diff > window_size: - break - - bin = int(math.floor(diff / bin_size)) - # ~ auto_corr[num_bins//2 - bin - 1] += 1 - auto_corr[num_half_bins + bin] += 1 - # ~ print(diff, bin, num_half_bins + bin) - - bin = int(math.floor(-diff / bin_size)) - auto_corr[num_half_bins + bin] += 1 - # ~ print(diff, bin, num_half_bins + bin) - - return auto_corr + @numba.jit( + nopython=True, + nogil=True, + cache=False, + ) + def _compute_correlograms_one_segment_numba( + correlograms, spike_times, spike_unit_indices, window_size, bin_size, num_half_bins + ): + """ + Compute the correlograms using `numba` for speed. + + The algorithm works by brute-force iteration through all + pairs of spikes (skipping those when outside of the window). + The spike-time difference and its time bin are computed + and stored in a (num_units, num_units, num_bins) + correlogram. The correlogram must be passed as an + argument and is filled in-place. + + Parameters + --------- + + correlograms: np.array + A (num_units, num_units, num_bins) array of correlograms + between all units at each lag time bin. This is passed + as counts for all segments are added to it. + spike_times : np.ndarray + An array of spike times (in samples, not seconds). + This contains spikes from all units. + spike_unit_indices : np.ndarray + An array of labels indicating the unit of the corresponding + spike in `spike_times`. + window_size : int + The window size over which to perform the cross-correlation, in samples + bin_size : int + The size of which to bin lags, in samples. + """ + start_j = 0 + for i in range(spike_times.size): + for j in range(start_j, spike_times.size): - @numba.jit(nopython=True, nogil=True, cache=False) - def _compute_crosscorr_numba(spike_times1, spike_times2, window_size, bin_size): - num_half_bins = window_size // bin_size - num_bins = 2 * num_half_bins + if i == j: + continue - cross_corr = np.zeros(num_bins, dtype=np.int64) + diff = spike_times[i] - spike_times[j] - start_j = 0 - for i in range(len(spike_times1)): - for j in range(start_j, len(spike_times2)): - diff = spike_times1[i] - spike_times2[j] + # When the diff is exactly the window size, keep going + # without iterating start_j in case this spike also has + # other diffs with other units that == window size. + if diff == window_size: + continue - if diff >= window_size: + # if the time of spike i is more than window size later than + # spike j, then spike i + 1 will also be more than a window size + # later than spike j. Iterate the start_j and check the next spike. + if diff > window_size: start_j += 1 continue + + # If the time of spike i is more than a window size earlier + # than spike j, then all following j spikes will be even later + # i spikes and so all more than a window size earlier. So move + # onto the next i. if diff < -window_size: break - bin = int(math.floor(diff / bin_size)) - # ~ bin = diff // bin_size - cross_corr[num_half_bins + bin] += 1 - # ~ print(diff, bin, num_half_bins + bin) + bin = diff // bin_size - return cross_corr - - @numba.jit( - nopython=True, - nogil=True, - cache=False, - parallel=True, - ) - def _compute_correlograms_numba(correlograms, spike_times, spike_labels, window_size, bin_size): - n_units = correlograms.shape[0] - - for i in numba.prange(n_units): - # ~ for i in range(n_units): - spike_times1 = spike_times[spike_labels == i] - - for j in range(i, n_units): - spike_times2 = spike_times[spike_labels == j] - - if i == j: - correlograms[i, j, :] += _compute_autocorr_numba(spike_times1, window_size, bin_size) - else: - cc = _compute_crosscorr_numba(spike_times1, spike_times2, window_size, bin_size) - correlograms[i, j, :] += cc - correlograms[j, i, :] += cc[::-1] + correlograms[spike_unit_indices[i], spike_unit_indices[j], num_half_bins + bin] += 1 diff --git a/src/spikeinterface/postprocessing/tests/test_correlograms.py b/src/spikeinterface/postprocessing/tests/test_correlograms.py index eef4af10fc..66d84c9565 100644 --- a/src/spikeinterface/postprocessing/tests/test_correlograms.py +++ b/src/spikeinterface/postprocessing/tests/test_correlograms.py @@ -7,12 +7,18 @@ except ModuleNotFoundError as err: HAVE_NUMBA = False - from spikeinterface import NumpySorting, generate_sorting from spikeinterface.postprocessing.tests.common_extension_tests import AnalyzerExtensionCommonTestSuite from spikeinterface.postprocessing import ComputeCorrelograms -from spikeinterface.postprocessing.correlograms import compute_correlograms_on_sorting, _make_bins +from spikeinterface.postprocessing.correlograms import ( + _compute_correlograms_on_sorting, + _make_bins, + compute_correlograms, +) import pytest +from pytest import param + +SKIP_NUMBA = pytest.mark.skipif(not HAVE_NUMBA, reason="Numba not available") class TestComputeCorrelograms(AnalyzerExtensionCommonTestSuite): @@ -22,17 +28,37 @@ class TestComputeCorrelograms(AnalyzerExtensionCommonTestSuite): [ dict(method="numpy"), dict(method="auto"), - pytest.param(dict(method="numba"), marks=pytest.mark.skipif(not HAVE_NUMBA, reason="Numba not available")), + param(dict(method="numba"), marks=SKIP_NUMBA), ], ) def test_extension(self, params): self.run_extension_tests(ComputeCorrelograms, params) + @pytest.mark.parametrize("method", ["numpy", param("numba", marks=SKIP_NUMBA)]) + def test_sortinganalyzer_correlograms(self, method): + """ + Test the outputs when using SortingAnalyzer against + the output passing sorting directly to `compute_correlograms`. + Sorting to `compute_correlograms` is tested extensively below + so if these match it means `SortingAnalyzer` is working. + """ + sorting_analyzer = self._prepare_sorting_analyzer("memory", sparse=False, extension_class=ComputeCorrelograms) + + params = dict(method=method, window_ms=100, bin_ms=6.5) + ext_numpy = sorting_analyzer.compute(ComputeCorrelograms.extension_name, **params) + + result_sorting, bins_sorting = compute_correlograms(self.sorting, **params) + + assert np.array_equal(result_sorting, ext_numpy.data["ccgs"]) + assert np.array_equal(bins_sorting, ext_numpy.data["bins"]) + +# Unit Tests +############ def test_make_bins(): """ Check the `_make_bins()` function that generates time bins (lags) for - the correllogram creates the expected number of bins. + the correlogram creates the expected number of bins. """ sorting = generate_sorting(num_units=5, sampling_frequency=30000.0, durations=[10.325, 3.5], seed=0) @@ -45,96 +71,78 @@ def test_make_bins(): bin_ms = 2.0 bins, window_size, bin_size = _make_bins(sorting, window_ms, bin_ms) assert bins.size == np.floor(window_ms / bin_ms) + 1 + assert np.array_equal(bins, np.linspace(-30, 30, bins.size)) -def _test_correlograms(sorting, window_ms, bin_ms, methods): - for method in methods: - correlograms, bins = compute_correlograms_on_sorting(sorting, window_ms=window_ms, bin_ms=bin_ms, method=method) - if method == "numpy": - ref_bins = bins - else: - assert np.allclose(bins, ref_bins, atol=1e-10), f"Failed with method={method}" - - -def test_equal_results_correlograms(): - # compare that the 2 methods have same results - methods = ["numpy"] - if HAVE_NUMBA: - methods.append("numba") +@pytest.mark.skipif(not HAVE_NUMBA, reason="Numba not available") +@pytest.mark.parametrize("window_and_bin_ms", [(60.0, 2.0), (3.57, 1.6421)]) +def test_equal_results_correlograms(window_and_bin_ms): + """ + Test that the 2 methods have same results with some varied time bins + that are not tested in other tests. + """ + window_ms, bin_ms = window_and_bin_ms sorting = generate_sorting(num_units=5, sampling_frequency=30000.0, durations=[10.325, 3.5], seed=0) - _test_correlograms(sorting, window_ms=60.0, bin_ms=2.0, methods=methods) - _test_correlograms(sorting, window_ms=43.57, bin_ms=1.6421, methods=methods) + result_numpy, bins_numpy = _compute_correlograms_on_sorting( + sorting, window_ms=window_ms, bin_ms=bin_ms, method="numpy" + ) + result_numba, bins_numba = _compute_correlograms_on_sorting( + sorting, window_ms=window_ms, bin_ms=bin_ms, method="numba" + ) + + assert np.array_equal(result_numpy, result_numba) + assert np.array_equal(result_numpy, result_numba) -def test_flat_cross_correlogram(): +@pytest.mark.parametrize("method", ["numpy", param("numba", marks=SKIP_NUMBA)]) +def test_flat_cross_correlogram(method): """ Check that the correlogram (num_units x num_units x num_bins) does not vary too much across time bins (lags), for entries representing two different units. """ sorting = generate_sorting(num_units=2, sampling_frequency=10000.0, durations=[100000.0], seed=0) - methods = ["numpy"] - if HAVE_NUMBA: - methods.append("numba") + correlograms, bins = _compute_correlograms_on_sorting(sorting, window_ms=50.0, bin_ms=1.0, method=method) + cc = correlograms[0, 1, :].copy() + m = np.mean(cc) - for method in methods: - correlograms, bins = compute_correlograms_on_sorting(sorting, window_ms=50.0, bin_ms=1.0, method=method) - cc = correlograms[0, 1, :].copy() - m = np.mean(cc) - assert np.all(cc > (m * 0.90)) - assert np.all(cc < (m * 1.10)) + assert np.all(cc > (m * 0.90)) + assert np.all(cc < (m * 1.10)) -def test_auto_equal_cross_correlograms(): +@pytest.mark.parametrize("method", ["numpy", param("numba", marks=SKIP_NUMBA)]) +def test_auto_equal_cross_correlograms(method): """ - check if cross correlogram is the same as autocorrelogram + Check if cross correlogram is the same as autocorrelogram by removing n spike in bin zeros - numpy method: - * have problem for the left bin - * have problem on center """ - - methods = ["numpy"] - if HAVE_NUMBA: - methods.append("numba") - num_spike = 2000 spike_times = np.sort(np.unique(np.random.randint(0, 100000, num_spike))) num_spike = spike_times.size units_dict = {"1": spike_times, "2": spike_times} sorting = NumpySorting.from_unit_dict([units_dict], sampling_frequency=10000.0) - for method in methods: - correlograms, bins = compute_correlograms_on_sorting(sorting, window_ms=10.0, bin_ms=0.1, method=method) + correlograms, bins = _compute_correlograms_on_sorting(sorting, window_ms=10.0, bin_ms=0.1, method=method) - num_half_bins = correlograms.shape[2] // 2 + num_half_bins = correlograms.shape[2] // 2 - cc = correlograms[0, 1, :] - ac = correlograms[0, 0, :] - cc_corrected = cc.copy() - cc_corrected[num_half_bins] -= num_spike + cc = correlograms[0, 1, :] + ac = correlograms[0, 0, :] + cc_corrected = cc.copy() + cc_corrected[num_half_bins] -= num_spike - if method == "numpy": - # numpy method have some border effect on left - assert np.array_equal(cc_corrected[1:num_half_bins], ac[1:num_half_bins]) - # numpy method have some problem on center - assert np.array_equal(cc_corrected[num_half_bins + 1 :], ac[num_half_bins + 1 :]) - else: - assert np.array_equal(cc_corrected, ac) + assert np.array_equal(cc_corrected, ac) -def test_detect_injected_correlation(): +@pytest.mark.parametrize("method", ["numpy", param("numba", marks=SKIP_NUMBA)]) +def test_detect_injected_correlation(method): """ Inject 1.44 ms of correlation every 13 spikes and compute cross-correlation. Check that the time bin lag with the peak correlation lag is 1.44 ms (within tolerance of a sampling period). """ - methods = ["numpy"] - if HAVE_NUMBA: - methods.append("numba") - sampling_frequency = 10000.0 num_spike = 2000 rng = np.random.default_rng(seed=0) @@ -143,6 +151,7 @@ def test_detect_injected_correlation(): n = min(spike_times1.size, spike_times2.size) spike_times1 = spike_times1[:n] spike_times2 = spike_times2[:n] + # inject 1.44 ms correlation every 13 spikes injected_delta_ms = 1.44 spike_times2[::13] = spike_times1[::13] + int(injected_delta_ms / 1000 * sampling_frequency) @@ -151,15 +160,212 @@ def test_detect_injected_correlation(): units_dict = {"1": spike_times1, "2": spike_times2} sorting = NumpySorting.from_unit_dict([units_dict], sampling_frequency=sampling_frequency) - for method in methods: - correlograms, bins = compute_correlograms_on_sorting(sorting, window_ms=10.0, bin_ms=0.1, method=method) + correlograms, bins = _compute_correlograms_on_sorting(sorting, window_ms=10.0, bin_ms=0.1, method=method) + + cc_01 = correlograms[0, 1, :] + cc_10 = correlograms[1, 0, :] - cc_01 = correlograms[0, 1, :] - cc_10 = correlograms[1, 0, :] + peak_location_01_ms = bins[np.argmax(cc_01)] + peak_location_02_ms = bins[np.argmax(cc_10)] - peak_location_01_ms = bins[np.argmax(cc_01)] - peak_location_02_ms = bins[np.argmax(cc_10)] + sampling_period_ms = 1000.0 / sampling_frequency + assert abs(peak_location_01_ms) - injected_delta_ms < sampling_period_ms + assert abs(peak_location_02_ms) - injected_delta_ms < sampling_period_ms - sampling_period_ms = 1000.0 / sampling_frequency - assert abs(peak_location_01_ms) - injected_delta_ms < sampling_period_ms - assert abs(peak_location_02_ms) - injected_delta_ms < sampling_period_ms + +# Functional Tests +################### +@pytest.mark.parametrize("fill_all_bins", [True, False]) +@pytest.mark.parametrize("on_time_bin", [True, False]) +@pytest.mark.parametrize("multi_segment", [True, False]) +def test_compute_correlograms(fill_all_bins, on_time_bin, multi_segment): + """ + Test the entry function `compute_correlograms` under a variety of conditions. + For specifics of `fill_all_bins` and `on_time_bin` see `generate_correlogram_test_dataset()`. + + This function tests numpy and numba in one go, to avoid over-parameterising the method. + It tests both a single-segment and multi-segment dataset. The way that segments are + handled for the correlogram is to combine counts across all segments, therefore the + counts should double when two segments with identical spike times / labels are used. + """ + sampling_frequency = 30000 + window_ms, bin_ms, spike_times, spike_unit_indices, expected_bins, expected_result_auto, expected_result_corr = ( + generate_correlogram_test_dataset(sampling_frequency, fill_all_bins, on_time_bin) + ) + + if multi_segment: + sorting = NumpySorting.from_times_labels( + times_list=[spike_times], labels_list=[spike_unit_indices], sampling_frequency=sampling_frequency + ) + else: + sorting = NumpySorting.from_times_labels( + times_list=[spike_times, spike_times], + labels_list=[spike_unit_indices, spike_unit_indices], + sampling_frequency=sampling_frequency, + ) + expected_result_auto *= 2 + expected_result_corr *= 2 + + result_numba, bins_numba = compute_correlograms(sorting, window_ms=window_ms, bin_ms=bin_ms, method="numba") + result_numpy, bins_numpy = compute_correlograms(sorting, window_ms=window_ms, bin_ms=bin_ms, method="numpy") + + for auto_idx in [(0, 0), (1, 1), (2, 2)]: + assert np.array_equal(expected_result_auto, result_numpy[auto_idx]) + assert np.array_equal(expected_result_auto, result_numba[auto_idx]) + + for auto_idx in [(1, 0), (0, 1), (0, 2), (2, 0), (1, 2), (2, 1)]: + assert np.array_equal(expected_result_corr, result_numpy[auto_idx]) + assert np.array_equal(expected_result_corr, result_numba[auto_idx]) + + +@pytest.mark.parametrize("method", ["numpy", param("numba", marks=SKIP_NUMBA)]) +def test_compute_correlograms_different_units(method): + """ + Make a supplementary test to `test_compute_correlograms` in which all + units had the same spike train. Test here a simpler and accessible + test case with only two neurons with different spike time differences + within and across units. + + This case is simple enough to validate by hand, for example for the + result[1, 1] case we are looking at the autocorrelogram of the unit '1'. + The spike times are 4 and 16 s, therefore we expect to see a count in + the +/- 10 to 15 s bin. + """ + sampling_frequency = 30000 + spike_times = np.array([0, 4, 8, 16]) / 1000 * sampling_frequency + spike_times.astype(int) + + spike_unit_indices = np.array([0, 1, 0, 1]) + + window_ms = 40 + bin_ms = 5 + + sorting = NumpySorting.from_times_labels( + times_list=[spike_times], labels_list=[spike_unit_indices], sampling_frequency=sampling_frequency + ) + + result, bins = compute_correlograms(sorting, window_ms=window_ms, bin_ms=bin_ms, method=method) + + assert np.array_equal(result[0, 0], np.array([0, 0, 1, 0, 0, 1, 0, 0])) + + assert np.array_equal(result[1, 1], np.array([0, 1, 0, 0, 0, 0, 1, 0])) + + assert np.array_equal(result[1, 0], np.array([0, 0, 0, 1, 1, 1, 0, 1])) + + assert np.array_equal(result[0, 1], np.array([1, 0, 1, 1, 1, 0, 0, 0])) + + +def generate_correlogram_test_dataset(sampling_frequency, fill_all_bins, hit_bin_edge): + """ + This generates a detailed correlogram test and expected outputs, for a number of + test cases: + + overflow edges : when there are counts expected in every measured bins, otherwise + counts are expected only in a (central) subset of bins. + hit_bin_edge : if `True`, the difference in spike times are created to land + exactly as multiples of the bin size, an edge case that caused + some problems in previous iterations of the algorithm. + + The approach used is to create a set of spike times which are + multiples of a 'base_diff_time'. When `hit_bin_edge` is `False` this is + set to 5.1 ms. So, we have spikes at: + 5.1 ms, 10.2 ms, 15.3 ms, ..., base_diff_time * num_filled_bins + + This means consecutive spike times are 5.1 ms apart. Then every two + spike times are 10.2 ms apart. This gives predictable bin counts, + that are maximal at the smaller bins (e.g. 5-10 s) and minimal at + the later bins (e.g. 100-105 s). Note at more than num_filled_bins the + the times will overflow to the next bin and test wont work. None of these + parameters should be changed. + + When `hit_bin_edge` is `False`, we expect that bin counts will increase from the + edge of the bins to the middle, maximum in the middle, 0 in the exact center + (-5 to 0, 0 to 5) and then decreasing until the end of the bin. For the autocorrelation, + the zero-lag case is not included and the two central bins will be zero. + + Different units are tested by repeating the spike times. This means all + results for all units autocorrelation and cross-correlation will be + identical, simplifying the tests. The only difference is that auto-correlation + does not count the zero-lag bins but cross-correlation does. Because the + spike times are identical, this means in the cross-correlation case we have + `num_filled_bins` in the central bin. By convention, this is always put + in the positive (i.e. 0-5 s) not negative (-5 to 0 s) bin. I guess it + could make sense to force it into both positive and negative bins? + + Finally, the case when the time differences are exactly the bin + size is tested. In this case the spike times are [0, 5, 10, 15, ...] + with all diffs 5 and the `bin_ms` set to 5. By convention, when spike + diffs hit the bin edge they are set into the 'right' (i.e. positive) + bin. For positive bins this does not change, but for negative bins + all entries are shifted one place to the right. + """ + num_units = 3 + + # These give us 61 bins, [-150, -145,...,0,...,145, 150] + window_ms = 300 + bin_ms = 5 + + # If overflow edges, we will have a diff at every possible + # bin e.g. the counts will be [31, 30, ..., 30, 31]. If not, + # test the case where there are zero bins e.g. [0, 0, 9, 8, ..., 8, 9, 0, 0]. + if fill_all_bins: + num_filled_bins = 60 + else: + num_filled_bins = 10 + + # If we are on a time bin, make the time delays exactly + # the same as a time bin, testing this tricky edge case. + if hit_bin_edge: + base_diff_time = bin_ms / 1000 + else: + base_diff_time = bin_ms / 1000 + 0.0001 # i.e. 0.0051 s + + # Now, make a set of times that increase by `base_diff_time` e.g. + # if base_diff_time=0.0051 then our spike times are [`0.0051, 0.0102, ...]` + spike_times = np.repeat(np.arange(num_filled_bins), num_units) * base_diff_time + spike_unit_indices = np.tile(np.arange(num_units), int(spike_times.size / num_units)) + + spike_times *= sampling_frequency + spike_times = spike_times.astype(int) + + # Here generate the expected results. This is done pretty much hard-coded + # to be as explicit as possible. + + # Generate the expected bins + num_bins = int(window_ms / bin_ms) + assert window_ms == 300, "dont change the window_ms" + assert bin_ms == 5, "dont change the bin_ms" + expected_bins = np.linspace(-150, 150, num_bins + 1) + + # In this case, all time bins are shifted to the right for the + # negative shift due to the diffs lying on the bin edge. + # [30, 31, ..., 59, 0, 59, ..., 30, 31] + if fill_all_bins and hit_bin_edge: + expected_result_auto = np.r_[np.arange(30, 60), 0, np.flip(np.arange(31, 60))] + + # In this case there are no edge effects and the bin counts + # [31, 30, ..., 59, 0, 0, 59, ..., 30, 31] + # are symmetrical + elif fill_all_bins and not hit_bin_edge: + forward = np.r_[np.arange(31, 60), 0] + expected_result_auto = np.r_[forward, np.flip(forward)] + + # Here we have many zero bins, but the existing bins are + # shifted left in the negative-bin base + # [0, 0, ..., 1, 2, 3, ..., 10, 0, 10, ..., 3, 2, 1, ..., 0] + elif not fill_all_bins and hit_bin_edge: + forward = np.r_[np.zeros(19), np.arange(10)] + expected_result_auto = np.r_[0, forward, 0, np.flip(forward)] + + # Here we have many zero bins and they are symmetrical + # [0, 0, ..., 1, 2, 3, ..., 10, 0, 10, ..., 3, 2, 1, ..., 0, 0] + elif not fill_all_bins and not hit_bin_edge: + forward = np.r_[np.zeros(19), np.arange(10), 0] + expected_result_auto = np.r_[forward, np.flip(forward)] + + # The zero-lag bins are only skipped in the autocorrelogram + # case. + expected_result_corr = expected_result_auto.copy() + expected_result_corr[int(num_bins / 2)] = num_filled_bins + + return window_ms, bin_ms, spike_times, spike_unit_indices, expected_bins, expected_result_auto, expected_result_corr