Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend correlogram docstrings #3011

Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 124 additions & 8 deletions src/spikeinterface/postprocessing/correlograms.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,49 @@
except ModuleNotFoundError as err:
HAVE_NUMBA = False

# TODO: here the default is 50 ms but in the docs it says 100 ms?
# _set_params, _select_extension_data, _run, _get_data I think are
# sorting analyzer things. Docstrings can be added here and propagated to
# all sorting analyer functions OR can be described in the class docstring.
# otherwise these are quite hard to understand where they are called in the
# code as not called internally on the class.


class ComputeCorrelograms(AnalyzerExtension):
"""
Compute auto and cross correlograms.

In the extracellular electrophysiology context, a correlogram
is a visualisation of the results of a cross-correlation
between two spike trains. The cross-correlation slides one spike train
along another sample-by-sample, taking the correlation at each 'lag'. This results
in a plot with 'lag' (i.e. time offset) on the x-axis and 'correlation'
(i.e. how similar to two spike trains are) on the y-axis. Often, the
values on the y-axis are given as the correlation, but may
be the covariance or simple frequencies.

Correlograms are often used to determine whether a unit has
ISI violations. In this context, a 'window' around spikes is first
specified. For example, if a window of 100 ms is taken, we will
take the correlation at lags from -100 ms to +100 ms around the spike peak.
In theory, we can have as many lags as we have samples. Often, this
visualisation is too high resolution and instead the lags are binned
(e.g. 0-5 ms, 5-10 ms, ..., 95-100 ms bins). When using counts as output,
binning the lags involves adding up all counts across a range of lags.

Parameters
----------
sorting_analyzer: SortingAnalyzer
A SortingAnalyzer object
window_ms : float, default: 50.0
The window in ms
The window around the spike to compute the correlation in ms. For example,
if 50 ms, the correlations will be computed at tags -25 ms ... 25 ms.
bin_ms : float, default: 1.0
The bin size in ms
The bin size in ms. This determines the bin size over which to
combine lags. For example, with a window size of -25 ms to 25 ms, and
bin size 1 ms, the correlation will be binned as -25 ms, -24 ms, ...
method : "auto" | "numpy" | "numba", default: "auto"
If "auto" and numba is installed, numba is used, otherwise numpy is used
If "auto" and numba is installed, numba is used, otherwise numpy is used.

Returns
-------
Expand All @@ -40,7 +68,7 @@ class ComputeCorrelograms(AnalyzerExtension):
The bin edges in ms

Returns
-------
-------
isi_histograms : np.array
2D array with ISI histograms (num_units, num_bins)
bins : np.array
Expand Down Expand Up @@ -83,13 +111,18 @@ def _get_data(self):
compute_correlograms_sorting_analyzer = ComputeCorrelograms.function_factory()


# TODO: ask when is this function used vs. compute_correlograms_on_sorting()?
def compute_correlograms(
sorting_analyzer_or_sorting,
window_ms: float = 50.0,
bin_ms: float = 1.0,
method: str = "auto",
):

"""
Convenience entry function to handle computation of
correlograms based on the method used. See ComputeCorrelograms()
for parameters.
"""
if isinstance(sorting_analyzer_or_sorting, MockWaveformExtractor):
sorting_analyzer_or_sorting = sorting_analyzer_or_sorting.sorting

Expand All @@ -107,6 +140,35 @@ def compute_correlograms(


def _make_bins(sorting, window_ms, bin_ms):
"""
Create the bins for the autocorrelogram, in samples.

The autocorrelogram bins are centered around zero but do not
include the results from zero lag. Each bin increases in
a positive / negative direction starting at zero.

For example, given a window_ms of 50 ms and a bin_ms of
5 ms, the bins in unit ms will be:
[-25 to -20, ..., -5 to 0, 0 to 5, ..., 20 to 25].

The window size will be clipped if not divisible by the bin size.
The bins are output in sample units, not seconds.

Parameters
----------
See ComputeCorrelograms() for parameters.

Returns
-------

bins : np.ndarray
The bins edges in ms
window_size : int
The window size in samples
bin_size : int
The bin size in samples

"""
fs = sorting.sampling_frequency

window_size = int(round(fs * window_ms / 2 * 1e-3))
Expand All @@ -120,6 +182,9 @@ def _make_bins(sorting, window_ms, bin_ms):
return bins, window_size, bin_size


# TODO: in another PR, coerce this input into `correlogram_for_one_segment()`
# to provide a numpy and numba version. Consider window_size and bin_size
# # being taken as ms to match general API.
def compute_autocorrelogram_from_spiketrain(spike_times, window_size, bin_size):
"""
Computes the auto-correlogram from a given spike train.
Expand All @@ -145,6 +210,9 @@ def compute_autocorrelogram_from_spiketrain(spike_times, window_size, bin_size):
return _compute_autocorr_numba(spike_times.astype(np.int64), window_size, bin_size)


# TODO: in another PR, coerce this input into `correlogram_for_one_segment()`
# to provide a numpy and numba version. Consider window_size and bin_size
# being taken as ms to match general API.
def compute_crosscorrelogram_from_spiketrain(spike_times1, spike_times2, window_size, bin_size):
"""
Computes the cros-correlogram between two given spike trains.
Expand Down Expand Up @@ -175,7 +243,18 @@ def compute_crosscorrelogram_from_spiketrain(spike_times1, spike_times2, window_

def compute_correlograms_on_sorting(sorting, window_ms, bin_ms, method="auto"):
"""
Computes several cross-correlogram in one course from several clusters.
Entry function to compute correlograms across all units in a `Sorting`
object (i.e. spike trains at all determined offsets will be computed
for each unit against every other unit).

Returns
-------
correlograms : np.array
A (num_units, num_units, num_bins) array where unit x unit correlation
matrices are stacked at all determined time bins. Note the true
correlation is not returned but instead the count of number of matches.
bins : np.array
The bins edges in ms
"""
assert method in ("auto", "numba", "numpy")

Expand Down Expand Up @@ -228,11 +307,48 @@ def compute_correlograms_numpy(sorting, window_size, bin_size):
return correlograms


# TODO: it would be better for this function to take num_half_bins
# and num_bins directly. However, this will break misc_metrics.slidingRP_violations()
# Check tests for slidingRP_violations(), write one covering that functionality
# if required, then make the refactoring.
# TODO: also make clear the output are always counts, not correlation / covariance matrices
def correlogram_for_one_segment(spike_times, spike_labels, window_size, bin_size):
"""
Called by compute_correlograms_numpy
"""
A very well optimized algorithm for the cross-correlation of
spike trains, copied from phy package written by Cyrille Rossant.

This method does not perform a cross-correlation in the typical
way (sliding and computing correlations, or via Fourier transform).
Instead the time difference between every other spike within the
window is directly computer and stored as a count in the relevant bin.

Initially, the spike_times array is shifted by 1 position, and the difference
computed. This gives the time differences betwen the closest spikes
(skipping the zero-lag case). Next, the differences between
spikes times in samples are converted into units relative to
bin_size ('binarized'). Spikes in which the binarized difference to
their closest neighbouring spike is greater than half the bin-size are
masked and not compared in future. Finally, the indicies of the
(num_units, num_units, num_bins) correlogram in which there are
a match are found and iterated appropriated. This repeats
for all shifts long the spike_train until no spikes have a corepsponding
match within the window size.

# TODO: is every combination really checked in this shifting procedure?

Parameters
----------
spike_times : np.ndarray
An array of spike times (in samples, not seconds). This contains
spikes from all units.
spike_labels : np.ndarray
An array of labels indicating the unit of the corresponding spike in
`spike_times`.
window_size : int
The window size over which to perform the cross-correlation, in samples
bin_size : int
The size of which to bin lags, in samples. TODO: come up with some standard terminology and way of describing this from within the module.
"""
num_half_bins = int(window_size // bin_size)
num_bins = int(2 * num_half_bins)
num_units = len(np.unique(spike_labels))
Expand Down