Skip to content

Commit

Permalink
Merge branch 'main' into count_spike_array
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelgarcia authored Nov 22, 2023
2 parents 712b1e0 + 38d82d4 commit b09ee96
Show file tree
Hide file tree
Showing 35 changed files with 906 additions and 452 deletions.
2 changes: 2 additions & 0 deletions doc/modules/preprocessing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ dtype (unless specified otherwise):
Some scaling pre-processors, such as :code:`whiten()` or :code:`zscore()`, will force the output to :code:`float32`.

When converting from a :code:`float` to an :code:`int`, the value will first be rounded to the nearest integer.


Available preprocessing
-----------------------
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/comparison/groundtruthstudy.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@


# This is to separate names when the key are tuples when saving folders
_key_separator = " ## "
_key_separator = "_##_"


class GroundTruthStudy:
Expand Down
91 changes: 61 additions & 30 deletions src/spikeinterface/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,35 +316,69 @@ def to_dict(
recursive: bool = False,
) -> dict:
"""
Make a nested serialized dictionary out of the extractor. The dictionary produced can be used to re-initialize
an extractor using load_extractor_from_dict(dump_dict)
Construct a nested dictionary representation of the extractor.
This method facilitates the serialization of the extractor instance by converting it
to a dictionary. The resulting dictionary can be used to re-initialize the extractor
through the `load_extractor_from_dict` function.
Examples
--------
>>> dump_dict = original_extractor.to_dict()
>>> reloaded_extractor = load_extractor_from_dict(dump_dict)
Parameters
----------
include_annotations: bool, default: False
If True, all annotations are added to the dict
include_properties: bool, default: False
If True, all properties are added to the dict
relative_to: str, Path, or None, default: None
If not None, files and folders are serialized relative to this path
Used in waveform extractor to maintain relative paths to binary files even if the
containing folder / diretory is moved
folder_metadata: str, Path, or None
Folder with numpy `npy` files containing additional information (e.g. probe in BaseRecording) and properties.
recursive: bool, default: False
If True, all dicitionaries in the kwargs are expanded with `to_dict` as well
include_annotations : bool, default: False
Whether to include all annotations in the dictionary
include_properties : bool, default: False
Whether to include all properties in the dictionary, by default False.
relative_to : Union[str, Path, None], default: None
If provided, file and folder paths will be made relative to this path,
enabling portability in folder formats such as the waveform extractor,
by default None.
folder_metadata : Union[str, Path, None], default: None
Path to a folder containing additional metadata files (e.g., probe information in BaseRecording)
in numpy `npy` format, by default None.
recursive : bool, default: False
If True, recursively apply `to_dict` to dictionaries within the kwargs, by default False.
Raises
------
ValueError
If `relative_to` is specified while `recursive` is False.
Returns
-------
dump_dict: dict
A dictionary representation of the extractor.
dict
A dictionary representation of the extractor, with the following structure:
{
"class": <the full import path of the class>,
"module": <module name>, (e.g. 'spikeinterface'),
"kwargs": <the values that were used to initialize the class>,
"version": <module version>,
"relative_paths": <whether paths are relative>,
"annotations": <annotations dictionary, if `include_annotations` is True>,
"properties": <properties dictionary, if `include_properties` is True>,
"folder_metadata": <relative path to folder_metadata, if specified>
}
Notes
-----
- The `relative_to` argument only has an effect if `recursive` is set to True.
- The `folder_metadata` argument will be made relative to `relative_to` if both are specified.
- The `version` field in the resulting dictionary reflects the version of the module
from which the extractor class originates.
- The full class attribute above is the full import of the class, e.g.
'spikeinterface.extractors.neoextractors.spikeglx.SpikeGLXRecordingExtractor'
- The module is usually 'spikeinterface', but can be different for custom extractors such as those of
SpikeForest or any other project that inherits the Extractor class from spikeinterface.
"""

kwargs = self._kwargs

if relative_to and not recursive:
raise ValueError("`relative_to` is only possible when `recursive=True`")

kwargs = self._kwargs
if recursive:
to_dict_kwargs = dict(
include_annotations=include_annotations,
Expand All @@ -366,27 +400,24 @@ def to_dict(
new_kwargs[name] = transform_extractors_to_dict(value)

kwargs = new_kwargs
class_name = str(type(self)).replace("<class '", "").replace("'>", "")

module_import_path = self.__class__.__module__
class_name_no_path = self.__class__.__name__
class_name = f"{module_import_path}.{class_name_no_path}" # e.g. 'spikeinterface.core.generate.AClass'
module = class_name.split(".")[0]
imported_module = importlib.import_module(module)

try:
version = imported_module.__version__
except AttributeError:
version = "unknown"
imported_module = importlib.import_module(module)
module_version = getattr(imported_module, "__version__", "unknown")

dump_dict = {
"class": class_name,
"module": module,
"kwargs": kwargs,
"version": version,
"version": module_version,
"relative_paths": (relative_to is not None),
}

try:
dump_dict["version"] = imported_module.__version__
except AttributeError:
dump_dict["version"] = "unknown"
dump_dict["version"] = module_version # Can be spikeinterface, spikefores, etc.

if include_annotations:
dump_dict["annotations"] = self._annotations
Expand Down Expand Up @@ -805,7 +836,7 @@ def save_to_folder(self, name=None, folder=None, overwrite=False, verbose=True,
* explicit sub-folder, implicit base-folder : `extractor.save(name="extarctor_name")`
* generated: `extractor.save()`
The second option saves to subfolder "extarctor_name" in
The second option saves to subfolder "extractor_name" in
"get_global_tmp_folder()". You can set the global tmp folder with:
"set_global_tmp_folder("path-to-global-folder")"
Expand Down
64 changes: 42 additions & 22 deletions src/spikeinterface/core/basesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,13 @@ def add_sorting_segment(self, sorting_segment):
self._sorting_segments.append(sorting_segment)
sorting_segment.set_parent_extractor(self)

def get_sampling_frequency(self):
def get_sampling_frequency(self) -> float:
return self._sampling_frequency

def get_num_segments(self):
def get_num_segments(self) -> int:
return len(self._sorting_segments)

def get_num_samples(self, segment_index=None):
def get_num_samples(self, segment_index=None) -> int:
"""Returns the number of samples of the associated recording for a segment.
Parameters
Expand All @@ -82,7 +82,7 @@ def get_num_samples(self, segment_index=None):
), "This methods requires an associated recording. Call self.register_recording() first."
return self._recording.get_num_samples(segment_index=segment_index)

def get_total_samples(self):
def get_total_samples(self) -> int:
"""Returns the total number of samples of the associated recording.
Returns
Expand Down Expand Up @@ -322,9 +322,11 @@ def count_num_spikes_per_unit(self, outputs="dict"):
else:
raise ValueError("count_num_spikes_per_unit() output must be 'dict' or 'array'")

def count_total_num_spikes(self):
def count_total_num_spikes(self) -> int:
"""
Get total number of spikes summed across segment and units.
Get total number of spikes in the sorting.
This is the sum of all spikes in all segments across all units.
Returns
-------
Expand All @@ -333,9 +335,10 @@ def count_total_num_spikes(self):
"""
return self.to_spike_vector().size

def select_units(self, unit_ids, renamed_unit_ids=None):
def select_units(self, unit_ids, renamed_unit_ids=None) -> BaseSorting:
"""
Selects a subset of units
Returns a new sorting object which contains only a selected subset of units.
Parameters
----------
Expand All @@ -354,9 +357,30 @@ def select_units(self, unit_ids, renamed_unit_ids=None):
sub_sorting = UnitsSelectionSorting(self, unit_ids, renamed_unit_ids=renamed_unit_ids)
return sub_sorting

def remove_units(self, remove_unit_ids):
def rename_units(self, new_unit_ids: np.ndarray | list) -> BaseSorting:
"""
Returns a new sorting object with renamed units.
Parameters
----------
new_unit_ids : numpy.array or list
List of new names for unit ids.
They should map positionally to the existing unit ids.
Returns
-------
BaseSorting
Sorting object with renamed units
"""
from spikeinterface import UnitsSelectionSorting

sub_sorting = UnitsSelectionSorting(self, renamed_unit_ids=new_unit_ids)
return sub_sorting

def remove_units(self, remove_unit_ids) -> BaseSorting:
"""
Removes a subset of units
Returns a new sorting object with contains only a selected subset of units.
Parameters
----------
Expand All @@ -366,7 +390,7 @@ def remove_units(self, remove_unit_ids):
Returns
-------
BaseSorting
Sorting object without removed units
Sorting without the removed units
"""
from spikeinterface import UnitsSelectionSorting

Expand All @@ -376,7 +400,8 @@ def remove_units(self, remove_unit_ids):

def remove_empty_units(self):
"""
Removes units with empty spike trains
Returns a new sorting object which contains only units with at least one spike.
For multi-segments, a unit is considered empty if it contains no spikes in all segments.
Returns
-------
Expand All @@ -387,16 +412,12 @@ def remove_empty_units(self):
return self.select_units(non_empty_units)

def get_non_empty_unit_ids(self):
non_empty_units = []
for segment_index in range(self.get_num_segments()):
for unit in self.get_unit_ids():
if len(self.get_unit_spike_train(unit, segment_index=segment_index)) > 0:
non_empty_units.append(unit)
non_empty_units = np.unique(non_empty_units)
return non_empty_units
num_spikes_per_unit = self.count_num_spikes_per_unit()

return np.array([unit_id for unit_id in self.unit_ids if num_spikes_per_unit[unit_id] != 0])

def get_empty_unit_ids(self):
unit_ids = self.get_unit_ids()
unit_ids = self.unit_ids
empty_units = unit_ids[~np.isin(unit_ids, self.get_non_empty_unit_ids())]
return empty_units

Expand All @@ -412,7 +433,7 @@ def get_all_spike_trains(self, outputs="unit_id"):
"""
Return all spike trains concatenated.
This is deprecated use sorting.to_spike_vector() instead
This is deprecated and will be removed in spikeinterface 0.102 use sorting.to_spike_vector() instead
"""

warnings.warn(
Expand Down Expand Up @@ -452,7 +473,6 @@ def to_spike_vector(self, concatenated=True, extremum_channel_inds=None, use_cac
Construct a unique structured numpy vector concatenating all spikes
with several fields: sample_index, unit_index, segment_index.
See also `get_all_spike_trains()`
Parameters
----------
Expand Down
57 changes: 51 additions & 6 deletions src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1336,15 +1336,60 @@ def generate_channel_locations(num_channels, num_columns, contact_spacing_um):
return channel_locations


def generate_unit_locations(num_units, channel_locations, margin_um=20.0, minimum_z=5.0, maximum_z=40.0, seed=None):
def generate_unit_locations(
num_units,
channel_locations,
margin_um=20.0,
minimum_z=5.0,
maximum_z=40.0,
minimum_distance=20.0,
max_iteration=100,
distance_strict=False,
seed=None,
):
rng = np.random.default_rng(seed=seed)
units_locations = np.zeros((num_units, 3), dtype="float32")
for dim in (0, 1):
lim0 = np.min(channel_locations[:, dim]) - margin_um
lim1 = np.max(channel_locations[:, dim]) + margin_um
units_locations[:, dim] = rng.uniform(lim0, lim1, size=num_units)

minimum_x, maximum_x = np.min(channel_locations[:, 0]) - margin_um, np.max(channel_locations[:, 0]) + margin_um
minimum_y, maximum_y = np.min(channel_locations[:, 1]) - margin_um, np.max(channel_locations[:, 1]) + margin_um

units_locations[:, 0] = rng.uniform(minimum_x, maximum_x, size=num_units)
units_locations[:, 1] = rng.uniform(minimum_y, maximum_y, size=num_units)
units_locations[:, 2] = rng.uniform(minimum_z, maximum_z, size=num_units)

if minimum_distance is not None:
solution_found = False
renew_inds = None
for i in range(max_iteration):
distances = np.linalg.norm(units_locations[:, np.newaxis] - units_locations[np.newaxis, :], axis=2)
inds0, inds1 = np.nonzero(distances < minimum_distance)
mask = inds0 != inds1
inds0 = inds0[mask]
inds1 = inds1[mask]

if inds0.size > 0:
if renew_inds is None:
renew_inds = np.unique(inds0)
else:
# random only bad ones in the previous set
renew_inds = renew_inds[np.isin(renew_inds, np.unique(inds0))]

units_locations[:, 0][renew_inds] = rng.uniform(minimum_x, maximum_x, size=renew_inds.size)
units_locations[:, 1][renew_inds] = rng.uniform(minimum_y, maximum_y, size=renew_inds.size)
units_locations[:, 2][renew_inds] = rng.uniform(minimum_z, maximum_z, size=renew_inds.size)
else:
solution_found = True
break

if not solution_found:
if distance_strict:
raise ValueError(
f"generate_unit_locations(): no solution for {minimum_distance=} and {max_iteration=} "
"You can use distance_strict=False or reduce minimum distance"
)
else:
warnings.warn(f"generate_unit_locations(): no solution for {minimum_distance=} and {max_iteration=}")

return units_locations


Expand All @@ -1369,7 +1414,7 @@ def generate_ground_truth_recording(
upsample_vector=None,
generate_sorting_kwargs=dict(firing_rates=15, refractory_period_ms=4.0),
noise_kwargs=dict(noise_level=5.0, strategy="on_the_fly"),
generate_unit_locations_kwargs=dict(margin_um=10.0, minimum_z=5.0, maximum_z=50.0),
generate_unit_locations_kwargs=dict(margin_um=10.0, minimum_z=5.0, maximum_z=50.0, minimum_distance=20),
generate_templates_kwargs=dict(),
dtype="float32",
seed=None,
Expand Down
Loading

0 comments on commit b09ee96

Please sign in to comment.