From 49c7a92a57af5a65f7367b375567afaed6abda56 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 15 Oct 2024 13:46:36 +0200 Subject: [PATCH 1/7] Use existing sparsity for unit location + add location with max channel --- .../postprocessing/localization_tools.py | 67 +++++++++++++++++-- .../tests/test_unit_locations.py | 1 + 2 files changed, 64 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/postprocessing/localization_tools.py b/src/spikeinterface/postprocessing/localization_tools.py index e6278fc59f..59ca8cf7db 100644 --- a/src/spikeinterface/postprocessing/localization_tools.py +++ b/src/spikeinterface/postprocessing/localization_tools.py @@ -76,8 +76,12 @@ def compute_monopolar_triangulation( assert feature in ["ptp", "energy", "peak_voltage"], f"{feature} is not a valid feature" contact_locations = sorting_analyzer_or_templates.get_channel_locations() + + if sorting_analyzer_or_templates.sparsity is None: + sparsity = compute_sparsity(sorting_analyzer_or_templates, method="radius", radius_um=radius_um) + else: + sparsity = sorting_analyzer_or_templates.sparsity - sparsity = compute_sparsity(sorting_analyzer_or_templates, method="radius", radius_um=radius_um) templates = get_dense_templates_array( sorting_analyzer_or_templates, return_scaled=get_return_scaled(sorting_analyzer_or_templates) ) @@ -157,9 +161,13 @@ def compute_center_of_mass( assert feature in ["ptp", "mean", "energy", "peak_voltage"], f"{feature} is not a valid feature" - sparsity = compute_sparsity( - sorting_analyzer_or_templates, peak_sign=peak_sign, method="radius", radius_um=radius_um - ) + if sorting_analyzer_or_templates.sparsity is None: + sparsity = compute_sparsity( + sorting_analyzer_or_templates, peak_sign=peak_sign, method="radius", radius_um=radius_um + ) + else: + sparsity = sorting_analyzer_or_templates.sparsity + templates = get_dense_templates_array( sorting_analyzer_or_templates, return_scaled=get_return_scaled(sorting_analyzer_or_templates) ) @@ -650,8 +658,59 @@ def get_convolution_weights( enforce_decrease_shells = numba.jit(enforce_decrease_shells_data, nopython=True) + +def compute_location_max_channel( + templates_or_sorting_analyzer: SortingAnalyzer | Templates, + unit_ids=None, + peak_sign: "neg" | "pos" | "both" = "neg", + mode: "extremum" | "at_index" | "peak_to_peak" = "extremum", +) -> np.ndarray: + """ + Localize a unit using max channel. + + This use inetrnally get_template_extremum_channel() + + + Parameters + ---------- + templates_or_sorting_analyzer : SortingAnalyzer | Templates + A SortingAnalyzer or Templates object + unit_ids: str | int | None + A list of unit_id to restrict the computation + peak_sign : "neg" | "pos" | "both" + Sign of the template to find extremum channels + mode : "extremum" | "at_index" | "peak_to_peak", default: "at_index" + Where the amplitude is computed + * "extremum" : take the peak value (max or min depending on `peak_sign`) + * "at_index" : take value at `nbefore` index + * "peak_to_peak" : take the peak-to-peak amplitude + + Returns + ------- + unit_location: np.ndarray + 2d + """ + extremum_channels_index = get_template_extremum_channel( + templates_or_sorting_analyzer, + peak_sign=peak_sign, + mode=mode, + outputs="index" + ) + contact_locations = templates_or_sorting_analyzer.get_channel_locations() + if unit_ids is None: + unit_ids = templates_or_sorting_analyzer.unit_ids + else: + unit_ids = np.asarray(unit_ids) + unit_location = np.zeros((unit_ids.size, 2), dtype="float32") + for i, unit_id in enumerate(unit_ids): + unit_location[i, :] = contact_locations[extremum_channels_index[unit_id]] + + return unit_location + + _unit_location_methods = { "center_of_mass": compute_center_of_mass, "grid_convolution": compute_grid_convolution, "monopolar_triangulation": compute_monopolar_triangulation, + "max_channel": compute_location_max_channel, } diff --git a/src/spikeinterface/postprocessing/tests/test_unit_locations.py b/src/spikeinterface/postprocessing/tests/test_unit_locations.py index c40a917a2b..545edb3497 100644 --- a/src/spikeinterface/postprocessing/tests/test_unit_locations.py +++ b/src/spikeinterface/postprocessing/tests/test_unit_locations.py @@ -13,6 +13,7 @@ class TestUnitLocationsExtension(AnalyzerExtensionCommonTestSuite): dict(method="grid_convolution", radius_um=150, weight_method={"mode": "gaussian_2d"}), dict(method="monopolar_triangulation", radius_um=150), dict(method="monopolar_triangulation", radius_um=150, optimizer="minimize_with_log_penality"), + dict(method="max_channel"), ], ) def test_extension(self, params): From 9cf9377a30b1733223037c58bb05709f0e76d5c8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 15 Oct 2024 11:50:21 +0000 Subject: [PATCH 2/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../postprocessing/localization_tools.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/postprocessing/localization_tools.py b/src/spikeinterface/postprocessing/localization_tools.py index 59ca8cf7db..4bf39e00e8 100644 --- a/src/spikeinterface/postprocessing/localization_tools.py +++ b/src/spikeinterface/postprocessing/localization_tools.py @@ -76,7 +76,7 @@ def compute_monopolar_triangulation( assert feature in ["ptp", "energy", "peak_voltage"], f"{feature} is not a valid feature" contact_locations = sorting_analyzer_or_templates.get_channel_locations() - + if sorting_analyzer_or_templates.sparsity is None: sparsity = compute_sparsity(sorting_analyzer_or_templates, method="radius", radius_um=radius_um) else: @@ -167,7 +167,7 @@ def compute_center_of_mass( ) else: sparsity = sorting_analyzer_or_templates.sparsity - + templates = get_dense_templates_array( sorting_analyzer_or_templates, return_scaled=get_return_scaled(sorting_analyzer_or_templates) ) @@ -658,7 +658,6 @@ def get_convolution_weights( enforce_decrease_shells = numba.jit(enforce_decrease_shells_data, nopython=True) - def compute_location_max_channel( templates_or_sorting_analyzer: SortingAnalyzer | Templates, unit_ids=None, @@ -691,10 +690,7 @@ def compute_location_max_channel( 2d """ extremum_channels_index = get_template_extremum_channel( - templates_or_sorting_analyzer, - peak_sign=peak_sign, - mode=mode, - outputs="index" + templates_or_sorting_analyzer, peak_sign=peak_sign, mode=mode, outputs="index" ) contact_locations = templates_or_sorting_analyzer.get_channel_locations() if unit_ids is None: From 0e5f50fdfb6b6ae35a9b06d814e811ac9bf833ee Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Mon, 21 Oct 2024 13:54:42 +0200 Subject: [PATCH 3/7] merci zach Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/postprocessing/localization_tools.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/postprocessing/localization_tools.py b/src/spikeinterface/postprocessing/localization_tools.py index 4bf39e00e8..a17abea1eb 100644 --- a/src/spikeinterface/postprocessing/localization_tools.py +++ b/src/spikeinterface/postprocessing/localization_tools.py @@ -667,14 +667,14 @@ def compute_location_max_channel( """ Localize a unit using max channel. - This use inetrnally get_template_extremum_channel() + This uses interrnally `get_template_extremum_channel()` Parameters ---------- templates_or_sorting_analyzer : SortingAnalyzer | Templates A SortingAnalyzer or Templates object - unit_ids: str | int | None + unit_ids: list[str] | list[int] | None A list of unit_id to restrict the computation peak_sign : "neg" | "pos" | "both" Sign of the template to find extremum channels From 0be00cf32fad46b1a55d7018ea051014644568ab Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Mon, 21 Oct 2024 18:47:14 +0200 Subject: [PATCH 4/7] Update src/spikeinterface/postprocessing/localization_tools.py Co-authored-by: Alessio Buccino --- src/spikeinterface/postprocessing/localization_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/postprocessing/localization_tools.py b/src/spikeinterface/postprocessing/localization_tools.py index a17abea1eb..3372a34c98 100644 --- a/src/spikeinterface/postprocessing/localization_tools.py +++ b/src/spikeinterface/postprocessing/localization_tools.py @@ -667,7 +667,7 @@ def compute_location_max_channel( """ Localize a unit using max channel. - This uses interrnally `get_template_extremum_channel()` + This uses internally `get_template_extremum_channel()` Parameters From 0002edcbe99764fcf65edae405a2be59647691f1 Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Mon, 21 Oct 2024 18:47:23 +0200 Subject: [PATCH 5/7] Update src/spikeinterface/postprocessing/localization_tools.py Co-authored-by: Alessio Buccino --- src/spikeinterface/postprocessing/localization_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/postprocessing/localization_tools.py b/src/spikeinterface/postprocessing/localization_tools.py index 3372a34c98..67d469f85c 100644 --- a/src/spikeinterface/postprocessing/localization_tools.py +++ b/src/spikeinterface/postprocessing/localization_tools.py @@ -686,7 +686,7 @@ def compute_location_max_channel( Returns ------- - unit_location: np.ndarray + unit_locations: np.ndarray 2d """ extremum_channels_index = get_template_extremum_channel( From b4e681d8524d119a86ba21fd9a19c5e6716c1ca0 Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Mon, 21 Oct 2024 18:47:33 +0200 Subject: [PATCH 6/7] Update src/spikeinterface/postprocessing/localization_tools.py Co-authored-by: Alessio Buccino --- src/spikeinterface/postprocessing/localization_tools.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/postprocessing/localization_tools.py b/src/spikeinterface/postprocessing/localization_tools.py index 67d469f85c..a073b6c518 100644 --- a/src/spikeinterface/postprocessing/localization_tools.py +++ b/src/spikeinterface/postprocessing/localization_tools.py @@ -697,11 +697,11 @@ def compute_location_max_channel( unit_ids = templates_or_sorting_analyzer.unit_ids else: unit_ids = np.asarray(unit_ids) - unit_location = np.zeros((unit_ids.size, 2), dtype="float32") + unit_locations = np.zeros((unit_ids.size, 2), dtype="float32") for i, unit_id in enumerate(unit_ids): - unit_location[i, :] = contact_locations[extremum_channels_index[unit_id]] + unit_locations[i, :] = contact_locations[extremum_channels_index[unit_id]] - return unit_location + return unit_locations _unit_location_methods = { From 55b50abdebf5f1e0aa0c00307d706588b1d9038d Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Mon, 21 Oct 2024 18:47:54 +0200 Subject: [PATCH 7/7] Update src/spikeinterface/postprocessing/localization_tools.py Co-authored-by: Alessio Buccino --- src/spikeinterface/postprocessing/localization_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/postprocessing/localization_tools.py b/src/spikeinterface/postprocessing/localization_tools.py index a073b6c518..837b983059 100644 --- a/src/spikeinterface/postprocessing/localization_tools.py +++ b/src/spikeinterface/postprocessing/localization_tools.py @@ -684,7 +684,7 @@ def compute_location_max_channel( * "at_index" : take value at `nbefore` index * "peak_to_peak" : take the peak-to-peak amplitude - Returns + Returns ------- unit_locations: np.ndarray 2d