From 329197618a9b48ef876d1e8b8e79f07f4abf5e49 Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Thu, 9 Nov 2023 10:33:00 +0100 Subject: [PATCH] Fix compute matching v3 (#2182) * some change to test * another change * another attempt * attempt merge * add condition * add auth * fix test and simpler implementation * small typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * avoid corner cose of doing the matching loop twice * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove n_jobs * Little docs cleanup * Remove internal n_jobs * Remove last internal n_jobs * Apply suggestions from code review * fix test * comment to test * docstring improvements * variable naming * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * new proposal for compute_matching * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: Heberto Mayorquin Co-authored-by: Alessio Buccino --- .../comparison/comparisontools.py | 130 +++++++++--------- .../comparison/paircomparisons.py | 4 +- .../comparison/tests/test_comparisontools.py | 66 ++++++--- 3 files changed, 111 insertions(+), 89 deletions(-) diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index 7a1fb87175..3cd856d662 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -124,12 +124,12 @@ def get_optimized_compute_matching_matrix(): @numba.jit(nopython=True, nogil=True) def compute_matching_matrix( - frames_spike_train1, - frames_spike_train2, + spike_frames_train1, + spike_frames_train2, unit_indices1, unit_indices2, - num_units_sorting1, - num_units_sorting2, + num_units_train1, + num_units_train2, delta_frames, ): """ @@ -137,30 +137,33 @@ def compute_matching_matrix( Given two spike trains, this function finds matching spikes based on a temporal proximity criterion defined by `delta_frames`. The resulting matrix indicates the number of matches between units - in `frames_spike_train1` and `frames_spike_train2`. + in `spike_frames_train1` and `spike_frames_train2`. Parameters ---------- - frames_spike_train1 : ndarray - Array of frames for the first spike train. Should be ordered in ascending order. - frames_spike_train2 : ndarray - Array of frames for the second spike train. Should be ordered in ascending order. + spike_frames_train1 : ndarray + An array of integer frame numbers corresponding to spike times for the first train. Must be in ascending order. + spike_frames_train2 : ndarray + An array of integer frame numbers corresponding to spike times for the second train. Must be in ascending order. unit_indices1 : ndarray - Array indicating the unit indices corresponding to each spike in `frames_spike_train1`. + An array of integers where `unit_indices1[i]` gives the unit index associated with the spike at `spike_frames_train1[i]`. unit_indices2 : ndarray - Array indicating the unit indices corresponding to each spike in `frames_spike_train2`. - num_units_sorting1 : int - Total number of units in the first spike train. - num_units_sorting2 : int - Total number of units in the second spike train. + An array of integers where `unit_indices2[i]` gives the unit index associated with the spike at `spike_frames_train2[i]`. + num_units_train1 : int + The total count of unique units in the first spike train. + num_units_train2 : int + The total count of unique units in the second spike train. delta_frames : int - Maximum difference in frames between two spikes to consider them as a match. + The inclusive upper limit on the frame difference for which two spikes are considered matching. That is + if `abs(spike_frames_train1[i] - spike_frames_train2[j]) <= delta_frames` then the spikes at `spike_frames_train1[i]` + and `spike_frames_train2[j]` are considered matching. Returns ------- matching_matrix : ndarray - A matrix of shape (num_units_sorting1, num_units_sorting2) where each entry [i, j] represents - the number of matching spikes between unit i of `frames_spike_train1` and unit j of `frames_spike_train2`. + A 2D numpy array of shape `(num_units_train1, num_units_train2)`. Each element `[i, j]` represents + the count of matching spike pairs between unit `i` from `spike_frames_train1` and unit `j` from `spike_frames_train2`. + Notes ----- @@ -168,59 +171,58 @@ def compute_matching_matrix( By iterating through each spike in the first train, it compares them against spikes in the second train, determining matches based on the two spikes frames being within `delta_frames` of each other. - To avoid redundant comparisons the algorithm maintains a reference, `lower_search_limit_in_second_train`, + To avoid redundant comparisons the algorithm maintains a reference, `second_train_search_start `, which signifies the minimal index in the second spike train that might match the upcoming spike - in the first train. This means that the start of the search moves forward in the second train as the - matches between the two trains are found decreasing the number of comparisons needed. + in the first train. + + The logic can be summarized as follows: + 1. Iterate through each spike in the first train + 2. For each spike, find the first match in the second train. + 3. Save the index of the first match as the new `second_train_search_start ` + 3. For each match, find as many matches as possible from the first match onwards. - An important condition here is thatthe same spike is not matched twice. This is managed by keeping track - of the last matched frame for each unit pair in `previous_frame1_match` and `previous_frame2_match` + An important condition here is that the same spike is not matched twice. This is managed by keeping track + of the last matched frame for each unit pair in `last_match_frame1` and `last_match_frame2` For more details on the rationale behind this approach, refer to the documentation of this module and/or - the metrics section in SpikeForest documentation. + the metrics section in SpikeForest documentation. """ - matching_matrix = np.zeros((num_units_sorting1, num_units_sorting2), dtype=np.uint16) + matching_matrix = np.zeros((num_units_train1, num_units_train2), dtype=np.uint16) # Used to avoid the same spike matching twice - previous_frame1_match = -np.ones_like(matching_matrix, dtype=np.int64) - previous_frame2_match = -np.ones_like(matching_matrix, dtype=np.int64) - - lower_search_limit_in_second_train = 0 - - for index1 in range(len(frames_spike_train1)): - # Keeps track of which frame in the second spike train should be used as a search start for matches - index2 = lower_search_limit_in_second_train - frame1 = frames_spike_train1[index1] - - # Determine next_frame1 if current frame is not the last frame - not_in_the_last_loop = index1 < len(frames_spike_train1) - 1 - if not_in_the_last_loop: - next_frame1 = frames_spike_train1[index1 + 1] - - while index2 < len(frames_spike_train2): - frame2 = frames_spike_train2[index2] - not_a_match = abs(frame1 - frame2) > delta_frames - if not_a_match: - # Go to the next frame in the first train + last_match_frame1 = -np.ones_like(matching_matrix, dtype=np.int64) + last_match_frame2 = -np.ones_like(matching_matrix, dtype=np.int64) + + num_spike_frames_train1 = len(spike_frames_train1) + num_spike_frames_train2 = len(spike_frames_train2) + + # Keeps track of which frame in the second spike train should be used as a search start for matches + second_train_search_start = 0 + for index1 in range(num_spike_frames_train1): + frame1 = spike_frames_train1[index1] + + for index2 in range(second_train_search_start, num_spike_frames_train2): + frame2 = spike_frames_train2[index2] + if frame2 < frame1 - delta_frames: + # no match move the left limit for the next loop + second_train_search_start += 1 + continue + elif frame2 > frame1 + delta_frames: + # no match stop search in train2 and continue increment in train1 break + else: + # match + unit_index1, unit_index2 = unit_indices1[index1], unit_indices2[index2] - # Map the match to a matrix - row, column = unit_indices1[index1], unit_indices2[index2] - - # The same spike cannot be matched twice see the notes in the docstring for more info on this constraint - if frame1 != previous_frame1_match[row, column] and frame2 != previous_frame2_match[row, column]: - previous_frame1_match[row, column] = frame1 - previous_frame2_match[row, column] = frame2 - - matching_matrix[row, column] += 1 - - index2 += 1 + if ( + frame1 != last_match_frame1[unit_index1, unit_index2] + and frame2 != last_match_frame2[unit_index1, unit_index2] + ): + last_match_frame1[unit_index1, unit_index2] = frame1 + last_match_frame2[unit_index1, unit_index2] = frame2 - # Advance the lower_search_limit_in_second_train if the next frame in the first train does not match - not_a_match_with_next = abs(next_frame1 - frame2) > delta_frames - if not_a_match_with_next: - lower_search_limit_in_second_train = index2 + matching_matrix[unit_index1, unit_index2] += 1 return matching_matrix @@ -230,7 +232,7 @@ def compute_matching_matrix( return compute_matching_matrix -def make_match_count_matrix(sorting1, sorting2, delta_frames, n_jobs=None): +def make_match_count_matrix(sorting1, sorting2, delta_frames): num_units_sorting1 = sorting1.get_num_units() num_units_sorting2 = sorting2.get_num_units() matching_matrix = np.zeros((num_units_sorting1, num_units_sorting2), dtype=np.uint16) @@ -275,7 +277,7 @@ def make_match_count_matrix(sorting1, sorting2, delta_frames, n_jobs=None): return match_event_counts_df -def make_agreement_scores(sorting1, sorting2, delta_frames, n_jobs=1): +def make_agreement_scores(sorting1, sorting2, delta_frames): """ Make the agreement matrix. No threshold (min_score) is applied at this step. @@ -291,8 +293,6 @@ def make_agreement_scores(sorting1, sorting2, delta_frames, n_jobs=1): The second sorting extractor delta_frames: int Number of frames to consider spikes coincident - n_jobs: int - Number of jobs to run in parallel Returns ------- @@ -309,7 +309,7 @@ def make_agreement_scores(sorting1, sorting2, delta_frames, n_jobs=1): event_counts1 = pd.Series(ev_counts1, index=unit1_ids) event_counts2 = pd.Series(ev_counts2, index=unit2_ids) - match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames, n_jobs=n_jobs) + match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames) agreement_scores = make_agreement_scores_from_count(match_event_count, event_counts1, event_counts2) diff --git a/src/spikeinterface/comparison/paircomparisons.py b/src/spikeinterface/comparison/paircomparisons.py index e2dc30493d..7f21aa657f 100644 --- a/src/spikeinterface/comparison/paircomparisons.py +++ b/src/spikeinterface/comparison/paircomparisons.py @@ -84,9 +84,7 @@ def _do_agreement(self): self.event_counts2 = do_count_event(self.sorting2) # matrix of event match count for each pair - self.match_event_count = make_match_count_matrix( - self.sorting1, self.sorting2, self.delta_frames, n_jobs=self.n_jobs - ) + self.match_event_count = make_match_count_matrix(self.sorting1, self.sorting2, self.delta_frames) # agreement matrix score for each pair self.agreement_scores = make_agreement_scores_from_count( diff --git a/src/spikeinterface/comparison/tests/test_comparisontools.py b/src/spikeinterface/comparison/tests/test_comparisontools.py index c6494b04d1..ab24678a1e 100644 --- a/src/spikeinterface/comparison/tests/test_comparisontools.py +++ b/src/spikeinterface/comparison/tests/test_comparisontools.py @@ -135,6 +135,23 @@ def test_make_match_count_matrix_repeated_matching_but_no_double_counting(): assert_array_equal(result.to_numpy(), expected_result) +def test_make_match_count_matrix_test_proper_search_in_the_second_train(): + "Search exhaustively in the second train, but only within the delta_frames window, do not terminate search early" + frames_spike_train1 = [500, 600, 800] + frames_spike_train2 = [0, 100, 200, 300, 500, 800] + unit_indices1 = [0, 0, 0] + unit_indices2 = [0, 0, 0, 0, 0, 0] + delta_frames = 20 + + sorting1, sorting2 = make_sorting(frames_spike_train1, unit_indices1, frames_spike_train2, unit_indices2) + + result = make_match_count_matrix(sorting1, sorting2, delta_frames=delta_frames) + + expected_result = np.array([[2]]) + + assert_array_equal(result.to_numpy(), expected_result) + + def test_make_agreement_scores(): delta_frames = 10 @@ -150,7 +167,7 @@ def test_make_agreement_scores(): [0, 0, 5], ) - agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames, n_jobs=1) + agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames) print(agreement_scores) ok = np.array([[2 / 3, 0], [0, 1.0]], dtype="float64") @@ -158,7 +175,7 @@ def test_make_agreement_scores(): assert_array_equal(agreement_scores.values, ok) # test if symetric - agreement_scores2 = make_agreement_scores(sorting2, sorting1, delta_frames, n_jobs=1) + agreement_scores2 = make_agreement_scores(sorting2, sorting1, delta_frames) assert_array_equal(agreement_scores, agreement_scores2.T) @@ -178,7 +195,7 @@ def test_make_possible_match(): [0, 0, 5], ) - agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames, n_jobs=1) + agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames) possible_match_12, possible_match_21 = make_possible_match(agreement_scores, min_accuracy) @@ -207,7 +224,7 @@ def test_make_best_match(): [0, 0, 5], ) - agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames, n_jobs=1) + agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames) best_match_12, best_match_21 = make_best_match(agreement_scores, min_accuracy) @@ -236,7 +253,7 @@ def test_make_hungarian_match(): [0, 0, 5], ) - agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames, n_jobs=1) + agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames) hungarian_match_12, hungarian_match_21 = make_hungarian_match(agreement_scores, min_accuracy) @@ -344,8 +361,8 @@ def test_do_confusion_matrix(): event_counts1 = do_count_event(sorting1) event_counts2 = do_count_event(sorting2) - match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames, n_jobs=1) - agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames, n_jobs=1) + match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames) + agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames) hungarian_match_12, hungarian_match_21 = make_hungarian_match(agreement_scores, min_accuracy) confusion = do_confusion_matrix(event_counts1, event_counts2, hungarian_match_12, match_event_count) @@ -363,8 +380,8 @@ def test_do_confusion_matrix(): event_counts1 = do_count_event(sorting1) event_counts2 = do_count_event(sorting2) - match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames, n_jobs=1) - agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames, n_jobs=1) + match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames) + agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames) hungarian_match_12, hungarian_match_21 = make_hungarian_match(agreement_scores, min_accuracy) confusion = do_confusion_matrix(event_counts1, event_counts2, hungarian_match_12, match_event_count) @@ -391,8 +408,8 @@ def test_do_count_score_and_perf(): event_counts1 = do_count_event(sorting1) event_counts2 = do_count_event(sorting2) - match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames, n_jobs=1) - agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames, n_jobs=1) + match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames) + agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames) hungarian_match_12, hungarian_match_21 = make_hungarian_match(agreement_scores, min_accuracy) count_score = do_count_score(event_counts1, event_counts2, hungarian_match_12, match_event_count) @@ -415,13 +432,20 @@ def test_do_count_score_and_perf(): if __name__ == "__main__": test_make_match_count_matrix() - test_make_agreement_scores() - - test_make_possible_match() - test_make_best_match() - test_make_hungarian_match() - - test_do_score_labels() - test_compare_spike_trains() - test_do_confusion_matrix() - test_do_count_score_and_perf() + test_make_match_count_matrix_sorting_with_itself_simple() + test_make_match_count_matrix_sorting_with_itself_longer() + test_make_match_count_matrix_with_mismatched_sortings() + test_make_match_count_matrix_no_double_matching() + test_make_match_count_matrix_repeated_matching_but_no_double_counting() + test_make_match_count_matrix_test_proper_search_in_the_second_train() + + # test_make_agreement_scores() + + # test_make_possible_match() + # test_make_best_match() + # test_make_hungarian_match() + + # test_do_score_labels() + # test_compare_spike_trains() + # test_do_confusion_matrix() + # test_do_count_score_and_perf()