From 324752aff01b0a38d027419c0f2ac92bf8bcd10d Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 16 May 2023 08:28:02 +0000 Subject: [PATCH 01/52] some change to test --- src/spikeinterface/core/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index a0c61a916f..6377eb695c 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -26,7 +26,7 @@ class BaseExtractor: """ - default_missing_property_values = {"f": np.nan, "O": None, "S": "", "U": ""} + default_missing_property_values = {'f': np.nan, "O": None, "S": "", "U": ""} # This replaces the old key_properties # These are annotations/properties that always need to be From 9c5d409daf4a6f10eb593068b6562fcba3387a0a Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 16 May 2023 08:34:28 +0000 Subject: [PATCH 02/52] another change --- src/spikeinterface/core/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 6377eb695c..a68ed36e6f 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -26,7 +26,7 @@ class BaseExtractor: """ - default_missing_property_values = {'f': np.nan, "O": None, "S": "", "U": ""} + default_missing_property_values = {'f': np.nan, "O": None, 'S': "", "U": ""} # This replaces the old key_properties # These are annotations/properties that always need to be From 5b305517c40bbd6c45787f05d65e2c0dbe66bbb8 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 16 May 2023 08:44:36 +0000 Subject: [PATCH 03/52] another attempt --- .github/workflows/pre-commit-post-merge.yml | 1 + src/spikeinterface/core/base.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/pre-commit-post-merge.yml b/.github/workflows/pre-commit-post-merge.yml index e260b1d317..eec3d411f7 100644 --- a/.github/workflows/pre-commit-post-merge.yml +++ b/.github/workflows/pre-commit-post-merge.yml @@ -1,6 +1,7 @@ name: Use pre-commit post merge on: + pull_request: push: branches: [main] diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index a68ed36e6f..f76cda9143 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -35,7 +35,7 @@ class BaseExtractor: _main_properties = [] installed = True - installation_mesg = "" + installation_mesg = '' is_writable = False def __init__(self, main_ids): From 298d3076c89b8a5aed4b6e4728bf160913140210 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 16 May 2023 08:56:36 +0000 Subject: [PATCH 04/52] attempt merge --- .github/workflows/pre-commit-post-merge.yml | 24 +++++++++++++++------ 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/.github/workflows/pre-commit-post-merge.yml b/.github/workflows/pre-commit-post-merge.yml index e260b1d317..29dcb80e91 100644 --- a/.github/workflows/pre-commit-post-merge.yml +++ b/.github/workflows/pre-commit-post-merge.yml @@ -8,10 +8,20 @@ jobs: pre-commit-post-merge: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 - with: - python-version: '3.10' - - uses: pre-commit/action@v3.0.0 - - uses: pre-commit-ci/lite-action@v1.0.1 - if: always() + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: '3.10' + - name: Install pre-commit + run: | + python -m pip install pre-commit + - name: Run pre-commit hooks + run: | + pre-commit run --all-files + - name: Commit changes + env: + GITHUB_TOKEN: ${{ secrets.PRE_COMMIT_TOKEN }} + run: | + git config --local user.email "action@github.com" + git config --local user.name "GitHub Action" + git diff --quiet && git diff --staged --quiet || (git commit -am "Format code"; git push) \ No newline at end of file From daf7feaeeddd8b1c37f04b854a6f7cd02c89f24f Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 16 May 2023 09:02:20 +0000 Subject: [PATCH 05/52] add condition --- .github/workflows/pre-commit-post-merge.yml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/pre-commit-post-merge.yml b/.github/workflows/pre-commit-post-merge.yml index 131cc6584a..19d39b6ad2 100644 --- a/.github/workflows/pre-commit-post-merge.yml +++ b/.github/workflows/pre-commit-post-merge.yml @@ -18,11 +18,12 @@ jobs: python -m pip install pre-commit - name: Run pre-commit hooks run: | - pre-commit run --all-files + pre-commit run --all-files || true - name: Commit changes env: GITHUB_TOKEN: ${{ secrets.PRE_COMMIT_TOKEN }} run: | git config --local user.email "action@github.com" git config --local user.name "GitHub Action" - git diff --quiet && git diff --staged --quiet || (git commit -am "Format code"; git push) \ No newline at end of file + git diff --quiet && git diff --staged --quiet || (git commit -am "Format code"; git push) + if: always() \ No newline at end of file From 798b73282946d2b8e372fce79770c7919d295864 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 16 May 2023 09:17:57 +0000 Subject: [PATCH 06/52] add auth --- .github/workflows/pre-commit-post-merge.yml | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/.github/workflows/pre-commit-post-merge.yml b/.github/workflows/pre-commit-post-merge.yml index 19d39b6ad2..d24d8a3205 100644 --- a/.github/workflows/pre-commit-post-merge.yml +++ b/.github/workflows/pre-commit-post-merge.yml @@ -1,7 +1,6 @@ name: Use pre-commit post merge on: - pull_request: push: branches: [main] @@ -25,5 +24,5 @@ jobs: run: | git config --local user.email "action@github.com" git config --local user.name "GitHub Action" - git diff --quiet && git diff --staged --quiet || (git commit -am "Format code"; git push) - if: always() \ No newline at end of file + git diff --quiet && git diff --staged --quiet || (git commit -am "Format code"; git remote set-url origin https://x-access-token:${{ secrets.PRE_COMMIT_TOKEN }}@github.com/h-mayorquin/spikeinterface.git; git push) + if: always() From 513a344f1fb8ef942467fd05056f1e6e4a8fa541 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 12 Sep 2023 12:20:52 +0200 Subject: [PATCH 07/52] add basic instance and numpy behavior --- .../core/tests/test_template_class.py | 60 +++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 src/spikeinterface/core/tests/test_template_class.py diff --git a/src/spikeinterface/core/tests/test_template_class.py b/src/spikeinterface/core/tests/test_template_class.py new file mode 100644 index 0000000000..defad77e00 --- /dev/null +++ b/src/spikeinterface/core/tests/test_template_class.py @@ -0,0 +1,60 @@ +from pathlib import Path + +import numpy as np +import pytest + +from spikeinterface.core.template import Templates + + +def test_dense_template_instance(): + num_units = 2 + num_samples = 4 + num_channels = 3 + templates_shape = (num_units, num_samples, num_channels) + templates_array = np.arange(num_units * num_samples * num_channels).reshape(templates_shape) + + templates = Templates(templates_array=templates_array) + + assert np.array_equal(templates.templates_array, templates_array) + assert templates.sparsity is None + assert templates.num_units == num_units + assert templates.num_samples == num_samples + assert templates.num_channels == num_channels + + +def test_numpy_like_behavior(): + num_units = 2 + num_samples = 4 + num_channels = 3 + templates_shape = (num_units, num_samples, num_channels) + templates_array = np.arange(num_units * num_samples * num_channels).reshape(templates_shape) + + templates = Templates(templates_array=templates_array) + + # Test that slicing works as in numpy + assert np.array_equal(templates[:], templates_array[:]) + assert np.array_equal(templates[0], templates_array[0]) + assert np.array_equal(templates[0, :], templates_array[0, :]) + assert np.array_equal(templates[0, :, :], templates_array[0, :, :]) + assert np.array_equal(templates[3:5, :, 2], templates_array[3:5, :, 2]) + + # Test unary ufuncs + assert np.array_equal(np.sqrt(templates), np.sqrt(templates_array)) + assert np.array_equal(np.abs(templates), np.abs(templates_array)) + assert np.array_equal(np.mean(templates, axis=0), np.mean(templates_array, axis=0)) + + # Test binary ufuncs + other_array = np.random.rand(*templates_shape) + other_template = Templates(templates_array=other_array) + + assert np.array_equal(np.add(templates, other_template), np.add(templates_array, other_array)) + assert np.array_equal(np.multiply(templates, other_template), np.multiply(templates_array, other_array)) + + # Test chaining of operations + chained_result = np.mean(np.multiply(templates, other_template), axis=0) + chained_expected = np.mean(np.multiply(templates_array, other_array), axis=0) + assert np.array_equal(chained_result, chained_expected) + + # Test ufuncs that return non-ndarray results + assert np.all(np.greater(templates, -1)) + assert not np.any(np.less(templates, 0)) From c242446086bca416766912fab079e17db100bd03 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 12 Sep 2023 12:24:27 +0200 Subject: [PATCH 08/52] add pickability --- .../core/tests/test_template_class.py | 22 ++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/tests/test_template_class.py b/src/spikeinterface/core/tests/test_template_class.py index defad77e00..b864421824 100644 --- a/src/spikeinterface/core/tests/test_template_class.py +++ b/src/spikeinterface/core/tests/test_template_class.py @@ -1,5 +1,5 @@ from pathlib import Path - +import pickle import numpy as np import pytest @@ -58,3 +58,23 @@ def test_numpy_like_behavior(): # Test ufuncs that return non-ndarray results assert np.all(np.greater(templates, -1)) assert not np.any(np.less(templates, 0)) + + +def test_pickle(): + num_units = 2 + num_samples = 4 + num_channels = 3 + templates_shape = (num_units, num_samples, num_channels) + templates_array = np.arange(num_units * num_samples * num_channels).reshape(templates_shape) + + templates = Templates(templates_array=templates_array) + + # Serialize and deserialize the object + serialized = pickle.dumps(templates) + deserialized = pickle.loads(serialized) + + assert np.array_equal(templates.templates_array, deserialized.templates_array) + assert templates.sparsity == deserialized.sparsity + assert templates.num_units == deserialized.num_units + assert templates.num_samples == deserialized.num_samples + assert templates.num_channels == deserialized.num_channels From 4ca9ec6d504adb83561ce8bb57189ff3e9ec7a8e Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 12 Sep 2023 12:40:45 +0200 Subject: [PATCH 09/52] add json test --- src/spikeinterface/core/template.py | 59 +++++++++++++++++++ .../core/tests/test_template_class.py | 54 ++++++++++------- 2 files changed, 93 insertions(+), 20 deletions(-) create mode 100644 src/spikeinterface/core/template.py diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py new file mode 100644 index 0000000000..3542879df4 --- /dev/null +++ b/src/spikeinterface/core/template.py @@ -0,0 +1,59 @@ +import json +from dataclasses import dataclass, field + +import numpy as np + +from spikeinterface.core.sparsity import ChannelSparsity + + +@dataclass +class Templates: + templates_array: np.ndarray + sparsity: ChannelSparsity = None + num_units: int = field(init=False) + num_samples: int = field(init=False) + num_channels: int = field(init=False) + + def __post_init__(self): + self.num_units, self.num_samples, self.num_channels = self.templates_array.shape + + # Implementing the slicing/indexing behavior as numpy + def __getitem__(self, index): + return self.templates_array[index] + + def __array__(self): + return self.templates_array + + def __array_ufunc__(self, ufunc, method, *inputs, **kwargs) -> np.ndarray: + # Replace any Templates instances with their ndarray representation + inputs = tuple(inp.templates_array if isinstance(inp, Templates) else inp for inp in inputs) + + # Apply the ufunc on the transformed inputs + result = getattr(ufunc, method)(*inputs, **kwargs) + + return result + + def to_dict(self): + sparsity = self.sparsity.to_dict() if self.sparsity is not None else None + return { + "templates_array": self.templates_array.tolist(), + "sparsity": sparsity, + "num_units": self.num_units, + "num_samples": self.num_samples, + "num_channels": self.num_channels, + } + + # Construct the object from a dictionary + @classmethod + def from_dict(cls, data): + return cls( + templates_array=np.array(data["templates_array"]), + sparsity=ChannelSparsity(data["sparsity"]), # Assuming you can reconstruct a ChannelSparsity from a string + ) + + def to_json(self): + return json.dumps(self.to_dict()) + + @classmethod + def from_json(cls, json_str): + return cls.from_dict(json.loads(json_str)) diff --git a/src/spikeinterface/core/tests/test_template_class.py b/src/spikeinterface/core/tests/test_template_class.py index b864421824..e9e3f60730 100644 --- a/src/spikeinterface/core/tests/test_template_class.py +++ b/src/spikeinterface/core/tests/test_template_class.py @@ -1,19 +1,28 @@ from pathlib import Path import pickle +import json + import numpy as np import pytest from spikeinterface.core.template import Templates -def test_dense_template_instance(): +@pytest.fixture +def dense_templates(): num_units = 2 num_samples = 4 num_channels = 3 templates_shape = (num_units, num_samples, num_channels) templates_array = np.arange(num_units * num_samples * num_channels).reshape(templates_shape) - templates = Templates(templates_array=templates_array) + return Templates(templates_array=templates_array) + + +def test_dense_template_instance(dense_templates): + templates = dense_templates + templates_array = templates.templates_array + num_units, num_samples, num_channels = templates_array.shape assert np.array_equal(templates.templates_array, templates_array) assert templates.sparsity is None @@ -22,14 +31,9 @@ def test_dense_template_instance(): assert templates.num_channels == num_channels -def test_numpy_like_behavior(): - num_units = 2 - num_samples = 4 - num_channels = 3 - templates_shape = (num_units, num_samples, num_channels) - templates_array = np.arange(num_units * num_samples * num_channels).reshape(templates_shape) - - templates = Templates(templates_array=templates_array) +def test_numpy_like_behavior(dense_templates): + templates = dense_templates + templates_array = templates.templates_array # Test that slicing works as in numpy assert np.array_equal(templates[:], templates_array[:]) @@ -44,7 +48,7 @@ def test_numpy_like_behavior(): assert np.array_equal(np.mean(templates, axis=0), np.mean(templates_array, axis=0)) # Test binary ufuncs - other_array = np.random.rand(*templates_shape) + other_array = np.random.rand(*templates_array.shape) other_template = Templates(templates_array=other_array) assert np.array_equal(np.add(templates, other_template), np.add(templates_array, other_array)) @@ -60,19 +64,29 @@ def test_numpy_like_behavior(): assert not np.any(np.less(templates, 0)) -def test_pickle(): - num_units = 2 - num_samples = 4 - num_channels = 3 - templates_shape = (num_units, num_samples, num_channels) - templates_array = np.arange(num_units * num_samples * num_channels).reshape(templates_shape) - - templates = Templates(templates_array=templates_array) +def test_pickle(dense_templates): + templates = dense_templates # Serialize and deserialize the object serialized = pickle.dumps(templates) - deserialized = pickle.loads(serialized) + deserialized_templates = pickle.loads(serialized) + + assert np.array_equal(templates.templates_array, deserialized_templates.templates_array) + assert templates.sparsity == deserialized_templates.sparsity + assert templates.num_units == deserialized_templates.num_units + assert templates.num_samples == deserialized_templates.num_samples + assert templates.num_channels == deserialized_templates.num_channels + + +def test_jsonification(dense_templates): + templates = dense_templates + # Serialize to JSON string + serialized = templates.to_json() + + # Deserialize back to object + deserialized = Templates.from_json(serialized) + # Check if deserialized object matches original assert np.array_equal(templates.templates_array, deserialized.templates_array) assert templates.sparsity == deserialized.sparsity assert templates.num_units == deserialized.num_units From 107bdf96ec26368dfa0666c2e7ac9bd5dd596563 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 12 Sep 2023 17:20:31 +0200 Subject: [PATCH 10/52] test fancy indices --- src/spikeinterface/core/template.py | 3 ++- src/spikeinterface/core/tests/test_template_class.py | 10 +++++++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index 3542879df4..2906692902 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -46,9 +46,10 @@ def to_dict(self): # Construct the object from a dictionary @classmethod def from_dict(cls, data): + sparsity = ChannelSparsity.from_dict(data["sparsity"]) if data["sparsity"] is not None else None return cls( templates_array=np.array(data["templates_array"]), - sparsity=ChannelSparsity(data["sparsity"]), # Assuming you can reconstruct a ChannelSparsity from a string + sparsity=sparsity, ) def to_json(self): diff --git a/src/spikeinterface/core/tests/test_template_class.py b/src/spikeinterface/core/tests/test_template_class.py index e9e3f60730..62673906ab 100644 --- a/src/spikeinterface/core/tests/test_template_class.py +++ b/src/spikeinterface/core/tests/test_template_class.py @@ -6,6 +6,7 @@ import pytest from spikeinterface.core.template import Templates +from spikeinterface.core.sparsity import ChannelSparsity @pytest.fixture @@ -41,6 +42,14 @@ def test_numpy_like_behavior(dense_templates): assert np.array_equal(templates[0, :], templates_array[0, :]) assert np.array_equal(templates[0, :, :], templates_array[0, :, :]) assert np.array_equal(templates[3:5, :, 2], templates_array[3:5, :, 2]) + # Test fancy indexing + indices = np.array([0, 1]) + assert np.array_equal(templates[indices], templates_array[indices]) + row_indices = np.array([0, 1]) + col_indices = np.array([2, 3]) + assert np.array_equal(templates[row_indices, col_indices], templates_array[row_indices, col_indices]) + mask = templates_array > 0.5 + assert np.array_equal(templates[mask], templates_array[mask]) # Test unary ufuncs assert np.array_equal(np.sqrt(templates), np.sqrt(templates_array)) @@ -50,7 +59,6 @@ def test_numpy_like_behavior(dense_templates): # Test binary ufuncs other_array = np.random.rand(*templates_array.shape) other_template = Templates(templates_array=other_array) - assert np.array_equal(np.add(templates, other_template), np.add(templates_array, other_array)) assert np.array_equal(np.multiply(templates, other_template), np.multiply(templates_array, other_array)) From d2c6ec727c01c47ec9635a5279eb3d432ce370ce Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 20 Sep 2023 22:39:01 +0200 Subject: [PATCH 11/52] alessio and samuel requests --- src/spikeinterface/core/template.py | 90 ++++++--- .../core/tests/test_template_class.py | 178 +++++++++--------- 2 files changed, 144 insertions(+), 124 deletions(-) diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index 2906692902..4b923db2df 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -1,55 +1,67 @@ +import numpy as np import json from dataclasses import dataclass, field +from .sparsity import ChannelSparsity -import numpy as np - -from spikeinterface.core.sparsity import ChannelSparsity - -@dataclass +@dataclass(kw_only=True) class Templates: templates_array: np.ndarray - sparsity: ChannelSparsity = None + sampling_frequency: float + nbefore: int + + sparsity_mask: np.ndarray = None + channel_ids: np.ndarray = None + unit_ids: np.ndarray = None + num_units: int = field(init=False) num_samples: int = field(init=False) num_channels: int = field(init=False) - def __post_init__(self): - self.num_units, self.num_samples, self.num_channels = self.templates_array.shape - - # Implementing the slicing/indexing behavior as numpy - def __getitem__(self, index): - return self.templates_array[index] - - def __array__(self): - return self.templates_array - - def __array_ufunc__(self, ufunc, method, *inputs, **kwargs) -> np.ndarray: - # Replace any Templates instances with their ndarray representation - inputs = tuple(inp.templates_array if isinstance(inp, Templates) else inp for inp in inputs) + nafter: int = field(init=False) + ms_before: float = field(init=False) + ms_after: float = field(init=False) + sparsity: ChannelSparsity = field(init=False) - # Apply the ufunc on the transformed inputs - result = getattr(ufunc, method)(*inputs, **kwargs) - - return result + def __post_init__(self): + self.num_units, self.num_samples = self.templates_array.shape[:2] + if self.sparsity_mask is None: + self.num_channels = self.templates_array.shape[2] + else: + self.num_channels = self.sparsity_mask.shape[1] + self.nafter = self.num_samples - self.nbefore - 1 + self.ms_before = self.nbefore / self.sampling_frequency * 1000 + self.ms_after = self.nafter / self.sampling_frequency * 1000 + if self.channel_ids is None: + self.channel_ids = np.arange(self.num_channels) + if self.unit_ids is None: + self.unit_ids = np.arange(self.num_units) + if self.sparsity_mask is not None: + self.sparsity = ChannelSparsity( + mask=self.sparsity_mask, + unit_ids=self.unit_ids, + channel_ids=self.channel_ids, + ) def to_dict(self): - sparsity = self.sparsity.to_dict() if self.sparsity is not None else None return { "templates_array": self.templates_array.tolist(), - "sparsity": sparsity, - "num_units": self.num_units, - "num_samples": self.num_samples, - "num_channels": self.num_channels, + "sparsity_mask": None if self.sparsity_mask is None else self.sparsity_mask.tolist(), + "channel_ids": self.channel_ids.tolist(), + "unit_ids": self.unit_ids.tolist(), + "sampling_frequency": self.sampling_frequency, + "nbefore": self.nbefore, } - # Construct the object from a dictionary @classmethod def from_dict(cls, data): - sparsity = ChannelSparsity.from_dict(data["sparsity"]) if data["sparsity"] is not None else None return cls( templates_array=np.array(data["templates_array"]), - sparsity=sparsity, + sparsity_mask=None if data["sparsity_mask"] is None else np.array(data["sparsity_mask"]), + channel_ids=np.array(data["channel_ids"]), + unit_ids=np.array(data["unit_ids"]), + sampling_frequency=data["sampling_frequency"], + nbefore=data["nbefore"], ) def to_json(self): @@ -58,3 +70,19 @@ def to_json(self): @classmethod def from_json(cls, json_str): return cls.from_dict(json.loads(json_str)) + + # Implementing the slicing/indexing behavior as numpy + def __getitem__(self, index): + return self.templates_array[index] + + def __array__(self): + return self.templates_array + + def __array_ufunc__(self, ufunc, method, *inputs, **kwargs) -> np.ndarray: + # Replace any Templates instances with their ndarray representation + inputs = tuple(inp.templates_array if isinstance(inp, Templates) else inp for inp in inputs) + + # Apply the ufunc on the transformed inputs + result = getattr(ufunc, method)(*inputs, **kwargs) + + return result diff --git a/src/spikeinterface/core/tests/test_template_class.py b/src/spikeinterface/core/tests/test_template_class.py index 62673906ab..5fc997c6bf 100644 --- a/src/spikeinterface/core/tests/test_template_class.py +++ b/src/spikeinterface/core/tests/test_template_class.py @@ -1,102 +1,94 @@ -from pathlib import Path -import pickle -import json - -import numpy as np import pytest - +import numpy as np +import pickle from spikeinterface.core.template import Templates -from spikeinterface.core.sparsity import ChannelSparsity -@pytest.fixture -def dense_templates(): +@pytest.mark.parametrize("template_obj", ["dense", "sparse"]) +def get_template_object(template_obj): num_units = 2 - num_samples = 4 + num_samples = 5 num_channels = 3 templates_shape = (num_units, num_samples, num_channels) templates_array = np.arange(num_units * num_samples * num_channels).reshape(templates_shape) - return Templates(templates_array=templates_array) - - -def test_dense_template_instance(dense_templates): - templates = dense_templates - templates_array = templates.templates_array - num_units, num_samples, num_channels = templates_array.shape - - assert np.array_equal(templates.templates_array, templates_array) - assert templates.sparsity is None - assert templates.num_units == num_units - assert templates.num_samples == num_samples - assert templates.num_channels == num_channels - - -def test_numpy_like_behavior(dense_templates): - templates = dense_templates - templates_array = templates.templates_array - - # Test that slicing works as in numpy - assert np.array_equal(templates[:], templates_array[:]) - assert np.array_equal(templates[0], templates_array[0]) - assert np.array_equal(templates[0, :], templates_array[0, :]) - assert np.array_equal(templates[0, :, :], templates_array[0, :, :]) - assert np.array_equal(templates[3:5, :, 2], templates_array[3:5, :, 2]) - # Test fancy indexing - indices = np.array([0, 1]) - assert np.array_equal(templates[indices], templates_array[indices]) - row_indices = np.array([0, 1]) - col_indices = np.array([2, 3]) - assert np.array_equal(templates[row_indices, col_indices], templates_array[row_indices, col_indices]) - mask = templates_array > 0.5 - assert np.array_equal(templates[mask], templates_array[mask]) - - # Test unary ufuncs - assert np.array_equal(np.sqrt(templates), np.sqrt(templates_array)) - assert np.array_equal(np.abs(templates), np.abs(templates_array)) - assert np.array_equal(np.mean(templates, axis=0), np.mean(templates_array, axis=0)) - - # Test binary ufuncs - other_array = np.random.rand(*templates_array.shape) - other_template = Templates(templates_array=other_array) - assert np.array_equal(np.add(templates, other_template), np.add(templates_array, other_array)) - assert np.array_equal(np.multiply(templates, other_template), np.multiply(templates_array, other_array)) - - # Test chaining of operations - chained_result = np.mean(np.multiply(templates, other_template), axis=0) - chained_expected = np.mean(np.multiply(templates_array, other_array), axis=0) - assert np.array_equal(chained_result, chained_expected) - - # Test ufuncs that return non-ndarray results - assert np.all(np.greater(templates, -1)) - assert not np.any(np.less(templates, 0)) - - -def test_pickle(dense_templates): - templates = dense_templates - - # Serialize and deserialize the object - serialized = pickle.dumps(templates) - deserialized_templates = pickle.loads(serialized) - - assert np.array_equal(templates.templates_array, deserialized_templates.templates_array) - assert templates.sparsity == deserialized_templates.sparsity - assert templates.num_units == deserialized_templates.num_units - assert templates.num_samples == deserialized_templates.num_samples - assert templates.num_channels == deserialized_templates.num_channels - - -def test_jsonification(dense_templates): - templates = dense_templates - # Serialize to JSON string - serialized = templates.to_json() - - # Deserialize back to object - deserialized = Templates.from_json(serialized) - - # Check if deserialized object matches original - assert np.array_equal(templates.templates_array, deserialized.templates_array) - assert templates.sparsity == deserialized.sparsity - assert templates.num_units == deserialized.num_units - assert templates.num_samples == deserialized.num_samples - assert templates.num_channels == deserialized.num_channels + sampling_frequency = 30_000 + nbefore = 2 + + if template_obj == "dense": + return Templates(templates_array=templates_array, sampling_frequency=sampling_frequency, nbefore=nbefore) + else: # sparse + sparsity_mask = np.array([[True, False, True], [False, True, False]]) + return Templates( + templates_array=templates_array, + sparsity_mask=sparsity_mask, + sampling_frequency=sampling_frequency, + nbefore=nbefore, + ) + + +@pytest.mark.parametrize("template_obj", ["dense", "sparse"]) +def test_pickle_serialization(template_obj, tmp_path): + obj = get_template_object(template_obj) + + # Dump to pickle + pkl_path = tmp_path / "templates.pkl" + with open(pkl_path, "wb") as f: + pickle.dump(obj, f) + + # Load from pickle + with open(pkl_path, "rb") as f: + loaded_obj = pickle.load(f) + + assert np.array_equal(obj.templates_array, loaded_obj.templates_array) + + +@pytest.mark.parametrize("template_obj", ["dense", "sparse"]) +def test_json_serialization(template_obj): + obj = get_template_object(template_obj) + + json_str = obj.to_json() + loaded_obj_from_json = Templates.from_json(json_str) + + assert np.array_equal(obj.templates_array, loaded_obj_from_json.templates_array) + + +# @pytest.fixture +# def dense_templates(): +# num_units = 2 +# num_samples = 4 +# num_channels = 3 +# templates_shape = (num_units, num_samples, num_channels) +# templates_array = np.arange(num_units * num_samples * num_channels).reshape(templates_shape) + +# return Templates(templates_array=templates_array) + + +# def test_pickle(dense_templates): +# templates = dense_templates + +# # Serialize and deserialize the object +# serialized = pickle.dumps(templates) +# deserialized_templates = pickle.loads(serialized) + +# assert np.array_equal(templates.templates_array, deserialized_templates.templates_array) +# assert templates.sparsity == deserialized_templates.sparsity +# assert templates.num_units == deserialized_templates.num_units +# assert templates.num_samples == deserialized_templates.num_samples +# assert templates.num_channels == deserialized_templates.num_channels + + +# def test_jsonification(dense_templates): +# templates = dense_templates +# # Serialize to JSON string +# serialized = templates.to_json() + +# # Deserialize back to object +# deserialized = Templates.from_json(serialized) + +# # Check if deserialized object matches original +# assert np.array_equal(templates.templates_array, deserialized.templates_array) +# assert templates.sparsity == deserialized.sparsity +# assert templates.num_units == deserialized.num_units +# assert templates.num_samples == deserialized.num_samples +# assert templates.num_channels == deserialized.num_channels From 961d26979cf339b8799d9f045f20f477589171c4 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 20 Sep 2023 22:39:38 +0200 Subject: [PATCH 12/52] remove slicing --- src/spikeinterface/core/template.py | 20 --------- .../core/tests/test_template_class.py | 41 ------------------- 2 files changed, 61 deletions(-) diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index 4b923db2df..6dbfb881f6 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -66,23 +66,3 @@ def from_dict(cls, data): def to_json(self): return json.dumps(self.to_dict()) - - @classmethod - def from_json(cls, json_str): - return cls.from_dict(json.loads(json_str)) - - # Implementing the slicing/indexing behavior as numpy - def __getitem__(self, index): - return self.templates_array[index] - - def __array__(self): - return self.templates_array - - def __array_ufunc__(self, ufunc, method, *inputs, **kwargs) -> np.ndarray: - # Replace any Templates instances with their ndarray representation - inputs = tuple(inp.templates_array if isinstance(inp, Templates) else inp for inp in inputs) - - # Apply the ufunc on the transformed inputs - result = getattr(ufunc, method)(*inputs, **kwargs) - - return result diff --git a/src/spikeinterface/core/tests/test_template_class.py b/src/spikeinterface/core/tests/test_template_class.py index 5fc997c6bf..2b6b4c9744 100644 --- a/src/spikeinterface/core/tests/test_template_class.py +++ b/src/spikeinterface/core/tests/test_template_class.py @@ -51,44 +51,3 @@ def test_json_serialization(template_obj): loaded_obj_from_json = Templates.from_json(json_str) assert np.array_equal(obj.templates_array, loaded_obj_from_json.templates_array) - - -# @pytest.fixture -# def dense_templates(): -# num_units = 2 -# num_samples = 4 -# num_channels = 3 -# templates_shape = (num_units, num_samples, num_channels) -# templates_array = np.arange(num_units * num_samples * num_channels).reshape(templates_shape) - -# return Templates(templates_array=templates_array) - - -# def test_pickle(dense_templates): -# templates = dense_templates - -# # Serialize and deserialize the object -# serialized = pickle.dumps(templates) -# deserialized_templates = pickle.loads(serialized) - -# assert np.array_equal(templates.templates_array, deserialized_templates.templates_array) -# assert templates.sparsity == deserialized_templates.sparsity -# assert templates.num_units == deserialized_templates.num_units -# assert templates.num_samples == deserialized_templates.num_samples -# assert templates.num_channels == deserialized_templates.num_channels - - -# def test_jsonification(dense_templates): -# templates = dense_templates -# # Serialize to JSON string -# serialized = templates.to_json() - -# # Deserialize back to object -# deserialized = Templates.from_json(serialized) - -# # Check if deserialized object matches original -# assert np.array_equal(templates.templates_array, deserialized.templates_array) -# assert templates.sparsity == deserialized.sparsity -# assert templates.num_units == deserialized.num_units -# assert templates.num_samples == deserialized.num_samples -# assert templates.num_channels == deserialized.num_channels From 9ee3a1de6741b156f2a56ca5f7455dd2ebf3b768 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 20 Sep 2023 23:15:08 +0200 Subject: [PATCH 13/52] passing tests --- src/spikeinterface/core/template.py | 41 +++++++++++++- .../core/tests/test_template_class.py | 55 ++++++++++++++----- 2 files changed, 79 insertions(+), 17 deletions(-) diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index 6dbfb881f6..8926281dfe 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -1,6 +1,6 @@ import numpy as np import json -from dataclasses import dataclass, field +from dataclasses import dataclass, field, astuple from .sparsity import ChannelSparsity @@ -21,7 +21,7 @@ class Templates: nafter: int = field(init=False) ms_before: float = field(init=False) ms_after: float = field(init=False) - sparsity: ChannelSparsity = field(init=False) + sparsity: ChannelSparsity = field(init=False, default=None) def __post_init__(self): self.num_units, self.num_samples = self.templates_array.shape[:2] @@ -66,3 +66,40 @@ def from_dict(cls, data): def to_json(self): return json.dumps(self.to_dict()) + + @classmethod + def from_json(cls, json_str): + return cls.from_dict(json.loads(json_str)) + + def __eq__(self, other): + """Necessary to compare arrays""" + if not isinstance(other, Templates): + return False + + # Convert the instances to tuples + self_tuple = astuple(self) + other_tuple = astuple(other) + + # Compare each field + for s_field, o_field in zip(self_tuple, other_tuple): + if isinstance(s_field, np.ndarray): + if not np.array_equal(s_field, o_field): + return False + + elif isinstance(s_field, ChannelSparsity): + if not isinstance(o_field, ChannelSparsity): + return False + + # (maybe ChannelSparsity should have its own __eq__ method) + # Compare ChannelSparsity by its mask, unit_ids and channel_ids + if not np.array_equal(s_field.mask, o_field.mask): + return False + if not np.array_equal(s_field.unit_ids, o_field.unit_ids): + return False + if not np.array_equal(s_field.channel_ids, o_field.channel_ids): + return False + else: + if s_field != o_field: + return False + + return True diff --git a/src/spikeinterface/core/tests/test_template_class.py b/src/spikeinterface/core/tests/test_template_class.py index 2b6b4c9744..b395f82d49 100644 --- a/src/spikeinterface/core/tests/test_template_class.py +++ b/src/spikeinterface/core/tests/test_template_class.py @@ -4,8 +4,33 @@ from spikeinterface.core.template import Templates -@pytest.mark.parametrize("template_obj", ["dense", "sparse"]) -def get_template_object(template_obj): +def compare_instances(obj1, obj2): + if not isinstance(obj1, Templates) or not isinstance(obj2, Templates): + raise ValueError("Both objects must be instances of the Templates class") + + for attr, value1 in obj1.__dict__.items(): + value2 = getattr(obj2, attr, None) + + # Comparing numpy arrays + if isinstance(value1, np.ndarray): + if not np.array_equal(value1, value2): + print(f"Attribute '{attr}' is not equal!") + print(f"Value from obj1:\n{value1}") + print(f"Value from obj2:\n{value2}") + return False + # Comparing other types + elif value1 != value2: + print(f"Attribute '{attr}' is not equal!") + print(f"Value from obj1: {value1}") + print(f"Value from obj2: {value2}") + return False + + print("All attributes are equal!") + return True + + +@pytest.mark.parametrize("template_object", ["dense", "sparse"]) +def generate_template_fixture(template_object): num_units = 2 num_samples = 5 num_channels = 3 @@ -15,7 +40,7 @@ def get_template_object(template_obj): sampling_frequency = 30_000 nbefore = 2 - if template_obj == "dense": + if template_object == "dense": return Templates(templates_array=templates_array, sampling_frequency=sampling_frequency, nbefore=nbefore) else: # sparse sparsity_mask = np.array([[True, False, True], [False, True, False]]) @@ -27,27 +52,27 @@ def get_template_object(template_obj): ) -@pytest.mark.parametrize("template_obj", ["dense", "sparse"]) -def test_pickle_serialization(template_obj, tmp_path): - obj = get_template_object(template_obj) +@pytest.mark.parametrize("template_object", ["dense", "sparse"]) +def test_pickle_serialization(template_object, tmp_path): + template = generate_template_fixture(template_object) # Dump to pickle pkl_path = tmp_path / "templates.pkl" with open(pkl_path, "wb") as f: - pickle.dump(obj, f) + pickle.dump(template, f) # Load from pickle with open(pkl_path, "rb") as f: - loaded_obj = pickle.load(f) + template_reloaded = pickle.load(f) - assert np.array_equal(obj.templates_array, loaded_obj.templates_array) + assert template == template_reloaded -@pytest.mark.parametrize("template_obj", ["dense", "sparse"]) -def test_json_serialization(template_obj): - obj = get_template_object(template_obj) +@pytest.mark.parametrize("template_object", ["dense", "sparse"]) +def test_json_serialization(template_object): + template = generate_template_fixture(template_object) - json_str = obj.to_json() - loaded_obj_from_json = Templates.from_json(json_str) + json_str = template.to_json() + template_reloaded_from_json = Templates.from_json(json_str) - assert np.array_equal(obj.templates_array, loaded_obj_from_json.templates_array) + assert template == template_reloaded_from_json From 9d7c9ac134151cb688ec4316c9bc2ff4af9f62ae Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 20 Sep 2023 23:21:13 +0200 Subject: [PATCH 14/52] add densification and sparsification methods --- src/spikeinterface/core/template.py | 12 +++++++++ .../core/tests/test_template_class.py | 25 ------------------- 2 files changed, 12 insertions(+), 25 deletions(-) diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index 8926281dfe..70c7d90527 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -64,6 +64,18 @@ def from_dict(cls, data): nbefore=data["nbefore"], ) + def get_dense_templates(self) -> np.ndarray: + if self.sparsity is None: + return self.templates_array + else: + self.sparsity.to_dense(self.templates_array) + + def get_sparse_templates(self) -> np.ndarray: + if self.sparsity is None: + raise ValueError("Can't return sparse templates without passing a sparsity mask") + else: + self.sparsity.to_sparse(self.templates_array) + def to_json(self): return json.dumps(self.to_dict()) diff --git a/src/spikeinterface/core/tests/test_template_class.py b/src/spikeinterface/core/tests/test_template_class.py index b395f82d49..f92e636d93 100644 --- a/src/spikeinterface/core/tests/test_template_class.py +++ b/src/spikeinterface/core/tests/test_template_class.py @@ -4,31 +4,6 @@ from spikeinterface.core.template import Templates -def compare_instances(obj1, obj2): - if not isinstance(obj1, Templates) or not isinstance(obj2, Templates): - raise ValueError("Both objects must be instances of the Templates class") - - for attr, value1 in obj1.__dict__.items(): - value2 = getattr(obj2, attr, None) - - # Comparing numpy arrays - if isinstance(value1, np.ndarray): - if not np.array_equal(value1, value2): - print(f"Attribute '{attr}' is not equal!") - print(f"Value from obj1:\n{value1}") - print(f"Value from obj2:\n{value2}") - return False - # Comparing other types - elif value1 != value2: - print(f"Attribute '{attr}' is not equal!") - print(f"Value from obj1: {value1}") - print(f"Value from obj2: {value2}") - return False - - print("All attributes are equal!") - return True - - @pytest.mark.parametrize("template_object", ["dense", "sparse"]) def generate_template_fixture(template_object): num_units = 2 From 6e1027bb75abc9ae1977ec0a6b87c7f951588576 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Mon, 25 Sep 2023 10:51:11 +0200 Subject: [PATCH 15/52] adding tests for sparsity and density --- src/spikeinterface/core/sparsity.py | 3 +- src/spikeinterface/core/template.py | 53 ++++++++++++-- .../core/tests/test_template_class.py | 71 ++++++++++++++++--- 3 files changed, 109 insertions(+), 18 deletions(-) diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index 455edcfc80..70b412d487 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -191,7 +191,7 @@ def densify_waveforms(self, waveforms: np.ndarray, unit_id: str) -> np.ndarray: assert self.are_waveforms_sparse(waveforms=waveforms, unit_id=unit_id), assert_msg densified_shape = waveforms.shape[:-1] + (self.num_channels,) - densified_waveforms = np.zeros(densified_shape, dtype=waveforms.dtype) + densified_waveforms = np.zeros(shape=densified_shape, dtype=waveforms.dtype) densified_waveforms[..., non_zero_indices] = waveforms return densified_waveforms @@ -202,6 +202,7 @@ def are_waveforms_dense(self, waveforms: np.ndarray) -> bool: def are_waveforms_sparse(self, waveforms: np.ndarray, unit_id: str) -> bool: non_zero_indices = self.unit_id_to_channel_indices[unit_id] num_active_channels = len(non_zero_indices) + return waveforms.shape[-1] == num_active_channels @classmethod diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index 70c7d90527..54070053eb 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -65,16 +65,52 @@ def from_dict(cls, data): ) def get_dense_templates(self) -> np.ndarray: + # Assumes and object without a sparsity mask already has dense templates if self.sparsity is None: return self.templates_array - else: - self.sparsity.to_dense(self.templates_array) + + dense_waveforms = np.zeros(shape=(self.num_units, self.num_samples, self.num_channels)) + for unit_index, unit_id in enumerate(self.unit_ids): + num_active_channels = self.sparsity.mask[unit_index].sum() + waveforms = self.templates_array[unit_index, :, :num_active_channels] + dense_waveforms[unit_index, ...] = self.sparsity.densify_waveforms(waveforms=waveforms, unit_id=unit_id) + + return dense_waveforms def get_sparse_templates(self) -> np.ndarray: + # Objects without sparsity mask don't have sparsity and therefore can't return sparse templates if self.sparsity is None: raise ValueError("Can't return sparse templates without passing a sparsity mask") - else: - self.sparsity.to_sparse(self.templates_array) + + # Waveforms are already sparse + if not self.sparsity.are_waveforms_dense(self.templates_array): + return self.templates_array + + max_num_active_channels = self.sparsity.max_num_active_channels + sparse_waveforms = np.zeros(shape=(self.num_units, self.num_samples, max_num_active_channels)) + for unit_index, unit_id in enumerate(self.unit_ids): + waveforms = self.templates_array[unit_index, ...] + sparse_waveforms[unit_index, ...] = self.sparsity.sparsify_waveforms(waveforms=waveforms, unit_id=unit_id) + + return sparse_waveforms + + def are_templates_sparse(self) -> bool: + if self.sparsity is None: + return False + + if self.templates_array.shape[-1] == self.num_channels: + return False + + unit_is_sparse = True + for unit_index, unit_id in enumerate(self.unit_ids): + non_zero_indices = self.sparsity.unit_id_to_channel_indices[unit_id] + num_active_channels = len(non_zero_indices) + waveforms = self.templates_array[unit_index, :, :num_active_channels] + unit_is_sparse = self.sparsity.are_waveforms_sparse(waveforms, unit_id=unit_id) + if not unit_is_sparse: + return False + + return unit_is_sparse def to_json(self): return json.dumps(self.to_dict()) @@ -84,7 +120,11 @@ def from_json(cls, json_str): return cls.from_dict(json.loads(json_str)) def __eq__(self, other): - """Necessary to compare arrays""" + """ + Necessary to compare templates because they naturally compare objects by equality of their fields + which is not possible for numpy arrays so we override the __eq__ method to compare each numpy arrays + with np.array_equal + """ if not isinstance(other, Templates): return False @@ -97,12 +137,11 @@ def __eq__(self, other): if isinstance(s_field, np.ndarray): if not np.array_equal(s_field, o_field): return False - + # Compare ChannelSparsity by its mask, unit_ids and channel_ids. Maybe ChannelSparsity should have its own __eq__ method elif isinstance(s_field, ChannelSparsity): if not isinstance(o_field, ChannelSparsity): return False - # (maybe ChannelSparsity should have its own __eq__ method) # Compare ChannelSparsity by its mask, unit_ids and channel_ids if not np.array_equal(s_field.mask, o_field.mask): return False diff --git a/src/spikeinterface/core/tests/test_template_class.py b/src/spikeinterface/core/tests/test_template_class.py index f92e636d93..cf0dffe532 100644 --- a/src/spikeinterface/core/tests/test_template_class.py +++ b/src/spikeinterface/core/tests/test_template_class.py @@ -2,10 +2,10 @@ import numpy as np import pickle from spikeinterface.core.template import Templates +from spikeinterface.core.sparsity import ChannelSparsity -@pytest.mark.parametrize("template_object", ["dense", "sparse"]) -def generate_template_fixture(template_object): +def generate_test_template(template_type): num_units = 2 num_samples = 5 num_channels = 3 @@ -15,10 +15,29 @@ def generate_template_fixture(template_object): sampling_frequency = 30_000 nbefore = 2 - if template_object == "dense": + if template_type == "dense": return Templates(templates_array=templates_array, sampling_frequency=sampling_frequency, nbefore=nbefore) - else: # sparse + elif template_type == "sparse": # sparse with sparse templates sparsity_mask = np.array([[True, False, True], [False, True, False]]) + sparsity = ChannelSparsity( + mask=sparsity_mask, unit_ids=np.arange(num_units), channel_ids=np.arange(num_channels) + ) + + sparse_templates_array = np.zeros(shape=(num_units, num_samples, sparsity.max_num_active_channels)) + for unit_index in range(num_units): + template = templates_array[unit_index, ...] + sparse_template = sparsity.sparsify_waveforms(waveforms=template, unit_id=unit_index) + sparse_templates_array[unit_index, :, : sparse_template.shape[1]] = sparse_template + + return Templates( + templates_array=sparse_templates_array, + sparsity_mask=sparsity_mask, + sampling_frequency=sampling_frequency, + nbefore=nbefore, + ) + elif template_type == "sparse_with_dense_templates": # sparse with dense templates + sparsity_mask = np.array([[True, False, True], [False, True, False]]) + return Templates( templates_array=templates_array, sparsity_mask=sparsity_mask, @@ -27,9 +46,9 @@ def generate_template_fixture(template_object): ) -@pytest.mark.parametrize("template_object", ["dense", "sparse"]) -def test_pickle_serialization(template_object, tmp_path): - template = generate_template_fixture(template_object) +@pytest.mark.parametrize("template_type", ["dense", "sparse", "sparse_with_dense_templates"]) +def test_pickle_serialization(template_type, tmp_path): + template = generate_test_template(template_type) # Dump to pickle pkl_path = tmp_path / "templates.pkl" @@ -43,11 +62,43 @@ def test_pickle_serialization(template_object, tmp_path): assert template == template_reloaded -@pytest.mark.parametrize("template_object", ["dense", "sparse"]) -def test_json_serialization(template_object): - template = generate_template_fixture(template_object) +@pytest.mark.parametrize("template_type", ["dense", "sparse", "sparse_with_dense_templates"]) +def test_json_serialization(template_type): + template = generate_test_template(template_type) json_str = template.to_json() template_reloaded_from_json = Templates.from_json(json_str) assert template == template_reloaded_from_json + + +@pytest.mark.parametrize("template_type", ["dense", "sparse", "sparse_with_dense_templates"]) +def test_get_dense_templates(template_type): + template = generate_test_template(template_type) + dense_templates = template.get_dense_templates() + assert dense_templates.shape == (template.num_units, template.num_samples, template.num_channels) + + +@pytest.mark.parametrize("template_type", ["dense", "sparse", "sparse_with_dense_templates"]) +def test_get_sparse_templates(template_type): + template = generate_test_template(template_type) + + if template_type == "dense": + with pytest.raises(ValueError): + sparse_templates = template.get_sparse_templates() + elif template_type == "sparse": + sparse_templates = template.get_sparse_templates() + assert sparse_templates.shape == ( + template.num_units, + template.num_samples, + template.sparsity.max_num_active_channels, + ) + assert template.are_templates_sparse() + elif template_type == "sparse_with_dense_templates": + sparse_templates = template.get_sparse_templates() + assert sparse_templates.shape == ( + template.num_units, + template.num_samples, + template.sparsity.max_num_active_channels, + ) + assert not template.are_templates_sparse() From d05e67da479716c7aac619dc0e4d3080a0a6d4a5 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 28 Sep 2023 11:24:45 +0200 Subject: [PATCH 16/52] prohibit dense templates when passing sparsity mask --- src/spikeinterface/core/sparsity.py | 29 ++++++----- src/spikeinterface/core/template.py | 50 ++++++++++--------- .../core/tests/test_template_class.py | 23 ++++----- 3 files changed, 55 insertions(+), 47 deletions(-) diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index 5bc2e51e8a..f2da16b757 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -150,11 +150,8 @@ def sparsify_waveforms(self, waveforms: np.ndarray, unit_id: str | int) -> np.nd or a single sparsified waveform (template) with shape (num_samples, num_active_channels). """ - assert_msg = ( - "Waveforms must be dense to sparsify them. " - f"Their last dimension {waveforms.shape[-1]} must be equal to the number of channels {self.num_channels}" - ) - assert self.are_waveforms_dense(waveforms=waveforms), assert_msg + if self.are_waveforms_sparse(waveforms=waveforms, unit_id=unit_id): + return waveforms non_zero_indices = self.unit_id_to_channel_indices[unit_id] sparsified_waveforms = waveforms[..., non_zero_indices] @@ -185,16 +182,20 @@ def densify_waveforms(self, waveforms: np.ndarray, unit_id: str | int) -> np.nda """ non_zero_indices = self.unit_id_to_channel_indices[unit_id] + num_active_channels = len(non_zero_indices) - assert_msg = ( - "Waveforms do not seem to be be in the sparsity shape of this unit_id. The number of active channels is " - f"{len(non_zero_indices)} but the waveform has {waveforms.shape[-1]} active channels." - ) - assert self.are_waveforms_sparse(waveforms=waveforms, unit_id=unit_id), assert_msg + if not self.are_waveforms_sparse(waveforms=waveforms, unit_id=unit_id): + error_message = ( + "Waveforms do not seem to be be in the sparsity shape of this unit_id. The number of active channels is " + f"{num_active_channels} but the waveform has non zero values outsies of those active channels: \n" + f"{waveforms[..., num_active_channels:]}" + ) + raise ValueError(error_message) densified_shape = waveforms.shape[:-1] + (self.num_channels,) densified_waveforms = np.zeros(shape=densified_shape, dtype=waveforms.dtype) - densified_waveforms[..., non_zero_indices] = waveforms + # Maps the active channels to their original indices + densified_waveforms[..., non_zero_indices] = waveforms[..., :num_active_channels] return densified_waveforms @@ -205,7 +206,11 @@ def are_waveforms_sparse(self, waveforms: np.ndarray, unit_id: str | int) -> boo non_zero_indices = self.unit_id_to_channel_indices[unit_id] num_active_channels = len(non_zero_indices) - return waveforms.shape[-1] == num_active_channels + # If any channel is non-zero outside of the active channels, then the waveforms are not sparse + excess_zeros = waveforms[..., num_active_channels:].sum() + are_sparse = excess_zeros == 0 + + return are_sparse @classmethod def from_unit_id_to_channel_ids(cls, unit_id_to_channel_ids, unit_ids, channel_ids): diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index 54070053eb..c0d4869d5e 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -32,6 +32,8 @@ def __post_init__(self): self.nafter = self.num_samples - self.nbefore - 1 self.ms_before = self.nbefore / self.sampling_frequency * 1000 self.ms_after = self.nafter / self.sampling_frequency * 1000 + + # Initialize sparsity object if self.channel_ids is None: self.channel_ids = np.arange(self.num_channels) if self.unit_ids is None: @@ -43,6 +45,10 @@ def __post_init__(self): channel_ids=self.channel_ids, ) + # Test that the templates are sparse if a sparsity mask is passed + if not self._are_passed_templates_sparse(): + raise ValueError("Sparsity mask passed but the templates are not sparse") + def to_dict(self): return { "templates_array": self.templates_array.tolist(), @@ -69,10 +75,11 @@ def get_dense_templates(self) -> np.ndarray: if self.sparsity is None: return self.templates_array - dense_waveforms = np.zeros(shape=(self.num_units, self.num_samples, self.num_channels)) + dense_shape = (self.num_units, self.num_samples, self.num_channels) + dense_waveforms = np.zeros(dense=dense_shape, dtype=self.templates_array.dtype) + for unit_index, unit_id in enumerate(self.unit_ids): - num_active_channels = self.sparsity.mask[unit_index].sum() - waveforms = self.templates_array[unit_index, :, :num_active_channels] + waveforms = self.templates_array[unit_index, ...] dense_waveforms[unit_index, ...] = self.sparsity.densify_waveforms(waveforms=waveforms, unit_id=unit_id) return dense_waveforms @@ -82,12 +89,9 @@ def get_sparse_templates(self) -> np.ndarray: if self.sparsity is None: raise ValueError("Can't return sparse templates without passing a sparsity mask") - # Waveforms are already sparse - if not self.sparsity.are_waveforms_dense(self.templates_array): - return self.templates_array - max_num_active_channels = self.sparsity.max_num_active_channels - sparse_waveforms = np.zeros(shape=(self.num_units, self.num_samples, max_num_active_channels)) + sparse_shape = (self.num_units, self.num_samples, max_num_active_channels) + sparse_waveforms = np.zeros(shape=sparse_shape, dtype=self.templates_array.dtype) for unit_index, unit_id in enumerate(self.unit_ids): waveforms = self.templates_array[unit_index, ...] sparse_waveforms[unit_index, ...] = self.sparsity.sparsify_waveforms(waveforms=waveforms, unit_id=unit_id) @@ -95,22 +99,20 @@ def get_sparse_templates(self) -> np.ndarray: return sparse_waveforms def are_templates_sparse(self) -> bool: - if self.sparsity is None: - return False - - if self.templates_array.shape[-1] == self.num_channels: - return False + return self.sparsity is not None - unit_is_sparse = True + def _are_passed_templates_sparse(self) -> bool: + """ + Tests if the templates passed to the init constructor are sparse + """ + are_templates_sparse = True for unit_index, unit_id in enumerate(self.unit_ids): - non_zero_indices = self.sparsity.unit_id_to_channel_indices[unit_id] - num_active_channels = len(non_zero_indices) - waveforms = self.templates_array[unit_index, :, :num_active_channels] - unit_is_sparse = self.sparsity.are_waveforms_sparse(waveforms, unit_id=unit_id) - if not unit_is_sparse: + waveforms = self.templates_array[unit_index, ...] + are_templates_sparse = self.sparsity.are_waveforms_sparse(waveforms, unit_id=unit_id) + if not are_templates_sparse: return False - return unit_is_sparse + return are_templates_sparse def to_json(self): return json.dumps(self.to_dict()) @@ -122,8 +124,8 @@ def from_json(cls, json_str): def __eq__(self, other): """ Necessary to compare templates because they naturally compare objects by equality of their fields - which is not possible for numpy arrays so we override the __eq__ method to compare each numpy arrays - with np.array_equal + which is not possible for numpy arrays. Therefore, we override the __eq__ method to compare each numpy arrays + using np.array_equal instead """ if not isinstance(other, Templates): return False @@ -137,7 +139,9 @@ def __eq__(self, other): if isinstance(s_field, np.ndarray): if not np.array_equal(s_field, o_field): return False - # Compare ChannelSparsity by its mask, unit_ids and channel_ids. Maybe ChannelSparsity should have its own __eq__ method + + # Compare ChannelSparsity by its mask, unit_ids and channel_ids. + # Maybe ChannelSparsity should have its own __eq__ method elif isinstance(s_field, ChannelSparsity): if not isinstance(o_field, ChannelSparsity): return False diff --git a/src/spikeinterface/core/tests/test_template_class.py b/src/spikeinterface/core/tests/test_template_class.py index cf0dffe532..b1244ab0d1 100644 --- a/src/spikeinterface/core/tests/test_template_class.py +++ b/src/spikeinterface/core/tests/test_template_class.py @@ -23,6 +23,7 @@ def generate_test_template(template_type): mask=sparsity_mask, unit_ids=np.arange(num_units), channel_ids=np.arange(num_channels) ) + # Create sparse templates sparse_templates_array = np.zeros(shape=(num_units, num_samples, sparsity.max_num_active_channels)) for unit_index in range(num_units): template = templates_array[unit_index, ...] @@ -35,6 +36,7 @@ def generate_test_template(template_type): sampling_frequency=sampling_frequency, nbefore=nbefore, ) + elif template_type == "sparse_with_dense_templates": # sparse with dense templates sparsity_mask = np.array([[True, False, True], [False, True, False]]) @@ -46,7 +48,7 @@ def generate_test_template(template_type): ) -@pytest.mark.parametrize("template_type", ["dense", "sparse", "sparse_with_dense_templates"]) +@pytest.mark.parametrize("template_type", ["dense", "sparse"]) def test_pickle_serialization(template_type, tmp_path): template = generate_test_template(template_type) @@ -62,7 +64,7 @@ def test_pickle_serialization(template_type, tmp_path): assert template == template_reloaded -@pytest.mark.parametrize("template_type", ["dense", "sparse", "sparse_with_dense_templates"]) +@pytest.mark.parametrize("template_type", ["dense", "sparse"]) def test_json_serialization(template_type): template = generate_test_template(template_type) @@ -72,14 +74,14 @@ def test_json_serialization(template_type): assert template == template_reloaded_from_json -@pytest.mark.parametrize("template_type", ["dense", "sparse", "sparse_with_dense_templates"]) +@pytest.mark.parametrize("template_type", ["dense", "sparse"]) def test_get_dense_templates(template_type): template = generate_test_template(template_type) dense_templates = template.get_dense_templates() assert dense_templates.shape == (template.num_units, template.num_samples, template.num_channels) -@pytest.mark.parametrize("template_type", ["dense", "sparse", "sparse_with_dense_templates"]) +@pytest.mark.parametrize("template_type", ["dense", "sparse"]) def test_get_sparse_templates(template_type): template = generate_test_template(template_type) @@ -94,11 +96,8 @@ def test_get_sparse_templates(template_type): template.sparsity.max_num_active_channels, ) assert template.are_templates_sparse() - elif template_type == "sparse_with_dense_templates": - sparse_templates = template.get_sparse_templates() - assert sparse_templates.shape == ( - template.num_units, - template.num_samples, - template.sparsity.max_num_active_channels, - ) - assert not template.are_templates_sparse() + + +def test_initialization_fail_with_dense_templates(): + with pytest.raises(ValueError, match="Sparsity mask passed but the templates are not sparse"): + template = generate_test_template(template_type="sparse_with_dense_templates") From cc8a5236e00072985cac399e078877c597b87ecd Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 28 Sep 2023 11:27:42 +0200 Subject: [PATCH 17/52] add docstring --- src/spikeinterface/core/template.py | 35 +++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index c0d4869d5e..dc6e0a5070 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -6,6 +6,41 @@ @dataclass(kw_only=True) class Templates: + """ + A class to represent spike templates, which can be either dense or sparse. + + Attributes + ---------- + templates_array : np.ndarray + Array containing the templates data. + sampling_frequency : float + Sampling frequency of the templates. + nbefore : int + Number of samples before the spike peak. + sparsity_mask : np.ndarray, optional + Binary array indicating the sparsity pattern of the templates. + If `None`, the templates are considered dense. + channel_ids : np.ndarray, optional + Array of channel IDs. If `None`, defaults to an array of increasing integers. + unit_ids : np.ndarray, optional + Array of unit IDs. If `None`, defaults to an array of increasing integers. + num_units : int + Number of units in the templates. Automatically determined from `templates_array`. + num_samples : int + Number of samples per template. Automatically determined from `templates_array`. + num_channels : int + Number of channels in the templates. Automatically determined from `templates_array` or `sparsity_mask`. + nafter : int + Number of samples after the spike peak. Calculated as `num_samples - nbefore - 1`. + ms_before : float + Milliseconds before the spike peak. Calculated from `nbefore` and `sampling_frequency`. + ms_after : float + Milliseconds after the spike peak. Calculated from `nafter` and `sampling_frequency`. + sparsity : ChannelSparsity, optional + Object representing the sparsity pattern of the templates. Calculated from `sparsity_mask`. + If `None`, the templates are considered dense. + """ + templates_array: np.ndarray sampling_frequency: float nbefore: int From 73e95627b9015d20addd78e7b59f697ce5acb335 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 28 Sep 2023 11:34:13 +0200 Subject: [PATCH 18/52] alessio remark about nafter definition --- src/spikeinterface/core/sparsity.py | 3 +-- src/spikeinterface/core/template.py | 12 +++++++----- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index f2da16b757..1593b6c9e4 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -208,9 +208,8 @@ def are_waveforms_sparse(self, waveforms: np.ndarray, unit_id: str | int) -> boo # If any channel is non-zero outside of the active channels, then the waveforms are not sparse excess_zeros = waveforms[..., num_active_channels:].sum() - are_sparse = excess_zeros == 0 - return are_sparse + return int(excess_zeros) == 0 @classmethod def from_unit_id_to_channel_ids(cls, unit_id_to_channel_ids, unit_ids, channel_ids): diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index dc6e0a5070..bc4f7bae80 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -64,7 +64,9 @@ def __post_init__(self): self.num_channels = self.templates_array.shape[2] else: self.num_channels = self.sparsity_mask.shape[1] - self.nafter = self.num_samples - self.nbefore - 1 + + # Time and frames domain information + self.nafter = self.num_samples - self.nbefore self.ms_before = self.nbefore / self.sampling_frequency * 1000 self.ms_after = self.nafter / self.sampling_frequency * 1000 @@ -110,8 +112,8 @@ def get_dense_templates(self) -> np.ndarray: if self.sparsity is None: return self.templates_array - dense_shape = (self.num_units, self.num_samples, self.num_channels) - dense_waveforms = np.zeros(dense=dense_shape, dtype=self.templates_array.dtype) + densified_shape = (self.num_units, self.num_samples, self.num_channels) + dense_waveforms = np.zeros(dense=densified_shape, dtype=self.templates_array.dtype) for unit_index, unit_id in enumerate(self.unit_ids): waveforms = self.templates_array[unit_index, ...] @@ -125,8 +127,8 @@ def get_sparse_templates(self) -> np.ndarray: raise ValueError("Can't return sparse templates without passing a sparsity mask") max_num_active_channels = self.sparsity.max_num_active_channels - sparse_shape = (self.num_units, self.num_samples, max_num_active_channels) - sparse_waveforms = np.zeros(shape=sparse_shape, dtype=self.templates_array.dtype) + sparisfied_shape = (self.num_units, self.num_samples, max_num_active_channels) + sparse_waveforms = np.zeros(shape=sparisfied_shape, dtype=self.templates_array.dtype) for unit_index, unit_id in enumerate(self.unit_ids): waveforms = self.templates_array[unit_index, ...] sparse_waveforms[unit_index, ...] = self.sparsity.sparsify_waveforms(waveforms=waveforms, unit_id=unit_id) From 52c333b1a5bdcc5975b2c67c2341f02e80ca74de Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 28 Sep 2023 11:50:25 +0200 Subject: [PATCH 19/52] fix mistake --- src/spikeinterface/core/template.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index bc4f7bae80..e8c0f83f50 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -113,7 +113,7 @@ def get_dense_templates(self) -> np.ndarray: return self.templates_array densified_shape = (self.num_units, self.num_samples, self.num_channels) - dense_waveforms = np.zeros(dense=densified_shape, dtype=self.templates_array.dtype) + dense_waveforms = np.zeros(shape=densified_shape, dtype=self.templates_array.dtype) for unit_index, unit_id in enumerate(self.unit_ids): waveforms = self.templates_array[unit_index, ...] From 2fb79ccf7fe3ad6a1efe5675171a82e54894aedd Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 24 Oct 2023 09:31:01 +0200 Subject: [PATCH 20/52] Update src/spikeinterface/core/sparsity.py Co-authored-by: Alessio Buccino --- src/spikeinterface/core/sparsity.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index 1593b6c9e4..0a8c165ba5 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -186,8 +186,8 @@ def densify_waveforms(self, waveforms: np.ndarray, unit_id: str | int) -> np.nda if not self.are_waveforms_sparse(waveforms=waveforms, unit_id=unit_id): error_message = ( - "Waveforms do not seem to be be in the sparsity shape of this unit_id. The number of active channels is " - f"{num_active_channels} but the waveform has non zero values outsies of those active channels: \n" + "Waveforms do not seem to be in the sparsity shape for this unit_id. The number of active channels is " + f"{num_active_channels}, but the waveform has non-zero values outsies of those active channels: \n" f"{waveforms[..., num_active_channels:]}" ) raise ValueError(error_message) From 437695c8b5c2f04331b0eac3cc7c876697b0709a Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 24 Oct 2023 10:01:47 +0200 Subject: [PATCH 21/52] changes --- src/spikeinterface/core/sparsity.py | 10 ++++ src/spikeinterface/core/template.py | 71 +++++++++++++---------------- 2 files changed, 41 insertions(+), 40 deletions(-) diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index 1593b6c9e4..990687ca04 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -211,6 +211,16 @@ def are_waveforms_sparse(self, waveforms: np.ndarray, unit_id: str | int) -> boo return int(excess_zeros) == 0 + def sparisfy_templates(self, templates_array: np.ndarray) -> np.ndarray: + max_num_active_channels = self.max_num_active_channels + sparisfied_shape = (self.num_units, self.num_samples, max_num_active_channels) + sparse_templates = np.zeros(shape=sparisfied_shape, dtype=templates_array.dtype) + for unit_index, unit_id in enumerate(self.unit_ids): + template = templates_array[unit_index, ...] + sparse_templates[unit_index, ...] = self.sparsify_waveforms(waveforms=template, unit_id=unit_id) + + return sparse_templates + @classmethod def from_unit_id_to_channel_ids(cls, unit_id_to_channel_ids, unit_ids, channel_ids): """ diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index e8c0f83f50..ed71b6d2ea 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -4,7 +4,7 @@ from .sparsity import ChannelSparsity -@dataclass(kw_only=True) +@dataclass class Templates: """ A class to represent spike templates, which can be either dense or sparse. @@ -18,7 +18,7 @@ class Templates: nbefore : int Number of samples before the spike peak. sparsity_mask : np.ndarray, optional - Binary array indicating the sparsity pattern of the templates. + Boolean array indicating the sparsity pattern of the templates. If `None`, the templates are considered dense. channel_ids : np.ndarray, optional Array of channel IDs. If `None`, defaults to an array of increasing integers. @@ -49,6 +49,8 @@ class Templates: channel_ids: np.ndarray = None unit_ids: np.ndarray = None + check_template_array_and_sparsity_mask_are_consistentency: bool = True + num_units: int = field(init=False) num_samples: int = field(init=False) num_channels: int = field(init=False) @@ -83,29 +85,9 @@ def __post_init__(self): ) # Test that the templates are sparse if a sparsity mask is passed - if not self._are_passed_templates_sparse(): - raise ValueError("Sparsity mask passed but the templates are not sparse") - - def to_dict(self): - return { - "templates_array": self.templates_array.tolist(), - "sparsity_mask": None if self.sparsity_mask is None else self.sparsity_mask.tolist(), - "channel_ids": self.channel_ids.tolist(), - "unit_ids": self.unit_ids.tolist(), - "sampling_frequency": self.sampling_frequency, - "nbefore": self.nbefore, - } - - @classmethod - def from_dict(cls, data): - return cls( - templates_array=np.array(data["templates_array"]), - sparsity_mask=None if data["sparsity_mask"] is None else np.array(data["sparsity_mask"]), - channel_ids=np.array(data["channel_ids"]), - unit_ids=np.array(data["unit_ids"]), - sampling_frequency=data["sampling_frequency"], - nbefore=data["nbefore"], - ) + if self.check_template_array_and_sparsity_mask_are_consistentency: + if not self._are_passed_templates_sparse(): + raise ValueError("Sparsity mask passed but the templates are not sparse") def get_dense_templates(self) -> np.ndarray: # Assumes and object without a sparsity mask already has dense templates @@ -121,20 +103,6 @@ def get_dense_templates(self) -> np.ndarray: return dense_waveforms - def get_sparse_templates(self) -> np.ndarray: - # Objects without sparsity mask don't have sparsity and therefore can't return sparse templates - if self.sparsity is None: - raise ValueError("Can't return sparse templates without passing a sparsity mask") - - max_num_active_channels = self.sparsity.max_num_active_channels - sparisfied_shape = (self.num_units, self.num_samples, max_num_active_channels) - sparse_waveforms = np.zeros(shape=sparisfied_shape, dtype=self.templates_array.dtype) - for unit_index, unit_id in enumerate(self.unit_ids): - waveforms = self.templates_array[unit_index, ...] - sparse_waveforms[unit_index, ...] = self.sparsity.sparsify_waveforms(waveforms=waveforms, unit_id=unit_id) - - return sparse_waveforms - def are_templates_sparse(self) -> bool: return self.sparsity is not None @@ -151,8 +119,31 @@ def _are_passed_templates_sparse(self) -> bool: return are_templates_sparse + def to_dict(self): + return { + "templates_array": self.templates_array, + "sparsity_mask": None if self.sparsity_mask is None else self.sparsity_mask, + "channel_ids": self.channel_ids, + "unit_ids": self.unit_ids, + "sampling_frequency": self.sampling_frequency, + "nbefore": self.nbefore, + } + + @classmethod + def from_dict(cls, data): + return cls( + templates_array=np.asarray(data["templates_array"]), + sparsity_mask=None if data["sparsity_mask"] is None else np.asarray(data["sparsity_mask"]), + channel_ids=np.asarray(data["channel_ids"]), + unit_ids=np.asarray(data["unit_ids"]), + sampling_frequency=data["sampling_frequency"], + nbefore=data["nbefore"], + ) + def to_json(self): - return json.dumps(self.to_dict()) + from spikeinterface.core.core_tools import SIJsonEncoder + + return json.dumps(self.to_dict(), cls=SIJsonEncoder) @classmethod def from_json(cls, json_str): From 600f20f9b7465f7f398097207d036e8ee4ff8d92 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 24 Oct 2023 10:29:30 +0200 Subject: [PATCH 22/52] modify docstring --- src/spikeinterface/core/template.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index ed71b6d2ea..909d47acfc 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -9,7 +9,7 @@ class Templates: """ A class to represent spike templates, which can be either dense or sparse. - Attributes + It is constructed with the following parameters: ---------- templates_array : np.ndarray Array containing the templates data. @@ -17,13 +17,21 @@ class Templates: Sampling frequency of the templates. nbefore : int Number of samples before the spike peak. - sparsity_mask : np.ndarray, optional + sparsity_mask : np.ndarray, optional (default=None) Boolean array indicating the sparsity pattern of the templates. If `None`, the templates are considered dense. - channel_ids : np.ndarray, optional + channel_ids : np.ndarray, optional (default=None) Array of channel IDs. If `None`, defaults to an array of increasing integers. - unit_ids : np.ndarray, optional + unit_ids : np.ndarray, optional (default=None) Array of unit IDs. If `None`, defaults to an array of increasing integers. + check_for_consistent_sparsity : bool, optional (default=True) + When passing a sparsity_mask, this checks that the templates array is also sparse and that it matches the + structure fo the sparsity_masl. + + The following attributes are avaialble after construction: + + Attributes + ---------- num_units : int Number of units in the templates. Automatically determined from `templates_array`. num_samples : int @@ -49,7 +57,7 @@ class Templates: channel_ids: np.ndarray = None unit_ids: np.ndarray = None - check_template_array_and_sparsity_mask_are_consistentency: bool = True + check_for_consistent_sparsity: bool = True num_units: int = field(init=False) num_samples: int = field(init=False) @@ -85,7 +93,7 @@ def __post_init__(self): ) # Test that the templates are sparse if a sparsity mask is passed - if self.check_template_array_and_sparsity_mask_are_consistentency: + if self.check_for_consistent_sparsity: if not self._are_passed_templates_sparse(): raise ValueError("Sparsity mask passed but the templates are not sparse") From a1e6eaec457a55c6043bedcc1349176aa4e57f0c Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 24 Oct 2023 10:34:21 +0200 Subject: [PATCH 23/52] remove tests for get_sparse_templates --- .../core/tests/test_template_class.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/src/spikeinterface/core/tests/test_template_class.py b/src/spikeinterface/core/tests/test_template_class.py index b1244ab0d1..40bb3f2b34 100644 --- a/src/spikeinterface/core/tests/test_template_class.py +++ b/src/spikeinterface/core/tests/test_template_class.py @@ -81,23 +81,6 @@ def test_get_dense_templates(template_type): assert dense_templates.shape == (template.num_units, template.num_samples, template.num_channels) -@pytest.mark.parametrize("template_type", ["dense", "sparse"]) -def test_get_sparse_templates(template_type): - template = generate_test_template(template_type) - - if template_type == "dense": - with pytest.raises(ValueError): - sparse_templates = template.get_sparse_templates() - elif template_type == "sparse": - sparse_templates = template.get_sparse_templates() - assert sparse_templates.shape == ( - template.num_units, - template.num_samples, - template.sparsity.max_num_active_channels, - ) - assert template.are_templates_sparse() - - def test_initialization_fail_with_dense_templates(): with pytest.raises(ValueError, match="Sparsity mask passed but the templates are not sparse"): template = generate_test_template(template_type="sparse_with_dense_templates") From aa08f1ba85aea5eec5835de04618061aa9df244b Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 24 Oct 2023 12:18:09 +0200 Subject: [PATCH 24/52] Update src/spikeinterface/core/template.py Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/core/template.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index 909d47acfc..8ebd0a75f5 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -28,7 +28,7 @@ class Templates: When passing a sparsity_mask, this checks that the templates array is also sparse and that it matches the structure fo the sparsity_masl. - The following attributes are avaialble after construction: + The following attributes are available after construction: Attributes ---------- From ea2a8a03c43b3ca444e02e50b837b1d6b7a51bd9 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 24 Oct 2023 12:18:32 +0200 Subject: [PATCH 25/52] Update src/spikeinterface/core/template.py Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/core/template.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index 8ebd0a75f5..e6556a68f7 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -9,7 +9,7 @@ class Templates: """ A class to represent spike templates, which can be either dense or sparse. - It is constructed with the following parameters: + Parameters ---------- templates_array : np.ndarray Array containing the templates data. From 1be2ce144e6f7f574a278dcf7775f3a11090a2a3 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 30 Oct 2023 18:57:51 +0100 Subject: [PATCH 26/52] Add a minimum distance in generate_unit_locations. --- src/spikeinterface/core/generate.py | 46 ++++++++++++++++--- .../core/tests/test_generate.py | 36 ++++++++++++++- 2 files changed, 74 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 44ea02d32c..003b9cb5b5 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1333,15 +1333,49 @@ def generate_channel_locations(num_channels, num_columns, contact_spacing_um): return channel_locations -def generate_unit_locations(num_units, channel_locations, margin_um=20.0, minimum_z=5.0, maximum_z=40.0, seed=None): +def generate_unit_locations(num_units, channel_locations, margin_um=20.0, minimum_z=5.0, maximum_z=40.0, + minimum_distance=20., max_iteration=100, distance_strict=False, seed=None): rng = np.random.default_rng(seed=seed) units_locations = np.zeros((num_units, 3), dtype="float32") - for dim in (0, 1): - lim0 = np.min(channel_locations[:, dim]) - margin_um - lim1 = np.max(channel_locations[:, dim]) + margin_um - units_locations[:, dim] = rng.uniform(lim0, lim1, size=num_units) + + minimum_x, maximum_x = np.min(channel_locations[:, 0]) - margin_um, np.max(channel_locations[:, 0]) + margin_um + minimum_y, maximum_y = np.min(channel_locations[:, 1]) - margin_um, np.max(channel_locations[:, 1]) + margin_um + + units_locations[:, 0] = rng.uniform(minimum_x, maximum_x, size=num_units) + units_locations[:, 1] = rng.uniform(minimum_y, maximum_y, size=num_units) units_locations[:, 2] = rng.uniform(minimum_z, maximum_z, size=num_units) + if minimum_distance is not None: + solution_found = False + renew_inds = None + for i in range(max_iteration): + distances = np.linalg.norm(units_locations[:, np.newaxis] - units_locations[np.newaxis, :], axis=2) + inds0, inds1 = np.nonzero(distances < minimum_distance) + mask = inds0 != inds1 + inds0 = inds0[mask] + inds1 = inds1[mask] + + if inds0.size > 0: + if renew_inds is None: + renew_inds = np.unique(inds0) + else: + # random only bad ones in the previous set + renew_inds = renew_inds[np.isin(renew_inds, np.unique(inds0))] + + units_locations[:, 0][renew_inds] = rng.uniform(minimum_x, maximum_x, size=renew_inds.size) + units_locations[:, 1][renew_inds] = rng.uniform(minimum_y, maximum_y, size=renew_inds.size) + units_locations[:, 2][renew_inds] = rng.uniform(minimum_z, maximum_z, size=renew_inds.size) + else: + solution_found = True + break + + if not solution_found: + if distance_strict: + raise ValueError(f"generate_unit_locations(): no solution for {minimum_distance=} and {max_iteration=} " + "You can use distance_strict=False or reduce minimum distance") + else: + warnings.warn(f"generate_unit_locations(): no solution for {minimum_distance=} and {max_iteration=}") + return units_locations @@ -1366,7 +1400,7 @@ def generate_ground_truth_recording( upsample_vector=None, generate_sorting_kwargs=dict(firing_rates=15, refractory_period_ms=4.0), noise_kwargs=dict(noise_level=5.0, strategy="on_the_fly"), - generate_unit_locations_kwargs=dict(margin_um=10.0, minimum_z=5.0, maximum_z=50.0), + generate_unit_locations_kwargs=dict(margin_um=10.0, minimum_z=5.0, maximum_z=50.0, minimum_distance=20), generate_templates_kwargs=dict(), dtype="float32", seed=None, diff --git a/src/spikeinterface/core/tests/test_generate.py b/src/spikeinterface/core/tests/test_generate.py index 9a9c61766f..582120ac51 100644 --- a/src/spikeinterface/core/tests/test_generate.py +++ b/src/spikeinterface/core/tests/test_generate.py @@ -4,6 +4,8 @@ import numpy as np from spikeinterface.core import load_extractor, extract_waveforms + +from probeinterface import generate_multi_columns_probe from spikeinterface.core.generate import ( generate_recording, generate_sorting, @@ -289,6 +291,35 @@ def test_generate_single_fake_waveform(): # plt.show() +def test_generate_unit_locations(): + + seed = 0 + + probe = generate_multi_columns_probe(num_columns=2, num_contact_per_column=20, xpitch=20, ypitch=20) + channel_locations = probe.contact_positions + + num_units = 100 + minimum_distance = 20. + unit_locations = generate_unit_locations(num_units, channel_locations, + margin_um=20.0, minimum_z=5.0, maximum_z=40.0, + minimum_distance=minimum_distance, max_iteration=500, + distance_strict=False, seed=seed) + distances = np.linalg.norm(unit_locations[:, np.newaxis] - unit_locations[np.newaxis, :], axis=2) + dist_flat = np.triu(distances, k=1).flatten() + dist_flat = dist_flat[dist_flat>0] + assert np.all(dist_flat > minimum_distance) + + + # import matplotlib.pyplot as plt + # fig, ax = plt.subplots() + # ax.hist(dist_flat, bins = np.arange(0, 400, 10)) + # fig, ax = plt.subplots() + # from probeinterface.plotting import plot_probe + # plot_probe(probe, ax=ax) + # ax.scatter(unit_locations[:, 0], unit_locations[:,1], marker='*', s=20) + # plt.show() + + def test_generate_templates(): seed = 0 @@ -297,7 +328,7 @@ def test_generate_templates(): num_units = 10 margin_um = 15.0 channel_locations = generate_channel_locations(num_chans, num_columns, 20.0) - unit_locations = generate_unit_locations(num_units, channel_locations, margin_um, seed) + unit_locations = generate_unit_locations(num_units, channel_locations, margin_um=margin_um, seed=seed) sampling_frequency = 30000.0 ms_before = 1.0 @@ -436,7 +467,8 @@ def test_generate_ground_truth_recording(): # test_noise_generator_consistency_after_dump(strategy, None) # test_generate_recording() # test_generate_single_fake_waveform() + test_generate_unit_locations() # test_generate_templates() # test_inject_templates() # test_generate_ground_truth_recording() - test_generate_sorting_with_spikes_on_borders() + # test_generate_sorting_with_spikes_on_borders() From 4f8bd73fd08798671fd252ad9a23ea40e12ef7e3 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 2 Nov 2023 11:08:11 +0100 Subject: [PATCH 27/52] Update src/spikeinterface/core/template.py Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/core/template.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index e6556a68f7..8beb6b46b1 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -17,7 +17,7 @@ class Templates: Sampling frequency of the templates. nbefore : int Number of samples before the spike peak. - sparsity_mask : np.ndarray, optional (default=None) + sparsity_mask : np.ndarray or None, default: None Boolean array indicating the sparsity pattern of the templates. If `None`, the templates are considered dense. channel_ids : np.ndarray, optional (default=None) From e52dd28d7fbc512e98bdbb28f4db32930b1cd73f Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 2 Nov 2023 11:09:00 +0100 Subject: [PATCH 28/52] docstring compliance --- src/spikeinterface/core/template.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index 8beb6b46b1..e6372c7082 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -20,11 +20,11 @@ class Templates: sparsity_mask : np.ndarray or None, default: None Boolean array indicating the sparsity pattern of the templates. If `None`, the templates are considered dense. - channel_ids : np.ndarray, optional (default=None) + channel_ids : np.ndarray, optional default: None Array of channel IDs. If `None`, defaults to an array of increasing integers. - unit_ids : np.ndarray, optional (default=None) + unit_ids : np.ndarray, optional default: None Array of unit IDs. If `None`, defaults to an array of increasing integers. - check_for_consistent_sparsity : bool, optional (default=True) + check_for_consistent_sparsity : bool, optional default: None When passing a sparsity_mask, this checks that the templates array is also sparse and that it matches the structure fo the sparsity_masl. From 5c2cb9b45f7c6a3cd4a308214e289f5673067101 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 2 Nov 2023 15:51:22 +0100 Subject: [PATCH 29/52] use python methods instead of parsing --- src/spikeinterface/core/base.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index b51bace55f..b737358eef 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -366,27 +366,24 @@ def to_dict( new_kwargs[name] = transform_extractors_to_dict(value) kwargs = new_kwargs - class_name = str(type(self)).replace("", "") + + module_import_path = self.__class__.__module__ + class_name_no_path = self.__class__.__name__ + class_name = f"{module_import_path}.{class_name_no_path}" # e.g. 'spikeinterface.core.generate.AClass' module = class_name.split(".")[0] - imported_module = importlib.import_module(module) - try: - version = imported_module.__version__ - except AttributeError: - version = "unknown" + imported_module = importlib.import_module(module) + spike_interface_version = getattr(imported_module, "__version__", "unknown") dump_dict = { "class": class_name, "module": module, "kwargs": kwargs, - "version": version, + "version": spike_interface_version, "relative_paths": (relative_to is not None), } - try: - dump_dict["version"] = imported_module.__version__ - except AttributeError: - dump_dict["version"] = "unknown" + dump_dict["version"] = spike_interface_version if include_annotations: dump_dict["annotations"] = self._annotations @@ -805,7 +802,7 @@ def save_to_folder(self, name=None, folder=None, verbose=True, **save_kwargs): * explicit sub-folder, implicit base-folder : `extractor.save(name="extarctor_name")` * generated: `extractor.save()` - The second option saves to subfolder "extarctor_name" in + The second option saves to subfolder "extractor_name" in "get_global_tmp_folder()". You can set the global tmp folder with: "set_global_tmp_folder("path-to-global-folder")" From 3015725a24dfba872c2448409ce9585dbcc392ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20WYNGAARD?= Date: Thu, 2 Nov 2023 16:56:33 +0100 Subject: [PATCH 30/52] `WaveformExtractor.is_extension` --> `has_extension` --- src/spikeinterface/core/waveform_extractor.py | 17 +++++++++++------ src/spikeinterface/exporters/report.py | 8 ++++---- src/spikeinterface/exporters/to_phy.py | 8 ++++---- .../postprocessing/principal_component.py | 2 +- .../tests/common_extension_tests.py | 2 +- .../tests/test_principal_component.py | 2 +- .../qualitymetrics/misc_metrics.py | 10 +++++----- .../qualitymetrics/quality_metric_calculator.py | 8 ++++---- .../tests/test_quality_metric_calculator.py | 2 +- src/spikeinterface/widgets/base.py | 2 +- src/spikeinterface/widgets/unit_summary.py | 10 +++++----- 11 files changed, 38 insertions(+), 33 deletions(-) diff --git a/src/spikeinterface/core/waveform_extractor.py b/src/spikeinterface/core/waveform_extractor.py index c97a727340..491fdd5ebe 100644 --- a/src/spikeinterface/core/waveform_extractor.py +++ b/src/spikeinterface/core/waveform_extractor.py @@ -506,7 +506,7 @@ def get_recording_property(self, key) -> np.ndarray: def get_sorting_property(self, key) -> np.ndarray: return self.sorting.get_property(key) - def get_extension_class(self, extension_name): + def get_extension_class(self, extension_name: str): """ Get extension class from name and check if registered. @@ -525,7 +525,7 @@ def get_extension_class(self, extension_name): ext_class = extensions_dict[extension_name] return ext_class - def is_extension(self, extension_name) -> bool: + def has_extension(self, extension_name: str) -> bool: """ Check if the extension exists in memory or in the folder. @@ -556,7 +556,12 @@ def is_extension(self, extension_name) -> bool: and "params" in self._waveforms_root[extension_name].attrs.keys() ) - def load_extension(self, extension_name): + def is_extension(self, extension_name) -> bool: + warn("WaveformExtractor.is_extension is deprecated! Use `has_extension` instead.") + assert False + return self.has_extension(extension_name) + + def load_extension(self, extension_name: str): """ Load an extension from its name. The module of the extension must be loaded and registered. @@ -572,7 +577,7 @@ def load_extension(self, extension_name): The loaded instance of the extension """ if self.folder is not None and extension_name not in self._loaded_extensions: - if self.is_extension(extension_name): + if self.has_extension(extension_name): ext_class = self.get_extension_class(extension_name) ext = ext_class.load(self.folder, self) if extension_name not in self._loaded_extensions: @@ -588,7 +593,7 @@ def delete_extension(self, extension_name) -> None: extension_name: str The extension name. """ - assert self.is_extension(extension_name), f"The extension {extension_name} is not available" + assert self.has_extension(extension_name), f"The extension {extension_name} is not available" del self._loaded_extensions[extension_name] if self.folder is not None and (self.folder / extension_name).is_dir(): shutil.rmtree(self.folder / extension_name) @@ -610,7 +615,7 @@ def get_available_extension_names(self): """ extension_names_in_folder = [] for extension_class in self.extensions: - if self.is_extension(extension_class.extension_name): + if self.has_extension(extension_class.extension_name): extension_names_in_folder.append(extension_class.extension_name) return extension_names_in_folder diff --git a/src/spikeinterface/exporters/report.py b/src/spikeinterface/exporters/report.py index 57a5ab0166..8b14930859 100644 --- a/src/spikeinterface/exporters/report.py +++ b/src/spikeinterface/exporters/report.py @@ -51,7 +51,7 @@ def export_report( unit_ids = sorting.unit_ids # load or compute spike_amplitudes - if we.is_extension("spike_amplitudes"): + if we.has_extension("spike_amplitudes"): spike_amplitudes = we.load_extension("spike_amplitudes").get_data(outputs="by_unit") elif force_computation: spike_amplitudes = compute_spike_amplitudes(we, peak_sign=peak_sign, outputs="by_unit", **job_kwargs) @@ -62,7 +62,7 @@ def export_report( ) # load or compute quality_metrics - if we.is_extension("quality_metrics"): + if we.has_extension("quality_metrics"): metrics = we.load_extension("quality_metrics").get_data() elif force_computation: metrics = compute_quality_metrics(we) @@ -73,7 +73,7 @@ def export_report( ) # load or compute correlograms - if we.is_extension("correlograms"): + if we.has_extension("correlograms"): correlograms, bins = we.load_extension("correlograms").get_data() elif force_computation: correlograms, bins = compute_correlograms(we, window_ms=100.0, bin_ms=1.0) @@ -84,7 +84,7 @@ def export_report( ) # pre-compute unit locations if not done - if not we.is_extension("unit_locations"): + if not we.has_extension("unit_locations"): unit_locations = compute_unit_locations(we) output_folder = Path(output_folder).absolute() diff --git a/src/spikeinterface/exporters/to_phy.py b/src/spikeinterface/exporters/to_phy.py index ecc5b316ec..607aa3e846 100644 --- a/src/spikeinterface/exporters/to_phy.py +++ b/src/spikeinterface/exporters/to_phy.py @@ -196,7 +196,7 @@ def export_to_phy( templates[unit_ind, :, :][:, : len(chan_inds)] = template templates_ind[unit_ind, : len(chan_inds)] = chan_inds - if waveform_extractor.is_extension("similarity"): + if waveform_extractor.has_extension("similarity"): tmc = waveform_extractor.load_extension("similarity") template_similarity = tmc.get_data() else: @@ -219,7 +219,7 @@ def export_to_phy( np.save(str(output_folder / "channel_groups.npy"), channel_groups) if compute_amplitudes: - if waveform_extractor.is_extension("spike_amplitudes"): + if waveform_extractor.has_extension("spike_amplitudes"): sac = waveform_extractor.load_extension("spike_amplitudes") amplitudes = sac.get_data(outputs="concatenated") else: @@ -231,7 +231,7 @@ def export_to_phy( np.save(str(output_folder / "amplitudes.npy"), amplitudes) if compute_pc_features: - if waveform_extractor.is_extension("principal_components"): + if waveform_extractor.has_extension("principal_components"): pc = waveform_extractor.load_extension("principal_components") else: pc = compute_principal_components( @@ -264,7 +264,7 @@ def export_to_phy( channel_group = pd.DataFrame({"cluster_id": [i for i in range(len(unit_ids))], "channel_group": unit_groups}) channel_group.to_csv(output_folder / "cluster_channel_group.tsv", sep="\t", index=False) - if waveform_extractor.is_extension("quality_metrics"): + if waveform_extractor.has_extension("quality_metrics"): qm = waveform_extractor.load_extension("quality_metrics") qm_data = qm.get_data() for column_name in qm_data.columns: diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index cf32e79b25..effd87007f 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -750,7 +750,7 @@ def compute_principal_components( >>> pc.run_for_all_spikes(file_path="all_pca_projections.npy") """ - if load_if_exists and waveform_extractor.is_extension(WaveformPrincipalComponent.extension_name): + if load_if_exists and waveform_extractor.has_extension(WaveformPrincipalComponent.extension_name): pc = waveform_extractor.load_extension(WaveformPrincipalComponent.extension_name) else: pc = WaveformPrincipalComponent.create(waveform_extractor) diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index b539bbd5d4..2bef246bc2 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -143,7 +143,7 @@ def _test_extension_folder(self, we, in_memory=False): # reload as an extension from we assert self.extension_class.extension_name in we.get_available_extension_names() - assert we.is_extension(self.extension_class.extension_name) + assert we.has_extension(self.extension_class.extension_name) ext = we.load_extension(self.extension_class.extension_name) assert isinstance(ext, self.extension_class) for ext_name in self.extension_data_names: diff --git a/src/spikeinterface/postprocessing/tests/test_principal_component.py b/src/spikeinterface/postprocessing/tests/test_principal_component.py index 49591d9b89..f5e315b18f 100644 --- a/src/spikeinterface/postprocessing/tests/test_principal_component.py +++ b/src/spikeinterface/postprocessing/tests/test_principal_component.py @@ -135,7 +135,7 @@ def test_project_new(self): from sklearn.decomposition import IncrementalPCA we = self.we1 - if we.is_extension("principal_components"): + if we.has_extension("principal_components"): we.delete_extension("principal_components") we_cp = we.select_units(we.unit_ids, self.cache_folder / "toy_waveforms_1seg_cp") diff --git a/src/spikeinterface/qualitymetrics/misc_metrics.py b/src/spikeinterface/qualitymetrics/misc_metrics.py index 5c734b9100..9dab06124b 100644 --- a/src/spikeinterface/qualitymetrics/misc_metrics.py +++ b/src/spikeinterface/qualitymetrics/misc_metrics.py @@ -201,7 +201,7 @@ def compute_snrs( snrs : dict Computed signal to noise ratio for each unit. """ - if waveform_extractor.is_extension("noise_levels"): + if waveform_extractor.has_extension("noise_levels"): noise_levels = waveform_extractor.load_extension("noise_levels").get_data() else: if random_chunk_kwargs_dict is None: @@ -687,7 +687,7 @@ def compute_amplitude_cv_metrics( if unit_ids is None: unit_ids = sorting.unit_ids - if waveform_extractor.is_extension(amplitude_extension): + if waveform_extractor.has_extension(amplitude_extension): sac = waveform_extractor.load_extension(amplitude_extension) amps = sac.get_data(outputs="concatenated") if amplitude_extension == "spike_amplitudes": @@ -803,7 +803,7 @@ def compute_amplitude_cutoffs( spike_amplitudes = None invert_amplitudes = False - if waveform_extractor.is_extension("spike_amplitudes"): + if waveform_extractor.has_extension("spike_amplitudes"): amp_calculator = waveform_extractor.load_extension("spike_amplitudes") spike_amplitudes = amp_calculator.get_data(outputs="by_unit") if amp_calculator._params["peak_sign"] == "pos": @@ -881,7 +881,7 @@ def compute_amplitude_medians(waveform_extractor, peak_sign="neg", unit_ids=None extremum_channels_ids = get_template_extremum_channel(waveform_extractor, peak_sign=peak_sign) spike_amplitudes = None - if waveform_extractor.is_extension("spike_amplitudes"): + if waveform_extractor.has_extension("spike_amplitudes"): amp_calculator = waveform_extractor.load_extension("spike_amplitudes") spike_amplitudes = amp_calculator.get_data(outputs="by_unit") @@ -974,7 +974,7 @@ def compute_drift_metrics( if unit_ids is None: unit_ids = sorting.unit_ids - if waveform_extractor.is_extension("spike_locations"): + if waveform_extractor.has_extension("spike_locations"): locs_calculator = waveform_extractor.load_extension("spike_locations") spike_locations = locs_calculator.get_data(outputs="concatenated") spike_locations_by_unit = locs_calculator.get_data(outputs="by_unit") diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index 53309db282..90a1f7206e 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -42,14 +42,14 @@ def _set_params( if metric_names is None: metric_names = list(_misc_metric_name_to_func.keys()) # if PC is available, PC metrics are automatically added to the list - if self.waveform_extractor.is_extension("principal_components"): + if self.waveform_extractor.has_extension("principal_components"): # by default 'nearest_neightbor' is removed because too slow pc_metrics = _possible_pc_metric_names.copy() pc_metrics.remove("nn_isolation") pc_metrics.remove("nn_noise_overlap") metric_names += pc_metrics # if spike_locations are not available, drift is removed from the list - if not self.waveform_extractor.is_extension("spike_locations"): + if not self.waveform_extractor.has_extension("spike_locations"): if "drift" in metric_names: metric_names.remove("drift") @@ -130,7 +130,7 @@ def _run(self, verbose, **job_kwargs): # metrics based on PCs pc_metric_names = [k for k in metric_names if k in _possible_pc_metric_names] if len(pc_metric_names) > 0 and not self._params["skip_pc_metrics"]: - if not self.waveform_extractor.is_extension("principal_components"): + if not self.waveform_extractor.has_extension("principal_components"): raise ValueError("waveform_principal_component must be provied") pc_extension = self.waveform_extractor.load_extension("principal_components") pc_metrics = calculate_pc_metrics( @@ -216,7 +216,7 @@ def compute_quality_metrics( metrics: pandas.DataFrame Data frame with the computed metrics """ - if load_if_exists and waveform_extractor.is_extension(QualityMetricCalculator.extension_name): + if load_if_exists and waveform_extractor.has_extension(QualityMetricCalculator.extension_name): qmc = waveform_extractor.load_extension(QualityMetricCalculator.extension_name) else: qmc = QualityMetricCalculator(waveform_extractor) diff --git a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py index eb8317e4df..b601e5d6d8 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py @@ -88,7 +88,7 @@ def test_metrics(self): we = self.we_long # avoid NaNs - if we.is_extension("spike_amplitudes"): + if we.has_extension("spike_amplitudes"): we.delete_extension("spike_amplitudes") # without PC diff --git a/src/spikeinterface/widgets/base.py b/src/spikeinterface/widgets/base.py index a5d3cb2429..6ff837065b 100644 --- a/src/spikeinterface/widgets/base.py +++ b/src/spikeinterface/widgets/base.py @@ -107,7 +107,7 @@ def check_extensions(waveform_extractor, extensions): error_msg = "" raise_error = False for extension in extensions: - if not waveform_extractor.is_extension(extension): + if not waveform_extractor.has_extension(extension): raise_error = True error_msg += ( f"The {extension} waveform extension is required for this widget. " diff --git a/src/spikeinterface/widgets/unit_summary.py b/src/spikeinterface/widgets/unit_summary.py index 35fde07326..aa280ad658 100644 --- a/src/spikeinterface/widgets/unit_summary.py +++ b/src/spikeinterface/widgets/unit_summary.py @@ -80,13 +80,13 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): fig = self.figure nrows = 2 ncols = 3 - if we.is_extension("correlograms") or we.is_extension("spike_amplitudes"): + if we.has_extension("correlograms") or we.has_extension("spike_amplitudes"): ncols += 1 - if we.is_extension("spike_amplitudes"): + if we.has_extension("spike_amplitudes"): nrows += 1 gs = fig.add_gridspec(nrows, ncols) - if we.is_extension("unit_locations"): + if we.has_extension("unit_locations"): ax1 = fig.add_subplot(gs[:2, 0]) # UnitLocationsPlotter().do_plot(dp.plot_data_unit_locations, ax=ax1) w = UnitLocationsWidget( @@ -129,7 +129,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ) ax3.set_ylabel(None) - if we.is_extension("correlograms"): + if we.has_extension("correlograms"): ax4 = fig.add_subplot(gs[:2, 3]) AutoCorrelogramsWidget( we, @@ -142,7 +142,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax4.set_title(None) ax4.set_yticks([]) - if we.is_extension("spike_amplitudes"): + if we.has_extension("spike_amplitudes"): ax5 = fig.add_subplot(gs[2, :3]) ax6 = fig.add_subplot(gs[2, 3]) axes = np.array([ax5, ax6]) From 89387f006812487f3b12f8bbb1c6d42e7cab07a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20WYNGAARD?= Date: Thu, 2 Nov 2023 17:34:24 +0100 Subject: [PATCH 31/52] Removing assert --- src/spikeinterface/core/waveform_extractor.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spikeinterface/core/waveform_extractor.py b/src/spikeinterface/core/waveform_extractor.py index 491fdd5ebe..1453b4156f 100644 --- a/src/spikeinterface/core/waveform_extractor.py +++ b/src/spikeinterface/core/waveform_extractor.py @@ -558,7 +558,6 @@ def has_extension(self, extension_name: str) -> bool: def is_extension(self, extension_name) -> bool: warn("WaveformExtractor.is_extension is deprecated! Use `has_extension` instead.") - assert False return self.has_extension(extension_name) def load_extension(self, extension_name: str): From 634537e58ea3bb742984f2ab8f2a552cfa88334e Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Fri, 3 Nov 2023 12:02:20 +0100 Subject: [PATCH 32/52] improve docstring --- src/spikeinterface/core/base.py | 73 +++++++++++++++++++++++---------- 1 file changed, 52 insertions(+), 21 deletions(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index b737358eef..4a05c8fb1a 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -316,32 +316,63 @@ def to_dict( recursive: bool = False, ) -> dict: """ - Make a nested serialized dictionary out of the extractor. The dictionary produced can be used to re-initialize - an extractor using load_extractor_from_dict(dump_dict) + Construct a nested dictionary representation of the extractor. + + This method facilitates the serialization of the extractor instance by converting it + to a dictionary. The resulting dictionary can be used to re-initialize the extractor + through the `load_extractor_from_dict` function. + + Examples + -------- + >>> dump_dict = original_extractor.to_dict() + >>> reloaded_extractor = load_extractor_from_dict(dump_dict) Parameters ---------- - include_annotations: bool, default: False - If True, all annotations are added to the dict - include_properties: bool, default: False - If True, all properties are added to the dict - relative_to: str, Path, or None, default: None - If not None, files and folders are serialized relative to this path - Used in waveform extractor to maintain relative paths to binary files even if the - containing folder / diretory is moved - folder_metadata: str, Path, or None - Folder with numpy `npy` files containing additional information (e.g. probe in BaseRecording) and properties. - recursive: bool, default: False - If True, all dicitionaries in the kwargs are expanded with `to_dict` as well + include_annotations : bool, optional + Whether to include all annotations in the dictionary, by default False. + include_properties : bool, optional + Whether to include all properties in the dictionary, by default False. + relative_to : Union[str, Path, None], optional + If provided, file and folder paths will be made relative to this path, + enabling portability in folder formats such as the waveform extractor, + by default None. + folder_metadata : Union[str, Path, None], optional + Path to a folder containing additional metadata files (e.g., probe information in BaseRecording) + in numpy `npy` format, by default None. + recursive : bool, optional + If True, recursively apply `to_dict` to dictionaries within the kwargs, by default False. + + Raises + ------ + ValueError + If `relative_to` is specified while `recursive` is False. Returns ------- - dump_dict: dict - A dictionary representation of the extractor. + dict + A dictionary representation of the extractor, with the following structure: + { + "class": , + "module": , (e.g. 'spikeinterface'), + "kwargs": , + "version": , + "relative_paths": , + "annotations": , + "properties": , + "folder_metadata": + } + + Notes + ----- + - The `relative_to` argument only has an effect if `recursive` is set to True. + - The `folder_metadata` argument will be made relative to `relative_to` if both are specified. + - The `version` field in the resulting dictionary reflects the version of the module + from which the extractor class originates. + - The full class attribute above is the full import of the class, e.g. + 'spikeinterface.extractors.neoextractors.spikeglx.SpikeGLXRecordingExtractor' """ - kwargs = self._kwargs - if relative_to and not recursive: raise ValueError("`relative_to` is only possible when `recursive=True`") @@ -373,17 +404,17 @@ def to_dict( module = class_name.split(".")[0] imported_module = importlib.import_module(module) - spike_interface_version = getattr(imported_module, "__version__", "unknown") + module_version = getattr(imported_module, "__version__", "unknown") dump_dict = { "class": class_name, "module": module, "kwargs": kwargs, - "version": spike_interface_version, + "version": module_version, "relative_paths": (relative_to is not None), } - dump_dict["version"] = spike_interface_version + dump_dict["version"] = module_version # Can be spikeinterface, spikefores, etc. if include_annotations: dump_dict["annotations"] = self._annotations From 0544c0ac0888d5a60c6f43ac5030eb4d3d8cbbfa Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Fri, 3 Nov 2023 12:03:35 +0100 Subject: [PATCH 33/52] improve docstring --- src/spikeinterface/core/base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 4a05c8fb1a..867184e63b 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -371,6 +371,8 @@ def to_dict( from which the extractor class originates. - The full class attribute above is the full import of the class, e.g. 'spikeinterface.extractors.neoextractors.spikeglx.SpikeGLXRecordingExtractor' + - The module is usually 'spikeinterface', but can be different for custom extractors such as those of + SpikeForest or any other project that inherits the Extractor class from spikeinterface. """ if relative_to and not recursive: From 8b68729bf403700825efc7c842252d9cd5d210ec Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Fri, 3 Nov 2023 12:05:41 +0100 Subject: [PATCH 34/52] alessio default arguments style --- src/spikeinterface/core/base.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 867184e63b..a229b31b14 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -329,18 +329,18 @@ def to_dict( Parameters ---------- - include_annotations : bool, optional + include_annotations : bool, default: False Whether to include all annotations in the dictionary, by default False. - include_properties : bool, optional + include_properties : bool, default: False Whether to include all properties in the dictionary, by default False. - relative_to : Union[str, Path, None], optional + relative_to : Union[str, Path, None], default: None If provided, file and folder paths will be made relative to this path, enabling portability in folder formats such as the waveform extractor, by default None. - folder_metadata : Union[str, Path, None], optional + folder_metadata : Union[str, Path, None], default: None Path to a folder containing additional metadata files (e.g., probe information in BaseRecording) in numpy `npy` format, by default None. - recursive : bool, optional + recursive : bool, default: False If True, recursively apply `to_dict` to dictionaries within the kwargs, by default False. Raises From 242b1a0ce458acd371a4e8b285f4f88ed994f214 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Fri, 3 Nov 2023 12:11:02 +0100 Subject: [PATCH 35/52] clean_refractory_period --- src/spikeinterface/core/base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index a229b31b14..d2d947d299 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -378,6 +378,7 @@ def to_dict( if relative_to and not recursive: raise ValueError("`relative_to` is only possible when `recursive=True`") + kwargs = self._kwargs if recursive: to_dict_kwargs = dict( include_annotations=include_annotations, From fc6f03c19c774e6941b9010a28f80b5efcf25c5a Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Fri, 3 Nov 2023 13:03:31 +0100 Subject: [PATCH 36/52] Update src/spikeinterface/core/base.py Co-authored-by: Alessio Buccino --- src/spikeinterface/core/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index d2d947d299..2347594119 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -330,7 +330,7 @@ def to_dict( Parameters ---------- include_annotations : bool, default: False - Whether to include all annotations in the dictionary, by default False. + Whether to include all annotations in the dictionary include_properties : bool, default: False Whether to include all properties in the dictionary, by default False. relative_to : Union[str, Path, None], default: None From e96b5a0d1cc39c6a2357657495dad323a6732ad1 Mon Sep 17 00:00:00 2001 From: h-mayorquin Date: Sun, 5 Nov 2023 10:24:52 +0100 Subject: [PATCH 37/52] add tests to stream --- .../extractors/tests/test_nwb_s3_extractor.py | 47 +++++++++++++++++-- 1 file changed, 43 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/extractors/tests/test_nwb_s3_extractor.py b/src/spikeinterface/extractors/tests/test_nwb_s3_extractor.py index 71a19f30d3..e7d148fd94 100644 --- a/src/spikeinterface/extractors/tests/test_nwb_s3_extractor.py +++ b/src/spikeinterface/extractors/tests/test_nwb_s3_extractor.py @@ -1,9 +1,11 @@ from pathlib import Path +import pickle import pytest import numpy as np import h5py +from spikeinterface.core.testing import check_recordings_equal, check_sortings_equal from spikeinterface.extractors import NwbRecordingExtractor, NwbSortingExtractor if hasattr(pytest, "global_test_folder"): @@ -15,7 +17,7 @@ @pytest.mark.ros3_test @pytest.mark.streaming_extractors @pytest.mark.skipif("ros3" not in h5py.registered_drivers(), reason="ROS3 driver not installed") -def test_recording_s3_nwb_ros3(): +def test_recording_s3_nwb_ros3(tmp_path): file_path = ( "https://dandi-api-staging-dandisets.s3.amazonaws.com/blobs/5f4/b7a/5f4b7a1f-7b95-4ad8-9579-4df6025371cc" ) @@ -40,9 +42,18 @@ def test_recording_s3_nwb_ros3(): trace_scaled = rec.get_traces(segment_index=segment_index, return_scaled=True, end_frame=2) assert trace_scaled.dtype == "float32" + tmp_file = tmp_path / "test_ros3_recording.pkl" + with open(tmp_file, 'wb') as f: + pickle.dump(rec, f) + + with open(tmp_file, 'rb') as f: + reloaded_recording = pickle.load(f) + + check_recordings_equal(rec, reloaded_recording) + @pytest.mark.streaming_extractors -def test_recording_s3_nwb_fsspec(): +def test_recording_s3_nwb_fsspec(tmp_path): file_path = ( "https://dandi-api-staging-dandisets.s3.amazonaws.com/blobs/5f4/b7a/5f4b7a1f-7b95-4ad8-9579-4df6025371cc" ) @@ -68,10 +79,21 @@ def test_recording_s3_nwb_fsspec(): assert trace_scaled.dtype == "float32" + tmp_file = tmp_path / "test_fsspec_recording.pkl" + with open(tmp_file, 'wb') as f: + pickle.dump(rec, f) + + with open(tmp_file, 'rb') as f: + reloaded_recording = pickle.load(f) + + check_recordings_equal(rec, reloaded_recording) + + + @pytest.mark.ros3_test @pytest.mark.streaming_extractors @pytest.mark.skipif("ros3" not in h5py.registered_drivers(), reason="ROS3 driver not installed") -def test_sorting_s3_nwb_ros3(): +def test_sorting_s3_nwb_ros3(tmp_path): file_path = "https://dandiarchive.s3.amazonaws.com/blobs/84b/aa4/84baa446-cf19-43e8-bdeb-fc804852279b" # we provide the 'sampling_frequency' because the NWB file does not the electrical series sort = NwbSortingExtractor(file_path, sampling_frequency=30000, stream_mode="ros3") @@ -90,9 +112,17 @@ def test_sorting_s3_nwb_ros3(): assert spike_train.dtype == "int64" assert np.all(spike_train >= 0) + tmp_file = tmp_path / "test_ros3_sorting.pkl" + with open(tmp_file, 'wb') as f: + pickle.dump(sort, f) + + with open(tmp_file, 'rb') as f: + reloaded_sorting = pickle.load(f) + + check_sortings_equal(reloaded_sorting, sort) @pytest.mark.streaming_extractors -def test_sorting_s3_nwb_fsspec(): +def test_sorting_s3_nwb_fsspec(tmp_path): file_path = "https://dandiarchive.s3.amazonaws.com/blobs/84b/aa4/84baa446-cf19-43e8-bdeb-fc804852279b" # we provide the 'sampling_frequency' because the NWB file does not the electrical series sort = NwbSortingExtractor( @@ -113,6 +143,15 @@ def test_sorting_s3_nwb_fsspec(): assert spike_train.dtype == "int64" assert np.all(spike_train >= 0) + tmp_file = tmp_path / "test_fsspec_sorting.pkl" + with open(tmp_file, 'wb') as f: + pickle.dump(sort, f) + + with open(tmp_file, 'rb') as f: + reloaded_sorting = pickle.load(f) + + check_sortings_equal(reloaded_sorting, sort) + if __name__ == "__main__": test_recording_s3_nwb_ros3() From 6c863c3dc81920b12b2ffadab76913ef6c0888d5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 5 Nov 2023 10:26:11 +0000 Subject: [PATCH 38/52] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/base.py | 2 +- .../extractors/tests/test_nwb_s3_extractor.py | 35 +++++++++---------- 2 files changed, 18 insertions(+), 19 deletions(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 88b7f12783..1a8674697a 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -27,7 +27,7 @@ class BaseExtractor: """ - default_missing_property_values = {'f': np.nan, "O": None, 'S': "", "U": ""} + default_missing_property_values = {"f": np.nan, "O": None, "S": "", "U": ""} # This replaces the old key_properties # These are annotations/properties that always need to be diff --git a/src/spikeinterface/extractors/tests/test_nwb_s3_extractor.py b/src/spikeinterface/extractors/tests/test_nwb_s3_extractor.py index e7d148fd94..253ca2e4ce 100644 --- a/src/spikeinterface/extractors/tests/test_nwb_s3_extractor.py +++ b/src/spikeinterface/extractors/tests/test_nwb_s3_extractor.py @@ -43,12 +43,12 @@ def test_recording_s3_nwb_ros3(tmp_path): assert trace_scaled.dtype == "float32" tmp_file = tmp_path / "test_ros3_recording.pkl" - with open(tmp_file, 'wb') as f: + with open(tmp_file, "wb") as f: pickle.dump(rec, f) - - with open(tmp_file, 'rb') as f: + + with open(tmp_file, "rb") as f: reloaded_recording = pickle.load(f) - + check_recordings_equal(rec, reloaded_recording) @@ -78,16 +78,14 @@ def test_recording_s3_nwb_fsspec(tmp_path): trace_scaled = rec.get_traces(segment_index=segment_index, return_scaled=True, end_frame=2) assert trace_scaled.dtype == "float32" - tmp_file = tmp_path / "test_fsspec_recording.pkl" - with open(tmp_file, 'wb') as f: + with open(tmp_file, "wb") as f: pickle.dump(rec, f) - - with open(tmp_file, 'rb') as f: + + with open(tmp_file, "rb") as f: reloaded_recording = pickle.load(f) - - check_recordings_equal(rec, reloaded_recording) + check_recordings_equal(rec, reloaded_recording) @pytest.mark.ros3_test @@ -113,14 +111,15 @@ def test_sorting_s3_nwb_ros3(tmp_path): assert np.all(spike_train >= 0) tmp_file = tmp_path / "test_ros3_sorting.pkl" - with open(tmp_file, 'wb') as f: + with open(tmp_file, "wb") as f: pickle.dump(sort, f) - - with open(tmp_file, 'rb') as f: + + with open(tmp_file, "rb") as f: reloaded_sorting = pickle.load(f) - + check_sortings_equal(reloaded_sorting, sort) + @pytest.mark.streaming_extractors def test_sorting_s3_nwb_fsspec(tmp_path): file_path = "https://dandiarchive.s3.amazonaws.com/blobs/84b/aa4/84baa446-cf19-43e8-bdeb-fc804852279b" @@ -144,12 +143,12 @@ def test_sorting_s3_nwb_fsspec(tmp_path): assert np.all(spike_train >= 0) tmp_file = tmp_path / "test_fsspec_sorting.pkl" - with open(tmp_file, 'wb') as f: + with open(tmp_file, "wb") as f: pickle.dump(sort, f) - - with open(tmp_file, 'rb') as f: + + with open(tmp_file, "rb") as f: reloaded_sorting = pickle.load(f) - + check_sortings_equal(reloaded_sorting, sort) From 85f9b050bdeec8405d05798ff8605a7ebbd5c5d0 Mon Sep 17 00:00:00 2001 From: h-mayorquin Date: Sun, 5 Nov 2023 10:46:04 +0100 Subject: [PATCH 39/52] trigger tests --- src/spikeinterface/extractors/nwbextractors.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/extractors/nwbextractors.py b/src/spikeinterface/extractors/nwbextractors.py index f7b445cdb9..010b22975c 100644 --- a/src/spikeinterface/extractors/nwbextractors.py +++ b/src/spikeinterface/extractors/nwbextractors.py @@ -566,12 +566,12 @@ def get_unit_spike_train( start_frame = 0 if end_frame is None: end_frame = np.inf - times = self._nwbfile.units["spike_times"][list(self._nwbfile.units.id[:]).index(unit_id)][:] + spike_times = self._nwbfile.units["spike_times"][list(self._nwbfile.units.id[:]).index(unit_id)][:] if self._timestamps is not None: - frames = np.searchsorted(times, self.timestamps).astype("int64") + frames = np.searchsorted(spike_times, self.timestamps).astype("int64") else: - frames = np.round(times * self._sampling_frequency).astype("int64") + frames = np.round(spike_times * self._sampling_frequency).astype("int64") return frames[(frames >= start_frame) & (frames < end_frame)] From 2b820bf7244f5d8d4b43c87df03d23dfe5828922 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20WYNGAARD?= Date: Mon, 6 Nov 2023 12:30:00 +0100 Subject: [PATCH 40/52] Improvement to `get_empty_units()` --- src/spikeinterface/core/basesorting.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 94b08d8cc3..319fc7cb12 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -364,16 +364,12 @@ def remove_empty_units(self): return self.select_units(non_empty_units) def get_non_empty_unit_ids(self): - non_empty_units = [] - for segment_index in range(self.get_num_segments()): - for unit in self.get_unit_ids(): - if len(self.get_unit_spike_train(unit, segment_index=segment_index)) > 0: - non_empty_units.append(unit) - non_empty_units = np.unique(non_empty_units) - return non_empty_units + num_spikes_per_unit = self.count_num_spikes_per_unit() + + return np.array([unit_id for unit_id in self.unit_ids if num_spikes_per_unit[unit_id] == 0]) def get_empty_unit_ids(self): - unit_ids = self.get_unit_ids() + unit_ids = self.unit_ids empty_units = unit_ids[~np.isin(unit_ids, self.get_non_empty_unit_ids())] return empty_units From 23d2d89858b70cd046f8a3a043b93a11ff2234cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20WYNGAARD?= Date: Mon, 6 Nov 2023 12:37:43 +0100 Subject: [PATCH 41/52] oops --- src/spikeinterface/core/basesorting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 319fc7cb12..87507f799e 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -366,7 +366,7 @@ def remove_empty_units(self): def get_non_empty_unit_ids(self): num_spikes_per_unit = self.count_num_spikes_per_unit() - return np.array([unit_id for unit_id in self.unit_ids if num_spikes_per_unit[unit_id] == 0]) + return np.array([unit_id for unit_id in self.unit_ids if num_spikes_per_unit[unit_id] != 0]) def get_empty_unit_ids(self): unit_ids = self.unit_ids From 2253a45a9dff06cb157600a453bbb235fadc489c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20WYNGAARD?= Date: Mon, 6 Nov 2023 13:27:45 +0100 Subject: [PATCH 42/52] Update src/spikeinterface/core/waveform_extractor.py Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/core/waveform_extractor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/waveform_extractor.py b/src/spikeinterface/core/waveform_extractor.py index 1453b4156f..cc17aea62f 100644 --- a/src/spikeinterface/core/waveform_extractor.py +++ b/src/spikeinterface/core/waveform_extractor.py @@ -557,7 +557,7 @@ def has_extension(self, extension_name: str) -> bool: ) def is_extension(self, extension_name) -> bool: - warn("WaveformExtractor.is_extension is deprecated! Use `has_extension` instead.") + warn("WaveformExtractor.is_extension is deprecated and will be removed in version 0.102.0! Use `has_extension` instead.") return self.has_extension(extension_name) def load_extension(self, extension_name: str): From 7be79220b12a9ba475adbe36c3341dd9cee32cc1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 6 Nov 2023 12:29:30 +0000 Subject: [PATCH 43/52] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/waveform_extractor.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/waveform_extractor.py b/src/spikeinterface/core/waveform_extractor.py index cc17aea62f..dd7c3043c3 100644 --- a/src/spikeinterface/core/waveform_extractor.py +++ b/src/spikeinterface/core/waveform_extractor.py @@ -557,7 +557,9 @@ def has_extension(self, extension_name: str) -> bool: ) def is_extension(self, extension_name) -> bool: - warn("WaveformExtractor.is_extension is deprecated and will be removed in version 0.102.0! Use `has_extension` instead.") + warn( + "WaveformExtractor.is_extension is deprecated and will be removed in version 0.102.0! Use `has_extension` instead." + ) return self.has_extension(extension_name) def load_extension(self, extension_name: str): From a4ce54a212b7d5a946d125328b0f0760cfd8bd85 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20WYNGAARD?= Date: Wed, 8 Nov 2023 13:33:01 +0100 Subject: [PATCH 44/52] Better docstring for empty units --- src/spikeinterface/core/basesorting.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 87507f799e..a43eaf6090 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -353,7 +353,8 @@ def remove_units(self, remove_unit_ids): def remove_empty_units(self): """ - Removes units with empty spike trains + Removes units with empty spike trains. + For multi-segments, a unit is considered empty if it contains no spikes in all segments. Returns ------- From 57cfbd2461fb865836590aed59222f9c4fdbc7ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20WYNGAARD?= Date: Fri, 10 Nov 2023 10:52:31 +0100 Subject: [PATCH 45/52] Fix filtering rounding error --- src/spikeinterface/preprocessing/filter.py | 4 ++++ src/spikeinterface/preprocessing/filter_gaussian.py | 3 +++ 2 files changed, 7 insertions(+) diff --git a/src/spikeinterface/preprocessing/filter.py b/src/spikeinterface/preprocessing/filter.py index 1d6947be79..172c666d62 100644 --- a/src/spikeinterface/preprocessing/filter.py +++ b/src/spikeinterface/preprocessing/filter.py @@ -153,6 +153,10 @@ def get_traces(self, start_frame, end_frame, channel_indices): filtered_traces = filtered_traces[left_margin:-right_margin, :] else: filtered_traces = filtered_traces[left_margin:, :] + + if np.issubdtype(self.dtype, np.integer): + filtered_traces = filtered_traces.round() + return filtered_traces.astype(self.dtype) diff --git a/src/spikeinterface/preprocessing/filter_gaussian.py b/src/spikeinterface/preprocessing/filter_gaussian.py index 79b5ba5bc3..325ce82074 100644 --- a/src/spikeinterface/preprocessing/filter_gaussian.py +++ b/src/spikeinterface/preprocessing/filter_gaussian.py @@ -74,6 +74,9 @@ def get_traces( filtered_fft = traces_fft * (gauss_high - gauss_low)[:, None] filtered_traces = np.real(np.fft.ifft(filtered_fft, axis=0)) + if np.issubdtype(dtype, np.integer): + filtered_traces = filtered_traces.round() + if right_margin > 0: return filtered_traces[left_margin:-right_margin, :].astype(dtype) else: From 5591593d569f23fb9b71e6f67b2a17ad1eb70dae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20WYNGAARD?= Date: Fri, 10 Nov 2023 17:37:23 +0100 Subject: [PATCH 46/52] Documented rounding in pre-processing --- doc/modules/preprocessing.rst | 2 ++ src/spikeinterface/preprocessing/normalize_scale.py | 4 ++++ src/spikeinterface/preprocessing/phase_shift.py | 2 ++ 3 files changed, 8 insertions(+) diff --git a/doc/modules/preprocessing.rst b/doc/modules/preprocessing.rst index 67f1e52011..e95edb968c 100644 --- a/doc/modules/preprocessing.rst +++ b/doc/modules/preprocessing.rst @@ -74,6 +74,8 @@ dtype (unless specified otherwise): Some scaling pre-processors, such as :code:`whiten()` or :code:`zscore()`, will force the output to :code:`float32`. +When converting from a :code:`float` to an :code:`int`, the value will first be rounded to the nearest integer. + Available preprocessing ----------------------- diff --git a/src/spikeinterface/preprocessing/normalize_scale.py b/src/spikeinterface/preprocessing/normalize_scale.py index 03afada380..f24aff6e79 100644 --- a/src/spikeinterface/preprocessing/normalize_scale.py +++ b/src/spikeinterface/preprocessing/normalize_scale.py @@ -20,6 +20,10 @@ def __init__(self, parent_recording_segment, gain, offset, dtype): def get_traces(self, start_frame, end_frame, channel_indices): traces = self.parent_recording_segment.get_traces(start_frame, end_frame, channel_indices) scaled_traces = traces * self.gain[:, channel_indices] + self.offset[:, channel_indices] + + if np.issubdtype(self._dtype, np.integer): + scaled_traces = scaled_traces.round() + return scaled_traces.astype(self._dtype) diff --git a/src/spikeinterface/preprocessing/phase_shift.py b/src/spikeinterface/preprocessing/phase_shift.py index 570ce48a5d..0734dad784 100644 --- a/src/spikeinterface/preprocessing/phase_shift.py +++ b/src/spikeinterface/preprocessing/phase_shift.py @@ -103,6 +103,8 @@ def get_traces(self, start_frame, end_frame, channel_indices): traces_shift = traces_shift[left_margin:-right_margin, :] if self.tmp_dtype is not None: + if np.issubdtype(self.dtype, np.integer): + traces_shift = traces_shift.round() traces_shift = traces_shift.astype(self.dtype) return traces_shift From 97f46c3ed9a15cbb961ff4a38ea981f5c4e4cb0f Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 15 Nov 2023 10:21:22 +0100 Subject: [PATCH 47/52] add rename_units method and test --- src/spikeinterface/core/basesorting.py | 52 ++++++++++++++----- .../core/tests/test_basesorting.py | 13 +++++ 2 files changed, 51 insertions(+), 14 deletions(-) diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 94b08d8cc3..9f94e38bb4 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -57,13 +57,13 @@ def add_sorting_segment(self, sorting_segment): self._sorting_segments.append(sorting_segment) sorting_segment.set_parent_extractor(self) - def get_sampling_frequency(self): + def get_sampling_frequency(self) -> float: return self._sampling_frequency - def get_num_segments(self): + def get_num_segments(self) -> int: return len(self._sorting_segments) - def get_num_samples(self, segment_index=None): + def get_num_samples(self, segment_index=None) -> int: """Returns the number of samples of the associated recording for a segment. Parameters @@ -82,7 +82,7 @@ def get_num_samples(self, segment_index=None): ), "This methods requires an associated recording. Call self.register_recording() first." return self._recording.get_num_samples(segment_index=segment_index) - def get_total_samples(self): + def get_total_samples(self) -> int: """Returns the total number of samples of the associated recording. Returns @@ -299,9 +299,11 @@ def count_num_spikes_per_unit(self) -> dict: return num_spikes - def count_total_num_spikes(self): + def count_total_num_spikes(self) -> int: """ - Get total number of spikes summed across segment and units. + Get total number of spikes in the sorting. + + This is the sum of all spikes in all segments across all units. Returns ------- @@ -310,9 +312,10 @@ def count_total_num_spikes(self): """ return self.to_spike_vector().size - def select_units(self, unit_ids, renamed_unit_ids=None): + def select_units(self, unit_ids, renamed_unit_ids=None) -> BaseSorting: """ - Selects a subset of units + Returns a new sorting object which contains only a selected subset of units. + Parameters ---------- @@ -331,9 +334,30 @@ def select_units(self, unit_ids, renamed_unit_ids=None): sub_sorting = UnitsSelectionSorting(self, unit_ids, renamed_unit_ids=renamed_unit_ids) return sub_sorting - def remove_units(self, remove_unit_ids): + def rename_units(self, new_unit_ids: np.ndarray | list) -> BaseSorting: """ - Removes a subset of units + Returns a new sorting object with renamed units. + + + Parameters + ---------- + new_unit_ids : numpy.array or list + List of new names for unit ids. + They should map positionally to the existing unit ids. + + Returns + ------- + BaseSorting + Sorting object with renamed units + """ + from spikeinterface import UnitsSelectionSorting + + sub_sorting = UnitsSelectionSorting(self, renamed_unit_ids=new_unit_ids) + return sub_sorting + + def remove_units(self, remove_unit_ids) -> BaseSorting: + """ + Returns a new sorting object with contains only a selected subset of units. Parameters ---------- @@ -343,7 +367,7 @@ def remove_units(self, remove_unit_ids): Returns ------- BaseSorting - Sorting object without removed units + Sorting without the removed units """ from spikeinterface import UnitsSelectionSorting @@ -353,7 +377,8 @@ def remove_units(self, remove_unit_ids): def remove_empty_units(self): """ - Removes units with empty spike trains + Returns a new sorting object which contains only units with at least one spike. + Returns ------- @@ -389,7 +414,7 @@ def get_all_spike_trains(self, outputs="unit_id"): """ Return all spike trains concatenated. - This is deprecated use sorting.to_spike_vector() instead + This is deprecated and will be removed in spikeinterface 0.102 use sorting.to_spike_vector() instead """ warnings.warn( @@ -429,7 +454,6 @@ def to_spike_vector(self, concatenated=True, extremum_channel_inds=None, use_cac Construct a unique structured numpy vector concatenating all spikes with several fields: sample_index, unit_index, segment_index. - See also `get_all_spike_trains()` Parameters ---------- diff --git a/src/spikeinterface/core/tests/test_basesorting.py b/src/spikeinterface/core/tests/test_basesorting.py index 0bdd9aecdd..a35898b420 100644 --- a/src/spikeinterface/core/tests/test_basesorting.py +++ b/src/spikeinterface/core/tests/test_basesorting.py @@ -22,6 +22,7 @@ ) from spikeinterface.core.base import BaseExtractor from spikeinterface.core.testing import check_sorted_arrays_equal, check_sortings_equal +from spikeinterface.core.generate import generate_sorting if hasattr(pytest, "global_test_folder"): cache_folder = pytest.global_test_folder / "core" @@ -169,6 +170,18 @@ def test_npy_sorting(): assert_raises(Exception, sorting.register_recording, rec) +def test_rename_units_method(): + num_units = 2 + durations = [1.0, 1.0] + + sorting = generate_sorting(num_units=num_units, durations=durations) + + new_unit_ids = ["a", "b"] + new_sorting = sorting.rename_units(new_unit_ids=new_unit_ids) + + assert np.array_equal(new_sorting.get_unit_ids(), new_unit_ids) + + def test_empty_sorting(): sorting = NumpySorting.from_unit_dict({}, 30000) From bf20cb5cf909d5178e09fbc2b7bde30170be4f09 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20WYNGAARD?= Date: Wed, 15 Nov 2023 11:23:49 +0100 Subject: [PATCH 48/52] Alessio suggestion Co-authored-by: Alessio Buccino --- src/spikeinterface/core/waveform_extractor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/waveform_extractor.py b/src/spikeinterface/core/waveform_extractor.py index dd7c3043c3..9d38a46bb2 100644 --- a/src/spikeinterface/core/waveform_extractor.py +++ b/src/spikeinterface/core/waveform_extractor.py @@ -558,7 +558,7 @@ def has_extension(self, extension_name: str) -> bool: def is_extension(self, extension_name) -> bool: warn( - "WaveformExtractor.is_extension is deprecated and will be removed in version 0.102.0! Use `has_extension` instead." + "WaveformExtractor.is_extension is deprecated and will be removed in version 0.102.0! Use `has_extension` instead.", DeprecationWarning, stacklevel=2 ) return self.has_extension(extension_name) From 2b0a432070f4298a1acdbb0d2147e5b4fd7ee2ac Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 15 Nov 2023 10:24:30 +0000 Subject: [PATCH 49/52] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/waveform_extractor.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/waveform_extractor.py b/src/spikeinterface/core/waveform_extractor.py index 9d38a46bb2..a81d36139d 100644 --- a/src/spikeinterface/core/waveform_extractor.py +++ b/src/spikeinterface/core/waveform_extractor.py @@ -558,7 +558,9 @@ def has_extension(self, extension_name: str) -> bool: def is_extension(self, extension_name) -> bool: warn( - "WaveformExtractor.is_extension is deprecated and will be removed in version 0.102.0! Use `has_extension` instead.", DeprecationWarning, stacklevel=2 + "WaveformExtractor.is_extension is deprecated and will be removed in version 0.102.0! Use `has_extension` instead.", + DeprecationWarning, + stacklevel=2, ) return self.has_extension(extension_name) From 874db2e1932479bbc3ba3522418912d8402c07ad Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Wed, 15 Nov 2023 12:06:53 +0100 Subject: [PATCH 50/52] handling segments --- src/spikeinterface/sorters/internal/spyking_circus2.py | 2 ++ src/spikeinterface/sorters/internal/tridesclous2.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index a16b642dd5..fd283a8224 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -33,6 +33,8 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): "job_kwargs": {"n_jobs": -1}, } + handle_multi_segment = True + @classmethod def get_sorter_version(cls): return "2.0" diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index e256915fa6..eb2ddc922d 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -50,6 +50,8 @@ class Tridesclous2Sorter(ComponentsBasedSorter): "save_array": True, } + handle_multi_segment = True + @classmethod def get_sorter_version(cls): return "2.0" From 8a70f80c1f14d123c215b03fdf90089c6f1b3a0e Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 15 Nov 2023 14:36:48 +0100 Subject: [PATCH 51/52] Avoid duplicated template and quality metric names --- src/spikeinterface/postprocessing/template_metrics.py | 2 +- src/spikeinterface/qualitymetrics/quality_metric_calculator.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index 858af3ee08..f68081dbda 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -60,7 +60,7 @@ def _set_params( metric_names += get_multi_channel_template_metric_names() metrics_kwargs = metrics_kwargs or dict() params = dict( - metric_names=[str(name) for name in metric_names], + metric_names=[str(name) for name in np.unique(metric_names)], sparsity=sparsity, peak_sign=peak_sign, upsampling_factor=int(upsampling_factor), diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index 53309db282..428139b3b9 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -61,7 +61,7 @@ def _set_params( qm_params_[k]["peak_sign"] = peak_sign params = dict( - metric_names=[str(name) for name in metric_names], + metric_names=[str(name) for name in np.unique(metric_names)], sparsity=sparsity, peak_sign=peak_sign, seed=seed, From d692439159e08cc5de33dd6ccbbbe3e19513f550 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 15 Nov 2023 15:44:39 +0000 Subject: [PATCH 52/52] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/generate.py | 21 ++++++++++++++----- .../core/tests/test_generate.py | 21 ++++++++++++------- 2 files changed, 29 insertions(+), 13 deletions(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 49a5650622..1c8661d12d 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1336,8 +1336,17 @@ def generate_channel_locations(num_channels, num_columns, contact_spacing_um): return channel_locations -def generate_unit_locations(num_units, channel_locations, margin_um=20.0, minimum_z=5.0, maximum_z=40.0, - minimum_distance=20., max_iteration=100, distance_strict=False, seed=None): +def generate_unit_locations( + num_units, + channel_locations, + margin_um=20.0, + minimum_z=5.0, + maximum_z=40.0, + minimum_distance=20.0, + max_iteration=100, + distance_strict=False, + seed=None, +): rng = np.random.default_rng(seed=seed) units_locations = np.zeros((num_units, 3), dtype="float32") @@ -1364,7 +1373,7 @@ def generate_unit_locations(num_units, channel_locations, margin_um=20.0, minimu else: # random only bad ones in the previous set renew_inds = renew_inds[np.isin(renew_inds, np.unique(inds0))] - + units_locations[:, 0][renew_inds] = rng.uniform(minimum_x, maximum_x, size=renew_inds.size) units_locations[:, 1][renew_inds] = rng.uniform(minimum_y, maximum_y, size=renew_inds.size) units_locations[:, 2][renew_inds] = rng.uniform(minimum_z, maximum_z, size=renew_inds.size) @@ -1374,8 +1383,10 @@ def generate_unit_locations(num_units, channel_locations, margin_um=20.0, minimu if not solution_found: if distance_strict: - raise ValueError(f"generate_unit_locations(): no solution for {minimum_distance=} and {max_iteration=} " - "You can use distance_strict=False or reduce minimum distance") + raise ValueError( + f"generate_unit_locations(): no solution for {minimum_distance=} and {max_iteration=} " + "You can use distance_strict=False or reduce minimum distance" + ) else: warnings.warn(f"generate_unit_locations(): no solution for {minimum_distance=} and {max_iteration=}") diff --git a/src/spikeinterface/core/tests/test_generate.py b/src/spikeinterface/core/tests/test_generate.py index 582120ac51..7b51abcccb 100644 --- a/src/spikeinterface/core/tests/test_generate.py +++ b/src/spikeinterface/core/tests/test_generate.py @@ -292,24 +292,29 @@ def test_generate_single_fake_waveform(): def test_generate_unit_locations(): - seed = 0 probe = generate_multi_columns_probe(num_columns=2, num_contact_per_column=20, xpitch=20, ypitch=20) channel_locations = probe.contact_positions num_units = 100 - minimum_distance = 20. - unit_locations = generate_unit_locations(num_units, channel_locations, - margin_um=20.0, minimum_z=5.0, maximum_z=40.0, - minimum_distance=minimum_distance, max_iteration=500, - distance_strict=False, seed=seed) + minimum_distance = 20.0 + unit_locations = generate_unit_locations( + num_units, + channel_locations, + margin_um=20.0, + minimum_z=5.0, + maximum_z=40.0, + minimum_distance=minimum_distance, + max_iteration=500, + distance_strict=False, + seed=seed, + ) distances = np.linalg.norm(unit_locations[:, np.newaxis] - unit_locations[np.newaxis, :], axis=2) dist_flat = np.triu(distances, k=1).flatten() - dist_flat = dist_flat[dist_flat>0] + dist_flat = dist_flat[dist_flat > 0] assert np.all(dist_flat > minimum_distance) - # import matplotlib.pyplot as plt # fig, ax = plt.subplots() # ax.hist(dist_flat, bins = np.arange(0, 400, 10))