Skip to content

Commit

Permalink
Fix compute matching v3 (#2182)
Browse files Browse the repository at this point in the history
* some change to test

* another change

* another attempt

* attempt merge

* add condition

* add auth

* fix test and simpler implementation

* small typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* avoid corner cose of doing the matching loop twice

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Remove n_jobs

* Little docs cleanup

* Remove internal n_jobs

* Remove last internal n_jobs

* Apply suggestions from code review

* fix test

* comment to test

* docstring improvements

* variable naming

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* new proposal for compute_matching

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: Heberto Mayorquin <[email protected]>
Co-authored-by: Alessio Buccino <[email protected]>
  • Loading branch information
3 people authored Nov 9, 2023
1 parent fa0b034 commit 3291976
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 89 deletions.
130 changes: 65 additions & 65 deletions src/spikeinterface/comparison/comparisontools.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,103 +124,105 @@ def get_optimized_compute_matching_matrix():

@numba.jit(nopython=True, nogil=True)
def compute_matching_matrix(
frames_spike_train1,
frames_spike_train2,
spike_frames_train1,
spike_frames_train2,
unit_indices1,
unit_indices2,
num_units_sorting1,
num_units_sorting2,
num_units_train1,
num_units_train2,
delta_frames,
):
"""
Compute a matrix representing the matches between two spike trains.
Given two spike trains, this function finds matching spikes based on a temporal proximity criterion
defined by `delta_frames`. The resulting matrix indicates the number of matches between units
in `frames_spike_train1` and `frames_spike_train2`.
in `spike_frames_train1` and `spike_frames_train2`.
Parameters
----------
frames_spike_train1 : ndarray
Array of frames for the first spike train. Should be ordered in ascending order.
frames_spike_train2 : ndarray
Array of frames for the second spike train. Should be ordered in ascending order.
spike_frames_train1 : ndarray
An array of integer frame numbers corresponding to spike times for the first train. Must be in ascending order.
spike_frames_train2 : ndarray
An array of integer frame numbers corresponding to spike times for the second train. Must be in ascending order.
unit_indices1 : ndarray
Array indicating the unit indices corresponding to each spike in `frames_spike_train1`.
An array of integers where `unit_indices1[i]` gives the unit index associated with the spike at `spike_frames_train1[i]`.
unit_indices2 : ndarray
Array indicating the unit indices corresponding to each spike in `frames_spike_train2`.
num_units_sorting1 : int
Total number of units in the first spike train.
num_units_sorting2 : int
Total number of units in the second spike train.
An array of integers where `unit_indices2[i]` gives the unit index associated with the spike at `spike_frames_train2[i]`.
num_units_train1 : int
The total count of unique units in the first spike train.
num_units_train2 : int
The total count of unique units in the second spike train.
delta_frames : int
Maximum difference in frames between two spikes to consider them as a match.
The inclusive upper limit on the frame difference for which two spikes are considered matching. That is
if `abs(spike_frames_train1[i] - spike_frames_train2[j]) <= delta_frames` then the spikes at `spike_frames_train1[i]`
and `spike_frames_train2[j]` are considered matching.
Returns
-------
matching_matrix : ndarray
A matrix of shape (num_units_sorting1, num_units_sorting2) where each entry [i, j] represents
the number of matching spikes between unit i of `frames_spike_train1` and unit j of `frames_spike_train2`.
A 2D numpy array of shape `(num_units_train1, num_units_train2)`. Each element `[i, j]` represents
the count of matching spike pairs between unit `i` from `spike_frames_train1` and unit `j` from `spike_frames_train2`.
Notes
-----
This algorithm identifies matching spikes between two ordered spike trains.
By iterating through each spike in the first train, it compares them against spikes in the second train,
determining matches based on the two spikes frames being within `delta_frames` of each other.
To avoid redundant comparisons the algorithm maintains a reference, `lower_search_limit_in_second_train`,
To avoid redundant comparisons the algorithm maintains a reference, `second_train_search_start `,
which signifies the minimal index in the second spike train that might match the upcoming spike
in the first train. This means that the start of the search moves forward in the second train as the
matches between the two trains are found decreasing the number of comparisons needed.
in the first train.
The logic can be summarized as follows:
1. Iterate through each spike in the first train
2. For each spike, find the first match in the second train.
3. Save the index of the first match as the new `second_train_search_start `
3. For each match, find as many matches as possible from the first match onwards.
An important condition here is thatthe same spike is not matched twice. This is managed by keeping track
of the last matched frame for each unit pair in `previous_frame1_match` and `previous_frame2_match`
An important condition here is that the same spike is not matched twice. This is managed by keeping track
of the last matched frame for each unit pair in `last_match_frame1` and `last_match_frame2`
For more details on the rationale behind this approach, refer to the documentation of this module and/or
the metrics section in SpikeForest documentation.
the metrics section in SpikeForest documentation.
"""

matching_matrix = np.zeros((num_units_sorting1, num_units_sorting2), dtype=np.uint16)
matching_matrix = np.zeros((num_units_train1, num_units_train2), dtype=np.uint16)

# Used to avoid the same spike matching twice
previous_frame1_match = -np.ones_like(matching_matrix, dtype=np.int64)
previous_frame2_match = -np.ones_like(matching_matrix, dtype=np.int64)

lower_search_limit_in_second_train = 0

for index1 in range(len(frames_spike_train1)):
# Keeps track of which frame in the second spike train should be used as a search start for matches
index2 = lower_search_limit_in_second_train
frame1 = frames_spike_train1[index1]

# Determine next_frame1 if current frame is not the last frame
not_in_the_last_loop = index1 < len(frames_spike_train1) - 1
if not_in_the_last_loop:
next_frame1 = frames_spike_train1[index1 + 1]

while index2 < len(frames_spike_train2):
frame2 = frames_spike_train2[index2]
not_a_match = abs(frame1 - frame2) > delta_frames
if not_a_match:
# Go to the next frame in the first train
last_match_frame1 = -np.ones_like(matching_matrix, dtype=np.int64)
last_match_frame2 = -np.ones_like(matching_matrix, dtype=np.int64)

num_spike_frames_train1 = len(spike_frames_train1)
num_spike_frames_train2 = len(spike_frames_train2)

# Keeps track of which frame in the second spike train should be used as a search start for matches
second_train_search_start = 0
for index1 in range(num_spike_frames_train1):
frame1 = spike_frames_train1[index1]

for index2 in range(second_train_search_start, num_spike_frames_train2):
frame2 = spike_frames_train2[index2]
if frame2 < frame1 - delta_frames:
# no match move the left limit for the next loop
second_train_search_start += 1
continue
elif frame2 > frame1 + delta_frames:
# no match stop search in train2 and continue increment in train1
break
else:
# match
unit_index1, unit_index2 = unit_indices1[index1], unit_indices2[index2]

# Map the match to a matrix
row, column = unit_indices1[index1], unit_indices2[index2]

# The same spike cannot be matched twice see the notes in the docstring for more info on this constraint
if frame1 != previous_frame1_match[row, column] and frame2 != previous_frame2_match[row, column]:
previous_frame1_match[row, column] = frame1
previous_frame2_match[row, column] = frame2

matching_matrix[row, column] += 1

index2 += 1
if (
frame1 != last_match_frame1[unit_index1, unit_index2]
and frame2 != last_match_frame2[unit_index1, unit_index2]
):
last_match_frame1[unit_index1, unit_index2] = frame1
last_match_frame2[unit_index1, unit_index2] = frame2

# Advance the lower_search_limit_in_second_train if the next frame in the first train does not match
not_a_match_with_next = abs(next_frame1 - frame2) > delta_frames
if not_a_match_with_next:
lower_search_limit_in_second_train = index2
matching_matrix[unit_index1, unit_index2] += 1

return matching_matrix

Expand All @@ -230,7 +232,7 @@ def compute_matching_matrix(
return compute_matching_matrix


def make_match_count_matrix(sorting1, sorting2, delta_frames, n_jobs=None):
def make_match_count_matrix(sorting1, sorting2, delta_frames):
num_units_sorting1 = sorting1.get_num_units()
num_units_sorting2 = sorting2.get_num_units()
matching_matrix = np.zeros((num_units_sorting1, num_units_sorting2), dtype=np.uint16)
Expand Down Expand Up @@ -275,7 +277,7 @@ def make_match_count_matrix(sorting1, sorting2, delta_frames, n_jobs=None):
return match_event_counts_df


def make_agreement_scores(sorting1, sorting2, delta_frames, n_jobs=1):
def make_agreement_scores(sorting1, sorting2, delta_frames):
"""
Make the agreement matrix.
No threshold (min_score) is applied at this step.
Expand All @@ -291,8 +293,6 @@ def make_agreement_scores(sorting1, sorting2, delta_frames, n_jobs=1):
The second sorting extractor
delta_frames: int
Number of frames to consider spikes coincident
n_jobs: int
Number of jobs to run in parallel
Returns
-------
Expand All @@ -309,7 +309,7 @@ def make_agreement_scores(sorting1, sorting2, delta_frames, n_jobs=1):
event_counts1 = pd.Series(ev_counts1, index=unit1_ids)
event_counts2 = pd.Series(ev_counts2, index=unit2_ids)

match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames, n_jobs=n_jobs)
match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames)

agreement_scores = make_agreement_scores_from_count(match_event_count, event_counts1, event_counts2)

Expand Down
4 changes: 1 addition & 3 deletions src/spikeinterface/comparison/paircomparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,7 @@ def _do_agreement(self):
self.event_counts2 = do_count_event(self.sorting2)

# matrix of event match count for each pair
self.match_event_count = make_match_count_matrix(
self.sorting1, self.sorting2, self.delta_frames, n_jobs=self.n_jobs
)
self.match_event_count = make_match_count_matrix(self.sorting1, self.sorting2, self.delta_frames)

# agreement matrix score for each pair
self.agreement_scores = make_agreement_scores_from_count(
Expand Down
66 changes: 45 additions & 21 deletions src/spikeinterface/comparison/tests/test_comparisontools.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,23 @@ def test_make_match_count_matrix_repeated_matching_but_no_double_counting():
assert_array_equal(result.to_numpy(), expected_result)


def test_make_match_count_matrix_test_proper_search_in_the_second_train():
"Search exhaustively in the second train, but only within the delta_frames window, do not terminate search early"
frames_spike_train1 = [500, 600, 800]
frames_spike_train2 = [0, 100, 200, 300, 500, 800]
unit_indices1 = [0, 0, 0]
unit_indices2 = [0, 0, 0, 0, 0, 0]
delta_frames = 20

sorting1, sorting2 = make_sorting(frames_spike_train1, unit_indices1, frames_spike_train2, unit_indices2)

result = make_match_count_matrix(sorting1, sorting2, delta_frames=delta_frames)

expected_result = np.array([[2]])

assert_array_equal(result.to_numpy(), expected_result)


def test_make_agreement_scores():
delta_frames = 10

Expand All @@ -150,15 +167,15 @@ def test_make_agreement_scores():
[0, 0, 5],
)

agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames, n_jobs=1)
agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames)
print(agreement_scores)

ok = np.array([[2 / 3, 0], [0, 1.0]], dtype="float64")

assert_array_equal(agreement_scores.values, ok)

# test if symetric
agreement_scores2 = make_agreement_scores(sorting2, sorting1, delta_frames, n_jobs=1)
agreement_scores2 = make_agreement_scores(sorting2, sorting1, delta_frames)
assert_array_equal(agreement_scores, agreement_scores2.T)


Expand All @@ -178,7 +195,7 @@ def test_make_possible_match():
[0, 0, 5],
)

agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames, n_jobs=1)
agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames)

possible_match_12, possible_match_21 = make_possible_match(agreement_scores, min_accuracy)

Expand Down Expand Up @@ -207,7 +224,7 @@ def test_make_best_match():
[0, 0, 5],
)

agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames, n_jobs=1)
agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames)

best_match_12, best_match_21 = make_best_match(agreement_scores, min_accuracy)

Expand Down Expand Up @@ -236,7 +253,7 @@ def test_make_hungarian_match():
[0, 0, 5],
)

agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames, n_jobs=1)
agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames)

hungarian_match_12, hungarian_match_21 = make_hungarian_match(agreement_scores, min_accuracy)

Expand Down Expand Up @@ -344,8 +361,8 @@ def test_do_confusion_matrix():

event_counts1 = do_count_event(sorting1)
event_counts2 = do_count_event(sorting2)
match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames, n_jobs=1)
agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames, n_jobs=1)
match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames)
agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames)
hungarian_match_12, hungarian_match_21 = make_hungarian_match(agreement_scores, min_accuracy)

confusion = do_confusion_matrix(event_counts1, event_counts2, hungarian_match_12, match_event_count)
Expand All @@ -363,8 +380,8 @@ def test_do_confusion_matrix():

event_counts1 = do_count_event(sorting1)
event_counts2 = do_count_event(sorting2)
match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames, n_jobs=1)
agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames, n_jobs=1)
match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames)
agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames)
hungarian_match_12, hungarian_match_21 = make_hungarian_match(agreement_scores, min_accuracy)

confusion = do_confusion_matrix(event_counts1, event_counts2, hungarian_match_12, match_event_count)
Expand All @@ -391,8 +408,8 @@ def test_do_count_score_and_perf():

event_counts1 = do_count_event(sorting1)
event_counts2 = do_count_event(sorting2)
match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames, n_jobs=1)
agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames, n_jobs=1)
match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames)
agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames)
hungarian_match_12, hungarian_match_21 = make_hungarian_match(agreement_scores, min_accuracy)

count_score = do_count_score(event_counts1, event_counts2, hungarian_match_12, match_event_count)
Expand All @@ -415,13 +432,20 @@ def test_do_count_score_and_perf():

if __name__ == "__main__":
test_make_match_count_matrix()
test_make_agreement_scores()

test_make_possible_match()
test_make_best_match()
test_make_hungarian_match()

test_do_score_labels()
test_compare_spike_trains()
test_do_confusion_matrix()
test_do_count_score_and_perf()
test_make_match_count_matrix_sorting_with_itself_simple()
test_make_match_count_matrix_sorting_with_itself_longer()
test_make_match_count_matrix_with_mismatched_sortings()
test_make_match_count_matrix_no_double_matching()
test_make_match_count_matrix_repeated_matching_but_no_double_counting()
test_make_match_count_matrix_test_proper_search_in_the_second_train()

# test_make_agreement_scores()

# test_make_possible_match()
# test_make_best_match()
# test_make_hungarian_match()

# test_do_score_labels()
# test_compare_spike_trains()
# test_do_confusion_matrix()
# test_do_count_score_and_perf()

0 comments on commit 3291976

Please sign in to comment.