diff --git a/src/spikeinterface/postprocessing/correlograms.py b/src/spikeinterface/postprocessing/correlograms.py index bc7d2578fa..074d7cda9f 100644 --- a/src/spikeinterface/postprocessing/correlograms.py +++ b/src/spikeinterface/postprocessing/correlograms.py @@ -13,21 +13,49 @@ except ModuleNotFoundError as err: HAVE_NUMBA = False +# TODO: here the default is 50 ms but in the docs it says 100 ms? +# _set_params, _select_extension_data, _run, _get_data I think are +# sorting analyzer things. Docstrings can be added here and propagated to +# all sorting analyer functions OR can be described in the class docstring. +# otherwise these are quite hard to understand where they are called in the +# code as not called internally on the class. + class ComputeCorrelograms(AnalyzerExtension): """ Compute auto and cross correlograms. + In the extracellular electrophysiology context, a correlogram + is a visualisation of the results of a cross-correlation + between two spike trains. The cross-correlation slides one spike train + along another sample-by-sample, taking the correlation at each 'lag'. This results + in a plot with 'lag' (i.e. time offset) on the x-axis and 'correlation' + (i.e. how similar to two spike trains are) on the y-axis. Often, the + values on the y-axis are given as the correlation, but may + be the covariance or simple frequencies. + + Correlograms are often used to determine whether a unit has + ISI violations. In this context, a 'window' around spikes is first + specified. For example, if a window of 100 ms is taken, we will + take the correlation at lags from -100 ms to +100 ms around the spike peak. + In theory, we can have as many lags as we have samples. Often, this + visualisation is too high resolution and instead the lags are binned + (e.g. 0-5 ms, 5-10 ms, ..., 95-100 ms bins). When using counts as output, + binning the lags involves adding up all counts across a range of lags. + Parameters ---------- sorting_analyzer: SortingAnalyzer A SortingAnalyzer object window_ms : float, default: 50.0 - The window in ms + The window around the spike to compute the correlation in ms. For example, + if 50 ms, the correlations will be computed at tags -25 ms ... 25 ms. bin_ms : float, default: 1.0 - The bin size in ms + The bin size in ms. This determines the bin size over which to + combine lags. For example, with a window size of -25 ms to 25 ms, and + bin size 1 ms, the correlation will be binned as -25 ms, -24 ms, ... method : "auto" | "numpy" | "numba", default: "auto" - If "auto" and numba is installed, numba is used, otherwise numpy is used + If "auto" and numba is installed, numba is used, otherwise numpy is used. Returns ------- @@ -40,7 +68,7 @@ class ComputeCorrelograms(AnalyzerExtension): The bin edges in ms Returns - ------- + ------- isi_histograms : np.array 2D array with ISI histograms (num_units, num_bins) bins : np.array @@ -83,13 +111,18 @@ def _get_data(self): compute_correlograms_sorting_analyzer = ComputeCorrelograms.function_factory() +# TODO: ask when is this function used vs. compute_correlograms_on_sorting()? def compute_correlograms( sorting_analyzer_or_sorting, window_ms: float = 50.0, bin_ms: float = 1.0, method: str = "auto", ): - + """ + Convenience entry function to handle computation of + correlograms based on the method used. See ComputeCorrelograms() + for parameters. + """ if isinstance(sorting_analyzer_or_sorting, MockWaveformExtractor): sorting_analyzer_or_sorting = sorting_analyzer_or_sorting.sorting @@ -107,6 +140,35 @@ def compute_correlograms( def _make_bins(sorting, window_ms, bin_ms): + """ + Create the bins for the autocorrelogram, in samples. + + The autocorrelogram bins are centered around zero but do not + include the results from zero lag. Each bin increases in + a positive / negative direction starting at zero. + + For example, given a window_ms of 50 ms and a bin_ms of + 5 ms, the bins in unit ms will be: + [-25 to -20, ..., -5 to 0, 0 to 5, ..., 20 to 25]. + + The window size will be clipped if not divisible by the bin size. + The bins are output in sample units, not seconds. + + Parameters + ---------- + See ComputeCorrelograms() for parameters. + + Returns + ------- + + bins : np.ndarray + The bins edges in ms + window_size : int + The window size in samples + bin_size : int + The bin size in samples + + """ fs = sorting.sampling_frequency window_size = int(round(fs * window_ms / 2 * 1e-3)) @@ -120,6 +182,9 @@ def _make_bins(sorting, window_ms, bin_ms): return bins, window_size, bin_size +# TODO: in another PR, coerce this input into `correlogram_for_one_segment()` +# to provide a numpy and numba version. Consider window_size and bin_size +# # being taken as ms to match general API. def compute_autocorrelogram_from_spiketrain(spike_times, window_size, bin_size): """ Computes the auto-correlogram from a given spike train. @@ -145,6 +210,9 @@ def compute_autocorrelogram_from_spiketrain(spike_times, window_size, bin_size): return _compute_autocorr_numba(spike_times.astype(np.int64), window_size, bin_size) +# TODO: in another PR, coerce this input into `correlogram_for_one_segment()` +# to provide a numpy and numba version. Consider window_size and bin_size +# being taken as ms to match general API. def compute_crosscorrelogram_from_spiketrain(spike_times1, spike_times2, window_size, bin_size): """ Computes the cros-correlogram between two given spike trains. @@ -175,7 +243,18 @@ def compute_crosscorrelogram_from_spiketrain(spike_times1, spike_times2, window_ def compute_correlograms_on_sorting(sorting, window_ms, bin_ms, method="auto"): """ - Computes several cross-correlogram in one course from several clusters. + Entry function to compute correlograms across all units in a `Sorting` + object (i.e. spike trains at all determined offsets will be computed + for each unit against every other unit). + + Returns + ------- + correlograms : np.array + A (num_units, num_units, num_bins) array where unit x unit correlation + matrices are stacked at all determined time bins. Note the true + correlation is not returned but instead the count of number of matches. + bins : np.array + The bins edges in ms """ assert method in ("auto", "numba", "numpy") @@ -228,11 +307,48 @@ def compute_correlograms_numpy(sorting, window_size, bin_size): return correlograms +# TODO: it would be better for this function to take num_half_bins +# and num_bins directly. However, this will break misc_metrics.slidingRP_violations() +# Check tests for slidingRP_violations(), write one covering that functionality +# if required, then make the refactoring. +# TODO: also make clear the output are always counts, not correlation / covariance matrices def correlogram_for_one_segment(spike_times, spike_labels, window_size, bin_size): """ - Called by compute_correlograms_numpy - """ + A very well optimized algorithm for the cross-correlation of + spike trains, copied from phy package written by Cyrille Rossant. + + This method does not perform a cross-correlation in the typical + way (sliding and computing correlations, or via Fourier transform). + Instead the time difference between every other spike within the + window is directly computer and stored as a count in the relevant bin. + + Initially, the spike_times array is shifted by 1 position, and the difference + computed. This gives the time differences betwen the closest spikes + (skipping the zero-lag case). Next, the differences between + spikes times in samples are converted into units relative to + bin_size ('binarized'). Spikes in which the binarized difference to + their closest neighbouring spike is greater than half the bin-size are + masked and not compared in future. Finally, the indicies of the + (num_units, num_units, num_bins) correlogram in which there are + a match are found and iterated appropriated. This repeats + for all shifts long the spike_train until no spikes have a corepsponding + match within the window size. + + # TODO: is every combination really checked in this shifting procedure? + Parameters + ---------- + spike_times : np.ndarray + An array of spike times (in samples, not seconds). This contains + spikes from all units. + spike_labels : np.ndarray + An array of labels indicating the unit of the corresponding spike in + `spike_times`. + window_size : int + The window size over which to perform the cross-correlation, in samples + bin_size : int + The size of which to bin lags, in samples. TODO: come up with some standard terminology and way of describing this from within the module. + """ num_half_bins = int(window_size // bin_size) num_bins = int(2 * num_half_bins) num_units = len(np.unique(spike_labels))