Skip to content

Commit

Permalink
Merge branch 'main' into prepare_release
Browse files Browse the repository at this point in the history
  • Loading branch information
alejoe91 authored Jul 19, 2024
2 parents ad2f656 + f4505e5 commit 71af00e
Show file tree
Hide file tree
Showing 43 changed files with 2,020 additions and 326 deletions.
14 changes: 11 additions & 3 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ spikeinterface.preprocessing
.. autofunction:: common_reference
.. autofunction:: correct_lsb
.. autofunction:: correct_motion
.. autofunction:: get_motion_presets
.. autofunction:: get_motion_parameters_preset
.. autofunction:: depth_order
.. autofunction:: detect_bad_channels
.. autofunction:: directional_derivative
Expand Down Expand Up @@ -324,15 +326,21 @@ spikeinterface.curation
------------------------
.. automodule:: spikeinterface.curation

.. autoclass:: CurationSorting
.. autoclass:: MergeUnitsSorting
.. autoclass:: SplitUnitSorting
.. autofunction:: apply_curation
.. autofunction:: get_potential_auto_merge
.. autofunction:: find_redundant_units
.. autofunction:: remove_redundant_units
.. autofunction:: remove_duplicated_spikes
.. autofunction:: remove_excess_spikes

Deprecated
~~~~~~~~~~
.. automodule:: spikeinterface.curation

.. autofunction:: apply_sortingview_curation
.. autoclass:: CurationSorting
.. autoclass:: MergeUnitsSorting
.. autoclass:: SplitUnitSorting


spikeinterface.generation
Expand Down
1,474 changes: 1,353 additions & 121 deletions doc/how_to/handle_drift.rst

Large diffs are not rendered by default.

Binary file removed doc/how_to/handle_drift_files/handle_drift_13_0.png
Binary file not shown.
Binary file removed doc/how_to/handle_drift_files/handle_drift_13_1.png
Binary file not shown.
Binary file removed doc/how_to/handle_drift_files/handle_drift_13_2.png
Binary file not shown.
Binary file removed doc/how_to/handle_drift_files/handle_drift_15_0.png
Binary file not shown.
Binary file removed doc/how_to/handle_drift_files/handle_drift_15_1.png
Binary file not shown.
Binary file removed doc/how_to/handle_drift_files/handle_drift_15_2.png
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified doc/how_to/handle_drift_files/handle_drift_17_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
22 changes: 17 additions & 5 deletions doc/modules/curation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -261,11 +261,23 @@ format is the definition; the second part of the format is manual action):
}
.. note::
The curation format was recently introduced (v0.101.0), and we are still working on
properly integrating it into the SpikeInterface ecosystem.
Soon there will be functions vailable, in the curation module, to apply this
standardized curation format to ``SortingAnalyzer`` and a ``BaseSorting`` objects.
The curation format can be loaded into a dictionary and directly applied to
a ``BaseSorting`` or ``SortingAnalyzer`` object using the :py:func:`~spikeinterface.curation.apply_curation` function.

.. code-block:: python
from spikeinterface.curation import apply_curation
# load the curation JSON file
curation_json = "path/to/curation.json"
with open(curation_json, 'r') as f:
curation_dict = json.load(f)
# apply the curation to the sorting output
clean_sorting = apply_curation(sorting, curation_dict=curation_dict)
# apply the curation to the sorting analyzer
clean_sorting_analyzer = apply_curation(sorting_analyzer, curation_dict=curation_dict)
Using the ``SpikeInterface GUI``
Expand Down
85 changes: 52 additions & 33 deletions examples/how_to/handle_drift.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
# jupyter:
# jupytext:
# cell_metadata_filter: -all
# formats: py,ipynb
# formats: py:light,ipynb
# text_representation:
# extension: .py
# format_name: light
# format_version: '1.5'
# jupytext_version: 1.14.6
# jupytext_version: 1.16.2
# kernelspec:
# display_name: Python 3 (ipykernel)
# language: python
Expand Down Expand Up @@ -55,6 +55,8 @@

import spikeinterface.full as si

from spikeinterface.preprocessing import get_motion_parameters_preset, get_motion_presets

# -

base_folder = Path("/mnt/data/sam/DataSpikeSorting/imposed_motion_nick")
Expand All @@ -70,6 +72,7 @@


def preprocess_chain(rec):
rec = rec.astype('float32')
rec = si.bandpass_filter(rec, freq_min=300.0, freq_max=6000.0)
rec = si.common_reference(rec, reference="global", operator="median")
return rec
Expand All @@ -79,33 +82,46 @@ def preprocess_chain(rec):

job_kwargs = dict(n_jobs=40, chunk_duration="1s", progress_bar=True)

# ### Run motion correction with one function!
#
#
# Correcting for drift is easy! You just need to run a single function.
# We will try this function with 3 presets.
# We will try this function with some presets.
#
# Internally a preset is a dictionary of dictionaries containing all parameters for every steps.
#
# Here we also save the motion correction results into a folder to be able to load them later.

# internally, we can explore a preset like this
# every parameter can be overwritten at runtime
from spikeinterface.preprocessing.motion import motion_options_preset
# ### preset and parameters
#
# Motion correction has some steps and eevry step can be controlled by a method and related parameters.
#
# A preset is a nested dict that contains theses methods/parameters.

preset_keys = get_motion_presets()
preset_keys

one_preset_params = get_motion_parameters_preset("kilosort_like")
one_preset_params

# ### Run motion correction with one function!
#
# Correcting for drift is easy! You just need to run a single function.
# We will try this function with some presets.
#
# Here we also save the motion correction results into a folder to be able to load them later.

motion_options_preset["kilosort_like"]
# lets try theses presets
some_presets = ("rigid_fast", "kilosort_like", "nonrigid_accurate", "nonrigid_fast_and_accurate", "dredge", "dredge_fast")

# lets try theses 3 presets
some_presets = ("rigid_fast", "kilosort_like", "nonrigid_accurate")
# some_presets = ('kilosort_like', )

# compute motion with 3 presets
# compute motion with theses presets
for preset in some_presets:
print("Computing with", preset)
folder = base_folder / "motion_folder_dataset1" / preset
if folder.exists():
shutil.rmtree(folder)
recording_corrected, motion_info = si.correct_motion(
rec, preset=preset, folder=folder, output_motion_info=True, **job_kwargs
recording_corrected, motion, motion_info = si.correct_motion(
rec, preset=preset, folder=folder, output_motion=True, output_motion_info=True, **job_kwargs
)

# ### Plot the results
Expand All @@ -127,11 +143,20 @@ def preprocess_chain(rec):
# The motion vector is computed for different depths.
# The corrected peak locations are flatter than the rigid case.
# The motion vector map is still be a bit noisy at some depths (e.g around 1000um).
# * The preset **nonrigid_accurate** seems to give the best results on this recording.
# The motion vector seems less noisy globally, but it is not "perfect" (see at the top of the probe 3200um to 3800um).
# Also note that in the first part of the recording before the imposed motion (0-600s) we clearly have a non-rigid motion:
# * The preset **dredge** is offcial DREDge re-implementation in spikeinterface.
# It give the best result : very fast and smooth motion estimation. Very few noise.
# This method also capture very well the non rigid motion gradient along the probe.
# The best method on the market at the moement.
# An enormous thanks to the dream team : Charlie Windolf, Julien Boussard, Erdem Varol, Liam Paninski.
# Note that in the first part of the recording before the imposed motion (0-600s) we clearly have a non-rigid motion:
# the upper part of the probe (2000-3000um) experience some drifts, but the lower part (0-1000um) is relatively stable.
# The method defined by this preset is able to capture this.
# * The preset **nonrigid_accurate** this is the ancestor of "dredge" before it was published.
# It seems to give the good results on this recording but with bit more noise.
# * The preset **dredge_fast** similar than dredge but faster (using grid_convolution).
# * The preset **nonrigid_fast_and_accurate** a variant of nonrigid_accurate but faster (using grid_convolution).
#
#

for preset in some_presets:
# load
Expand All @@ -140,8 +165,8 @@ def preprocess_chain(rec):

# and plot
fig = plt.figure(figsize=(14, 8))
si.plot_motion(
motion_info,
si.plot_motion_info(
motion_info, rec,
figure=fig,
depth_lim=(400, 600),
color_amplitude=True,
Expand Down Expand Up @@ -173,6 +198,8 @@ def preprocess_chain(rec):
folder = base_folder / "motion_folder_dataset1" / preset
motion_info = si.load_motion_info(folder)

motion = motion_info["motion"]

fig, axs = plt.subplots(ncols=2, figsize=(12, 8), sharey=True)

ax = axs[0]
Expand All @@ -190,24 +217,16 @@ def preprocess_chain(rec):

color_kargs = dict(alpha=0.2, s=2, c=c)

loc = motion_info["peak_locations"]
peak_locations = motion_info["peak_locations"]
# color='black',
ax.scatter(loc["x"][mask][sl], loc["y"][mask][sl], **color_kargs)

loc2 = correct_motion_on_peaks(
motion_info["peaks"],
motion_info["peak_locations"],
rec.sampling_frequency,
motion_info["motion"],
motion_info["temporal_bins"],
motion_info["spatial_bins"],
direction="y",
)
ax.scatter(peak_locations["x"][mask][sl], peak_locations["y"][mask][sl], **color_kargs)

peak_locations2 = correct_motion_on_peaks(peaks, peak_locations, motion,rec)

ax = axs[1]
si.plot_probe_map(rec, ax=ax)
# color='black',
ax.scatter(loc2["x"][mask][sl], loc2["y"][mask][sl], **color_kargs)
ax.scatter(peak_locations2["x"][mask][sl], peak_locations2["y"][mask][sl], **color_kargs)

ax.set_ylim(400, 600)
fig.suptitle(f"{preset=}")
Expand All @@ -228,7 +247,7 @@ def preprocess_chain(rec):
keys = run_times[0].keys()

bottom = np.zeros(len(run_times))
fig, ax = plt.subplots()
fig, ax = plt.subplots(figsize=(14, 6))
for k in keys:
rtimes = np.array([rt[k] for rt in run_times])
if np.any(rtimes > 0.0):
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ test_extractors = [

test_preprocessing = [
"ibllib>=2.36.0", # for IBL
"torch",
]


Expand Down
20 changes: 14 additions & 6 deletions src/spikeinterface/core/sorting_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def random_spikes_selection(


def apply_merges_to_sorting(
sorting, merge_unit_groups, new_unit_ids=None, censor_ms=None, return_kept=False, new_id_strategy="append"
sorting, merge_unit_groups, new_unit_ids=None, censor_ms=None, return_extra=False, new_id_strategy="append"
):
"""
Apply a resolved representation of the merges to a sorting object.
Expand All @@ -250,8 +250,8 @@ def apply_merges_to_sorting(
merged units will have the first unit_id of every lists of merges.
censor_ms: float | None, default: None
When applying the merges, should be discard consecutive spikes violating a given refractory per
return_kept : bool, default: False
If True, also return also a boolean mask of kept spikes.
return_extra : bool, default: False
If True, also return also a boolean mask of kept spikes and new_unit_ids.
new_id_strategy : "append" | "take_first", default: "append"
The strategy that should be used, if `new_unit_ids` is None, to create new unit_ids.
Expand Down Expand Up @@ -316,8 +316,8 @@ def apply_merges_to_sorting(
spikes = spikes[keep_mask]
sorting = NumpySorting(spikes, sorting.sampling_frequency, all_unit_ids)

if return_kept:
return sorting, keep_mask
if return_extra:
return sorting, keep_mask, new_unit_ids
else:
return sorting

Expand Down Expand Up @@ -384,11 +384,13 @@ def generate_unit_ids_for_merge_group(old_unit_ids, merge_unit_groups, new_unit_
new_unit_ids : list | None, default: None
Optional new unit_ids for merged units. If given, it needs to have the same length as `merge_unit_groups`.
If None, new ids will be generated.
new_id_strategy : "append" | "take_first", default: "append"
new_id_strategy : "append" | "take_first" | "join", default: "append"
The strategy that should be used, if `new_unit_ids` is None, to create new unit_ids.
* "append" : new_units_ids will be added at the end of max(sorging.unit_ids)
* "take_first" : new_unit_ids will be the first unit_id of every list of merges
* "join" : new_unit_ids will join unit_ids of groups with a "-".
Only works if unit_ids are str otherwise switch to "append"
Returns
-------
Expand Down Expand Up @@ -423,6 +425,12 @@ def generate_unit_ids_for_merge_group(old_unit_ids, merge_unit_groups, new_unit_
else:
# dtype int
new_unit_ids = list(max(old_unit_ids) + 1 + np.arange(num_merge, dtype=dtype))
elif new_id_strategy == "join":
if np.issubdtype(dtype, np.character):
new_unit_ids = ["-".join(group) for group in merge_unit_groups]
else:
# dtype int
new_unit_ids = list(max(old_unit_ids) + 1 + np.arange(num_merge, dtype=dtype))
else:
raise ValueError("wrong new_id_strategy")

Expand Down
14 changes: 10 additions & 4 deletions src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,12 +732,12 @@ def _save_or_select_or_merge(
else:
from spikeinterface.core.sorting_tools import apply_merges_to_sorting

sorting_provenance, keep_mask = apply_merges_to_sorting(
sorting_provenance, keep_mask, _ = apply_merges_to_sorting(
sorting=sorting_provenance,
merge_unit_groups=merge_unit_groups,
new_unit_ids=new_unit_ids,
censor_ms=censor_ms,
return_kept=True,
return_extra=True,
)
if censor_ms is None:
# in this case having keep_mask None is faster instead of having a vector of ones
Expand Down Expand Up @@ -885,6 +885,7 @@ def merge_units(
merging_mode="soft",
sparsity_overlap=0.75,
new_id_strategy="append",
return_new_unit_ids=False,
format="memory",
folder=None,
verbose=False,
Expand Down Expand Up @@ -917,14 +918,15 @@ def merge_units(
The strategy that should be used, if `new_unit_ids` is None, to create new unit_ids.
* "append" : new_units_ids will be added at the end of max(sorting.unit_ids)
* "take_first" : new_unit_ids will be the first unit_id of every list of merges
return_new_unit_ids : bool, default False
Alse return new_unit_ids which are the ids of the new units.
folder : Path | None, default: None
The new folder where the analyzer with merged units is copied for `format` "binary_folder" or "zarr"
format : "memory" | "binary_folder" | "zarr", default: "memory"
The format of SortingAnalyzer
verbose : bool, default: False
Whether to display calculations (such as sparsity estimation)
Returns
-------
analyzer : SortingAnalyzer
Expand Down Expand Up @@ -952,7 +954,7 @@ def merge_units(
)
all_unit_ids = _get_ids_after_merging(self.unit_ids, merge_unit_groups, new_unit_ids=new_unit_ids)

return self._save_or_select_or_merge(
new_analyzer = self._save_or_select_or_merge(
format=format,
folder=folder,
merge_unit_groups=merge_unit_groups,
Expand All @@ -964,6 +966,10 @@ def merge_units(
new_unit_ids=new_unit_ids,
**job_kwargs,
)
if return_new_unit_ids:
return new_analyzer, new_unit_ids
else:
return new_analyzer

def copy(self):
"""
Expand Down
7 changes: 6 additions & 1 deletion src/spikeinterface/core/tests/test_sorting_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def test_apply_merges_to_sorting():
spikes1[spikes1["unit_index"] == 2]["sample_index"], spikes2[spikes2["unit_index"] == 0]["sample_index"]
)

sorting3, keep_mask = apply_merges_to_sorting(sorting1, [["a", "b"]], censor_ms=1.5, return_kept=True)
sorting3, keep_mask, _ = apply_merges_to_sorting(sorting1, [["a", "b"]], censor_ms=1.5, return_extra=True)
spikes3 = sorting3.to_spike_vector()
assert spikes3.size < spikes1.size
assert not keep_mask[1]
Expand Down Expand Up @@ -153,6 +153,11 @@ def test_generate_unit_ids_for_merge_group():
)
assert np.array_equal(new_unit_ids, ["0", "9"])

new_unit_ids = generate_unit_ids_for_merge_group(
["0", "5", "12", "9", "15"], [["0", "5"], ["9", "15"]], new_id_strategy="join"
)
assert np.array_equal(new_unit_ids, ["0-5", "9-15"])


if __name__ == "__main__":
# test_spike_vector_to_spike_trains()
Expand Down
Loading

0 comments on commit 71af00e

Please sign in to comment.