Skip to content

Commit

Permalink
Merge branch 'main' into import-docs
Browse files Browse the repository at this point in the history
  • Loading branch information
zm711 authored Nov 17, 2023
2 parents 7595d9c + efc042e commit f7cc7b8
Show file tree
Hide file tree
Showing 40 changed files with 815 additions and 175 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/full-test-with-codecov.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ jobs:
- name: Shows installed packages by pip, git-annex and cached testing files
uses: ./.github/actions/show-test-environment
- name: run tests
env:
HDF5_PLUGIN_PATH: ${{ github.workspace }}/hdf5_plugin_path_maxwell
run: |
source ${{ github.workspace }}/test_env/bin/activate
pytest -m "not sorters_external" --cov=./ --cov-report xml:./coverage.xml -vv -ra --durations=0 | tee report_full.txt; test ${PIPESTATUS[0]} -eq 0 || exit 1
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/full-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ jobs:
- name: Test core
run: ./.github/run_tests.sh core
- name: Test extractors
env:
HDF5_PLUGIN_PATH: ${{ github.workspace }}/hdf5_plugin_path_maxwell
if: ${{ steps.modules-changed.outputs.EXTRACTORS_CHANGED == 'true' || steps.modules-changed.outputs.CORE_CHANGED == 'true' }}
run: ./.github/run_tests.sh "extractors and not streaming_extractors"
- name: Test preprocessing
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ repos:
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/psf/black
rev: 23.10.1
rev: 23.11.0
hooks:
- id: black
files: ^src/
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ With SpikeInterface, users can:

## Documentation

Detailed documentation of the latest PyPI release of SpikeInterface can be found [here](https://spikeinterface.readthedocs.io/en/0.99.0).
Detailed documentation of the latest PyPI release of SpikeInterface can be found [here](https://spikeinterface.readthedocs.io/en/0.99.1).

Detailed documentation of the development version of SpikeInterface can be found [here](https://spikeinterface.readthedocs.io/en/latest).

Expand Down
2 changes: 2 additions & 0 deletions doc/modules/preprocessing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ dtype (unless specified otherwise):
Some scaling pre-processors, such as :code:`whiten()` or :code:`zscore()`, will force the output to :code:`float32`.

When converting from a :code:`float` to an :code:`int`, the value will first be rounded to the nearest integer.


Available preprocessing
-----------------------
Expand Down
13 changes: 13 additions & 0 deletions doc/releases/0.99.1.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
.. _release0.99.1:

SpikeInterface 0.99.1 release notes
-----------------------------------

14th November 2023

Minor release with some bug fixes.

* Fix crash when default start / end frame arguments on motion interpolation are used (#2176)
* Fix bug in `make_match_count_matrix()` when computing matching events (#2182, #2191, #2196)
* Fix maxwell tests by setting HDF5_PLUGIN_PATH env in action (#2161)
* Add read_npz_sorting to extractors module (#2183)
7 changes: 7 additions & 0 deletions doc/whatisnew.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Release notes
.. toctree::
:maxdepth: 1

releases/0.99.1.rst
releases/0.99.0.rst
releases/0.98.2.rst
releases/0.98.1.rst
Expand All @@ -32,6 +33,12 @@ Release notes
releases/0.9.1.rst


Version 0.99.1
==============

* Minor release with some bug fixes


Version 0.99.0
==============

Expand Down
143 changes: 96 additions & 47 deletions src/spikeinterface/comparison/comparisontools.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ def compute_agreement_score(num_matches, num1, num2):
def do_count_event(sorting):
"""
Count event for each units in a sorting.
Kept for backward compatibility sorting.count_num_spikes_per_unit() is doing the same.
Parameters
----------
sorting: SortingExtractor
Expand All @@ -75,14 +78,7 @@ def do_count_event(sorting):
"""
import pandas as pd

unit_ids = sorting.get_unit_ids()
ev_counts = np.zeros(len(unit_ids), dtype="int64")
for segment_index in range(sorting.get_num_segments()):
ev_counts += np.array(
[len(sorting.get_unit_spike_train(u, segment_index=segment_index)) for u in unit_ids], dtype="int64"
)
event_counts = pd.Series(ev_counts, index=unit_ids)
return event_counts
return pd.Series(sorting.count_num_spikes_per_unit())


def count_match_spikes(times1, all_times2, delta_frames): # , event_counts1, event_counts2 unit2_ids,
Expand Down Expand Up @@ -133,11 +129,9 @@ def compute_matching_matrix(
delta_frames,
):
"""
Compute a matrix representing the matches between two spike trains.
Given two spike trains, this function finds matching spikes based on a temporal proximity criterion
defined by `delta_frames`. The resulting matrix indicates the number of matches between units
in `spike_frames_train1` and `spike_frames_train2`.
Internal function used by `make_match_count_matrix()`.
This function is for one segment only.
The loop over segment is done in `make_match_count_matrix()`
Parameters
----------
Expand All @@ -164,31 +158,9 @@ def compute_matching_matrix(
A 2D numpy array of shape `(num_units_train1, num_units_train2)`. Each element `[i, j]` represents
the count of matching spike pairs between unit `i` from `spike_frames_train1` and unit `j` from `spike_frames_train2`.
Notes
-----
This algorithm identifies matching spikes between two ordered spike trains.
By iterating through each spike in the first train, it compares them against spikes in the second train,
determining matches based on the two spikes frames being within `delta_frames` of each other.
To avoid redundant comparisons the algorithm maintains a reference, `second_train_search_start `,
which signifies the minimal index in the second spike train that might match the upcoming spike
in the first train.
The logic can be summarized as follows:
1. Iterate through each spike in the first train
2. For each spike, find the first match in the second train.
3. Save the index of the first match as the new `second_train_search_start `
3. For each match, find as many matches as possible from the first match onwards.
An important condition here is that the same spike is not matched twice. This is managed by keeping track
of the last matched frame for each unit pair in `last_match_frame1` and `last_match_frame2`
For more details on the rationale behind this approach, refer to the documentation of this module and/or
the metrics section in SpikeForest documentation.
"""

matching_matrix = np.zeros((num_units_train1, num_units_train2), dtype=np.uint16)
matching_matrix = np.zeros((num_units_train1, num_units_train2), dtype=np.uint64)

# Used to avoid the same spike matching twice
last_match_frame1 = -np.ones_like(matching_matrix, dtype=np.int64)
Expand Down Expand Up @@ -216,11 +188,11 @@ def compute_matching_matrix(
unit_index1, unit_index2 = unit_indices1[index1], unit_indices2[index2]

if (
frame1 != last_match_frame1[unit_index1, unit_index2]
and frame2 != last_match_frame2[unit_index1, unit_index2]
index1 != last_match_frame1[unit_index1, unit_index2]
and index2 != last_match_frame2[unit_index1, unit_index2]
):
last_match_frame1[unit_index1, unit_index2] = frame1
last_match_frame2[unit_index1, unit_index2] = frame2
last_match_frame1[unit_index1, unit_index2] = index1
last_match_frame2[unit_index1, unit_index2] = index2

matching_matrix[unit_index1, unit_index2] += 1

Expand All @@ -232,10 +204,65 @@ def compute_matching_matrix(
return compute_matching_matrix


def make_match_count_matrix(sorting1, sorting2, delta_frames):
def make_match_count_matrix(sorting1, sorting2, delta_frames, ensure_symmetry=False):
"""
Computes a matrix representing the matches between two Sorting objects.
Given two spike trains, this function finds matching spikes based on a temporal proximity criterion
defined by `delta_frames`. The resulting matrix indicates the number of matches between units
in `spike_frames_train1` and `spike_frames_train2` for each pair of units.
Note that this algo is not symmetric and is biased with `sorting1` representing ground truth for the comparison
Parameters
----------
sorting1 : Sorting
An array of integer frame numbers corresponding to spike times for the first train. Must be in ascending order.
sorting2 : Sorting
An array of integer frame numbers corresponding to spike times for the second train. Must be in ascending order.
delta_frames : int
The inclusive upper limit on the frame difference for which two spikes are considered matching. That is
if `abs(spike_frames_train1[i] - spike_frames_train2[j]) <= delta_frames` then the spikes at
`spike_frames_train1[i]` and `spike_frames_train2[j]` are considered matching.
ensure_symmetry: bool, default False
If ensure_symmetry=True, then the algo is run two times by switching sorting1 and sorting2.
And the minimum of the two results is taken.
Returns
-------
matching_matrix : ndarray
A 2D numpy array of shape `(num_units_train1, num_units_train2)`. Each element `[i, j]` represents
the count of matching spike pairs between unit `i` from `spike_frames_train1` and unit `j` from `spike_frames_train2`.
Notes
-----
This algorithm identifies matching spikes between two ordered spike trains.
By iterating through each spike in the first train, it compares them against spikes in the second train,
determining matches based on the two spikes frames being within `delta_frames` of each other.
To avoid redundant comparisons the algorithm maintains a reference, `second_train_search_start `,
which signifies the minimal index in the second spike train that might match the upcoming spike
in the first train.
The logic can be summarized as follows:
1. Iterate through each spike in the first train
2. For each spike, find the first match in the second train.
3. Save the index of the first match as the new `second_train_search_start `
3. For each match, find as many matches as possible from the first match onwards.
An important condition here is that the same spike is not matched twice. This is managed by keeping track
of the last matched frame for each unit pair in `last_match_frame1` and `last_match_frame2`
There are corner cases where a spike can be counted twice in the spiketrain 2 if there are bouts of bursting activity
(below delta_frames) in the spiketrain 1. To ensure that the number of matches does not exceed the number of spikes,
we apply a final clip.
For more details on the rationale behind this approach, refer to the documentation of this module and/or
the metrics section in SpikeForest documentation.
"""

num_units_sorting1 = sorting1.get_num_units()
num_units_sorting2 = sorting2.get_num_units()
matching_matrix = np.zeros((num_units_sorting1, num_units_sorting2), dtype=np.uint16)
matching_matrix = np.zeros((num_units_sorting1, num_units_sorting2), dtype=np.uint64)

spike_vector1_segments = sorting1.to_spike_vector(concatenated=False)
spike_vector2_segments = sorting2.to_spike_vector(concatenated=False)
Expand All @@ -257,7 +284,7 @@ def make_match_count_matrix(sorting1, sorting2, delta_frames):
unit_indices1_sorted = spike_vector1["unit_index"]
unit_indices2_sorted = spike_vector2["unit_index"]

matching_matrix += get_optimized_compute_matching_matrix()(
matching_matrix_seg = get_optimized_compute_matching_matrix()(
sample_frames1_sorted,
sample_frames2_sorted,
unit_indices1_sorted,
Expand All @@ -267,6 +294,26 @@ def make_match_count_matrix(sorting1, sorting2, delta_frames):
delta_frames,
)

if ensure_symmetry:
matching_matrix_seg_switch = get_optimized_compute_matching_matrix()(
sample_frames2_sorted,
sample_frames1_sorted,
unit_indices2_sorted,
unit_indices1_sorted,
num_units_sorting2,
num_units_sorting1,
delta_frames,
)
matching_matrix_seg = np.maximum(matching_matrix_seg, matching_matrix_seg_switch.T)

matching_matrix += matching_matrix_seg

# ensure the number of match do not exceed the number of spike in train 2
# this is a simple way to handle corner cases for bursting in sorting1
spike_count2 = np.array(list(sorting2.count_num_spikes_per_unit().values()))
spike_count2 = spike_count2[np.newaxis, :]
matching_matrix = np.clip(matching_matrix, None, spike_count2)

# Build a data frame from the matching matrix
import pandas as pd

Expand All @@ -277,12 +324,12 @@ def make_match_count_matrix(sorting1, sorting2, delta_frames):
return match_event_counts_df


def make_agreement_scores(sorting1, sorting2, delta_frames):
def make_agreement_scores(sorting1, sorting2, delta_frames, ensure_symmetry=True):
"""
Make the agreement matrix.
No threshold (min_score) is applied at this step.
Note : this computation is symmetric.
Note : this computation is symmetric by default.
Inverting sorting1 and sorting2 give the transposed matrix.
Parameters
Expand All @@ -293,7 +340,9 @@ def make_agreement_scores(sorting1, sorting2, delta_frames):
The second sorting extractor
delta_frames: int
Number of frames to consider spikes coincident
ensure_symmetry: bool, default: True
If ensure_symmetry is True, then the algo is run two times by switching sorting1 and sorting2.
And the minimum of the two results is taken.
Returns
-------
agreement_scores: array (float)
Expand All @@ -309,7 +358,7 @@ def make_agreement_scores(sorting1, sorting2, delta_frames):
event_counts1 = pd.Series(ev_counts1, index=unit1_ids)
event_counts2 = pd.Series(ev_counts2, index=unit2_ids)

match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames)
match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames, ensure_symmetry=ensure_symmetry)

agreement_scores = make_agreement_scores_from_count(match_event_count, event_counts1, event_counts2)

Expand Down
9 changes: 8 additions & 1 deletion src/spikeinterface/comparison/paircomparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(
delta_time=0.4,
match_score=0.5,
chance_score=0.1,
ensure_symmetry=False,
n_jobs=1,
verbose=False,
):
Expand Down Expand Up @@ -55,6 +56,8 @@ def __init__(
self.unit1_ids = self.sorting1.get_unit_ids()
self.unit2_ids = self.sorting2.get_unit_ids()

self.ensure_symmetry = ensure_symmetry

self._do_agreement()
self._do_matching()

Expand Down Expand Up @@ -84,7 +87,9 @@ def _do_agreement(self):
self.event_counts2 = do_count_event(self.sorting2)

# matrix of event match count for each pair
self.match_event_count = make_match_count_matrix(self.sorting1, self.sorting2, self.delta_frames)
self.match_event_count = make_match_count_matrix(
self.sorting1, self.sorting2, self.delta_frames, ensure_symmetry=self.ensure_symmetry
)

# agreement matrix score for each pair
self.agreement_scores = make_agreement_scores_from_count(
Expand Down Expand Up @@ -151,6 +156,7 @@ def __init__(
delta_time=delta_time,
match_score=match_score,
chance_score=chance_score,
ensure_symmetry=True,
n_jobs=n_jobs,
verbose=verbose,
)
Expand Down Expand Up @@ -283,6 +289,7 @@ def __init__(
delta_time=delta_time,
match_score=match_score,
chance_score=chance_score,
ensure_symmetry=False,
n_jobs=n_jobs,
verbose=verbose,
)
Expand Down
Loading

0 comments on commit f7cc7b8

Please sign in to comment.