Skip to content

Commit

Permalink
Template similarity lags (#2941)
Browse files Browse the repository at this point in the history
Extend template similarity with lags and distance metrics

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Alessio Buccino <[email protected]>
  • Loading branch information
3 people authored Jun 26, 2024
1 parent 99cc04e commit 921ec82
Show file tree
Hide file tree
Showing 5 changed files with 179 additions and 45 deletions.
8 changes: 5 additions & 3 deletions src/spikeinterface/comparison/basecomparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,9 +313,11 @@ class MixinTemplateComparison:
"""
Mixin for template comparisons to define:
* similarity method
* sparsity
* support
* num_shifts
"""

def __init__(self, similarity_method="cosine_similarity", sparsity_dict=None):
def __init__(self, similarity_method="cosine", support="union", num_shifts=0):
self.similarity_method = similarity_method
self.sparsity_dict = sparsity_dict
self.support = support
self.num_shifts = num_shifts
9 changes: 6 additions & 3 deletions src/spikeinterface/comparison/multicomparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,8 +333,9 @@ def __init__(
match_score=0.8,
chance_score=0.3,
verbose=False,
similarity_method="cosine_similarity",
sparsity_dict=None,
similarity_method="cosine",
support="union",
num_shifts=0,
do_matching=True,
):
if name_list is None:
Expand All @@ -347,7 +348,9 @@ def __init__(
chance_score=chance_score,
verbose=verbose,
)
MixinTemplateComparison.__init__(self, similarity_method=similarity_method, sparsity_dict=sparsity_dict)
MixinTemplateComparison.__init__(
self, similarity_method=similarity_method, support=support, num_shifts=num_shifts
)

if do_matching:
self._compute_all()
Expand Down
47 changes: 25 additions & 22 deletions src/spikeinterface/comparison/paircomparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,24 +697,26 @@ class TemplateComparison(BasePairComparison, MixinTemplateComparison):
Parameters
----------
sorting_analyzer_1 : SortingAnalyzer
The first SortingAnalyzer to get templates to compare
The first SortingAnalyzer to get templates to compare.
sorting_analyzer_2 : SortingAnalyzer
The second SortingAnalyzer to get templates to compare
The second SortingAnalyzer to get templates to compare.
unit_ids1 : list, default: None
List of units from sorting_analyzer_1 to compare
List of units from sorting_analyzer_1 to compare.
unit_ids2 : list, default: None
List of units from sorting_analyzer_2 to compare
similarity_method : str, default: "cosine_similarity"
Method for the similaroty matrix
sparsity_dict : dict, default: None
Dictionary for sparsity
List of units from sorting_analyzer_2 to compare.
similarity_method : "cosine" | "l1" | "l2", default: "cosine"
Method for the similarity matrix.
support : "dense" | "union" | "intersection", default: "union"
The support to compute the similarity matrix.
num_shifts : int, default: 0
Number of shifts to use to shift templates to maximize similarity.
verbose : bool, default: False
If True, output is verbose
If True, output is verbose.
Returns
-------
comparison : TemplateComparison
The output TemplateComparison object
The output TemplateComparison object.
"""

def __init__(
Expand All @@ -727,8 +729,9 @@ def __init__(
unit_ids2=None,
match_score=0.7,
chance_score=0.3,
similarity_method="cosine_similarity",
sparsity_dict=None,
similarity_method="cosine",
support="union",
num_shifts=0,
verbose=False,
):
if name1 is None:
Expand All @@ -745,7 +748,9 @@ def __init__(
chance_score=chance_score,
verbose=verbose,
)
MixinTemplateComparison.__init__(self, similarity_method=similarity_method, sparsity_dict=sparsity_dict)
MixinTemplateComparison.__init__(
self, similarity_method=similarity_method, support=support, num_shifts=num_shifts
)

self.sorting_analyzer_1 = sorting_analyzer_1
self.sorting_analyzer_2 = sorting_analyzer_2
Expand All @@ -754,10 +759,9 @@ def __init__(

# two options: all channels are shared or partial channels are shared
if sorting_analyzer_1.recording.get_num_channels() != sorting_analyzer_2.recording.get_num_channels():
raise NotImplementedError
raise ValueError("The two recordings must have the same number of channels")
if np.any([ch1 != ch2 for (ch1, ch2) in zip(channel_ids1, channel_ids2)]):
# TODO: here we can check location and run it on the union. Might be useful for reconfigurable probes
raise NotImplementedError
raise ValueError("The two recordings must have the same channel ids")

self.matches = dict()

Expand All @@ -768,11 +772,6 @@ def __init__(
unit_ids2 = sorting_analyzer_2.sorting.get_unit_ids()
self.unit_ids = [unit_ids1, unit_ids2]

if sparsity_dict is not None:
raise NotImplementedError
else:
self.sparsity = None

self._do_agreement()
self._do_matching()

Expand All @@ -781,7 +780,11 @@ def _do_agreement(self):
print("Agreement scores...")

agreement_scores = compute_template_similarity_by_pair(
self.sorting_analyzer_1, self.sorting_analyzer_2, method=self.similarity_method
self.sorting_analyzer_1,
self.sorting_analyzer_2,
method=self.similarity_method,
support=self.support,
num_shifts=self.num_shifts,
)
import pandas as pd

Expand Down
143 changes: 128 additions & 15 deletions src/spikeinterface/postprocessing/template_similarity.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import numpy as np
import warnings

from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension
from ..core.template_tools import get_dense_templates_array
Expand All @@ -9,13 +10,26 @@
class ComputeTemplateSimilarity(AnalyzerExtension):
"""Compute similarity between templates with several methods.
Similarity is defined as 1 - distance(T_1, T_2) for two templates T_1, T_2
Parameters
----------
sorting_analyzer: SortingAnalyzer
sorting_analyzer : SortingAnalyzer
The SortingAnalyzer object
method: str, default: "cosine_similarity"
The method to compute the similarity
method : str, default: "cosine"
The method to compute the similarity. Can be in ["cosine", "l2", "l1"]
max_lag_ms : float, default: 0
If specified, the best distance for all given lag within max_lag_ms is kept, for every template
support : "dense" | "union" | "intersection", default: "union"
Support that should be considered to compute the distances between the templates, given their sparsities.
Can be either ["dense", "union", "intersection"]
In case of "l1" or "l2", the formula used is:
similarity = 1 - norm(T_1 - T_2)/(norm(T_1) + norm(T_2))
In case of cosine this is:
similarity = 1 - sum(T_1.T_2)/(norm(T_1)norm(T_2))
Returns
-------
Expand All @@ -32,8 +46,15 @@ class ComputeTemplateSimilarity(AnalyzerExtension):
def __init__(self, sorting_analyzer):
AnalyzerExtension.__init__(self, sorting_analyzer)

def _set_params(self, method="cosine_similarity"):
params = dict(method=method)
def _set_params(self, method="cosine", max_lag_ms=0, support="union"):
if method == "cosine_similarity":
warnings.warn(
"The method 'cosine_similarity' is deprecated and will be removed in the next version. Use 'cosine' instead.",
DeprecationWarning,
stacklevel=2,
)
method = "cosine"
params = dict(method=method, max_lag_ms=max_lag_ms, support=support)
return params

def _select_extension_data(self, unit_ids):
Expand All @@ -43,11 +64,19 @@ def _select_extension_data(self, unit_ids):
return dict(similarity=new_similarity)

def _run(self, verbose=False):
num_shifts = int(self.params["max_lag_ms"] * self.sorting_analyzer.sampling_frequency / 1000)
templates_array = get_dense_templates_array(
self.sorting_analyzer, return_scaled=self.sorting_analyzer.return_scaled
)
sparsity = self.sorting_analyzer.sparsity
similarity = compute_similarity_with_templates_array(
templates_array, templates_array, method=self.params["method"]
templates_array,
templates_array,
method=self.params["method"],
num_shifts=num_shifts,
support=self.params["support"],
sparsity=sparsity,
other_sparsity=sparsity,
)
self.data["similarity"] = similarity

Expand All @@ -60,25 +89,109 @@ def _get_data(self):
compute_template_similarity = ComputeTemplateSimilarity.function_factory()


def compute_similarity_with_templates_array(templates_array, other_templates_array, method):
def compute_similarity_with_templates_array(
templates_array, other_templates_array, method, support="union", num_shifts=0, sparsity=None, other_sparsity=None
):

import sklearn.metrics.pairwise

if method == "cosine_similarity":
assert templates_array.shape[0] == other_templates_array.shape[0]
templates_flat = templates_array.reshape(templates_array.shape[0], -1)
other_templates_flat = templates_array.reshape(other_templates_array.shape[0], -1)
similarity = sklearn.metrics.pairwise.cosine_similarity(templates_flat, other_templates_flat)

method = "cosine"

all_metrics = ["cosine", "l1", "l2"]

if method not in all_metrics:
raise ValueError(f"compute_template_similarity (method {method}) not exists")

assert (
templates_array.shape[1] == other_templates_array.shape[1]
), "The number of samples in the templates should be the same for both arrays"
assert (
templates_array.shape[2] == other_templates_array.shape[2]
), "The number of channels in the templates should be the same for both arrays"
num_templates = templates_array.shape[0]
num_samples = templates_array.shape[1]
num_channels = templates_array.shape[2]
other_num_templates = other_templates_array.shape[0]

mask = None
if sparsity is not None and other_sparsity is not None:
if support == "intersection":
mask = np.logical_and(sparsity.mask[:, np.newaxis, :], other_sparsity.mask[np.newaxis, :, :])
elif support == "union":
mask = np.logical_and(sparsity.mask[:, np.newaxis, :], other_sparsity.mask[np.newaxis, :, :])
units_overlaps = np.sum(mask, axis=2) > 0
mask = np.logical_or(sparsity.mask[:, np.newaxis, :], other_sparsity.mask[np.newaxis, :, :])
mask[~units_overlaps] = False
if mask is not None:
units_overlaps = np.sum(mask, axis=2) > 0
overlapping_templates = {}
for i in range(num_templates):
overlapping_templates[i] = np.flatnonzero(units_overlaps[i])
else:
raise ValueError(f"compute_template_similarity(method {method}) not exists")
# here we make a dense mask and overlapping templates
overlapping_templates = {i: np.arange(other_num_templates) for i in range(num_templates)}
mask = np.ones((num_templates, other_num_templates, num_channels), dtype=bool)

assert num_shifts < num_samples, "max_lag is too large"
num_shifts_both_sides = 2 * num_shifts + 1
distances = np.ones((num_shifts_both_sides, num_templates, other_num_templates), dtype=np.float32)

# We can use the fact that dist[i,j] at lag t is equal to dist[j,i] at time -t
# So the matrix can be computed only for negative lags and be transposed
for count, shift in enumerate(range(-num_shifts, 1)):
src_sliced_templates = templates_array[:, num_shifts : num_samples - num_shifts]
tgt_sliced_templates = other_templates_array[:, num_shifts + shift : num_samples - num_shifts + shift]
for i in range(num_templates):
src_template = src_sliced_templates[i]
tgt_templates = tgt_sliced_templates[overlapping_templates[i]]
for gcount, j in enumerate(overlapping_templates[i]):
# symmetric values are handled later
if num_templates == other_num_templates and j < i:
continue
src = src_template[:, mask[i, j]].reshape(1, -1)
tgt = (tgt_templates[gcount][:, mask[i, j]]).reshape(1, -1)

if method == "l1":
norm_i = np.sum(np.abs(src))
norm_j = np.sum(np.abs(tgt))
distances[count, i, j] = sklearn.metrics.pairwise.pairwise_distances(src, tgt, metric="l1")
distances[count, i, j] /= norm_i + norm_j
elif method == "l2":
norm_i = np.linalg.norm(src, ord=2)
norm_j = np.linalg.norm(tgt, ord=2)
distances[count, i, j] = sklearn.metrics.pairwise.pairwise_distances(src, tgt, metric="l2")
distances[count, i, j] /= norm_i + norm_j
else:
distances[count, i, j] = sklearn.metrics.pairwise.pairwise_distances(src, tgt, metric="cosine")
if num_templates == other_num_templates:
distances[count, j, i] = distances[count, i, j]

if num_shifts != 0:
distances[num_shifts_both_sides - count - 1] = distances[count].T

distances = np.min(distances, axis=0)
similarity = 1 - distances

return similarity


def compute_template_similarity_by_pair(sorting_analyzer_1, sorting_analyzer_2, method="cosine_similarity"):
def compute_template_similarity_by_pair(
sorting_analyzer_1, sorting_analyzer_2, method="cosine", support="union", num_shifts=0
):
templates_array_1 = get_dense_templates_array(sorting_analyzer_1, return_scaled=True)
templates_array_2 = get_dense_templates_array(sorting_analyzer_2, return_scaled=True)
similarity = compute_similarity_with_templates_array(templates_array_1, templates_array_2, method)
sparsity_1 = sorting_analyzer_1.sparsity
sparsity_2 = sorting_analyzer_2.sparsity
similarity = compute_similarity_with_templates_array(
templates_array_1,
templates_array_2,
method=method,
support=support,
num_shifts=num_shifts,
sparsity=sparsity_1,
other_sparsity=sparsity_2,
)
return similarity


Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pytest

from spikeinterface.postprocessing.tests.common_extension_tests import (
AnalyzerExtensionCommonTestSuite,
)
Expand All @@ -7,8 +9,19 @@

class TestSimilarityExtension(AnalyzerExtensionCommonTestSuite):

def test_extension(self):
self.run_extension_tests(ComputeTemplateSimilarity, params=dict(method="cosine_similarity"))
@pytest.mark.parametrize(
"params",
[
dict(method="cosine"),
dict(method="l2"),
dict(method="l1", max_lag_ms=0.2),
dict(method="l1", support="intersection"),
dict(method="l2", support="union"),
dict(method="cosine", support="dense"),
],
)
def test_extension(self, params):
self.run_extension_tests(ComputeTemplateSimilarity, params=params)

def test_check_equal_template_with_distribution_overlap(self):
"""
Expand Down

0 comments on commit 921ec82

Please sign in to comment.