Skip to content

Commit

Permalink
Merge pull request SpikeInterface#3476 from samuelgarcia/localization…
Browse files Browse the repository at this point in the history
…_and_sparsity

Unit localization
  • Loading branch information
alejoe91 authored Oct 22, 2024
2 parents 9b53f09 + 55b50ab commit bfe9fb6
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 4 deletions.
63 changes: 59 additions & 4 deletions src/spikeinterface/postprocessing/localization_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,11 @@ def compute_monopolar_triangulation(

contact_locations = sorting_analyzer_or_templates.get_channel_locations()

sparsity = compute_sparsity(sorting_analyzer_or_templates, method="radius", radius_um=radius_um)
if sorting_analyzer_or_templates.sparsity is None:
sparsity = compute_sparsity(sorting_analyzer_or_templates, method="radius", radius_um=radius_um)
else:
sparsity = sorting_analyzer_or_templates.sparsity

templates = get_dense_templates_array(
sorting_analyzer_or_templates, return_scaled=get_return_scaled(sorting_analyzer_or_templates)
)
Expand Down Expand Up @@ -157,9 +161,13 @@ def compute_center_of_mass(

assert feature in ["ptp", "mean", "energy", "peak_voltage"], f"{feature} is not a valid feature"

sparsity = compute_sparsity(
sorting_analyzer_or_templates, peak_sign=peak_sign, method="radius", radius_um=radius_um
)
if sorting_analyzer_or_templates.sparsity is None:
sparsity = compute_sparsity(
sorting_analyzer_or_templates, peak_sign=peak_sign, method="radius", radius_um=radius_um
)
else:
sparsity = sorting_analyzer_or_templates.sparsity

templates = get_dense_templates_array(
sorting_analyzer_or_templates, return_scaled=get_return_scaled(sorting_analyzer_or_templates)
)
Expand Down Expand Up @@ -650,8 +658,55 @@ def get_convolution_weights(
enforce_decrease_shells = numba.jit(enforce_decrease_shells_data, nopython=True)


def compute_location_max_channel(
templates_or_sorting_analyzer: SortingAnalyzer | Templates,
unit_ids=None,
peak_sign: "neg" | "pos" | "both" = "neg",
mode: "extremum" | "at_index" | "peak_to_peak" = "extremum",
) -> np.ndarray:
"""
Localize a unit using max channel.
This uses internally `get_template_extremum_channel()`
Parameters
----------
templates_or_sorting_analyzer : SortingAnalyzer | Templates
A SortingAnalyzer or Templates object
unit_ids: list[str] | list[int] | None
A list of unit_id to restrict the computation
peak_sign : "neg" | "pos" | "both"
Sign of the template to find extremum channels
mode : "extremum" | "at_index" | "peak_to_peak", default: "at_index"
Where the amplitude is computed
* "extremum" : take the peak value (max or min depending on `peak_sign`)
* "at_index" : take value at `nbefore` index
* "peak_to_peak" : take the peak-to-peak amplitude
Returns
-------
unit_locations: np.ndarray
2d
"""
extremum_channels_index = get_template_extremum_channel(
templates_or_sorting_analyzer, peak_sign=peak_sign, mode=mode, outputs="index"
)
contact_locations = templates_or_sorting_analyzer.get_channel_locations()
if unit_ids is None:
unit_ids = templates_or_sorting_analyzer.unit_ids
else:
unit_ids = np.asarray(unit_ids)
unit_locations = np.zeros((unit_ids.size, 2), dtype="float32")
for i, unit_id in enumerate(unit_ids):
unit_locations[i, :] = contact_locations[extremum_channels_index[unit_id]]

return unit_locations


_unit_location_methods = {
"center_of_mass": compute_center_of_mass,
"grid_convolution": compute_grid_convolution,
"monopolar_triangulation": compute_monopolar_triangulation,
"max_channel": compute_location_max_channel,
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class TestUnitLocationsExtension(AnalyzerExtensionCommonTestSuite):
dict(method="grid_convolution", radius_um=150, weight_method={"mode": "gaussian_2d"}),
dict(method="monopolar_triangulation", radius_um=150),
dict(method="monopolar_triangulation", radius_um=150, optimizer="minimize_with_log_penality"),
dict(method="max_channel"),
],
)
def test_extension(self, params):
Expand Down

0 comments on commit bfe9fb6

Please sign in to comment.