Skip to content

Commit

Permalink
Merge branch 'SpikeInterface:main' into meta_merging_sc2
Browse files Browse the repository at this point in the history
  • Loading branch information
yger authored Jun 29, 2024
2 parents e5b02a5 + 0d1dda3 commit 599755c
Show file tree
Hide file tree
Showing 23 changed files with 584 additions and 140 deletions.
2 changes: 1 addition & 1 deletion src/spikeinterface/core/basesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def register_recording(self, recording, check_spike_frames=True):
self.get_num_segments() == recording.get_num_segments()
), "The recording has a different number of segments than the sorting!"
if check_spike_frames:
if has_exceeding_spikes(recording, self):
if has_exceeding_spikes(self, recording):
warnings.warn(
"Some spikes exceed the recording's duration! "
"Removing these excess spikes with `spikeinterface.curation.remove_excess_spikes()` "
Expand Down
125 changes: 113 additions & 12 deletions src/spikeinterface/core/core_tools.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from __future__ import annotations
from pathlib import Path, WindowsPath
from typing import Union
from typing import Union, Generator
import os
import sys
import datetime
import json
from copy import deepcopy
import importlib
from math import prod
from collections import namedtuple

import numpy as np

Expand Down Expand Up @@ -183,6 +184,75 @@ def is_dict_extractor(d: dict) -> bool:
return is_extractor


extractor_dict_element = namedtuple(typename="extractor_dict_element", field_names=["value", "name", "access_path"])


def extractor_dict_iterator(extractor_dict: dict) -> Generator[extractor_dict_element]:
"""
Iterator for recursive traversal of a dictionary.
This function explores the dictionary recursively and yields the path to each value along with the value itself.
By path here we mean the keys that lead to the value in the dictionary:
e.g. for the dictionary {'a': {'b': 1}}, the path to the value 1 is ('a', 'b').
See `BaseExtractor.to_dict()` for a description of `extractor_dict` structure.
Parameters
----------
extractor_dict : dict
Input dictionary
Yields
------
extractor_dict_element
Named tuple containing the value, the name, and the access_path to the value in the dictionary.
"""

def _extractor_dict_iterator(dict_list_or_value, access_path=(), name=""):
if isinstance(dict_list_or_value, dict):
for k, v in dict_list_or_value.items():
yield from _extractor_dict_iterator(v, access_path + (k,), name=k)
elif isinstance(dict_list_or_value, list):
for i, v in enumerate(dict_list_or_value):
yield from _extractor_dict_iterator(
v, access_path + (i,), name=name
) # Propagate name of list to children
else:
yield extractor_dict_element(
value=dict_list_or_value,
name=name,
access_path=access_path,
)

yield from _extractor_dict_iterator(extractor_dict)


def set_value_in_extractor_dict(extractor_dict: dict, access_path: tuple, new_value):
"""
In place modification of a value in a nested dictionary given its access path.
Parameters
----------
extractor_dict : dict
The dictionary to modify
access_path : tuple
The path to the value in the dictionary
new_value : object
The new value to set
Returns
-------
dict
The modified dictionary
"""

current = extractor_dict
for key in access_path[:-1]:
current = current[key]
current[access_path[-1]] = new_value


def recursive_path_modifier(d, func, target="path", copy=True) -> dict:
"""
Generic function for recursive modification of paths in an extractor dict.
Expand Down Expand Up @@ -250,15 +320,17 @@ def recursive_path_modifier(d, func, target="path", copy=True) -> dict:
raise ValueError(f"{k} key for path must be str or list[str]")


def _get_paths_list(d):
# this explore a dict and get all paths flatten in a list
# the trick is to use a closure func called by recursive_path_modifier()
path_list = []
# This is the current definition that an element in a extractor_dict is a path
# This is shared across a couple of definition so it is here for DNRY
element_is_path = lambda element: "path" in element.name and isinstance(element.value, (str, Path))


def append_to_path(p):
path_list.append(p)
def _get_paths_list(d: dict) -> list[str | Path]:
path_list = [e.value for e in extractor_dict_iterator(d) if element_is_path(e)]

# if check_if_exists: TODO: Enable this once container_tools test uses proper mocks
# path_list = [p for p in path_list if Path(p).exists()]

recursive_path_modifier(d, append_to_path, target="path", copy=True)
return path_list


Expand Down Expand Up @@ -318,7 +390,7 @@ def check_paths_relative(input_dict, relative_folder) -> bool:
return len(not_possible) == 0


def make_paths_relative(input_dict, relative_folder) -> dict:
def make_paths_relative(input_dict: dict, relative_folder: str | Path) -> dict:
"""
Recursively transform a dict describing an BaseExtractor to make every path relative to a folder.
Expand All @@ -334,9 +406,22 @@ def make_paths_relative(input_dict, relative_folder) -> dict:
output_dict: dict
A copy of the input dict with modified paths.
"""

relative_folder = Path(relative_folder).resolve().absolute()
func = lambda p: _relative_to(p, relative_folder)
output_dict = recursive_path_modifier(input_dict, func, target="path", copy=True)

path_elements_in_dict = [e for e in extractor_dict_iterator(input_dict) if element_is_path(e)]
# Only paths that exist are made relative
path_elements_in_dict = [e for e in path_elements_in_dict if Path(e.value).exists()]

output_dict = deepcopy(input_dict)
for element in path_elements_in_dict:
new_value = _relative_to(element.value, relative_folder)
set_value_in_extractor_dict(
extractor_dict=output_dict,
access_path=element.access_path,
new_value=new_value,
)

return output_dict


Expand All @@ -359,12 +444,28 @@ def make_paths_absolute(input_dict, base_folder):
base_folder = Path(base_folder)
# use as_posix instead of str to make the path unix like even on window
func = lambda p: (base_folder / p).resolve().absolute().as_posix()
output_dict = recursive_path_modifier(input_dict, func, target="path", copy=True)

path_elements_in_dict = [e for e in extractor_dict_iterator(input_dict) if element_is_path(e)]
output_dict = deepcopy(input_dict)

output_dict = deepcopy(input_dict)
for element in path_elements_in_dict:
absolute_path = (base_folder / element.value).resolve()
if Path(absolute_path).exists():
new_value = absolute_path.as_posix() # Not so sure about this, Sam
set_value_in_extractor_dict(
extractor_dict=output_dict,
access_path=element.access_path,
new_value=new_value,
)

return output_dict


def recursive_key_finder(d, key):
# Find all values for a key on a dictionary, even if nested
# TODO refactor to use extractor_dict_iterator

for k, v in d.items():
if isinstance(v, dict):
yield from recursive_key_finder(v, key)
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/frameslicesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(self, parent_sorting, start_frame=None, end_frame=None, check_spike
assert (
start_frame <= parent_n_samples
), "`start_frame` should be smaller than the sortings' total number of samples."
if check_spike_frames and has_exceeding_spikes(parent_sorting._recording, parent_sorting):
if check_spike_frames and has_exceeding_spikes(parent_sorting, parent_sorting._recording):
raise ValueError(
"The sorting object has spikes whose times go beyond the recording duration."
"This could indicate a bug in the sorter. "
Expand Down
16 changes: 8 additions & 8 deletions src/spikeinterface/core/node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,29 +152,29 @@ class SpikeRetriever(PeakSource):
* compute_spike_amplitudes()
* compute_principal_components()
sorting : BaseSorting
The sorting object.
recording : BaseRecording
The recording object.
sorting: BaseSorting
The sorting object.
channel_from_template: bool, default: True
channel_from_template : bool, default: True
If True, then the channel_index is inferred from the template and `extremum_channel_inds` must be provided.
If False, the max channel is computed for each spike given a radius around the template max channel.
extremum_channel_inds: dict of int | None, default: None
extremum_channel_inds : dict of int | None, default: None
The extremum channel index dict given from template.
radius_um: float, default: 50
radius_um : float, default: 50
The radius to find the real max channel.
Used only when channel_from_template=False
peak_sign: "neg" | "pos", default: "neg"
peak_sign : "neg" | "pos", default: "neg"
Peak sign to find the max channel.
Used only when channel_from_template=False
include_spikes_in_margin: bool, default False
include_spikes_in_margin : bool, default False
If not None then spikes in margin are added and an extra filed in dtype is added
"""

def __init__(
self,
recording,
sorting,
recording,
channel_from_template=True,
extremum_channel_inds=None,
radius_um=50,
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def create_sorting_analyzer(
recording.channel_ids, sparsity.channel_ids
), "create_sorting_analyzer(): if external sparsity is given unit_ids must correspond"
elif sparse:
sparsity = estimate_sparsity(recording, sorting, **sparsity_kwargs)
sparsity = estimate_sparsity(sorting, recording, **sparsity_kwargs)
else:
sparsity = None

Expand Down
7 changes: 4 additions & 3 deletions src/spikeinterface/core/sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,8 +539,8 @@ def compute_sparsity(


def estimate_sparsity(
recording: BaseRecording,
sorting: BaseSorting,
recording: BaseRecording,
num_spikes_for_sparsity: int = 100,
ms_before: float = 1.0,
ms_after: float = 2.5,
Expand All @@ -563,10 +563,11 @@ def estimate_sparsity(
Parameters
----------
recording : BaseRecording
The recording
sorting : BaseSorting
The sorting
recording : BaseRecording
The recording
num_spikes_for_sparsity : int, default: 100
How many spikes per units to compute the sparsity
ms_before : float, default: 1.0
Expand Down
Loading

0 comments on commit 599755c

Please sign in to comment.