Skip to content

Commit

Permalink
Merge branch 'main' of github.com:SpikeInterface/spikeinterface into …
Browse files Browse the repository at this point in the history
…reset-times
  • Loading branch information
alejoe91 committed Sep 10, 2024
2 parents 8c0ff56 + c240155 commit 669aff8
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 20 deletions.
2 changes: 1 addition & 1 deletion src/spikeinterface/core/core_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def is_dict_extractor(d: dict) -> bool:
extractor_dict_element = namedtuple(typename="extractor_dict_element", field_names=["value", "name", "access_path"])


def extractor_dict_iterator(extractor_dict: dict) -> Generator[extractor_dict_element]:
def extractor_dict_iterator(extractor_dict: dict) -> Generator[extractor_dict_element, None, None]:
"""
Iterator for recursive traversal of a dictionary.
This function explores the dictionary recursively and yields the path to each value along with the value itself.
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/recording_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def _init_binary_worker(recording, file_path_dict, dtype, byte_offest, cast_unsi
def write_binary_recording(
recording: "BaseRecording",
file_paths: list[Path | str] | Path | str,
dtype: np.ndtype = None,
dtype: np.typing.DTypeLike = None,
add_file_extension: bool = True,
byte_offset: int = 0,
auto_cast_uint: bool = True,
Expand Down
8 changes: 6 additions & 2 deletions src/spikeinterface/generation/drift_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,13 @@ def interpolate_templates(templates_array, source_locations, dest_locations, int
source_locations = np.asarray(source_locations)
dest_locations = np.asarray(dest_locations)
if dest_locations.ndim == 2:
new_shape = templates_array.shape
new_shape = (*templates_array.shape[:2], len(dest_locations))
elif dest_locations.ndim == 3:
new_shape = (dest_locations.shape[0],) + templates_array.shape
new_shape = (
dest_locations.shape[0],
*templates_array.shape[:2],
dest_locations.shape[1],
)
else:
raise ValueError(f"Incorrect dimensions for dest_locations: {dest_locations.ndim}. Dimensions can be 2 or 3. ")

Expand Down
36 changes: 21 additions & 15 deletions src/spikeinterface/postprocessing/template_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,9 @@

import numpy as np
import warnings
from typing import Optional
from copy import deepcopy

from ..core.sortinganalyzer import register_result_extension, AnalyzerExtension
from ..core import ChannelSparsity
from ..core.template_tools import get_template_extremum_channel
from ..core.template_tools import get_dense_templates_array

Expand Down Expand Up @@ -238,13 +236,17 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job

for metric_name in metrics_single_channel:
func = _metric_name_to_func[metric_name]
value = func(
template_upsampled,
sampling_frequency=sampling_frequency_up,
trough_idx=trough_idx,
peak_idx=peak_idx,
**self.params["metrics_kwargs"],
)
try:
value = func(
template_upsampled,
sampling_frequency=sampling_frequency_up,
trough_idx=trough_idx,
peak_idx=peak_idx,
**self.params["metrics_kwargs"],
)
except Exception as e:
warnings.warn(f"Error computing metric {metric_name} for unit {unit_id}: {e}")
value = np.nan
template_metrics.at[index, metric_name] = value

# compute metrics multi_channel
Expand Down Expand Up @@ -274,12 +276,16 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job
sampling_frequency_up = sampling_frequency

func = _metric_name_to_func[metric_name]
value = func(
template_upsampled,
channel_locations=channel_locations_sparse,
sampling_frequency=sampling_frequency_up,
**self.params["metrics_kwargs"],
)
try:
value = func(
template_upsampled,
channel_locations=channel_locations_sparse,
sampling_frequency=sampling_frequency_up,
**self.params["metrics_kwargs"],
)
except Exception as e:
warnings.warn(f"Error computing metric {metric_name} for unit {unit_id}: {e}")
value = np.nan
template_metrics.at[index, metric_name] = value
return template_metrics

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,10 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job
pc_metric_names = [k for k in metric_names if k in _possible_pc_metric_names]
if len(pc_metric_names) > 0 and not self.params["skip_pc_metrics"]:
if not sorting_analyzer.has_extension("principal_components"):
raise ValueError("waveform_principal_component must be provied")
raise ValueError(
"To compute principal components base metrics, the principal components "
"extension must be computed first."
)
pc_metrics = compute_pc_metrics(
sorting_analyzer,
unit_ids=non_empty_unit_ids,
Expand Down

0 comments on commit 669aff8

Please sign in to comment.