diff --git a/.github/import_test.py b/.github/import_test.py index f7c3e9f858..52bdde42f9 100644 --- a/.github/import_test.py +++ b/.github/import_test.py @@ -26,18 +26,17 @@ time_taken_list = [] for _ in range(n_samples): script_to_execute = ( - f"import timeit \n" - f"import_statement = '{import_statement}' \n" - f"time_taken = timeit.timeit(import_statement, number=1) \n" - f"print(time_taken) \n" - ) + f"import timeit \n" + f"import_statement = '{import_statement}' \n" + f"time_taken = timeit.timeit(import_statement, number=1) \n" + f"print(time_taken) \n" + ) result = subprocess.run(["python", "-c", script_to_execute], capture_output=True, text=True) if result.returncode != 0: - error_message = ( - f"Error when running {import_statement} \n" - f"Error in subprocess: {result.stderr.strip()}\n" + error_message = ( + f"Error when running {import_statement} \n" f"Error in subprocess: {result.stderr.strip()}\n" ) exceptions.append(error_message) break @@ -46,15 +45,25 @@ time_taken_list.append(time_taken) for time in time_taken_list: - if time > 1.5: - exceptions.append(f"Importing {import_statement} took too long: {time:.2f} seconds") + import_time_threshold = 2.0 # Most of the times is sub-second but there outliers + if time >= import_time_threshold: + exceptions.append( + f"Importing {import_statement} took: {time:.2f} s. Should be <: {import_time_threshold} s." + ) break + if time_taken_list: - avg_time_taken = sum(time_taken_list) / len(time_taken_list) - std_dev_time_taken = math.sqrt(sum((x - avg_time_taken) ** 2 for x in time_taken_list) / len(time_taken_list)) + avg_time = sum(time_taken_list) / len(time_taken_list) + std_time = math.sqrt(sum((x - avg_time) ** 2 for x in time_taken_list) / len(time_taken_list)) times_list_str = ", ".join(f"{time:.2f}" for time in time_taken_list) - markdown_output += f"| `{import_statement}` | {avg_time_taken:.2f} | {std_dev_time_taken:.2f} | {times_list_str} |\n" + markdown_output += f"| `{import_statement}` | {avg_time:.2f} | {std_time:.2f} | {times_list_str} |\n" + + import_time_threshold = 1.0 + if avg_time > import_time_threshold: + exceptions.append( + f"Importing {import_statement} took: {avg_time:.2f} s in average. Should be <: {import_time_threshold} s." + ) if exceptions: raise Exception("\n".join(exceptions)) diff --git a/.github/workflows/installation-tips-test.yml b/.github/workflows/installation-tips-test.yml index cbe313b12e..e83399cf7c 100644 --- a/.github/workflows/installation-tips-test.yml +++ b/.github/workflows/installation-tips-test.yml @@ -28,8 +28,9 @@ jobs: with: python-version: '3.10' - name: Test Conda Environment Creation - uses: conda-incubator/setup-miniconda@v2.2.0 + uses: conda-incubator/setup-miniconda@v3 with: + miniconda-version: "latest" environment-file: ./installation_tips/full_spikeinterface_environment_${{ matrix.label }}.yml activate-environment: si_env - name: Check Installation Tips diff --git a/doc/how_to/index.rst b/doc/how_to/index.rst index 79156e2690..54fd404848 100644 --- a/doc/how_to/index.rst +++ b/doc/how_to/index.rst @@ -1,5 +1,5 @@ How to Guides -========= +============= Guides on how to solve specific, short problems in SpikeInterface. Learn how to... @@ -12,3 +12,4 @@ Guides on how to solve specific, short problems in SpikeInterface. Learn how to. load_matlab_data combine_recordings process_by_channel_group + load_your_data_into_sorting diff --git a/doc/how_to/load_your_data_into_sorting.rst b/doc/how_to/load_your_data_into_sorting.rst new file mode 100644 index 0000000000..21c4460c5a --- /dev/null +++ b/doc/how_to/load_your_data_into_sorting.rst @@ -0,0 +1,154 @@ +Load Your Own Data into a Sorting +================================= + +Why make a :code:`Sorting`? + +SpikeInterface contains pre-build readers for the output of many common sorters. +However, what if you have sorting output that is not in a standard format (e.g. +old csv file)? If this is the case you can make your own Sorting object to load +your data into SpikeInterface. This means you can still easily apply various +downstream analyses to your results (e.g. building correlograms or for generating +a :code:`SortingAnalyzer``). + +The Sorting object is a core object within SpikeInterface that acts as a convenient +way to interface with sorting results, no matter which sorter was used to generate +them. **At a fundamental level it is a series of spike times and a series of labels +for each unit and a sampling frequency for transforming frames to time.** Below, we will show you have +to take your existing data and load it as a SpikeInterface :code:`Sorting` object. + + +Reading a standard spike sorting format into a :code:`Sorting` +------------------------------------------------------------- + +For most spike sorting output formats the :code:`Sorting` is automatically generated. For example one could do + +.. code-block:: python + + from spikeinterface.extractors import read_phy + + # For kilosort/phy output files we can use the read_phy + # most formats will have a read_xx that can used. + phy_sorting = read_phy('path/to/folder') + +And voilà you now have your :code:`Sorting` object generated and can use it for further analysis. For all the +current formats see :ref:`compatible_formats`. + + + +Loading your own data into a :code:`Sorting` +------------------------------------------- + + +This :code:`Sorting` contains important information about your spike trains including: + + * spike times: the peaks of the extracellular potentials expressed in samples/frames these can + be converted to seconds under the hood using the sampling_frequency + * spike labels: the neuron id for each spike, can also be called cluster ids or unit ids + Stored as the :code:`unit_ids` in SpikeInterface + * sampling_frequency: the rate at which the recording equipment was run at. Note this is the + frequency and not the period. This value allows for switching between samples/frames to seconds + + +There are 3 options for loading your own data into a sorting object + +With lists of spike trains and spike labels +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +In this case we need a list of spike times unit labels, sampling_frequency and optional unit_ids +if you want specific labels to be used (in this case we only create the :code:`Sorting` based on +the requested unit_ids). + +.. code-block:: python + + import numpy as np + from spikeinterface.core import NumpySorting + + # in this case we are making a monosegment sorting + # we have four spikes that are spread among two neurons + my_sorting = NumpySorting.from_times_labels( + times_list=[ + np.array([1000,12000,15000,22000]) # Note these are samples/frames not times in seconds + ], + labels_list=[ + np.array(["a","b","a","b"]) + ], + sampling_frequency=30_000.0 + ) + + +With a unit dictionary +^^^^^^^^^^^^^^^^^^^^^^ + +We can also use a dictionary where each unit is a key and its spike times are values. +This is entered as either a list of dicts with each dict being a segment or as a single +dict for monosegment. We still need to separately specify the sampling_frequency + +.. code-block:: python + + from spikeinterface.core import NumpySorting + + my_sorting = NumpySorting.from_unit_dict( + units_dict_list={ + '0': [1000,15000], + '1': [12000,22000], + }, + sampling_frequency=30_000.0 + ) + + +With Neo SpikeTrains +^^^^^^^^^^^^^^^^^^^^ + +Finally since SpikeInterface is tightly integrated with the Neo project you can create +a sorting from :code:`Neo.SpikeTrain` objects. See :doc:`Neo documentation` for more information on +using :code:`Neo.SpikeTrain`'s. + +.. code-block:: python + + from spikeinterface.core import NumpySorting + + # neo_spiketrain is a Neo spiketrain object + my_sorting = NumpySorting.from_neo_spiketrain_list( + neo_spiketrain, + sampling_frequency=30_000.0, + ) + + +Loading multisegment data into a :code:`Sorting` +----------------------------------------------- + +One of the great advantages of SpikeInterface :code:`Sorting` objects is that they can also handle +multisegment recordings and sortings (e.g. you have a baseline, stimulus, post-stimulus). The +exact same machinery can be used to generate your sorting, but in this case we do a list of arrays instead of +a single list. Let's go through one example for using :code:`from_times_labels`: + +.. code-block:: python + + import numpy as np + from spikeinterface.core import NumpySorting + + # in this case we are making three-segment sorting + # we have four spikes that are spread among two neurons + # in each segment + my_sorting = NumpySorting.from_times_labels( + times_list=[ + np.array([1000,12000,15000,22000]), + np.array([30000,33000, 41000, 47000]), + np.array([50000,53000,64000,70000]), + ], + labels_list=[ + np.array([0,1,0,1]), + np.array([0,0,1,1]), + np.array([1,0,1,0]), + ], + sampling_frequency=30_000.0 + ) + + +Next steps +---------- + +Now that we've created a Sorting object you can combine it with a Recording to make a +:ref:`SortingAnalyzer` +or start visualizing using plotting functions from our widgets model such as +:py:func:`~spikeinterface.widgets.plot_crosscorrelograms`. diff --git a/doc/modules/curation.rst b/doc/modules/curation.rst index 401ceea5dc..46fdcc6d65 100644 --- a/doc/modules/curation.rst +++ b/doc/modules/curation.rst @@ -41,6 +41,111 @@ The merging and splitting operations are handled by the :py:class:`~spikeinterfa # here is the final clean sorting clean_sorting = cs.sorting +Manual curation format +---------------------- + +SpikeInterface internally supports a JSON-based manual curation format. +When manual curation is necessary, modifying a dataset in place is a bad practice. +Instead, to ensure the reproducibility of the spike sorting pipelines, we have introduced a simple and JSON-based manual curation format. +This format defines at the moment : merges + deletions + manual tags. +The simple file can be kept along side the output of a sorter and applied on the result to have a "clean" result. + +This format has two part: + + * **definition** with the folowing keys: + + * "format_version" : format specification + * "unit_ids" : the list of unit_ds + * "label_definitions" : list of label categories and possible labels per category. + Every category can be *exclusive=True* onely one label or *exclusive=False* several labels possible + + * **manual output** curation with the folowing keys: + + * "manual_labels" + * "merged_unit_groups" + * "removed_units" + +Here is the description of the format with a simple example: + +.. code-block:: json + + { + # the first part of the format is the definitation + "format_version": "1", + "unit_ids": [ + "u1", + "u2", + "u3", + "u6", + "u10", + "u14", + "u20", + "u31", + "u42" + ], + "label_definitions": { + "quality": { + "label_options": [ + "good", + "noise", + "MUA", + "artifact" + ], + "exclusive": true + }, + "putative_type": { + "label_options": [ + "excitatory", + "inhibitory", + "pyramidal", + "mitral" + ], + "exclusive": false + } + }, + # the second part of the format is manual action + "manual_labels": [ + { + "unit_id": "u1", + "quality": [ + "good" + ] + }, + { + "unit_id": "u2", + "quality": [ + "noise" + ], + "putative_type": [ + "excitatory", + "pyramidal" + ] + }, + { + "unit_id": "u3", + "putative_type": [ + "inhibitory" + ] + } + ], + "merged_unit_groups": [ + [ + "u3", + "u6" + ], + [ + "u10", + "u14", + "u20" + ] + ], + "removed_units": [ + "u31", + "u42" + ] + } + + Automatic curation tools ------------------------ diff --git a/doc/modules/motion_correction.rst b/doc/modules/motion_correction.rst index 8be2456caa..af81cb42d1 100644 --- a/doc/modules/motion_correction.rst +++ b/doc/modules/motion_correction.rst @@ -163,21 +163,19 @@ The high-level :py:func:`~spikeinterface.preprocessing.correct_motion()` is inte max_distance_um=150.0, **job_kwargs) # Step 2: motion inference - motion, temporal_bins, spatial_bins = estimate_motion(recording=rec, - peaks=peaks, - peak_locations=peak_locations, - method="decentralized", - direction="y", - bin_duration_s=2.0, - bin_um=5.0, - win_step_um=50.0, - win_sigma_um=150.0) + motion = estimate_motion(recording=rec, + peaks=peaks, + peak_locations=peak_locations, + method="decentralized", + direction="y", + bin_duration_s=2.0, + bin_um=5.0, + win_step_um=50.0, + win_sigma_um=150.0) # Step 3: motion interpolation # this step is lazy rec_corrected = interpolate_motion(recording=rec, motion=motion, - temporal_bins=temporal_bins, - spatial_bins=spatial_bins, border_mode="remove_channels", spatial_interpolation_method="kriging", sigma_um=30.) @@ -220,8 +218,6 @@ different preprocessing chains: one for motion correction and one for spike sort rec_corrected2 = interpolate_motion( recording=rec2, motion=motion_info['motion'], - temporal_bins=motion_info['temporal_bins'], - spatial_bins=motion_info['spatial_bins'], **motion_info['parameters']['interpolate_motion_kwargs']) sorting = run_sorter(sorter_name="montainsort5", recording=rec_corrected2) diff --git a/doc/releases/0.100.5.rst b/doc/releases/0.100.5.rst new file mode 100644 index 0000000000..1f480e942b --- /dev/null +++ b/doc/releases/0.100.5.rst @@ -0,0 +1,12 @@ +.. _release0.100.5: + +SpikeInterface 0.100.5 release notes +------------------------------------ + +6th April 2024 + +Minor release with bug fixes + +* Open Ephys: Use discovered recording ids to load sync timestamps (#2655) +* Fix channel gains in NwbRecordingExtractor with backend (#2661) +* Fix depth location in spikes on traces map (#2676) diff --git a/doc/releases/0.100.6.rst b/doc/releases/0.100.6.rst new file mode 100644 index 0000000000..7f2bb5cd66 --- /dev/null +++ b/doc/releases/0.100.6.rst @@ -0,0 +1,13 @@ +.. _release0.100.6: + +SpikeInterface 0.100.6 release notes +------------------------------------ + +30th April 2024 + +Minor release with bug fixes + +* Avoid np.prod in make_shared_array (#2621) +* Improve caching of MS5 sorter (#2690) +* Allow for remove_excess_spikes to remove negative spike times (#2716) +* Update ks4 wrapper for newer version>=4.0.3 (#2701, #2774) diff --git a/doc/releases/0.100.7.rst b/doc/releases/0.100.7.rst new file mode 100644 index 0000000000..a224494da5 --- /dev/null +++ b/doc/releases/0.100.7.rst @@ -0,0 +1,14 @@ +.. _release0.100.7: + +SpikeInterface 0.100.7 release notes +------------------------------------ + +7th June 2024 + +Minor release with bug fixes + +* Fix get_traces for a local common reference (#2649) +* Update KS4 parameters (#2810) +* Zarr: extract time vector once and for all! (#2828) +* Fix waveforms save in recordingless mode (#2889) +* Fix the new way of handling cmap in matpltolib. This fix the matplotib 3.9 problem related to this (#2891) diff --git a/doc/whatisnew.rst b/doc/whatisnew.rst index 3c9f2b44c7..becd91790e 100644 --- a/doc/whatisnew.rst +++ b/doc/whatisnew.rst @@ -8,6 +8,9 @@ Release notes .. toctree:: :maxdepth: 1 + releases/0.100.7.rst + releases/0.100.6.rst + releases/0.100.5.rst releases/0.100.4.rst releases/0.100.3.rst releases/0.100.2.rst @@ -38,6 +41,29 @@ Release notes releases/0.9.1.rst +(PRE-RELEASE) Version 0.101.0rc0 +================================ + +* Major release with `SortingAnalyzer` + +Version 0.100.7 +=============== + +* Minor release with bug fixes + + +Version 0.100.6 +=============== + +* Minor release with bug fixes + + +Version 0.100.5 +=============== + +* Minor release with bug fixes + + Version 0.100.4 =============== diff --git a/pyproject.toml b/pyproject.toml index a3551d0451..a2e1b3d3a5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "spikeinterface" -version = "0.101.0" +version = "0.101.0rc0" authors = [ { name="Alessio Buccino", email="alessiop.buccino@gmail.com" }, { name="Samuel Garcia", email="sam.garcia.die@gmail.com" }, @@ -20,7 +20,7 @@ classifiers = [ dependencies = [ - "numpy", + "numpy>=1.26, <2.0", # 1.20 np.ptp, 1.26 for avoiding pickling errors when numpy >2.0 "threadpoolctl>=3.0.0", "tqdm", "zarr>=2.16,<2.18", @@ -65,18 +65,16 @@ extractors = [ "pyedflib>=0.1.30", "sonpy;python_version<'3.10'", "lxml", # lxml for neuroscope - "scipy<1.13", + "scipy", "ONE-api>=2.7.0", # alf sorter and streaming IBL - "ibllib>=2.32.5", # streaming IBL + "ibllib>=2.36.0", # streaming IBL "pymatreader>=0.0.32", # For cell explorer matlab files "zugbruecke>=0.2; sys_platform!='win32'", # For plexon2 ] streaming_extractors = [ "ONE-api>=2.7.0", # alf sorter and streaming IBL - "ibllib>=2.32.5", # streaming IBL - "scipy<1.13", # ibl has a dependency on scipy but it does not have an upper bound - # Remove this once https://github.com/int-brain-lab/ibllib/issues/753 + "ibllib>=2.36.0", # streaming IBL # Following dependencies are for streaming with nwb files "pynwb>=2.6.0", "fsspec", @@ -121,8 +119,8 @@ test_core = [ # for github test : probeinterface and neo from master # for release we need pypi, so this need to be commented - "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", - "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", + # "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", + # "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", ] test = [ @@ -154,8 +152,8 @@ test = [ # for github test : probeinterface and neo from master # for release we need pypi, so this need to be commented - "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", - "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", + # "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", + # "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", ] docs = [ @@ -175,8 +173,8 @@ docs = [ "xarray", # For use of SortingAnalyzer zarr format "networkx", # for release we need pypi, so this needs to be commented - "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", # We always build from the latest version - "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", # We always build from the latest version + # "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", # We always build from the latest version + # "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", # We always build from the latest version ] diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 6920061366..184959512b 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -643,7 +643,7 @@ def _channel_slice(self, channel_ids, renamed_channel_ids=None): from .channelslice import ChannelSliceRecording warnings.warn( - "This method will be removed in version 0.103, use `select_channels` or `rename_channels` instead.", + "Recording.channel_slice will be removed in version 0.103, use `select_channels` or `rename_channels` instead.", DeprecationWarning, stacklevel=2, ) @@ -657,12 +657,52 @@ def _remove_channels(self, remove_channel_ids): sub_recording = ChannelSliceRecording(self, new_channel_ids) return sub_recording - def _frame_slice(self, start_frame, end_frame): + def frame_slice(self, start_frame: int, end_frame: int) -> BaseRecording: + """ + Returns a new recording with sliced frames. Note that this operation is not in place. + + Parameters + ---------- + start_frame : int + The start frame + end_frame : int + The end frame + + Returns + ------- + BaseRecording + The object with sliced frames + """ + from .frameslicerecording import FrameSliceRecording sub_recording = FrameSliceRecording(self, start_frame=start_frame, end_frame=end_frame) return sub_recording + def time_slice(self, start_time: float, end_time: float) -> BaseRecording: + """ + Returns a new recording with sliced time. Note that this operation is not in place. + + Parameters + ---------- + start_time : float + The start time in seconds. + end_time : float + The end time in seconds. + + Returns + ------- + BaseRecording + The object with sliced time. + """ + + assert self.get_num_segments() == 1, "Time slicing is only supported for single segment recordings." + + start_frame = self.time_to_sample_index(start_time) + end_frame = self.time_to_sample_index(end_time) + + return self.frame_slice(start_frame=start_frame, end_frame=end_frame) + def _select_segments(self, segment_indices): from .segmentutils import SelectSegmentRecording diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index 50118c78d7..2a9f075954 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -75,9 +75,6 @@ def is_filtered(self): def _channel_slice(self, channel_ids, renamed_channel_ids=None): raise NotImplementedError - def _frame_slice(self, channel_ids, renamed_channel_ids=None): - raise NotImplementedError - def set_probe(self, probe, group_mode="by_probe", in_place=False): """ Attach a list of Probe object to a recording. @@ -190,7 +187,7 @@ def _set_probes(self, probe_or_probegroup, group_mode="by_probe", in_place=False if np.array_equal(new_channel_ids, self.get_channel_ids()): sub_recording = self.clone() else: - sub_recording = self.channel_slice(new_channel_ids) + sub_recording = self.select_channels(new_channel_ids) # create a vector that handle all contacts in property sub_recording.set_property("contact_vector", probe_as_numpy_array, ids=None) @@ -461,6 +458,22 @@ def channel_slice(self, channel_ids, renamed_channel_ids=None): """ return self._channel_slice(channel_ids, renamed_channel_ids=renamed_channel_ids) + def select_channels(self, channel_ids): + """ + Returns a new object with sliced channels. + + Parameters + ---------- + channel_ids : np.array or list + The list of channels to keep + + Returns + ------- + BaseRecordingSnippets + The object with sliced channels + """ + raise NotImplementedError + def remove_channels(self, remove_channel_ids): """ Returns a new object with removed channels. @@ -494,7 +507,7 @@ def frame_slice(self, start_frame, end_frame): BaseRecordingSnippets The object with sliced frames """ - return self._frame_slice(start_frame, end_frame) + raise NotImplementedError def select_segments(self, segment_indices): """ @@ -545,7 +558,7 @@ def split_by(self, property="group", outputs="dict"): for value in np.unique(values): (inds,) = np.nonzero(values == value) new_channel_ids = self.get_channel_ids()[inds] - subrec = self.channel_slice(new_channel_ids) + subrec = self.select_channels(new_channel_ids) if outputs == "list": recordings.append(subrec) elif outputs == "dict": diff --git a/src/spikeinterface/core/basesnippets.py b/src/spikeinterface/core/basesnippets.py index 5443234910..1f3fee74a8 100644 --- a/src/spikeinterface/core/basesnippets.py +++ b/src/spikeinterface/core/basesnippets.py @@ -135,9 +135,20 @@ def get_snippets_from_frames( def _save(self, format="binary", **save_kwargs): raise NotImplementedError + def select_channels(self, channel_ids: list | np.array | tuple) -> "BaseSnippets": + from .channelslice import ChannelSliceSnippets + + return ChannelSliceSnippets(self, channel_ids) + def _channel_slice(self, channel_ids, renamed_channel_ids=None): from .channelslice import ChannelSliceSnippets + import warnings + warnings.warn( + "Snippets.channel_slice will be removed in version 0.103, use `select_channels` or `rename_channels` instead.", + DeprecationWarning, + stacklevel=2, + ) sub_recording = ChannelSliceSnippets(self, channel_ids, renamed_channel_ids=renamed_channel_ids) return sub_recording @@ -148,9 +159,6 @@ def _remove_channels(self, remove_channel_ids): sub_recording = ChannelSliceSnippets(self, new_channel_ids) return sub_recording - def _frame_slice(self, start_frame, end_frame): - raise NotImplementedError - def _select_segments(self, segment_indices): from .segmentutils import SelectSegmentSnippets diff --git a/src/spikeinterface/core/core_tools.py b/src/spikeinterface/core/core_tools.py index 5232539422..664eac169f 100644 --- a/src/spikeinterface/core/core_tools.py +++ b/src/spikeinterface/core/core_tools.py @@ -83,7 +83,12 @@ def default(self, obj): if isinstance(obj, np.generic): return obj.item() - if np.issctype(obj): # Cast numpy datatypes to their names + # Standard numpy dtypes like np.dtype('int32") are transformed this way + if isinstance(obj, np.dtype): + return np.dtype(obj).name + + # This will transform to a string canonical representation of the dtype (e.g. np.int32 -> 'int32') + if isinstance(obj, type) and issubclass(obj, np.generic): return np.dtype(obj).name if isinstance(obj, np.ndarray): diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index c32baf9d59..8b1b293543 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -709,7 +709,7 @@ def get_chunk_with_margin( case zero padding is used, in the second case np.pad is called with mod="reflect". """ - length = rec_segment.get_num_samples() + length = int(rec_segment.get_num_samples()) if channel_indices is None: channel_indices = slice(None) @@ -917,3 +917,34 @@ def get_rec_attributes(recording): dtype=recording.get_dtype(), ) return rec_attributes + + +def do_recording_attributes_match(recording1, recording2_attributes) -> bool: + """ + Check if two recordings have the same attributes + + Parameters + ---------- + recording1 : BaseRecording + The first recording object + recording2_attributes : dict + The recording attributes to test against + + Returns + ------- + bool + True if the recordings have the same attributes + """ + recording1_attributes = get_rec_attributes(recording1) + recording2_attributes = deepcopy(recording2_attributes) + recording1_attributes.pop("properties") + recording2_attributes.pop("properties") + + return ( + np.array_equal(recording1_attributes["channel_ids"], recording2_attributes["channel_ids"]) + and recording1_attributes["sampling_frequency"] == recording2_attributes["sampling_frequency"] + and recording1_attributes["num_channels"] == recording2_attributes["num_channels"] + and recording1_attributes["num_samples"] == recording2_attributes["num_samples"] + and recording1_attributes["is_filtered"] == recording2_attributes["is_filtered"] + and recording1_attributes["dtype"] == recording2_attributes["dtype"] + ) diff --git a/src/spikeinterface/core/segmentutils.py b/src/spikeinterface/core/segmentutils.py index c3881cc1f8..959b7f8c43 100644 --- a/src/spikeinterface/core/segmentutils.py +++ b/src/spikeinterface/core/segmentutils.py @@ -156,7 +156,7 @@ def __init__(self, parent_segments, sampling_frequency, ignore_times=True): BaseRecordingSegment.__init__(self, **time_kwargs) self.parent_segments = parent_segments self.all_length = [rec_seg.get_num_samples() for rec_seg in self.parent_segments] - self.cumsum_length = np.cumsum([0] + self.all_length) + self.cumsum_length = [0] + [sum(self.all_length[: i + 1]) for i in range(len(self.all_length))] self.total_length = int(np.sum(self.all_length)) def get_num_samples(self): diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index c541634e98..46d02099d5 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -22,7 +22,7 @@ from .basesorting import BaseSorting from .base import load_extractor -from .recording_tools import check_probe_do_not_overlap, get_rec_attributes +from .recording_tools import check_probe_do_not_overlap, get_rec_attributes, do_recording_attributes_match from .core_tools import check_json, retrieve_importing_provenance from .job_tools import split_job_kwargs from .numpyextractors import NumpySorting @@ -203,6 +203,8 @@ def __init__( self.format = format self.sparsity = sparsity self.return_scaled = return_scaled + # this is used to store temporary recording + self._temporary_recording = None # extensions are not loaded at init self.extensions = dict() @@ -605,13 +607,37 @@ def load_from_zarr(cls, folder, recording=None): return sorting_analyzer + def set_temporary_recording(self, recording: BaseRecording): + """ + Sets a temporary recording object. This function can be useful to temporarily set + a "cached" recording object that is not saved in the SortingAnalyzer object to speed up + computations. Upon reloading, the SortingAnalyzer object will try to reload the recording + from the original location in a lazy way. + + + Parameters + ---------- + recording : BaseRecording + The recording object to set as temporary recording. + """ + # check that recording is compatible + assert do_recording_attributes_match(recording, self.rec_attributes), "Recording attributes do not match." + assert np.array_equal( + recording.get_channel_locations(), self.get_channel_locations() + ), "Recording channel locations do not match." + if self._recording is not None: + warnings.warn("SortingAnalyzer recording is already set. The current recording is temporarily replaced.") + self._temporary_recording = recording + def _save_or_select(self, format="binary_folder", folder=None, unit_ids=None) -> "SortingAnalyzer": """ Internal used by both save_as(), copy() and select_units() which are more or less the same. """ if self.has_recording(): - recording = self.recording + recording = self._recording + elif self.has_temporary_recording(): + recording = self._temporary_recording else: recording = None @@ -728,9 +754,9 @@ def is_read_only(self) -> bool: @property def recording(self) -> BaseRecording: - if not self.has_recording(): + if not self.has_recording() and not self.has_temporary_recording(): raise ValueError("SortingAnalyzer could not load the recording") - return self._recording + return self._temporary_recording or self._recording @property def channel_ids(self) -> np.ndarray: @@ -747,6 +773,9 @@ def unit_ids(self) -> np.ndarray: def has_recording(self) -> bool: return self._recording is not None + def has_temporary_recording(self) -> bool: + return self._temporary_recording is not None + def is_sparse(self) -> bool: return self.sparsity is not None diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index 8dda9136cc..066d79b6b4 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -2,7 +2,7 @@ import numpy as np import json -from dataclasses import dataclass, field, astuple +from dataclasses import dataclass, field, astuple, replace from probeinterface import Probe from pathlib import Path from .sparsity import ChannelSparsity @@ -97,8 +97,20 @@ def __post_init__(self): # Initialize sparsity object if self.channel_ids is None: self.channel_ids = np.arange(self.num_channels) + else: + self.channel_ids = np.asarray(self.channel_ids) + assert ( + len(self.channel_ids) == self.num_channels + ), f"length of channel ids {len(self.channel_ids)} must be equal to the number of channels {self.num_channels}" + if self.unit_ids is None: self.unit_ids = np.arange(self.num_units) + else: + self.unit_ids = np.asarray(self.unit_ids) + assert ( + self.unit_ids.size == self.num_units + ), f"length of units ids {self.unit_ids.size} must be equal to the number of units {self.num_units}" + if self.sparsity_mask is not None: self.sparsity = ChannelSparsity( mask=self.sparsity_mask, @@ -128,6 +140,53 @@ def __repr__(self): return repr_str + def select_units(self, unit_ids) -> Templates: + """ + Return a new Templates object with only the selected units. + + Parameters + ---------- + unit_ids : list + List of unit IDs to select. + """ + unit_ids_list = list(self.unit_ids) + unit_indices = np.array([unit_ids_list.index(unit_id) for unit_id in unit_ids], dtype=int) + sliced_sparsity_mask = None if self.sparsity_mask is None else self.sparsity_mask[unit_indices] + + # Data class method to only change selected fields + return replace( + self, + templates_array=self.templates_array[unit_indices], + sparsity_mask=sliced_sparsity_mask, + unit_ids=unit_ids, + check_for_consistent_sparsity=False, + ) + + def select_channels(self, channel_ids) -> Templates: + """ + Return a new Templates object with only the selected channels. + This operation can be useful to remove bad channels for hybrid recording + generation. + + Parameters + ---------- + channel_ids : list + List of channel IDs to select. + """ + assert not self.are_templates_sparse(), "Cannot select channels on sparse templates" + channel_ids_list = list(self.channel_ids) + channel_indices = np.array([channel_ids_list.index(channel_id) for channel_id in channel_ids]) + sliced_sparsity_mask = None if self.sparsity_mask is None else self.sparsity_mask[:, channel_indices] + + # Data class method to only change selected fields + return replace( + self, + templates_array=self.templates_array[:, :, channel_indices], + sparsity_mask=sliced_sparsity_mask, + channel_ids=channel_ids, + check_for_consistent_sparsity=False, + ) + def to_sparse(self, sparsity): # Turn a dense representation of templates into a sparse one, given some sparsity. # Note that nothing prevent Templates tobe empty after sparsification if the sparse mask have no channels for some units diff --git a/src/spikeinterface/core/tests/test_baserecording.py b/src/spikeinterface/core/tests/test_baserecording.py index eb6cf7ac12..682881af8a 100644 --- a/src/spikeinterface/core/tests/test_baserecording.py +++ b/src/spikeinterface/core/tests/test_baserecording.py @@ -361,5 +361,30 @@ def test_select_channels(): assert np.array_equal(selected_channel_ids, ["a", "c"]) +def test_time_slice(): + # Case with sampling frequency + sampling_frequency = 10_000.0 + recording = generate_recording(durations=[1.0], num_channels=3, sampling_frequency=sampling_frequency) + + sliced_recording_times = recording.time_slice(start_time=0.1, end_time=0.8) + sliced_recording_frames = recording.frame_slice(start_frame=1000, end_frame=8000) + + assert np.allclose(sliced_recording_times.get_traces(), sliced_recording_frames.get_traces()) + + +def test_time_slice_with_time_vector(): + + # Case with time vector + sampling_frequency = 10_000.0 + recording = generate_recording(durations=[1.0], num_channels=3, sampling_frequency=sampling_frequency) + times = 1 + np.arange(0, 10_000) / sampling_frequency + recording.set_times(times=times, segment_index=0, with_warning=False) + + sliced_recording_times = recording.time_slice(start_time=1.1, end_time=1.8) + sliced_recording_frames = recording.frame_slice(start_frame=1000, end_frame=8000) + + assert np.allclose(sliced_recording_times.get_traces(), sliced_recording_frames.get_traces()) + + if __name__ == "__main__": test_BaseRecording() diff --git a/src/spikeinterface/core/tests/test_jsonification.py b/src/spikeinterface/core/tests/test_jsonification.py index f63cfb16d8..4417ea342f 100644 --- a/src/spikeinterface/core/tests/test_jsonification.py +++ b/src/spikeinterface/core/tests/test_jsonification.py @@ -122,7 +122,6 @@ def test_numpy_dtype_alises_encoding(): # People tend to use this a dtype instead of the proper classes json.dumps(np.int32, cls=SIJsonEncoder) json.dumps(np.float32, cls=SIJsonEncoder) - json.dumps(np.bool_, cls=SIJsonEncoder) # Note that np.bool was deperecated in numpy 1.20.0 def test_recording_encoding(numpy_generated_recording): diff --git a/src/spikeinterface/core/tests/test_sortinganalyzer.py b/src/spikeinterface/core/tests/test_sortinganalyzer.py index 13e01c32da..d780932146 100644 --- a/src/spikeinterface/core/tests/test_sortinganalyzer.py +++ b/src/spikeinterface/core/tests/test_sortinganalyzer.py @@ -32,8 +32,13 @@ def get_dataset(): return recording, sorting -def test_SortingAnalyzer_memory(tmp_path): - recording, sorting = get_dataset() +@pytest.fixture(scope="module") +def dataset(): + return get_dataset() + + +def test_SortingAnalyzer_memory(tmp_path, dataset): + recording, sorting = dataset sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=False, sparsity=None) _check_sorting_analyzers(sorting_analyzer, sorting, cache_folder=tmp_path) @@ -52,8 +57,8 @@ def test_SortingAnalyzer_memory(tmp_path): assert not sorting_analyzer.return_scaled -def test_SortingAnalyzer_binary_folder(tmp_path): - recording, sorting = get_dataset() +def test_SortingAnalyzer_binary_folder(tmp_path, dataset): + recording, sorting = dataset folder = tmp_path / "test_SortingAnalyzer_binary_folder" if folder.exists(): @@ -82,8 +87,8 @@ def test_SortingAnalyzer_binary_folder(tmp_path): _check_sorting_analyzers(sorting_analyzer, sorting, cache_folder=tmp_path) -def test_SortingAnalyzer_zarr(tmp_path): - recording, sorting = get_dataset() +def test_SortingAnalyzer_zarr(tmp_path, dataset): + recording, sorting = dataset folder = tmp_path / "test_SortingAnalyzer_zarr.zarr" if folder.exists(): @@ -103,10 +108,27 @@ def test_SortingAnalyzer_zarr(tmp_path): ) -def _check_sorting_analyzers(sorting_analyzer, original_sorting, cache_folder): +def test_SortingAnalyzer_tmp_recording(dataset): + recording, sorting = dataset + recording_cached = recording.save(mode="memory") - print() - print(sorting_analyzer) + sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=False, sparsity=None) + sorting_analyzer.set_temporary_recording(recording_cached) + assert sorting_analyzer.has_temporary_recording() + # check that saving as uses the original recording + sorting_analyzer_saved = sorting_analyzer.save_as(format="memory") + assert sorting_analyzer_saved.has_recording() + assert not sorting_analyzer_saved.has_temporary_recording() + assert isinstance(sorting_analyzer_saved.recording, type(recording)) + + recording_sliced = recording.channel_slice(recording.channel_ids[:-1]) + + # wrong channels + with pytest.raises(AssertionError): + sorting_analyzer.set_temporary_recording(recording_sliced) + + +def _check_sorting_analyzers(sorting_analyzer, original_sorting, cache_folder): register_result_extension(DummyAnalyzerExtension) @@ -290,8 +312,10 @@ def test_extensions_sorting(): if __name__ == "__main__": tmp_path = Path("test_SortingAnalyzer") - test_SortingAnalyzer_memory(tmp_path) - test_SortingAnalyzer_binary_folder(tmp_path) - test_SortingAnalyzer_zarr(tmp_path) + dataset = _get_dataset() + test_SortingAnalyzer_memory(tmp_path, dataset) + test_SortingAnalyzer_binary_folder(tmp_path, dataset) + test_SortingAnalyzer_zarr(tmp_path, dataset) + test_SortingAnalyzer_tmp_recording(dataset) test_extension() test_extension_params() diff --git a/src/spikeinterface/core/tests/test_template_class.py b/src/spikeinterface/core/tests/test_template_class.py index 34a89ea5d5..4e0a0c8567 100644 --- a/src/spikeinterface/core/tests/test_template_class.py +++ b/src/spikeinterface/core/tests/test_template_class.py @@ -8,12 +8,13 @@ def generate_test_template(template_type, is_scaled=True) -> Templates: - num_units = 2 + num_units = 3 num_samples = 5 - num_channels = 3 + num_channels = 4 templates_shape = (num_units, num_samples, num_channels) templates_array = np.arange(num_units * num_samples * num_channels).reshape(templates_shape) - + unit_ids = ["unit_a", "unit_b", "unit_c"] + channel_ids = ["channel1", "channel2", "channel3", "channel4"] sampling_frequency = 30_000 nbefore = 2 @@ -25,19 +26,25 @@ def generate_test_template(template_type, is_scaled=True) -> Templates: sampling_frequency=sampling_frequency, nbefore=nbefore, probe=probe, + unit_ids=unit_ids, + channel_ids=channel_ids, is_scaled=is_scaled, ) elif template_type == "sparse": # sparse with sparse templates - sparsity_mask = np.array([[True, False, True], [False, True, False]]) + sparsity_mask = np.array( + [[True, False, True, True], [False, True, False, False], [True, False, True, False]], + ) sparsity = ChannelSparsity( - mask=sparsity_mask, unit_ids=np.arange(num_units), channel_ids=np.arange(num_channels) + mask=sparsity_mask, + unit_ids=unit_ids, + channel_ids=channel_ids, ) # Create sparse templates sparse_templates_array = np.zeros(shape=(num_units, num_samples, sparsity.max_num_active_channels)) - for unit_index in range(num_units): + for unit_index, unit_id in enumerate(unit_ids): template = templates_array[unit_index, ...] - sparse_template = sparsity.sparsify_waveforms(waveforms=template, unit_id=unit_index) + sparse_template = sparsity.sparsify_waveforms(waveforms=template, unit_id=unit_id) sparse_templates_array[unit_index, :, : sparse_template.shape[1]] = sparse_template return Templates( @@ -47,11 +54,14 @@ def generate_test_template(template_type, is_scaled=True) -> Templates: nbefore=nbefore, probe=probe, is_scaled=is_scaled, + unit_ids=unit_ids, + channel_ids=channel_ids, ) elif template_type == "sparse_with_dense_templates": # sparse with dense templates - sparsity_mask = np.array([[True, False, True], [False, True, False]]) - + sparsity_mask = np.array( + [[True, False, True, True], [False, True, False, False], [True, False, True, False]], + ) return Templates( templates_array=templates_array, sparsity_mask=sparsity_mask, @@ -59,6 +69,8 @@ def generate_test_template(template_type, is_scaled=True) -> Templates: nbefore=nbefore, probe=probe, is_scaled=is_scaled, + unit_ids=unit_ids, + channel_ids=channel_ids, ) @@ -117,6 +129,48 @@ def test_save_and_load_zarr(template_type, is_scaled, tmp_path): assert original_template == loaded_template +@pytest.mark.parametrize("is_scaled", [True, False]) +@pytest.mark.parametrize("template_type", ["dense", "sparse"]) +def test_select_units(template_type, is_scaled): + template = generate_test_template(template_type, is_scaled) + selected_unit_ids = ["unit_a", "unit_c"] + selected_unit_ids_indices = [0, 2] + + selected_template = template.select_units(selected_unit_ids) + + # Verify that the selected template has the correct number of units + assert selected_template.num_units == len(selected_unit_ids) + # Verify that the unit ids match + assert np.array_equal(selected_template.unit_ids, selected_unit_ids) + # Verify that the templates data matches + assert np.array_equal(selected_template.templates_array, template.templates_array[selected_unit_ids_indices]) + + if template.sparsity_mask is not None: + assert np.array_equal(selected_template.sparsity_mask, template.sparsity_mask[selected_unit_ids_indices]) + + +@pytest.mark.parametrize("is_scaled", [True, False]) +@pytest.mark.parametrize("template_type", ["dense"]) +def test_select_channels(template_type, is_scaled): + template = generate_test_template(template_type, is_scaled) + selected_channel_ids = ["channel1", "channel3"] + selected_channel_ids_indices = [0, 2] + + selected_template = template.select_channels(selected_channel_ids) + + # Verify that the selected template has the correct number of channels + assert selected_template.num_channels == len(selected_channel_ids) + # Verify that the channel ids match + assert np.array_equal(selected_template.channel_ids, selected_channel_ids) + # Verify that the templates data matches + assert np.array_equal( + selected_template.templates_array, template.templates_array[:, :, selected_channel_ids_indices] + ) + + if template.sparsity_mask is not None: + assert np.array_equal(selected_template.sparsity_mask, template.sparsity_mask[:, selected_channel_ids_indices]) + + if __name__ == "__main__": # test_json_serialization("sparse") test_json_serialization("dense") diff --git a/src/spikeinterface/curation/__init__.py b/src/spikeinterface/curation/__init__.py index 9c6e17edb5..add08ddb5e 100644 --- a/src/spikeinterface/curation/__init__.py +++ b/src/spikeinterface/curation/__init__.py @@ -11,4 +11,7 @@ from .mergeunitssorting import MergeUnitsSorting, merge_units_sorting from .splitunitsorting import SplitUnitSorting, split_unit_sorting +# curation format +from .curation_format import validate_curation_dict, curation_label_to_dataframe + from .sortingview_curation import apply_sortingview_curation diff --git a/src/spikeinterface/curation/curation_format.py b/src/spikeinterface/curation/curation_format.py new file mode 100644 index 0000000000..d6eded4345 --- /dev/null +++ b/src/spikeinterface/curation/curation_format.py @@ -0,0 +1,163 @@ +from itertools import combinations + + +supported_curation_format_versions = {"1"} + + +def validate_curation_dict(curation_dict): + """ + Validate that the curation dictionary given as parameter complies with the format + + The function do not return anything. This raise an error if something is wring in the format. + + Parameters + ---------- + curation_dict : dict + + """ + + # format + if "format_version" not in curation_dict: + raise ValueError("No version_format") + + if curation_dict["format_version"] not in supported_curation_format_versions: + raise ValueError( + f"Format version ({curation_dict['format_version']}) not supported. " + f"Only {supported_curation_format_versions} are valid" + ) + + # unit_ids + labeled_unit_set = set([lbl["unit_id"] for lbl in curation_dict["manual_labels"]]) + merged_units_set = set(sum(curation_dict["merged_unit_groups"], [])) + removed_units_set = set(curation_dict["removed_units"]) + + if curation_dict["unit_ids"] is not None: + # old format v0 did not contain unit_ids so this can contains None + unit_set = set(curation_dict["unit_ids"]) + if not labeled_unit_set.issubset(unit_set): + raise ValueError("Curation format: some labeled units are not in the unit list") + if not merged_units_set.issubset(unit_set): + raise ValueError("Curation format: some merged units are not in the unit list") + if not removed_units_set.issubset(unit_set): + raise ValueError("Curation format: some removed units are not in the unit list") + + all_merging_groups = [set(group) for group in curation_dict["merged_unit_groups"]] + for gp_1, gp_2 in combinations(all_merging_groups, 2): + if len(gp_1.intersection(gp_2)) != 0: + raise ValueError("Some units belong to multiple merge groups") + if len(removed_units_set.intersection(merged_units_set)) != 0: + raise ValueError("Some units were merged and deleted") + + # Check the labels exclusivity + for lbl in curation_dict["manual_labels"]: + for label_key in curation_dict["label_definitions"].keys(): + if label_key in lbl: + unit_id = lbl["unit_id"] + label_value = lbl[label_key] + if not isinstance(label_value, list): + raise ValueError(f"Curation format: manual_labels {unit_id} is invalid shoudl be a list") + + is_exclusive = curation_dict["label_definitions"][label_key]["exclusive"] + + if is_exclusive and not len(label_value) <= 1: + raise ValueError( + f"Curation format: manual_labels {unit_id} {label_key} are exclusive labels. {label_value} is invalid" + ) + + +def convert_from_sortingview_curation_format_v0(sortingview_dict, destination_format="1"): + """ + Converts the old sortingview curation format (v0) into a curation dictionary new format (v1) + Couple of caveats: + * The list of units is not available in the original sortingview dictionary. We set it to None + * Labels can not be mutually exclusive. + * Labels have no category, so we regroup them under the "all_labels" category + + Parameters + ---------- + sortingview_dict : dict + Dictionary containing the curation information from sortingview + destination_format : str + Version of the format to use. + Default to "1" + + Returns + ------- + curation_dict: dict + A curation dictionary + """ + + assert destination_format == "1" + + merge_groups = sortingview_dict["mergeGroups"] + merged_units = sum(merge_groups, []) + if len(merged_units) > 0: + unit_id_type = int if isinstance(merged_units[0], int) else str + else: + unit_id_type = str + all_units = [] + all_labels = [] + manual_labels = [] + general_cat = "all_labels" + for unit_id_, l_labels in sortingview_dict["labelsByUnit"].items(): + all_labels.extend(l_labels) + # recorver the correct type for unit_id + unit_id = unit_id_type(unit_id_) + all_units.append(unit_id) + manual_labels.append({"unit_id": unit_id, general_cat: l_labels}) + labels_def = {"all_labels": {"name": "all_labels", "label_options": list(set(all_labels)), "exclusive": False}} + + curation_dict = { + "format_version": destination_format, + "unit_ids": None, + "label_definitions": labels_def, + "manual_labels": manual_labels, + "merged_unit_groups": merge_groups, + "removed_units": [], + } + + return curation_dict + + +def curation_label_to_dataframe(curation_dict): + """ + Transform the curation dict into a pandas dataframe. + For label category with exclusive=True : a column is created and values are the unique label. + For label category with exclusive=False : one column per possible is created and values are boolean. + + If exclusive=False and the same label appear several times then it raises an error. + + Parameters + ---------- + curation_dict : dict + A curation dictionary + + Returns + ------- + labels : pd.DataFrame + dataframe with labels. + """ + import pandas as pd + + labels = pd.DataFrame(index=curation_dict["unit_ids"]) + + for label_key, label_def in curation_dict["label_definitions"].items(): + if label_def["exclusive"]: + assert label_key not in labels.columns, f"{label_key} is already a column" + labels[label_key] = pd.Series(dtype=str) + labels[label_key][:] = "" + for lbl in curation_dict["manual_labels"]: + value = lbl.get(label_key, []) + if len(value) == 1: + labels.at[lbl["unit_id"], label_key] = value[0] + else: + for label_opt in label_def["label_options"]: + assert label_opt not in labels.columns, f"{label_opt} is already a column" + labels[label_opt] = pd.Series(dtype=bool) + labels[label_opt][:] = False + for lbl in curation_dict["manual_labels"]: + values = lbl.get(label_key, []) + for value in values: + labels.at[lbl["unit_id"], value] = True + + return labels diff --git a/src/spikeinterface/curation/sortingview_curation.py b/src/spikeinterface/curation/sortingview_curation.py index b31b8c39d5..c4d2a32958 100644 --- a/src/spikeinterface/curation/sortingview_curation.py +++ b/src/spikeinterface/curation/sortingview_curation.py @@ -6,6 +6,8 @@ from .curationsorting import CurationSorting +# @alessio +# TODO later : this should be reimplemented using the new curation format def apply_sortingview_curation( sorting, uri_or_json, exclude_labels=None, include_labels=None, skip_merge=False, verbose=False ): diff --git a/src/spikeinterface/curation/tests/test_curation_format.py b/src/spikeinterface/curation/tests/test_curation_format.py new file mode 100644 index 0000000000..6d132fbe97 --- /dev/null +++ b/src/spikeinterface/curation/tests/test_curation_format.py @@ -0,0 +1,161 @@ +import pytest + +from pathlib import Path +import json + +from spikeinterface.curation.curation_format import ( + validate_curation_dict, + convert_from_sortingview_curation_format_v0, + curation_label_to_dataframe, +) + + +"""example = { + 'unit_ids': List[str, int], + 'label_definitions': { + 'category_key1': + { + 'label_options': List[str], + 'exclusive': bool} + }, + 'manual_labels': [ + {'unit_id': str or int, + category_key1': List[str], + } + ], + 'merged_unit_groups': List[List[unit_ids]], # one cell goes into at most one list + 'removed_units': List[unit_ids] # Can not be in the merged_units +} +""" + + +curation_ids_int = { + "format_version": "1", + "unit_ids": [1, 2, 3, 6, 10, 14, 20, 31, 42], + "label_definitions": { + "quality": {"label_options": ["good", "noise", "MUA", "artifact"], "exclusive": True}, + "putative_type": { + "label_options": ["excitatory", "inhibitory", "pyramidal", "mitral"], + "exclusive": False, + }, + }, + "manual_labels": [ + {"unit_id": 1, "quality": ["good"]}, + { + "unit_id": 2, + "quality": [ + "noise", + ], + "putative_type": ["excitatory", "pyramidal"], + }, + {"unit_id": 3, "putative_type": ["inhibitory"]}, + ], + "merged_unit_groups": [[3, 6], [10, 14, 20]], # one cell goes into at most one list + "removed_units": [31, 42], # Can not be in the merged_units +} + +curation_ids_str = { + "format_version": "1", + "unit_ids": ["u1", "u2", "u3", "u6", "u10", "u14", "u20", "u31", "u42"], + "label_definitions": { + "quality": {"label_options": ["good", "noise", "MUA", "artifact"], "exclusive": True}, + "putative_type": { + "label_options": ["excitatory", "inhibitory", "pyramidal", "mitral"], + "exclusive": False, + }, + }, + "manual_labels": [ + {"unit_id": "u1", "quality": ["good"]}, + { + "unit_id": "u2", + "quality": [ + "noise", + ], + "putative_type": ["excitatory", "pyramidal"], + }, + {"unit_id": "u3", "putative_type": ["inhibitory"]}, + ], + "merged_unit_groups": [["u3", "u6"], ["u10", "u14", "u20"]], # one cell goes into at most one list + "removed_units": ["u31", "u42"], # Can not be in the merged_units +} + +# This is a failure example with duplicated merge +duplicate_merge = curation_ids_int.copy() +duplicate_merge["merged_unit_groups"] = [[3, 6, 10], [10, 14, 20]] + + +# This is a failure example with unit 3 both in removed and merged +merged_and_removed = curation_ids_int.copy() +merged_and_removed["merged_unit_groups"] = [[3, 6], [10, 14, 20]] +merged_and_removed["removed_units"] = [3, 31, 42] + +# this is a failure because unit 99 is not in the initial list +unknown_merged_unit = curation_ids_int.copy() +unknown_merged_unit["merged_unit_groups"] = [[3, 6, 99], [10, 14, 20]] + +# this is a failure because unit 99 is not in the initial list +unknown_removed_unit = curation_ids_int.copy() +unknown_removed_unit["removed_units"] = [31, 42, 99] + + +def test_curation_format_validation(): + validate_curation_dict(curation_ids_int) + validate_curation_dict(curation_ids_str) + + with pytest.raises(ValueError): + # Raised because duplicated merged units + validate_curation_dict(duplicate_merge) + with pytest.raises(ValueError): + # Raised because some units belong to merged and removed unit groups + validate_curation_dict(merged_and_removed) + with pytest.raises(ValueError): + # Some merged units are not in the unit list + validate_curation_dict(unknown_merged_unit) + with pytest.raises(ValueError): + # Raise because some removed units are not in the unit list + validate_curation_dict(unknown_removed_unit) + + +def test_to_from_json(): + + json.loads(json.dumps(curation_ids_int, indent=4)) + json.loads(json.dumps(curation_ids_str, indent=4)) + + +def test_convert_from_sortingview_curation_format_v0(): + + parent_folder = Path(__file__).parent + for filename in ( + "sv-sorting-curation.json", + "sv-sorting-curation-int.json", + "sv-sorting-curation-str.json", + "sv-sorting-curation-false-positive.json", + ): + + json_file = parent_folder / filename + with open(json_file, "r") as f: + curation_v0 = json.load(f) + # print(curation_v0) + curation_v1 = convert_from_sortingview_curation_format_v0(curation_v0) + # print(curation_v1) + validate_curation_dict(curation_v1) + + +def test_curation_label_to_dataframe(): + + df = curation_label_to_dataframe(curation_ids_int) + assert "quality" in df.columns + assert "excitatory" in df.columns + print(df) + + df = curation_label_to_dataframe(curation_ids_str) + # print(df) + + +if __name__ == "__main__": + # test_curation_format_validation() + # test_to_from_json() + # test_convert_from_sortingview_curation_format_v0() + # test_curation_label_to_dataframe() + + print(json.dumps(curation_ids_str, indent=4)) diff --git a/src/spikeinterface/extractors/bids.py b/src/spikeinterface/extractors/bids.py index 0b89e96649..d75752da9e 100644 --- a/src/spikeinterface/extractors/bids.py +++ b/src/spikeinterface/extractors/bids.py @@ -4,7 +4,6 @@ import numpy as np -import neo import probeinterface from .nwbextractors import read_nwb @@ -47,6 +46,8 @@ def read_bids(folder_path): recordings.append(rec) elif file_path.suffix == ".nix": + import neo + neo_reader = neo.rawio.NIXRawIO(file_path) neo_reader.parse_header() stream_ids = neo_reader.header["signal_streams"]["id"] diff --git a/src/spikeinterface/extractors/mcsh5extractors.py b/src/spikeinterface/extractors/mcsh5extractors.py index d44b7d17a8..c55f9d47db 100644 --- a/src/spikeinterface/extractors/mcsh5extractors.py +++ b/src/spikeinterface/extractors/mcsh5extractors.py @@ -7,13 +7,6 @@ from spikeinterface.core import BaseRecording, BaseRecordingSegment from spikeinterface.core.core_tools import define_function_from_class -try: - import h5py - - HAVE_MCSH5 = True -except ImportError: - HAVE_MCSH5 = False - class MCSH5RecordingExtractor(BaseRecording): """Load a MCS H5 file as a recording extractor. @@ -32,7 +25,6 @@ class MCSH5RecordingExtractor(BaseRecording): """ extractor_name = "MCSH5Recording" - installed = HAVE_MCSH5 # check at class level if installed or not mode = "file" installation_mesg = ( "To use the MCSH5RecordingExtractor install h5py: \n\n pip install h5py\n\n" # error message when not installed @@ -40,7 +32,14 @@ class MCSH5RecordingExtractor(BaseRecording): name = "mcsh5" def __init__(self, file_path, stream_id=0): - assert self.installed, self.installation_mesg + + try: + import h5py + + HAVE_MCSH5 = True + except ImportError: + raise ImportError(self.installation_mesg) + self._file_path = file_path mcs_info = openMCSH5File(self._file_path, stream_id) @@ -103,6 +102,8 @@ def get_traces(self, start_frame=None, end_frame=None, channel_indices=None): def openMCSH5File(filename, stream_id): """Open an MCS hdf5 file, read and return the recording info.""" + import h5py + rf = h5py.File(filename, "r") stream_name = "Stream_" + str(stream_id) diff --git a/src/spikeinterface/extractors/neoextractors/blackrock.py b/src/spikeinterface/extractors/neoextractors/blackrock.py index c3a4c5ad31..5e28c4a20d 100644 --- a/src/spikeinterface/extractors/neoextractors/blackrock.py +++ b/src/spikeinterface/extractors/neoextractors/blackrock.py @@ -4,7 +4,6 @@ from packaging import version from typing import Optional -import neo from spikeinterface.core.core_tools import define_function_from_class @@ -43,9 +42,8 @@ def __init__( use_names_as_ids=False, ): neo_kwargs = self.map_to_neo_kwargs(file_path) - if version.parse(neo.__version__) > version.parse("0.12.0"): - # do not load spike because this is slow but not released yet - neo_kwargs["load_nev"] = False + neo_kwargs["load_nev"] = False # Avoid loading spikes release in neo 0.12.0 + # trick to avoid to select automatically the correct stream_id suffix = Path(file_path).suffix if ".ns" in suffix: diff --git a/src/spikeinterface/extractors/neoextractors/spikeglx.py b/src/spikeinterface/extractors/neoextractors/spikeglx.py index 25e1432297..4f92fca988 100644 --- a/src/spikeinterface/extractors/neoextractors/spikeglx.py +++ b/src/spikeinterface/extractors/neoextractors/spikeglx.py @@ -1,11 +1,7 @@ from __future__ import annotations -from packaging import version - -import numpy as np from pathlib import Path -import neo import probeinterface from spikeinterface.extractors.neuropixels_utils import get_neuropixels_sample_shifts diff --git a/src/spikeinterface/extractors/nwbextractors.py b/src/spikeinterface/extractors/nwbextractors.py index 4729ccea86..1f413ae2b0 100644 --- a/src/spikeinterface/extractors/nwbextractors.py +++ b/src/spikeinterface/extractors/nwbextractors.py @@ -731,7 +731,7 @@ def _fetch_recording_segment_info_backend(self, file, cache, load_time_vector, s sampling_frequency = 1.0 / np.median(np.diff(timestamps[:samples_for_rate_estimation])) if load_time_vector and timestamps is not None: - times_kwargs = dict(time_vector=electrical_series.timestamps) + times_kwargs = dict(time_vector=electrical_series["timestamps"]) else: times_kwargs = dict(sampling_frequency=sampling_frequency, t_start=t_start) diff --git a/src/spikeinterface/extractors/phykilosortextractors.py b/src/spikeinterface/extractors/phykilosortextractors.py index 7fdd77e703..e65ff0adfb 100644 --- a/src/spikeinterface/extractors/phykilosortextractors.py +++ b/src/spikeinterface/extractors/phykilosortextractors.py @@ -170,10 +170,22 @@ def __init__( self.set_property(key="quality", values=cluster_info[prop_name]) else: if load_all_cluster_properties: - # pandas loads strings as objects + # pandas loads strings with empty values as objects with NaNs + prop_dtype = None if cluster_info[prop_name].values.dtype.kind == "O": - prop_dtype = type(cluster_info[prop_name].values[0]) - values_ = cluster_info[prop_name].values.astype(prop_dtype) + for value in cluster_info[prop_name].values: + if isinstance(value, (np.floating, float)) and np.isnan( + value + ): # Blank values are encoded as 'NaN'. + continue + + prop_dtype = type(value) + break + if prop_dtype is not None: + values_ = cluster_info[prop_name].values.astype(prop_dtype) + else: + # Could not find a valid dtype for the column. Skip it. + continue else: values_ = cluster_info[prop_name].values self.set_property(key=prop_name, values=values_) diff --git a/src/spikeinterface/generation/tests/test_template_database.py b/src/spikeinterface/generation/tests/test_template_database.py index 757018de89..9e2a013ad0 100644 --- a/src/spikeinterface/generation/tests/test_template_database.py +++ b/src/spikeinterface/generation/tests/test_template_database.py @@ -18,7 +18,8 @@ def test_fetch_template_object_from_database(): templates = fetch_template_object_from_database("test_templates.zarr") assert isinstance(templates, Templates) - assert templates.num_units == 100 + assert templates.num_units == 89 + assert templates.num_samples == 240 assert templates.num_channels == 384 @@ -35,7 +36,7 @@ def test_fetch_templates_database_info(): def test_query_templates_from_database(): templates_info = fetch_templates_database_info() - templates_info = templates_info.iloc[::15] + templates_info = templates_info.iloc[[1, 3, 5]] num_selected = len(templates_info) templates = query_templates_from_database(templates_info) diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 79d8ca17c8..2e544d086b 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -2,7 +2,7 @@ import numpy as np -from spikeinterface.core import ChannelSparsity, get_chunk_with_margin +from spikeinterface.core import ChannelSparsity from spikeinterface.core.job_tools import ChunkRecordingExecutor, _shared_job_kwargs_doc, ensure_n_jobs, fix_job_kwargs from spikeinterface.core.template_tools import get_template_extremum_channel diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index 281782745a..c99b2d4f3b 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -1,8 +1,8 @@ from __future__ import annotations import pytest -import numpy as np import shutil +import numpy as np from spikeinterface.core import generate_ground_truth_recording from spikeinterface.core import create_sorting_analyzer @@ -34,34 +34,65 @@ def get_dataset(): class AnalyzerExtensionCommonTestSuite: """ - Common tests with class approach to compute extension on several cases (3 format x 2 sparsity) - - This is done a a list of differents parameters (extension_function_params_list). - - This automatically precompute extension dependencies with default params before running computation. - - This also test the select_units() ability. + Common tests with class approach to compute extension on several cases, + format ("memory", "binary_folder", "zarr") and sparsity (True, False). + Extensions refer to the extension classes that handle the postprocessing, + for example extracting principal components or amplitude scalings. + + This base class provides a fixture which sets a recording + and sorting object onto itself, which are set up once each time + the base class is subclassed in a test environment. The recording + and sorting object are used in the creation of the `sorting_analyzer` + object used to run postprocessing routines. + + When subclassed, a test function that parametrises arguments + that are passed to the `sorting_analyzer.compute()` can be setup. + This must call `run_extension_tests()` which sets up a `sorting_analyzer` + with the relevant format and sparsity. This also automatically precomputes + extension dependencies with default params, Then, `check_one()` is called + which runs the compute function with the passed params and tests that: + + 1) the returned extractor object has data on it + 2) check `sorting_analyzer.get_extension()` does not return None + 3) the correct units are sliced with the `select_units()` function. """ - extension_class = None - extension_function_params_list = None - - @classmethod - def setUpClass(cls): - cls.recording, cls.sorting = get_dataset() - # sparsity is computed once for all cases to save processing time and force a small radius - cls.sparsity = estimate_sparsity(cls.recording, cls.sorting, method="radius", radius_um=20) - - @property - def extension_name(self): - return self.extension_class.extension_name + @pytest.fixture(autouse=True, scope="class") + def setUpClass(self, create_cache_folder): + """ + This method sets up the class once at the start of testing. It is + in scope for the lifetime of te class and is reused across all + tests that inherit from this base class to save processing time and + force a small radius. + + When setting attributes on `self` in `scope="class"` a new + class instance is used for each. In this case, we have to set + from the base object `__class__` to ensure the attributes + are available to all subclass instances. + """ + self.__class__.recording, self.__class__.sorting = get_dataset() + + self.__class__.sparsity = estimate_sparsity( + self.__class__.recording, self.__class__.sorting, method="radius", radius_um=20 + ) + self.__class__.cache_folder = create_cache_folder + + def _prepare_sorting_analyzer(self, format, sparse, extension_class): + """ + Prepare a SortingAnalyzer object with dependencies already computed + according to format (e.g. "memory", "binary_folder", "zarr") + and sparsity (e.g. True, False). + """ + sparsity_ = self.sparsity if sparse else None - @pytest.fixture(autouse=True) - def create_cache_folder(self, tmp_path_factory): - self.cache_folder = tmp_path_factory.mktemp("cache_folder") + sorting_analyzer = self.get_sorting_analyzer( + self.recording, self.sorting, format=format, sparsity=sparsity_, name=extension_class.extension_name + ) + return sorting_analyzer def get_sorting_analyzer(self, recording, sorting, format="memory", sparsity=None, name=""): sparse = sparsity is not None + if format == "memory": folder = None elif format == "binary_folder": @@ -77,43 +108,52 @@ def get_sorting_analyzer(self, recording, sorting, format="memory", sparsity=Non return sorting_analyzer - def _prepare_sorting_analyzer(self, format, sparse): + def _prepare_sorting_analyzer(self, format, sparse, extension_class): # prepare a SortingAnalyzer object with depencies already computed sparsity_ = self.sparsity if sparse else None sorting_analyzer = self.get_sorting_analyzer( - self.recording, self.sorting, format=format, sparsity=sparsity_, name=self.extension_class.extension_name + self.recording, self.sorting, format=format, sparsity=sparsity_, name=extension_class.extension_name ) sorting_analyzer.compute("random_spikes", max_spikes_per_unit=50, seed=2205) - for dependency_name in self.extension_class.depend_on: + + for dependency_name in extension_class.depend_on: if "|" in dependency_name: dependency_name = dependency_name.split("|")[0] sorting_analyzer.compute(dependency_name) + return sorting_analyzer - def _check_one(self, sorting_analyzer): - if self.extension_class.need_job_kwargs: + def _check_one(self, sorting_analyzer, extension_class, params): + """ + Take a prepared sorting analyzer object, compute the extension of interest + with the passed parameters, and check the output is not empty, the extension + exists and `select_units()` method works. + """ + if extension_class.need_job_kwargs: job_kwargs = dict(n_jobs=2, chunk_duration="1s", progress_bar=True) else: job_kwargs = dict() - for params in self.extension_function_params_list: - print(" params", params) - ext = sorting_analyzer.compute(self.extension_name, **params, **job_kwargs) - assert len(ext.data) > 0 - main_data = ext.get_data() + ext = sorting_analyzer.compute(extension_class.extension_name, **params, **job_kwargs) + assert len(ext.data) > 0 + main_data = ext.get_data() + assert len(main_data) > 0 - ext = sorting_analyzer.get_extension(self.extension_name) + ext = sorting_analyzer.get_extension(extension_class.extension_name) assert ext is not None some_unit_ids = sorting_analyzer.unit_ids[::2] sliced = sorting_analyzer.select_units(some_unit_ids, format="memory") assert np.array_equal(sliced.unit_ids, sorting_analyzer.unit_ids[::2]) - # print(sliced) - def test_extension(self): + def run_extension_tests(self, extension_class, params): + """ + Convenience function to perform all checks on the extension + of interest with the passed parameters. Will perform tests + for sparsity and format. + """ for sparse in (True, False): for format in ("memory", "binary_folder", "zarr"): - print() print("sparse", sparse, format) - sorting_analyzer = self._prepare_sorting_analyzer(format, sparse) - self._check_one(sorting_analyzer) + sorting_analyzer = self._prepare_sorting_analyzer(format, sparse, extension_class) + self._check_one(sorting_analyzer, extension_class, params) diff --git a/src/spikeinterface/postprocessing/tests/test_align_sorting.py b/src/spikeinterface/postprocessing/tests/test_align_sorting.py index a02e224984..d44ace3db8 100644 --- a/src/spikeinterface/postprocessing/tests/test_align_sorting.py +++ b/src/spikeinterface/postprocessing/tests/test_align_sorting.py @@ -1,8 +1,3 @@ -import pytest -import shutil - -import pytest - import numpy as np from spikeinterface import NumpySorting @@ -12,8 +7,16 @@ def test_align_sorting(): + """ + `align_sorting()` shifts, in time, the spikes belonging to a unit. + For each unit, an offset is provided and the spike peak index is shifted. + + This test creates a sorting object, then creates an 'unaligned' sorting + object in which the peaks for some of the units are shifted. Next, the `align_sorting()` + function is unused to unshift them, and the original sorting spike train + peak times compared with the corrected sorting train. + """ sorting = generate_sorting(durations=[10.0], seed=0) - print(sorting) unit_ids = sorting.unit_ids @@ -21,20 +24,27 @@ def test_align_sorting(): unit_peak_shifts[unit_ids[-1]] = 5 unit_peak_shifts[unit_ids[-2]] = -5 - # sorting to dict - d = {unit_id: sorting.get_unit_spike_train(unit_id) + unit_peak_shifts[unit_id] for unit_id in sorting.unit_ids} - sorting_unaligned = NumpySorting.from_unit_dict(d, sampling_frequency=sorting.get_sampling_frequency()) - print(sorting_unaligned) + shifted_unit_dict = { + unit_id: sorting.get_unit_spike_train(unit_id) + unit_peak_shifts[unit_id] for unit_id in sorting.unit_ids + } + sorting_unaligned = NumpySorting.from_unit_dict( + shifted_unit_dict, sampling_frequency=sorting.get_sampling_frequency() + ) sorting_aligned = align_sorting(sorting_unaligned, unit_peak_shifts) - print(sorting_aligned) - - for start_frame, end_frame in [(None, None), (10000, 50000)]: - for unit_id in unit_ids[-2:]: - st = sorting.get_unit_spike_train(unit_id) - st_clean = sorting_aligned.get_unit_spike_train(unit_id) - assert np.array_equal(st, st_clean) - -if __name__ == "__main__": - test_align_sorting() + for unit_id in unit_ids: + spiketrain_orig = sorting.get_unit_spike_train(unit_id) + spiketrain_aligned = sorting_aligned.get_unit_spike_train(unit_id) + spiketrain_unaligned = sorting_unaligned.get_unit_spike_train(unit_id) + + # check the shift induced in the test has changed the + # spiketrain as expected. + if unit_peak_shifts[unit_id] == 0: + assert np.array_equal(spiketrain_orig, spiketrain_unaligned) + else: + assert not np.array_equal(spiketrain_orig, spiketrain_unaligned) + + # Perform the key test, that after correction the spiketrain + # matches the original spiketrain for all units (shifted and unshifted). + assert np.array_equal(spiketrain_orig, spiketrain_aligned) diff --git a/src/spikeinterface/postprocessing/tests/test_amplitude_scalings.py b/src/spikeinterface/postprocessing/tests/test_amplitude_scalings.py index b59aca16a8..0868f5238e 100644 --- a/src/spikeinterface/postprocessing/tests/test_amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/tests/test_amplitude_scalings.py @@ -1,21 +1,29 @@ -import unittest import numpy as np - +import pytest from spikeinterface.postprocessing.tests.common_extension_tests import AnalyzerExtensionCommonTestSuite from spikeinterface.postprocessing import ComputeAmplitudeScalings -class AmplitudeScalingsExtensionTest(AnalyzerExtensionCommonTestSuite, unittest.TestCase): - extension_class = ComputeAmplitudeScalings - extension_function_params_list = [ - dict(handle_collisions=True), - dict(handle_collisions=False), - ] +class TestAmplitudeScalingsExtension(AnalyzerExtensionCommonTestSuite): + + @pytest.mark.parametrize("params", [dict(handle_collisions=True), dict(handle_collisions=False)]) + def test_extension(self, params): + self.run_extension_tests(ComputeAmplitudeScalings, params) def test_scaling_values(self): - sorting_analyzer = self._prepare_sorting_analyzer("memory", True) + """ + Amplitude finds the scaling factor for each waveform + to best match its unit template. In this test, amplitude scalings + are calculated from the `sorting_analyzer`. In the test environment, + injected waveforms are not scaled from the template and so + should only differ by Gaussian noise. Therefore the median + scaling should be close to 1. + """ + sorting_analyzer = self._prepare_sorting_analyzer( + "memory", sparse=True, extension_class=ComputeAmplitudeScalings + ) sorting_analyzer.compute("amplitude_scalings", handle_collisions=False) spikes = sorting_analyzer.sorting.to_spike_vector() @@ -26,17 +34,4 @@ def test_scaling_values(self): mask = spikes["unit_index"] == unit_index scalings = ext.data["amplitude_scalings"][mask] median_scaling = np.median(scalings) - # print(unit_index, median_scaling) np.testing.assert_array_equal(np.round(median_scaling), 1) - - # import matplotlib.pyplot as plt - # fig, ax = plt.subplots() - # ax.hist(ext.data["amplitude_scalings"]) - # plt.show() - - -if __name__ == "__main__": - test = AmplitudeScalingsExtensionTest() - test.setUpClass() - test.test_extension() - test.test_scaling_values() diff --git a/src/spikeinterface/postprocessing/tests/test_correlograms.py b/src/spikeinterface/postprocessing/tests/test_correlograms.py index 6d727e6448..eef4af10fc 100644 --- a/src/spikeinterface/postprocessing/tests/test_correlograms.py +++ b/src/spikeinterface/postprocessing/tests/test_correlograms.py @@ -1,6 +1,4 @@ -import unittest import numpy as np -from typing import List try: import numba @@ -14,54 +12,47 @@ from spikeinterface.postprocessing.tests.common_extension_tests import AnalyzerExtensionCommonTestSuite from spikeinterface.postprocessing import ComputeCorrelograms from spikeinterface.postprocessing.correlograms import compute_correlograms_on_sorting, _make_bins +import pytest -class ComputeCorrelogramsTest(AnalyzerExtensionCommonTestSuite, unittest.TestCase): - extension_class = ComputeCorrelograms - extension_function_params_list = [ - dict(method="numpy"), - dict(method="auto"), - ] - if HAVE_NUMBA: - extension_function_params_list.append(dict(method="numba")) +class TestComputeCorrelograms(AnalyzerExtensionCommonTestSuite): + + @pytest.mark.parametrize( + "params", + [ + dict(method="numpy"), + dict(method="auto"), + pytest.param(dict(method="numba"), marks=pytest.mark.skipif(not HAVE_NUMBA, reason="Numba not available")), + ], + ) + def test_extension(self, params): + self.run_extension_tests(ComputeCorrelograms, params) def test_make_bins(): + """ + Check the `_make_bins()` function that generates time bins (lags) for + the correllogram creates the expected number of bins. + """ sorting = generate_sorting(num_units=5, sampling_frequency=30000.0, durations=[10.325, 3.5], seed=0) window_ms = 43.57 bin_ms = 1.6421 bins, window_size, bin_size = _make_bins(sorting, window_ms, bin_ms) assert bins.size == np.floor(window_ms / bin_ms) + 1 - # print(bins, window_size, bin_size) window_ms = 60.0 bin_ms = 2.0 bins, window_size, bin_size = _make_bins(sorting, window_ms, bin_ms) assert bins.size == np.floor(window_ms / bin_ms) + 1 - # print(bins, window_size, bin_size) def _test_correlograms(sorting, window_ms, bin_ms, methods): for method in methods: correlograms, bins = compute_correlograms_on_sorting(sorting, window_ms=window_ms, bin_ms=bin_ms, method=method) if method == "numpy": - ref_correlograms = correlograms ref_bins = bins else: - # ~ import matplotlib.pyplot as plt - # ~ for i in range(ref_correlograms.shape[1]): - # ~ for j in range(ref_correlograms.shape[1]): - # ~ fig, ax = plt.subplots() - # ~ ax.plot(bins[:-1], ref_correlograms[i, j, :], color='green', label='numpy') - # ~ ax.plot(bins[:-1], correlograms[i, j, :], color='red', label=method) - # ~ ax.legend() - # ~ ax.set_title(f'{i} {j}') - # ~ plt.show() - - # numba and numyp do not have exactly the same output - # assert np.all(correlograms == ref_correlograms), f"Failed with method={method}" - assert np.allclose(bins, ref_bins, atol=1e-10), f"Failed with method={method}" @@ -78,15 +69,16 @@ def test_equal_results_correlograms(): def test_flat_cross_correlogram(): + """ + Check that the correlogram (num_units x num_units x num_bins) does not + vary too much across time bins (lags), for entries representing two different units. + """ sorting = generate_sorting(num_units=2, sampling_frequency=10000.0, durations=[100000.0], seed=0) methods = ["numpy"] if HAVE_NUMBA: methods.append("numba") - # ~ import matplotlib.pyplot as plt - # ~ fig, ax = plt.subplots() - for method in methods: correlograms, bins = compute_correlograms_on_sorting(sorting, window_ms=50.0, bin_ms=1.0, method=method) cc = correlograms[0, 1, :].copy() @@ -94,11 +86,6 @@ def test_flat_cross_correlogram(): assert np.all(cc > (m * 0.90)) assert np.all(cc < (m * 1.10)) - # ~ ax.plot(bins[:-1], cc, label=method) - # ~ ax.legend() - # ~ ax.set_ylim(0, np.max(correlograms) * 1.1) - # ~ plt.show() - def test_auto_equal_cross_correlograms(): """ @@ -137,18 +124,13 @@ def test_auto_equal_cross_correlograms(): else: assert np.array_equal(cc_corrected, ac) - # ~ import matplotlib.pyplot as plt - # ~ fig, ax = plt.subplots() - # ~ ax.plot(bins[:-1], cc, marker='*', color='red', label='cross-corr') - # ~ ax.plot(bins[:-1], cc_corrected, marker='*', color='orange', label='cross-corr corrected') - # ~ ax.plot(bins[:-1], ac, marker='*', color='green', label='auto-corr') - # ~ ax.set_title(method) - # ~ ax.legend() - # ~ ax.set_ylim(0, np.max(correlograms) * 1.1) - # ~ plt.show() - def test_detect_injected_correlation(): + """ + Inject 1.44 ms of correlation every 13 spikes and compute + cross-correlation. Check that the time bin lag with the peak + correlation lag is 1.44 ms (within tolerance of a sampling period). + """ methods = ["numpy"] if HAVE_NUMBA: methods.append("numba") @@ -181,24 +163,3 @@ def test_detect_injected_correlation(): sampling_period_ms = 1000.0 / sampling_frequency assert abs(peak_location_01_ms) - injected_delta_ms < sampling_period_ms assert abs(peak_location_02_ms) - injected_delta_ms < sampling_period_ms - - # import matplotlib.pyplot as plt - # fig, ax = plt.subplots() - # half_bin_ms = np.mean(np.diff(bins)) / 2. - # ax.plot(bins[:-1]+half_bin_ms, cc_01, marker='*', color='red', label='cross-corr 0>1') - # ax.plot(bins[:-1]+half_bin_ms, cc_10, marker='*', color='orange', label='cross-corr 1>0') - # ax.set_title(method) - # ax.legend() - # plt.show() - - -if __name__ == "__main__": - # test_make_bins() - # test_equal_results_correlograms() - # test_flat_cross_correlogram() - # test_auto_equal_cross_correlograms() - # test_detect_injected_correlation() - - test = ComputeCorrelogramsTest() - test.setUpClass() - test.test_extension() diff --git a/src/spikeinterface/postprocessing/tests/test_isi.py b/src/spikeinterface/postprocessing/tests/test_isi.py index 8626e56453..444e837cb4 100644 --- a/src/spikeinterface/postprocessing/tests/test_isi.py +++ b/src/spikeinterface/postprocessing/tests/test_isi.py @@ -1,12 +1,11 @@ -import unittest import numpy as np from typing import List from spikeinterface.postprocessing.tests.common_extension_tests import AnalyzerExtensionCommonTestSuite -from spikeinterface.postprocessing import compute_isi_histograms, ComputeISIHistograms +from spikeinterface.postprocessing import ComputeISIHistograms from spikeinterface.postprocessing.isi import _compute_isi_histograms - +import pytest try: import numba @@ -16,38 +15,40 @@ HAVE_NUMBA = False -class ComputeISIHistogramsTest(AnalyzerExtensionCommonTestSuite, unittest.TestCase): - extension_class = ComputeISIHistograms - extension_function_params_list = [ - dict(method="numpy"), - dict(method="auto"), - ] - if HAVE_NUMBA: - extension_function_params_list.append(dict(method="numba")) +class TestComputeISIHistograms(AnalyzerExtensionCommonTestSuite): + + @pytest.mark.parametrize( + "params", + [ + dict(method="numpy"), + dict(method="auto"), + pytest.param(dict(method="numba"), marks=pytest.mark.skipif(not HAVE_NUMBA, reason="Numba not available")), + ], + ) + def test_extension(self, params): + self.run_extension_tests(ComputeISIHistograms, params) def test_compute_ISI(self): + """ + This test checks the creation of ISI histograms matches across + "numpy", "auto" and "numba" methods. Does not parameterize as requires + as list because everything tested against Numpy. The Numpy result is not + explicitly tested. + """ methods = ["numpy", "auto"] if HAVE_NUMBA: methods.append("numba") - _test_ISI(self.sorting, window_ms=60.0, bin_ms=1.0, methods=methods) - _test_ISI(self.sorting, window_ms=43.57, bin_ms=1.6421, methods=methods) - - -def _test_ISI(sorting, window_ms: float, bin_ms: float, methods: List[str]): - for method in methods: - ISI, bins = _compute_isi_histograms(sorting, window_ms=window_ms, bin_ms=bin_ms, method=method) - - if method == "numpy": - ref_ISI = ISI - ref_bins = bins - else: - assert np.all(ISI == ref_ISI), f"Failed with method={method}" - assert np.allclose(bins, ref_bins, atol=1e-10), f"Failed with method={method}" + self._test_ISI(self.sorting, window_ms=60.0, bin_ms=1.0, methods=methods) + self._test_ISI(self.sorting, window_ms=43.57, bin_ms=1.6421, methods=methods) + def _test_ISI(self, sorting, window_ms: float, bin_ms: float, methods: List[str]): + for method in methods: + ISI, bins = _compute_isi_histograms(sorting, window_ms=window_ms, bin_ms=bin_ms, method=method) -if __name__ == "__main__": - test = ComputeISIHistogramsTest() - test.setUpClass() - test.test_extension() - test.test_compute_ISI() + if method == "numpy": + ref_ISI = ISI + ref_bins = bins + else: + assert np.all(ISI == ref_ISI), f"Failed with method={method}" + assert np.allclose(bins, ref_bins, atol=1e-10), f"Failed with method={method}" diff --git a/src/spikeinterface/postprocessing/tests/test_noise_levels.py b/src/spikeinterface/postprocessing/tests/test_noise_levels.py index f334f92fa6..0f9265d00f 100644 --- a/src/spikeinterface/postprocessing/tests/test_noise_levels.py +++ b/src/spikeinterface/postprocessing/tests/test_noise_levels.py @@ -1 +1,3 @@ # "noise_levels" extensions is now in core + +# TODO: can this page now be deleted? diff --git a/src/spikeinterface/postprocessing/tests/test_principal_component.py b/src/spikeinterface/postprocessing/tests/test_principal_component.py index f9b847ec22..08ec32c6c2 100644 --- a/src/spikeinterface/postprocessing/tests/test_principal_component.py +++ b/src/spikeinterface/postprocessing/tests/test_principal_component.py @@ -1,28 +1,33 @@ -import unittest import pytest -from pathlib import Path - import numpy as np -from spikeinterface.postprocessing import ComputePrincipalComponents, compute_principal_components +from spikeinterface.postprocessing import ComputePrincipalComponents from spikeinterface.postprocessing.tests.common_extension_tests import AnalyzerExtensionCommonTestSuite -DEBUG = False - +class TestPrincipalComponentsExtension(AnalyzerExtensionCommonTestSuite): -class PrincipalComponentsExtensionTest(AnalyzerExtensionCommonTestSuite, unittest.TestCase): - extension_class = ComputePrincipalComponents - extension_function_params_list = [ - dict(mode="by_channel_local"), - dict(mode="by_channel_global"), - # mode concatenated cannot be tested here because it do not work with sparse=True - ] + @pytest.mark.parametrize( + "params", + [ + dict(mode="by_channel_local"), + dict(mode="by_channel_global"), + # mode concatenated cannot be tested here because it do not work with sparse=True + ], + ) + def test_extension(self, params): + self.run_extension_tests(ComputePrincipalComponents, params=params) def test_mode_concatenated(self): - # this is tested outside "extension_function_params_list" because it do not support sparsity! + """ + Replicate the "extension_function_params_list" test outside of + AnalyzerExtensionCommonTestSuite because it does not support sparsity. - sorting_analyzer = self._prepare_sorting_analyzer(format="memory", sparse=False) + Also, add two additional checks on the dimension and n components of the output. + """ + sorting_analyzer = self._prepare_sorting_analyzer( + format="memory", sparse=False, extension_class=ComputePrincipalComponents + ) n_components = 3 sorting_analyzer.compute("principal_components", mode="concatenated", n_components=n_components) @@ -33,94 +38,126 @@ def test_mode_concatenated(self): assert pca.ndim == 2 assert pca.shape[1] == n_components - def test_get_projections(self): - - for sparse in (False, True): - - sorting_analyzer = self._prepare_sorting_analyzer(format="memory", sparse=sparse) - num_chans = sorting_analyzer.get_num_channels() - n_components = 2 - - sorting_analyzer.compute("principal_components", mode="by_channel_global", n_components=n_components) - ext = sorting_analyzer.get_extension("principal_components") - - for unit_id in sorting_analyzer.unit_ids: - if not sparse: - one_proj = ext.get_projections_one_unit(unit_id, sparse=False) - assert one_proj.shape[1] == n_components - assert one_proj.shape[2] == num_chans - else: - one_proj = ext.get_projections_one_unit(unit_id, sparse=False) - assert one_proj.shape[1] == n_components - assert one_proj.shape[2] == num_chans - - one_proj, chan_inds = ext.get_projections_one_unit(unit_id, sparse=True) - assert one_proj.shape[1] == n_components - assert one_proj.shape[2] < num_chans - assert one_proj.shape[2] == chan_inds.size - - some_unit_ids = sorting_analyzer.unit_ids[::2] - some_channel_ids = sorting_analyzer.channel_ids[::2] - - random_spikes_indices = sorting_analyzer.get_extension("random_spikes").get_data() - - # this should be all spikes all channels - some_projections, spike_unit_index = ext.get_some_projections(channel_ids=None, unit_ids=None) - assert some_projections.shape[0] == spike_unit_index.shape[0] - assert spike_unit_index.shape[0] == random_spikes_indices.size - assert some_projections.shape[1] == n_components - assert some_projections.shape[2] == num_chans - - # this should be some spikes all channels - some_projections, spike_unit_index = ext.get_some_projections(channel_ids=None, unit_ids=some_unit_ids) - assert some_projections.shape[0] == spike_unit_index.shape[0] - assert spike_unit_index.shape[0] < random_spikes_indices.size - assert some_projections.shape[1] == n_components - assert some_projections.shape[2] == num_chans - assert 1 not in spike_unit_index - - # this should be some spikes some channels - some_projections, spike_unit_index = ext.get_some_projections( - channel_ids=some_channel_ids, unit_ids=some_unit_ids - ) - assert some_projections.shape[0] == spike_unit_index.shape[0] - assert spike_unit_index.shape[0] < random_spikes_indices.size - assert some_projections.shape[1] == n_components - assert some_projections.shape[2] == some_channel_ids.size - assert 1 not in spike_unit_index - - def test_compute_for_all_spikes(self): - - for sparse in (True, False): - sorting_analyzer = self._prepare_sorting_analyzer(format="memory", sparse=sparse) + @pytest.mark.parametrize("sparse", [True, False]) + def test_get_projections(self, sparse): + """ + Test the shape of output projection score matrices are + correct when adjusting sparsity and using the + `get_some_projections()` function. We expect them + to hold, for each spike and each channel, the loading + for each of the specified number of components. + """ + sorting_analyzer = self._prepare_sorting_analyzer( + format="memory", sparse=sparse, extension_class=ComputePrincipalComponents + ) + num_chans = sorting_analyzer.get_num_channels() + n_components = 2 + + sorting_analyzer.compute("principal_components", mode="by_channel_global", n_components=n_components) + ext = sorting_analyzer.get_extension("principal_components") - num_spikes = sorting_analyzer.sorting.to_spike_vector().size + # First, check the created projections have the expected number + # of components and the expected number of channels based on sparsity. + for unit_id in sorting_analyzer.unit_ids: + if not sparse: + one_proj = ext.get_projections_one_unit(unit_id, sparse=False) + assert one_proj.shape[1] == n_components + assert one_proj.shape[2] == num_chans + else: + one_proj = ext.get_projections_one_unit(unit_id, sparse=False) + assert one_proj.shape[1] == n_components + assert one_proj.shape[2] == num_chans + + one_proj, chan_inds = ext.get_projections_one_unit(unit_id, sparse=True) + assert one_proj.shape[1] == n_components + num_channels_for_unit = sorting_analyzer.sparsity.unit_id_to_channel_ids[unit_id].size + assert one_proj.shape[2] == num_channels_for_unit + assert one_proj.shape[2] == chan_inds.size + + # Next, check that the `get_some_projections()` function returns + # projections with the expected shapes when selecting subjsets + # of channel and unit IDs. + some_unit_ids = sorting_analyzer.unit_ids[::2] + some_channel_ids = sorting_analyzer.channel_ids[::2] + + random_spikes_indices = sorting_analyzer.get_extension("random_spikes").get_data() + all_num_spikes = sorting_analyzer.sorting.get_total_num_spikes() + unit_ids_num_spikes = np.sum(all_num_spikes[unit_id] for unit_id in some_unit_ids) + + # this should be all spikes all channels + some_projections, spike_unit_index = ext.get_some_projections(channel_ids=None, unit_ids=None) + assert some_projections.shape[0] == spike_unit_index.shape[0] + assert spike_unit_index.shape[0] == random_spikes_indices.size + assert some_projections.shape[1] == n_components + assert some_projections.shape[2] == num_chans + + # this should be some spikes all channels + some_projections, spike_unit_index = ext.get_some_projections(channel_ids=None, unit_ids=some_unit_ids) + assert some_projections.shape[0] == spike_unit_index.shape[0] + assert spike_unit_index.shape[0] == unit_ids_num_spikes + assert some_projections.shape[1] == n_components + assert some_projections.shape[2] == num_chans + assert 1 not in spike_unit_index + + # this should be some spikes some channels + some_projections, spike_unit_index = ext.get_some_projections( + channel_ids=some_channel_ids, unit_ids=some_unit_ids + ) + assert some_projections.shape[0] == spike_unit_index.shape[0] + assert spike_unit_index.shape[0] == unit_ids_num_spikes + assert some_projections.shape[1] == n_components + assert some_projections.shape[2] == some_channel_ids.size + assert 1 not in spike_unit_index + + @pytest.mark.parametrize("sparse", [True, False]) + def test_compute_for_all_spikes(self, sparse): + """ + Compute the principal component scores, checking the shape + matches the number of spikes as expected. This is re-run + with n_jobs=2 and output projection score matrices + checked against n_jobs=1. + """ + sorting_analyzer = self._prepare_sorting_analyzer( + format="memory", sparse=sparse, extension_class=ComputePrincipalComponents + ) + + num_spikes = sorting_analyzer.sorting.to_spike_vector().size - n_components = 3 - sorting_analyzer.compute("principal_components", mode="by_channel_local", n_components=n_components) - ext = sorting_analyzer.get_extension("principal_components") + n_components = 3 + sorting_analyzer.compute("principal_components", mode="by_channel_local", n_components=n_components) + ext = sorting_analyzer.get_extension("principal_components") - pc_file1 = self.cache_folder / "all_pc1.npy" - ext.run_for_all_spikes(pc_file1, chunk_size=10000, n_jobs=1) - all_pc1 = np.load(pc_file1) - assert all_pc1.shape[0] == num_spikes + pc_file1 = self.cache_folder / "all_pc1.npy" + ext.run_for_all_spikes(pc_file1, chunk_size=10000, n_jobs=1) + all_pc1 = np.load(pc_file1) + assert all_pc1.shape[0] == num_spikes - pc_file2 = self.cache_folder / "all_pc2.npy" - ext.run_for_all_spikes(pc_file2, chunk_size=10000, n_jobs=2) - all_pc2 = np.load(pc_file2) + pc_file2 = self.cache_folder / "all_pc2.npy" + ext.run_for_all_spikes(pc_file2, chunk_size=10000, n_jobs=2) + all_pc2 = np.load(pc_file2) - assert np.array_equal(all_pc1, all_pc2) + assert np.array_equal(all_pc1, all_pc2) def test_project_new(self): - from sklearn.decomposition import IncrementalPCA - - sorting_analyzer = self._prepare_sorting_analyzer(format="memory", sparse=False) + """ + `project_new` projects new (unseen) waveforms onto the PCA components. + First compute principal components from existing waveforms. Then, + generate a new 'spikes' vector that includes sample_index, unit_index + and segment_index alongside some waveforms (the spike vector is required + to generate some corresponding unit IDs for the generated waveforms following + the API of principal_components.py). + + Then, check that the new projection scores matrix is the expected shape. + """ + sorting_analyzer = self._prepare_sorting_analyzer( + format="memory", sparse=False, extension_class=ComputePrincipalComponents + ) waveforms = sorting_analyzer.get_extension("waveforms").data["waveforms"] n_components = 3 sorting_analyzer.compute("principal_components", mode="by_channel_local", n_components=n_components) - ext_pca = sorting_analyzer.get_extension(self.extension_name) + ext_pca = sorting_analyzer.get_extension(ComputePrincipalComponents.extension_name) num_spike = 100 new_spikes = sorting_analyzer.sorting.to_spike_vector()[:num_spike] @@ -130,20 +167,3 @@ def test_project_new(self): assert new_proj.shape[0] == num_spike assert new_proj.shape[1] == n_components assert new_proj.shape[2] == ext_pca.data["pca_projection"].shape[2] - - -if __name__ == "__main__": - test = PrincipalComponentsExtensionTest() - test.setUpClass() - test.test_extension() - test.test_mode_concatenated() - test.test_get_projections() - test.test_compute_for_all_spikes() - test.test_project_new() - - # ext = test.sorting_analyzers["sparseTrue_memory"].get_extension("principal_components") - # pca = ext.data["pca_projection"] - # import matplotlib.pyplot as plt - # fig, ax = plt.subplots() - # ax.scatter(pca[:, 0, 0], pca[:, 0, 1]) - # plt.show() diff --git a/src/spikeinterface/postprocessing/tests/test_spike_amplitudes.py b/src/spikeinterface/postprocessing/tests/test_spike_amplitudes.py index 8ff7666371..a68483a1b2 100644 --- a/src/spikeinterface/postprocessing/tests/test_spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/tests/test_spike_amplitudes.py @@ -1,22 +1,8 @@ -import unittest -import numpy as np - from spikeinterface.postprocessing import ComputeSpikeAmplitudes from spikeinterface.postprocessing.tests.common_extension_tests import AnalyzerExtensionCommonTestSuite -class ComputeSpikeAmplitudesTest(AnalyzerExtensionCommonTestSuite, unittest.TestCase): - extension_class = ComputeSpikeAmplitudes - extension_function_params_list = [ - dict(), - ] - - -if __name__ == "__main__": - test = ComputeSpikeAmplitudesTest() - test.setUpClass() - test.test_extension() +class TestComputeSpikeAmplitudes(AnalyzerExtensionCommonTestSuite): - # for k, sorting_analyzer in test.sorting_analyzers.items(): - # print(sorting_analyzer) - # print(sorting_analyzer.get_extension("spike_amplitudes").data["amplitudes"].shape) + def test_extension(self): + self.run_extension_tests(ComputeSpikeAmplitudes, params=dict()) diff --git a/src/spikeinterface/postprocessing/tests/test_spike_locations.py b/src/spikeinterface/postprocessing/tests/test_spike_locations.py index d48ff3d84b..46a39d23ea 100644 --- a/src/spikeinterface/postprocessing/tests/test_spike_locations.py +++ b/src/spikeinterface/postprocessing/tests/test_spike_locations.py @@ -1,26 +1,19 @@ -import unittest -import numpy as np - from spikeinterface.postprocessing import ComputeSpikeLocations from spikeinterface.postprocessing.tests.common_extension_tests import AnalyzerExtensionCommonTestSuite +import pytest -class SpikeLocationsExtensionTest(AnalyzerExtensionCommonTestSuite, unittest.TestCase): - extension_class = ComputeSpikeLocations - extension_function_params_list = [ - dict( - method="center_of_mass", spike_retriver_kwargs=dict(channel_from_template=True) - ), # chunk_size=10000, n_jobs=1, - dict(method="center_of_mass", spike_retriver_kwargs=dict(channel_from_template=False)), - dict( - method="center_of_mass", - ), - dict(method="monopolar_triangulation"), # , chunk_size=10000, n_jobs=1 - dict(method="grid_convolution"), # , chunk_size=10000, n_jobs=1 - ] - +class TestSpikeLocationsExtension(AnalyzerExtensionCommonTestSuite): -if __name__ == "__main__": - test = SpikeLocationsExtensionTest() - test.setUpClass() - test.test_extension() + @pytest.mark.parametrize( + "params", + [ + dict(method="center_of_mass", spike_retriver_kwargs=dict(channel_from_template=True)), + dict(method="center_of_mass", spike_retriver_kwargs=dict(channel_from_template=False)), + dict(method="center_of_mass"), + dict(method="monopolar_triangulation"), + dict(method="grid_convolution"), + ], + ) + def test_extension(self, params): + self.run_extension_tests(ComputeSpikeLocations, params) diff --git a/src/spikeinterface/postprocessing/tests/test_template_metrics.py b/src/spikeinterface/postprocessing/tests/test_template_metrics.py index 360f0f379f..694aa083cc 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_metrics.py +++ b/src/spikeinterface/postprocessing/tests/test_template_metrics.py @@ -1,20 +1,17 @@ -import unittest - - from spikeinterface.postprocessing.tests.common_extension_tests import AnalyzerExtensionCommonTestSuite from spikeinterface.postprocessing import ComputeTemplateMetrics +import pytest -class TemplateMetricsTest(AnalyzerExtensionCommonTestSuite, unittest.TestCase): - extension_class = ComputeTemplateMetrics - extension_function_params_list = [ - dict(), - dict(upsampling_factor=2), - dict(include_multi_channel_metrics=True), - ] - +class TestTemplateMetrics(AnalyzerExtensionCommonTestSuite): -if __name__ == "__main__": - test = TemplateMetricsTest() - test.setUpClass() - test.test_extension() + @pytest.mark.parametrize( + "params", + [ + dict(), + dict(upsampling_factor=2), + dict(include_multi_channel_metrics=True), + ], + ) + def test_extension(self, params): + self.run_extension_tests(ComputeTemplateMetrics, params) diff --git a/src/spikeinterface/postprocessing/tests/test_template_similarity.py b/src/spikeinterface/postprocessing/tests/test_template_similarity.py index 1693530454..a4de2a3a90 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_similarity.py +++ b/src/spikeinterface/postprocessing/tests/test_template_similarity.py @@ -1,23 +1,23 @@ -import unittest - from spikeinterface.postprocessing.tests.common_extension_tests import ( AnalyzerExtensionCommonTestSuite, - get_dataset, ) from spikeinterface.postprocessing import check_equal_template_with_distribution_overlap, ComputeTemplateSimilarity -class SimilarityExtensionTest(AnalyzerExtensionCommonTestSuite, unittest.TestCase): - extension_class = ComputeTemplateSimilarity - extension_function_params_list = [ - dict(method="cosine_similarity"), - ] +class TestSimilarityExtension(AnalyzerExtensionCommonTestSuite): - def test_check_equal_template_with_distribution_overlap(self): - recording, sorting = get_dataset() + def test_extension(self): + self.run_extension_tests(ComputeTemplateSimilarity, params=dict(method="cosine_similarity")) - sorting_analyzer = self.get_sorting_analyzer(recording=recording, sorting=sorting, sparsity=None) + def test_check_equal_template_with_distribution_overlap(self): + """ + Create a sorting object, extract its waveforms. Compare waveforms + from all pairs of units (excluding a unit against itself) + and check `check_equal_template_with_distribution_overlap()` + correctly determines they are different. + """ + sorting_analyzer = self._prepare_sorting_analyzer("memory", None, ComputeTemplateSimilarity) sorting_analyzer.compute("random_spikes") sorting_analyzer.compute("waveforms") sorting_analyzer.compute("templates") @@ -30,12 +30,5 @@ def test_check_equal_template_with_distribution_overlap(self): if unit_id0 == unit_id1: continue waveforms1 = wf_ext.get_waveforms_one_unit(unit_id1) - check_equal_template_with_distribution_overlap(waveforms0, waveforms1) - - -if __name__ == "__main__": - test = SimilarityExtensionTest() - # test.setUpClass() - # test.test_extension() - # test_check_equal_template_with_distribution_overlap() + assert not check_equal_template_with_distribution_overlap(waveforms0, waveforms1) diff --git a/src/spikeinterface/postprocessing/tests/test_unit_locations.py b/src/spikeinterface/postprocessing/tests/test_unit_locations.py index b23adf5868..c40a917a2b 100644 --- a/src/spikeinterface/postprocessing/tests/test_unit_locations.py +++ b/src/spikeinterface/postprocessing/tests/test_unit_locations.py @@ -1,21 +1,19 @@ -import unittest from spikeinterface.postprocessing.tests.common_extension_tests import AnalyzerExtensionCommonTestSuite from spikeinterface.postprocessing import ComputeUnitLocations +import pytest -class UnitLocationsExtensionTest(AnalyzerExtensionCommonTestSuite, unittest.TestCase): - extension_class = ComputeUnitLocations - extension_function_params_list = [ - dict(method="center_of_mass", radius_um=100), - dict(method="grid_convolution", radius_um=50), - dict(method="grid_convolution", radius_um=150, weight_method={"mode": "gaussian_2d"}), - dict(method="monopolar_triangulation", radius_um=150), - dict(method="monopolar_triangulation", radius_um=150, optimizer="minimize_with_log_penality"), - ] +class TestUnitLocationsExtension(AnalyzerExtensionCommonTestSuite): - -if __name__ == "__main__": - test = UnitLocationsExtensionTest() - test.setUpClass() - test.test_extension() - # test.tearDown() + @pytest.mark.parametrize( + "params", + [ + dict(method="center_of_mass", radius_um=100), + dict(method="grid_convolution", radius_um=50), + dict(method="grid_convolution", radius_um=150, weight_method={"mode": "gaussian_2d"}), + dict(method="monopolar_triangulation", radius_um=150), + dict(method="monopolar_triangulation", radius_um=150, optimizer="minimize_with_log_penality"), + ], + ) + def test_extension(self, params): + self.run_extension_tests(ComputeUnitLocations, params=params) diff --git a/src/spikeinterface/postprocessing/unit_locations.py b/src/spikeinterface/postprocessing/unit_locations.py index e40523e7e5..16d9955e58 100644 --- a/src/spikeinterface/postprocessing/unit_locations.py +++ b/src/spikeinterface/postprocessing/unit_locations.py @@ -239,7 +239,7 @@ def compute_monopolar_triangulation( contact_locations = sorting_analyzer.get_channel_locations() sparsity = compute_sparsity(sorting_analyzer, method="radius", radius_um=radius_um) - templates = get_dense_templates_array(sorting_analyzer) + templates = get_dense_templates_array(sorting_analyzer, return_scaled=sorting_analyzer.return_scaled) nbefore = _get_nbefore(sorting_analyzer) if enforce_decrease: @@ -303,7 +303,7 @@ def compute_center_of_mass(sorting_analyzer, peak_sign="neg", radius_um=75, feat assert feature in ["ptp", "mean", "energy", "peak_voltage"], f"{feature} is not a valid feature" sparsity = compute_sparsity(sorting_analyzer, peak_sign=peak_sign, method="radius", radius_um=radius_um) - templates = get_dense_templates_array(sorting_analyzer) + templates = get_dense_templates_array(sorting_analyzer, return_scaled=sorting_analyzer.return_scaled) nbefore = _get_nbefore(sorting_analyzer) unit_location = np.zeros((unit_ids.size, 2), dtype="float64") @@ -374,7 +374,7 @@ def compute_grid_convolution( contact_locations = sorting_analyzer.get_channel_locations() unit_ids = sorting_analyzer.unit_ids - templates = get_dense_templates_array(sorting_analyzer) + templates = get_dense_templates_array(sorting_analyzer, return_scaled=sorting_analyzer.return_scaled) nbefore = _get_nbefore(sorting_analyzer) nafter = templates.shape[1] - nbefore diff --git a/src/spikeinterface/preprocessing/astype.py b/src/spikeinterface/preprocessing/astype.py index da1435130c..a05610ea2e 100644 --- a/src/spikeinterface/preprocessing/astype.py +++ b/src/spikeinterface/preprocessing/astype.py @@ -14,8 +14,20 @@ class AstypeRecording(BasePreprocessor): For recording with an unsigned dtype, please use the `unsigned_to_signed` preprocessing function. - If `round` is True, will round the values to the nearest integer. - If `round` is None, will round in the case of float to integer conversion. + Parameters + ---------- + dtype : None | str | dtype, default: None + dtype of the output recording. If None, takes dtype from input `recording`. + recording : Recording + The recording extractor to be converted. + round : Bool | None, default: None + If True, will round the values to the nearest integer using `numpy.round`. + If None and dtype is an integer, will round floats to nearest integer. + + Returns + ------- + astype_recording : AstypeRecording + The converted recording extractor object """ name = "astype" diff --git a/src/spikeinterface/preprocessing/depth_order.py b/src/spikeinterface/preprocessing/depth_order.py index 9569459080..f08f6404da 100644 --- a/src/spikeinterface/preprocessing/depth_order.py +++ b/src/spikeinterface/preprocessing/depth_order.py @@ -12,7 +12,7 @@ class DepthOrderRecording(ChannelSliceRecording): Parameters ---------- - recording : BaseRecording + parent_recording : BaseRecording The recording to re-order. channel_ids : list/array or None If given, a subset of channels to order locations for diff --git a/src/spikeinterface/preprocessing/detect_bad_channels.py b/src/spikeinterface/preprocessing/detect_bad_channels.py index 276a8ac0b4..5d8f7107c7 100644 --- a/src/spikeinterface/preprocessing/detect_bad_channels.py +++ b/src/spikeinterface/preprocessing/detect_bad_channels.py @@ -57,29 +57,29 @@ def detect_bad_channels( The method to be used for bad channel detection std_mad_threshold : float, default: 5 The standard deviation/mad multiplier threshold - psd_hf_threshold (coeherence+psd) : float, default: 0.02 - An absolute threshold (uV^2/Hz) used as a cutoff for noise channels. + psd_hf_threshold : float, default: 0.02 + For coherence+psd - an absolute threshold (uV^2/Hz) used as a cutoff for noise channels. Channels with average power at >80% Nyquist larger than this threshold will be labeled as noise - dead_channel_threshold (coeherence+psd) : float, default: -0.5 - Threshold for channel coherence below which channels are labeled as dead - noisy_channel_threshold (coeherence+psd) : float, default: 1 + dead_channel_threshold : float, default: -0.5 + For coherence+psd - threshold for channel coherence below which channels are labeled as dead + noisy_channel_threshold : float, default: 1 Threshold for channel coherence above which channels are labeled as noisy (together with psd condition) - outside_channel_threshold (coeherence+psd) : float, default: -0.75 - Threshold for channel coherence above which channels at the edge of the recording are marked as outside + outside_channel_threshold : float, default: -0.75 + For coherence+psd - threshold for channel coherence above which channels at the edge of the recording are marked as outside of the brain - outside_channels_location (coeherence+psd) : "top" | "bottom" | "both", default: "top" - Location of the outside channels. If "top", only the channels at the top of the probe can be + outside_channels_location : "top" | "bottom" | "both", default: "top" + For coherence+psd - location of the outside channels. If "top", only the channels at the top of the probe can be marked as outside channels. If "bottom", only the channels at the bottom of the probe can be marked as outside channels. If "both", both the channels at the top and bottom of the probe can be marked as outside channels - n_neighbors (coeherence+psd) : int, default: 11 - Number of channel neighbors to compute median filter (needs to be odd) - nyquist_threshold (coeherence+psd) : float, default: 0.8 - Frequency with respect to Nyquist (Fn=1) above which the mean of the PSD is calculated and compared + n_neighbors : int, default: 11 + For coeherence+psd - number of channel neighbors to compute median filter (needs to be odd) + nyquist_threshold : float, default: 0.8 + For coherence+psd - frequency with respect to Nyquist (Fn=1) above which the mean of the PSD is calculated and compared with psd_hf_threshold - direction (coeherence+psd) : "x" | "y" | "z", default: "y" - The depth dimension + direction : "x" | "y" | "z", default: "y" + For coherence+psd - the depth dimension highpass_filter_cutoff : float, default: 300 If the recording is not filtered, the cutoff frequency of the highpass filter chunk_duration_s : float, default: 0.5 diff --git a/src/spikeinterface/preprocessing/filter.py b/src/spikeinterface/preprocessing/filter.py index 3f1a155d0d..6a1733c57c 100644 --- a/src/spikeinterface/preprocessing/filter.py +++ b/src/spikeinterface/preprocessing/filter.py @@ -8,14 +8,16 @@ from ..core import get_chunk_with_margin -_common_filter_docs = """**filter_kwargs : keyword arguments for parallel processing: - - * filter_order : order - The order of the filter - * filter_mode : "sos or "ba" - "sos" is bi quadratic and more stable than ab so thery are prefered. - * ftype : str - Filter type for iirdesign ("butter" / "cheby1" / ... all possible of scipy.signal.iirdesign) +_common_filter_docs = """**filter_kwargs : dict + Certain keyword arguments for `scipy.signal` filters: + filter_order : order + The order of the filter + filter_mode : "sos" | "ba", default: "sos" + Filter form of the filter coefficients: + - second-order sections ("sos") + - numerator/denominator : ("ba") + ftype : str, default: "butter" + Filter type for `scipy.signal.iirfilter` e.g. "butter", "cheby1". """ @@ -39,21 +41,25 @@ class FilterRecording(BasePreprocessor): Type of the filter margin_ms : float, default: 5.0 Margin in ms on border to avoid border effect - filter_mode : "sos" | "ba", default: "sos" - Filter form of the filter coefficients: - - second-order sections ("sos") - - numerator/denominator : ("ba") - coef : array or None, default: None + coeff : array | None, default: None Filter coefficients in the filter_mode form. dtype : dtype or None, default: None The dtype of the returned traces. If None, the dtype of the parent recording is used - {} + add_reflect_padding : Bool, default False + If True, uses a left and right margin during calculation. + filter_order : order + The order of the filter for `scipy.signal.iirfilter` + filter_mode : "sos" | "ba", default: "sos" + Filter form of the filter coefficients for `scipy.signal.iirfilter`: + - second-order sections ("sos") + - numerator/denominator : ("ba") + ftype : str, default: "butter" + Filter type for `scipy.signal.iirfilter` e.g. "butter", "cheby1". Returns ------- filter_recording : FilterRecording The filtered recording extractor object - """ name = "filter" @@ -179,6 +185,7 @@ class BandpassFilterRecording(FilterRecording): dtype : dtype or None The dtype of the returned traces. If None, the dtype of the parent recording is used {} + Returns ------- filter_recording : BandpassFilterRecording @@ -213,6 +220,7 @@ class HighpassFilterRecording(FilterRecording): dtype : dtype or None The dtype of the returned traces. If None, the dtype of the parent recording is used {} + Returns ------- filter_recording : HighpassFilterRecording @@ -240,7 +248,11 @@ class NotchFilterRecording(BasePreprocessor): The target frequency in Hz of the notch filter q : int The quality factor of the notch filter - {} + dtype : None | dtype, default: None + dtype of recording. If None, will take from `recording` + margin_ms : float, default: 5.0 + Margin in ms on border to avoid border effect + Returns ------- filter_recording : NotchFilterRecording @@ -284,6 +296,9 @@ def __init__(self, recording, freq=3000, q=30, margin_ms=5.0, dtype=None): notch_filter = define_function_from_class(source_class=NotchFilterRecording, name="notch_filter") highpass_filter = define_function_from_class(source_class=HighpassFilterRecording, name="highpass_filter") +bandpass_filter.__doc__ = bandpass_filter.__doc__.format(_common_filter_docs) +highpass_filter.__doc__ = highpass_filter.__doc__.format(_common_filter_docs) + def fix_dtype(recording, dtype): if dtype is None: diff --git a/src/spikeinterface/preprocessing/motion.py b/src/spikeinterface/preprocessing/motion.py index cefe4d4d7a..8023bd4367 100644 --- a/src/spikeinterface/preprocessing/motion.py +++ b/src/spikeinterface/preprocessing/motion.py @@ -1,14 +1,14 @@ from __future__ import annotations -import time -from pathlib import Path - import numpy as np import json +from pathlib import Path +import time from spikeinterface.core import get_noise_levels, fix_job_kwargs from spikeinterface.core.job_tools import _shared_job_kwargs_doc from spikeinterface.core.core_tools import SIJsonEncoder +from spikeinterface.core.job_tools import _shared_job_kwargs_doc motion_options_preset = { # This preset should be the most acccurate @@ -68,7 +68,7 @@ weight_with_amplitude=False, ), "interpolate_motion_kwargs": dict( - direction=1, border_mode="remove_channels", spatial_interpolation_method="kriging", sigma_um=20.0, p=2 + border_mode="remove_channels", spatial_interpolation_method="kriging", sigma_um=20.0, p=2 ), }, "nonrigid_fast_and_accurate": { @@ -127,7 +127,7 @@ weight_with_amplitude=False, ), "interpolate_motion_kwargs": dict( - direction=1, border_mode="remove_channels", spatial_interpolation_method="kriging", sigma_um=20.0, p=2 + border_mode="remove_channels", spatial_interpolation_method="kriging", sigma_um=20.0, p=2 ), }, # This preset is a super fast rigid estimation with center of mass @@ -152,7 +152,7 @@ rigid=True, ), "interpolate_motion_kwargs": dict( - direction=1, border_mode="remove_channels", spatial_interpolation_method="kriging", sigma_um=20.0, p=2 + border_mode="remove_channels", spatial_interpolation_method="kriging", sigma_um=20.0, p=2 ), }, # This preset try to mimic kilosort2.5 motion estimator @@ -186,7 +186,7 @@ win_shape="rect", ), "interpolate_motion_kwargs": dict( - direction=1, border_mode="force_extrapolate", spatial_interpolation_method="kriging", sigma_um=20.0, p=2 + border_mode="force_extrapolate", spatial_interpolation_method="kriging", sigma_um=20.0, p=2 ), }, # empty preset @@ -276,9 +276,8 @@ def correct_motion( recording_corrected : Recording The motion corrected recording motion_info : dict - Optional output if `output_motion_info=True` + Optional output if `output_motion_info=True`. The key "motion" holds the Motion object. """ - # local import are important because "sortingcomponents" is not important by default from spikeinterface.sortingcomponents.peak_detection import detect_peaks, detect_peak_methods from spikeinterface.sortingcomponents.peak_selection import select_peaks @@ -377,21 +376,15 @@ def correct_motion( np.save(folder / "peak_locations.npy", peak_locations) t0 = time.perf_counter() - motion, temporal_bins, spatial_bins = estimate_motion(recording, peaks, peak_locations, **estimate_motion_kwargs) + motion = estimate_motion(recording, peaks, peak_locations, **estimate_motion_kwargs) t1 = time.perf_counter() run_times["estimate_motion"] = t1 - t0 - recording_corrected = InterpolateMotionRecording( - recording, motion, temporal_bins, spatial_bins, **interpolate_motion_kwargs - ) + recording_corrected = InterpolateMotionRecording(recording, motion, **interpolate_motion_kwargs) if folder is not None: (folder / "run_times.json").write_text(json.dumps(run_times, indent=4), encoding="utf8") - - np.save(folder / "temporal_bins.npy", temporal_bins) - np.save(folder / "motion.npy", motion) - if spatial_bins is not None: - np.save(folder / "spatial_bins.npy", spatial_bins) + motion.save(folder / "motion") if output_motion_info: motion_info = dict( @@ -399,8 +392,6 @@ def correct_motion( run_times=run_times, peaks=peaks, peak_locations=peak_locations, - temporal_bins=temporal_bins, - spatial_bins=spatial_bins, motion=motion, ) return recording_corrected, motion_info @@ -419,6 +410,8 @@ def correct_motion( def load_motion_info(folder): + from spikeinterface.sortingcomponents.motion_utils import Motion + folder = Path(folder) motion_info = {} @@ -429,11 +422,13 @@ def load_motion_info(folder): with open(folder / "run_times.json") as f: motion_info["run_times"] = json.load(f) - array_names = ("peaks", "peak_locations", "temporal_bins", "spatial_bins", "motion") + array_names = ("peaks", "peak_locations") for name in array_names: if (folder / f"{name}.npy").exists(): motion_info[name] = np.load(folder / f"{name}.npy") else: motion_info[name] = None + motion_info["motion"] = Motion.load(folder / "motion") + return motion_info diff --git a/src/spikeinterface/preprocessing/normalize_scale.py b/src/spikeinterface/preprocessing/normalize_scale.py index 44b9ac9937..e537be4694 100644 --- a/src/spikeinterface/preprocessing/normalize_scale.py +++ b/src/spikeinterface/preprocessing/normalize_scale.py @@ -54,7 +54,7 @@ class NormalizeByQuantileRecording(BasePreprocessor): Median for the output distribution q1 : float, default: 0.01 Lower quantile used for measuring the scale - q1 : float, default: 0.99 + q2 : float, default: 0.99 Upper quantile used for measuring the mode : "by_channel" | "pool_channel", default: "by_channel" If "by_channel" each channel is rescaled independently. diff --git a/src/spikeinterface/preprocessing/phase_shift.py b/src/spikeinterface/preprocessing/phase_shift.py index 02fc9b1206..ca93d58364 100644 --- a/src/spikeinterface/preprocessing/phase_shift.py +++ b/src/spikeinterface/preprocessing/phase_shift.py @@ -31,6 +31,9 @@ class PhaseShiftRecording(BasePreprocessor): inter_sample_shift : None or numpy array, default: None If "inter_sample_shift" is not in recording properties, we can externally provide one. + dtype : None | str | dtype, default: None + Dtype of input and output `recording` objects. + Returns ------- diff --git a/src/spikeinterface/preprocessing/preprocessing_tools.py b/src/spikeinterface/preprocessing/preprocessing_tools.py index c0b80c349b..942478fd71 100644 --- a/src/spikeinterface/preprocessing/preprocessing_tools.py +++ b/src/spikeinterface/preprocessing/preprocessing_tools.py @@ -80,7 +80,7 @@ def get_spatial_interpolation_kernel( elif method == "idw": distances = scipy.spatial.distance.cdist(source_location, target_location, metric="euclidean") - interpolation_kernel = np.zeros((source_location.shape[0], target_location.shape[0]), dtype="float64") + interpolation_kernel = np.zeros((source_location.shape[0], target_location.shape[0]), dtype=dtype) for c in range(target_location.shape[0]): ind_sorted = np.argsort(distances[:, c]) chan_closest = ind_sorted[:num_closest] @@ -97,7 +97,7 @@ def get_spatial_interpolation_kernel( elif method == "nearest": distances = scipy.spatial.distance.cdist(source_location, target_location, metric="euclidean") - interpolation_kernel = np.zeros((source_location.shape[0], target_location.shape[0]), dtype="float64") + interpolation_kernel = np.zeros((source_location.shape[0], target_location.shape[0]), dtype=dtype) for c in range(target_location.shape[0]): ind_closest = np.argmin(distances[:, c]) interpolation_kernel[ind_closest, c] = 1.0 diff --git a/src/spikeinterface/preprocessing/resample.py b/src/spikeinterface/preprocessing/resample.py index cc110118a5..4843df5444 100644 --- a/src/spikeinterface/preprocessing/resample.py +++ b/src/spikeinterface/preprocessing/resample.py @@ -28,7 +28,7 @@ class ResampleRecording(BasePreprocessor): The recording extractor to be re-referenced resample_rate : int The resampling frequency - margin : float, default: 100.0 + margin_ms : float, default: 100.0 Margin in ms for computations, will be used to decrease edge effects. dtype : dtype or None, default: None The dtype of the returned traces. If None, the dtype of the parent recording is used. diff --git a/src/spikeinterface/preprocessing/silence_periods.py b/src/spikeinterface/preprocessing/silence_periods.py index 5f70bfbb40..74d370b3a9 100644 --- a/src/spikeinterface/preprocessing/silence_periods.py +++ b/src/spikeinterface/preprocessing/silence_periods.py @@ -25,7 +25,9 @@ class SilencedPeriodsRecording(BasePreprocessor): One list per segment of tuples (start_frame, end_frame) to silence noise_levels : array Noise levels if already computed - + seed : int | None, default: None + Random seed for `get_noise_levels` and `NoiseGeneratorRecording`. + If none, `get_noise_levels` uses `seed=0` and `NoiseGeneratorRecording` generates a random seed using `numpy.random.default_rng`. mode : "zeros" | "noise, default: "zeros" Determines what periods are replaced by. Can be one of the following: diff --git a/src/spikeinterface/preprocessing/tests/test_highpass_spatial_filter.py b/src/spikeinterface/preprocessing/tests/test_highpass_spatial_filter.py index 5c843e7c0b..0dd75fd476 100644 --- a/src/spikeinterface/preprocessing/tests/test_highpass_spatial_filter.py +++ b/src/spikeinterface/preprocessing/tests/test_highpass_spatial_filter.py @@ -8,14 +8,7 @@ import spikeinterface.extractors as se from spikeinterface.core import generate_recording import spikeinterface.widgets as sw - -try: - import spikeglx - import neurodsp.voltage as voltage - - HAVE_IBL_NPIX = True -except ImportError: - HAVE_IBL_NPIX = False +import importlib.util ON_GITHUB = bool(os.getenv("GITHUB_ACTIONS")) @@ -31,7 +24,10 @@ # ---------------------------------------------------------------------------------------------------------------------- -@pytest.mark.skipif(not HAVE_IBL_NPIX or ON_GITHUB, reason="Only local. Requires ibl-neuropixel install") +@pytest.mark.skipif( + importlib.util.find_spec("neurodsp") is not None or importlib.util.find_spec("spikeglx") or ON_GITHUB, + reason="Only local. Requires ibl-neuropixel install", +) @pytest.mark.parametrize("lagc", [False, 1, 300]) def test_highpass_spatial_filter_real_data(lagc): """ @@ -56,6 +52,9 @@ def test_highpass_spatial_filter_real_data(lagc): use DEBUG = true to visualise. """ + import spikeglx + import neurodsp.voltage as voltage + options = dict(lagc=lagc, ntr_pad=25, ntr_tap=50, butter_kwargs=None) print(options) @@ -146,6 +145,8 @@ def get_ibl_si_data(): """ Set fixture to session to ensure origional data is not changed. """ + import spikeglx + local_path = si.download_dataset(remote_path="spikeglx/Noise4Sam_g0") ibl_recording = spikeglx.Reader( local_path / "Noise4Sam_g0_imec0" / "Noise4Sam_g0_t0.imec0.ap.bin", ignore_warnings=True diff --git a/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py b/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py index ad073e40aa..1189f04f7d 100644 --- a/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py +++ b/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py @@ -6,17 +6,10 @@ import spikeinterface.preprocessing as spre import spikeinterface.extractors as se from spikeinterface.core.generate import generate_recording +import importlib.util -try: - import spikeglx - import neurodsp.voltage as voltage - - HAVE_IBL_NPIX = True -except ImportError: - HAVE_IBL_NPIX = False ON_GITHUB = bool(os.getenv("GITHUB_ACTIONS")) - DEBUG = False if DEBUG: import matplotlib.pyplot as plt @@ -30,7 +23,10 @@ # ------------------------------------------------------------------------------- -@pytest.mark.skipif(not HAVE_IBL_NPIX or ON_GITHUB, reason="Only local. Requires ibl-neuropixel install") +@pytest.mark.skipif( + importlib.util.find_spec("neurodsp") is not None or importlib.util.find_spec("spikeglx") or ON_GITHUB, + reason="Only local. Requires ibl-neuropixel install", +) def test_compare_real_data_with_ibl(): """ Test SI implementation of bad channel interpolation against native IBL. @@ -43,6 +39,9 @@ def test_compare_real_data_with_ibl(): si_scaled_recordin.get_traces(0) is also close to 1e-2. """ # Download and load data + import spikeglx + import neurodsp.voltage as voltage + local_path = si.download_dataset(remote_path="spikeglx/Noise4Sam_g0") si_recording = se.read_spikeglx(local_path, stream_id="imec0.ap") ibl_recording = spikeglx.Reader( @@ -80,7 +79,10 @@ def test_compare_real_data_with_ibl(): assert np.mean(is_close) > 0.999 -@pytest.mark.skipif(not HAVE_IBL_NPIX, reason="Requires ibl-neuropixel install") +@pytest.mark.skipif( + importlib.util.find_spec("neurodsp") is not None or importlib.util.find_spec("spikeglx") is not None, + reason="Requires ibl-neuropixel install", +) @pytest.mark.parametrize("num_channels", [32, 64]) @pytest.mark.parametrize("sigma_um", [1.25, 40]) @pytest.mark.parametrize("p", [0, -0.5, 1, 5]) @@ -90,6 +92,8 @@ def test_compare_input_argument_ranges_against_ibl(shanks, p, sigma_um, num_chan Perform an extended test across a range of function inputs to check IBL and SI interpolation results match. """ + import neurodsp.voltage as voltage + recording = generate_recording(num_channels=num_channels, durations=[1]) # distribute default probe locations across 4 shanks if set diff --git a/src/spikeinterface/preprocessing/tests/test_motion.py b/src/spikeinterface/preprocessing/tests/test_motion.py index e79fda1ad8..a298b41d8f 100644 --- a/src/spikeinterface/preprocessing/tests/test_motion.py +++ b/src/spikeinterface/preprocessing/tests/test_motion.py @@ -1,14 +1,11 @@ -import pytest -from pathlib import Path - import shutil +from pathlib import Path +import numpy as np +import pytest from spikeinterface.core import generate_recording - from spikeinterface.preprocessing import correct_motion, load_motion_info -import numpy as np - def test_estimate_and_correct_motion(create_cache_folder): cache_folder = create_cache_folder @@ -18,6 +15,7 @@ def test_estimate_and_correct_motion(create_cache_folder): folder = cache_folder / "estimate_and_correct_motion" if folder.exists(): shutil.rmtree(folder) + rec_corrected = correct_motion(rec, folder=folder) print(rec_corrected) @@ -26,5 +24,5 @@ def test_estimate_and_correct_motion(create_cache_folder): if __name__ == "__main__": - print(correct_motion.__doc__) - # test_estimate_and_correct_motion() + # print(correct_motion.__doc__) + test_estimate_and_correct_motion() diff --git a/src/spikeinterface/sorters/basesorter.py b/src/spikeinterface/sorters/basesorter.py index 2f87065d9f..8c52626703 100644 --- a/src/spikeinterface/sorters/basesorter.py +++ b/src/spikeinterface/sorters/basesorter.py @@ -183,7 +183,7 @@ def set_params_to_folder(cls, recording, output_folder, new_params, verbose): # custom check params params = cls._check_params(recording, output_folder, params) # common check : filter warning - if recording.is_filtered and cls._check_apply_filter_in_params(params) and verbose: + if recording.is_filtered() and cls._check_apply_filter_in_params(params) and verbose: print(f"Warning! The recording is already filtered, but {cls.sorter_name} filter is enabled") # dump parameters inside the folder with json diff --git a/src/spikeinterface/sorters/external/mountainsort5.py b/src/spikeinterface/sorters/external/mountainsort5.py index 6fa68de190..cf6933c9e6 100644 --- a/src/spikeinterface/sorters/external/mountainsort5.py +++ b/src/spikeinterface/sorters/external/mountainsort5.py @@ -120,7 +120,6 @@ def _setup_recording(cls, recording, sorter_output_folder, params, verbose): @classmethod def _run_from_folder(cls, sorter_output_folder, params, verbose): import mountainsort5 as ms5 - from mountainsort5.util import create_cached_recording recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) if recording is None: diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index f2c385b718..b5df0f1059 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -1,9 +1,7 @@ from __future__ import annotations -from operator import is_ from .si_based import ComponentsBasedSorter -import os import shutil import numpy as np @@ -11,7 +9,6 @@ from spikeinterface.core.job_tools import fix_job_kwargs from spikeinterface.core.recording_tools import get_noise_levels from spikeinterface.core.template import Templates -from spikeinterface.core.template_tools import get_template_extremum_amplitude from spikeinterface.core.waveform_tools import estimate_templates from spikeinterface.preprocessing import common_reference, whiten, bandpass_filter, correct_motion from spikeinterface.sortingcomponents.tools import cache_preprocessing @@ -316,7 +313,11 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): from spikeinterface.preprocessing.motion import load_motion_info motion_info = load_motion_info(motion_folder) - merging_params["maximum_distance_um"] = max(50, 2 * np.abs(motion_info["motion"]).max()) + motion = motion_info["motion"] + max_motion = max( + np.max(np.abs(motion.displacement[seg_index])) for seg_index in range(len(motion.displacement)) + ) + merging_params["maximum_distance_um"] = max(50, 2 * max_motion) # peak_sign = params['detection'].get('peak_sign', 'neg') # best_amplitudes = get_template_extremum_amplitude(templates, peak_sign=peak_sign) diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py index 3212f95e7f..55ef21de9d 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py @@ -1,20 +1,22 @@ from __future__ import annotations -import time +import json from pathlib import Path +import pickle +import time import numpy as np from spikeinterface.core import get_noise_levels +from spikeinterface.sortingcomponents.benchmark.benchmark_tools import Benchmark, BenchmarkStudy, _simpleaxis +from spikeinterface.sortingcomponents.motion_estimation import estimate_motion from spikeinterface.sortingcomponents.peak_detection import detect_peaks from spikeinterface.sortingcomponents.peak_selection import select_peaks from spikeinterface.sortingcomponents.peak_localization import localize_peaks -from spikeinterface.sortingcomponents.motion_estimation import estimate_motion -from spikeinterface.sortingcomponents.benchmark.benchmark_tools import Benchmark, BenchmarkStudy, _simpleaxis - - from spikeinterface.widgets import plot_probe_map +from spikeinterface.sortingcomponents.motion_utils import Motion + # import MEArec as mr # TODO : plot_peaks @@ -28,8 +30,8 @@ def get_gt_motion_from_unit_displacement( unit_displacements, displacement_sampling_frequency, unit_locations, - temporal_bins, - spatial_bins, + temporal_bins_s, + spatial_bins_um, direction_dim=1, ): import scipy.interpolate @@ -37,20 +39,24 @@ def get_gt_motion_from_unit_displacement( unit_displacements = unit_displacements[:, :, direction_dim] times = np.arange(unit_displacements.shape[0]) / displacement_sampling_frequency f = scipy.interpolate.interp1d(times, unit_displacements, axis=0) - unit_displacements = f(temporal_bins) + unit_displacements = f(temporal_bins_s.clip(times[0], times[-1])) # spatial interpolataion of units discplacement - if spatial_bins.shape[0] == 1: + if spatial_bins_um.shape[0] == 1: # rigid - gt_motion = np.mean(unit_displacements, axis=1)[:, None] + gt_displacement = np.mean(unit_displacements, axis=1)[:, None] else: # non rigid - gt_motion = np.zeros((temporal_bins.size, spatial_bins.size)) - for t in range(temporal_bins.shape[0]): + gt_displacement = np.zeros((temporal_bins_s.size, spatial_bins_um.size)) + for t in range(temporal_bins_s.shape[0]): f = scipy.interpolate.interp1d( unit_locations[:, direction_dim], unit_displacements[t, :], fill_value="extrapolate" ) - gt_motion[t, :] = f(spatial_bins) + gt_displacement[t, :] = f(spatial_bins_um) + + gt_motion = Motion( + gt_displacement, temporal_bins_s, spatial_bins_um, direction="xyz"[direction_dim], interpolation_method="linear" + ) return gt_motion @@ -92,9 +98,7 @@ def run(self, **job_kwargs): t2 = time.perf_counter() peak_locations = localize_peaks(self.recording, selected_peaks, **p["localize_kwargs"], **job_kwargs) t3 = time.perf_counter() - motion, temporal_bins, spatial_bins = estimate_motion( - self.recording, selected_peaks, peak_locations, **p["estimate_motion_kwargs"] - ) + motion = estimate_motion(self.recording, selected_peaks, peak_locations, **p["estimate_motion_kwargs"]) t4 = time.perf_counter() step_run_times = dict( @@ -106,43 +110,37 @@ def run(self, **job_kwargs): self.result["step_run_times"] = step_run_times self.result["raw_motion"] = motion - self.result["temporal_bins"] = temporal_bins - self.result["spatial_bins"] = spatial_bins def compute_result(self, **result_params): raw_motion = self.result["raw_motion"] - temporal_bins = self.result["temporal_bins"] - spatial_bins = self.result["spatial_bins"] gt_motion = get_gt_motion_from_unit_displacement( self.unit_displacements, self.displacement_sampling_frequency, self.unit_locations, - temporal_bins, - spatial_bins, + raw_motion.temporal_bins_s[0], + raw_motion.spatial_bins_um, direction_dim=self.direction_dim, ) # align globally gt_motion and motion to avoid offsets motion = raw_motion.copy() - motion += np.median(gt_motion - motion) + motion.displacement[0] += np.median(gt_motion.displacement[0] - motion.displacement[0]) self.result["gt_motion"] = gt_motion self.result["motion"] = motion _run_key_saved = [ - ("raw_motion", "npy"), - ("temporal_bins", "npy"), - ("spatial_bins", "npy"), + ("raw_motion", "Motion"), ("step_run_times", "pickle"), ] _result_key_saved = [ ( "gt_motion", - "npy", + "Motion", ), ( "motion", - "npy", + "Motion", ), ] @@ -189,20 +187,20 @@ def plot_drift(self, case_keys=None, gt_drift=True, tested_drift=True, scaling_p # dirft ax = ax1 = fig.add_subplot(gs[2:7]) ax1.sharey(ax0) - temporal_bins = bench.result["temporal_bins"] - spatial_bins = bench.result["spatial_bins"] + # temporal_bins_s = bench.result["temporal_bins_s"] + # spatial_bins_um = bench.result["spatial_bins_um"] gt_motion = bench.result["gt_motion"] motion = bench.result["motion"] # for i in range(self.gt_unit_positions.shape[1]): - # ax.plot(temporal_bins, self.gt_unit_positions[:, i], alpha=0.5, ls="--", c="0.5") + # ax.plot(temporal_bins_s, self.gt_unit_positions[:, i], alpha=0.5, ls="--", c="0.5") - for i in range(gt_motion.shape[1]): - depth = spatial_bins[i] + for i in range(gt_motion.displacement[0].shape[1]): + depth = motion.spatial_bins_um[i] if gt_drift: - ax.plot(temporal_bins, gt_motion[:, i] + depth, color="green", lw=4) + ax.plot(motion.temporal_bins_s[0], gt_motion.displacement[0][:, i] + depth, color="green", lw=4) if tested_drift: - ax.plot(temporal_bins, motion[:, i] + depth, color="cyan", lw=2) + ax.plot(motion.temporal_bins_s[0], motion.displacement[0][:, i] + depth, color="cyan", lw=2) ax.set_xlabel("time (s)") _simpleaxis(ax) @@ -241,14 +239,14 @@ def plot_errors(self, case_keys=None, figsize=None, lim=None): gt_motion = bench.result["gt_motion"] motion = bench.result["motion"] - temporal_bins = bench.result["temporal_bins"] - spatial_bins = bench.result["spatial_bins"] + # temporal_bins_s = bench.result["temporal_bins_s"] + # spatial_bins_um = bench.result["spatial_bins_um"] fig = plt.figure(figsize=figsize) gs = fig.add_gridspec(2, 2) - errors = gt_motion - motion + errors = gt_motion.displacement[0] - motion.displacement[0] channel_positions = bench.recording.get_channel_locations() probe_y_min, probe_y_max = channel_positions[:, 1].min(), channel_positions[:, 1].max() @@ -259,7 +257,12 @@ def plot_errors(self, case_keys=None, figsize=None, lim=None): aspect="auto", interpolation="nearest", origin="lower", - extent=(temporal_bins[0], temporal_bins[-1], spatial_bins[0], spatial_bins[-1]), + extent=( + motion.temporal_bins_s[0][0], + motion.temporal_bins_s[0][-1], + motion.spatial_bins_um[0], + motion.spatial_bins_um[-1], + ), ) plt.colorbar(im, ax=ax, label="error") ax.set_ylabel("depth (um)") @@ -270,7 +273,7 @@ def plot_errors(self, case_keys=None, figsize=None, lim=None): ax = fig.add_subplot(gs[1, 0]) mean_error = np.sqrt(np.mean((errors) ** 2, axis=1)) - ax.plot(temporal_bins, mean_error) + ax.plot(motion.temporal_bins_s[0], mean_error) ax.set_xlabel("time (s)") ax.set_ylabel("error") _simpleaxis(ax) @@ -279,7 +282,7 @@ def plot_errors(self, case_keys=None, figsize=None, lim=None): ax = fig.add_subplot(gs[1, 1]) depth_error = np.sqrt(np.mean((errors) ** 2, axis=0)) - ax.plot(spatial_bins, depth_error) + ax.plot(motion.spatial_bins_um, depth_error) ax.axvline(probe_y_min, color="k", ls="--", alpha=0.5) ax.axvline(probe_y_max, color="k", ls="--", alpha=0.5) ax.set_xlabel("depth (um)") @@ -289,6 +292,7 @@ def plot_errors(self, case_keys=None, figsize=None, lim=None): ax.set_ylim(0, lim) def plot_summary_errors(self, case_keys=None, show_legend=True, figsize=(15, 5)): + import matplotlib.pyplot as plt if case_keys is None: case_keys = list(self.cases.keys()) @@ -304,17 +308,17 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, figsize=(15, 5)) gt_motion = bench.result["gt_motion"] motion = bench.result["motion"] - temporal_bins = bench.result["temporal_bins"] - spatial_bins = bench.result["spatial_bins"] + # temporal_bins_s = bench.result["temporal_bins_s"] + # spatial_bins_um = bench.result["spatial_bins_um"] # c = colors[count] if colors is not None else None c = colors[key] - errors = gt_motion - motion + errors = gt_motion.displacement[0] - motion.displacement[0] mean_error = np.sqrt(np.mean((errors) ** 2, axis=1)) depth_error = np.sqrt(np.mean((errors) ** 2, axis=0)) - axes[0].plot(temporal_bins, mean_error, lw=1, label=label, color=c) + axes[0].plot(motion.temporal_bins_s[0], mean_error, lw=1, label=label, color=c) parts = axes[1].violinplot(mean_error, [count], showmeans=True) if c is not None: for pc in parts["bodies"]: @@ -324,7 +328,7 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, figsize=(15, 5)) if k != "bodies": # for line in parts[k]: parts[k].set_color(c) - axes[2].plot(spatial_bins, depth_error, label=label, color=c) + axes[2].plot(motion.spatial_bins_um, depth_error, label=label, color=c) ax0 = ax = axes[0] ax.set_xlabel("Time [s]") @@ -361,8 +365,8 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, figsize=(15, 5)) # "peaks", # "selected_peaks", # "motion", -# "temporal_bins", -# "spatial_bins", +# "temporal_bins_s", +# "spatial_bins_um", # "peak_locations", # "gt_motion", # ) @@ -438,7 +442,7 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, figsize=(15, 5)) # self.recording, self.selected_peaks, **self.localize_kwargs, **self.job_kwargs # ) # t3 = time.perf_counter() -# self.motion, self.temporal_bins, self.spatial_bins = estimate_motion( +# self.motion, self.temporal_bins_s, self.spatial_bins_um = estimate_motion( # self.recording, self.selected_peaks, self.peak_locations, **self.estimate_motion_kwargs # ) @@ -463,7 +467,7 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, figsize=(15, 5)) # def run_estimate_motion(self): # # usefull to re run only the motion estimate with peak localization # t3 = time.perf_counter() -# self.motion, self.temporal_bins, self.spatial_bins = estimate_motion( +# self.motion, self.temporal_bins_s, self.spatial_bins_um = estimate_motion( # self.recording, self.selected_peaks, self.peak_locations, **self.estimate_motion_kwargs # ) # t4 = time.perf_counter() @@ -479,7 +483,7 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, figsize=(15, 5)) # self.save_to_folder() # def compute_gt_motion(self): -# self.gt_unit_positions, _ = mr.extract_units_drift_vector(self.mearec_filename, time_vector=self.temporal_bins) +# self.gt_unit_positions, _ = mr.extract_units_drift_vector(self.mearec_filename, time_vector=self.temporal_bins_s) # template_locations = np.array(mr.load_recordings(self.mearec_filename).template_locations) # assert len(template_locations.shape) == 3 @@ -489,18 +493,18 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, figsize=(15, 5)) # unit_motions = self.gt_unit_positions - unit_mid_positions # # unit_positions = np.mean(self.gt_unit_positions, axis=0) -# if self.spatial_bins is None: +# if self.spatial_bins_um is None: # self.gt_motion = np.mean(unit_motions, axis=1)[:, None] # channel_positions = self.recording.get_channel_locations() # probe_y_min, probe_y_max = channel_positions[:, 1].min(), channel_positions[:, 1].max() # center = (probe_y_min + probe_y_max) // 2 -# self.spatial_bins = np.array([center]) +# self.spatial_bins_um = np.array([center]) # else: # # time, units # self.gt_motion = np.zeros_like(self.motion) # for t in range(self.gt_unit_positions.shape[0]): # f = scipy.interpolate.interp1d(unit_mid_positions, unit_motions[t, :], fill_value="extrapolate") -# self.gt_motion[t, :] = f(self.spatial_bins) +# self.gt_motion[t, :] = f(self.spatial_bins_um) # def plot_true_drift(self, scaling_probe=1.5, figsize=(15, 10), axes=None): # if axes is None: @@ -534,11 +538,11 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, figsize=(15, 5)) # ax = axes[1] # for i in range(self.gt_unit_positions.shape[1]): -# ax.plot(self.temporal_bins, self.gt_unit_positions[:, i], alpha=0.5, ls="--", c="0.5") +# ax.plot(self.temporal_bins_s, self.gt_unit_positions[:, i], alpha=0.5, ls="--", c="0.5") # for i in range(self.gt_motion.shape[1]): -# depth = self.spatial_bins[i] -# ax.plot(self.temporal_bins, self.gt_motion[:, i] + depth, color="green", lw=4) +# depth = self.spatial_bins_um[i] +# ax.plot(self.temporal_bins_s, self.gt_motion[:, i] + depth, color="green", lw=4) # # ax.set_ylim(ymin, ymax) # ax.set_xlabel("time (s)") @@ -617,15 +621,15 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, figsize=(15, 5)) # ax.axhline(probe_y_max, color="k", ls="--", alpha=0.5) # if show_drift: -# if self.spatial_bins is None: +# if self.spatial_bins_um is None: # center = (probe_y_min + probe_y_max) // 2 -# ax.plot(self.temporal_bins, self.gt_motion[:, 0] + center, color="green", lw=1.5) -# ax.plot(self.temporal_bins, self.motion[:, 0] + center, color="orange", lw=1.5) +# ax.plot(self.temporal_bins_s, self.gt_motion[:, 0] + center, color="green", lw=1.5) +# ax.plot(self.temporal_bins_s, self.motion[:, 0] + center, color="orange", lw=1.5) # else: # for i in range(self.gt_motion.shape[1]): -# depth = self.spatial_bins[i] -# ax.plot(self.temporal_bins, self.gt_motion[:, i] + depth, color="green", lw=1.5) -# ax.plot(self.temporal_bins, self.motion[:, i] + depth, color="orange", lw=1.5) +# depth = self.spatial_bins_um[i] +# ax.plot(self.temporal_bins_s, self.gt_motion[:, i] + depth, color="green", lw=1.5) +# ax.plot(self.temporal_bins_s, self.motion[:, i] + depth, color="orange", lw=1.5) # if show_histogram: # ax2 = fig.add_subplot(gs[3]) @@ -669,10 +673,9 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, figsize=(15, 5)) # peak_locations_corrected = correct_motion_on_peaks( # self.selected_peaks, # self.peak_locations, -# self.recording.sampling_frequency, # self.motion, -# self.temporal_bins, -# self.spatial_bins, +# self.temporal_bins_s, +# self.spatial_bins_um, # direction="y", # ) # if axes is None: @@ -734,18 +737,18 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, figsize=(15, 5)) # colors = plt.colormaps["jet"].resampled(n) # for i in range(0, n, step): # ax = axs[0] -# ax.plot(self.temporal_bins, self.gt_motion[:, i], lw=1.5, ls="--", color=colors(i)) +# ax.plot(self.temporal_bins_s, self.gt_motion[:, i], lw=1.5, ls="--", color=colors(i)) # ax.plot( -# self.temporal_bins, +# self.temporal_bins_s, # self.motion[:, i], # lw=1.5, # ls="-", # color=colors(i), -# label=f"{self.spatial_bins[i]:0.1f}", +# label=f"{self.spatial_bins_um[i]:0.1f}", # ) # ax = axs[1] -# ax.plot(self.temporal_bins, self.motion[:, i] - self.gt_motion[:, i], lw=1.5, ls="-", color=colors(i)) +# ax.plot(self.temporal_bins_s, self.motion[:, i] - self.gt_motion[:, i], lw=1.5, ls="-", color=colors(i)) # ax = axs[0] # ax.set_title(self.title) @@ -774,7 +777,7 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, figsize=(15, 5)) # aspect="auto", # interpolation="nearest", # origin="lower", -# extent=(self.temporal_bins[0], self.temporal_bins[-1], self.spatial_bins[0], self.spatial_bins[-1]), +# extent=(self.temporal_bins_s[0], self.temporal_bins_s[-1], self.spatial_bins_um[0], self.spatial_bins_um[-1]), # ) # plt.colorbar(im, ax=ax, label="error") # ax.set_ylabel("depth (um)") @@ -785,7 +788,7 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, figsize=(15, 5)) # ax = fig.add_subplot(gs[1, 0]) # mean_error = np.sqrt(np.mean((errors) ** 2, axis=1)) -# ax.plot(self.temporal_bins, mean_error) +# ax.plot(self.temporal_bins_s, mean_error) # ax.set_xlabel("time (s)") # ax.set_ylabel("error") # _simpleaxis(ax) @@ -794,7 +797,7 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, figsize=(15, 5)) # ax = fig.add_subplot(gs[1, 1]) # depth_error = np.sqrt(np.mean((errors) ** 2, axis=0)) -# ax.plot(self.spatial_bins, depth_error) +# ax.plot(self.spatial_bins_um, depth_error) # ax.axvline(probe_y_min, color="k", ls="--", alpha=0.5) # ax.axvline(probe_y_max, color="k", ls="--", alpha=0.5) # ax.set_xlabel("depth (um)") @@ -816,7 +819,7 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, figsize=(15, 5)) # mean_error = np.sqrt(np.mean((errors) ** 2, axis=1)) # depth_error = np.sqrt(np.mean((errors) ** 2, axis=0)) -# axes[0].plot(benchmark.temporal_bins, mean_error, lw=1, label=benchmark.title, color=c) +# axes[0].plot(benchmark.temporal_bins_s, mean_error, lw=1, label=benchmark.title, color=c) # parts = axes[1].violinplot(mean_error, [count], showmeans=True) # if c is not None: # for pc in parts["bodies"]: @@ -826,7 +829,7 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, figsize=(15, 5)) # if k != "bodies": # # for line in parts[k]: # parts[k].set_color(c) -# axes[2].plot(benchmark.spatial_bins, depth_error, label=benchmark.title, color=c) +# axes[2].plot(benchmark.spatial_bins_um, depth_error, label=benchmark.title, color=c) # ax0 = ax = axes[0] # ax.set_xlabel("Time [s]") @@ -875,10 +878,10 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, figsize=(15, 5)) # interpolation="nearest", # origin="lower", # extent=( -# benchmark.temporal_bins[0], -# benchmark.temporal_bins[-1], -# benchmark.spatial_bins[0], -# benchmark.spatial_bins[-1], +# benchmark.temporal_bins_s[0], +# benchmark.temporal_bins_s[-1], +# benchmark.spatial_bins_um[0], +# benchmark.spatial_bins_um[-1], # ), # ) # fig.colorbar(im, ax=ax, label="error") @@ -896,11 +899,11 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, figsize=(15, 5)) # def plot_motions_several_benchmarks(benchmarks): # fig, ax = plt.subplots(figsize=(15, 5)) -# ax.plot(list(benchmarks)[0].temporal_bins, list(benchmarks)[0].gt_motion[:, 0], lw=2, c="k", label="real motion") +# ax.plot(list(benchmarks)[0].temporal_bins_s, list(benchmarks)[0].gt_motion[:, 0], lw=2, c="k", label="real motion") # for count, benchmark in enumerate(benchmarks): -# ax.plot(benchmark.temporal_bins, benchmark.motion.mean(1), lw=1, c=f"C{count}", label=benchmark.title) +# ax.plot(benchmark.temporal_bins_s, benchmark.motion.mean(1), lw=1, c=f"C{count}", label=benchmark.title) # ax.fill_between( -# benchmark.temporal_bins, +# benchmark.temporal_bins_s, # benchmark.motion.mean(1) - benchmark.motion.std(1), # benchmark.motion.mean(1) + benchmark.motion.std(1), # color=f"C{count}", diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py index a515424648..a6ff05fc55 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py @@ -44,9 +44,7 @@ def run(self, **job_kwargs): recording = self.drifting_recording elif self.params["recording_source"] == "corrected": correct_motion_kwargs = self.params["correct_motion_kwargs"] - recording = InterpolateMotionRecording( - self.drifting_recording, self.motion, self.temporal_bins, self.spatial_bins, **correct_motion_kwargs - ) + recording = InterpolateMotionRecording(self.drifting_recording, self.motion, **correct_motion_kwargs) else: raise ValueError("recording_source") diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py index b2cf56eb9c..e9f128993d 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py @@ -406,6 +406,8 @@ def _save_keys(self, saved_keys, folder): pickle.dump(self.result[k], f) elif format == "sorting": self.result[k].save(folder=folder / k, format="numpy_folder", overwrite=True) + elif format == "Motion": + self.result[k].save(folder=folder / k) elif format == "zarr_templates": self.result[k].to_zarr(folder / k) elif format == "sorting_analyzer": @@ -440,6 +442,10 @@ def load_folder(cls, folder): from spikeinterface.core import load_extractor result[k] = load_extractor(folder / k) + elif format == "Motion": + from spikeinterface.sortingcomponents.motion_utils import Motion + + result[k] = Motion.load(folder / k) elif format == "zarr_templates": from spikeinterface.core.template import Templates diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_estimation.py b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_estimation.py index 696531b221..526cc2e92f 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_estimation.py @@ -65,8 +65,10 @@ def test_benchmark_motion_estimaton(create_cache_folder): # plots study.plot_true_drift() + study.plot_drift() study.plot_errors() study.plot_summary_errors() + import matplotlib.pyplot as plt plt.show() diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_interpolation.py b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_interpolation.py index 4b7264a9de..6d80d027f2 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_interpolation.py @@ -48,8 +48,11 @@ def test_benchmark_motion_interpolation(create_cache_folder): spatial_bins, direction_dim=1, ) + # print(gt_motion) + + # import matplotlib.pyplot as plt # fig, ax = plt.subplots() - # ax.imshow(gt_motion.T) + # ax.imshow(gt_motion.displacement[0].T) # plt.show() cases = {} @@ -130,6 +133,8 @@ def test_benchmark_motion_interpolation(create_cache_folder): study.plot_sorting_accuracy(mode="depth", mode_best_merge=False) study.plot_sorting_accuracy(mode="depth", mode_best_merge=True) + import matplotlib.pyplot as plt + plt.show() diff --git a/src/spikeinterface/sortingcomponents/matching/naive.py b/src/spikeinterface/sortingcomponents/matching/naive.py index c172e90fd8..0dc71d789b 100644 --- a/src/spikeinterface/sortingcomponents/matching/naive.py +++ b/src/spikeinterface/sortingcomponents/matching/naive.py @@ -4,7 +4,7 @@ import numpy as np -from spikeinterface.core import get_noise_levels, get_channel_distances, get_chunk_with_margin, get_random_data_chunks +from spikeinterface.core import get_noise_levels, get_channel_distances, get_random_data_chunks from spikeinterface.sortingcomponents.peak_detection import DetectPeakLocallyExclusive from spikeinterface.core.template import Templates diff --git a/src/spikeinterface/sortingcomponents/motion_estimation.py b/src/spikeinterface/sortingcomponents/motion_estimation.py index 9eb5415316..3134d68681 100644 --- a/src/spikeinterface/sortingcomponents/motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/motion_estimation.py @@ -1,7 +1,11 @@ from __future__ import annotations -import numpy as np from tqdm.auto import tqdm, trange +import numpy as np + + +from .motion_utils import Motion +from .tools import make_multi_method_doc try: import torch @@ -11,8 +15,6 @@ except ImportError: HAVE_TORCH = False -from .tools import make_multi_method_doc - def estimate_motion( recording, @@ -55,7 +57,7 @@ def estimate_motion( **histogram section** direction: "x" | "y" | "z", default: "y" - Dimension on which the motion is estimated + Dimension on which the motion is estimated. "y" is depth along the probe. bin_duration_s: float, default: 10 Bin duration in second bin_um: float, default: 10 @@ -105,19 +107,8 @@ def estimate_motion( Returns ------- - motion: numpy array 2d - Motion estimate in um. - Shape (temporal bins, spatial bins) - motion.shape[0] = temporal_bins.shape[0] - motion.shape[1] = 1 (rigid) or spatial_bins.shape[1] (non rigid) - If upsample_to_histogram_bin, motion.shape[1] corresponds to spatial - bins given by bin_um. - temporal_bins: numpy.array 1d - temporal bins (bin center) - spatial_bins: numpy.array 1d - Windows center. - spatial_bins.shape[0] == motion.shape[1] - If rigid then spatial_bins.shape[0] == 1 + motion: Motion object + The motion object. extra_check: dict Optional output if `output_extra_check=True` This dict contain histogram, pairwise_displacement usefull for ploting. @@ -148,7 +139,7 @@ def estimate_motion( # run method method_class = estimate_motion_methods[method] - motion, temporal_bins = method_class.run( + motion_array, temporal_bins = method_class.run( recording, peaks, peak_locations, @@ -164,27 +155,30 @@ def estimate_motion( ) # replace nan by zeros - motion[np.isnan(motion)] = 0 + np.nan_to_num(motion_array, copy=False) if post_clean: - motion = clean_motion_vector( - motion, temporal_bins, bin_duration_s, speed_threshold=speed_threshold, sigma_smooth_s=sigma_smooth_s + motion_array = clean_motion_vector( + motion_array, temporal_bins, bin_duration_s, speed_threshold=speed_threshold, sigma_smooth_s=sigma_smooth_s ) if upsample_to_histogram_bin is None: upsample_to_histogram_bin = not rigid if upsample_to_histogram_bin: - extra_check["motion"] = motion + extra_check["motion_array"] = motion_array extra_check["non_rigid_window_centers"] = non_rigid_window_centers non_rigid_windows = np.array(non_rigid_windows) non_rigid_windows /= non_rigid_windows.sum(axis=0, keepdims=True) non_rigid_window_centers = spatial_bin_edges[:-1] + bin_um / 2 - motion = motion @ non_rigid_windows + motion_array = motion_array @ non_rigid_windows + + # TODO handle multi segment + motion = Motion([motion_array], [temporal_bins], non_rigid_window_centers, direction=direction) if output_extra_check: - return motion, temporal_bins, non_rigid_window_centers, extra_check + return motion, extra_check else: - return motion, temporal_bins, non_rigid_window_centers + return motion class DecentralizedRegistration: @@ -342,7 +336,7 @@ def run( extra_check["spatial_hist_bin_edges"] = spatial_hist_bin_edges # temporal bins are bin center - temporal_bins = temporal_hist_bin_edges[:-1] + bin_duration_s // 2.0 + temporal_bins = 0.5 * (temporal_hist_bin_edges[1:] + temporal_hist_bin_edges[:-1]) motion = np.zeros((temporal_bins.size, len(non_rigid_windows)), dtype=np.float64) windows_iter = non_rigid_windows @@ -690,16 +684,15 @@ def make_2d_motion_histogram( spatial_bin_edges 1d array with spatial bin edges """ - fs = recording.get_sampling_frequency() - num_samples = recording.get_num_samples(segment_index=0) - bin_sample_size = int(bin_duration_s * fs) - sample_bin_edges = np.arange(0, num_samples + bin_sample_size, bin_sample_size) - temporal_bin_edges = sample_bin_edges / fs + n_samples = recording.get_num_samples() + mint_s = recording.sample_index_to_time(0) + maxt_s = recording.sample_index_to_time(n_samples) + temporal_bin_edges = np.arange(mint_s, maxt_s + bin_duration_s, bin_duration_s) if spatial_bin_edges is None: spatial_bin_edges = get_spatial_bin_edges(recording, direction, margin_um, bin_um) arr = np.zeros((peaks.size, 2), dtype="float64") - arr[:, 0] = peaks["sample_index"] + arr[:, 0] = recording.sample_index_to_time(peaks["sample_index"]) arr[:, 1] = peak_locations[direction] if weight_with_amplitude: @@ -707,11 +700,11 @@ def make_2d_motion_histogram( else: weights = None - motion_histogram, edges = np.histogramdd(arr, bins=(sample_bin_edges, spatial_bin_edges), weights=weights) + motion_histogram, edges = np.histogramdd(arr, bins=(temporal_bin_edges, spatial_bin_edges), weights=weights) # average amplitude in each bin if weight_with_amplitude: - bin_counts, _ = np.histogramdd(arr, bins=(sample_bin_edges, spatial_bin_edges)) + bin_counts, _ = np.histogramdd(arr, bins=(temporal_bin_edges, spatial_bin_edges)) bin_counts[bin_counts == 0] = 1 motion_histogram = motion_histogram / bin_counts @@ -766,11 +759,10 @@ def make_3d_motion_histograms( spatial_bin_edges 1d array with spatial bin edges """ - fs = recording.get_sampling_frequency() - num_samples = recording.get_num_samples(segment_index=0) - bin_sample_size = int(bin_duration_s * fs) - sample_bin_edges = np.arange(0, num_samples + bin_sample_size, bin_sample_size) - temporal_bin_edges = sample_bin_edges / fs + n_samples = recording.get_num_samples() + mint_s = recording.sample_index_to_time(0) + maxt_s = recording.sample_index_to_time(n_samples) + temporal_bin_edges = np.arange(mint_s, maxt_s + bin_duration_s, bin_duration_s) if spatial_bin_edges is None: spatial_bin_edges = get_spatial_bin_edges(recording, direction, margin_um, bin_um) @@ -785,14 +777,14 @@ def make_3d_motion_histograms( ) arr = np.zeros((peaks.size, 3), dtype="float64") - arr[:, 0] = peaks["sample_index"] + arr[:, 0] = recording.sample_index_to_time(peaks["sample_index"]) arr[:, 1] = peak_locations[direction] arr[:, 2] = abs_peaks_log_norm motion_histograms, edges = np.histogramdd( arr, bins=( - sample_bin_edges, + temporal_bin_edges, spatial_bin_edges, amplitude_bin_edges, ), @@ -825,7 +817,6 @@ def compute_pairwise_displacement( """ Compute pairwise displacement """ - from scipy import sparse from scipy import linalg assert conv_engine in ("torch", "numpy"), f"'conv_engine' must be 'torch' or 'numpy'" diff --git a/src/spikeinterface/sortingcomponents/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion_interpolation.py index 5e3733b363..32bb7634e9 100644 --- a/src/spikeinterface/sortingcomponents/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion_interpolation.py @@ -1,40 +1,26 @@ from __future__ import annotations import numpy as np - - from spikeinterface.core.core_tools import define_function_from_class -from spikeinterface.preprocessing.basepreprocessor import BasePreprocessor, BasePreprocessorSegment from spikeinterface.preprocessing import get_spatial_interpolation_kernel +from spikeinterface.preprocessing.basepreprocessor import BasePreprocessor, BasePreprocessorSegment +from spikeinterface.preprocessing.filter import fix_dtype -def correct_motion_on_peaks( - peaks, - peak_locations, - sampling_frequency, - motion, - temporal_bins, - spatial_bins, - direction="y", -): +def correct_motion_on_peaks(peaks, peak_locations, motion, recording): """ Given the output of estimate_motion(), apply inverse motion on peak locations. Parameters ---------- - peaks: np.array + peaks : np.array peaks vector - peak_locations: np.array + peak_locations : np.array peaks location vector - sampling_frequency: np.array - sampling_frequency of the recording - motion: np.array 2D - motion.shape[0] equal temporal_bins.shape[0] - motion.shape[1] equal 1 when "rigid" motion equal temporal_bins.shape[0] when "non-rigid" - temporal_bins: np.array - Temporal bins in second. - spatial_bins: np.array - Bins for non-rigid motion. If spatial_bins.sahpe[0] == 1 then rigid motion is used. + motion : Motion + The motion object. + recording : Recording + The recording object. This is used to convert sample indices to times. Returns ------- @@ -42,21 +28,17 @@ def correct_motion_on_peaks( Motion-corrected peak locations """ corrected_peak_locations = peak_locations.copy() - import scipy.interpolate - - spike_times = peaks["sample_index"] / sampling_frequency - if spatial_bins.shape[0] == 1: - # rigid motion interpolation 1D - f = scipy.interpolate.interp1d(temporal_bins, motion[:, 0], bounds_error=False, fill_value="extrapolate") - shift = f(spike_times) - corrected_peak_locations[direction] -= shift - else: - # non rigid motion = interpolation 2D - f = scipy.interpolate.RegularGridInterpolator( - (temporal_bins, spatial_bins), motion, method="linear", bounds_error=False, fill_value=None + + for segment_index in range(motion.num_segments): + times_s = recording.sample_index_to_time(peaks["sample_index"], segment_index=segment_index) + i0, i1 = np.searchsorted(peaks["segment_index"], [segment_index, segment_index + 1]) + + spike_times = times_s[i0:i1] + spike_locs = peak_locations[motion.direction][i0:i1] + spike_displacement = motion.get_displacement_at_time_and_depth( + spike_times, spike_locs, segment_index=segment_index ) - shift = f(np.c_[spike_times, peak_locations[direction]]) - corrected_peak_locations[direction] -= shift + corrected_peak_locations[i0:i1][motion.direction] -= spike_displacement return corrected_peak_locations @@ -66,12 +48,12 @@ def interpolate_motion_on_traces( times, channel_locations, motion, - temporal_bins, - spatial_bins, - direction=1, + segment_index=None, channel_inds=None, + interpolation_time_bin_centers_s=None, spatial_interpolation_method="kriging", spatial_interpolation_kwargs={}, + dtype=None, ): """ Apply inverse motion with spatial interpolation on traces. @@ -82,20 +64,19 @@ def interpolate_motion_on_traces( ---------- traces : np.array Trace snippet (num_samples, num_channels) + times : np.array + Sample times in seconds for the frames of the traces snippet channel_location: np.array 2d Channel location with shape (n, 2) or (n, 3) - motion: np.array 2D - motion.shape[0] equal temporal_bins.shape[0] - motion.shape[1] equal 1 when "rigid" motion - equal temporal_bins.shape[0] when "none rigid" - temporal_bins: np.array - Temporal bins in second. - spatial_bins: None or np.array - Bins for non-rigid motion. If None, rigid motion is used - direction: int in (0, 1, 2) - Dimension of shift in channel_locations. + motion: Motion + The motion object. + segment_index: int or None + The segment index. channel_inds: None or list If not None, interpolate only a subset of channels. + interpolation_time_bin_centers_s : None or np.array + Manually specify the time bins which the interpolation happens + in for this segment. If None, these are the motion estimate's time bins. spatial_interpolation_method: "idw" | "kriging", default: "kriging" The spatial interpolation method used to interpolate the channel locations: * idw : Inverse Distance Weighing @@ -105,40 +86,62 @@ def interpolate_motion_on_traces( Returns ------- - channel_motions: np.array - Shift over time by channel - Shape (times.shape[0], channel_location.shape[0]) + traces_corrected: np.array + Motion-corrected trace snippet, (num_samples, num_channels) """ # assert HAVE_NUMBA assert times.shape[0] == traces.shape[0] + if dtype is None: + dtype = traces.dtype + if dtype.kind != "f": + raise ValueError(f"Can't interpolate_motion with dtype {dtype}.") + if traces.dtype != dtype: + traces = traces.astype(dtype) + + if segment_index is None: + if motion.num_segments == 1: + segment_index = 0 + else: + raise ValueError("Several segment need segment_index=") + if channel_inds is None: traces_corrected = np.zeros(traces.shape, dtype=traces.dtype) else: channel_inds = np.asarray(channel_inds) traces_corrected = np.zeros((traces.shape[0], channel_inds.size), dtype=traces.dtype) - # regroup times by closet temporal_bins - bin_inds = _get_closest_ind(temporal_bins, times) + total_num_chans = channel_locations.shape[0] - # inperpolation kernel will be the same per temporal bin - for bin_ind in np.unique(bin_inds): - # Step 1 : channel motion - if spatial_bins.shape[0] == 1: - # rigid motion : same motion for all channels - channel_motions = motion[bin_ind, 0] - else: - # non rigid : interpolation channel motion for this temporal bin - import scipy.interpolate + # -- determine the blocks of frames that will land in the same interpolation time bin + time_bins = interpolation_time_bin_centers_s + if time_bins is None: + time_bins = motion.temporal_bins_s[segment_index] + bin_s = time_bins[1] - time_bins[0] + bins_start = time_bins[0] - 0.5 * bin_s + # nearest bin center for each frame? + bin_inds = (times - bins_start) // bin_s + bin_inds = bin_inds.astype(int) + # the time bins may not cover the whole set of times in the recording, + # so we need to clip these indices to the valid range + np.clip(bin_inds, 0, time_bins.size, out=bin_inds) - f = scipy.interpolate.interp1d( - spatial_bins, motion[bin_ind, :], kind="linear", axis=0, bounds_error=False, fill_value="extrapolate" - ) - locs = channel_locations[:, direction] - channel_motions = f(locs) + # -- what are the possibilities here anyway? + bins_here = np.arange(bin_inds[0], bin_inds[-1] + 1) + + # inperpolation kernel will be the same per temporal bin + interp_times = np.empty(total_num_chans) + current_start_index = 0 + for bin_ind in bins_here: + bin_time = time_bins[bin_ind] + interp_times.fill(bin_time) + channel_motions = motion.get_displacement_at_time_and_depth( + interp_times, + channel_locations[:, motion.dim], + segment_index=segment_index, + ) channel_locations_moved = channel_locations.copy() - channel_locations_moved[:, direction] += channel_motions - # channel_locations_moved[:, direction] -= channel_motions + channel_locations_moved[:, motion.dim] += channel_motions if channel_inds is not None: channel_locations_moved = channel_locations_moved[channel_inds] @@ -146,24 +149,35 @@ def interpolate_motion_on_traces( drift_kernel = get_spatial_interpolation_kernel( channel_locations, channel_locations_moved, - dtype="float32", + dtype=dtype, method=spatial_interpolation_method, **spatial_interpolation_kwargs, ) - i0 = np.searchsorted(bin_inds, bin_ind, side="left") - i1 = np.searchsorted(bin_inds, bin_ind, side="right") + # keep this for DEBUG + # import matplotlib.pyplot as plt + # fig, ax = plt.subplots() + # ax.matshow(drift_kernel) + # ax.set_title(f"bin_ind {bin_ind} - {bin_time}s - {spatial_interpolation_method}") + # plt.show() + + # quickly find the end of this bin, which is also the start of the next + next_start_index = current_start_index + np.searchsorted( + bin_inds[current_start_index:], bin_ind + 1, side="left" + ) + in_bin = slice(current_start_index, next_start_index) # here we use a simple np.matmul even if dirft_kernel can be super sparse. # because the speed for a sparse matmul is not so good when we disable multi threaad (due multi processing # in ChunkRecordingExecutor) - traces_corrected[i0:i1] = traces[i0:i1] @ drift_kernel + np.matmul(traces[in_bin], drift_kernel, out=traces_corrected[in_bin]) + current_start_index = next_start_index return traces_corrected # if HAVE_NUMBA: -# # @numba.jit(parallel=False) +# # @numba.jit(parallel=False) # @numba.jit(parallel=True) # def my_sparse_dot(data_in, data_out, sparse_chans, weights): # """ @@ -178,7 +192,7 @@ def interpolate_motion_on_traces( # num_samples = data_in.shape[0] # num_chan_out = data_out.shape[1] # num_sparse = sparse_chans.shape[1] -# # for sample_index in range(num_samples): +# # for sample_index in range(num_samples): # for sample_index in numba.prange(num_samples): # for out_chan in range(num_chan_out): # v = 0 @@ -205,24 +219,25 @@ def _get_closest_ind(array, values): class InterpolateMotionRecording(BasePreprocessor): """ - Recording that corrects motion on-the-fly given a motion vector estimation (rigid or non-rigid). - This internally applies a spatial interpolation on the original traces after reversing the motion. - `estimate_motion()` must be called before this to estimate the motion vector. + Interpolate the input recording's traces to correct for motion, according to the + motion estimate object `motion`. The interpolation is carried out "lazily" / on the fly + by applying a spatial interpolation on the original traces to estimate their values + at the positions of the probe's channels after being shifted inversely to the motion. + + To get a Motion object, use `interpolate_motion()`. + + By default, each frame is spatially interpolated by the motion at the nearest motion + estimation time bin -- in other words, the temporal resolution of the motion correction + is the same as the motion estimation's. However, this behavior can be changed by setting + `interpolation_time_bin_centers_s` or `interpolation_time_bin_size_s` below. In that case, + the motion estimate will be interpolated to match the interpolation time bins. Parameters ---------- recording: Recording The parent recording. - motion: np.array 2D - The motion signal obtained with `estimate_motion()` - motion.shape[0] must correspond to temporal_bins.shape[0] - motion.shape[1] is 1 when "rigid" motion and spatial_bins.shape[0] when "non-rigid" - temporal_bins: np.array - Temporal bins in second. - spatial_bins: None or np.array - Bins for non-rigid motion. If None, rigid motion is used - direction: 0 | 1 | 2, default: 1 - Dimension along which channel_locations are shifted (0 - x, 1 - y, 2 - z) + motion: Motion + The motion object spatial_interpolation_method: "kriging" | "idw" | "nearest", default: "kriging" The spatial interpolation method used to interpolate the channel locations. See `spikeinterface.preprocessing.get_spatial_interpolation_kernel()` for more details. @@ -239,10 +254,22 @@ class InterpolateMotionRecording(BasePreprocessor): Number of closest channels used by "idw" method for interpolation. border_mode: "remove_channels" | "force_extrapolate" | "force_zeros", default: "remove_channels" Control how channels are handled on border: - * "remove_channels": remove channels on the border, the recording has less channels * "force_extrapolate": keep all channel and force extrapolation (can lead to strange signal) * "force_zeros": keep all channel but set zeros when outside (force_extrapolate=False) + interpolation_time_bin_centers_s: np.array or list of np.array, optional + Spatially interpolate each frame according to the displacement estimate at its closest + bin center in this array. If not supplied, this is set to the motion estimate's time bin + centers. If it's supplied, the motion estimate is interpolated to these bin centers. + If you have a multi-segment recording, pass a list of these, one per segment. + interpolation_time_bin_size_s: float, optional + Similar to the previous argument: interpolation_time_bin_centers_s will be constructed + by bins spaced by interpolation_time_bin_size_s. This is ignored if interpolation_time_bin_centers_s + is supplied. + dtype : str or np.dtype, optional + Interpolation needs to convert to a floating dtype. If dtype is supplied, that will be used. + If the input recording is already floating and dtype=None, then its dtype is used by default. + If the input recording is integer, then float32 is used by default. Returns ------- @@ -250,60 +277,52 @@ class InterpolateMotionRecording(BasePreprocessor): Recording after motion correction """ - name = "correct_motion" + name = "interpolate_motion" def __init__( self, recording, motion, - temporal_bins, - spatial_bins, - direction=1, border_mode="remove_channels", spatial_interpolation_method="kriging", sigma_um=20.0, p=1, num_closest=3, + interpolation_time_bin_centers_s=None, + interpolation_time_bin_size_s=None, + dtype=None, + **spatial_interpolation_kwargs, ): - assert recording.get_num_segments() == 1, "correct_motion() is only available for single-segment recordings" - - # force as arrays - temporal_bins = np.asarray(temporal_bins) - motion = np.asarray(motion) - spatial_bins = np.asarray(spatial_bins) + # assert recording.get_num_segments() == 1, "correct_motion() is only available for single-segment recordings" channel_locations = recording.get_channel_locations() - assert channel_locations.ndim >= direction, ( - f"'direction' {direction} not available. " f"Channel locations have {channel_locations.ndim} dimensions." + assert channel_locations.ndim >= motion.dim, ( + f"'direction' {motion.direction} not available. " + f"Channel locations have {channel_locations.ndim} dimensions." + ) + spatial_interpolation_kwargs = dict( + sigma_um=sigma_um, p=p, num_closest=num_closest, **spatial_interpolation_kwargs ) - spatial_interpolation_kwargs = dict(sigma_um=sigma_um, p=p, num_closest=num_closest) if border_mode == "remove_channels": - locs = channel_locations[:, direction] - l0, l1 = np.min(channel_locations[:, direction]), np.max(channel_locations[:, direction]) + locs = channel_locations[:, motion.dim] + l0, l1 = np.min(locs), np.max(locs) - # compute max and min motion (with interpolation) - # and check if channels are inside + # check if channels stay inside the probe extents for all segments channel_inside = np.ones(locs.shape[0], dtype="bool") - for operator in (np.max, np.min): - if spatial_bins.shape[0] == 1: - best_motions = operator(motion[:, 0]) - else: - # non rigid : interpolation channel motion for this temporal bin - import scipy.spatial - import scipy.interpolate - - f = scipy.interpolate.interp1d( - spatial_bins, - operator(motion[:, :], axis=0), - kind="linear", - axis=0, - bounds_error=False, - fill_value="extrapolate", - ) - best_motions = f(locs) - channel_inside &= ((locs + best_motions) >= l0) & ((locs + best_motions) <= l1) - - (channel_inds,) = np.nonzero(channel_inside) + for segment_index in range(recording.get_num_segments()): + # evaluate the positions of all channels over all time bins + channel_displacements = motion.get_displacement_at_time_and_depth( + times_s=motion.temporal_bins_s[segment_index], + locations_um=locs, + grid=True, + ) + channel_locations_moved = locs[:, None] + channel_displacements + # check if these remain inside of the probe + seg_inside = channel_locations_moved.clip(l0, l1) == channel_locations_moved + seg_inside = seg_inside.all(axis=1) + channel_inside &= seg_inside + + channel_inds = np.flatnonzero(channel_inside) channel_ids = recording.channel_ids[channel_inds] spatial_interpolation_kwargs["force_extrapolate"] = False elif border_mode == "force_extrapolate": @@ -317,7 +336,14 @@ def __init__( else: raise ValueError("Wrong border_mode") - BasePreprocessor.__init__(self, recording, channel_ids=channel_ids) + if dtype is None: + if recording.dtype.kind == "f": + dtype = recording.dtype + else: + raise ValueError(f"Can't interpolate traces of recording with non-floating dtype={recording.dtype=}.") + + dtype_ = fix_dtype(recording, dtype) + BasePreprocessor.__init__(self, recording, channel_ids=channel_ids, dtype=dtype_) if border_mode == "remove_channels": # change the wiring of the probe @@ -327,32 +353,48 @@ def __init__( contact_vector["device_channel_indices"] = np.arange(len(channel_ids), dtype="int64") self.set_property("contact_vector", contact_vector) - for parent_segment in recording._recording_segments: + # handle manual interpolation_time_bin_centers_s + # the case where interpolation_time_bin_size_s is set is handled per-segment below + if interpolation_time_bin_centers_s is None: + if interpolation_time_bin_size_s is None: + interpolation_time_bin_centers_s = motion.temporal_bins_s + + for segment_index, parent_segment in enumerate(recording._recording_segments): + # finish the per-segment part of the time bin logic + if interpolation_time_bin_centers_s is None: + # in this case, interpolation_time_bin_size_s is set. + s_end = parent_segment.get_num_samples() + t_start, t_end = parent_segment.sample_index_to_time(np.array([0, s_end])) + halfbin = interpolation_time_bin_size_s / 2.0 + segment_interpolation_time_bins_s = np.arange(t_start + halfbin, t_end, interpolation_time_bin_size_s) + else: + segment_interpolation_time_bins_s = interpolation_time_bin_centers_s[segment_index] + rec_segment = InterpolateMotionRecordingSegment( parent_segment, channel_locations, motion, - temporal_bins, - spatial_bins, - direction, spatial_interpolation_method, spatial_interpolation_kwargs, channel_inds, + segment_index, + segment_interpolation_time_bins_s, + dtype=dtype_, ) self.add_recording_segment(rec_segment) self._kwargs = dict( recording=recording, motion=motion, - temporal_bins=temporal_bins, - spatial_bins=spatial_bins, - direction=direction, border_mode=border_mode, spatial_interpolation_method=spatial_interpolation_method, sigma_um=sigma_um, p=p, num_closest=num_closest, + interpolation_time_bin_centers_s=interpolation_time_bin_centers_s, + dtype=dtype_.str, ) + self._kwargs.update(spatial_interpolation_kwargs) class InterpolateMotionRecordingSegment(BasePreprocessorSegment): @@ -361,61 +403,51 @@ def __init__( parent_recording_segment, channel_locations, motion, - temporal_bins, - spatial_bins, - direction, spatial_interpolation_method, spatial_interpolation_kwargs, channel_inds, + segment_index, + interpolation_time_bin_centers_s, + dtype="float32", ): BasePreprocessorSegment.__init__(self, parent_recording_segment) self.channel_locations = channel_locations - self.motion = motion - self.temporal_bins = temporal_bins - self.spatial_bins = spatial_bins - self.direction = direction self.spatial_interpolation_method = spatial_interpolation_method self.spatial_interpolation_kwargs = spatial_interpolation_kwargs self.channel_inds = channel_inds + self.segment_index = segment_index + self.interpolation_time_bin_centers_s = interpolation_time_bin_centers_s + self.dtype = dtype + self.motion = motion def get_traces(self, start_frame, end_frame, channel_indices): if self.time_vector is not None: - raise NotImplementedError( - "time_vector for InterpolateMotionRecording do not work because temporal_bins start from 0" - ) - # times = np.asarray(self.time_vector[start_frame:end_frame]) + raise NotImplementedError("InterpolateMotionRecording does not yet support recordings with time_vectors.") if start_frame is None: start_frame = 0 if end_frame is None: end_frame = self.get_num_samples() - times = np.arange(end_frame - start_frame, dtype="float64") - times /= self.sampling_frequency - t0 = start_frame / self.sampling_frequency - # if self.t_start is not None: - # t0 = t0 + self.t_start - times += t0 - + times = self.parent_recording_segment.sample_index_to_time(np.arange(start_frame, end_frame)) traces = self.parent_recording_segment.get_traces(start_frame, end_frame, channel_indices=slice(None)) - - trace2 = interpolate_motion_on_traces( + traces = traces.astype(self.dtype) + traces = interpolate_motion_on_traces( traces, times, self.channel_locations, self.motion, - self.temporal_bins, - self.spatial_bins, - direction=self.direction, + segment_index=self.segment_index, channel_inds=self.channel_inds, spatial_interpolation_method=self.spatial_interpolation_method, spatial_interpolation_kwargs=self.spatial_interpolation_kwargs, + interpolation_time_bin_centers_s=self.interpolation_time_bin_centers_s, ) if channel_indices is not None: - trace2 = trace2[:, channel_indices] + traces = traces[:, channel_indices] - return trace2 + return traces -interpolate_motion = define_function_from_class(source_class=InterpolateMotionRecording, name="correct_motion") +interpolate_motion = define_function_from_class(source_class=InterpolateMotionRecording, name="interpolate_motion") diff --git a/src/spikeinterface/sortingcomponents/motion_utils.py b/src/spikeinterface/sortingcomponents/motion_utils.py new file mode 100644 index 0000000000..26d4b35b1a --- /dev/null +++ b/src/spikeinterface/sortingcomponents/motion_utils.py @@ -0,0 +1,230 @@ +import json +from pathlib import Path + +import numpy as np +import spikeinterface +from spikeinterface.core.core_tools import check_json + + +class Motion: + """ + Motion of the tissue relative the probe. + + Parameters + ---------- + displacement : numpy array 2d or list of + Motion estimate in um. + List is the number of segment. + For each semgent : + * shape (temporal bins, spatial bins) + * motion.shape[0] = temporal_bins.shape[0] + * motion.shape[1] = 1 (rigid) or spatial_bins.shape[1] (non rigid) + temporal_bins_s : numpy.array 1d or list of + temporal bins (bin center) + spatial_bins_um : numpy.array 1d + Windows center. + spatial_bins_um.shape[0] == displacement.shape[1] + If rigid then spatial_bins_um.shape[0] == 1 + direction : str, default: 'y' + Direction of the motion. + interpolation_method : str + How to determine the displacement between bin centers? See the docs + for scipy.interpolate.RegularGridInterpolator for options. + """ + + def __init__(self, displacement, temporal_bins_s, spatial_bins_um, direction="y", interpolation_method="linear"): + if isinstance(displacement, np.ndarray): + self.displacement = [displacement] + assert isinstance(temporal_bins_s, np.ndarray) + self.temporal_bins_s = [temporal_bins_s] + else: + assert isinstance(displacement, (list, tuple)) + self.displacement = displacement + self.temporal_bins_s = temporal_bins_s + + assert isinstance(spatial_bins_um, np.ndarray) + self.spatial_bins_um = spatial_bins_um + + self.num_segments = len(self.displacement) + self.interpolators = None + self.interpolation_method = interpolation_method + + self.direction = direction + self.dim = ["x", "y", "z"].index(direction) + self.check_properties() + + def check_properties(self): + assert all(d.ndim == 2 for d in self.displacement) + assert all(t.ndim == 1 for t in self.temporal_bins_s) + assert all(self.spatial_bins_um.shape == (d.shape[1],) for d in self.displacement) + + def __repr__(self): + nbins = self.spatial_bins_um.shape[0] + if nbins == 1: + rigid_txt = "rigid" + else: + rigid_txt = f"non-rigid - {nbins} spatial bins" + + interval_s = self.temporal_bins_s[0][1] - self.temporal_bins_s[0][0] + txt = f"Motion {rigid_txt} - interval {interval_s}s - {self.num_segments} segments" + return txt + + def make_interpolators(self): + from scipy.interpolate import RegularGridInterpolator + + self.interpolators = [ + RegularGridInterpolator( + (self.temporal_bins_s[j], self.spatial_bins_um), self.displacement[j], method=self.interpolation_method + ) + for j in range(self.num_segments) + ] + self.temporal_bounds = [(t[0], t[-1]) for t in self.temporal_bins_s] + self.spatial_bounds = (self.spatial_bins_um.min(), self.spatial_bins_um.max()) + + def get_displacement_at_time_and_depth(self, times_s, locations_um, segment_index=None, grid=False): + """Evaluate the motion estimate at times and positions + + Evaluate the motion estimate, returning the (linearly interpolated) estimated displacement + at the given times and locations. + + Parameters + ---------- + times_s: np.array + locations_um: np.array + Either this is a one-dimensional array (a vector of positions along self.dimension), or + else a 2d array with the 2 or 3 spatial dimensions indexed along axis=1. + segment_index: int, optional + grid : bool + If grid=False, the default, then times_s and locations_um should have the same one-dimensional + shape, and the returned displacement[i] is the displacement at time times_s[i] and location + locations_um[i]. + If grid=True, times_s and locations_um determine a grid of positions to evaluate the displacement. + Then the returned displacement[i,j] is the displacement at depth locations_um[i] and time times_s[j]. + + Returns + ------- + displacement : np.array + A displacement per input location, of shape times_s.shape if grid=False and (locations_um.size, times_s.size) + if grid=True. + """ + if self.interpolators is None: + self.make_interpolators() + + if segment_index is None: + if self.num_segments == 1: + segment_index = 0 + else: + raise ValueError("Several segment need segment_index=") + + times_s = np.asarray(times_s) + locations_um = np.asarray(locations_um) + + if locations_um.ndim == 1: + locations_um = locations_um + elif locations_um.ndim == 2: + locations_um = locations_um[:, self.dim] + else: + assert False + + times_s = times_s.clip(*self.temporal_bounds[segment_index]) + locations_um = locations_um.clip(*self.spatial_bounds) + + if grid: + # construct a grid over which to evaluate the displacement + locations_um, times_s = np.meshgrid(locations_um, times_s, indexing="ij") + out_shape = times_s.shape + locations_um = locations_um.ravel() + times_s = times_s.ravel() + else: + # usual case: input is a point cloud + assert locations_um.shape == times_s.shape + assert times_s.ndim == 1 + out_shape = times_s.shape + + points = np.column_stack((times_s, locations_um)) + displacement = self.interpolators[segment_index](points) + # reshape to grid domain shape if necessary + displacement = displacement.reshape(out_shape) + + return displacement + + def to_dict(self): + return dict( + displacement=self.displacement, + temporal_bins_s=self.temporal_bins_s, + spatial_bins_um=self.spatial_bins_um, + interpolation_method=self.interpolation_method, + ) + + def save(self, folder): + folder = Path(folder) + folder.mkdir(exist_ok=False, parents=True) + + info_file = folder / f"spikeinterface_info.json" + info = dict( + version=spikeinterface.__version__, + dev_mode=spikeinterface.DEV_MODE, + object="Motion", + num_segments=self.num_segments, + direction=self.direction, + interpolation_method=self.interpolation_method, + ) + with open(info_file, mode="w") as f: + json.dump(check_json(info), f, indent=4) + + np.save(folder / "spatial_bins_um.npy", self.spatial_bins_um) + + for segment_index in range(self.num_segments): + np.save(folder / f"displacement_seg{segment_index}.npy", self.displacement[segment_index]) + np.save(folder / f"temporal_bins_s_seg{segment_index}.npy", self.temporal_bins_s[segment_index]) + + @classmethod + def load(cls, folder): + folder = Path(folder) + + info_file = folder / f"spikeinterface_info.json" + err_msg = f"Motion.load(folder): the folder {folder} does not contain a Motion object." + if not info_file.exists(): + raise IOError(err_msg) + + with open(info_file, "r") as f: + info = json.load(f) + if "object" not in info or info["object"] != "Motion": + raise IOError(err_msg) + + direction = info["direction"] + interpolation_method = info["interpolation_method"] + spatial_bins_um = np.load(folder / "spatial_bins_um.npy") + displacement = [] + temporal_bins_s = [] + for segment_index in range(info["num_segments"]): + displacement.append(np.load(folder / f"displacement_seg{segment_index}.npy")) + temporal_bins_s.append(np.load(folder / f"temporal_bins_s_seg{segment_index}.npy")) + + return cls( + displacement, + temporal_bins_s, + spatial_bins_um, + direction=direction, + interpolation_method=interpolation_method, + ) + + def __eq__(self, other): + for segment_index in range(self.num_segments): + if not np.allclose(self.displacement[segment_index], other.displacement[segment_index]): + return False + if not np.allclose(self.temporal_bins_s[segment_index], other.temporal_bins_s[segment_index]): + return False + + if not np.allclose(self.spatial_bins_um, other.spatial_bins_um): + return False + + return True + + def copy(self): + return Motion( + self.displacement.copy(), + self.temporal_bins_s.copy(), + self.spatial_bins_um.copy(), + interpolation_method=self.interpolation_method, + ) diff --git a/src/spikeinterface/sortingcomponents/peak_detection.py b/src/spikeinterface/sortingcomponents/peak_detection.py index 11218a688f..b6f7709d27 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection.py +++ b/src/spikeinterface/sortingcomponents/peak_detection.py @@ -4,7 +4,7 @@ import copy -from typing import Tuple, Union, List, Dict, Any, Optional, Callable +from typing import Tuple, List, Optional import numpy as np @@ -24,7 +24,6 @@ ) from spikeinterface.postprocessing.unit_locations import get_convolution_weights -from ..core import get_chunk_with_margin from .tools import make_multi_method_doc @@ -43,11 +42,10 @@ except ImportError: HAVE_TORCH = False + """ TODO: * remove the wrapper class and move all implementation to instance - * - """ diff --git a/src/spikeinterface/sortingcomponents/peak_localization.py b/src/spikeinterface/sortingcomponents/peak_localization.py index fcae485af9..23faea2d79 100644 --- a/src/spikeinterface/sortingcomponents/peak_localization.py +++ b/src/spikeinterface/sortingcomponents/peak_localization.py @@ -204,7 +204,7 @@ def compute(self, traces, peaks, waveforms): wf = waveforms[idx][:, :, chan_inds] if self.feature == "ptp": - wf_data = wf.ptp(axis=1) + wf_data = np.ptp(wf, axis=1) elif self.feature == "mean": wf_data = wf.mean(axis=1) elif self.feature == "energy": @@ -293,7 +293,7 @@ def compute(self, traces, peaks, waveforms): wf = waveforms[i, :][:, chan_inds] if self.feature == "ptp": - wf_data = wf.ptp(axis=0) + wf_data = np.ptp(wf, axis=0) elif self.feature == "energy": wf_data = np.linalg.norm(wf, axis=0) elif self.feature == "peak_voltage": diff --git a/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py b/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py index 597eee7a99..af62ba52ec 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py @@ -1,17 +1,11 @@ -import pytest - -import shutil +from pathlib import Path import numpy as np - -from spikeinterface.sortingcomponents.peak_detection import detect_peaks -from spikeinterface.sortingcomponents.motion_estimation import estimate_motion - -from spikeinterface.sortingcomponents.motion_interpolation import InterpolateMotionRecording +import pytest from spikeinterface.core.node_pipeline import ExtractDenseWaveforms - +from spikeinterface.sortingcomponents.motion_estimation import estimate_motion +from spikeinterface.sortingcomponents.peak_detection import detect_peaks from spikeinterface.sortingcomponents.peak_localization import LocalizeCenterOfMass - from spikeinterface.sortingcomponents.tests.common import make_dataset @@ -159,34 +153,27 @@ def test_estimate_motion(setup_module): ) kwargs.update(cases_kwargs) - job_kwargs = dict(progress_bar=False) - - motion, temporal_bins, spatial_bins, extra_check = estimate_motion( - recording, peaks, peak_locations, **kwargs, **job_kwargs - ) - + motion, extra_check = estimate_motion(recording, peaks, peak_locations, **kwargs) motions[name] = motion - assert temporal_bins.shape[0] == motion.shape[0] - assert spatial_bins.shape[0] == motion.shape[1] - if cases_kwargs["rigid"]: - assert motion.shape[1] == 1 + assert motion.displacement[0].shape[1] == 1 else: - assert motion.shape[1] > 1 + assert motion.displacement[0].shape[1] > 1 - # Test saving to disk - corrected_rec = InterpolateMotionRecording( - recording, motion, temporal_bins, spatial_bins, border_mode="force_extrapolate" - ) - rec_folder = cache_folder / (name.replace("/", "").replace(" ", "_") + "_recording") - if rec_folder.exists(): - shutil.rmtree(rec_folder) - corrected_rec.save(folder=rec_folder) + # # Test saving to disk + # corrected_rec = InterpolateMotionRecording( + # recording, motion, temporal_bins, spatial_bins, border_mode="force_extrapolate" + # ) + # rec_folder = cache_folder / (name.replace("/", "").replace(" ", "_") + "_recording") + # if rec_folder.exists(): + # shutil.rmtree(rec_folder) + # corrected_rec.save(folder=rec_folder) if DEBUG: fig, ax = plt.subplots() - ax.plot(temporal_bins, motion) + seg_index = 0 + ax.plot(motion.temporal_bins_s[0], motion.displacement[seg_index]) # motion_histogram = extra_check['motion_histogram'] # spatial_hist_bins = extra_check['spatial_hist_bin_edges'] @@ -206,33 +193,25 @@ def test_estimate_motion(setup_module): plt.show() # same params with differents engine should be the same - motion0, motion1 = motions["rigid / decentralized / torch"], motions["rigid / decentralized / numpy"] - assert (motion0 == motion1).all() + motion0 = motions["rigid / decentralized / torch"] + motion1 = motions["rigid / decentralized / numpy"] + assert motion0 == motion1 - motion0, motion1 = ( - motions["rigid / decentralized / torch / time_horizon_s"], - motions["rigid / decentralized / numpy / time_horizon_s"], - ) - # TODO : later torch and numpy used to be the same - # assert np.testing.assert_almost_equal(motion0, motion1) + motion0 = motions["rigid / decentralized / torch / time_horizon_s"] + motion1 = motions["rigid / decentralized / numpy / time_horizon_s"] + np.testing.assert_array_almost_equal(motion0.displacement, motion1.displacement) - motion0, motion1 = motions["non-rigid / decentralized / torch"], motions["non-rigid / decentralized / numpy"] - # TODO : later torch and numpy used to be the same - # assert np.testing.assert_almost_equal(motion0, motion1) + motion0 = motions["non-rigid / decentralized / torch"] + motion1 = motions["non-rigid / decentralized / numpy"] + np.testing.assert_array_almost_equal(motion0.displacement, motion1.displacement) - motion0, motion1 = ( - motions["non-rigid / decentralized / torch / time_horizon_s"], - motions["non-rigid / decentralized / numpy / time_horizon_s"], - ) - # TODO : later torch and numpy used to be the same - # assert np.testing.assert_almost_equal(motion0, motion1) + motion0 = motions["non-rigid / decentralized / torch / time_horizon_s"] + motion1 = motions["non-rigid / decentralized / numpy / time_horizon_s"] + np.testing.assert_array_almost_equal(motion0.displacement, motion1.displacement) - motion0, motion1 = ( - motions["non-rigid / decentralized / torch / spatial_prior"], - motions["non-rigid / decentralized / numpy / spatial_prior"], - ) - # TODO : later torch and numpy used to be the same - # assert np.testing.assert_almost_equal(motion0, motion1) + motion0 = motions["non-rigid / decentralized / torch / spatial_prior"] + motion1 = motions["non-rigid / decentralized / numpy / spatial_prior"] + np.testing.assert_array_almost_equal(motion0.displacement, motion1.displacement) if __name__ == "__main__": diff --git a/src/spikeinterface/sortingcomponents/tests/test_motion_interpolation.py b/src/spikeinterface/sortingcomponents/tests/test_motion_interpolation.py index de22ee010d..cb26560272 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/tests/test_motion_interpolation.py @@ -1,30 +1,39 @@ -import numpy as np +from pathlib import Path +import numpy as np +import pytest +import spikeinterface.core as sc +from spikeinterface import download_dataset from spikeinterface.sortingcomponents.motion_interpolation import ( + InterpolateMotionRecording, correct_motion_on_peaks, + interpolate_motion, interpolate_motion_on_traces, - InterpolateMotionRecording, ) - +from spikeinterface.sortingcomponents.motion_utils import Motion from spikeinterface.sortingcomponents.tests.common import make_dataset def make_fake_motion(rec): - # make a fake motion vector + # make a fake motion object duration = rec.get_total_duration() locs = rec.get_channel_locations() temporal_bins = np.arange(0.5, duration - 0.49, 0.5) spatial_bins = np.arange(locs[:, 1].min(), locs[:, 1].max(), 100) - motion = np.zeros((temporal_bins.size, spatial_bins.size)) - motion[:, :] = np.linspace(-30, 30, temporal_bins.size)[:, None] + displacement = np.zeros((temporal_bins.size, spatial_bins.size)) + displacement[:, :] = np.linspace(-30, 30, temporal_bins.size)[:, None] - return motion, temporal_bins, spatial_bins + motion = Motion([displacement], [temporal_bins], spatial_bins, direction="y") + + return motion def test_correct_motion_on_peaks(): rec, sorting = make_dataset() peaks = sorting.to_spike_vector() - motion, temporal_bins, spatial_bins = make_fake_motion(rec) + print(peaks.dtype) + motion = make_fake_motion(rec) + # print(motion) # fake locations peak_locations = np.zeros((peaks.size), dtype=[("x", "float32"), ("y", "float")]) @@ -32,26 +41,25 @@ def test_correct_motion_on_peaks(): corrected_peak_locations = correct_motion_on_peaks( peaks, peak_locations, - rec.sampling_frequency, motion, - temporal_bins, - spatial_bins, - direction="y", + rec, ) # print(corrected_peak_locations) assert np.any(corrected_peak_locations["y"] != 0) # import matplotlib.pyplot as plt # fig, ax = plt.subplots() - # ax.plot(times[peaks['sample_index']], corrected_peak_locations['y']) - # ax.plot(temporal_bins, motion[:, 1]) + # segment_index = 0 + # times = rec.get_times(segment_index=segment_index) + # ax.scatter(times[peaks['sample_index']], corrected_peak_locations['y']) + # ax.plot(motion.temporal_bins_s[segment_index], motion.displacement[segment_index][:, 1]) # plt.show() def test_interpolate_motion_on_traces(): rec, sorting = make_dataset() - motion, temporal_bins, spatial_bins = make_fake_motion(rec) + motion = make_fake_motion(rec) channel_locations = rec.get_channel_locations() @@ -64,28 +72,60 @@ def test_interpolate_motion_on_traces(): times, channel_locations, motion, - temporal_bins, - spatial_bins, - direction=1, channel_inds=None, spatial_interpolation_method=method, - spatial_interpolation_kwargs={}, + # spatial_interpolation_kwargs={}, + spatial_interpolation_kwargs={"force_extrapolate": True}, ) assert traces.shape == traces_corrected.shape assert traces.dtype == traces_corrected.dtype +def test_interpolation_simple(): + # a recording where a 1 moves at 1 chan per second. 30 chans 10 frames. + # there will be 9 chans of drift, so we add 9 chans of padding to the bottom + nt = nc0 = 10 # these need to be the same for this test + nc1 = nc0 + nc0 - 1 + traces = np.zeros((nt, nc1), dtype="float32") + traces[:, :nc0] = np.eye(nc0) + rec = sc.NumpyRecording(traces, sampling_frequency=1) + rec.set_dummy_probe_from_locations(np.c_[np.zeros(nc1), np.arange(nc1)]) + + true_motion = Motion(np.arange(nt)[:, None], 0.5 + np.arange(nt), np.zeros(1)) + rec_corrected = interpolate_motion(rec, true_motion, spatial_interpolation_method="nearest") + traces_corrected = rec_corrected.get_traces() + assert traces_corrected.shape == (nc0, nc0) + assert np.array_equal(traces_corrected[:, 0], np.ones(nt)) + assert np.array_equal(traces_corrected[:, 1:], np.zeros((nt, nc0 - 1))) + + # let's try a new version where we interpolate too slowly + rec_corrected = interpolate_motion( + rec, true_motion, spatial_interpolation_method="nearest", num_closest=2, interpolation_time_bin_size_s=2 + ) + traces_corrected = rec_corrected.get_traces() + assert traces_corrected.shape == (nc0, nc0) + # what happens with nearest here? + # well... due to rounding towards the nearest even number, the motion (which at + # these time bin centers is 0.5, 2.5, 4.5, ...) flips the signal's nearest + # neighbor back and forth between the first and second channels + assert np.all(traces_corrected[::2, 0] == 1) + assert np.all(traces_corrected[1::2, 0] == 0) + assert np.all(traces_corrected[1::2, 1] == 1) + assert np.all(traces_corrected[::2, 1] == 0) + assert np.all(traces_corrected[:, 2:] == 0) + + def test_InterpolateMotionRecording(): rec, sorting = make_dataset() - motion, temporal_bins, spatial_bins = make_fake_motion(rec) + motion = make_fake_motion(rec) - rec2 = InterpolateMotionRecording(rec, motion, temporal_bins, spatial_bins, border_mode="force_extrapolate") + rec2 = InterpolateMotionRecording(rec, motion, border_mode="force_extrapolate") assert rec2.channel_ids.size == 32 - rec2 = InterpolateMotionRecording(rec, motion, temporal_bins, spatial_bins, border_mode="force_zeros") + rec2 = InterpolateMotionRecording(rec, motion, border_mode="force_zeros") assert rec2.channel_ids.size == 32 - rec2 = InterpolateMotionRecording(rec, motion, temporal_bins, spatial_bins, border_mode="remove_channels") + rec2 = InterpolateMotionRecording(rec, motion, border_mode="remove_channels") assert rec2.channel_ids.size == 24 for ch_id in (0, 1, 14, 15, 16, 17, 30, 31): assert ch_id not in rec2.channel_ids @@ -106,6 +146,7 @@ def test_InterpolateMotionRecording(): if __name__ == "__main__": - test_correct_motion_on_peaks() - test_interpolate_motion_on_traces() + # test_correct_motion_on_peaks() + # test_interpolate_motion_on_traces() + test_interpolation_simple() test_InterpolateMotionRecording() diff --git a/src/spikeinterface/sortingcomponents/tests/test_motion_utils.py b/src/spikeinterface/sortingcomponents/tests/test_motion_utils.py new file mode 100644 index 0000000000..0b67be39c0 --- /dev/null +++ b/src/spikeinterface/sortingcomponents/tests/test_motion_utils.py @@ -0,0 +1,86 @@ +import pickle +import shutil +from pathlib import Path + +import numpy as np +import pytest +from spikeinterface.sortingcomponents.motion_utils import Motion +from spikeinterface.generation import make_one_displacement_vector + +if hasattr(pytest, "global_test_folder"): + cache_folder = pytest.global_test_folder / "sortingcomponents" +else: + cache_folder = Path("cache_folder") / "sortingcomponents" + + +def make_fake_motion(): + displacement_sampling_frequency = 5.0 + spatial_bins_um = np.array([100.0, 200.0, 300.0, 400.0]) + + displacement_vector = make_one_displacement_vector( + drift_mode="zigzag", + duration=50.0, + amplitude_factor=1.0, + displacement_sampling_frequency=displacement_sampling_frequency, + period_s=25.0, + ) + temporal_bins_s = np.arange(displacement_vector.size) / displacement_sampling_frequency + displacement = np.zeros((temporal_bins_s.size, spatial_bins_um.size)) + + n = spatial_bins_um.size + for i in range(n): + displacement[:, i] = displacement_vector * ((i + 1) / n) + + motion = Motion(displacement, temporal_bins_s, spatial_bins_um, direction="y") + + return motion + + +def test_Motion(): + + temporal_bins_s = np.arange(0.0, 10.0, 1.0) + spatial_bins_um = np.array([100.0, 200.0]) + + displacement = np.zeros((temporal_bins_s.shape[0], spatial_bins_um.shape[0])) + displacement[:, :] = np.linspace(-20, 20, temporal_bins_s.shape[0])[:, np.newaxis] + + motion = Motion(displacement, temporal_bins_s, spatial_bins_um, direction="y") + assert motion.interpolators is None + + # serialize with pickle before interpolation fit + motion2 = pickle.loads(pickle.dumps(motion)) + assert motion2.interpolators is None + # serialize with pickle after interpolation fit + motion2.make_interpolators() + assert motion2.interpolators is not None + motion2 = pickle.loads(pickle.dumps(motion2)) + assert motion2.interpolators is not None + + # to/from dict + motion2 = Motion(**motion.to_dict()) + assert motion == motion2 + assert motion2.interpolators is None + + # do interpolate + displacement = motion.get_displacement_at_time_and_depth([2, 4.4, 11], [120.0, 80.0, 150.0]) + # print(displacement) + assert displacement.shape[0] == 3 + # check clip + assert displacement[2] == 20.0 + + # interpolate grid + displacement = motion.get_displacement_at_time_and_depth([2, 4.4, 11, 15, 19], [150.0, 80.0], grid=True) + assert displacement.shape == (2, 5) + assert displacement[0, 2] == 20.0 + + # save/load to folder + folder = cache_folder / "motion_saved" + if folder.exists(): + shutil.rmtree(folder) + motion.save(folder) + motion2 = Motion.load(folder) + assert motion == motion2 + + +if __name__ == "__main__": + test_Motion() diff --git a/src/spikeinterface/sortingcomponents/tests/test_waveforms/test_waveform_thresholder.py b/src/spikeinterface/sortingcomponents/tests/test_waveforms/test_waveform_thresholder.py index 4f55030283..79a9603b8d 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_waveforms/test_waveform_thresholder.py +++ b/src/spikeinterface/sortingcomponents/tests/test_waveforms/test_waveform_thresholder.py @@ -37,7 +37,7 @@ def test_waveform_thresholder_ptp( recording, peaks, nodes=pipeline_nodes, job_kwargs=chunk_executor_kwargs ) - data = tresholded_waveforms.ptp(axis=1) / noise_levels + data = np.ptp(tresholded_waveforms, axis=1) / noise_levels assert np.all(data[data != 0] > 3) diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index cf0d22c0c8..cc45dd3e40 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -103,7 +103,7 @@ def cache_preprocessing(recording, mode="memory", memory_limit=0.5, delete_cache if mode == "memory": if HAVE_PSUTIL: assert 0 < memory_limit < 1, "memory_limit should be in ]0, 1[" - memory_usage = memory_limit * psutil.virtual_memory()[4] + memory_usage = memory_limit * psutil.virtual_memory().available if recording.get_total_memory_size() < memory_usage: recording = recording.save_to_memory(format="memory", shared=True, **job_kwargs) else: diff --git a/src/spikeinterface/sortingcomponents/waveforms/waveform_thresholder.py b/src/spikeinterface/sortingcomponents/waveforms/waveform_thresholder.py index 76d72f3b08..b4c54be6ad 100644 --- a/src/spikeinterface/sortingcomponents/waveforms/waveform_thresholder.py +++ b/src/spikeinterface/sortingcomponents/waveforms/waveform_thresholder.py @@ -78,7 +78,7 @@ def __init__( def compute(self, traces, peaks, waveforms): if self.feature == "ptp": - wf_data = waveforms.ptp(axis=1) / self.noise_levels + wf_data = np.ptp(waveforms, axis=1) / self.noise_levels elif self.feature == "mean": wf_data = waveforms.mean(axis=1) / self.noise_levels elif self.feature == "energy": diff --git a/src/spikeinterface/widgets/amplitudes.py b/src/spikeinterface/widgets/amplitudes.py index efbf6f3f32..ac73c57249 100644 --- a/src/spikeinterface/widgets/amplitudes.py +++ b/src/spikeinterface/widgets/amplitudes.py @@ -189,7 +189,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.next_data_plot = data_plot.copy() cm = 1 / 2.54 - we = data_plot["sorting_analyzer"] + analyzer = data_plot["sorting_analyzer"] width_cm = backend_kwargs["width_cm"] height_cm = backend_kwargs["height_cm"] @@ -202,8 +202,8 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.figure = plt.figure(figsize=((ratios[1] * width_cm) * cm, height_cm * cm)) plt.show() - self.unit_selector = UnitSelector(we.unit_ids) - self.unit_selector.value = list(we.unit_ids)[:1] + self.unit_selector = UnitSelector(analyzer.unit_ids) + self.unit_selector.value = list(analyzer.unit_ids)[:1] self.checkbox_histograms = W.Checkbox( value=data_plot["plot_histograms"], @@ -215,7 +215,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.unit_selector, self.checkbox_histograms, ], - layout=W.Layout(align_items="center", width="4cm", height="100%"), + layout=W.Layout(align_items="center", width="100%", height="100%"), ) self.widget = W.AppLayout( diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index 110555dd6a..fc0c91423d 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -7,14 +7,113 @@ class MotionWidget(BaseWidget): """ - Plot unit depths + Plot the Motion object + + Parameters + ---------- + motion : Motion + The motion object + segment_index : None | int + If Motion is multi segment, the must be not None + mode : "auto" | "line" | "map" + How to plot map or lines. + "auto" make it automatic if the number of depth is too high. + """ + + def __init__( + self, + motion, + segment_index=None, + mode="line", + motion_lim=None, + backend=None, + **backend_kwargs, + ): + if isinstance(motion, dict): + raise ValueError( + "The API has changed, plot_motion() used Motion object now, maybe you want plot_motion_info(motion_info)" + ) + + if segment_index is None: + if len(motion.displacement) == 1: + segment_index = 0 + else: + raise ValueError("plot motion : the Motion object is multi segment you must provide segment_index=XX") + + plot_data = dict( + motion=motion, + segment_index=segment_index, + mode=mode, + motion_lim=motion_lim, + ) + + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + from matplotlib.colors import Normalize + + dp = to_attr(data_plot) + + motion = data_plot["motion"] + segment_index = data_plot["segment_index"] + + assert backend_kwargs["axes"] is None + + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + displacement = motion.displacement[dp.segment_index] + temporal_bins_s = motion.temporal_bins_s[dp.segment_index] + depth = motion.spatial_bins_um + + if dp.motion_lim is None: + motion_lim = np.max(np.abs(displacement)) * 1.05 + else: + motion_lim = dp.motion_lim + + ax = self.ax + fig = self.figure + if dp.mode == "line": + ax.plot(temporal_bins_s, displacement, alpha=0.2, color="black") + ax.plot(temporal_bins_s, np.mean(displacement, axis=1), color="C0") + ax.set_xlabel("Times [s]") + ax.set_ylabel("motion [um]") + elif dp.mode == "map": + im = ax.imshow( + displacement.T, + interpolation="nearest", + aspect="auto", + origin="lower", + extent=(temporal_bins_s[0], temporal_bins_s[-1], depth[0], depth[-1]), + cmap="PiYG", + ) + im.set_clim(-motion_lim, motion_lim) + + cbar = fig.colorbar(im) + cbar.ax.set_ylabel("motion [um]") + ax.set_xlabel("Times [s]") + ax.set_ylabel("Depth [um]") + + +class MotionInfoWidget(BaseWidget): + """ + Plot motion information from the motion_info dict returned by correct_motion(). + This plot: + * the motion iself + * the peak depth vs time before correction + * the peak depth vs time after correction Parameters ---------- motion_info : dict The motion info return by correct_motion() or load back with load_motion_info() + segment_index : int, default: None + The segment index to display. recording : RecordingExtractor, default: None The recording extractor object (only used to get "real" times) + segment_index : int, default: 0 + The segment index to display. sampling_frequency : float, default: None The sampling frequency (needed if recording is None) depth_lim : tuple or None, default: None @@ -36,6 +135,7 @@ class MotionWidget(BaseWidget): def __init__( self, motion_info, + segment_index=None, recording=None, depth_lim=None, motion_lim=None, @@ -47,11 +147,20 @@ def __init__( backend=None, **backend_kwargs, ): + + motion = motion_info["motion"] + if segment_index is None: + if len(motion.displacement) == 1: + segment_index = 0 + else: + raise ValueError("plot motion : teh Motion object is multi segment you must provide segmentindex=XX") + times = recording.get_times() if recording is not None else None plot_data = dict( sampling_frequency=motion_info["parameters"]["sampling_frequency"], times=times, + segment_index=segment_index, depth_lim=depth_lim, motion_lim=motion_lim, color_amplitude=color_amplitude, @@ -59,6 +168,7 @@ def __init__( amplitude_cmap=amplitude_cmap, amplitude_clim=amplitude_clim, amplitude_alpha=amplitude_alpha, + recording=recording, **motion_info, ) @@ -73,16 +183,29 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): dp = to_attr(data_plot) - assert backend_kwargs["axes"] is None - assert backend_kwargs["ax"] is None + assert backend_kwargs["axes"] is None, "axes argument is not allowed in MotionWidget" + assert backend_kwargs["ax"] is None, "ax argument is not allowed in MotionWidget" self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) fig = self.figure fig.clear() - is_rigid = dp.motion.shape[1] == 1 + is_rigid = dp.motion.spatial_bins_um.shape[0] == 1 - gs = fig.add_gridspec(2, 2, wspace=0.3, hspace=0.3) + motion = dp.motion + + displacement = motion.displacement[dp.segment_index] + temporal_bins_s = motion.temporal_bins_s[dp.segment_index] + spatial_bins_um = motion.spatial_bins_um + + if dp.motion_lim is None: + motion_lim = np.max(np.abs(displacement)) * 1.05 + else: + motion_lim = dp.motion_lim + + is_rigid = displacement.shape[1] == 1 + + gs = fig.add_gridspec(2, 2, wspace=0.3, hspace=0.5) ax0 = fig.add_subplot(gs[0, 0]) ax1 = fig.add_subplot(gs[0, 1]) ax2 = fig.add_subplot(gs[1, 0]) @@ -91,31 +214,23 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax1.sharex(ax0) ax1.sharey(ax0) - if dp.motion_lim is None: - motion_lim = np.max(np.abs(dp.motion)) * 1.05 - else: - motion_lim = dp.motion_lim - if dp.times is None: - temporal_bins_plot = dp.temporal_bins + # temporal_bins_plot = dp.temporal_bins x = dp.peaks["sample_index"] / dp.sampling_frequency else: # use real times and adjust temporal bins with t_start - temporal_bins_plot = dp.temporal_bins + dp.times[0] + # temporal_bins_plot = dp.temporal_bins + dp.times[0] x = dp.times[dp.peaks["sample_index"]] corrected_location = correct_motion_on_peaks( dp.peaks, dp.peak_locations, - dp.sampling_frequency, dp.motion, - dp.temporal_bins, - dp.spatial_bins, - direction="y", + dp.recording, ) - y = dp.peak_locations["y"] - y2 = corrected_location["y"] + y = dp.peak_locations[motion.direction] + y2 = corrected_location[motion.direction] if dp.scatter_decimate is not None: x = x[:: dp.scatter_decimate] y = y[:: dp.scatter_decimate] @@ -149,37 +264,38 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax0.set_ylim(*dp.depth_lim) ax0.set_title("Peak depth") ax0.set_xlabel("Times [s]") - ax0.set_ylabel("Depth [um]") + ax0.set_ylabel("Depth [$\\mu$m]") ax1.scatter(x, y2, s=1, **color_kwargs) ax1.set_xlabel("Times [s]") - ax1.set_ylabel("Depth [um]") + ax1.set_ylabel("Depth [$\\mu$m]") ax1.set_title("Corrected peak depth") - ax2.plot(temporal_bins_plot, dp.motion, alpha=0.2, color="black") - ax2.plot(temporal_bins_plot, np.mean(dp.motion, axis=1), color="C0") + ax2.plot(temporal_bins_s, displacement, alpha=0.2, color="black") + ax2.plot(temporal_bins_s, np.mean(displacement, axis=1), color="C0") ax2.set_ylim(-motion_lim, motion_lim) - ax2.set_ylabel("Motion [um]") + ax2.set_ylabel("Motion [$\\mu$m]") + ax2.set_xlabel("Times [s]") ax2.set_title("Motion vectors") axes = [ax0, ax1, ax2] if not is_rigid: im = ax3.imshow( - dp.motion.T, + displacement.T, aspect="auto", origin="lower", extent=( - temporal_bins_plot[0], - temporal_bins_plot[-1], - dp.spatial_bins[0], - dp.spatial_bins[-1], + temporal_bins_s[0], + temporal_bins_s[-1], + spatial_bins_um[0], + spatial_bins_um[-1], ), ) im.set_clim(-motion_lim, motion_lim) cbar = fig.colorbar(im) - cbar.ax.set_xlabel("motion [um]") + cbar.ax.set_ylabel("Motion [$\\mu$m]") ax3.set_xlabel("Times [s]") - ax3.set_ylabel("Depth [um]") + ax3.set_ylabel("Depth [$\\mu$m]") ax3.set_title("Motion vectors") axes.append(ax3) self.axes = np.array(axes) diff --git a/src/spikeinterface/widgets/potential_merges.py b/src/spikeinterface/widgets/potential_merges.py new file mode 100644 index 0000000000..be882209b8 --- /dev/null +++ b/src/spikeinterface/widgets/potential_merges.py @@ -0,0 +1,248 @@ +from __future__ import annotations + +import numpy as np +from warnings import warn + +from .base import BaseWidget, default_backend_kwargs + +from .amplitudes import AmplitudesWidget +from .crosscorrelograms import CrossCorrelogramsWidget +from .unit_templates import UnitTemplatesWidget + +from .utils import get_some_colors + +from ..core.sortinganalyzer import SortingAnalyzer + + +class PotentialMergesWidget(BaseWidget): + """ + Plots potential merges + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + The input sorting analyzer + potential_merges : list of lists or tuples + List of potential merges (see `spikeinterface.curation.get_potential_auto_merges`) + segment_index : int + The segment index to display + max_spike_samples : int or None, default: None + The maximum number of spikes to display per unit + """ + + def __init__( + self, + sorting_analyzer: SortingAnalyzer, + potential_merges: list, + unit_colors: list = None, + segment_index: int = 0, + max_spikes_per_unit: int = 100, + backend=None, + **backend_kwargs, + ): + sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer) + + self.check_extensions(sorting_analyzer, ["templates", "spike_amplitudes", "correlograms"]) + + unique_merge_units = np.unique([u for merge in potential_merges for u in merge]) + if unit_colors is None: + unit_colors = get_some_colors(sorting_analyzer.unit_ids) + + plot_data = dict( + sorting_analyzer=sorting_analyzer, + potential_merges=potential_merges, + unit_colors=unit_colors, + segment_index=segment_index, + max_spikes_per_unit=max_spikes_per_unit, + unique_merge_units=unique_merge_units, + ) + + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_ipywidgets(self, data_plot, **backend_kwargs): + from math import lcm + import matplotlib.pyplot as plt + + # import ipywidgets.widgets as widgets + import ipywidgets.widgets as W + from IPython.display import display + from .utils_ipywidgets import check_ipywidget_backend, ScaleWidget, WidenNarrowWidget + + check_ipywidget_backend() + + self.next_data_plot = data_plot.copy() + + cm = 1 / 2.54 + analyzer = data_plot["sorting_analyzer"] + + width_cm = backend_kwargs["width_cm"] + height_cm = backend_kwargs["height_cm"] * 3 + + ratios = [0.2, 0.8] + + with plt.ioff(): + output = W.Output() + with output: + self.figure = plt.figure( + figsize=((ratios[1] * width_cm) * cm, height_cm * cm), + constrained_layout=True, + ) + plt.show() + # find max number of merges: + self.gs = None + self.axes_amplitudes = None + self.ax_templates = None + self.ax_probe = None + self.axes_cc = None + + # Instantiate sub-widgets + self.w_amplitudes = AmplitudesWidget( + analyzer, + unit_colors=data_plot["unit_colors"], + unit_ids=data_plot["unique_merge_units"], + plot_histograms=True, + plot_legend=False, + immediate_plot=False, + ) + self.w_templates = UnitTemplatesWidget( + analyzer, + unit_ids=data_plot["unique_merge_units"], + unit_colors=data_plot["unit_colors"], + plot_legend=False, + immediate_plot=False, + ) + self.w_crosscorrelograms = CrossCorrelogramsWidget( + analyzer, + unit_ids=data_plot["unique_merge_units"], + min_similarity_for_correlograms=0, + unit_colors=data_plot["unit_colors"], + immediate_plot=False, + ) + + options = ["-".join([str(u) for u in m]) for m in data_plot["potential_merges"]] + value = options[0] + self.unit_selector_label = W.Label(value="Potential merges:") + self.unit_selector = W.Dropdown(options=options, value=value, layout=W.Layout(width="80%")) + self.previous_num_merges = len(data_plot["potential_merges"][0]) + self.scaler = ScaleWidget(value=1.0, layout=W.Layout(width="80%")) + self.widen_narrow = WidenNarrowWidget(value=1.0, layout=W.Layout(width="80%")) + + left_sidebar = W.VBox( + [self.unit_selector_label, self.unit_selector, self.scaler, self.widen_narrow], + layout=W.Layout(width="100%"), + ) + + self.widget = W.AppLayout( + center=self.figure.canvas, + left_sidebar=left_sidebar, + pane_widths=ratios + [0], + ) + + if len(np.unique([len(m) for m in self.data_plot["potential_merges"]])) == 1: + # in this case we multiply the number of columns by 3 to have 2/3 of the space for the templates + ncols = 3 * len(self.data_plot["potential_merges"][0]) + else: + ncols = lcm(*[len(m) for m in self.data_plot["potential_merges"]]) + right_axes = int(ncols * 2 / 3) + self.ncols = ncols + self.right_axes = right_axes + + # a first update + self._update_plot(None) + + self.unit_selector.observe(self._update_plot, names="value", type="change") + self.scaler.observe(self._update_plot, names="value", type="change") + self.widen_narrow.observe(self._update_plot, names="value", type="change") + + if backend_kwargs["display"]: + display(self.widget) + + def _update_gs(self, merge_units): + import matplotlib.gridspec as gridspec + + # we create a vertical grid with 1 row between the 3 first plots + n_units = len(merge_units) + ncols = self.ncols + right_axes = self.right_axes + unit_len_in_gs = self.ncols // n_units + nrows = ncols * 3 + 2 + + if self.gs is not None and self.previous_num_merges == len(merge_units): + self.ax_templates.clear() + self.ax_probe.clear() + for ax in self.axes_amplitudes: + ax.clear() + for ax in self.axes_cc.flatten(): + ax.clear() + else: + self.figure.clear() + self.gs = gridspec.GridSpec(nrows, ncols, figure=self.figure) + self.ax_templates = self.figure.add_subplot(self.gs[:ncols, :right_axes]) + self.ax_probe = self.figure.add_subplot(self.gs[:ncols, right_axes:]) + row_offset = ncols + 1 + ax_amplitudes_ts = self.figure.add_subplot(self.gs[row_offset : row_offset + ncols, :right_axes]) + ax_amplitudes_hist = self.figure.add_subplot(self.gs[row_offset : row_offset + ncols, right_axes:]) + self.axes_amplitudes = [ax_amplitudes_ts, ax_amplitudes_hist] + row_offset += ncols + 1 + self.axes_cc = [] + for i in range(0, n_units): + for j in range(0, n_units): + self.axes_cc.append( + self.figure.add_subplot( + self.gs[ + row_offset + (unit_len_in_gs) * i : row_offset + (unit_len_in_gs) * (i + 1), + j * unit_len_in_gs : (j + 1) * unit_len_in_gs, + ] + ) + ) + self.axes_cc = np.array(self.axes_cc).reshape((n_units, n_units)) + self.previous_num_merges = len(merge_units) + + def _update_plot(self, change=None): + + merge_units = self.unit_selector.value + sorting_analyzer = self.data_plot["sorting_analyzer"] + channel_locations = sorting_analyzer.get_channel_locations() + unit_ids = sorting_analyzer.unit_ids + + # unroll the merges + unit_ids_str = [str(u) for u in unit_ids] + plot_unit_ids = [] + for m in merge_units.split("-"): + plot_unit_ids.append(unit_ids[unit_ids_str.index(m)]) + self._update_gs(plot_unit_ids) + + backend_kwargs_mpl = default_backend_kwargs["matplotlib"].copy() + backend_kwargs_mpl.pop("axes") + backend_kwargs_mpl.pop("ax") + + amplitude_data_plot = self.w_amplitudes.data_plot.copy() + amplitude_data_plot["unit_ids"] = plot_unit_ids + self.w_amplitudes.plot_matplotlib(amplitude_data_plot, ax=None, axes=self.axes_amplitudes, **backend_kwargs_mpl) + + unit_template_data_plot = self.w_templates.data_plot.copy() + unit_template_data_plot["unit_ids"] = plot_unit_ids + unit_template_data_plot["same_axis"] = True + unit_template_data_plot["set_title"] = False + unit_template_data_plot["scale"] = self.scaler.value + unit_template_data_plot["widen_narrow_scale"] = self.widen_narrow.value + # update templates and shading + templates_ext = sorting_analyzer.get_extension("templates") + unit_template_data_plot["templates"] = templates_ext.get_templates(unit_ids=plot_unit_ids, operator="average") + unit_template_data_plot["templates_shading"] = self.w_templates._get_template_shadings( + plot_unit_ids, self.w_templates.data_plot["templates_percentile_shading"] + ) + self.w_templates.plot_matplotlib(unit_template_data_plot, ax=self.ax_templates, axes=None, **backend_kwargs_mpl) + self.ax_templates.axis("off") + self.w_templates._plot_probe(self.ax_probe, channel_locations, plot_unit_ids) + crosscorrelograms_data_plot = self.w_crosscorrelograms.data_plot.copy() + crosscorrelograms_data_plot["unit_ids"] = plot_unit_ids + merge_unit_indices = np.flatnonzero(np.isin(self.data_plot["unique_merge_units"], plot_unit_ids)) + updated_correlograms = crosscorrelograms_data_plot["correlograms"] + updated_correlograms = updated_correlograms[merge_unit_indices][:, merge_unit_indices] + crosscorrelograms_data_plot["correlograms"] = updated_correlograms + self.w_crosscorrelograms.plot_matplotlib( + crosscorrelograms_data_plot, axes=self.axes_cc, ax=None, **backend_kwargs_mpl + ) + self.figure.canvas.draw() + self.figure.canvas.flush_events() diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index fdc937dc25..e841a1c93b 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -2,6 +2,8 @@ import pytest import os +import numpy as np + if __name__ != "__main__": try: import matplotlib @@ -578,6 +580,38 @@ def test_plot_multicomparison(self): _, axes = plt.subplots(len(mcmp.object_list), 1) sw.plot_multicomparison_agreement_by_sorter(mcmp, axes=axes) + def test_plot_motion(self): + from spikeinterface.sortingcomponents.tests.test_motion_utils import make_fake_motion + + motion = make_fake_motion() + + possible_backends = list(sw.MotionWidget.get_possible_backends()) + for backend in possible_backends: + if backend not in self.skip_backends: + sw.plot_motion(motion, backend=backend, mode="line") + sw.plot_motion(motion, backend=backend, mode="map") + + def test_plot_motion_info(self): + from spikeinterface.sortingcomponents.tests.test_motion_utils import make_fake_motion + + motion = make_fake_motion() + rng = np.random.default_rng(seed=2205) + peak_locations = np.zeros(self.peaks.size, dtype=[("x", "float64"), ("y", "float64")]) + peak_locations["y"] = rng.uniform(motion.spatial_bins_um[0], motion.spatial_bins_um[-1], size=self.peaks.size) + + motion_info = dict( + motion=motion, + parameters=dict(sampling_frequency=30000.0), + run_times=dict(), + peaks=self.peaks, + peak_locations=peak_locations, + ) + + possible_backends = list(sw.MotionWidget.get_possible_backends()) + for backend in possible_backends: + if backend not in self.skip_backends: + sw.plot_motion_info(motion_info, recording=self.recording, backend=backend) + if __name__ == "__main__": # unittest.main() @@ -592,7 +626,7 @@ def test_plot_multicomparison(self): # mytest.test_plot_traces() # mytest.test_plot_spikes_on_traces() # mytest.test_plot_unit_waveforms() - mytest.test_plot_spikes_on_traces() + # mytest.test_plot_spikes_on_traces() # mytest.test_plot_unit_depths() # mytest.test_plot_autocorrelograms() # mytest.test_plot_crosscorrelograms() @@ -612,6 +646,8 @@ def test_plot_multicomparison(self): # mytest.test_plot_peak_activity() # mytest.test_plot_multicomparison() # mytest.test_plot_sorting_summary() + # mytest.test_plot_motion() + mytest.test_plot_motion_info() plt.show() # TestWidgets.tearDownClass() diff --git a/src/spikeinterface/widgets/unit_waveforms.py b/src/spikeinterface/widgets/unit_waveforms.py index add8c820b8..b046e55fbf 100644 --- a/src/spikeinterface/widgets/unit_waveforms.py +++ b/src/spikeinterface/widgets/unit_waveforms.py @@ -252,7 +252,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): else: if dp.same_axis: backend_kwargs["num_axes"] = 1 - backend_kwargs["ncols"] = None + backend_kwargs["ncols"] = 1 else: backend_kwargs["num_axes"] = len(dp.unit_ids) backend_kwargs["ncols"] = min(dp.ncols, len(dp.unit_ids)) @@ -487,11 +487,10 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): # a first update self._update_plot(None) - - self.unit_selector.observe(self._update_plot, names="value", type="change") - self.scaler.observe(self._update_plot, names="value", type="change") - self.widen_narrow.observe(self._update_plot, names="value", type="change") for w in ( + self.unit_selector, + self.scaler, + self.widen_narrow, self.same_axis_button, self.plot_templates_button, self.template_shading_button, @@ -592,30 +591,38 @@ def _update_plot(self, change): ax.axis("off") # update probe plot - self.ax_probe.plot( + self._plot_probe( + self.ax_probe, + channel_locations, + unit_ids, + ) + fig_probe = self.ax_probe.get_figure() + + self.fig_wf.canvas.draw() + self.fig_wf.canvas.flush_events() + fig_probe.canvas.draw() + fig_probe.canvas.flush_events() + + def _plot_probe(self, ax, channel_locations, unit_ids): + # update probe plot + ax.plot( channel_locations[:, 0], channel_locations[:, 1], ls="", marker="o", color="gray", markersize=2, alpha=0.5 ) - self.ax_probe.axis("off") - self.ax_probe.axis("equal") + ax.axis("off") + ax.axis("equal") # TODO this could be done with probeinterface plotting plotting tools!! for unit in unit_ids: - channel_inds = data_plot["sparsity"].unit_id_to_channel_indices[unit] - self.ax_probe.plot( + channel_inds = self.data_plot["sparsity"].unit_id_to_channel_indices[unit] + ax.plot( channel_locations[channel_inds, 0], channel_locations[channel_inds, 1], ls="", marker="o", markersize=3, - color=self.next_data_plot["unit_colors"][unit], + color=self.data_plot["unit_colors"][unit], ) - self.ax_probe.set_xlim(np.min(channel_locations[:, 0]) - 10, np.max(channel_locations[:, 0]) + 10) - fig_probe = self.ax_probe.get_figure() - - self.fig_wf.canvas.draw() - self.fig_wf.canvas.flush_events() - fig_probe.canvas.draw() - fig_probe.canvas.flush_events() + ax.set_xlim(np.min(channel_locations[:, 0]) - 10, np.max(channel_locations[:, 0]) + 10) def get_waveforms_scales(templates, channel_locations, nbefore, x_offset_units=False, widen_narrow_scale=1.0): diff --git a/src/spikeinterface/widgets/utils.py b/src/spikeinterface/widgets/utils.py index 8677c788a2..ac0676e4c7 100644 --- a/src/spikeinterface/widgets/utils.py +++ b/src/spikeinterface/widgets/utils.py @@ -3,14 +3,6 @@ import numpy as np -try: - import distinctipy - - HAVE_DISTINCTIPY = True -except ImportError: - HAVE_DISTINCTIPY = False - - def get_some_colors( keys, color_engine="auto", map_name="gist_ncar", format="RGBA", shuffle=None, seed=None, margin=None ): @@ -48,6 +40,13 @@ def get_some_colors( except ImportError: HAVE_MPL = False + try: + import distinctipy + + HAVE_DISTINCTIPY = True + except ImportError: + HAVE_DISTINCTIPY = False + assert color_engine in ("auto", "distinctipy", "matplotlib", "colorsys") possible_formats = ("RGBA",) diff --git a/src/spikeinterface/widgets/utils_ipywidgets.py b/src/spikeinterface/widgets/utils_ipywidgets.py index 12985d366f..e31f0e0444 100644 --- a/src/spikeinterface/widgets/utils_ipywidgets.py +++ b/src/spikeinterface/widgets/utils_ipywidgets.py @@ -401,7 +401,7 @@ def __init__(self, unit_ids, **kwargs): options=self.unit_ids, value=self.unit_ids, disabled=False, - layout=W.Layout(height="100%", width="80%", align="center"), + layout=W.Layout(height="100%", width="3cm", align="center"), ) super(W.VBox, self).__init__(children=[label, self.selector], **kwargs) diff --git a/src/spikeinterface/widgets/utils_matplotlib.py b/src/spikeinterface/widgets/utils_matplotlib.py index 825245750f..ceb7605d25 100644 --- a/src/spikeinterface/widgets/utils_matplotlib.py +++ b/src/spikeinterface/widgets/utils_matplotlib.py @@ -1,6 +1,5 @@ from __future__ import annotations -import matplotlib import matplotlib.pyplot as plt import numpy as np @@ -12,13 +11,10 @@ def make_mpl_figure(figure=None, ax=None, axes=None, ncols=None, num_axes=None, if figure is not None: assert ax is None and axes is None, "figure/ax/axes : only one of then can be not None" if num_axes is None: - if "ipympl" not in matplotlib.get_backend(): - ax = figure.add_subplot(111) - else: - ax = figure.add_subplot(111) + ax = figure.add_subplot(111) axes = np.array([[ax]]) else: - assert ncols is not None + assert ncols is not None, "ncols must be provided when num_axes is provided" axes = [] nrows = int(np.ceil(num_axes / ncols)) axes = np.full((nrows, ncols), fill_value=None, dtype=object) diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index b3c1820276..d6df59b0f3 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -10,9 +10,10 @@ from .autocorrelograms import AutoCorrelogramsWidget from .crosscorrelograms import CrossCorrelogramsWidget from .isi_distribution import ISIDistributionWidget -from .motion import MotionWidget +from .motion import MotionWidget, MotionInfoWidget from .multicomparison import MultiCompGraphWidget, MultiCompGlobalAgreementWidget, MultiCompAgreementBySorterWidget from .peak_activity import PeakActivityMapWidget +from .potential_merges import PotentialMergesWidget from .probe_map import ProbeMapWidget from .quality_metrics import QualityMetricsWidget from .rasters import RasterWidget @@ -44,10 +45,12 @@ CrossCorrelogramsWidget, ISIDistributionWidget, MotionWidget, + MotionInfoWidget, MultiCompGlobalAgreementWidget, MultiCompAgreementBySorterWidget, MultiCompGraphWidget, PeakActivityMapWidget, + PotentialMergesWidget, ProbeMapWidget, QualityMetricsWidget, RasterWidget, @@ -115,10 +118,12 @@ plot_crosscorrelograms = CrossCorrelogramsWidget plot_isi_distribution = ISIDistributionWidget plot_motion = MotionWidget +plot_motion_info = MotionInfoWidget plot_multicomparison_agreement = MultiCompGlobalAgreementWidget plot_multicomparison_agreement_by_sorter = MultiCompAgreementBySorterWidget plot_multicomparison_graph = MultiCompGraphWidget plot_peak_activity = PeakActivityMapWidget +plot_potential_merges = PotentialMergesWidget plot_probe_map = ProbeMapWidget plot_quality_metrics = QualityMetricsWidget plot_rasters = RasterWidget