Skip to content

Commit

Permalink
Sam's suggestions: renaming and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
alejoe91 committed Jun 26, 2024
1 parent 572ddf9 commit d707b69
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 26 deletions.
6 changes: 3 additions & 3 deletions src/spikeinterface/comparison/basecomparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,10 +314,10 @@ class MixinTemplateComparison:
Mixin for template comparisons to define:
* similarity method
* support
* n_shifts
* num_shifts
"""

def __init__(self, similarity_method="cosine", support="union", n_shifts=0):
def __init__(self, similarity_method="cosine", support="union", num_shifts=0):
self.similarity_method = similarity_method
self.support = support
self.n_shifts = n_shifts
self.num_shifts = num_shifts
6 changes: 4 additions & 2 deletions src/spikeinterface/comparison/multicomparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def __init__(
verbose=False,
similarity_method="cosine",
support="union",
n_shifts=0,
num_shifts=0,
do_matching=True,
):
if name_list is None:
Expand All @@ -348,7 +348,9 @@ def __init__(
chance_score=chance_score,
verbose=verbose,
)
MixinTemplateComparison.__init__(self, similarity_method=similarity_method, support=support, n_shifts=n_shifts)
MixinTemplateComparison.__init__(
self, similarity_method=similarity_method, support=support, num_shifts=num_shifts
)

if do_matching:
self._compute_all()
Expand Down
10 changes: 6 additions & 4 deletions src/spikeinterface/comparison/paircomparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,7 +708,7 @@ class TemplateComparison(BasePairComparison, MixinTemplateComparison):
Method for the similarity matrix.
support : "dense" | "union" | "intersection", default: "union"
The support to compute the similarity matrix.
n_shifts : int, default: 0
num_shifts : int, default: 0
Number of shifts to use to shift templates to maximize similarity.
verbose : bool, default: False
If True, output is verbose.
Expand All @@ -731,7 +731,7 @@ def __init__(
chance_score=0.3,
similarity_method="cosine",
support="union",
n_shifts=0,
num_shifts=0,
verbose=False,
):
if name1 is None:
Expand All @@ -748,7 +748,9 @@ def __init__(
chance_score=chance_score,
verbose=verbose,
)
MixinTemplateComparison.__init__(self, similarity_method=similarity_method, support=support, n_shifts=n_shifts)
MixinTemplateComparison.__init__(
self, similarity_method=similarity_method, support=support, num_shifts=num_shifts
)

self.sorting_analyzer_1 = sorting_analyzer_1
self.sorting_analyzer_2 = sorting_analyzer_2
Expand Down Expand Up @@ -782,7 +784,7 @@ def _do_agreement(self):
self.sorting_analyzer_2,
method=self.similarity_method,
support=self.support,
n_shifts=self.n_shifts,
num_shifts=self.num_shifts,
)
import pandas as pd

Expand Down
32 changes: 19 additions & 13 deletions src/spikeinterface/postprocessing/template_similarity.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import numpy as np
import warnings

from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension
from ..core.template_tools import get_dense_templates_array
Expand Down Expand Up @@ -47,6 +48,11 @@ def __init__(self, sorting_analyzer):

def _set_params(self, method="cosine", max_lag_ms=0, support="union"):
if method == "cosine_similarity":
warnings.warn(
"The method 'cosine_similarity' is deprecated and will be removed in the next version. Use 'cosine' instead.",
DeprecationWarning,
stacklevel=2,
)
method = "cosine"
params = dict(method=method, max_lag_ms=max_lag_ms, support=support)
return params
Expand All @@ -58,7 +64,7 @@ def _select_extension_data(self, unit_ids):
return dict(similarity=new_similarity)

def _run(self, verbose=False):
n_shifts = int(self.params["max_lag_ms"] * self.sorting_analyzer.sampling_frequency / 1000)
num_shifts = int(self.params["max_lag_ms"] * self.sorting_analyzer.sampling_frequency / 1000)
templates_array = get_dense_templates_array(
self.sorting_analyzer, return_scaled=self.sorting_analyzer.return_scaled
)
Expand All @@ -67,7 +73,7 @@ def _run(self, verbose=False):
templates_array,
templates_array,
method=self.params["method"],
n_shifts=n_shifts,
num_shifts=num_shifts,
support=self.params["support"],
sparsity=sparsity,
other_sparsity=sparsity,
Expand All @@ -84,7 +90,7 @@ def _get_data(self):


def compute_similarity_with_templates_array(
templates_array, other_templates_array, method, support="union", n_shifts=0, sparsity=None, other_sparsity=None
templates_array, other_templates_array, method, support="union", num_shifts=0, sparsity=None, other_sparsity=None
):

import sklearn.metrics.pairwise
Expand Down Expand Up @@ -127,15 +133,15 @@ def compute_similarity_with_templates_array(
overlapping_templates = {i: np.arange(other_num_templates) for i in range(num_templates)}
mask = np.ones((num_templates, other_num_templates, num_channels), dtype=bool)

assert n_shifts < num_samples, "max_lag is too large"
num_shifts = 2 * n_shifts + 1
distances = np.ones((num_shifts, num_templates, other_num_templates), dtype=np.float32)
assert num_shifts < num_samples, "max_lag is too large"
num_shifts_both_sides = 2 * num_shifts + 1
distances = np.ones((num_shifts_both_sides, num_templates, other_num_templates), dtype=np.float32)

# We can use the fact that dist[i,j] at lag t is equal to dist[j,i] at time -t
# So the matrix can be computed only for negative lags and be transposed
for count, shift in enumerate(range(-n_shifts, 1)):
src_sliced_templates = templates_array[:, n_shifts : num_samples - n_shifts]
tgt_sliced_templates = other_templates_array[:, n_shifts + shift : num_samples - n_shifts + shift]
for count, shift in enumerate(range(-num_shifts, 1)):
src_sliced_templates = templates_array[:, num_shifts : num_samples - num_shifts]
tgt_sliced_templates = other_templates_array[:, num_shifts + shift : num_samples - num_shifts + shift]
for i in range(num_templates):
src_template = src_sliced_templates[i]
tgt_templates = tgt_sliced_templates[overlapping_templates[i]]
Expand All @@ -161,8 +167,8 @@ def compute_similarity_with_templates_array(
if num_templates == other_num_templates:
distances[count, j, i] = distances[count, i, j]

if n_shifts != 0:
distances[num_shifts - count - 1] = distances[count].T
if num_shifts != 0:
distances[num_shifts_both_sides - count - 1] = distances[count].T

distances = np.min(distances, axis=0)
similarity = 1 - distances
Expand All @@ -171,7 +177,7 @@ def compute_similarity_with_templates_array(


def compute_template_similarity_by_pair(
sorting_analyzer_1, sorting_analyzer_2, method="cosine", support="union", n_shifts=0
sorting_analyzer_1, sorting_analyzer_2, method="cosine", support="union", num_shifts=0
):
templates_array_1 = get_dense_templates_array(sorting_analyzer_1, return_scaled=True)
templates_array_2 = get_dense_templates_array(sorting_analyzer_2, return_scaled=True)
Expand All @@ -182,7 +188,7 @@ def compute_template_similarity_by_pair(
templates_array_2,
method=method,
support=support,
n_shifts=n_shifts,
num_shifts=num_shifts,
sparsity=sparsity_1,
other_sparsity=sparsity_2,
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pytest

from spikeinterface.postprocessing.tests.common_extension_tests import (
AnalyzerExtensionCommonTestSuite,
)
Expand All @@ -7,10 +9,19 @@

class TestSimilarityExtension(AnalyzerExtensionCommonTestSuite):

def test_extension(self):
self.run_extension_tests(ComputeTemplateSimilarity, params=dict(method="cosine"))
self.run_extension_tests(ComputeTemplateSimilarity, params=dict(method="l2"))
self.run_extension_tests(ComputeTemplateSimilarity, params=dict(method="l1", max_lag_ms=0.2))
@pytest.mark.parametrize(
"params",
[
dict(method="cosine"),
dict(method="l2"),
dict(method="l1", max_lag_ms=0.2),
dict(method="l1", support="intersection"),
dict(method="l2", support="union"),
dict(method="cosine", support="dense"),
],
)
def test_extension(self, params):
self.run_extension_tests(ComputeTemplateSimilarity, params=params)

def test_check_equal_template_with_distribution_overlap(self):
"""
Expand Down

0 comments on commit d707b69

Please sign in to comment.