From 324752aff01b0a38d027419c0f2ac92bf8bcd10d Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 16 May 2023 08:28:02 +0000 Subject: [PATCH 01/80] some change to test --- src/spikeinterface/core/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index a0c61a916f..6377eb695c 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -26,7 +26,7 @@ class BaseExtractor: """ - default_missing_property_values = {"f": np.nan, "O": None, "S": "", "U": ""} + default_missing_property_values = {'f': np.nan, "O": None, "S": "", "U": ""} # This replaces the old key_properties # These are annotations/properties that always need to be From 9c5d409daf4a6f10eb593068b6562fcba3387a0a Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 16 May 2023 08:34:28 +0000 Subject: [PATCH 02/80] another change --- src/spikeinterface/core/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 6377eb695c..a68ed36e6f 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -26,7 +26,7 @@ class BaseExtractor: """ - default_missing_property_values = {'f': np.nan, "O": None, "S": "", "U": ""} + default_missing_property_values = {'f': np.nan, "O": None, 'S': "", "U": ""} # This replaces the old key_properties # These are annotations/properties that always need to be From 5b305517c40bbd6c45787f05d65e2c0dbe66bbb8 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 16 May 2023 08:44:36 +0000 Subject: [PATCH 03/80] another attempt --- .github/workflows/pre-commit-post-merge.yml | 1 + src/spikeinterface/core/base.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/pre-commit-post-merge.yml b/.github/workflows/pre-commit-post-merge.yml index e260b1d317..eec3d411f7 100644 --- a/.github/workflows/pre-commit-post-merge.yml +++ b/.github/workflows/pre-commit-post-merge.yml @@ -1,6 +1,7 @@ name: Use pre-commit post merge on: + pull_request: push: branches: [main] diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index a68ed36e6f..f76cda9143 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -35,7 +35,7 @@ class BaseExtractor: _main_properties = [] installed = True - installation_mesg = "" + installation_mesg = '' is_writable = False def __init__(self, main_ids): From 298d3076c89b8a5aed4b6e4728bf160913140210 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 16 May 2023 08:56:36 +0000 Subject: [PATCH 04/80] attempt merge --- .github/workflows/pre-commit-post-merge.yml | 24 +++++++++++++++------ 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/.github/workflows/pre-commit-post-merge.yml b/.github/workflows/pre-commit-post-merge.yml index e260b1d317..29dcb80e91 100644 --- a/.github/workflows/pre-commit-post-merge.yml +++ b/.github/workflows/pre-commit-post-merge.yml @@ -8,10 +8,20 @@ jobs: pre-commit-post-merge: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 - with: - python-version: '3.10' - - uses: pre-commit/action@v3.0.0 - - uses: pre-commit-ci/lite-action@v1.0.1 - if: always() + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: '3.10' + - name: Install pre-commit + run: | + python -m pip install pre-commit + - name: Run pre-commit hooks + run: | + pre-commit run --all-files + - name: Commit changes + env: + GITHUB_TOKEN: ${{ secrets.PRE_COMMIT_TOKEN }} + run: | + git config --local user.email "action@github.com" + git config --local user.name "GitHub Action" + git diff --quiet && git diff --staged --quiet || (git commit -am "Format code"; git push) \ No newline at end of file From daf7feaeeddd8b1c37f04b854a6f7cd02c89f24f Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 16 May 2023 09:02:20 +0000 Subject: [PATCH 05/80] add condition --- .github/workflows/pre-commit-post-merge.yml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/pre-commit-post-merge.yml b/.github/workflows/pre-commit-post-merge.yml index 131cc6584a..19d39b6ad2 100644 --- a/.github/workflows/pre-commit-post-merge.yml +++ b/.github/workflows/pre-commit-post-merge.yml @@ -18,11 +18,12 @@ jobs: python -m pip install pre-commit - name: Run pre-commit hooks run: | - pre-commit run --all-files + pre-commit run --all-files || true - name: Commit changes env: GITHUB_TOKEN: ${{ secrets.PRE_COMMIT_TOKEN }} run: | git config --local user.email "action@github.com" git config --local user.name "GitHub Action" - git diff --quiet && git diff --staged --quiet || (git commit -am "Format code"; git push) \ No newline at end of file + git diff --quiet && git diff --staged --quiet || (git commit -am "Format code"; git push) + if: always() \ No newline at end of file From 798b73282946d2b8e372fce79770c7919d295864 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 16 May 2023 09:17:57 +0000 Subject: [PATCH 06/80] add auth --- .github/workflows/pre-commit-post-merge.yml | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/.github/workflows/pre-commit-post-merge.yml b/.github/workflows/pre-commit-post-merge.yml index 19d39b6ad2..d24d8a3205 100644 --- a/.github/workflows/pre-commit-post-merge.yml +++ b/.github/workflows/pre-commit-post-merge.yml @@ -1,7 +1,6 @@ name: Use pre-commit post merge on: - pull_request: push: branches: [main] @@ -25,5 +24,5 @@ jobs: run: | git config --local user.email "action@github.com" git config --local user.name "GitHub Action" - git diff --quiet && git diff --staged --quiet || (git commit -am "Format code"; git push) - if: always() \ No newline at end of file + git diff --quiet && git diff --staged --quiet || (git commit -am "Format code"; git remote set-url origin https://x-access-token:${{ secrets.PRE_COMMIT_TOKEN }}@github.com/h-mayorquin/spikeinterface.git; git push) + if: always() From 513a344f1fb8ef942467fd05056f1e6e4a8fa541 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 12 Sep 2023 12:20:52 +0200 Subject: [PATCH 07/80] add basic instance and numpy behavior --- .../core/tests/test_template_class.py | 60 +++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 src/spikeinterface/core/tests/test_template_class.py diff --git a/src/spikeinterface/core/tests/test_template_class.py b/src/spikeinterface/core/tests/test_template_class.py new file mode 100644 index 0000000000..defad77e00 --- /dev/null +++ b/src/spikeinterface/core/tests/test_template_class.py @@ -0,0 +1,60 @@ +from pathlib import Path + +import numpy as np +import pytest + +from spikeinterface.core.template import Templates + + +def test_dense_template_instance(): + num_units = 2 + num_samples = 4 + num_channels = 3 + templates_shape = (num_units, num_samples, num_channels) + templates_array = np.arange(num_units * num_samples * num_channels).reshape(templates_shape) + + templates = Templates(templates_array=templates_array) + + assert np.array_equal(templates.templates_array, templates_array) + assert templates.sparsity is None + assert templates.num_units == num_units + assert templates.num_samples == num_samples + assert templates.num_channels == num_channels + + +def test_numpy_like_behavior(): + num_units = 2 + num_samples = 4 + num_channels = 3 + templates_shape = (num_units, num_samples, num_channels) + templates_array = np.arange(num_units * num_samples * num_channels).reshape(templates_shape) + + templates = Templates(templates_array=templates_array) + + # Test that slicing works as in numpy + assert np.array_equal(templates[:], templates_array[:]) + assert np.array_equal(templates[0], templates_array[0]) + assert np.array_equal(templates[0, :], templates_array[0, :]) + assert np.array_equal(templates[0, :, :], templates_array[0, :, :]) + assert np.array_equal(templates[3:5, :, 2], templates_array[3:5, :, 2]) + + # Test unary ufuncs + assert np.array_equal(np.sqrt(templates), np.sqrt(templates_array)) + assert np.array_equal(np.abs(templates), np.abs(templates_array)) + assert np.array_equal(np.mean(templates, axis=0), np.mean(templates_array, axis=0)) + + # Test binary ufuncs + other_array = np.random.rand(*templates_shape) + other_template = Templates(templates_array=other_array) + + assert np.array_equal(np.add(templates, other_template), np.add(templates_array, other_array)) + assert np.array_equal(np.multiply(templates, other_template), np.multiply(templates_array, other_array)) + + # Test chaining of operations + chained_result = np.mean(np.multiply(templates, other_template), axis=0) + chained_expected = np.mean(np.multiply(templates_array, other_array), axis=0) + assert np.array_equal(chained_result, chained_expected) + + # Test ufuncs that return non-ndarray results + assert np.all(np.greater(templates, -1)) + assert not np.any(np.less(templates, 0)) From c242446086bca416766912fab079e17db100bd03 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 12 Sep 2023 12:24:27 +0200 Subject: [PATCH 08/80] add pickability --- .../core/tests/test_template_class.py | 22 ++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/tests/test_template_class.py b/src/spikeinterface/core/tests/test_template_class.py index defad77e00..b864421824 100644 --- a/src/spikeinterface/core/tests/test_template_class.py +++ b/src/spikeinterface/core/tests/test_template_class.py @@ -1,5 +1,5 @@ from pathlib import Path - +import pickle import numpy as np import pytest @@ -58,3 +58,23 @@ def test_numpy_like_behavior(): # Test ufuncs that return non-ndarray results assert np.all(np.greater(templates, -1)) assert not np.any(np.less(templates, 0)) + + +def test_pickle(): + num_units = 2 + num_samples = 4 + num_channels = 3 + templates_shape = (num_units, num_samples, num_channels) + templates_array = np.arange(num_units * num_samples * num_channels).reshape(templates_shape) + + templates = Templates(templates_array=templates_array) + + # Serialize and deserialize the object + serialized = pickle.dumps(templates) + deserialized = pickle.loads(serialized) + + assert np.array_equal(templates.templates_array, deserialized.templates_array) + assert templates.sparsity == deserialized.sparsity + assert templates.num_units == deserialized.num_units + assert templates.num_samples == deserialized.num_samples + assert templates.num_channels == deserialized.num_channels From 4ca9ec6d504adb83561ce8bb57189ff3e9ec7a8e Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 12 Sep 2023 12:40:45 +0200 Subject: [PATCH 09/80] add json test --- src/spikeinterface/core/template.py | 59 +++++++++++++++++++ .../core/tests/test_template_class.py | 54 ++++++++++------- 2 files changed, 93 insertions(+), 20 deletions(-) create mode 100644 src/spikeinterface/core/template.py diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py new file mode 100644 index 0000000000..3542879df4 --- /dev/null +++ b/src/spikeinterface/core/template.py @@ -0,0 +1,59 @@ +import json +from dataclasses import dataclass, field + +import numpy as np + +from spikeinterface.core.sparsity import ChannelSparsity + + +@dataclass +class Templates: + templates_array: np.ndarray + sparsity: ChannelSparsity = None + num_units: int = field(init=False) + num_samples: int = field(init=False) + num_channels: int = field(init=False) + + def __post_init__(self): + self.num_units, self.num_samples, self.num_channels = self.templates_array.shape + + # Implementing the slicing/indexing behavior as numpy + def __getitem__(self, index): + return self.templates_array[index] + + def __array__(self): + return self.templates_array + + def __array_ufunc__(self, ufunc, method, *inputs, **kwargs) -> np.ndarray: + # Replace any Templates instances with their ndarray representation + inputs = tuple(inp.templates_array if isinstance(inp, Templates) else inp for inp in inputs) + + # Apply the ufunc on the transformed inputs + result = getattr(ufunc, method)(*inputs, **kwargs) + + return result + + def to_dict(self): + sparsity = self.sparsity.to_dict() if self.sparsity is not None else None + return { + "templates_array": self.templates_array.tolist(), + "sparsity": sparsity, + "num_units": self.num_units, + "num_samples": self.num_samples, + "num_channels": self.num_channels, + } + + # Construct the object from a dictionary + @classmethod + def from_dict(cls, data): + return cls( + templates_array=np.array(data["templates_array"]), + sparsity=ChannelSparsity(data["sparsity"]), # Assuming you can reconstruct a ChannelSparsity from a string + ) + + def to_json(self): + return json.dumps(self.to_dict()) + + @classmethod + def from_json(cls, json_str): + return cls.from_dict(json.loads(json_str)) diff --git a/src/spikeinterface/core/tests/test_template_class.py b/src/spikeinterface/core/tests/test_template_class.py index b864421824..e9e3f60730 100644 --- a/src/spikeinterface/core/tests/test_template_class.py +++ b/src/spikeinterface/core/tests/test_template_class.py @@ -1,19 +1,28 @@ from pathlib import Path import pickle +import json + import numpy as np import pytest from spikeinterface.core.template import Templates -def test_dense_template_instance(): +@pytest.fixture +def dense_templates(): num_units = 2 num_samples = 4 num_channels = 3 templates_shape = (num_units, num_samples, num_channels) templates_array = np.arange(num_units * num_samples * num_channels).reshape(templates_shape) - templates = Templates(templates_array=templates_array) + return Templates(templates_array=templates_array) + + +def test_dense_template_instance(dense_templates): + templates = dense_templates + templates_array = templates.templates_array + num_units, num_samples, num_channels = templates_array.shape assert np.array_equal(templates.templates_array, templates_array) assert templates.sparsity is None @@ -22,14 +31,9 @@ def test_dense_template_instance(): assert templates.num_channels == num_channels -def test_numpy_like_behavior(): - num_units = 2 - num_samples = 4 - num_channels = 3 - templates_shape = (num_units, num_samples, num_channels) - templates_array = np.arange(num_units * num_samples * num_channels).reshape(templates_shape) - - templates = Templates(templates_array=templates_array) +def test_numpy_like_behavior(dense_templates): + templates = dense_templates + templates_array = templates.templates_array # Test that slicing works as in numpy assert np.array_equal(templates[:], templates_array[:]) @@ -44,7 +48,7 @@ def test_numpy_like_behavior(): assert np.array_equal(np.mean(templates, axis=0), np.mean(templates_array, axis=0)) # Test binary ufuncs - other_array = np.random.rand(*templates_shape) + other_array = np.random.rand(*templates_array.shape) other_template = Templates(templates_array=other_array) assert np.array_equal(np.add(templates, other_template), np.add(templates_array, other_array)) @@ -60,19 +64,29 @@ def test_numpy_like_behavior(): assert not np.any(np.less(templates, 0)) -def test_pickle(): - num_units = 2 - num_samples = 4 - num_channels = 3 - templates_shape = (num_units, num_samples, num_channels) - templates_array = np.arange(num_units * num_samples * num_channels).reshape(templates_shape) - - templates = Templates(templates_array=templates_array) +def test_pickle(dense_templates): + templates = dense_templates # Serialize and deserialize the object serialized = pickle.dumps(templates) - deserialized = pickle.loads(serialized) + deserialized_templates = pickle.loads(serialized) + + assert np.array_equal(templates.templates_array, deserialized_templates.templates_array) + assert templates.sparsity == deserialized_templates.sparsity + assert templates.num_units == deserialized_templates.num_units + assert templates.num_samples == deserialized_templates.num_samples + assert templates.num_channels == deserialized_templates.num_channels + + +def test_jsonification(dense_templates): + templates = dense_templates + # Serialize to JSON string + serialized = templates.to_json() + + # Deserialize back to object + deserialized = Templates.from_json(serialized) + # Check if deserialized object matches original assert np.array_equal(templates.templates_array, deserialized.templates_array) assert templates.sparsity == deserialized.sparsity assert templates.num_units == deserialized.num_units From 107bdf96ec26368dfa0666c2e7ac9bd5dd596563 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 12 Sep 2023 17:20:31 +0200 Subject: [PATCH 10/80] test fancy indices --- src/spikeinterface/core/template.py | 3 ++- src/spikeinterface/core/tests/test_template_class.py | 10 +++++++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index 3542879df4..2906692902 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -46,9 +46,10 @@ def to_dict(self): # Construct the object from a dictionary @classmethod def from_dict(cls, data): + sparsity = ChannelSparsity.from_dict(data["sparsity"]) if data["sparsity"] is not None else None return cls( templates_array=np.array(data["templates_array"]), - sparsity=ChannelSparsity(data["sparsity"]), # Assuming you can reconstruct a ChannelSparsity from a string + sparsity=sparsity, ) def to_json(self): diff --git a/src/spikeinterface/core/tests/test_template_class.py b/src/spikeinterface/core/tests/test_template_class.py index e9e3f60730..62673906ab 100644 --- a/src/spikeinterface/core/tests/test_template_class.py +++ b/src/spikeinterface/core/tests/test_template_class.py @@ -6,6 +6,7 @@ import pytest from spikeinterface.core.template import Templates +from spikeinterface.core.sparsity import ChannelSparsity @pytest.fixture @@ -41,6 +42,14 @@ def test_numpy_like_behavior(dense_templates): assert np.array_equal(templates[0, :], templates_array[0, :]) assert np.array_equal(templates[0, :, :], templates_array[0, :, :]) assert np.array_equal(templates[3:5, :, 2], templates_array[3:5, :, 2]) + # Test fancy indexing + indices = np.array([0, 1]) + assert np.array_equal(templates[indices], templates_array[indices]) + row_indices = np.array([0, 1]) + col_indices = np.array([2, 3]) + assert np.array_equal(templates[row_indices, col_indices], templates_array[row_indices, col_indices]) + mask = templates_array > 0.5 + assert np.array_equal(templates[mask], templates_array[mask]) # Test unary ufuncs assert np.array_equal(np.sqrt(templates), np.sqrt(templates_array)) @@ -50,7 +59,6 @@ def test_numpy_like_behavior(dense_templates): # Test binary ufuncs other_array = np.random.rand(*templates_array.shape) other_template = Templates(templates_array=other_array) - assert np.array_equal(np.add(templates, other_template), np.add(templates_array, other_array)) assert np.array_equal(np.multiply(templates, other_template), np.multiply(templates_array, other_array)) From d2c6ec727c01c47ec9635a5279eb3d432ce370ce Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 20 Sep 2023 22:39:01 +0200 Subject: [PATCH 11/80] alessio and samuel requests --- src/spikeinterface/core/template.py | 90 ++++++--- .../core/tests/test_template_class.py | 178 +++++++++--------- 2 files changed, 144 insertions(+), 124 deletions(-) diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index 2906692902..4b923db2df 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -1,55 +1,67 @@ +import numpy as np import json from dataclasses import dataclass, field +from .sparsity import ChannelSparsity -import numpy as np - -from spikeinterface.core.sparsity import ChannelSparsity - -@dataclass +@dataclass(kw_only=True) class Templates: templates_array: np.ndarray - sparsity: ChannelSparsity = None + sampling_frequency: float + nbefore: int + + sparsity_mask: np.ndarray = None + channel_ids: np.ndarray = None + unit_ids: np.ndarray = None + num_units: int = field(init=False) num_samples: int = field(init=False) num_channels: int = field(init=False) - def __post_init__(self): - self.num_units, self.num_samples, self.num_channels = self.templates_array.shape - - # Implementing the slicing/indexing behavior as numpy - def __getitem__(self, index): - return self.templates_array[index] - - def __array__(self): - return self.templates_array - - def __array_ufunc__(self, ufunc, method, *inputs, **kwargs) -> np.ndarray: - # Replace any Templates instances with their ndarray representation - inputs = tuple(inp.templates_array if isinstance(inp, Templates) else inp for inp in inputs) + nafter: int = field(init=False) + ms_before: float = field(init=False) + ms_after: float = field(init=False) + sparsity: ChannelSparsity = field(init=False) - # Apply the ufunc on the transformed inputs - result = getattr(ufunc, method)(*inputs, **kwargs) - - return result + def __post_init__(self): + self.num_units, self.num_samples = self.templates_array.shape[:2] + if self.sparsity_mask is None: + self.num_channels = self.templates_array.shape[2] + else: + self.num_channels = self.sparsity_mask.shape[1] + self.nafter = self.num_samples - self.nbefore - 1 + self.ms_before = self.nbefore / self.sampling_frequency * 1000 + self.ms_after = self.nafter / self.sampling_frequency * 1000 + if self.channel_ids is None: + self.channel_ids = np.arange(self.num_channels) + if self.unit_ids is None: + self.unit_ids = np.arange(self.num_units) + if self.sparsity_mask is not None: + self.sparsity = ChannelSparsity( + mask=self.sparsity_mask, + unit_ids=self.unit_ids, + channel_ids=self.channel_ids, + ) def to_dict(self): - sparsity = self.sparsity.to_dict() if self.sparsity is not None else None return { "templates_array": self.templates_array.tolist(), - "sparsity": sparsity, - "num_units": self.num_units, - "num_samples": self.num_samples, - "num_channels": self.num_channels, + "sparsity_mask": None if self.sparsity_mask is None else self.sparsity_mask.tolist(), + "channel_ids": self.channel_ids.tolist(), + "unit_ids": self.unit_ids.tolist(), + "sampling_frequency": self.sampling_frequency, + "nbefore": self.nbefore, } - # Construct the object from a dictionary @classmethod def from_dict(cls, data): - sparsity = ChannelSparsity.from_dict(data["sparsity"]) if data["sparsity"] is not None else None return cls( templates_array=np.array(data["templates_array"]), - sparsity=sparsity, + sparsity_mask=None if data["sparsity_mask"] is None else np.array(data["sparsity_mask"]), + channel_ids=np.array(data["channel_ids"]), + unit_ids=np.array(data["unit_ids"]), + sampling_frequency=data["sampling_frequency"], + nbefore=data["nbefore"], ) def to_json(self): @@ -58,3 +70,19 @@ def to_json(self): @classmethod def from_json(cls, json_str): return cls.from_dict(json.loads(json_str)) + + # Implementing the slicing/indexing behavior as numpy + def __getitem__(self, index): + return self.templates_array[index] + + def __array__(self): + return self.templates_array + + def __array_ufunc__(self, ufunc, method, *inputs, **kwargs) -> np.ndarray: + # Replace any Templates instances with their ndarray representation + inputs = tuple(inp.templates_array if isinstance(inp, Templates) else inp for inp in inputs) + + # Apply the ufunc on the transformed inputs + result = getattr(ufunc, method)(*inputs, **kwargs) + + return result diff --git a/src/spikeinterface/core/tests/test_template_class.py b/src/spikeinterface/core/tests/test_template_class.py index 62673906ab..5fc997c6bf 100644 --- a/src/spikeinterface/core/tests/test_template_class.py +++ b/src/spikeinterface/core/tests/test_template_class.py @@ -1,102 +1,94 @@ -from pathlib import Path -import pickle -import json - -import numpy as np import pytest - +import numpy as np +import pickle from spikeinterface.core.template import Templates -from spikeinterface.core.sparsity import ChannelSparsity -@pytest.fixture -def dense_templates(): +@pytest.mark.parametrize("template_obj", ["dense", "sparse"]) +def get_template_object(template_obj): num_units = 2 - num_samples = 4 + num_samples = 5 num_channels = 3 templates_shape = (num_units, num_samples, num_channels) templates_array = np.arange(num_units * num_samples * num_channels).reshape(templates_shape) - return Templates(templates_array=templates_array) - - -def test_dense_template_instance(dense_templates): - templates = dense_templates - templates_array = templates.templates_array - num_units, num_samples, num_channels = templates_array.shape - - assert np.array_equal(templates.templates_array, templates_array) - assert templates.sparsity is None - assert templates.num_units == num_units - assert templates.num_samples == num_samples - assert templates.num_channels == num_channels - - -def test_numpy_like_behavior(dense_templates): - templates = dense_templates - templates_array = templates.templates_array - - # Test that slicing works as in numpy - assert np.array_equal(templates[:], templates_array[:]) - assert np.array_equal(templates[0], templates_array[0]) - assert np.array_equal(templates[0, :], templates_array[0, :]) - assert np.array_equal(templates[0, :, :], templates_array[0, :, :]) - assert np.array_equal(templates[3:5, :, 2], templates_array[3:5, :, 2]) - # Test fancy indexing - indices = np.array([0, 1]) - assert np.array_equal(templates[indices], templates_array[indices]) - row_indices = np.array([0, 1]) - col_indices = np.array([2, 3]) - assert np.array_equal(templates[row_indices, col_indices], templates_array[row_indices, col_indices]) - mask = templates_array > 0.5 - assert np.array_equal(templates[mask], templates_array[mask]) - - # Test unary ufuncs - assert np.array_equal(np.sqrt(templates), np.sqrt(templates_array)) - assert np.array_equal(np.abs(templates), np.abs(templates_array)) - assert np.array_equal(np.mean(templates, axis=0), np.mean(templates_array, axis=0)) - - # Test binary ufuncs - other_array = np.random.rand(*templates_array.shape) - other_template = Templates(templates_array=other_array) - assert np.array_equal(np.add(templates, other_template), np.add(templates_array, other_array)) - assert np.array_equal(np.multiply(templates, other_template), np.multiply(templates_array, other_array)) - - # Test chaining of operations - chained_result = np.mean(np.multiply(templates, other_template), axis=0) - chained_expected = np.mean(np.multiply(templates_array, other_array), axis=0) - assert np.array_equal(chained_result, chained_expected) - - # Test ufuncs that return non-ndarray results - assert np.all(np.greater(templates, -1)) - assert not np.any(np.less(templates, 0)) - - -def test_pickle(dense_templates): - templates = dense_templates - - # Serialize and deserialize the object - serialized = pickle.dumps(templates) - deserialized_templates = pickle.loads(serialized) - - assert np.array_equal(templates.templates_array, deserialized_templates.templates_array) - assert templates.sparsity == deserialized_templates.sparsity - assert templates.num_units == deserialized_templates.num_units - assert templates.num_samples == deserialized_templates.num_samples - assert templates.num_channels == deserialized_templates.num_channels - - -def test_jsonification(dense_templates): - templates = dense_templates - # Serialize to JSON string - serialized = templates.to_json() - - # Deserialize back to object - deserialized = Templates.from_json(serialized) - - # Check if deserialized object matches original - assert np.array_equal(templates.templates_array, deserialized.templates_array) - assert templates.sparsity == deserialized.sparsity - assert templates.num_units == deserialized.num_units - assert templates.num_samples == deserialized.num_samples - assert templates.num_channels == deserialized.num_channels + sampling_frequency = 30_000 + nbefore = 2 + + if template_obj == "dense": + return Templates(templates_array=templates_array, sampling_frequency=sampling_frequency, nbefore=nbefore) + else: # sparse + sparsity_mask = np.array([[True, False, True], [False, True, False]]) + return Templates( + templates_array=templates_array, + sparsity_mask=sparsity_mask, + sampling_frequency=sampling_frequency, + nbefore=nbefore, + ) + + +@pytest.mark.parametrize("template_obj", ["dense", "sparse"]) +def test_pickle_serialization(template_obj, tmp_path): + obj = get_template_object(template_obj) + + # Dump to pickle + pkl_path = tmp_path / "templates.pkl" + with open(pkl_path, "wb") as f: + pickle.dump(obj, f) + + # Load from pickle + with open(pkl_path, "rb") as f: + loaded_obj = pickle.load(f) + + assert np.array_equal(obj.templates_array, loaded_obj.templates_array) + + +@pytest.mark.parametrize("template_obj", ["dense", "sparse"]) +def test_json_serialization(template_obj): + obj = get_template_object(template_obj) + + json_str = obj.to_json() + loaded_obj_from_json = Templates.from_json(json_str) + + assert np.array_equal(obj.templates_array, loaded_obj_from_json.templates_array) + + +# @pytest.fixture +# def dense_templates(): +# num_units = 2 +# num_samples = 4 +# num_channels = 3 +# templates_shape = (num_units, num_samples, num_channels) +# templates_array = np.arange(num_units * num_samples * num_channels).reshape(templates_shape) + +# return Templates(templates_array=templates_array) + + +# def test_pickle(dense_templates): +# templates = dense_templates + +# # Serialize and deserialize the object +# serialized = pickle.dumps(templates) +# deserialized_templates = pickle.loads(serialized) + +# assert np.array_equal(templates.templates_array, deserialized_templates.templates_array) +# assert templates.sparsity == deserialized_templates.sparsity +# assert templates.num_units == deserialized_templates.num_units +# assert templates.num_samples == deserialized_templates.num_samples +# assert templates.num_channels == deserialized_templates.num_channels + + +# def test_jsonification(dense_templates): +# templates = dense_templates +# # Serialize to JSON string +# serialized = templates.to_json() + +# # Deserialize back to object +# deserialized = Templates.from_json(serialized) + +# # Check if deserialized object matches original +# assert np.array_equal(templates.templates_array, deserialized.templates_array) +# assert templates.sparsity == deserialized.sparsity +# assert templates.num_units == deserialized.num_units +# assert templates.num_samples == deserialized.num_samples +# assert templates.num_channels == deserialized.num_channels From 961d26979cf339b8799d9f045f20f477589171c4 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 20 Sep 2023 22:39:38 +0200 Subject: [PATCH 12/80] remove slicing --- src/spikeinterface/core/template.py | 20 --------- .../core/tests/test_template_class.py | 41 ------------------- 2 files changed, 61 deletions(-) diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index 4b923db2df..6dbfb881f6 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -66,23 +66,3 @@ def from_dict(cls, data): def to_json(self): return json.dumps(self.to_dict()) - - @classmethod - def from_json(cls, json_str): - return cls.from_dict(json.loads(json_str)) - - # Implementing the slicing/indexing behavior as numpy - def __getitem__(self, index): - return self.templates_array[index] - - def __array__(self): - return self.templates_array - - def __array_ufunc__(self, ufunc, method, *inputs, **kwargs) -> np.ndarray: - # Replace any Templates instances with their ndarray representation - inputs = tuple(inp.templates_array if isinstance(inp, Templates) else inp for inp in inputs) - - # Apply the ufunc on the transformed inputs - result = getattr(ufunc, method)(*inputs, **kwargs) - - return result diff --git a/src/spikeinterface/core/tests/test_template_class.py b/src/spikeinterface/core/tests/test_template_class.py index 5fc997c6bf..2b6b4c9744 100644 --- a/src/spikeinterface/core/tests/test_template_class.py +++ b/src/spikeinterface/core/tests/test_template_class.py @@ -51,44 +51,3 @@ def test_json_serialization(template_obj): loaded_obj_from_json = Templates.from_json(json_str) assert np.array_equal(obj.templates_array, loaded_obj_from_json.templates_array) - - -# @pytest.fixture -# def dense_templates(): -# num_units = 2 -# num_samples = 4 -# num_channels = 3 -# templates_shape = (num_units, num_samples, num_channels) -# templates_array = np.arange(num_units * num_samples * num_channels).reshape(templates_shape) - -# return Templates(templates_array=templates_array) - - -# def test_pickle(dense_templates): -# templates = dense_templates - -# # Serialize and deserialize the object -# serialized = pickle.dumps(templates) -# deserialized_templates = pickle.loads(serialized) - -# assert np.array_equal(templates.templates_array, deserialized_templates.templates_array) -# assert templates.sparsity == deserialized_templates.sparsity -# assert templates.num_units == deserialized_templates.num_units -# assert templates.num_samples == deserialized_templates.num_samples -# assert templates.num_channels == deserialized_templates.num_channels - - -# def test_jsonification(dense_templates): -# templates = dense_templates -# # Serialize to JSON string -# serialized = templates.to_json() - -# # Deserialize back to object -# deserialized = Templates.from_json(serialized) - -# # Check if deserialized object matches original -# assert np.array_equal(templates.templates_array, deserialized.templates_array) -# assert templates.sparsity == deserialized.sparsity -# assert templates.num_units == deserialized.num_units -# assert templates.num_samples == deserialized.num_samples -# assert templates.num_channels == deserialized.num_channels From 9ee3a1de6741b156f2a56ca5f7455dd2ebf3b768 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 20 Sep 2023 23:15:08 +0200 Subject: [PATCH 13/80] passing tests --- src/spikeinterface/core/template.py | 41 +++++++++++++- .../core/tests/test_template_class.py | 55 ++++++++++++++----- 2 files changed, 79 insertions(+), 17 deletions(-) diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index 6dbfb881f6..8926281dfe 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -1,6 +1,6 @@ import numpy as np import json -from dataclasses import dataclass, field +from dataclasses import dataclass, field, astuple from .sparsity import ChannelSparsity @@ -21,7 +21,7 @@ class Templates: nafter: int = field(init=False) ms_before: float = field(init=False) ms_after: float = field(init=False) - sparsity: ChannelSparsity = field(init=False) + sparsity: ChannelSparsity = field(init=False, default=None) def __post_init__(self): self.num_units, self.num_samples = self.templates_array.shape[:2] @@ -66,3 +66,40 @@ def from_dict(cls, data): def to_json(self): return json.dumps(self.to_dict()) + + @classmethod + def from_json(cls, json_str): + return cls.from_dict(json.loads(json_str)) + + def __eq__(self, other): + """Necessary to compare arrays""" + if not isinstance(other, Templates): + return False + + # Convert the instances to tuples + self_tuple = astuple(self) + other_tuple = astuple(other) + + # Compare each field + for s_field, o_field in zip(self_tuple, other_tuple): + if isinstance(s_field, np.ndarray): + if not np.array_equal(s_field, o_field): + return False + + elif isinstance(s_field, ChannelSparsity): + if not isinstance(o_field, ChannelSparsity): + return False + + # (maybe ChannelSparsity should have its own __eq__ method) + # Compare ChannelSparsity by its mask, unit_ids and channel_ids + if not np.array_equal(s_field.mask, o_field.mask): + return False + if not np.array_equal(s_field.unit_ids, o_field.unit_ids): + return False + if not np.array_equal(s_field.channel_ids, o_field.channel_ids): + return False + else: + if s_field != o_field: + return False + + return True diff --git a/src/spikeinterface/core/tests/test_template_class.py b/src/spikeinterface/core/tests/test_template_class.py index 2b6b4c9744..b395f82d49 100644 --- a/src/spikeinterface/core/tests/test_template_class.py +++ b/src/spikeinterface/core/tests/test_template_class.py @@ -4,8 +4,33 @@ from spikeinterface.core.template import Templates -@pytest.mark.parametrize("template_obj", ["dense", "sparse"]) -def get_template_object(template_obj): +def compare_instances(obj1, obj2): + if not isinstance(obj1, Templates) or not isinstance(obj2, Templates): + raise ValueError("Both objects must be instances of the Templates class") + + for attr, value1 in obj1.__dict__.items(): + value2 = getattr(obj2, attr, None) + + # Comparing numpy arrays + if isinstance(value1, np.ndarray): + if not np.array_equal(value1, value2): + print(f"Attribute '{attr}' is not equal!") + print(f"Value from obj1:\n{value1}") + print(f"Value from obj2:\n{value2}") + return False + # Comparing other types + elif value1 != value2: + print(f"Attribute '{attr}' is not equal!") + print(f"Value from obj1: {value1}") + print(f"Value from obj2: {value2}") + return False + + print("All attributes are equal!") + return True + + +@pytest.mark.parametrize("template_object", ["dense", "sparse"]) +def generate_template_fixture(template_object): num_units = 2 num_samples = 5 num_channels = 3 @@ -15,7 +40,7 @@ def get_template_object(template_obj): sampling_frequency = 30_000 nbefore = 2 - if template_obj == "dense": + if template_object == "dense": return Templates(templates_array=templates_array, sampling_frequency=sampling_frequency, nbefore=nbefore) else: # sparse sparsity_mask = np.array([[True, False, True], [False, True, False]]) @@ -27,27 +52,27 @@ def get_template_object(template_obj): ) -@pytest.mark.parametrize("template_obj", ["dense", "sparse"]) -def test_pickle_serialization(template_obj, tmp_path): - obj = get_template_object(template_obj) +@pytest.mark.parametrize("template_object", ["dense", "sparse"]) +def test_pickle_serialization(template_object, tmp_path): + template = generate_template_fixture(template_object) # Dump to pickle pkl_path = tmp_path / "templates.pkl" with open(pkl_path, "wb") as f: - pickle.dump(obj, f) + pickle.dump(template, f) # Load from pickle with open(pkl_path, "rb") as f: - loaded_obj = pickle.load(f) + template_reloaded = pickle.load(f) - assert np.array_equal(obj.templates_array, loaded_obj.templates_array) + assert template == template_reloaded -@pytest.mark.parametrize("template_obj", ["dense", "sparse"]) -def test_json_serialization(template_obj): - obj = get_template_object(template_obj) +@pytest.mark.parametrize("template_object", ["dense", "sparse"]) +def test_json_serialization(template_object): + template = generate_template_fixture(template_object) - json_str = obj.to_json() - loaded_obj_from_json = Templates.from_json(json_str) + json_str = template.to_json() + template_reloaded_from_json = Templates.from_json(json_str) - assert np.array_equal(obj.templates_array, loaded_obj_from_json.templates_array) + assert template == template_reloaded_from_json From 9d7c9ac134151cb688ec4316c9bc2ff4af9f62ae Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 20 Sep 2023 23:21:13 +0200 Subject: [PATCH 14/80] add densification and sparsification methods --- src/spikeinterface/core/template.py | 12 +++++++++ .../core/tests/test_template_class.py | 25 ------------------- 2 files changed, 12 insertions(+), 25 deletions(-) diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index 8926281dfe..70c7d90527 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -64,6 +64,18 @@ def from_dict(cls, data): nbefore=data["nbefore"], ) + def get_dense_templates(self) -> np.ndarray: + if self.sparsity is None: + return self.templates_array + else: + self.sparsity.to_dense(self.templates_array) + + def get_sparse_templates(self) -> np.ndarray: + if self.sparsity is None: + raise ValueError("Can't return sparse templates without passing a sparsity mask") + else: + self.sparsity.to_sparse(self.templates_array) + def to_json(self): return json.dumps(self.to_dict()) diff --git a/src/spikeinterface/core/tests/test_template_class.py b/src/spikeinterface/core/tests/test_template_class.py index b395f82d49..f92e636d93 100644 --- a/src/spikeinterface/core/tests/test_template_class.py +++ b/src/spikeinterface/core/tests/test_template_class.py @@ -4,31 +4,6 @@ from spikeinterface.core.template import Templates -def compare_instances(obj1, obj2): - if not isinstance(obj1, Templates) or not isinstance(obj2, Templates): - raise ValueError("Both objects must be instances of the Templates class") - - for attr, value1 in obj1.__dict__.items(): - value2 = getattr(obj2, attr, None) - - # Comparing numpy arrays - if isinstance(value1, np.ndarray): - if not np.array_equal(value1, value2): - print(f"Attribute '{attr}' is not equal!") - print(f"Value from obj1:\n{value1}") - print(f"Value from obj2:\n{value2}") - return False - # Comparing other types - elif value1 != value2: - print(f"Attribute '{attr}' is not equal!") - print(f"Value from obj1: {value1}") - print(f"Value from obj2: {value2}") - return False - - print("All attributes are equal!") - return True - - @pytest.mark.parametrize("template_object", ["dense", "sparse"]) def generate_template_fixture(template_object): num_units = 2 From 6e1027bb75abc9ae1977ec0a6b87c7f951588576 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Mon, 25 Sep 2023 10:51:11 +0200 Subject: [PATCH 15/80] adding tests for sparsity and density --- src/spikeinterface/core/sparsity.py | 3 +- src/spikeinterface/core/template.py | 53 ++++++++++++-- .../core/tests/test_template_class.py | 71 ++++++++++++++++--- 3 files changed, 109 insertions(+), 18 deletions(-) diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index 455edcfc80..70b412d487 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -191,7 +191,7 @@ def densify_waveforms(self, waveforms: np.ndarray, unit_id: str) -> np.ndarray: assert self.are_waveforms_sparse(waveforms=waveforms, unit_id=unit_id), assert_msg densified_shape = waveforms.shape[:-1] + (self.num_channels,) - densified_waveforms = np.zeros(densified_shape, dtype=waveforms.dtype) + densified_waveforms = np.zeros(shape=densified_shape, dtype=waveforms.dtype) densified_waveforms[..., non_zero_indices] = waveforms return densified_waveforms @@ -202,6 +202,7 @@ def are_waveforms_dense(self, waveforms: np.ndarray) -> bool: def are_waveforms_sparse(self, waveforms: np.ndarray, unit_id: str) -> bool: non_zero_indices = self.unit_id_to_channel_indices[unit_id] num_active_channels = len(non_zero_indices) + return waveforms.shape[-1] == num_active_channels @classmethod diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index 70c7d90527..54070053eb 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -65,16 +65,52 @@ def from_dict(cls, data): ) def get_dense_templates(self) -> np.ndarray: + # Assumes and object without a sparsity mask already has dense templates if self.sparsity is None: return self.templates_array - else: - self.sparsity.to_dense(self.templates_array) + + dense_waveforms = np.zeros(shape=(self.num_units, self.num_samples, self.num_channels)) + for unit_index, unit_id in enumerate(self.unit_ids): + num_active_channels = self.sparsity.mask[unit_index].sum() + waveforms = self.templates_array[unit_index, :, :num_active_channels] + dense_waveforms[unit_index, ...] = self.sparsity.densify_waveforms(waveforms=waveforms, unit_id=unit_id) + + return dense_waveforms def get_sparse_templates(self) -> np.ndarray: + # Objects without sparsity mask don't have sparsity and therefore can't return sparse templates if self.sparsity is None: raise ValueError("Can't return sparse templates without passing a sparsity mask") - else: - self.sparsity.to_sparse(self.templates_array) + + # Waveforms are already sparse + if not self.sparsity.are_waveforms_dense(self.templates_array): + return self.templates_array + + max_num_active_channels = self.sparsity.max_num_active_channels + sparse_waveforms = np.zeros(shape=(self.num_units, self.num_samples, max_num_active_channels)) + for unit_index, unit_id in enumerate(self.unit_ids): + waveforms = self.templates_array[unit_index, ...] + sparse_waveforms[unit_index, ...] = self.sparsity.sparsify_waveforms(waveforms=waveforms, unit_id=unit_id) + + return sparse_waveforms + + def are_templates_sparse(self) -> bool: + if self.sparsity is None: + return False + + if self.templates_array.shape[-1] == self.num_channels: + return False + + unit_is_sparse = True + for unit_index, unit_id in enumerate(self.unit_ids): + non_zero_indices = self.sparsity.unit_id_to_channel_indices[unit_id] + num_active_channels = len(non_zero_indices) + waveforms = self.templates_array[unit_index, :, :num_active_channels] + unit_is_sparse = self.sparsity.are_waveforms_sparse(waveforms, unit_id=unit_id) + if not unit_is_sparse: + return False + + return unit_is_sparse def to_json(self): return json.dumps(self.to_dict()) @@ -84,7 +120,11 @@ def from_json(cls, json_str): return cls.from_dict(json.loads(json_str)) def __eq__(self, other): - """Necessary to compare arrays""" + """ + Necessary to compare templates because they naturally compare objects by equality of their fields + which is not possible for numpy arrays so we override the __eq__ method to compare each numpy arrays + with np.array_equal + """ if not isinstance(other, Templates): return False @@ -97,12 +137,11 @@ def __eq__(self, other): if isinstance(s_field, np.ndarray): if not np.array_equal(s_field, o_field): return False - + # Compare ChannelSparsity by its mask, unit_ids and channel_ids. Maybe ChannelSparsity should have its own __eq__ method elif isinstance(s_field, ChannelSparsity): if not isinstance(o_field, ChannelSparsity): return False - # (maybe ChannelSparsity should have its own __eq__ method) # Compare ChannelSparsity by its mask, unit_ids and channel_ids if not np.array_equal(s_field.mask, o_field.mask): return False diff --git a/src/spikeinterface/core/tests/test_template_class.py b/src/spikeinterface/core/tests/test_template_class.py index f92e636d93..cf0dffe532 100644 --- a/src/spikeinterface/core/tests/test_template_class.py +++ b/src/spikeinterface/core/tests/test_template_class.py @@ -2,10 +2,10 @@ import numpy as np import pickle from spikeinterface.core.template import Templates +from spikeinterface.core.sparsity import ChannelSparsity -@pytest.mark.parametrize("template_object", ["dense", "sparse"]) -def generate_template_fixture(template_object): +def generate_test_template(template_type): num_units = 2 num_samples = 5 num_channels = 3 @@ -15,10 +15,29 @@ def generate_template_fixture(template_object): sampling_frequency = 30_000 nbefore = 2 - if template_object == "dense": + if template_type == "dense": return Templates(templates_array=templates_array, sampling_frequency=sampling_frequency, nbefore=nbefore) - else: # sparse + elif template_type == "sparse": # sparse with sparse templates sparsity_mask = np.array([[True, False, True], [False, True, False]]) + sparsity = ChannelSparsity( + mask=sparsity_mask, unit_ids=np.arange(num_units), channel_ids=np.arange(num_channels) + ) + + sparse_templates_array = np.zeros(shape=(num_units, num_samples, sparsity.max_num_active_channels)) + for unit_index in range(num_units): + template = templates_array[unit_index, ...] + sparse_template = sparsity.sparsify_waveforms(waveforms=template, unit_id=unit_index) + sparse_templates_array[unit_index, :, : sparse_template.shape[1]] = sparse_template + + return Templates( + templates_array=sparse_templates_array, + sparsity_mask=sparsity_mask, + sampling_frequency=sampling_frequency, + nbefore=nbefore, + ) + elif template_type == "sparse_with_dense_templates": # sparse with dense templates + sparsity_mask = np.array([[True, False, True], [False, True, False]]) + return Templates( templates_array=templates_array, sparsity_mask=sparsity_mask, @@ -27,9 +46,9 @@ def generate_template_fixture(template_object): ) -@pytest.mark.parametrize("template_object", ["dense", "sparse"]) -def test_pickle_serialization(template_object, tmp_path): - template = generate_template_fixture(template_object) +@pytest.mark.parametrize("template_type", ["dense", "sparse", "sparse_with_dense_templates"]) +def test_pickle_serialization(template_type, tmp_path): + template = generate_test_template(template_type) # Dump to pickle pkl_path = tmp_path / "templates.pkl" @@ -43,11 +62,43 @@ def test_pickle_serialization(template_object, tmp_path): assert template == template_reloaded -@pytest.mark.parametrize("template_object", ["dense", "sparse"]) -def test_json_serialization(template_object): - template = generate_template_fixture(template_object) +@pytest.mark.parametrize("template_type", ["dense", "sparse", "sparse_with_dense_templates"]) +def test_json_serialization(template_type): + template = generate_test_template(template_type) json_str = template.to_json() template_reloaded_from_json = Templates.from_json(json_str) assert template == template_reloaded_from_json + + +@pytest.mark.parametrize("template_type", ["dense", "sparse", "sparse_with_dense_templates"]) +def test_get_dense_templates(template_type): + template = generate_test_template(template_type) + dense_templates = template.get_dense_templates() + assert dense_templates.shape == (template.num_units, template.num_samples, template.num_channels) + + +@pytest.mark.parametrize("template_type", ["dense", "sparse", "sparse_with_dense_templates"]) +def test_get_sparse_templates(template_type): + template = generate_test_template(template_type) + + if template_type == "dense": + with pytest.raises(ValueError): + sparse_templates = template.get_sparse_templates() + elif template_type == "sparse": + sparse_templates = template.get_sparse_templates() + assert sparse_templates.shape == ( + template.num_units, + template.num_samples, + template.sparsity.max_num_active_channels, + ) + assert template.are_templates_sparse() + elif template_type == "sparse_with_dense_templates": + sparse_templates = template.get_sparse_templates() + assert sparse_templates.shape == ( + template.num_units, + template.num_samples, + template.sparsity.max_num_active_channels, + ) + assert not template.are_templates_sparse() From d05e67da479716c7aac619dc0e4d3080a0a6d4a5 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 28 Sep 2023 11:24:45 +0200 Subject: [PATCH 16/80] prohibit dense templates when passing sparsity mask --- src/spikeinterface/core/sparsity.py | 29 ++++++----- src/spikeinterface/core/template.py | 50 ++++++++++--------- .../core/tests/test_template_class.py | 23 ++++----- 3 files changed, 55 insertions(+), 47 deletions(-) diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index 5bc2e51e8a..f2da16b757 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -150,11 +150,8 @@ def sparsify_waveforms(self, waveforms: np.ndarray, unit_id: str | int) -> np.nd or a single sparsified waveform (template) with shape (num_samples, num_active_channels). """ - assert_msg = ( - "Waveforms must be dense to sparsify them. " - f"Their last dimension {waveforms.shape[-1]} must be equal to the number of channels {self.num_channels}" - ) - assert self.are_waveforms_dense(waveforms=waveforms), assert_msg + if self.are_waveforms_sparse(waveforms=waveforms, unit_id=unit_id): + return waveforms non_zero_indices = self.unit_id_to_channel_indices[unit_id] sparsified_waveforms = waveforms[..., non_zero_indices] @@ -185,16 +182,20 @@ def densify_waveforms(self, waveforms: np.ndarray, unit_id: str | int) -> np.nda """ non_zero_indices = self.unit_id_to_channel_indices[unit_id] + num_active_channels = len(non_zero_indices) - assert_msg = ( - "Waveforms do not seem to be be in the sparsity shape of this unit_id. The number of active channels is " - f"{len(non_zero_indices)} but the waveform has {waveforms.shape[-1]} active channels." - ) - assert self.are_waveforms_sparse(waveforms=waveforms, unit_id=unit_id), assert_msg + if not self.are_waveforms_sparse(waveforms=waveforms, unit_id=unit_id): + error_message = ( + "Waveforms do not seem to be be in the sparsity shape of this unit_id. The number of active channels is " + f"{num_active_channels} but the waveform has non zero values outsies of those active channels: \n" + f"{waveforms[..., num_active_channels:]}" + ) + raise ValueError(error_message) densified_shape = waveforms.shape[:-1] + (self.num_channels,) densified_waveforms = np.zeros(shape=densified_shape, dtype=waveforms.dtype) - densified_waveforms[..., non_zero_indices] = waveforms + # Maps the active channels to their original indices + densified_waveforms[..., non_zero_indices] = waveforms[..., :num_active_channels] return densified_waveforms @@ -205,7 +206,11 @@ def are_waveforms_sparse(self, waveforms: np.ndarray, unit_id: str | int) -> boo non_zero_indices = self.unit_id_to_channel_indices[unit_id] num_active_channels = len(non_zero_indices) - return waveforms.shape[-1] == num_active_channels + # If any channel is non-zero outside of the active channels, then the waveforms are not sparse + excess_zeros = waveforms[..., num_active_channels:].sum() + are_sparse = excess_zeros == 0 + + return are_sparse @classmethod def from_unit_id_to_channel_ids(cls, unit_id_to_channel_ids, unit_ids, channel_ids): diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index 54070053eb..c0d4869d5e 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -32,6 +32,8 @@ def __post_init__(self): self.nafter = self.num_samples - self.nbefore - 1 self.ms_before = self.nbefore / self.sampling_frequency * 1000 self.ms_after = self.nafter / self.sampling_frequency * 1000 + + # Initialize sparsity object if self.channel_ids is None: self.channel_ids = np.arange(self.num_channels) if self.unit_ids is None: @@ -43,6 +45,10 @@ def __post_init__(self): channel_ids=self.channel_ids, ) + # Test that the templates are sparse if a sparsity mask is passed + if not self._are_passed_templates_sparse(): + raise ValueError("Sparsity mask passed but the templates are not sparse") + def to_dict(self): return { "templates_array": self.templates_array.tolist(), @@ -69,10 +75,11 @@ def get_dense_templates(self) -> np.ndarray: if self.sparsity is None: return self.templates_array - dense_waveforms = np.zeros(shape=(self.num_units, self.num_samples, self.num_channels)) + dense_shape = (self.num_units, self.num_samples, self.num_channels) + dense_waveforms = np.zeros(dense=dense_shape, dtype=self.templates_array.dtype) + for unit_index, unit_id in enumerate(self.unit_ids): - num_active_channels = self.sparsity.mask[unit_index].sum() - waveforms = self.templates_array[unit_index, :, :num_active_channels] + waveforms = self.templates_array[unit_index, ...] dense_waveforms[unit_index, ...] = self.sparsity.densify_waveforms(waveforms=waveforms, unit_id=unit_id) return dense_waveforms @@ -82,12 +89,9 @@ def get_sparse_templates(self) -> np.ndarray: if self.sparsity is None: raise ValueError("Can't return sparse templates without passing a sparsity mask") - # Waveforms are already sparse - if not self.sparsity.are_waveforms_dense(self.templates_array): - return self.templates_array - max_num_active_channels = self.sparsity.max_num_active_channels - sparse_waveforms = np.zeros(shape=(self.num_units, self.num_samples, max_num_active_channels)) + sparse_shape = (self.num_units, self.num_samples, max_num_active_channels) + sparse_waveforms = np.zeros(shape=sparse_shape, dtype=self.templates_array.dtype) for unit_index, unit_id in enumerate(self.unit_ids): waveforms = self.templates_array[unit_index, ...] sparse_waveforms[unit_index, ...] = self.sparsity.sparsify_waveforms(waveforms=waveforms, unit_id=unit_id) @@ -95,22 +99,20 @@ def get_sparse_templates(self) -> np.ndarray: return sparse_waveforms def are_templates_sparse(self) -> bool: - if self.sparsity is None: - return False - - if self.templates_array.shape[-1] == self.num_channels: - return False + return self.sparsity is not None - unit_is_sparse = True + def _are_passed_templates_sparse(self) -> bool: + """ + Tests if the templates passed to the init constructor are sparse + """ + are_templates_sparse = True for unit_index, unit_id in enumerate(self.unit_ids): - non_zero_indices = self.sparsity.unit_id_to_channel_indices[unit_id] - num_active_channels = len(non_zero_indices) - waveforms = self.templates_array[unit_index, :, :num_active_channels] - unit_is_sparse = self.sparsity.are_waveforms_sparse(waveforms, unit_id=unit_id) - if not unit_is_sparse: + waveforms = self.templates_array[unit_index, ...] + are_templates_sparse = self.sparsity.are_waveforms_sparse(waveforms, unit_id=unit_id) + if not are_templates_sparse: return False - return unit_is_sparse + return are_templates_sparse def to_json(self): return json.dumps(self.to_dict()) @@ -122,8 +124,8 @@ def from_json(cls, json_str): def __eq__(self, other): """ Necessary to compare templates because they naturally compare objects by equality of their fields - which is not possible for numpy arrays so we override the __eq__ method to compare each numpy arrays - with np.array_equal + which is not possible for numpy arrays. Therefore, we override the __eq__ method to compare each numpy arrays + using np.array_equal instead """ if not isinstance(other, Templates): return False @@ -137,7 +139,9 @@ def __eq__(self, other): if isinstance(s_field, np.ndarray): if not np.array_equal(s_field, o_field): return False - # Compare ChannelSparsity by its mask, unit_ids and channel_ids. Maybe ChannelSparsity should have its own __eq__ method + + # Compare ChannelSparsity by its mask, unit_ids and channel_ids. + # Maybe ChannelSparsity should have its own __eq__ method elif isinstance(s_field, ChannelSparsity): if not isinstance(o_field, ChannelSparsity): return False diff --git a/src/spikeinterface/core/tests/test_template_class.py b/src/spikeinterface/core/tests/test_template_class.py index cf0dffe532..b1244ab0d1 100644 --- a/src/spikeinterface/core/tests/test_template_class.py +++ b/src/spikeinterface/core/tests/test_template_class.py @@ -23,6 +23,7 @@ def generate_test_template(template_type): mask=sparsity_mask, unit_ids=np.arange(num_units), channel_ids=np.arange(num_channels) ) + # Create sparse templates sparse_templates_array = np.zeros(shape=(num_units, num_samples, sparsity.max_num_active_channels)) for unit_index in range(num_units): template = templates_array[unit_index, ...] @@ -35,6 +36,7 @@ def generate_test_template(template_type): sampling_frequency=sampling_frequency, nbefore=nbefore, ) + elif template_type == "sparse_with_dense_templates": # sparse with dense templates sparsity_mask = np.array([[True, False, True], [False, True, False]]) @@ -46,7 +48,7 @@ def generate_test_template(template_type): ) -@pytest.mark.parametrize("template_type", ["dense", "sparse", "sparse_with_dense_templates"]) +@pytest.mark.parametrize("template_type", ["dense", "sparse"]) def test_pickle_serialization(template_type, tmp_path): template = generate_test_template(template_type) @@ -62,7 +64,7 @@ def test_pickle_serialization(template_type, tmp_path): assert template == template_reloaded -@pytest.mark.parametrize("template_type", ["dense", "sparse", "sparse_with_dense_templates"]) +@pytest.mark.parametrize("template_type", ["dense", "sparse"]) def test_json_serialization(template_type): template = generate_test_template(template_type) @@ -72,14 +74,14 @@ def test_json_serialization(template_type): assert template == template_reloaded_from_json -@pytest.mark.parametrize("template_type", ["dense", "sparse", "sparse_with_dense_templates"]) +@pytest.mark.parametrize("template_type", ["dense", "sparse"]) def test_get_dense_templates(template_type): template = generate_test_template(template_type) dense_templates = template.get_dense_templates() assert dense_templates.shape == (template.num_units, template.num_samples, template.num_channels) -@pytest.mark.parametrize("template_type", ["dense", "sparse", "sparse_with_dense_templates"]) +@pytest.mark.parametrize("template_type", ["dense", "sparse"]) def test_get_sparse_templates(template_type): template = generate_test_template(template_type) @@ -94,11 +96,8 @@ def test_get_sparse_templates(template_type): template.sparsity.max_num_active_channels, ) assert template.are_templates_sparse() - elif template_type == "sparse_with_dense_templates": - sparse_templates = template.get_sparse_templates() - assert sparse_templates.shape == ( - template.num_units, - template.num_samples, - template.sparsity.max_num_active_channels, - ) - assert not template.are_templates_sparse() + + +def test_initialization_fail_with_dense_templates(): + with pytest.raises(ValueError, match="Sparsity mask passed but the templates are not sparse"): + template = generate_test_template(template_type="sparse_with_dense_templates") From cc8a5236e00072985cac399e078877c597b87ecd Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 28 Sep 2023 11:27:42 +0200 Subject: [PATCH 17/80] add docstring --- src/spikeinterface/core/template.py | 35 +++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index c0d4869d5e..dc6e0a5070 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -6,6 +6,41 @@ @dataclass(kw_only=True) class Templates: + """ + A class to represent spike templates, which can be either dense or sparse. + + Attributes + ---------- + templates_array : np.ndarray + Array containing the templates data. + sampling_frequency : float + Sampling frequency of the templates. + nbefore : int + Number of samples before the spike peak. + sparsity_mask : np.ndarray, optional + Binary array indicating the sparsity pattern of the templates. + If `None`, the templates are considered dense. + channel_ids : np.ndarray, optional + Array of channel IDs. If `None`, defaults to an array of increasing integers. + unit_ids : np.ndarray, optional + Array of unit IDs. If `None`, defaults to an array of increasing integers. + num_units : int + Number of units in the templates. Automatically determined from `templates_array`. + num_samples : int + Number of samples per template. Automatically determined from `templates_array`. + num_channels : int + Number of channels in the templates. Automatically determined from `templates_array` or `sparsity_mask`. + nafter : int + Number of samples after the spike peak. Calculated as `num_samples - nbefore - 1`. + ms_before : float + Milliseconds before the spike peak. Calculated from `nbefore` and `sampling_frequency`. + ms_after : float + Milliseconds after the spike peak. Calculated from `nafter` and `sampling_frequency`. + sparsity : ChannelSparsity, optional + Object representing the sparsity pattern of the templates. Calculated from `sparsity_mask`. + If `None`, the templates are considered dense. + """ + templates_array: np.ndarray sampling_frequency: float nbefore: int From 73e95627b9015d20addd78e7b59f697ce5acb335 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 28 Sep 2023 11:34:13 +0200 Subject: [PATCH 18/80] alessio remark about nafter definition --- src/spikeinterface/core/sparsity.py | 3 +-- src/spikeinterface/core/template.py | 12 +++++++----- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index f2da16b757..1593b6c9e4 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -208,9 +208,8 @@ def are_waveforms_sparse(self, waveforms: np.ndarray, unit_id: str | int) -> boo # If any channel is non-zero outside of the active channels, then the waveforms are not sparse excess_zeros = waveforms[..., num_active_channels:].sum() - are_sparse = excess_zeros == 0 - return are_sparse + return int(excess_zeros) == 0 @classmethod def from_unit_id_to_channel_ids(cls, unit_id_to_channel_ids, unit_ids, channel_ids): diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index dc6e0a5070..bc4f7bae80 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -64,7 +64,9 @@ def __post_init__(self): self.num_channels = self.templates_array.shape[2] else: self.num_channels = self.sparsity_mask.shape[1] - self.nafter = self.num_samples - self.nbefore - 1 + + # Time and frames domain information + self.nafter = self.num_samples - self.nbefore self.ms_before = self.nbefore / self.sampling_frequency * 1000 self.ms_after = self.nafter / self.sampling_frequency * 1000 @@ -110,8 +112,8 @@ def get_dense_templates(self) -> np.ndarray: if self.sparsity is None: return self.templates_array - dense_shape = (self.num_units, self.num_samples, self.num_channels) - dense_waveforms = np.zeros(dense=dense_shape, dtype=self.templates_array.dtype) + densified_shape = (self.num_units, self.num_samples, self.num_channels) + dense_waveforms = np.zeros(dense=densified_shape, dtype=self.templates_array.dtype) for unit_index, unit_id in enumerate(self.unit_ids): waveforms = self.templates_array[unit_index, ...] @@ -125,8 +127,8 @@ def get_sparse_templates(self) -> np.ndarray: raise ValueError("Can't return sparse templates without passing a sparsity mask") max_num_active_channels = self.sparsity.max_num_active_channels - sparse_shape = (self.num_units, self.num_samples, max_num_active_channels) - sparse_waveforms = np.zeros(shape=sparse_shape, dtype=self.templates_array.dtype) + sparisfied_shape = (self.num_units, self.num_samples, max_num_active_channels) + sparse_waveforms = np.zeros(shape=sparisfied_shape, dtype=self.templates_array.dtype) for unit_index, unit_id in enumerate(self.unit_ids): waveforms = self.templates_array[unit_index, ...] sparse_waveforms[unit_index, ...] = self.sparsity.sparsify_waveforms(waveforms=waveforms, unit_id=unit_id) From 52c333b1a5bdcc5975b2c67c2341f02e80ca74de Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 28 Sep 2023 11:50:25 +0200 Subject: [PATCH 19/80] fix mistake --- src/spikeinterface/core/template.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index bc4f7bae80..e8c0f83f50 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -113,7 +113,7 @@ def get_dense_templates(self) -> np.ndarray: return self.templates_array densified_shape = (self.num_units, self.num_samples, self.num_channels) - dense_waveforms = np.zeros(dense=densified_shape, dtype=self.templates_array.dtype) + dense_waveforms = np.zeros(shape=densified_shape, dtype=self.templates_array.dtype) for unit_index, unit_id in enumerate(self.unit_ids): waveforms = self.templates_array[unit_index, ...] From 2fb79ccf7fe3ad6a1efe5675171a82e54894aedd Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 24 Oct 2023 09:31:01 +0200 Subject: [PATCH 20/80] Update src/spikeinterface/core/sparsity.py Co-authored-by: Alessio Buccino --- src/spikeinterface/core/sparsity.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index 1593b6c9e4..0a8c165ba5 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -186,8 +186,8 @@ def densify_waveforms(self, waveforms: np.ndarray, unit_id: str | int) -> np.nda if not self.are_waveforms_sparse(waveforms=waveforms, unit_id=unit_id): error_message = ( - "Waveforms do not seem to be be in the sparsity shape of this unit_id. The number of active channels is " - f"{num_active_channels} but the waveform has non zero values outsies of those active channels: \n" + "Waveforms do not seem to be in the sparsity shape for this unit_id. The number of active channels is " + f"{num_active_channels}, but the waveform has non-zero values outsies of those active channels: \n" f"{waveforms[..., num_active_channels:]}" ) raise ValueError(error_message) From 437695c8b5c2f04331b0eac3cc7c876697b0709a Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 24 Oct 2023 10:01:47 +0200 Subject: [PATCH 21/80] changes --- src/spikeinterface/core/sparsity.py | 10 ++++ src/spikeinterface/core/template.py | 71 +++++++++++++---------------- 2 files changed, 41 insertions(+), 40 deletions(-) diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index 1593b6c9e4..990687ca04 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -211,6 +211,16 @@ def are_waveforms_sparse(self, waveforms: np.ndarray, unit_id: str | int) -> boo return int(excess_zeros) == 0 + def sparisfy_templates(self, templates_array: np.ndarray) -> np.ndarray: + max_num_active_channels = self.max_num_active_channels + sparisfied_shape = (self.num_units, self.num_samples, max_num_active_channels) + sparse_templates = np.zeros(shape=sparisfied_shape, dtype=templates_array.dtype) + for unit_index, unit_id in enumerate(self.unit_ids): + template = templates_array[unit_index, ...] + sparse_templates[unit_index, ...] = self.sparsify_waveforms(waveforms=template, unit_id=unit_id) + + return sparse_templates + @classmethod def from_unit_id_to_channel_ids(cls, unit_id_to_channel_ids, unit_ids, channel_ids): """ diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index e8c0f83f50..ed71b6d2ea 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -4,7 +4,7 @@ from .sparsity import ChannelSparsity -@dataclass(kw_only=True) +@dataclass class Templates: """ A class to represent spike templates, which can be either dense or sparse. @@ -18,7 +18,7 @@ class Templates: nbefore : int Number of samples before the spike peak. sparsity_mask : np.ndarray, optional - Binary array indicating the sparsity pattern of the templates. + Boolean array indicating the sparsity pattern of the templates. If `None`, the templates are considered dense. channel_ids : np.ndarray, optional Array of channel IDs. If `None`, defaults to an array of increasing integers. @@ -49,6 +49,8 @@ class Templates: channel_ids: np.ndarray = None unit_ids: np.ndarray = None + check_template_array_and_sparsity_mask_are_consistentency: bool = True + num_units: int = field(init=False) num_samples: int = field(init=False) num_channels: int = field(init=False) @@ -83,29 +85,9 @@ def __post_init__(self): ) # Test that the templates are sparse if a sparsity mask is passed - if not self._are_passed_templates_sparse(): - raise ValueError("Sparsity mask passed but the templates are not sparse") - - def to_dict(self): - return { - "templates_array": self.templates_array.tolist(), - "sparsity_mask": None if self.sparsity_mask is None else self.sparsity_mask.tolist(), - "channel_ids": self.channel_ids.tolist(), - "unit_ids": self.unit_ids.tolist(), - "sampling_frequency": self.sampling_frequency, - "nbefore": self.nbefore, - } - - @classmethod - def from_dict(cls, data): - return cls( - templates_array=np.array(data["templates_array"]), - sparsity_mask=None if data["sparsity_mask"] is None else np.array(data["sparsity_mask"]), - channel_ids=np.array(data["channel_ids"]), - unit_ids=np.array(data["unit_ids"]), - sampling_frequency=data["sampling_frequency"], - nbefore=data["nbefore"], - ) + if self.check_template_array_and_sparsity_mask_are_consistentency: + if not self._are_passed_templates_sparse(): + raise ValueError("Sparsity mask passed but the templates are not sparse") def get_dense_templates(self) -> np.ndarray: # Assumes and object without a sparsity mask already has dense templates @@ -121,20 +103,6 @@ def get_dense_templates(self) -> np.ndarray: return dense_waveforms - def get_sparse_templates(self) -> np.ndarray: - # Objects without sparsity mask don't have sparsity and therefore can't return sparse templates - if self.sparsity is None: - raise ValueError("Can't return sparse templates without passing a sparsity mask") - - max_num_active_channels = self.sparsity.max_num_active_channels - sparisfied_shape = (self.num_units, self.num_samples, max_num_active_channels) - sparse_waveforms = np.zeros(shape=sparisfied_shape, dtype=self.templates_array.dtype) - for unit_index, unit_id in enumerate(self.unit_ids): - waveforms = self.templates_array[unit_index, ...] - sparse_waveforms[unit_index, ...] = self.sparsity.sparsify_waveforms(waveforms=waveforms, unit_id=unit_id) - - return sparse_waveforms - def are_templates_sparse(self) -> bool: return self.sparsity is not None @@ -151,8 +119,31 @@ def _are_passed_templates_sparse(self) -> bool: return are_templates_sparse + def to_dict(self): + return { + "templates_array": self.templates_array, + "sparsity_mask": None if self.sparsity_mask is None else self.sparsity_mask, + "channel_ids": self.channel_ids, + "unit_ids": self.unit_ids, + "sampling_frequency": self.sampling_frequency, + "nbefore": self.nbefore, + } + + @classmethod + def from_dict(cls, data): + return cls( + templates_array=np.asarray(data["templates_array"]), + sparsity_mask=None if data["sparsity_mask"] is None else np.asarray(data["sparsity_mask"]), + channel_ids=np.asarray(data["channel_ids"]), + unit_ids=np.asarray(data["unit_ids"]), + sampling_frequency=data["sampling_frequency"], + nbefore=data["nbefore"], + ) + def to_json(self): - return json.dumps(self.to_dict()) + from spikeinterface.core.core_tools import SIJsonEncoder + + return json.dumps(self.to_dict(), cls=SIJsonEncoder) @classmethod def from_json(cls, json_str): From 600f20f9b7465f7f398097207d036e8ee4ff8d92 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 24 Oct 2023 10:29:30 +0200 Subject: [PATCH 22/80] modify docstring --- src/spikeinterface/core/template.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index ed71b6d2ea..909d47acfc 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -9,7 +9,7 @@ class Templates: """ A class to represent spike templates, which can be either dense or sparse. - Attributes + It is constructed with the following parameters: ---------- templates_array : np.ndarray Array containing the templates data. @@ -17,13 +17,21 @@ class Templates: Sampling frequency of the templates. nbefore : int Number of samples before the spike peak. - sparsity_mask : np.ndarray, optional + sparsity_mask : np.ndarray, optional (default=None) Boolean array indicating the sparsity pattern of the templates. If `None`, the templates are considered dense. - channel_ids : np.ndarray, optional + channel_ids : np.ndarray, optional (default=None) Array of channel IDs. If `None`, defaults to an array of increasing integers. - unit_ids : np.ndarray, optional + unit_ids : np.ndarray, optional (default=None) Array of unit IDs. If `None`, defaults to an array of increasing integers. + check_for_consistent_sparsity : bool, optional (default=True) + When passing a sparsity_mask, this checks that the templates array is also sparse and that it matches the + structure fo the sparsity_masl. + + The following attributes are avaialble after construction: + + Attributes + ---------- num_units : int Number of units in the templates. Automatically determined from `templates_array`. num_samples : int @@ -49,7 +57,7 @@ class Templates: channel_ids: np.ndarray = None unit_ids: np.ndarray = None - check_template_array_and_sparsity_mask_are_consistentency: bool = True + check_for_consistent_sparsity: bool = True num_units: int = field(init=False) num_samples: int = field(init=False) @@ -85,7 +93,7 @@ def __post_init__(self): ) # Test that the templates are sparse if a sparsity mask is passed - if self.check_template_array_and_sparsity_mask_are_consistentency: + if self.check_for_consistent_sparsity: if not self._are_passed_templates_sparse(): raise ValueError("Sparsity mask passed but the templates are not sparse") From a1e6eaec457a55c6043bedcc1349176aa4e57f0c Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 24 Oct 2023 10:34:21 +0200 Subject: [PATCH 23/80] remove tests for get_sparse_templates --- .../core/tests/test_template_class.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/src/spikeinterface/core/tests/test_template_class.py b/src/spikeinterface/core/tests/test_template_class.py index b1244ab0d1..40bb3f2b34 100644 --- a/src/spikeinterface/core/tests/test_template_class.py +++ b/src/spikeinterface/core/tests/test_template_class.py @@ -81,23 +81,6 @@ def test_get_dense_templates(template_type): assert dense_templates.shape == (template.num_units, template.num_samples, template.num_channels) -@pytest.mark.parametrize("template_type", ["dense", "sparse"]) -def test_get_sparse_templates(template_type): - template = generate_test_template(template_type) - - if template_type == "dense": - with pytest.raises(ValueError): - sparse_templates = template.get_sparse_templates() - elif template_type == "sparse": - sparse_templates = template.get_sparse_templates() - assert sparse_templates.shape == ( - template.num_units, - template.num_samples, - template.sparsity.max_num_active_channels, - ) - assert template.are_templates_sparse() - - def test_initialization_fail_with_dense_templates(): with pytest.raises(ValueError, match="Sparsity mask passed but the templates are not sparse"): template = generate_test_template(template_type="sparse_with_dense_templates") From aa08f1ba85aea5eec5835de04618061aa9df244b Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 24 Oct 2023 12:18:09 +0200 Subject: [PATCH 24/80] Update src/spikeinterface/core/template.py Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/core/template.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index 909d47acfc..8ebd0a75f5 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -28,7 +28,7 @@ class Templates: When passing a sparsity_mask, this checks that the templates array is also sparse and that it matches the structure fo the sparsity_masl. - The following attributes are avaialble after construction: + The following attributes are available after construction: Attributes ---------- From ea2a8a03c43b3ca444e02e50b837b1d6b7a51bd9 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 24 Oct 2023 12:18:32 +0200 Subject: [PATCH 25/80] Update src/spikeinterface/core/template.py Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/core/template.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index 8ebd0a75f5..e6556a68f7 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -9,7 +9,7 @@ class Templates: """ A class to represent spike templates, which can be either dense or sparse. - It is constructed with the following parameters: + Parameters ---------- templates_array : np.ndarray Array containing the templates data. From 1be2ce144e6f7f574a278dcf7775f3a11090a2a3 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 30 Oct 2023 18:57:51 +0100 Subject: [PATCH 26/80] Add a minimum distance in generate_unit_locations. --- src/spikeinterface/core/generate.py | 46 ++++++++++++++++--- .../core/tests/test_generate.py | 36 ++++++++++++++- 2 files changed, 74 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 44ea02d32c..003b9cb5b5 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1333,15 +1333,49 @@ def generate_channel_locations(num_channels, num_columns, contact_spacing_um): return channel_locations -def generate_unit_locations(num_units, channel_locations, margin_um=20.0, minimum_z=5.0, maximum_z=40.0, seed=None): +def generate_unit_locations(num_units, channel_locations, margin_um=20.0, minimum_z=5.0, maximum_z=40.0, + minimum_distance=20., max_iteration=100, distance_strict=False, seed=None): rng = np.random.default_rng(seed=seed) units_locations = np.zeros((num_units, 3), dtype="float32") - for dim in (0, 1): - lim0 = np.min(channel_locations[:, dim]) - margin_um - lim1 = np.max(channel_locations[:, dim]) + margin_um - units_locations[:, dim] = rng.uniform(lim0, lim1, size=num_units) + + minimum_x, maximum_x = np.min(channel_locations[:, 0]) - margin_um, np.max(channel_locations[:, 0]) + margin_um + minimum_y, maximum_y = np.min(channel_locations[:, 1]) - margin_um, np.max(channel_locations[:, 1]) + margin_um + + units_locations[:, 0] = rng.uniform(minimum_x, maximum_x, size=num_units) + units_locations[:, 1] = rng.uniform(minimum_y, maximum_y, size=num_units) units_locations[:, 2] = rng.uniform(minimum_z, maximum_z, size=num_units) + if minimum_distance is not None: + solution_found = False + renew_inds = None + for i in range(max_iteration): + distances = np.linalg.norm(units_locations[:, np.newaxis] - units_locations[np.newaxis, :], axis=2) + inds0, inds1 = np.nonzero(distances < minimum_distance) + mask = inds0 != inds1 + inds0 = inds0[mask] + inds1 = inds1[mask] + + if inds0.size > 0: + if renew_inds is None: + renew_inds = np.unique(inds0) + else: + # random only bad ones in the previous set + renew_inds = renew_inds[np.isin(renew_inds, np.unique(inds0))] + + units_locations[:, 0][renew_inds] = rng.uniform(minimum_x, maximum_x, size=renew_inds.size) + units_locations[:, 1][renew_inds] = rng.uniform(minimum_y, maximum_y, size=renew_inds.size) + units_locations[:, 2][renew_inds] = rng.uniform(minimum_z, maximum_z, size=renew_inds.size) + else: + solution_found = True + break + + if not solution_found: + if distance_strict: + raise ValueError(f"generate_unit_locations(): no solution for {minimum_distance=} and {max_iteration=} " + "You can use distance_strict=False or reduce minimum distance") + else: + warnings.warn(f"generate_unit_locations(): no solution for {minimum_distance=} and {max_iteration=}") + return units_locations @@ -1366,7 +1400,7 @@ def generate_ground_truth_recording( upsample_vector=None, generate_sorting_kwargs=dict(firing_rates=15, refractory_period_ms=4.0), noise_kwargs=dict(noise_level=5.0, strategy="on_the_fly"), - generate_unit_locations_kwargs=dict(margin_um=10.0, minimum_z=5.0, maximum_z=50.0), + generate_unit_locations_kwargs=dict(margin_um=10.0, minimum_z=5.0, maximum_z=50.0, minimum_distance=20), generate_templates_kwargs=dict(), dtype="float32", seed=None, diff --git a/src/spikeinterface/core/tests/test_generate.py b/src/spikeinterface/core/tests/test_generate.py index 9a9c61766f..582120ac51 100644 --- a/src/spikeinterface/core/tests/test_generate.py +++ b/src/spikeinterface/core/tests/test_generate.py @@ -4,6 +4,8 @@ import numpy as np from spikeinterface.core import load_extractor, extract_waveforms + +from probeinterface import generate_multi_columns_probe from spikeinterface.core.generate import ( generate_recording, generate_sorting, @@ -289,6 +291,35 @@ def test_generate_single_fake_waveform(): # plt.show() +def test_generate_unit_locations(): + + seed = 0 + + probe = generate_multi_columns_probe(num_columns=2, num_contact_per_column=20, xpitch=20, ypitch=20) + channel_locations = probe.contact_positions + + num_units = 100 + minimum_distance = 20. + unit_locations = generate_unit_locations(num_units, channel_locations, + margin_um=20.0, minimum_z=5.0, maximum_z=40.0, + minimum_distance=minimum_distance, max_iteration=500, + distance_strict=False, seed=seed) + distances = np.linalg.norm(unit_locations[:, np.newaxis] - unit_locations[np.newaxis, :], axis=2) + dist_flat = np.triu(distances, k=1).flatten() + dist_flat = dist_flat[dist_flat>0] + assert np.all(dist_flat > minimum_distance) + + + # import matplotlib.pyplot as plt + # fig, ax = plt.subplots() + # ax.hist(dist_flat, bins = np.arange(0, 400, 10)) + # fig, ax = plt.subplots() + # from probeinterface.plotting import plot_probe + # plot_probe(probe, ax=ax) + # ax.scatter(unit_locations[:, 0], unit_locations[:,1], marker='*', s=20) + # plt.show() + + def test_generate_templates(): seed = 0 @@ -297,7 +328,7 @@ def test_generate_templates(): num_units = 10 margin_um = 15.0 channel_locations = generate_channel_locations(num_chans, num_columns, 20.0) - unit_locations = generate_unit_locations(num_units, channel_locations, margin_um, seed) + unit_locations = generate_unit_locations(num_units, channel_locations, margin_um=margin_um, seed=seed) sampling_frequency = 30000.0 ms_before = 1.0 @@ -436,7 +467,8 @@ def test_generate_ground_truth_recording(): # test_noise_generator_consistency_after_dump(strategy, None) # test_generate_recording() # test_generate_single_fake_waveform() + test_generate_unit_locations() # test_generate_templates() # test_inject_templates() # test_generate_ground_truth_recording() - test_generate_sorting_with_spikes_on_borders() + # test_generate_sorting_with_spikes_on_borders() From 4f8bd73fd08798671fd252ad9a23ea40e12ef7e3 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 2 Nov 2023 11:08:11 +0100 Subject: [PATCH 27/80] Update src/spikeinterface/core/template.py Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/core/template.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index e6556a68f7..8beb6b46b1 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -17,7 +17,7 @@ class Templates: Sampling frequency of the templates. nbefore : int Number of samples before the spike peak. - sparsity_mask : np.ndarray, optional (default=None) + sparsity_mask : np.ndarray or None, default: None Boolean array indicating the sparsity pattern of the templates. If `None`, the templates are considered dense. channel_ids : np.ndarray, optional (default=None) From e52dd28d7fbc512e98bdbb28f4db32930b1cd73f Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 2 Nov 2023 11:09:00 +0100 Subject: [PATCH 28/80] docstring compliance --- src/spikeinterface/core/template.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index 8beb6b46b1..e6372c7082 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -20,11 +20,11 @@ class Templates: sparsity_mask : np.ndarray or None, default: None Boolean array indicating the sparsity pattern of the templates. If `None`, the templates are considered dense. - channel_ids : np.ndarray, optional (default=None) + channel_ids : np.ndarray, optional default: None Array of channel IDs. If `None`, defaults to an array of increasing integers. - unit_ids : np.ndarray, optional (default=None) + unit_ids : np.ndarray, optional default: None Array of unit IDs. If `None`, defaults to an array of increasing integers. - check_for_consistent_sparsity : bool, optional (default=True) + check_for_consistent_sparsity : bool, optional default: None When passing a sparsity_mask, this checks that the templates array is also sparse and that it matches the structure fo the sparsity_masl. From 5c2cb9b45f7c6a3cd4a308214e289f5673067101 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 2 Nov 2023 15:51:22 +0100 Subject: [PATCH 29/80] use python methods instead of parsing --- src/spikeinterface/core/base.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index b51bace55f..b737358eef 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -366,27 +366,24 @@ def to_dict( new_kwargs[name] = transform_extractors_to_dict(value) kwargs = new_kwargs - class_name = str(type(self)).replace("", "") + + module_import_path = self.__class__.__module__ + class_name_no_path = self.__class__.__name__ + class_name = f"{module_import_path}.{class_name_no_path}" # e.g. 'spikeinterface.core.generate.AClass' module = class_name.split(".")[0] - imported_module = importlib.import_module(module) - try: - version = imported_module.__version__ - except AttributeError: - version = "unknown" + imported_module = importlib.import_module(module) + spike_interface_version = getattr(imported_module, "__version__", "unknown") dump_dict = { "class": class_name, "module": module, "kwargs": kwargs, - "version": version, + "version": spike_interface_version, "relative_paths": (relative_to is not None), } - try: - dump_dict["version"] = imported_module.__version__ - except AttributeError: - dump_dict["version"] = "unknown" + dump_dict["version"] = spike_interface_version if include_annotations: dump_dict["annotations"] = self._annotations @@ -805,7 +802,7 @@ def save_to_folder(self, name=None, folder=None, verbose=True, **save_kwargs): * explicit sub-folder, implicit base-folder : `extractor.save(name="extarctor_name")` * generated: `extractor.save()` - The second option saves to subfolder "extarctor_name" in + The second option saves to subfolder "extractor_name" in "get_global_tmp_folder()". You can set the global tmp folder with: "set_global_tmp_folder("path-to-global-folder")" From 3015725a24dfba872c2448409ce9585dbcc392ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20WYNGAARD?= Date: Thu, 2 Nov 2023 16:56:33 +0100 Subject: [PATCH 30/80] `WaveformExtractor.is_extension` --> `has_extension` --- src/spikeinterface/core/waveform_extractor.py | 17 +++++++++++------ src/spikeinterface/exporters/report.py | 8 ++++---- src/spikeinterface/exporters/to_phy.py | 8 ++++---- .../postprocessing/principal_component.py | 2 +- .../tests/common_extension_tests.py | 2 +- .../tests/test_principal_component.py | 2 +- .../qualitymetrics/misc_metrics.py | 10 +++++----- .../qualitymetrics/quality_metric_calculator.py | 8 ++++---- .../tests/test_quality_metric_calculator.py | 2 +- src/spikeinterface/widgets/base.py | 2 +- src/spikeinterface/widgets/unit_summary.py | 10 +++++----- 11 files changed, 38 insertions(+), 33 deletions(-) diff --git a/src/spikeinterface/core/waveform_extractor.py b/src/spikeinterface/core/waveform_extractor.py index c97a727340..491fdd5ebe 100644 --- a/src/spikeinterface/core/waveform_extractor.py +++ b/src/spikeinterface/core/waveform_extractor.py @@ -506,7 +506,7 @@ def get_recording_property(self, key) -> np.ndarray: def get_sorting_property(self, key) -> np.ndarray: return self.sorting.get_property(key) - def get_extension_class(self, extension_name): + def get_extension_class(self, extension_name: str): """ Get extension class from name and check if registered. @@ -525,7 +525,7 @@ def get_extension_class(self, extension_name): ext_class = extensions_dict[extension_name] return ext_class - def is_extension(self, extension_name) -> bool: + def has_extension(self, extension_name: str) -> bool: """ Check if the extension exists in memory or in the folder. @@ -556,7 +556,12 @@ def is_extension(self, extension_name) -> bool: and "params" in self._waveforms_root[extension_name].attrs.keys() ) - def load_extension(self, extension_name): + def is_extension(self, extension_name) -> bool: + warn("WaveformExtractor.is_extension is deprecated! Use `has_extension` instead.") + assert False + return self.has_extension(extension_name) + + def load_extension(self, extension_name: str): """ Load an extension from its name. The module of the extension must be loaded and registered. @@ -572,7 +577,7 @@ def load_extension(self, extension_name): The loaded instance of the extension """ if self.folder is not None and extension_name not in self._loaded_extensions: - if self.is_extension(extension_name): + if self.has_extension(extension_name): ext_class = self.get_extension_class(extension_name) ext = ext_class.load(self.folder, self) if extension_name not in self._loaded_extensions: @@ -588,7 +593,7 @@ def delete_extension(self, extension_name) -> None: extension_name: str The extension name. """ - assert self.is_extension(extension_name), f"The extension {extension_name} is not available" + assert self.has_extension(extension_name), f"The extension {extension_name} is not available" del self._loaded_extensions[extension_name] if self.folder is not None and (self.folder / extension_name).is_dir(): shutil.rmtree(self.folder / extension_name) @@ -610,7 +615,7 @@ def get_available_extension_names(self): """ extension_names_in_folder = [] for extension_class in self.extensions: - if self.is_extension(extension_class.extension_name): + if self.has_extension(extension_class.extension_name): extension_names_in_folder.append(extension_class.extension_name) return extension_names_in_folder diff --git a/src/spikeinterface/exporters/report.py b/src/spikeinterface/exporters/report.py index 57a5ab0166..8b14930859 100644 --- a/src/spikeinterface/exporters/report.py +++ b/src/spikeinterface/exporters/report.py @@ -51,7 +51,7 @@ def export_report( unit_ids = sorting.unit_ids # load or compute spike_amplitudes - if we.is_extension("spike_amplitudes"): + if we.has_extension("spike_amplitudes"): spike_amplitudes = we.load_extension("spike_amplitudes").get_data(outputs="by_unit") elif force_computation: spike_amplitudes = compute_spike_amplitudes(we, peak_sign=peak_sign, outputs="by_unit", **job_kwargs) @@ -62,7 +62,7 @@ def export_report( ) # load or compute quality_metrics - if we.is_extension("quality_metrics"): + if we.has_extension("quality_metrics"): metrics = we.load_extension("quality_metrics").get_data() elif force_computation: metrics = compute_quality_metrics(we) @@ -73,7 +73,7 @@ def export_report( ) # load or compute correlograms - if we.is_extension("correlograms"): + if we.has_extension("correlograms"): correlograms, bins = we.load_extension("correlograms").get_data() elif force_computation: correlograms, bins = compute_correlograms(we, window_ms=100.0, bin_ms=1.0) @@ -84,7 +84,7 @@ def export_report( ) # pre-compute unit locations if not done - if not we.is_extension("unit_locations"): + if not we.has_extension("unit_locations"): unit_locations = compute_unit_locations(we) output_folder = Path(output_folder).absolute() diff --git a/src/spikeinterface/exporters/to_phy.py b/src/spikeinterface/exporters/to_phy.py index ecc5b316ec..607aa3e846 100644 --- a/src/spikeinterface/exporters/to_phy.py +++ b/src/spikeinterface/exporters/to_phy.py @@ -196,7 +196,7 @@ def export_to_phy( templates[unit_ind, :, :][:, : len(chan_inds)] = template templates_ind[unit_ind, : len(chan_inds)] = chan_inds - if waveform_extractor.is_extension("similarity"): + if waveform_extractor.has_extension("similarity"): tmc = waveform_extractor.load_extension("similarity") template_similarity = tmc.get_data() else: @@ -219,7 +219,7 @@ def export_to_phy( np.save(str(output_folder / "channel_groups.npy"), channel_groups) if compute_amplitudes: - if waveform_extractor.is_extension("spike_amplitudes"): + if waveform_extractor.has_extension("spike_amplitudes"): sac = waveform_extractor.load_extension("spike_amplitudes") amplitudes = sac.get_data(outputs="concatenated") else: @@ -231,7 +231,7 @@ def export_to_phy( np.save(str(output_folder / "amplitudes.npy"), amplitudes) if compute_pc_features: - if waveform_extractor.is_extension("principal_components"): + if waveform_extractor.has_extension("principal_components"): pc = waveform_extractor.load_extension("principal_components") else: pc = compute_principal_components( @@ -264,7 +264,7 @@ def export_to_phy( channel_group = pd.DataFrame({"cluster_id": [i for i in range(len(unit_ids))], "channel_group": unit_groups}) channel_group.to_csv(output_folder / "cluster_channel_group.tsv", sep="\t", index=False) - if waveform_extractor.is_extension("quality_metrics"): + if waveform_extractor.has_extension("quality_metrics"): qm = waveform_extractor.load_extension("quality_metrics") qm_data = qm.get_data() for column_name in qm_data.columns: diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index cf32e79b25..effd87007f 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -750,7 +750,7 @@ def compute_principal_components( >>> pc.run_for_all_spikes(file_path="all_pca_projections.npy") """ - if load_if_exists and waveform_extractor.is_extension(WaveformPrincipalComponent.extension_name): + if load_if_exists and waveform_extractor.has_extension(WaveformPrincipalComponent.extension_name): pc = waveform_extractor.load_extension(WaveformPrincipalComponent.extension_name) else: pc = WaveformPrincipalComponent.create(waveform_extractor) diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index b539bbd5d4..2bef246bc2 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -143,7 +143,7 @@ def _test_extension_folder(self, we, in_memory=False): # reload as an extension from we assert self.extension_class.extension_name in we.get_available_extension_names() - assert we.is_extension(self.extension_class.extension_name) + assert we.has_extension(self.extension_class.extension_name) ext = we.load_extension(self.extension_class.extension_name) assert isinstance(ext, self.extension_class) for ext_name in self.extension_data_names: diff --git a/src/spikeinterface/postprocessing/tests/test_principal_component.py b/src/spikeinterface/postprocessing/tests/test_principal_component.py index 49591d9b89..f5e315b18f 100644 --- a/src/spikeinterface/postprocessing/tests/test_principal_component.py +++ b/src/spikeinterface/postprocessing/tests/test_principal_component.py @@ -135,7 +135,7 @@ def test_project_new(self): from sklearn.decomposition import IncrementalPCA we = self.we1 - if we.is_extension("principal_components"): + if we.has_extension("principal_components"): we.delete_extension("principal_components") we_cp = we.select_units(we.unit_ids, self.cache_folder / "toy_waveforms_1seg_cp") diff --git a/src/spikeinterface/qualitymetrics/misc_metrics.py b/src/spikeinterface/qualitymetrics/misc_metrics.py index 5c734b9100..9dab06124b 100644 --- a/src/spikeinterface/qualitymetrics/misc_metrics.py +++ b/src/spikeinterface/qualitymetrics/misc_metrics.py @@ -201,7 +201,7 @@ def compute_snrs( snrs : dict Computed signal to noise ratio for each unit. """ - if waveform_extractor.is_extension("noise_levels"): + if waveform_extractor.has_extension("noise_levels"): noise_levels = waveform_extractor.load_extension("noise_levels").get_data() else: if random_chunk_kwargs_dict is None: @@ -687,7 +687,7 @@ def compute_amplitude_cv_metrics( if unit_ids is None: unit_ids = sorting.unit_ids - if waveform_extractor.is_extension(amplitude_extension): + if waveform_extractor.has_extension(amplitude_extension): sac = waveform_extractor.load_extension(amplitude_extension) amps = sac.get_data(outputs="concatenated") if amplitude_extension == "spike_amplitudes": @@ -803,7 +803,7 @@ def compute_amplitude_cutoffs( spike_amplitudes = None invert_amplitudes = False - if waveform_extractor.is_extension("spike_amplitudes"): + if waveform_extractor.has_extension("spike_amplitudes"): amp_calculator = waveform_extractor.load_extension("spike_amplitudes") spike_amplitudes = amp_calculator.get_data(outputs="by_unit") if amp_calculator._params["peak_sign"] == "pos": @@ -881,7 +881,7 @@ def compute_amplitude_medians(waveform_extractor, peak_sign="neg", unit_ids=None extremum_channels_ids = get_template_extremum_channel(waveform_extractor, peak_sign=peak_sign) spike_amplitudes = None - if waveform_extractor.is_extension("spike_amplitudes"): + if waveform_extractor.has_extension("spike_amplitudes"): amp_calculator = waveform_extractor.load_extension("spike_amplitudes") spike_amplitudes = amp_calculator.get_data(outputs="by_unit") @@ -974,7 +974,7 @@ def compute_drift_metrics( if unit_ids is None: unit_ids = sorting.unit_ids - if waveform_extractor.is_extension("spike_locations"): + if waveform_extractor.has_extension("spike_locations"): locs_calculator = waveform_extractor.load_extension("spike_locations") spike_locations = locs_calculator.get_data(outputs="concatenated") spike_locations_by_unit = locs_calculator.get_data(outputs="by_unit") diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index 53309db282..90a1f7206e 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -42,14 +42,14 @@ def _set_params( if metric_names is None: metric_names = list(_misc_metric_name_to_func.keys()) # if PC is available, PC metrics are automatically added to the list - if self.waveform_extractor.is_extension("principal_components"): + if self.waveform_extractor.has_extension("principal_components"): # by default 'nearest_neightbor' is removed because too slow pc_metrics = _possible_pc_metric_names.copy() pc_metrics.remove("nn_isolation") pc_metrics.remove("nn_noise_overlap") metric_names += pc_metrics # if spike_locations are not available, drift is removed from the list - if not self.waveform_extractor.is_extension("spike_locations"): + if not self.waveform_extractor.has_extension("spike_locations"): if "drift" in metric_names: metric_names.remove("drift") @@ -130,7 +130,7 @@ def _run(self, verbose, **job_kwargs): # metrics based on PCs pc_metric_names = [k for k in metric_names if k in _possible_pc_metric_names] if len(pc_metric_names) > 0 and not self._params["skip_pc_metrics"]: - if not self.waveform_extractor.is_extension("principal_components"): + if not self.waveform_extractor.has_extension("principal_components"): raise ValueError("waveform_principal_component must be provied") pc_extension = self.waveform_extractor.load_extension("principal_components") pc_metrics = calculate_pc_metrics( @@ -216,7 +216,7 @@ def compute_quality_metrics( metrics: pandas.DataFrame Data frame with the computed metrics """ - if load_if_exists and waveform_extractor.is_extension(QualityMetricCalculator.extension_name): + if load_if_exists and waveform_extractor.has_extension(QualityMetricCalculator.extension_name): qmc = waveform_extractor.load_extension(QualityMetricCalculator.extension_name) else: qmc = QualityMetricCalculator(waveform_extractor) diff --git a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py index eb8317e4df..b601e5d6d8 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py @@ -88,7 +88,7 @@ def test_metrics(self): we = self.we_long # avoid NaNs - if we.is_extension("spike_amplitudes"): + if we.has_extension("spike_amplitudes"): we.delete_extension("spike_amplitudes") # without PC diff --git a/src/spikeinterface/widgets/base.py b/src/spikeinterface/widgets/base.py index a5d3cb2429..6ff837065b 100644 --- a/src/spikeinterface/widgets/base.py +++ b/src/spikeinterface/widgets/base.py @@ -107,7 +107,7 @@ def check_extensions(waveform_extractor, extensions): error_msg = "" raise_error = False for extension in extensions: - if not waveform_extractor.is_extension(extension): + if not waveform_extractor.has_extension(extension): raise_error = True error_msg += ( f"The {extension} waveform extension is required for this widget. " diff --git a/src/spikeinterface/widgets/unit_summary.py b/src/spikeinterface/widgets/unit_summary.py index 35fde07326..aa280ad658 100644 --- a/src/spikeinterface/widgets/unit_summary.py +++ b/src/spikeinterface/widgets/unit_summary.py @@ -80,13 +80,13 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): fig = self.figure nrows = 2 ncols = 3 - if we.is_extension("correlograms") or we.is_extension("spike_amplitudes"): + if we.has_extension("correlograms") or we.has_extension("spike_amplitudes"): ncols += 1 - if we.is_extension("spike_amplitudes"): + if we.has_extension("spike_amplitudes"): nrows += 1 gs = fig.add_gridspec(nrows, ncols) - if we.is_extension("unit_locations"): + if we.has_extension("unit_locations"): ax1 = fig.add_subplot(gs[:2, 0]) # UnitLocationsPlotter().do_plot(dp.plot_data_unit_locations, ax=ax1) w = UnitLocationsWidget( @@ -129,7 +129,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ) ax3.set_ylabel(None) - if we.is_extension("correlograms"): + if we.has_extension("correlograms"): ax4 = fig.add_subplot(gs[:2, 3]) AutoCorrelogramsWidget( we, @@ -142,7 +142,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax4.set_title(None) ax4.set_yticks([]) - if we.is_extension("spike_amplitudes"): + if we.has_extension("spike_amplitudes"): ax5 = fig.add_subplot(gs[2, :3]) ax6 = fig.add_subplot(gs[2, 3]) axes = np.array([ax5, ax6]) From 89387f006812487f3b12f8bbb1c6d42e7cab07a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20WYNGAARD?= Date: Thu, 2 Nov 2023 17:34:24 +0100 Subject: [PATCH 31/80] Removing assert --- src/spikeinterface/core/waveform_extractor.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spikeinterface/core/waveform_extractor.py b/src/spikeinterface/core/waveform_extractor.py index 491fdd5ebe..1453b4156f 100644 --- a/src/spikeinterface/core/waveform_extractor.py +++ b/src/spikeinterface/core/waveform_extractor.py @@ -558,7 +558,6 @@ def has_extension(self, extension_name: str) -> bool: def is_extension(self, extension_name) -> bool: warn("WaveformExtractor.is_extension is deprecated! Use `has_extension` instead.") - assert False return self.has_extension(extension_name) def load_extension(self, extension_name: str): From 9fecf89ab6e2a3669310e2bfad49b8b6955c6036 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 3 Nov 2023 12:00:20 +0100 Subject: [PATCH 32/80] (test) fix maxwell tests --- .github/workflows/full-test-with-codecov.yml | 2 ++ .github/workflows/full-test.yml | 2 ++ src/spikeinterface/extractors/tests/test_neoextractors.py | 2 +- 3 files changed, 5 insertions(+), 1 deletion(-) diff --git a/.github/workflows/full-test-with-codecov.yml b/.github/workflows/full-test-with-codecov.yml index a5561c2ffc..08e1ee6e1a 100644 --- a/.github/workflows/full-test-with-codecov.yml +++ b/.github/workflows/full-test-with-codecov.yml @@ -52,6 +52,8 @@ jobs: - name: Shows installed packages by pip, git-annex and cached testing files uses: ./.github/actions/show-test-environment - name: run tests + env: + HDF5_PLUGIN_PATH: ${{ github.workspace }}/hdf5_plugin_path_maxwell run: | source ${{ github.workspace }}/test_env/bin/activate pytest -m "not sorters_external" --cov=./ --cov-report xml:./coverage.xml -vv -ra --durations=0 | tee report_full.txt; test ${PIPESTATUS[0]} -eq 0 || exit 1 diff --git a/.github/workflows/full-test.yml b/.github/workflows/full-test.yml index dad42e021b..2f5fa02a0f 100644 --- a/.github/workflows/full-test.yml +++ b/.github/workflows/full-test.yml @@ -132,6 +132,8 @@ jobs: - name: Test core run: ./.github/run_tests.sh core - name: Test extractors + env: + HDF5_PLUGIN_PATH: ${{ github.workspace }}/hdf5_plugin_path_maxwell if: ${{ steps.modules-changed.outputs.EXTRACTORS_CHANGED == 'true' || steps.modules-changed.outputs.CORE_CHANGED == 'true' }} run: ./.github/run_tests.sh "extractors and not streaming_extractors" - name: Test preprocessing diff --git a/src/spikeinterface/extractors/tests/test_neoextractors.py b/src/spikeinterface/extractors/tests/test_neoextractors.py index 64c6499767..f52d3d52bc 100644 --- a/src/spikeinterface/extractors/tests/test_neoextractors.py +++ b/src/spikeinterface/extractors/tests/test_neoextractors.py @@ -278,7 +278,7 @@ class CedRecordingTest(RecordingCommonTestSuite, unittest.TestCase): ] -@pytest.mark.skipif(ON_GITHUB, reason="Maxwell plugin not installed on GitHub") +# @pytest.mark.skipif(ON_GITHUB, reason="Maxwell plugin not installed on GitHub") class MaxwellRecordingTest(RecordingCommonTestSuite, unittest.TestCase): ExtractorClass = MaxwellRecordingExtractor downloads = ["maxwell"] From 634537e58ea3bb742984f2ab8f2a552cfa88334e Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Fri, 3 Nov 2023 12:02:20 +0100 Subject: [PATCH 33/80] improve docstring --- src/spikeinterface/core/base.py | 73 +++++++++++++++++++++++---------- 1 file changed, 52 insertions(+), 21 deletions(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index b737358eef..4a05c8fb1a 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -316,32 +316,63 @@ def to_dict( recursive: bool = False, ) -> dict: """ - Make a nested serialized dictionary out of the extractor. The dictionary produced can be used to re-initialize - an extractor using load_extractor_from_dict(dump_dict) + Construct a nested dictionary representation of the extractor. + + This method facilitates the serialization of the extractor instance by converting it + to a dictionary. The resulting dictionary can be used to re-initialize the extractor + through the `load_extractor_from_dict` function. + + Examples + -------- + >>> dump_dict = original_extractor.to_dict() + >>> reloaded_extractor = load_extractor_from_dict(dump_dict) Parameters ---------- - include_annotations: bool, default: False - If True, all annotations are added to the dict - include_properties: bool, default: False - If True, all properties are added to the dict - relative_to: str, Path, or None, default: None - If not None, files and folders are serialized relative to this path - Used in waveform extractor to maintain relative paths to binary files even if the - containing folder / diretory is moved - folder_metadata: str, Path, or None - Folder with numpy `npy` files containing additional information (e.g. probe in BaseRecording) and properties. - recursive: bool, default: False - If True, all dicitionaries in the kwargs are expanded with `to_dict` as well + include_annotations : bool, optional + Whether to include all annotations in the dictionary, by default False. + include_properties : bool, optional + Whether to include all properties in the dictionary, by default False. + relative_to : Union[str, Path, None], optional + If provided, file and folder paths will be made relative to this path, + enabling portability in folder formats such as the waveform extractor, + by default None. + folder_metadata : Union[str, Path, None], optional + Path to a folder containing additional metadata files (e.g., probe information in BaseRecording) + in numpy `npy` format, by default None. + recursive : bool, optional + If True, recursively apply `to_dict` to dictionaries within the kwargs, by default False. + + Raises + ------ + ValueError + If `relative_to` is specified while `recursive` is False. Returns ------- - dump_dict: dict - A dictionary representation of the extractor. + dict + A dictionary representation of the extractor, with the following structure: + { + "class": , + "module": , (e.g. 'spikeinterface'), + "kwargs": , + "version": , + "relative_paths": , + "annotations": , + "properties": , + "folder_metadata": + } + + Notes + ----- + - The `relative_to` argument only has an effect if `recursive` is set to True. + - The `folder_metadata` argument will be made relative to `relative_to` if both are specified. + - The `version` field in the resulting dictionary reflects the version of the module + from which the extractor class originates. + - The full class attribute above is the full import of the class, e.g. + 'spikeinterface.extractors.neoextractors.spikeglx.SpikeGLXRecordingExtractor' """ - kwargs = self._kwargs - if relative_to and not recursive: raise ValueError("`relative_to` is only possible when `recursive=True`") @@ -373,17 +404,17 @@ def to_dict( module = class_name.split(".")[0] imported_module = importlib.import_module(module) - spike_interface_version = getattr(imported_module, "__version__", "unknown") + module_version = getattr(imported_module, "__version__", "unknown") dump_dict = { "class": class_name, "module": module, "kwargs": kwargs, - "version": spike_interface_version, + "version": module_version, "relative_paths": (relative_to is not None), } - dump_dict["version"] = spike_interface_version + dump_dict["version"] = module_version # Can be spikeinterface, spikefores, etc. if include_annotations: dump_dict["annotations"] = self._annotations From 0544c0ac0888d5a60c6f43ac5030eb4d3d8cbbfa Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Fri, 3 Nov 2023 12:03:35 +0100 Subject: [PATCH 34/80] improve docstring --- src/spikeinterface/core/base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 4a05c8fb1a..867184e63b 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -371,6 +371,8 @@ def to_dict( from which the extractor class originates. - The full class attribute above is the full import of the class, e.g. 'spikeinterface.extractors.neoextractors.spikeglx.SpikeGLXRecordingExtractor' + - The module is usually 'spikeinterface', but can be different for custom extractors such as those of + SpikeForest or any other project that inherits the Extractor class from spikeinterface. """ if relative_to and not recursive: From 8b68729bf403700825efc7c842252d9cd5d210ec Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Fri, 3 Nov 2023 12:05:41 +0100 Subject: [PATCH 35/80] alessio default arguments style --- src/spikeinterface/core/base.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 867184e63b..a229b31b14 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -329,18 +329,18 @@ def to_dict( Parameters ---------- - include_annotations : bool, optional + include_annotations : bool, default: False Whether to include all annotations in the dictionary, by default False. - include_properties : bool, optional + include_properties : bool, default: False Whether to include all properties in the dictionary, by default False. - relative_to : Union[str, Path, None], optional + relative_to : Union[str, Path, None], default: None If provided, file and folder paths will be made relative to this path, enabling portability in folder formats such as the waveform extractor, by default None. - folder_metadata : Union[str, Path, None], optional + folder_metadata : Union[str, Path, None], default: None Path to a folder containing additional metadata files (e.g., probe information in BaseRecording) in numpy `npy` format, by default None. - recursive : bool, optional + recursive : bool, default: False If True, recursively apply `to_dict` to dictionaries within the kwargs, by default False. Raises From 6fd40a841c8b6b26d801ab5976ef51e1c4a19052 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 3 Nov 2023 12:10:39 +0100 Subject: [PATCH 36/80] Remove comment --- src/spikeinterface/extractors/tests/test_neoextractors.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spikeinterface/extractors/tests/test_neoextractors.py b/src/spikeinterface/extractors/tests/test_neoextractors.py index f52d3d52bc..14f94eb20b 100644 --- a/src/spikeinterface/extractors/tests/test_neoextractors.py +++ b/src/spikeinterface/extractors/tests/test_neoextractors.py @@ -278,7 +278,6 @@ class CedRecordingTest(RecordingCommonTestSuite, unittest.TestCase): ] -# @pytest.mark.skipif(ON_GITHUB, reason="Maxwell plugin not installed on GitHub") class MaxwellRecordingTest(RecordingCommonTestSuite, unittest.TestCase): ExtractorClass = MaxwellRecordingExtractor downloads = ["maxwell"] From 242b1a0ce458acd371a4e8b285f4f88ed994f214 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Fri, 3 Nov 2023 12:11:02 +0100 Subject: [PATCH 37/80] clean_refractory_period --- src/spikeinterface/core/base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index a229b31b14..d2d947d299 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -378,6 +378,7 @@ def to_dict( if relative_to and not recursive: raise ValueError("`relative_to` is only possible when `recursive=True`") + kwargs = self._kwargs if recursive: to_dict_kwargs = dict( include_annotations=include_annotations, From fc6f03c19c774e6941b9010a28f80b5efcf25c5a Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Fri, 3 Nov 2023 13:03:31 +0100 Subject: [PATCH 38/80] Update src/spikeinterface/core/base.py Co-authored-by: Alessio Buccino --- src/spikeinterface/core/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index d2d947d299..2347594119 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -330,7 +330,7 @@ def to_dict( Parameters ---------- include_annotations : bool, default: False - Whether to include all annotations in the dictionary, by default False. + Whether to include all annotations in the dictionary include_properties : bool, default: False Whether to include all properties in the dictionary, by default False. relative_to : Union[str, Path, None], default: None From e96b5a0d1cc39c6a2357657495dad323a6732ad1 Mon Sep 17 00:00:00 2001 From: h-mayorquin Date: Sun, 5 Nov 2023 10:24:52 +0100 Subject: [PATCH 39/80] add tests to stream --- .../extractors/tests/test_nwb_s3_extractor.py | 47 +++++++++++++++++-- 1 file changed, 43 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/extractors/tests/test_nwb_s3_extractor.py b/src/spikeinterface/extractors/tests/test_nwb_s3_extractor.py index 71a19f30d3..e7d148fd94 100644 --- a/src/spikeinterface/extractors/tests/test_nwb_s3_extractor.py +++ b/src/spikeinterface/extractors/tests/test_nwb_s3_extractor.py @@ -1,9 +1,11 @@ from pathlib import Path +import pickle import pytest import numpy as np import h5py +from spikeinterface.core.testing import check_recordings_equal, check_sortings_equal from spikeinterface.extractors import NwbRecordingExtractor, NwbSortingExtractor if hasattr(pytest, "global_test_folder"): @@ -15,7 +17,7 @@ @pytest.mark.ros3_test @pytest.mark.streaming_extractors @pytest.mark.skipif("ros3" not in h5py.registered_drivers(), reason="ROS3 driver not installed") -def test_recording_s3_nwb_ros3(): +def test_recording_s3_nwb_ros3(tmp_path): file_path = ( "https://dandi-api-staging-dandisets.s3.amazonaws.com/blobs/5f4/b7a/5f4b7a1f-7b95-4ad8-9579-4df6025371cc" ) @@ -40,9 +42,18 @@ def test_recording_s3_nwb_ros3(): trace_scaled = rec.get_traces(segment_index=segment_index, return_scaled=True, end_frame=2) assert trace_scaled.dtype == "float32" + tmp_file = tmp_path / "test_ros3_recording.pkl" + with open(tmp_file, 'wb') as f: + pickle.dump(rec, f) + + with open(tmp_file, 'rb') as f: + reloaded_recording = pickle.load(f) + + check_recordings_equal(rec, reloaded_recording) + @pytest.mark.streaming_extractors -def test_recording_s3_nwb_fsspec(): +def test_recording_s3_nwb_fsspec(tmp_path): file_path = ( "https://dandi-api-staging-dandisets.s3.amazonaws.com/blobs/5f4/b7a/5f4b7a1f-7b95-4ad8-9579-4df6025371cc" ) @@ -68,10 +79,21 @@ def test_recording_s3_nwb_fsspec(): assert trace_scaled.dtype == "float32" + tmp_file = tmp_path / "test_fsspec_recording.pkl" + with open(tmp_file, 'wb') as f: + pickle.dump(rec, f) + + with open(tmp_file, 'rb') as f: + reloaded_recording = pickle.load(f) + + check_recordings_equal(rec, reloaded_recording) + + + @pytest.mark.ros3_test @pytest.mark.streaming_extractors @pytest.mark.skipif("ros3" not in h5py.registered_drivers(), reason="ROS3 driver not installed") -def test_sorting_s3_nwb_ros3(): +def test_sorting_s3_nwb_ros3(tmp_path): file_path = "https://dandiarchive.s3.amazonaws.com/blobs/84b/aa4/84baa446-cf19-43e8-bdeb-fc804852279b" # we provide the 'sampling_frequency' because the NWB file does not the electrical series sort = NwbSortingExtractor(file_path, sampling_frequency=30000, stream_mode="ros3") @@ -90,9 +112,17 @@ def test_sorting_s3_nwb_ros3(): assert spike_train.dtype == "int64" assert np.all(spike_train >= 0) + tmp_file = tmp_path / "test_ros3_sorting.pkl" + with open(tmp_file, 'wb') as f: + pickle.dump(sort, f) + + with open(tmp_file, 'rb') as f: + reloaded_sorting = pickle.load(f) + + check_sortings_equal(reloaded_sorting, sort) @pytest.mark.streaming_extractors -def test_sorting_s3_nwb_fsspec(): +def test_sorting_s3_nwb_fsspec(tmp_path): file_path = "https://dandiarchive.s3.amazonaws.com/blobs/84b/aa4/84baa446-cf19-43e8-bdeb-fc804852279b" # we provide the 'sampling_frequency' because the NWB file does not the electrical series sort = NwbSortingExtractor( @@ -113,6 +143,15 @@ def test_sorting_s3_nwb_fsspec(): assert spike_train.dtype == "int64" assert np.all(spike_train >= 0) + tmp_file = tmp_path / "test_fsspec_sorting.pkl" + with open(tmp_file, 'wb') as f: + pickle.dump(sort, f) + + with open(tmp_file, 'rb') as f: + reloaded_sorting = pickle.load(f) + + check_sortings_equal(reloaded_sorting, sort) + if __name__ == "__main__": test_recording_s3_nwb_ros3() From 6c863c3dc81920b12b2ffadab76913ef6c0888d5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 5 Nov 2023 10:26:11 +0000 Subject: [PATCH 40/80] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/base.py | 2 +- .../extractors/tests/test_nwb_s3_extractor.py | 35 +++++++++---------- 2 files changed, 18 insertions(+), 19 deletions(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 88b7f12783..1a8674697a 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -27,7 +27,7 @@ class BaseExtractor: """ - default_missing_property_values = {'f': np.nan, "O": None, 'S': "", "U": ""} + default_missing_property_values = {"f": np.nan, "O": None, "S": "", "U": ""} # This replaces the old key_properties # These are annotations/properties that always need to be diff --git a/src/spikeinterface/extractors/tests/test_nwb_s3_extractor.py b/src/spikeinterface/extractors/tests/test_nwb_s3_extractor.py index e7d148fd94..253ca2e4ce 100644 --- a/src/spikeinterface/extractors/tests/test_nwb_s3_extractor.py +++ b/src/spikeinterface/extractors/tests/test_nwb_s3_extractor.py @@ -43,12 +43,12 @@ def test_recording_s3_nwb_ros3(tmp_path): assert trace_scaled.dtype == "float32" tmp_file = tmp_path / "test_ros3_recording.pkl" - with open(tmp_file, 'wb') as f: + with open(tmp_file, "wb") as f: pickle.dump(rec, f) - - with open(tmp_file, 'rb') as f: + + with open(tmp_file, "rb") as f: reloaded_recording = pickle.load(f) - + check_recordings_equal(rec, reloaded_recording) @@ -78,16 +78,14 @@ def test_recording_s3_nwb_fsspec(tmp_path): trace_scaled = rec.get_traces(segment_index=segment_index, return_scaled=True, end_frame=2) assert trace_scaled.dtype == "float32" - tmp_file = tmp_path / "test_fsspec_recording.pkl" - with open(tmp_file, 'wb') as f: + with open(tmp_file, "wb") as f: pickle.dump(rec, f) - - with open(tmp_file, 'rb') as f: + + with open(tmp_file, "rb") as f: reloaded_recording = pickle.load(f) - - check_recordings_equal(rec, reloaded_recording) + check_recordings_equal(rec, reloaded_recording) @pytest.mark.ros3_test @@ -113,14 +111,15 @@ def test_sorting_s3_nwb_ros3(tmp_path): assert np.all(spike_train >= 0) tmp_file = tmp_path / "test_ros3_sorting.pkl" - with open(tmp_file, 'wb') as f: + with open(tmp_file, "wb") as f: pickle.dump(sort, f) - - with open(tmp_file, 'rb') as f: + + with open(tmp_file, "rb") as f: reloaded_sorting = pickle.load(f) - + check_sortings_equal(reloaded_sorting, sort) + @pytest.mark.streaming_extractors def test_sorting_s3_nwb_fsspec(tmp_path): file_path = "https://dandiarchive.s3.amazonaws.com/blobs/84b/aa4/84baa446-cf19-43e8-bdeb-fc804852279b" @@ -144,12 +143,12 @@ def test_sorting_s3_nwb_fsspec(tmp_path): assert np.all(spike_train >= 0) tmp_file = tmp_path / "test_fsspec_sorting.pkl" - with open(tmp_file, 'wb') as f: + with open(tmp_file, "wb") as f: pickle.dump(sort, f) - - with open(tmp_file, 'rb') as f: + + with open(tmp_file, "rb") as f: reloaded_sorting = pickle.load(f) - + check_sortings_equal(reloaded_sorting, sort) From 85f9b050bdeec8405d05798ff8605a7ebbd5c5d0 Mon Sep 17 00:00:00 2001 From: h-mayorquin Date: Sun, 5 Nov 2023 10:46:04 +0100 Subject: [PATCH 41/80] trigger tests --- src/spikeinterface/extractors/nwbextractors.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/extractors/nwbextractors.py b/src/spikeinterface/extractors/nwbextractors.py index f7b445cdb9..010b22975c 100644 --- a/src/spikeinterface/extractors/nwbextractors.py +++ b/src/spikeinterface/extractors/nwbextractors.py @@ -566,12 +566,12 @@ def get_unit_spike_train( start_frame = 0 if end_frame is None: end_frame = np.inf - times = self._nwbfile.units["spike_times"][list(self._nwbfile.units.id[:]).index(unit_id)][:] + spike_times = self._nwbfile.units["spike_times"][list(self._nwbfile.units.id[:]).index(unit_id)][:] if self._timestamps is not None: - frames = np.searchsorted(times, self.timestamps).astype("int64") + frames = np.searchsorted(spike_times, self.timestamps).astype("int64") else: - frames = np.round(times * self._sampling_frequency).astype("int64") + frames = np.round(spike_times * self._sampling_frequency).astype("int64") return frames[(frames >= start_frame) & (frames < end_frame)] From 2b820bf7244f5d8d4b43c87df03d23dfe5828922 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20WYNGAARD?= Date: Mon, 6 Nov 2023 12:30:00 +0100 Subject: [PATCH 42/80] Improvement to `get_empty_units()` --- src/spikeinterface/core/basesorting.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 94b08d8cc3..319fc7cb12 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -364,16 +364,12 @@ def remove_empty_units(self): return self.select_units(non_empty_units) def get_non_empty_unit_ids(self): - non_empty_units = [] - for segment_index in range(self.get_num_segments()): - for unit in self.get_unit_ids(): - if len(self.get_unit_spike_train(unit, segment_index=segment_index)) > 0: - non_empty_units.append(unit) - non_empty_units = np.unique(non_empty_units) - return non_empty_units + num_spikes_per_unit = self.count_num_spikes_per_unit() + + return np.array([unit_id for unit_id in self.unit_ids if num_spikes_per_unit[unit_id] == 0]) def get_empty_unit_ids(self): - unit_ids = self.get_unit_ids() + unit_ids = self.unit_ids empty_units = unit_ids[~np.isin(unit_ids, self.get_non_empty_unit_ids())] return empty_units From 23d2d89858b70cd046f8a3a043b93a11ff2234cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20WYNGAARD?= Date: Mon, 6 Nov 2023 12:37:43 +0100 Subject: [PATCH 43/80] oops --- src/spikeinterface/core/basesorting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 319fc7cb12..87507f799e 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -366,7 +366,7 @@ def remove_empty_units(self): def get_non_empty_unit_ids(self): num_spikes_per_unit = self.count_num_spikes_per_unit() - return np.array([unit_id for unit_id in self.unit_ids if num_spikes_per_unit[unit_id] == 0]) + return np.array([unit_id for unit_id in self.unit_ids if num_spikes_per_unit[unit_id] != 0]) def get_empty_unit_ids(self): unit_ids = self.unit_ids From 2253a45a9dff06cb157600a453bbb235fadc489c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20WYNGAARD?= Date: Mon, 6 Nov 2023 13:27:45 +0100 Subject: [PATCH 44/80] Update src/spikeinterface/core/waveform_extractor.py Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/core/waveform_extractor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/waveform_extractor.py b/src/spikeinterface/core/waveform_extractor.py index 1453b4156f..cc17aea62f 100644 --- a/src/spikeinterface/core/waveform_extractor.py +++ b/src/spikeinterface/core/waveform_extractor.py @@ -557,7 +557,7 @@ def has_extension(self, extension_name: str) -> bool: ) def is_extension(self, extension_name) -> bool: - warn("WaveformExtractor.is_extension is deprecated! Use `has_extension` instead.") + warn("WaveformExtractor.is_extension is deprecated and will be removed in version 0.102.0! Use `has_extension` instead.") return self.has_extension(extension_name) def load_extension(self, extension_name: str): From 7be79220b12a9ba475adbe36c3341dd9cee32cc1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 6 Nov 2023 12:29:30 +0000 Subject: [PATCH 45/80] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/waveform_extractor.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/waveform_extractor.py b/src/spikeinterface/core/waveform_extractor.py index cc17aea62f..dd7c3043c3 100644 --- a/src/spikeinterface/core/waveform_extractor.py +++ b/src/spikeinterface/core/waveform_extractor.py @@ -557,7 +557,9 @@ def has_extension(self, extension_name: str) -> bool: ) def is_extension(self, extension_name) -> bool: - warn("WaveformExtractor.is_extension is deprecated and will be removed in version 0.102.0! Use `has_extension` instead.") + warn( + "WaveformExtractor.is_extension is deprecated and will be removed in version 0.102.0! Use `has_extension` instead." + ) return self.has_extension(extension_name) def load_extension(self, extension_name: str): From 3712da0ab5bfe22cf53598dc04c47c37ef554c02 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Mon, 6 Nov 2023 21:35:01 +0000 Subject: [PATCH 46/80] Handle start / stop frame default `None`. --- src/spikeinterface/sortingcomponents/motion_interpolation.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/spikeinterface/sortingcomponents/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion_interpolation.py index a81212897c..86485aa25d 100644 --- a/src/spikeinterface/sortingcomponents/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion_interpolation.py @@ -387,6 +387,11 @@ def get_traces(self, start_frame, end_frame, channel_indices): ) # times = np.asarray(self.time_vector[start_frame:end_frame]) else: + if start_frame is None: + start_frame = 0 + if end_frame is None: + end_frame = self.get_num_samples() + times = np.arange((end_frame or self.get_num_samples()) - (start_frame or 0), dtype="float64") times /= self.sampling_frequency t0 = start_frame / self.sampling_frequency From 266be6f2861490bd3e8a3d5e3d3a9c49527f6f50 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Mon, 6 Nov 2023 21:36:50 +0000 Subject: [PATCH 47/80] Remove redundant `else` statement. --- .../sortingcomponents/motion_interpolation.py | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion_interpolation.py index 86485aa25d..ba046db85f 100644 --- a/src/spikeinterface/sortingcomponents/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion_interpolation.py @@ -386,18 +386,18 @@ def get_traces(self, start_frame, end_frame, channel_indices): "time_vector for InterpolateMotionRecording do not work because temporal_bins start from 0" ) # times = np.asarray(self.time_vector[start_frame:end_frame]) - else: - if start_frame is None: - start_frame = 0 - if end_frame is None: - end_frame = self.get_num_samples() - - times = np.arange((end_frame or self.get_num_samples()) - (start_frame or 0), dtype="float64") - times /= self.sampling_frequency - t0 = start_frame / self.sampling_frequency - # if self.t_start is not None: - # t0 = t0 + self.t_start - times += t0 + + if start_frame is None: + start_frame = 0 + if end_frame is None: + end_frame = self.get_num_samples() + + times = np.arange((end_frame or self.get_num_samples()) - (start_frame or 0), dtype="float64") + times /= self.sampling_frequency + t0 = start_frame / self.sampling_frequency + # if self.t_start is not None: + # t0 = t0 + self.t_start + times += t0 traces = self.parent_recording_segment.get_traces(start_frame, end_frame, channel_indices=slice(None)) From a4ce54a212b7d5a946d125328b0f0760cfd8bd85 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20WYNGAARD?= Date: Wed, 8 Nov 2023 13:33:01 +0100 Subject: [PATCH 48/80] Better docstring for empty units --- src/spikeinterface/core/basesorting.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 87507f799e..a43eaf6090 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -353,7 +353,8 @@ def remove_units(self, remove_unit_ids): def remove_empty_units(self): """ - Removes units with empty spike trains + Removes units with empty spike trains. + For multi-segments, a unit is considered empty if it contains no spikes in all segments. Returns ------- From dcce25a397e7b365a1bd0de22032b32d48128a4c Mon Sep 17 00:00:00 2001 From: Joe Ziminski <55797454+JoeZiminski@users.noreply.github.com> Date: Wed, 8 Nov 2023 12:48:16 +0000 Subject: [PATCH 49/80] Update src/spikeinterface/sortingcomponents/motion_interpolation.py Co-authored-by: Alessio Buccino --- src/spikeinterface/sortingcomponents/motion_interpolation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion_interpolation.py index ba046db85f..93a8ce62c8 100644 --- a/src/spikeinterface/sortingcomponents/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion_interpolation.py @@ -392,7 +392,7 @@ def get_traces(self, start_frame, end_frame, channel_indices): if end_frame is None: end_frame = self.get_num_samples() - times = np.arange((end_frame or self.get_num_samples()) - (start_frame or 0), dtype="float64") + times = np.arange(end_frame - start_frame, dtype="float64") times /= self.sampling_frequency t0 = start_frame / self.sampling_frequency # if self.t_start is not None: From d2c14375770b1dec6f6b29546bb319726176dffa Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Wed, 8 Nov 2023 14:58:50 -0500 Subject: [PATCH 50/80] add read_npz_sorting function --- src/spikeinterface/extractors/extractorlist.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/spikeinterface/extractors/extractorlist.py b/src/spikeinterface/extractors/extractorlist.py index 235dd705dc..f8198c3d18 100644 --- a/src/spikeinterface/extractors/extractorlist.py +++ b/src/spikeinterface/extractors/extractorlist.py @@ -13,6 +13,7 @@ ZarrRecordingExtractor, read_binary, read_zarr, + read_npz_sorting, ) # sorting/recording/event from neo From 329197618a9b48ef876d1e8b8e79f07f4abf5e49 Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Thu, 9 Nov 2023 10:33:00 +0100 Subject: [PATCH 51/80] Fix compute matching v3 (#2182) * some change to test * another change * another attempt * attempt merge * add condition * add auth * fix test and simpler implementation * small typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * avoid corner cose of doing the matching loop twice * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove n_jobs * Little docs cleanup * Remove internal n_jobs * Remove last internal n_jobs * Apply suggestions from code review * fix test * comment to test * docstring improvements * variable naming * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * new proposal for compute_matching * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: Heberto Mayorquin Co-authored-by: Alessio Buccino --- .../comparison/comparisontools.py | 130 +++++++++--------- .../comparison/paircomparisons.py | 4 +- .../comparison/tests/test_comparisontools.py | 66 ++++++--- 3 files changed, 111 insertions(+), 89 deletions(-) diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index 7a1fb87175..3cd856d662 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -124,12 +124,12 @@ def get_optimized_compute_matching_matrix(): @numba.jit(nopython=True, nogil=True) def compute_matching_matrix( - frames_spike_train1, - frames_spike_train2, + spike_frames_train1, + spike_frames_train2, unit_indices1, unit_indices2, - num_units_sorting1, - num_units_sorting2, + num_units_train1, + num_units_train2, delta_frames, ): """ @@ -137,30 +137,33 @@ def compute_matching_matrix( Given two spike trains, this function finds matching spikes based on a temporal proximity criterion defined by `delta_frames`. The resulting matrix indicates the number of matches between units - in `frames_spike_train1` and `frames_spike_train2`. + in `spike_frames_train1` and `spike_frames_train2`. Parameters ---------- - frames_spike_train1 : ndarray - Array of frames for the first spike train. Should be ordered in ascending order. - frames_spike_train2 : ndarray - Array of frames for the second spike train. Should be ordered in ascending order. + spike_frames_train1 : ndarray + An array of integer frame numbers corresponding to spike times for the first train. Must be in ascending order. + spike_frames_train2 : ndarray + An array of integer frame numbers corresponding to spike times for the second train. Must be in ascending order. unit_indices1 : ndarray - Array indicating the unit indices corresponding to each spike in `frames_spike_train1`. + An array of integers where `unit_indices1[i]` gives the unit index associated with the spike at `spike_frames_train1[i]`. unit_indices2 : ndarray - Array indicating the unit indices corresponding to each spike in `frames_spike_train2`. - num_units_sorting1 : int - Total number of units in the first spike train. - num_units_sorting2 : int - Total number of units in the second spike train. + An array of integers where `unit_indices2[i]` gives the unit index associated with the spike at `spike_frames_train2[i]`. + num_units_train1 : int + The total count of unique units in the first spike train. + num_units_train2 : int + The total count of unique units in the second spike train. delta_frames : int - Maximum difference in frames between two spikes to consider them as a match. + The inclusive upper limit on the frame difference for which two spikes are considered matching. That is + if `abs(spike_frames_train1[i] - spike_frames_train2[j]) <= delta_frames` then the spikes at `spike_frames_train1[i]` + and `spike_frames_train2[j]` are considered matching. Returns ------- matching_matrix : ndarray - A matrix of shape (num_units_sorting1, num_units_sorting2) where each entry [i, j] represents - the number of matching spikes between unit i of `frames_spike_train1` and unit j of `frames_spike_train2`. + A 2D numpy array of shape `(num_units_train1, num_units_train2)`. Each element `[i, j]` represents + the count of matching spike pairs between unit `i` from `spike_frames_train1` and unit `j` from `spike_frames_train2`. + Notes ----- @@ -168,59 +171,58 @@ def compute_matching_matrix( By iterating through each spike in the first train, it compares them against spikes in the second train, determining matches based on the two spikes frames being within `delta_frames` of each other. - To avoid redundant comparisons the algorithm maintains a reference, `lower_search_limit_in_second_train`, + To avoid redundant comparisons the algorithm maintains a reference, `second_train_search_start `, which signifies the minimal index in the second spike train that might match the upcoming spike - in the first train. This means that the start of the search moves forward in the second train as the - matches between the two trains are found decreasing the number of comparisons needed. + in the first train. + + The logic can be summarized as follows: + 1. Iterate through each spike in the first train + 2. For each spike, find the first match in the second train. + 3. Save the index of the first match as the new `second_train_search_start ` + 3. For each match, find as many matches as possible from the first match onwards. - An important condition here is thatthe same spike is not matched twice. This is managed by keeping track - of the last matched frame for each unit pair in `previous_frame1_match` and `previous_frame2_match` + An important condition here is that the same spike is not matched twice. This is managed by keeping track + of the last matched frame for each unit pair in `last_match_frame1` and `last_match_frame2` For more details on the rationale behind this approach, refer to the documentation of this module and/or - the metrics section in SpikeForest documentation. + the metrics section in SpikeForest documentation. """ - matching_matrix = np.zeros((num_units_sorting1, num_units_sorting2), dtype=np.uint16) + matching_matrix = np.zeros((num_units_train1, num_units_train2), dtype=np.uint16) # Used to avoid the same spike matching twice - previous_frame1_match = -np.ones_like(matching_matrix, dtype=np.int64) - previous_frame2_match = -np.ones_like(matching_matrix, dtype=np.int64) - - lower_search_limit_in_second_train = 0 - - for index1 in range(len(frames_spike_train1)): - # Keeps track of which frame in the second spike train should be used as a search start for matches - index2 = lower_search_limit_in_second_train - frame1 = frames_spike_train1[index1] - - # Determine next_frame1 if current frame is not the last frame - not_in_the_last_loop = index1 < len(frames_spike_train1) - 1 - if not_in_the_last_loop: - next_frame1 = frames_spike_train1[index1 + 1] - - while index2 < len(frames_spike_train2): - frame2 = frames_spike_train2[index2] - not_a_match = abs(frame1 - frame2) > delta_frames - if not_a_match: - # Go to the next frame in the first train + last_match_frame1 = -np.ones_like(matching_matrix, dtype=np.int64) + last_match_frame2 = -np.ones_like(matching_matrix, dtype=np.int64) + + num_spike_frames_train1 = len(spike_frames_train1) + num_spike_frames_train2 = len(spike_frames_train2) + + # Keeps track of which frame in the second spike train should be used as a search start for matches + second_train_search_start = 0 + for index1 in range(num_spike_frames_train1): + frame1 = spike_frames_train1[index1] + + for index2 in range(second_train_search_start, num_spike_frames_train2): + frame2 = spike_frames_train2[index2] + if frame2 < frame1 - delta_frames: + # no match move the left limit for the next loop + second_train_search_start += 1 + continue + elif frame2 > frame1 + delta_frames: + # no match stop search in train2 and continue increment in train1 break + else: + # match + unit_index1, unit_index2 = unit_indices1[index1], unit_indices2[index2] - # Map the match to a matrix - row, column = unit_indices1[index1], unit_indices2[index2] - - # The same spike cannot be matched twice see the notes in the docstring for more info on this constraint - if frame1 != previous_frame1_match[row, column] and frame2 != previous_frame2_match[row, column]: - previous_frame1_match[row, column] = frame1 - previous_frame2_match[row, column] = frame2 - - matching_matrix[row, column] += 1 - - index2 += 1 + if ( + frame1 != last_match_frame1[unit_index1, unit_index2] + and frame2 != last_match_frame2[unit_index1, unit_index2] + ): + last_match_frame1[unit_index1, unit_index2] = frame1 + last_match_frame2[unit_index1, unit_index2] = frame2 - # Advance the lower_search_limit_in_second_train if the next frame in the first train does not match - not_a_match_with_next = abs(next_frame1 - frame2) > delta_frames - if not_a_match_with_next: - lower_search_limit_in_second_train = index2 + matching_matrix[unit_index1, unit_index2] += 1 return matching_matrix @@ -230,7 +232,7 @@ def compute_matching_matrix( return compute_matching_matrix -def make_match_count_matrix(sorting1, sorting2, delta_frames, n_jobs=None): +def make_match_count_matrix(sorting1, sorting2, delta_frames): num_units_sorting1 = sorting1.get_num_units() num_units_sorting2 = sorting2.get_num_units() matching_matrix = np.zeros((num_units_sorting1, num_units_sorting2), dtype=np.uint16) @@ -275,7 +277,7 @@ def make_match_count_matrix(sorting1, sorting2, delta_frames, n_jobs=None): return match_event_counts_df -def make_agreement_scores(sorting1, sorting2, delta_frames, n_jobs=1): +def make_agreement_scores(sorting1, sorting2, delta_frames): """ Make the agreement matrix. No threshold (min_score) is applied at this step. @@ -291,8 +293,6 @@ def make_agreement_scores(sorting1, sorting2, delta_frames, n_jobs=1): The second sorting extractor delta_frames: int Number of frames to consider spikes coincident - n_jobs: int - Number of jobs to run in parallel Returns ------- @@ -309,7 +309,7 @@ def make_agreement_scores(sorting1, sorting2, delta_frames, n_jobs=1): event_counts1 = pd.Series(ev_counts1, index=unit1_ids) event_counts2 = pd.Series(ev_counts2, index=unit2_ids) - match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames, n_jobs=n_jobs) + match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames) agreement_scores = make_agreement_scores_from_count(match_event_count, event_counts1, event_counts2) diff --git a/src/spikeinterface/comparison/paircomparisons.py b/src/spikeinterface/comparison/paircomparisons.py index e2dc30493d..7f21aa657f 100644 --- a/src/spikeinterface/comparison/paircomparisons.py +++ b/src/spikeinterface/comparison/paircomparisons.py @@ -84,9 +84,7 @@ def _do_agreement(self): self.event_counts2 = do_count_event(self.sorting2) # matrix of event match count for each pair - self.match_event_count = make_match_count_matrix( - self.sorting1, self.sorting2, self.delta_frames, n_jobs=self.n_jobs - ) + self.match_event_count = make_match_count_matrix(self.sorting1, self.sorting2, self.delta_frames) # agreement matrix score for each pair self.agreement_scores = make_agreement_scores_from_count( diff --git a/src/spikeinterface/comparison/tests/test_comparisontools.py b/src/spikeinterface/comparison/tests/test_comparisontools.py index c6494b04d1..ab24678a1e 100644 --- a/src/spikeinterface/comparison/tests/test_comparisontools.py +++ b/src/spikeinterface/comparison/tests/test_comparisontools.py @@ -135,6 +135,23 @@ def test_make_match_count_matrix_repeated_matching_but_no_double_counting(): assert_array_equal(result.to_numpy(), expected_result) +def test_make_match_count_matrix_test_proper_search_in_the_second_train(): + "Search exhaustively in the second train, but only within the delta_frames window, do not terminate search early" + frames_spike_train1 = [500, 600, 800] + frames_spike_train2 = [0, 100, 200, 300, 500, 800] + unit_indices1 = [0, 0, 0] + unit_indices2 = [0, 0, 0, 0, 0, 0] + delta_frames = 20 + + sorting1, sorting2 = make_sorting(frames_spike_train1, unit_indices1, frames_spike_train2, unit_indices2) + + result = make_match_count_matrix(sorting1, sorting2, delta_frames=delta_frames) + + expected_result = np.array([[2]]) + + assert_array_equal(result.to_numpy(), expected_result) + + def test_make_agreement_scores(): delta_frames = 10 @@ -150,7 +167,7 @@ def test_make_agreement_scores(): [0, 0, 5], ) - agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames, n_jobs=1) + agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames) print(agreement_scores) ok = np.array([[2 / 3, 0], [0, 1.0]], dtype="float64") @@ -158,7 +175,7 @@ def test_make_agreement_scores(): assert_array_equal(agreement_scores.values, ok) # test if symetric - agreement_scores2 = make_agreement_scores(sorting2, sorting1, delta_frames, n_jobs=1) + agreement_scores2 = make_agreement_scores(sorting2, sorting1, delta_frames) assert_array_equal(agreement_scores, agreement_scores2.T) @@ -178,7 +195,7 @@ def test_make_possible_match(): [0, 0, 5], ) - agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames, n_jobs=1) + agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames) possible_match_12, possible_match_21 = make_possible_match(agreement_scores, min_accuracy) @@ -207,7 +224,7 @@ def test_make_best_match(): [0, 0, 5], ) - agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames, n_jobs=1) + agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames) best_match_12, best_match_21 = make_best_match(agreement_scores, min_accuracy) @@ -236,7 +253,7 @@ def test_make_hungarian_match(): [0, 0, 5], ) - agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames, n_jobs=1) + agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames) hungarian_match_12, hungarian_match_21 = make_hungarian_match(agreement_scores, min_accuracy) @@ -344,8 +361,8 @@ def test_do_confusion_matrix(): event_counts1 = do_count_event(sorting1) event_counts2 = do_count_event(sorting2) - match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames, n_jobs=1) - agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames, n_jobs=1) + match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames) + agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames) hungarian_match_12, hungarian_match_21 = make_hungarian_match(agreement_scores, min_accuracy) confusion = do_confusion_matrix(event_counts1, event_counts2, hungarian_match_12, match_event_count) @@ -363,8 +380,8 @@ def test_do_confusion_matrix(): event_counts1 = do_count_event(sorting1) event_counts2 = do_count_event(sorting2) - match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames, n_jobs=1) - agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames, n_jobs=1) + match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames) + agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames) hungarian_match_12, hungarian_match_21 = make_hungarian_match(agreement_scores, min_accuracy) confusion = do_confusion_matrix(event_counts1, event_counts2, hungarian_match_12, match_event_count) @@ -391,8 +408,8 @@ def test_do_count_score_and_perf(): event_counts1 = do_count_event(sorting1) event_counts2 = do_count_event(sorting2) - match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames, n_jobs=1) - agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames, n_jobs=1) + match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames) + agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames) hungarian_match_12, hungarian_match_21 = make_hungarian_match(agreement_scores, min_accuracy) count_score = do_count_score(event_counts1, event_counts2, hungarian_match_12, match_event_count) @@ -415,13 +432,20 @@ def test_do_count_score_and_perf(): if __name__ == "__main__": test_make_match_count_matrix() - test_make_agreement_scores() - - test_make_possible_match() - test_make_best_match() - test_make_hungarian_match() - - test_do_score_labels() - test_compare_spike_trains() - test_do_confusion_matrix() - test_do_count_score_and_perf() + test_make_match_count_matrix_sorting_with_itself_simple() + test_make_match_count_matrix_sorting_with_itself_longer() + test_make_match_count_matrix_with_mismatched_sortings() + test_make_match_count_matrix_no_double_matching() + test_make_match_count_matrix_repeated_matching_but_no_double_counting() + test_make_match_count_matrix_test_proper_search_in_the_second_train() + + # test_make_agreement_scores() + + # test_make_possible_match() + # test_make_best_match() + # test_make_hungarian_match() + + # test_do_score_labels() + # test_compare_spike_trains() + # test_do_confusion_matrix() + # test_do_count_score_and_perf() From 57cfbd2461fb865836590aed59222f9c4fdbc7ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20WYNGAARD?= Date: Fri, 10 Nov 2023 10:52:31 +0100 Subject: [PATCH 52/80] Fix filtering rounding error --- src/spikeinterface/preprocessing/filter.py | 4 ++++ src/spikeinterface/preprocessing/filter_gaussian.py | 3 +++ 2 files changed, 7 insertions(+) diff --git a/src/spikeinterface/preprocessing/filter.py b/src/spikeinterface/preprocessing/filter.py index 1d6947be79..172c666d62 100644 --- a/src/spikeinterface/preprocessing/filter.py +++ b/src/spikeinterface/preprocessing/filter.py @@ -153,6 +153,10 @@ def get_traces(self, start_frame, end_frame, channel_indices): filtered_traces = filtered_traces[left_margin:-right_margin, :] else: filtered_traces = filtered_traces[left_margin:, :] + + if np.issubdtype(self.dtype, np.integer): + filtered_traces = filtered_traces.round() + return filtered_traces.astype(self.dtype) diff --git a/src/spikeinterface/preprocessing/filter_gaussian.py b/src/spikeinterface/preprocessing/filter_gaussian.py index 79b5ba5bc3..325ce82074 100644 --- a/src/spikeinterface/preprocessing/filter_gaussian.py +++ b/src/spikeinterface/preprocessing/filter_gaussian.py @@ -74,6 +74,9 @@ def get_traces( filtered_fft = traces_fft * (gauss_high - gauss_low)[:, None] filtered_traces = np.real(np.fft.ifft(filtered_fft, axis=0)) + if np.issubdtype(dtype, np.integer): + filtered_traces = filtered_traces.round() + if right_margin > 0: return filtered_traces[left_margin:-right_margin, :].astype(dtype) else: From 05c9ff1796c89fb1c0d64612f004e4b5fa1ed00e Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 10 Nov 2023 12:13:30 +0100 Subject: [PATCH 53/80] Fix corner case in make_match_count_matrix() Add symetric option and propagate in SymmetricSortingComparison/GroundTruthComparison --- .../comparison/comparisontools.py | 135 ++++++++++-------- .../comparison/paircomparisons.py | 9 +- .../comparison/tests/test_comparisontools.py | 53 ++++++- 3 files changed, 132 insertions(+), 65 deletions(-) diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index 3cd856d662..aa9adfcb5c 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -132,61 +132,6 @@ def compute_matching_matrix( num_units_train2, delta_frames, ): - """ - Compute a matrix representing the matches between two spike trains. - - Given two spike trains, this function finds matching spikes based on a temporal proximity criterion - defined by `delta_frames`. The resulting matrix indicates the number of matches between units - in `spike_frames_train1` and `spike_frames_train2`. - - Parameters - ---------- - spike_frames_train1 : ndarray - An array of integer frame numbers corresponding to spike times for the first train. Must be in ascending order. - spike_frames_train2 : ndarray - An array of integer frame numbers corresponding to spike times for the second train. Must be in ascending order. - unit_indices1 : ndarray - An array of integers where `unit_indices1[i]` gives the unit index associated with the spike at `spike_frames_train1[i]`. - unit_indices2 : ndarray - An array of integers where `unit_indices2[i]` gives the unit index associated with the spike at `spike_frames_train2[i]`. - num_units_train1 : int - The total count of unique units in the first spike train. - num_units_train2 : int - The total count of unique units in the second spike train. - delta_frames : int - The inclusive upper limit on the frame difference for which two spikes are considered matching. That is - if `abs(spike_frames_train1[i] - spike_frames_train2[j]) <= delta_frames` then the spikes at `spike_frames_train1[i]` - and `spike_frames_train2[j]` are considered matching. - - Returns - ------- - matching_matrix : ndarray - A 2D numpy array of shape `(num_units_train1, num_units_train2)`. Each element `[i, j]` represents - the count of matching spike pairs between unit `i` from `spike_frames_train1` and unit `j` from `spike_frames_train2`. - - - Notes - ----- - This algorithm identifies matching spikes between two ordered spike trains. - By iterating through each spike in the first train, it compares them against spikes in the second train, - determining matches based on the two spikes frames being within `delta_frames` of each other. - - To avoid redundant comparisons the algorithm maintains a reference, `second_train_search_start `, - which signifies the minimal index in the second spike train that might match the upcoming spike - in the first train. - - The logic can be summarized as follows: - 1. Iterate through each spike in the first train - 2. For each spike, find the first match in the second train. - 3. Save the index of the first match as the new `second_train_search_start ` - 3. For each match, find as many matches as possible from the first match onwards. - - An important condition here is that the same spike is not matched twice. This is managed by keeping track - of the last matched frame for each unit pair in `last_match_frame1` and `last_match_frame2` - - For more details on the rationale behind this approach, refer to the documentation of this module and/or - the metrics section in SpikeForest documentation. - """ matching_matrix = np.zeros((num_units_train1, num_units_train2), dtype=np.uint16) @@ -232,7 +177,61 @@ def compute_matching_matrix( return compute_matching_matrix -def make_match_count_matrix(sorting1, sorting2, delta_frames): +def make_match_count_matrix(sorting1, sorting2, delta_frames, symetric=False): + """ + Compute a matrix representing the matches between two Sorting objects. + + Given two spike trains, this function finds matching spikes based on a temporal proximity criterion + defined by `delta_frames`. The resulting matrix indicates the number of matches between units + in `spike_frames_train1` and `spike_frames_train2` for each pair of units. + + Note that this algo is not symetric and biased toward sorting1 is the ground truth. + + Parameters + ---------- + sorting1 : Sorting + An array of integer frame numbers corresponding to spike times for the first train. Must be in ascending order. + sorting2 : Sorting + An array of integer frame numbers corresponding to spike times for the second train. Must be in ascending order. + delta_frames : int + The inclusive upper limit on the frame difference for which two spikes are considered matching. That is + if `abs(spike_frames_train1[i] - spike_frames_train2[j]) <= delta_frames` then the spikes at `spike_frames_train1[i]` + and `spike_frames_train2[j]` are considered matching. + symetric: bool, dfault False + If symetric, the this the algos is run two times by switching sorting1 and sorting2 the minimum of the two + results is taken. + Returns + ------- + matching_matrix : ndarray + A 2D numpy array of shape `(num_units_train1, num_units_train2)`. Each element `[i, j]` represents + the count of matching spike pairs between unit `i` from `spike_frames_train1` and unit `j` from `spike_frames_train2`. + + Notes + ----- + This algorithm identifies matching spikes between two ordered spike trains. + By iterating through each spike in the first train, it compares them against spikes in the second train, + determining matches based on the two spikes frames being within `delta_frames` of each other. + + To avoid redundant comparisons the algorithm maintains a reference, `second_train_search_start `, + which signifies the minimal index in the second spike train that might match the upcoming spike + in the first train. + + The logic can be summarized as follows: + 1. Iterate through each spike in the first train + 2. For each spike, find the first match in the second train. + 3. Save the index of the first match as the new `second_train_search_start ` + 3. For each match, find as many matches as possible from the first match onwards. + + An important condition here is that the same spike is not matched twice. This is managed by keeping track + of the last matched frame for each unit pair in `last_match_frame1` and `last_match_frame2` + There are corner cases where a spike can be counted twice in the the spiketrain 2 in case of bursting situations + (below delta_frames) in the spiketrain 1. To ensure that the number of match do not exceed the number of spike, + we applied a final clip. + + For more details on the rationale behind this approach, refer to the documentation of this module and/or + the metrics section in SpikeForest documentation. + """ + num_units_sorting1 = sorting1.get_num_units() num_units_sorting2 = sorting2.get_num_units() matching_matrix = np.zeros((num_units_sorting1, num_units_sorting2), dtype=np.uint16) @@ -257,7 +256,7 @@ def make_match_count_matrix(sorting1, sorting2, delta_frames): unit_indices1_sorted = spike_vector1["unit_index"] unit_indices2_sorted = spike_vector2["unit_index"] - matching_matrix += get_optimized_compute_matching_matrix()( + matching_matrix_seg = get_optimized_compute_matching_matrix()( sample_frames1_sorted, sample_frames2_sorted, unit_indices1_sorted, @@ -267,6 +266,28 @@ def make_match_count_matrix(sorting1, sorting2, delta_frames): delta_frames, ) + if symetric: + matching_matrix_seg_switch = get_optimized_compute_matching_matrix()( + sample_frames2_sorted, + sample_frames1_sorted, + unit_indices2_sorted, + unit_indices1_sorted, + num_units_sorting2, + num_units_sorting1, + delta_frames, + ) + matching_matrix_seg = np.maximum(matching_matrix_seg, matching_matrix_seg_switch.T) + + + matching_matrix += matching_matrix_seg + + + # ensure the number of match do not exceed the number of spike in train 2 + # this is a simple way to handle corner cases for bursting in sorting1 + spike_count2 = np.array(list(sorting2.count_num_spikes_per_unit().values())) + spike_count2 = spike_count2[np.newaxis, :] + matching_matrix = np.clip(matching_matrix, None, spike_count2) + # Build a data frame from the matching matrix import pandas as pd diff --git a/src/spikeinterface/comparison/paircomparisons.py b/src/spikeinterface/comparison/paircomparisons.py index 7f21aa657f..e57f2c047a 100644 --- a/src/spikeinterface/comparison/paircomparisons.py +++ b/src/spikeinterface/comparison/paircomparisons.py @@ -28,6 +28,7 @@ def __init__( delta_time=0.4, match_score=0.5, chance_score=0.1, + symetric=False, n_jobs=1, verbose=False, ): @@ -55,6 +56,8 @@ def __init__( self.unit1_ids = self.sorting1.get_unit_ids() self.unit2_ids = self.sorting2.get_unit_ids() + self.symetric = symetric + self._do_agreement() self._do_matching() @@ -84,7 +87,9 @@ def _do_agreement(self): self.event_counts2 = do_count_event(self.sorting2) # matrix of event match count for each pair - self.match_event_count = make_match_count_matrix(self.sorting1, self.sorting2, self.delta_frames) + self.match_event_count = make_match_count_matrix( + self.sorting1, self.sorting2, self.delta_frames, symetric=self.symetric + ) # agreement matrix score for each pair self.agreement_scores = make_agreement_scores_from_count( @@ -151,6 +156,7 @@ def __init__( delta_time=delta_time, match_score=match_score, chance_score=chance_score, + symetric=True, n_jobs=n_jobs, verbose=verbose, ) @@ -283,6 +289,7 @@ def __init__( delta_time=delta_time, match_score=match_score, chance_score=chance_score, + symetric=False, n_jobs=n_jobs, verbose=verbose, ) diff --git a/src/spikeinterface/comparison/tests/test_comparisontools.py b/src/spikeinterface/comparison/tests/test_comparisontools.py index ab24678a1e..137d4cff05 100644 --- a/src/spikeinterface/comparison/tests/test_comparisontools.py +++ b/src/spikeinterface/comparison/tests/test_comparisontools.py @@ -135,6 +135,43 @@ def test_make_match_count_matrix_repeated_matching_but_no_double_counting(): assert_array_equal(result.to_numpy(), expected_result) +def test_make_match_count_matrix_repeated_matching_but_no_double_counting_2(): + # More challenging condition, this was failing with the previous approach that used np.where and np.diff + # This actual implementation should fail but the "clip protection" by number of spike make the solution. + # This is cheating but acceptable for really corner cases (burst in the ground truth). + frames_spike_train1 = [100, 105, 110] + frames_spike_train2 = [100, 105, ] + unit_indices1 = [0, 0, 0] + unit_indices2 = [0, 0,] + delta_frames = 20 # long enough, so all frames in both sortings are within each other reach + + sorting1, sorting2 = make_sorting(frames_spike_train1, unit_indices1, frames_spike_train2, unit_indices2) + + # this is easy because it is sorting2 centric + result = make_match_count_matrix(sorting2, sorting1, delta_frames=delta_frames, symetric=False) + expected_result = np.array([[2]]) + assert_array_equal(result.to_numpy(), expected_result) + + # this work only because we protect by clipping + result = make_match_count_matrix(sorting1, sorting2, delta_frames=delta_frames, symetric=False) + expected_result = np.array([[2]]) + assert_array_equal(result.to_numpy(), expected_result) + +def test_make_match_count_matrix_symetric(): + frames_spike_train1 = [100, 102, 105, 120, 1000, ] + unit_indices1 = [0, 2, 1, 0, 0] + frames_spike_train2 = [101, 150, 1000] + unit_indices2 = [0, 1, 0] + delta_frames = 100 + + sorting1, sorting2 = make_sorting(frames_spike_train1, unit_indices1, frames_spike_train2, unit_indices2) + + result = make_match_count_matrix(sorting1, sorting2, delta_frames=delta_frames, symetric=True) + result_T = make_match_count_matrix(sorting2, sorting1, delta_frames=delta_frames, symetric=True) + + assert_array_equal(result.T, result_T) + + def test_make_match_count_matrix_test_proper_search_in_the_second_train(): "Search exhaustively in the second train, but only within the delta_frames window, do not terminate search early" frames_spike_train1 = [500, 600, 800] @@ -431,13 +468,15 @@ def test_do_count_score_and_perf(): if __name__ == "__main__": - test_make_match_count_matrix() - test_make_match_count_matrix_sorting_with_itself_simple() - test_make_match_count_matrix_sorting_with_itself_longer() - test_make_match_count_matrix_with_mismatched_sortings() - test_make_match_count_matrix_no_double_matching() - test_make_match_count_matrix_repeated_matching_but_no_double_counting() - test_make_match_count_matrix_test_proper_search_in_the_second_train() + # test_make_match_count_matrix() + # test_make_match_count_matrix_sorting_with_itself_simple() + # test_make_match_count_matrix_sorting_with_itself_longer() + # test_make_match_count_matrix_with_mismatched_sortings() + # test_make_match_count_matrix_no_double_matching() + # test_make_match_count_matrix_repeated_matching_but_no_double_counting() + # test_make_match_count_matrix_repeated_matching_but_no_double_counting_2() + # test_make_match_count_matrix_test_proper_search_in_the_second_train() + test_make_match_count_matrix_symetric() # test_make_agreement_scores() From 3aff2255d5a6426b91fd667be13abad8d795df86 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 10 Nov 2023 11:16:08 +0000 Subject: [PATCH 54/80] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../comparison/comparisontools.py | 5 +---- .../comparison/tests/test_comparisontools.py | 19 ++++++++++++++++--- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index aa9adfcb5c..c8a6edb577 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -132,7 +132,6 @@ def compute_matching_matrix( num_units_train2, delta_frames, ): - matching_matrix = np.zeros((num_units_train1, num_units_train2), dtype=np.uint16) # Used to avoid the same spike matching twice @@ -225,7 +224,7 @@ def make_match_count_matrix(sorting1, sorting2, delta_frames, symetric=False): An important condition here is that the same spike is not matched twice. This is managed by keeping track of the last matched frame for each unit pair in `last_match_frame1` and `last_match_frame2` There are corner cases where a spike can be counted twice in the the spiketrain 2 in case of bursting situations - (below delta_frames) in the spiketrain 1. To ensure that the number of match do not exceed the number of spike, + (below delta_frames) in the spiketrain 1. To ensure that the number of match do not exceed the number of spike, we applied a final clip. For more details on the rationale behind this approach, refer to the documentation of this module and/or @@ -278,10 +277,8 @@ def make_match_count_matrix(sorting1, sorting2, delta_frames, symetric=False): ) matching_matrix_seg = np.maximum(matching_matrix_seg, matching_matrix_seg_switch.T) - matching_matrix += matching_matrix_seg - # ensure the number of match do not exceed the number of spike in train 2 # this is a simple way to handle corner cases for bursting in sorting1 spike_count2 = np.array(list(sorting2.count_num_spikes_per_unit().values())) diff --git a/src/spikeinterface/comparison/tests/test_comparisontools.py b/src/spikeinterface/comparison/tests/test_comparisontools.py index 137d4cff05..5a6f18f0f9 100644 --- a/src/spikeinterface/comparison/tests/test_comparisontools.py +++ b/src/spikeinterface/comparison/tests/test_comparisontools.py @@ -140,9 +140,15 @@ def test_make_match_count_matrix_repeated_matching_but_no_double_counting_2(): # This actual implementation should fail but the "clip protection" by number of spike make the solution. # This is cheating but acceptable for really corner cases (burst in the ground truth). frames_spike_train1 = [100, 105, 110] - frames_spike_train2 = [100, 105, ] + frames_spike_train2 = [ + 100, + 105, + ] unit_indices1 = [0, 0, 0] - unit_indices2 = [0, 0,] + unit_indices2 = [ + 0, + 0, + ] delta_frames = 20 # long enough, so all frames in both sortings are within each other reach sorting1, sorting2 = make_sorting(frames_spike_train1, unit_indices1, frames_spike_train2, unit_indices2) @@ -157,8 +163,15 @@ def test_make_match_count_matrix_repeated_matching_but_no_double_counting_2(): expected_result = np.array([[2]]) assert_array_equal(result.to_numpy(), expected_result) + def test_make_match_count_matrix_symetric(): - frames_spike_train1 = [100, 102, 105, 120, 1000, ] + frames_spike_train1 = [ + 100, + 102, + 105, + 120, + 1000, + ] unit_indices1 = [0, 2, 1, 0, 0] frames_spike_train2 = [101, 150, 1000] unit_indices2 = [0, 1, 0] From c5263f7e7a4fe50bee987aa4947d865c094ef067 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 10 Nov 2023 12:20:22 +0100 Subject: [PATCH 55/80] get_optimized_compute_matching_matrix: protect with index instead of frame --- src/spikeinterface/comparison/comparisontools.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index aa9adfcb5c..0a142ed878 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -161,11 +161,11 @@ def compute_matching_matrix( unit_index1, unit_index2 = unit_indices1[index1], unit_indices2[index2] if ( - frame1 != last_match_frame1[unit_index1, unit_index2] - and frame2 != last_match_frame2[unit_index1, unit_index2] + index1 != last_match_frame1[unit_index1, unit_index2] + and index2 != last_match_frame2[unit_index1, unit_index2] ): - last_match_frame1[unit_index1, unit_index2] = frame1 - last_match_frame2[unit_index1, unit_index2] = frame2 + last_match_frame1[unit_index1, unit_index2] = index1 + last_match_frame2[unit_index1, unit_index2] = index2 matching_matrix[unit_index1, unit_index2] += 1 From a823a08dde2d2acf0c055d6c0eb93b80bd09212b Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 10 Nov 2023 13:14:24 +0100 Subject: [PATCH 56/80] oups --- src/spikeinterface/comparison/paircomparisons.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/comparison/paircomparisons.py b/src/spikeinterface/comparison/paircomparisons.py index e57f2c047a..d6d40c8d8c 100644 --- a/src/spikeinterface/comparison/paircomparisons.py +++ b/src/spikeinterface/comparison/paircomparisons.py @@ -28,7 +28,7 @@ def __init__( delta_time=0.4, match_score=0.5, chance_score=0.1, - symetric=False, + symmetric=False, n_jobs=1, verbose=False, ): @@ -56,7 +56,7 @@ def __init__( self.unit1_ids = self.sorting1.get_unit_ids() self.unit2_ids = self.sorting2.get_unit_ids() - self.symetric = symetric + self.symmetric = symmetric self._do_agreement() self._do_matching() @@ -88,7 +88,7 @@ def _do_agreement(self): # matrix of event match count for each pair self.match_event_count = make_match_count_matrix( - self.sorting1, self.sorting2, self.delta_frames, symetric=self.symetric + self.sorting1, self.sorting2, self.delta_frames, symmetric=self.symmetric ) # agreement matrix score for each pair @@ -156,7 +156,7 @@ def __init__( delta_time=delta_time, match_score=match_score, chance_score=chance_score, - symetric=True, + symmetric=True, n_jobs=n_jobs, verbose=verbose, ) @@ -289,7 +289,7 @@ def __init__( delta_time=delta_time, match_score=match_score, chance_score=chance_score, - symetric=False, + symmetric=False, n_jobs=n_jobs, verbose=verbose, ) From b03febe478aab1af44f82f1e7357c29771c3fbbc Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 10 Nov 2023 13:15:51 +0100 Subject: [PATCH 57/80] oups --- src/spikeinterface/comparison/comparisontools.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index 26f220ef73..9dcda06ada 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -176,7 +176,7 @@ def compute_matching_matrix( return compute_matching_matrix -def make_match_count_matrix(sorting1, sorting2, delta_frames, symetric=False): +def make_match_count_matrix(sorting1, sorting2, delta_frames, symmetric=False): """ Compute a matrix representing the matches between two Sorting objects. @@ -184,7 +184,7 @@ def make_match_count_matrix(sorting1, sorting2, delta_frames, symetric=False): defined by `delta_frames`. The resulting matrix indicates the number of matches between units in `spike_frames_train1` and `spike_frames_train2` for each pair of units. - Note that this algo is not symetric and biased toward sorting1 is the ground truth. + Note that this algo is not symmetric and biased toward sorting1 is the ground truth. Parameters ---------- @@ -196,8 +196,8 @@ def make_match_count_matrix(sorting1, sorting2, delta_frames, symetric=False): The inclusive upper limit on the frame difference for which two spikes are considered matching. That is if `abs(spike_frames_train1[i] - spike_frames_train2[j]) <= delta_frames` then the spikes at `spike_frames_train1[i]` and `spike_frames_train2[j]` are considered matching. - symetric: bool, dfault False - If symetric, the this the algos is run two times by switching sorting1 and sorting2 the minimum of the two + symmetric: bool, dfault False + If symmetric, the this the algos is run two times by switching sorting1 and sorting2 the minimum of the two results is taken. Returns ------- @@ -265,7 +265,7 @@ def make_match_count_matrix(sorting1, sorting2, delta_frames, symetric=False): delta_frames, ) - if symetric: + if symmetric: matching_matrix_seg_switch = get_optimized_compute_matching_matrix()( sample_frames2_sorted, sample_frames1_sorted, From dc55dfc97a66ae8681a501951bbdaa1be8342be9 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 10 Nov 2023 16:25:10 +0100 Subject: [PATCH 58/80] oups --- .../comparison/tests/test_comparisontools.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/comparison/tests/test_comparisontools.py b/src/spikeinterface/comparison/tests/test_comparisontools.py index 5a6f18f0f9..b6cd3fc3b4 100644 --- a/src/spikeinterface/comparison/tests/test_comparisontools.py +++ b/src/spikeinterface/comparison/tests/test_comparisontools.py @@ -154,17 +154,17 @@ def test_make_match_count_matrix_repeated_matching_but_no_double_counting_2(): sorting1, sorting2 = make_sorting(frames_spike_train1, unit_indices1, frames_spike_train2, unit_indices2) # this is easy because it is sorting2 centric - result = make_match_count_matrix(sorting2, sorting1, delta_frames=delta_frames, symetric=False) + result = make_match_count_matrix(sorting2, sorting1, delta_frames=delta_frames, symmetric=False) expected_result = np.array([[2]]) assert_array_equal(result.to_numpy(), expected_result) # this work only because we protect by clipping - result = make_match_count_matrix(sorting1, sorting2, delta_frames=delta_frames, symetric=False) + result = make_match_count_matrix(sorting1, sorting2, delta_frames=delta_frames, symmetric=False) expected_result = np.array([[2]]) assert_array_equal(result.to_numpy(), expected_result) -def test_make_match_count_matrix_symetric(): +def test_make_match_count_matrix_symmetric(): frames_spike_train1 = [ 100, 102, @@ -179,8 +179,8 @@ def test_make_match_count_matrix_symetric(): sorting1, sorting2 = make_sorting(frames_spike_train1, unit_indices1, frames_spike_train2, unit_indices2) - result = make_match_count_matrix(sorting1, sorting2, delta_frames=delta_frames, symetric=True) - result_T = make_match_count_matrix(sorting2, sorting1, delta_frames=delta_frames, symetric=True) + result = make_match_count_matrix(sorting1, sorting2, delta_frames=delta_frames, symmetric=True) + result_T = make_match_count_matrix(sorting2, sorting1, delta_frames=delta_frames, symmetric=True) assert_array_equal(result.T, result_T) @@ -224,7 +224,7 @@ def test_make_agreement_scores(): assert_array_equal(agreement_scores.values, ok) - # test if symetric + # test if symmetric agreement_scores2 = make_agreement_scores(sorting2, sorting1, delta_frames) assert_array_equal(agreement_scores, agreement_scores2.T) @@ -489,7 +489,7 @@ def test_do_count_score_and_perf(): # test_make_match_count_matrix_repeated_matching_but_no_double_counting() # test_make_match_count_matrix_repeated_matching_but_no_double_counting_2() # test_make_match_count_matrix_test_proper_search_in_the_second_train() - test_make_match_count_matrix_symetric() + test_make_match_count_matrix_symmetric() # test_make_agreement_scores() From 5591593d569f23fb9b71e6f67b2a17ad1eb70dae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20WYNGAARD?= Date: Fri, 10 Nov 2023 17:37:23 +0100 Subject: [PATCH 59/80] Documented rounding in pre-processing --- doc/modules/preprocessing.rst | 2 ++ src/spikeinterface/preprocessing/normalize_scale.py | 4 ++++ src/spikeinterface/preprocessing/phase_shift.py | 2 ++ 3 files changed, 8 insertions(+) diff --git a/doc/modules/preprocessing.rst b/doc/modules/preprocessing.rst index 67f1e52011..e95edb968c 100644 --- a/doc/modules/preprocessing.rst +++ b/doc/modules/preprocessing.rst @@ -74,6 +74,8 @@ dtype (unless specified otherwise): Some scaling pre-processors, such as :code:`whiten()` or :code:`zscore()`, will force the output to :code:`float32`. +When converting from a :code:`float` to an :code:`int`, the value will first be rounded to the nearest integer. + Available preprocessing ----------------------- diff --git a/src/spikeinterface/preprocessing/normalize_scale.py b/src/spikeinterface/preprocessing/normalize_scale.py index 03afada380..f24aff6e79 100644 --- a/src/spikeinterface/preprocessing/normalize_scale.py +++ b/src/spikeinterface/preprocessing/normalize_scale.py @@ -20,6 +20,10 @@ def __init__(self, parent_recording_segment, gain, offset, dtype): def get_traces(self, start_frame, end_frame, channel_indices): traces = self.parent_recording_segment.get_traces(start_frame, end_frame, channel_indices) scaled_traces = traces * self.gain[:, channel_indices] + self.offset[:, channel_indices] + + if np.issubdtype(self._dtype, np.integer): + scaled_traces = scaled_traces.round() + return scaled_traces.astype(self._dtype) diff --git a/src/spikeinterface/preprocessing/phase_shift.py b/src/spikeinterface/preprocessing/phase_shift.py index 570ce48a5d..0734dad784 100644 --- a/src/spikeinterface/preprocessing/phase_shift.py +++ b/src/spikeinterface/preprocessing/phase_shift.py @@ -103,6 +103,8 @@ def get_traces(self, start_frame, end_frame, channel_indices): traces_shift = traces_shift[left_margin:-right_margin, :] if self.tmp_dtype is not None: + if np.issubdtype(self.dtype, np.integer): + traces_shift = traces_shift.round() traces_shift = traces_shift.astype(self.dtype) return traces_shift From b55991ca59cec7f40e9eca5c6383bef8f156d2ac Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Mon, 13 Nov 2023 10:10:17 +0100 Subject: [PATCH 60/80] Update src/spikeinterface/comparison/comparisontools.py Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/comparison/comparisontools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index 9dcda06ada..1b503629bd 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -221,7 +221,7 @@ def make_match_count_matrix(sorting1, sorting2, delta_frames, symmetric=False): 3. Save the index of the first match as the new `second_train_search_start ` 3. For each match, find as many matches as possible from the first match onwards. - An important condition here is that the same spike is not matched twice. This is managed by keeping track + An important condition is that the same spike is not matched twice. This is managed by keeping track of the last matched frame for each unit pair in `last_match_frame1` and `last_match_frame2` There are corner cases where a spike can be counted twice in the the spiketrain 2 in case of bursting situations (below delta_frames) in the spiketrain 1. To ensure that the number of match do not exceed the number of spike, From 866469addcab659552166047a4524976e4fe0687 Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Mon, 13 Nov 2023 10:10:35 +0100 Subject: [PATCH 61/80] Update src/spikeinterface/comparison/comparisontools.py Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/comparison/comparisontools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index 1b503629bd..2f7ed61427 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -184,7 +184,7 @@ def make_match_count_matrix(sorting1, sorting2, delta_frames, symmetric=False): defined by `delta_frames`. The resulting matrix indicates the number of matches between units in `spike_frames_train1` and `spike_frames_train2` for each pair of units. - Note that this algo is not symmetric and biased toward sorting1 is the ground truth. + Note that this algo is not symmetric and is biased with `sorting1` representing ground truth for the comparison Parameters ---------- From 2e315392934a5af6239636791c3eee1c3bf45787 Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Mon, 13 Nov 2023 10:10:42 +0100 Subject: [PATCH 62/80] Update src/spikeinterface/comparison/comparisontools.py Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/comparison/comparisontools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index 2f7ed61427..1ad1c59499 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -178,7 +178,7 @@ def compute_matching_matrix( def make_match_count_matrix(sorting1, sorting2, delta_frames, symmetric=False): """ - Compute a matrix representing the matches between two Sorting objects. + Computes a matrix representing the matches between two Sorting objects. Given two spike trains, this function finds matching spikes based on a temporal proximity criterion defined by `delta_frames`. The resulting matrix indicates the number of matches between units From 5b59a506c4a5a6b7863fc2c309c0e697af966279 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 13 Nov 2023 10:29:54 +0100 Subject: [PATCH 63/80] Prepare release 0.99.1 --- README.md | 2 +- doc/releases/0.99.1.rst | 14 ++++++++++++++ doc/whatisnew.rst | 7 +++++++ pyproject.toml | 10 +++++----- 4 files changed, 27 insertions(+), 6 deletions(-) create mode 100644 doc/releases/0.99.1.rst diff --git a/README.md b/README.md index d51f372848..977cf6eba4 100644 --- a/README.md +++ b/README.md @@ -67,7 +67,7 @@ With SpikeInterface, users can: ## Documentation -Detailed documentation of the latest PyPI release of SpikeInterface can be found [here](https://spikeinterface.readthedocs.io/en/0.99.0). +Detailed documentation of the latest PyPI release of SpikeInterface can be found [here](https://spikeinterface.readthedocs.io/en/0.99.1). Detailed documentation of the development version of SpikeInterface can be found [here](https://spikeinterface.readthedocs.io/en/latest). diff --git a/doc/releases/0.99.1.rst b/doc/releases/0.99.1.rst new file mode 100644 index 0000000000..00acbbbd75 --- /dev/null +++ b/doc/releases/0.99.1.rst @@ -0,0 +1,14 @@ +.. _release0.99.1: + +SpikeInterface 0.99.1 release notes +----------------------------------- + +14th November 2023 + +Minor release with some bug fixes. + +* Fix crash when default start / end frame arguments on motion interpolation are used (#2176) +* Fix bug in `make_match_count_matrix()` when computing matching events (#2182) +* Fix corner case in `make_match_count_matrix()` and make it symmetric (#2191) +* Fix maxwell tests by setting HDF5_PLUGIN_PATH env in action (#2161) +* Add read_npz_sorting to extractors module (#2183) diff --git a/doc/whatisnew.rst b/doc/whatisnew.rst index 33735a47fd..2232173e5a 100644 --- a/doc/whatisnew.rst +++ b/doc/whatisnew.rst @@ -8,6 +8,7 @@ Release notes .. toctree:: :maxdepth: 1 + releases/0.99.1.rst releases/0.99.0.rst releases/0.98.2.rst releases/0.98.1.rst @@ -32,6 +33,12 @@ Release notes releases/0.9.1.rst +Version 0.99.1 +============== + +* Minor release with some bug fixes + + Version 0.99.0 ============== diff --git a/pyproject.toml b/pyproject.toml index 658703b25c..954a6dbc8d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "spikeinterface" -version = "0.100.0.dev0" +version = "0.99.1" authors = [ { name="Alessio Buccino", email="alessiop.buccino@gmail.com" }, { name="Samuel Garcia", email="sam.garcia.die@gmail.com" }, @@ -146,8 +146,8 @@ test = [ # for github test : probeinterface and neo from master # for release we need pypi, so this need to be commented - "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", - "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", + # "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", + # "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", ] docs = [ @@ -164,8 +164,8 @@ docs = [ "hdbscan>=0.8.33", # For sorters spykingcircus2 + tridesclous "numba", # For many postprocessing functions # for release we need pypi, so this needs to be commented - "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", # We always build from the latest version - "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", # We always build from the latest version + # "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", # We always build from the latest version + # "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", # We always build from the latest version ] From 3e04bd383579aa858e73e6b57b8fea34f911bf44 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20WYNGAARD?= Date: Mon, 13 Nov 2023 12:06:36 +0100 Subject: [PATCH 64/80] Use `uint64` for match count (#2196) * Use `uint64` for match count `uint16` leads to overflow and erroneous count. * Fixes the bug --- src/spikeinterface/comparison/comparisontools.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index 3cd856d662..1c3685c666 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -188,7 +188,7 @@ def compute_matching_matrix( the metrics section in SpikeForest documentation. """ - matching_matrix = np.zeros((num_units_train1, num_units_train2), dtype=np.uint16) + matching_matrix = np.zeros((num_units_train1, num_units_train2), dtype=np.uint64) # Used to avoid the same spike matching twice last_match_frame1 = -np.ones_like(matching_matrix, dtype=np.int64) @@ -235,7 +235,7 @@ def compute_matching_matrix( def make_match_count_matrix(sorting1, sorting2, delta_frames): num_units_sorting1 = sorting1.get_num_units() num_units_sorting2 = sorting2.get_num_units() - matching_matrix = np.zeros((num_units_sorting1, num_units_sorting2), dtype=np.uint16) + matching_matrix = np.zeros((num_units_sorting1, num_units_sorting2), dtype=np.uint64) spike_vector1_segments = sorting1.to_spike_vector(concatenated=False) spike_vector2_segments = sorting2.to_spike_vector(concatenated=False) From 4fa5a0cde87e64c818a9bceb5e2ebcd1268a5289 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 13 Nov 2023 13:54:15 +0100 Subject: [PATCH 65/80] symmetric > ensure_symmetry --- .../comparison/comparisontools.py | 55 +++++++++++++++---- .../comparison/paircomparisons.py | 10 ++-- .../comparison/tests/test_comparisontools.py | 50 ++++++++--------- 3 files changed, 73 insertions(+), 42 deletions(-) diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index 1ad1c59499..d9ab2e685d 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -132,6 +132,36 @@ def compute_matching_matrix( num_units_train2, delta_frames, ): + """ + Internal function used by `make_match_count_matrix()`. + This function is for one segment only. + The llop over segment is done in `make_match_count_matrix()` + + Parameters + ---------- + spike_frames_train1 : ndarray + An array of integer frame numbers corresponding to spike times for the first train. Must be in ascending order. + spike_frames_train2 : ndarray + An array of integer frame numbers corresponding to spike times for the second train. Must be in ascending order. + unit_indices1 : ndarray + An array of integers where `unit_indices1[i]` gives the unit index associated with the spike at `spike_frames_train1[i]`. + unit_indices2 : ndarray + An array of integers where `unit_indices2[i]` gives the unit index associated with the spike at `spike_frames_train2[i]`. + num_units_train1 : int + The total count of unique units in the first spike train. + num_units_train2 : int + The total count of unique units in the second spike train. + delta_frames : int + The inclusive upper limit on the frame difference for which two spikes are considered matching. That is + if `abs(spike_frames_train1[i] - spike_frames_train2[j]) <= delta_frames` then the spikes at `spike_frames_train1[i]` + and `spike_frames_train2[j]` are considered matching. + + Returns + ------- + matching_matrix : ndarray + A 2D numpy array of shape `(num_units_train1, num_units_train2)` + + """ matching_matrix = np.zeros((num_units_train1, num_units_train2), dtype=np.uint16) # Used to avoid the same spike matching twice @@ -176,7 +206,7 @@ def compute_matching_matrix( return compute_matching_matrix -def make_match_count_matrix(sorting1, sorting2, delta_frames, symmetric=False): +def make_match_count_matrix(sorting1, sorting2, delta_frames, ensure_symmetry=False): """ Computes a matrix representing the matches between two Sorting objects. @@ -194,11 +224,11 @@ def make_match_count_matrix(sorting1, sorting2, delta_frames, symmetric=False): An array of integer frame numbers corresponding to spike times for the second train. Must be in ascending order. delta_frames : int The inclusive upper limit on the frame difference for which two spikes are considered matching. That is - if `abs(spike_frames_train1[i] - spike_frames_train2[j]) <= delta_frames` then the spikes at `spike_frames_train1[i]` - and `spike_frames_train2[j]` are considered matching. - symmetric: bool, dfault False - If symmetric, the this the algos is run two times by switching sorting1 and sorting2 the minimum of the two - results is taken. + if `abs(spike_frames_train1[i] - spike_frames_train2[j]) <= delta_frames` then the spikes at + `spike_frames_train1[i]` and `spike_frames_train2[j]` are considered matching. + ensure_symmetry: bool, default False + If ensure_symmetry=True, then the algo is run two times by switching sorting1 and sorting2. + And the minimum of the two results is taken. Returns ------- matching_matrix : ndarray @@ -221,11 +251,12 @@ def make_match_count_matrix(sorting1, sorting2, delta_frames, symmetric=False): 3. Save the index of the first match as the new `second_train_search_start ` 3. For each match, find as many matches as possible from the first match onwards. - An important condition is that the same spike is not matched twice. This is managed by keeping track + An important condition here is that the same spike is not matched twice. This is managed by keeping track of the last matched frame for each unit pair in `last_match_frame1` and `last_match_frame2` - There are corner cases where a spike can be counted twice in the the spiketrain 2 in case of bursting situations - (below delta_frames) in the spiketrain 1. To ensure that the number of match do not exceed the number of spike, - we applied a final clip. + There are corner cases where a spike can be counted twice in the spiketrain 2 if there are bouts of bursting activity + (below delta_frames) in the spiketrain 1. To ensure that the number of matches does not exceed the number of spikes, + we apply a final clip. + For more details on the rationale behind this approach, refer to the documentation of this module and/or the metrics section in SpikeForest documentation. @@ -265,7 +296,7 @@ def make_match_count_matrix(sorting1, sorting2, delta_frames, symmetric=False): delta_frames, ) - if symmetric: + if ensure_symmetry: matching_matrix_seg_switch = get_optimized_compute_matching_matrix()( sample_frames2_sorted, sample_frames1_sorted, @@ -327,7 +358,7 @@ def make_agreement_scores(sorting1, sorting2, delta_frames): event_counts1 = pd.Series(ev_counts1, index=unit1_ids) event_counts2 = pd.Series(ev_counts2, index=unit2_ids) - match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames) + match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames, ensure_symmetry=True) agreement_scores = make_agreement_scores_from_count(match_event_count, event_counts1, event_counts2) diff --git a/src/spikeinterface/comparison/paircomparisons.py b/src/spikeinterface/comparison/paircomparisons.py index d6d40c8d8c..02e74b7053 100644 --- a/src/spikeinterface/comparison/paircomparisons.py +++ b/src/spikeinterface/comparison/paircomparisons.py @@ -28,7 +28,7 @@ def __init__( delta_time=0.4, match_score=0.5, chance_score=0.1, - symmetric=False, + ensure_symmetry=False, n_jobs=1, verbose=False, ): @@ -56,7 +56,7 @@ def __init__( self.unit1_ids = self.sorting1.get_unit_ids() self.unit2_ids = self.sorting2.get_unit_ids() - self.symmetric = symmetric + self.ensure_symmetry = ensure_symmetry self._do_agreement() self._do_matching() @@ -88,7 +88,7 @@ def _do_agreement(self): # matrix of event match count for each pair self.match_event_count = make_match_count_matrix( - self.sorting1, self.sorting2, self.delta_frames, symmetric=self.symmetric + self.sorting1, self.sorting2, self.delta_frames, ensure_symmetry=self.ensure_symmetry ) # agreement matrix score for each pair @@ -156,7 +156,7 @@ def __init__( delta_time=delta_time, match_score=match_score, chance_score=chance_score, - symmetric=True, + ensure_symmetry=True, n_jobs=n_jobs, verbose=verbose, ) @@ -289,7 +289,7 @@ def __init__( delta_time=delta_time, match_score=match_score, chance_score=chance_score, - symmetric=False, + ensure_symmetry=False, n_jobs=n_jobs, verbose=verbose, ) diff --git a/src/spikeinterface/comparison/tests/test_comparisontools.py b/src/spikeinterface/comparison/tests/test_comparisontools.py index b6cd3fc3b4..31adee8ca4 100644 --- a/src/spikeinterface/comparison/tests/test_comparisontools.py +++ b/src/spikeinterface/comparison/tests/test_comparisontools.py @@ -154,17 +154,17 @@ def test_make_match_count_matrix_repeated_matching_but_no_double_counting_2(): sorting1, sorting2 = make_sorting(frames_spike_train1, unit_indices1, frames_spike_train2, unit_indices2) # this is easy because it is sorting2 centric - result = make_match_count_matrix(sorting2, sorting1, delta_frames=delta_frames, symmetric=False) + result = make_match_count_matrix(sorting2, sorting1, delta_frames=delta_frames, ensure_symmetry=False) expected_result = np.array([[2]]) assert_array_equal(result.to_numpy(), expected_result) # this work only because we protect by clipping - result = make_match_count_matrix(sorting1, sorting2, delta_frames=delta_frames, symmetric=False) + result = make_match_count_matrix(sorting1, sorting2, delta_frames=delta_frames, ensure_symmetry=False) expected_result = np.array([[2]]) assert_array_equal(result.to_numpy(), expected_result) -def test_make_match_count_matrix_symmetric(): +def test_make_match_count_matrix_ensure_symmetry(): frames_spike_train1 = [ 100, 102, @@ -179,8 +179,8 @@ def test_make_match_count_matrix_symmetric(): sorting1, sorting2 = make_sorting(frames_spike_train1, unit_indices1, frames_spike_train2, unit_indices2) - result = make_match_count_matrix(sorting1, sorting2, delta_frames=delta_frames, symmetric=True) - result_T = make_match_count_matrix(sorting2, sorting1, delta_frames=delta_frames, symmetric=True) + result = make_match_count_matrix(sorting1, sorting2, delta_frames=delta_frames, ensure_symmetry=True) + result_T = make_match_count_matrix(sorting2, sorting1, delta_frames=delta_frames, ensure_symmetry=True) assert_array_equal(result.T, result_T) @@ -481,23 +481,23 @@ def test_do_count_score_and_perf(): if __name__ == "__main__": - # test_make_match_count_matrix() - # test_make_match_count_matrix_sorting_with_itself_simple() - # test_make_match_count_matrix_sorting_with_itself_longer() - # test_make_match_count_matrix_with_mismatched_sortings() - # test_make_match_count_matrix_no_double_matching() - # test_make_match_count_matrix_repeated_matching_but_no_double_counting() - # test_make_match_count_matrix_repeated_matching_but_no_double_counting_2() - # test_make_match_count_matrix_test_proper_search_in_the_second_train() - test_make_match_count_matrix_symmetric() - - # test_make_agreement_scores() - - # test_make_possible_match() - # test_make_best_match() - # test_make_hungarian_match() - - # test_do_score_labels() - # test_compare_spike_trains() - # test_do_confusion_matrix() - # test_do_count_score_and_perf() + test_make_match_count_matrix() + test_make_match_count_matrix_sorting_with_itself_simple() + test_make_match_count_matrix_sorting_with_itself_longer() + test_make_match_count_matrix_with_mismatched_sortings() + test_make_match_count_matrix_no_double_matching() + test_make_match_count_matrix_repeated_matching_but_no_double_counting() + test_make_match_count_matrix_repeated_matching_but_no_double_counting_2() + test_make_match_count_matrix_test_proper_search_in_the_second_train() + test_make_match_count_matrix_ensure_symmetry() + + test_make_agreement_scores() + + test_make_possible_match() + test_make_best_match() + test_make_hungarian_match() + + test_do_score_labels() + test_compare_spike_trains() + test_do_confusion_matrix() + test_do_count_score_and_perf() From bb38af83a4ebea4af91ae389306ae08226f8fd46 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 13 Nov 2023 12:58:23 +0000 Subject: [PATCH 66/80] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/comparison/comparisontools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index 7bab5fe3aa..f475dbf2e1 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -257,7 +257,7 @@ def make_match_count_matrix(sorting1, sorting2, delta_frames, ensure_symmetry=Fa There are corner cases where a spike can be counted twice in the spiketrain 2 if there are bouts of bursting activity (below delta_frames) in the spiketrain 1. To ensure that the number of matches does not exceed the number of spikes, we apply a final clip. - + For more details on the rationale behind this approach, refer to the documentation of this module and/or the metrics section in SpikeForest documentation. From 8d3b830bb79fd0e2f0e4b55fffee69ba5cc59a48 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 13 Nov 2023 14:10:09 +0100 Subject: [PATCH 67/80] Improve do_count_event() --- src/spikeinterface/comparison/comparisontools.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index 7bab5fe3aa..b3d76f9da5 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -63,6 +63,9 @@ def compute_agreement_score(num_matches, num1, num2): def do_count_event(sorting): """ Count event for each units in a sorting. + + Kept for backward compatibility sorting.count_num_spikes_per_unit() is doing the same. + Parameters ---------- sorting: SortingExtractor @@ -75,14 +78,8 @@ def do_count_event(sorting): """ import pandas as pd - unit_ids = sorting.get_unit_ids() - ev_counts = np.zeros(len(unit_ids), dtype="int64") - for segment_index in range(sorting.get_num_segments()): - ev_counts += np.array( - [len(sorting.get_unit_spike_train(u, segment_index=segment_index)) for u in unit_ids], dtype="int64" - ) - event_counts = pd.Series(ev_counts, index=unit_ids) - return event_counts + return pd.Series(sorting.count_num_spikes_per_unit()) + def count_match_spikes(times1, all_times2, delta_frames): # , event_counts1, event_counts2 unit2_ids, From 27f4e6dbf4d0d48f640e53c530b90fbfb3434166 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 13 Nov 2023 13:11:30 +0000 Subject: [PATCH 68/80] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/comparison/comparisontools.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index 64b9b757f6..c56d02e3b3 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -63,7 +63,7 @@ def compute_agreement_score(num_matches, num1, num2): def do_count_event(sorting): """ Count event for each units in a sorting. - + Kept for backward compatibility sorting.count_num_spikes_per_unit() is doing the same. Parameters @@ -81,7 +81,6 @@ def do_count_event(sorting): return pd.Series(sorting.count_num_spikes_per_unit()) - def count_match_spikes(times1, all_times2, delta_frames): # , event_counts1, event_counts2 unit2_ids, """ Computes matching spikes between one spike train and a list of others. From 0349dbb017c75cbd2fb2c54b94e24e053c19b0b3 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 13 Nov 2023 14:36:24 +0100 Subject: [PATCH 69/80] Update src/spikeinterface/comparison/comparisontools.py --- src/spikeinterface/comparison/comparisontools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index c56d02e3b3..731753287e 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -131,7 +131,7 @@ def compute_matching_matrix( """ Internal function used by `make_match_count_matrix()`. This function is for one segment only. - The llop over segment is done in `make_match_count_matrix()` + The loop over segment is done in `make_match_count_matrix()` Parameters ---------- From 62aa0ad36e632ecac94b9f3d37bf3d7081ea8989 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 13 Nov 2023 14:53:09 +0100 Subject: [PATCH 70/80] expose ensure_symmetry in make_agreement_scores() --- src/spikeinterface/comparison/comparisontools.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index 64b9b757f6..5fd8195998 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -324,12 +324,12 @@ def make_match_count_matrix(sorting1, sorting2, delta_frames, ensure_symmetry=Fa return match_event_counts_df -def make_agreement_scores(sorting1, sorting2, delta_frames): +def make_agreement_scores(sorting1, sorting2, delta_frames, ensure_symmetry=True): """ Make the agreement matrix. No threshold (min_score) is applied at this step. - Note : this computation is symmetric. + Note : this computation is symmetric by default. Inverting sorting1 and sorting2 give the transposed matrix. Parameters @@ -340,7 +340,9 @@ def make_agreement_scores(sorting1, sorting2, delta_frames): The second sorting extractor delta_frames: int Number of frames to consider spikes coincident - + ensure_symmetry: bool, default: True + If ensure_symmetry is True, then the algo is run two times by switching sorting1 and sorting2. + And the minimum of the two results is taken. Returns ------- agreement_scores: array (float) @@ -356,7 +358,7 @@ def make_agreement_scores(sorting1, sorting2, delta_frames): event_counts1 = pd.Series(ev_counts1, index=unit1_ids) event_counts2 = pd.Series(ev_counts2, index=unit2_ids) - match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames, ensure_symmetry=True) + match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames, ensure_symmetry=ensure_symmetry) agreement_scores = make_agreement_scores_from_count(match_event_count, event_counts1, event_counts2) From 93cd8fff205a71112bfb5e365dbfae022e0c5d12 Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Mon, 13 Nov 2023 14:53:57 +0100 Subject: [PATCH 71/80] Update src/spikeinterface/comparison/comparisontools.py Co-authored-by: Heberto Mayorquin --- src/spikeinterface/comparison/comparisontools.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index 8e3b984420..19ba6afd27 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -155,7 +155,8 @@ def compute_matching_matrix( Returns ------- matching_matrix : ndarray - A 2D numpy array of shape `(num_units_train1, num_units_train2)` + A 2D numpy array of shape `(num_units_train1, num_units_train2)`. Each element `[i, j]` represents + the count of matching spike pairs between unit `i` from `spike_frames_train1` and unit `j` from `spike_frames_train2`. """ From 4ddefe390f7a6a9a241d0387fe46042b01710503 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 13 Nov 2023 14:59:03 +0100 Subject: [PATCH 72/80] Add last PR to release notes --- doc/releases/0.99.1.rst | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/doc/releases/0.99.1.rst b/doc/releases/0.99.1.rst index 00acbbbd75..688f9f6a41 100644 --- a/doc/releases/0.99.1.rst +++ b/doc/releases/0.99.1.rst @@ -8,7 +8,6 @@ SpikeInterface 0.99.1 release notes Minor release with some bug fixes. * Fix crash when default start / end frame arguments on motion interpolation are used (#2176) -* Fix bug in `make_match_count_matrix()` when computing matching events (#2182) -* Fix corner case in `make_match_count_matrix()` and make it symmetric (#2191) +* Fix bug in `make_match_count_matrix()` when computing matching events (#2182, #2191, #2196) * Fix maxwell tests by setting HDF5_PLUGIN_PATH env in action (#2161) * Add read_npz_sorting to extractors module (#2183) From 3df7a2cbe80712dd8f53a68a3f082cb5330c2bf6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 13 Nov 2023 17:40:11 +0000 Subject: [PATCH 73/80] [pre-commit.ci] pre-commit autoupdate MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/psf/black: 23.10.1 → 23.11.0](https://github.com/psf/black/compare/23.10.1...23.11.0) --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9770856dfa..9cc1129ed2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,7 +6,7 @@ repos: - id: end-of-file-fixer - id: trailing-whitespace - repo: https://github.com/psf/black - rev: 23.10.1 + rev: 23.11.0 hooks: - id: black files: ^src/ From 673f82a577d72a6f36a543e662fbc4764983cca7 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 14 Nov 2023 18:46:07 +0100 Subject: [PATCH 74/80] After release --- pyproject.toml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 954a6dbc8d..658703b25c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "spikeinterface" -version = "0.99.1" +version = "0.100.0.dev0" authors = [ { name="Alessio Buccino", email="alessiop.buccino@gmail.com" }, { name="Samuel Garcia", email="sam.garcia.die@gmail.com" }, @@ -146,8 +146,8 @@ test = [ # for github test : probeinterface and neo from master # for release we need pypi, so this need to be commented - # "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", - # "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", + "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", + "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", ] docs = [ @@ -164,8 +164,8 @@ docs = [ "hdbscan>=0.8.33", # For sorters spykingcircus2 + tridesclous "numba", # For many postprocessing functions # for release we need pypi, so this needs to be commented - # "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", # We always build from the latest version - # "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", # We always build from the latest version + "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", # We always build from the latest version + "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", # We always build from the latest version ] From 97f46c3ed9a15cbb961ff4a38ea981f5c4e4cb0f Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 15 Nov 2023 10:21:22 +0100 Subject: [PATCH 75/80] add rename_units method and test --- src/spikeinterface/core/basesorting.py | 52 ++++++++++++++----- .../core/tests/test_basesorting.py | 13 +++++ 2 files changed, 51 insertions(+), 14 deletions(-) diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 94b08d8cc3..9f94e38bb4 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -57,13 +57,13 @@ def add_sorting_segment(self, sorting_segment): self._sorting_segments.append(sorting_segment) sorting_segment.set_parent_extractor(self) - def get_sampling_frequency(self): + def get_sampling_frequency(self) -> float: return self._sampling_frequency - def get_num_segments(self): + def get_num_segments(self) -> int: return len(self._sorting_segments) - def get_num_samples(self, segment_index=None): + def get_num_samples(self, segment_index=None) -> int: """Returns the number of samples of the associated recording for a segment. Parameters @@ -82,7 +82,7 @@ def get_num_samples(self, segment_index=None): ), "This methods requires an associated recording. Call self.register_recording() first." return self._recording.get_num_samples(segment_index=segment_index) - def get_total_samples(self): + def get_total_samples(self) -> int: """Returns the total number of samples of the associated recording. Returns @@ -299,9 +299,11 @@ def count_num_spikes_per_unit(self) -> dict: return num_spikes - def count_total_num_spikes(self): + def count_total_num_spikes(self) -> int: """ - Get total number of spikes summed across segment and units. + Get total number of spikes in the sorting. + + This is the sum of all spikes in all segments across all units. Returns ------- @@ -310,9 +312,10 @@ def count_total_num_spikes(self): """ return self.to_spike_vector().size - def select_units(self, unit_ids, renamed_unit_ids=None): + def select_units(self, unit_ids, renamed_unit_ids=None) -> BaseSorting: """ - Selects a subset of units + Returns a new sorting object which contains only a selected subset of units. + Parameters ---------- @@ -331,9 +334,30 @@ def select_units(self, unit_ids, renamed_unit_ids=None): sub_sorting = UnitsSelectionSorting(self, unit_ids, renamed_unit_ids=renamed_unit_ids) return sub_sorting - def remove_units(self, remove_unit_ids): + def rename_units(self, new_unit_ids: np.ndarray | list) -> BaseSorting: """ - Removes a subset of units + Returns a new sorting object with renamed units. + + + Parameters + ---------- + new_unit_ids : numpy.array or list + List of new names for unit ids. + They should map positionally to the existing unit ids. + + Returns + ------- + BaseSorting + Sorting object with renamed units + """ + from spikeinterface import UnitsSelectionSorting + + sub_sorting = UnitsSelectionSorting(self, renamed_unit_ids=new_unit_ids) + return sub_sorting + + def remove_units(self, remove_unit_ids) -> BaseSorting: + """ + Returns a new sorting object with contains only a selected subset of units. Parameters ---------- @@ -343,7 +367,7 @@ def remove_units(self, remove_unit_ids): Returns ------- BaseSorting - Sorting object without removed units + Sorting without the removed units """ from spikeinterface import UnitsSelectionSorting @@ -353,7 +377,8 @@ def remove_units(self, remove_unit_ids): def remove_empty_units(self): """ - Removes units with empty spike trains + Returns a new sorting object which contains only units with at least one spike. + Returns ------- @@ -389,7 +414,7 @@ def get_all_spike_trains(self, outputs="unit_id"): """ Return all spike trains concatenated. - This is deprecated use sorting.to_spike_vector() instead + This is deprecated and will be removed in spikeinterface 0.102 use sorting.to_spike_vector() instead """ warnings.warn( @@ -429,7 +454,6 @@ def to_spike_vector(self, concatenated=True, extremum_channel_inds=None, use_cac Construct a unique structured numpy vector concatenating all spikes with several fields: sample_index, unit_index, segment_index. - See also `get_all_spike_trains()` Parameters ---------- diff --git a/src/spikeinterface/core/tests/test_basesorting.py b/src/spikeinterface/core/tests/test_basesorting.py index 0bdd9aecdd..a35898b420 100644 --- a/src/spikeinterface/core/tests/test_basesorting.py +++ b/src/spikeinterface/core/tests/test_basesorting.py @@ -22,6 +22,7 @@ ) from spikeinterface.core.base import BaseExtractor from spikeinterface.core.testing import check_sorted_arrays_equal, check_sortings_equal +from spikeinterface.core.generate import generate_sorting if hasattr(pytest, "global_test_folder"): cache_folder = pytest.global_test_folder / "core" @@ -169,6 +170,18 @@ def test_npy_sorting(): assert_raises(Exception, sorting.register_recording, rec) +def test_rename_units_method(): + num_units = 2 + durations = [1.0, 1.0] + + sorting = generate_sorting(num_units=num_units, durations=durations) + + new_unit_ids = ["a", "b"] + new_sorting = sorting.rename_units(new_unit_ids=new_unit_ids) + + assert np.array_equal(new_sorting.get_unit_ids(), new_unit_ids) + + def test_empty_sorting(): sorting = NumpySorting.from_unit_dict({}, 30000) From bf20cb5cf909d5178e09fbc2b7bde30170be4f09 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20WYNGAARD?= Date: Wed, 15 Nov 2023 11:23:49 +0100 Subject: [PATCH 76/80] Alessio suggestion Co-authored-by: Alessio Buccino --- src/spikeinterface/core/waveform_extractor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/waveform_extractor.py b/src/spikeinterface/core/waveform_extractor.py index dd7c3043c3..9d38a46bb2 100644 --- a/src/spikeinterface/core/waveform_extractor.py +++ b/src/spikeinterface/core/waveform_extractor.py @@ -558,7 +558,7 @@ def has_extension(self, extension_name: str) -> bool: def is_extension(self, extension_name) -> bool: warn( - "WaveformExtractor.is_extension is deprecated and will be removed in version 0.102.0! Use `has_extension` instead." + "WaveformExtractor.is_extension is deprecated and will be removed in version 0.102.0! Use `has_extension` instead.", DeprecationWarning, stacklevel=2 ) return self.has_extension(extension_name) From 2b0a432070f4298a1acdbb0d2147e5b4fd7ee2ac Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 15 Nov 2023 10:24:30 +0000 Subject: [PATCH 77/80] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/waveform_extractor.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/waveform_extractor.py b/src/spikeinterface/core/waveform_extractor.py index 9d38a46bb2..a81d36139d 100644 --- a/src/spikeinterface/core/waveform_extractor.py +++ b/src/spikeinterface/core/waveform_extractor.py @@ -558,7 +558,9 @@ def has_extension(self, extension_name: str) -> bool: def is_extension(self, extension_name) -> bool: warn( - "WaveformExtractor.is_extension is deprecated and will be removed in version 0.102.0! Use `has_extension` instead.", DeprecationWarning, stacklevel=2 + "WaveformExtractor.is_extension is deprecated and will be removed in version 0.102.0! Use `has_extension` instead.", + DeprecationWarning, + stacklevel=2, ) return self.has_extension(extension_name) From 874db2e1932479bbc3ba3522418912d8402c07ad Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Wed, 15 Nov 2023 12:06:53 +0100 Subject: [PATCH 78/80] handling segments --- src/spikeinterface/sorters/internal/spyking_circus2.py | 2 ++ src/spikeinterface/sorters/internal/tridesclous2.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index a16b642dd5..fd283a8224 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -33,6 +33,8 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): "job_kwargs": {"n_jobs": -1}, } + handle_multi_segment = True + @classmethod def get_sorter_version(cls): return "2.0" diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index e256915fa6..eb2ddc922d 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -50,6 +50,8 @@ class Tridesclous2Sorter(ComponentsBasedSorter): "save_array": True, } + handle_multi_segment = True + @classmethod def get_sorter_version(cls): return "2.0" From 8a70f80c1f14d123c215b03fdf90089c6f1b3a0e Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 15 Nov 2023 14:36:48 +0100 Subject: [PATCH 79/80] Avoid duplicated template and quality metric names --- src/spikeinterface/postprocessing/template_metrics.py | 2 +- src/spikeinterface/qualitymetrics/quality_metric_calculator.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index 858af3ee08..f68081dbda 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -60,7 +60,7 @@ def _set_params( metric_names += get_multi_channel_template_metric_names() metrics_kwargs = metrics_kwargs or dict() params = dict( - metric_names=[str(name) for name in metric_names], + metric_names=[str(name) for name in np.unique(metric_names)], sparsity=sparsity, peak_sign=peak_sign, upsampling_factor=int(upsampling_factor), diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index 53309db282..428139b3b9 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -61,7 +61,7 @@ def _set_params( qm_params_[k]["peak_sign"] = peak_sign params = dict( - metric_names=[str(name) for name in metric_names], + metric_names=[str(name) for name in np.unique(metric_names)], sparsity=sparsity, peak_sign=peak_sign, seed=seed, From d692439159e08cc5de33dd6ccbbbe3e19513f550 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 15 Nov 2023 15:44:39 +0000 Subject: [PATCH 80/80] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/generate.py | 21 ++++++++++++++----- .../core/tests/test_generate.py | 21 ++++++++++++------- 2 files changed, 29 insertions(+), 13 deletions(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 49a5650622..1c8661d12d 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1336,8 +1336,17 @@ def generate_channel_locations(num_channels, num_columns, contact_spacing_um): return channel_locations -def generate_unit_locations(num_units, channel_locations, margin_um=20.0, minimum_z=5.0, maximum_z=40.0, - minimum_distance=20., max_iteration=100, distance_strict=False, seed=None): +def generate_unit_locations( + num_units, + channel_locations, + margin_um=20.0, + minimum_z=5.0, + maximum_z=40.0, + minimum_distance=20.0, + max_iteration=100, + distance_strict=False, + seed=None, +): rng = np.random.default_rng(seed=seed) units_locations = np.zeros((num_units, 3), dtype="float32") @@ -1364,7 +1373,7 @@ def generate_unit_locations(num_units, channel_locations, margin_um=20.0, minimu else: # random only bad ones in the previous set renew_inds = renew_inds[np.isin(renew_inds, np.unique(inds0))] - + units_locations[:, 0][renew_inds] = rng.uniform(minimum_x, maximum_x, size=renew_inds.size) units_locations[:, 1][renew_inds] = rng.uniform(minimum_y, maximum_y, size=renew_inds.size) units_locations[:, 2][renew_inds] = rng.uniform(minimum_z, maximum_z, size=renew_inds.size) @@ -1374,8 +1383,10 @@ def generate_unit_locations(num_units, channel_locations, margin_um=20.0, minimu if not solution_found: if distance_strict: - raise ValueError(f"generate_unit_locations(): no solution for {minimum_distance=} and {max_iteration=} " - "You can use distance_strict=False or reduce minimum distance") + raise ValueError( + f"generate_unit_locations(): no solution for {minimum_distance=} and {max_iteration=} " + "You can use distance_strict=False or reduce minimum distance" + ) else: warnings.warn(f"generate_unit_locations(): no solution for {minimum_distance=} and {max_iteration=}") diff --git a/src/spikeinterface/core/tests/test_generate.py b/src/spikeinterface/core/tests/test_generate.py index 582120ac51..7b51abcccb 100644 --- a/src/spikeinterface/core/tests/test_generate.py +++ b/src/spikeinterface/core/tests/test_generate.py @@ -292,24 +292,29 @@ def test_generate_single_fake_waveform(): def test_generate_unit_locations(): - seed = 0 probe = generate_multi_columns_probe(num_columns=2, num_contact_per_column=20, xpitch=20, ypitch=20) channel_locations = probe.contact_positions num_units = 100 - minimum_distance = 20. - unit_locations = generate_unit_locations(num_units, channel_locations, - margin_um=20.0, minimum_z=5.0, maximum_z=40.0, - minimum_distance=minimum_distance, max_iteration=500, - distance_strict=False, seed=seed) + minimum_distance = 20.0 + unit_locations = generate_unit_locations( + num_units, + channel_locations, + margin_um=20.0, + minimum_z=5.0, + maximum_z=40.0, + minimum_distance=minimum_distance, + max_iteration=500, + distance_strict=False, + seed=seed, + ) distances = np.linalg.norm(unit_locations[:, np.newaxis] - unit_locations[np.newaxis, :], axis=2) dist_flat = np.triu(distances, k=1).flatten() - dist_flat = dist_flat[dist_flat>0] + dist_flat = dist_flat[dist_flat > 0] assert np.all(dist_flat > minimum_distance) - # import matplotlib.pyplot as plt # fig, ax = plt.subplots() # ax.hist(dist_flat, bins = np.arange(0, 400, 10))