Skip to content

Commit

Permalink
Merge branch 'main' into speed-up-amp-scalings
Browse files Browse the repository at this point in the history
  • Loading branch information
alejoe91 authored Sep 27, 2023
2 parents 8e4b43a + c5bafd1 commit 537ebbb
Show file tree
Hide file tree
Showing 28 changed files with 371 additions and 154 deletions.
5 changes: 3 additions & 2 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,14 +118,15 @@
'examples_dirs': ['../examples/modules_gallery'],
'gallery_dirs': ['modules_gallery', ], # path where to save gallery generated examples
'subsection_order': ExplicitOrder([
'../examples/modules_gallery/core/',
'../examples/modules_gallery/extractors/',
'../examples/modules_gallery/core',
'../examples/modules_gallery/extractors',
'../examples/modules_gallery/qualitymetrics',
'../examples/modules_gallery/comparison',
'../examples/modules_gallery/widgets',
]),
'within_subsection_order': FileNameSortKey,
'ignore_pattern': '/generate_',
'nested_sections': False,
}

intersphinx_mapping = {
Expand Down
1 change: 1 addition & 0 deletions doc/how_to/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ How to guides
get_started
analyse_neuropixels
handle_drift
load_matlab_data
100 changes: 100 additions & 0 deletions doc/how_to/load_matlab_data.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
Exporting MATLAB Data to Binary & Loading in SpikeInterface
===========================================================

In this tutorial, we will walk through the process of exporting data from MATLAB in a binary format and subsequently loading it using SpikeInterface in Python.

Exporting Data from MATLAB
--------------------------

Begin by ensuring your data structure is correct. Organize your data matrix so that the first dimension corresponds to samples/time and the second to channels.
Here, we present a MATLAB code that creates a random dataset and writes it to a binary file as an illustration.

.. code-block:: matlab
% Define the size of your data
numSamples = 1000;
numChannels = 384;
% Generate random data as an example
data = rand(numSamples, numChannels);
% Write the data to a binary file
fileID = fopen('your_data_as_a_binary.bin', 'wb');
fwrite(fileID, data, 'double');
fclose(fileID);
.. note::

In your own script, replace the random data generation with your actual dataset.

Loading Data in SpikeInterface
------------------------------

After executing the above MATLAB code, a binary file named `your_data_as_a_binary.bin` will be created in your MATLAB directory. To load this file in Python, you'll need its full path.

Use the following Python script to load the binary data into SpikeInterface:

.. code-block:: python
import spikeinterface as si
from pathlib import Path
# Define file path
# For Linux or macOS:
file_path = Path("/The/Path/To/Your/Data/your_data_as_a_binary.bin")
# For Windows:
# file_path = Path(r"c:\path\to\your\data\your_data_as_a_binary.bin")
# Confirm file existence
assert file_path.is_file(), f"Error: {file_path} is not a valid file. Please check the path."
# Define recording parameters
sampling_frequency = 30_000.0 # Adjust according to your MATLAB dataset
num_channels = 384 # Adjust according to your MATLAB dataset
dtype = "float64" # MATLAB's double corresponds to Python's float64
# Load data using SpikeInterface
recording = si.read_binary(file_path, sampling_frequency=sampling_frequency,
num_channels=num_channels, dtype=dtype)
# Confirm that the data was loaded correctly by comparing the data shapes and see they match the MATLAB data
print(recording.get_num_frames(), recording.get_num_channels())
Follow the steps above to seamlessly import your MATLAB data into SpikeInterface. Once loaded, you can harness the full power of SpikeInterface for data processing, including filtering, spike sorting, and more.

Common Pitfalls & Tips
----------------------

1. **Data Shape**: Make sure your MATLAB data matrix's first dimension is samples/time and the second is channels. If your time is in the second dimension, use `time_axis=1` in `si.read_binary()`.
2. **File Path**: Always double-check the Python file path.
3. **Data Type Consistency**: Ensure data types between MATLAB and Python are consistent. MATLAB's `double` is equivalent to Numpy's `float64`.
4. **Sampling Frequency**: Set the appropriate sampling frequency in Hz for SpikeInterface.
5. **Transition to Python**: Moving from MATLAB to Python can be challenging. For newcomers to Python, consider reviewing numpy's [Numpy for MATLAB Users](https://numpy.org/doc/stable/user/numpy-for-matlab-users.html) guide.

Using gains and offsets for integer data
----------------------------------------

Raw data formats often store data as integer values for memory efficiency. To give these integers meaningful physical units, you can apply a gain and an offset.
In SpikeInterface, you can use the `gain_to_uV` and `offset_to_uV` parameters, since traces are handled in microvolts (uV). Both parameters can be integrated into the `read_binary` function.
If your data in MATLAB is stored as `int16`, and you know the gain and offset, you can use the following code to load the data:

.. code-block:: python
sampling_frequency = 30_000.0 # Adjust according to your MATLAB dataset
num_channels = 384 # Adjust according to your MATLAB dataset
dtype_int = 'int16' # Adjust according to your MATLAB dataset
gain_to_uV = 0.195 # Adjust according to your MATLAB dataset
offset_to_uV = 0 # Adjust according to your MATLAB dataset
recording = si.read_binary(file_path, sampling_frequency=sampling_frequency,
num_channels=num_channels, dtype=dtype_int,
gain_to_uV=gain_to_uV, offset_to_uV=offset_to_uV)
recording.get_traces(return_scaled=True) # Return traces in micro volts (uV)
This will equip your recording object with capabilities to convert the data to float values in uV using the :code:`get_traces()` method with the :code:`return_scaled` parameter set to :code:`True`.

.. note::

The gain and offset parameters are usually format dependent and you will need to find out the correct values for your data format. You can load your data without gain and offset but then the traces will be in integer values and not in uV.
3 changes: 1 addition & 2 deletions doc/modules/core.rst
Original file line number Diff line number Diff line change
Expand Up @@ -547,8 +547,7 @@ workflow.
In order to do this, one can use the :code:`Numpy*` classes, :py:class:`~spikeinterface.core.NumpyRecording`,
:py:class:`~spikeinterface.core.NumpySorting`, :py:class:`~spikeinterface.core.NumpyEvent`, and
:py:class:`~spikeinterface.core.NumpySnippets`. These object behave exactly like normal SpikeInterface objects,
but they are not bound to a file. This makes these objects *not dumpable*, so parallel processing is not supported.
In order to make them *dumpable*, one can simply :code:`save()` them (see :ref:`save_load`).
but they are not bound to a file.

Also note the class :py:class:`~spikeinterface.core.SharedMemorySorting` which is very similar to
Similar to :py:class:`~spikeinterface.core.NumpySorting` but with an unerlying SharedMemory which is usefull for
Expand Down
10 changes: 6 additions & 4 deletions src/spikeinterface/comparison/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class HybridUnitsRecording(InjectTemplatesRecording):
The refractory period of the injected spike train (in ms).
injected_sorting_folder: str | Path | None
If given, the injected sorting is saved to this folder.
It must be specified if injected_sorting is None or not dumpable.
It must be specified if injected_sorting is None or not serialisable to file.
Returns
-------
Expand Down Expand Up @@ -84,7 +84,8 @@ def __init__(
)
# save injected sorting if necessary
self.injected_sorting = injected_sorting
if not self.injected_sorting.check_if_json_serializable():
if not self.injected_sorting.check_serializablility("json"):
# TODO later : also use pickle
assert injected_sorting_folder is not None, "Provide injected_sorting_folder to injected sorting object"
self.injected_sorting = self.injected_sorting.save(folder=injected_sorting_folder)

Expand Down Expand Up @@ -137,7 +138,7 @@ class HybridSpikesRecording(InjectTemplatesRecording):
this refractory period.
injected_sorting_folder: str | Path | None
If given, the injected sorting is saved to this folder.
It must be specified if injected_sorting is None or not dumpable.
It must be specified if injected_sorting is None or not serializable to file.
Returns
-------
Expand Down Expand Up @@ -180,7 +181,8 @@ def __init__(
self.injected_sorting = injected_sorting

# save injected sorting if necessary
if not self.injected_sorting.check_if_json_serializable():
if not self.injected_sorting.check_serializablility("json"):
# TODO later : also use pickle
assert injected_sorting_folder is not None, "Provide injected_sorting_folder to injected sorting object"
self.injected_sorting = self.injected_sorting.save(folder=injected_sorting_folder)

Expand Down
22 changes: 19 additions & 3 deletions src/spikeinterface/comparison/multicomparisons.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pathlib import Path
import json
import pickle
import warnings

import numpy as np

Expand Down Expand Up @@ -180,9 +181,16 @@ def get_agreement_sorting(self, minimum_agreement_count=1, minimum_agreement_cou
return sorting

def save_to_folder(self, save_folder):
warnings.warn(
"save_to_folder() is deprecated. "
"You should save and load the multi sorting comparison object using pickle."
"\n>>> pickle.dump(mcmp, open('mcmp.pkl', 'wb')))))\n>>> mcmp_loaded = pickle.load(open('mcmp.pkl', 'rb'))",
DeprecationWarning,
stacklevel=2,
)
for sorting in self.object_list:
assert (
sorting.check_if_json_serializable()
assert sorting.check_serializablility(
"json"
), "MultiSortingComparison.save_to_folder() need json serializable sortings"

save_folder = Path(save_folder)
Expand All @@ -205,6 +213,13 @@ def save_to_folder(self, save_folder):

@staticmethod
def load_from_folder(folder_path):
warnings.warn(
"load_from_folder() is deprecated. "
"You should save and load the multi sorting comparison object using pickle."
"\n>>> pickle.dump(mcmp, open('mcmp.pkl', 'wb')))))\n>>> mcmp_loaded = pickle.load(open('mcmp.pkl', 'rb'))",
DeprecationWarning,
stacklevel=2,
)
folder_path = Path(folder_path)
with (folder_path / "kwargs.json").open() as f:
kwargs = json.load(f)
Expand Down Expand Up @@ -244,7 +259,8 @@ def __init__(

BaseSorting.__init__(self, sampling_frequency=sampling_frequency, unit_ids=unit_ids)

self._is_json_serializable = False
self._serializablility["json"] = False
self._serializablility["pickle"] = True

if len(unit_ids) > 0:
for k in ("agreement_number", "avg_agreement", "unit_ids"):
Expand Down
73 changes: 40 additions & 33 deletions src/spikeinterface/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ def __init__(self, main_ids: Sequence) -> None:
# * number of units for sorting
self._properties = {}

self._is_dumpable = True
self._is_json_serializable = True
self._serializablility = {"memory": True, "json": True, "pickle": True}

# extractor specific list of pip extra requirements
self.extra_requirements = []
Expand Down Expand Up @@ -471,24 +470,33 @@ def clone(self) -> "BaseExtractor":
clone = BaseExtractor.from_dict(d)
return clone

def check_if_dumpable(self):
"""Check if the object is dumpable, including nested objects.
def check_serializablility(self, type):
kwargs = self._kwargs
for value in kwargs.values():
# here we check if the value is a BaseExtractor, a list of BaseExtractors, or a dict of BaseExtractors
if isinstance(value, BaseExtractor):
if not value.check_serializablility(type=type):
return False
elif isinstance(value, list):
for v in value:
if isinstance(v, BaseExtractor) and not v.check_serializablility(type=type):
return False
elif isinstance(value, dict):
for v in value.values():
if isinstance(v, BaseExtractor) and not v.check_serializablility(type=type):
return False
return self._serializablility[type]

def check_if_memory_serializable(self):
"""
Check if the object is serializable to memory with pickle, including nested objects.
Returns
-------
bool
True if the object is dumpable, False otherwise.
True if the object is memory serializable, False otherwise.
"""
kwargs = self._kwargs
for value in kwargs.values():
# here we check if the value is a BaseExtractor, a list of BaseExtractors, or a dict of BaseExtractors
if isinstance(value, BaseExtractor):
return value.check_if_dumpable()
elif isinstance(value, list) and (len(value) > 0) and isinstance(value[0], BaseExtractor):
return all([v.check_if_dumpable() for v in value])
elif isinstance(value, dict) and isinstance(value[list(value.keys())[0]], BaseExtractor):
return all([v.check_if_dumpable() for k, v in value.items()])
return self._is_dumpable
return self.check_serializablility("memory")

def check_if_json_serializable(self):
"""
Expand All @@ -499,16 +507,13 @@ def check_if_json_serializable(self):
bool
True if the object is json serializable, False otherwise.
"""
kwargs = self._kwargs
for value in kwargs.values():
# here we check if the value is a BaseExtractor, a list of BaseExtractors, or a dict of BaseExtractors
if isinstance(value, BaseExtractor):
return value.check_if_json_serializable()
elif isinstance(value, list) and (len(value) > 0) and isinstance(value[0], BaseExtractor):
return all([v.check_if_json_serializable() for v in value])
elif isinstance(value, dict) and isinstance(value[list(value.keys())[0]], BaseExtractor):
return all([v.check_if_json_serializable() for k, v in value.items()])
return self._is_json_serializable
# we keep this for backward compatilibity or not ????
# is this needed ??? I think no.
return self.check_serializablility("json")

def check_if_pickle_serializable(self):
# is this needed ??? I think no.
return self.check_serializablility("pickle")

@staticmethod
def _get_file_path(file_path: Union[str, Path], extensions: Sequence) -> Path:
Expand Down Expand Up @@ -557,7 +562,7 @@ def dump(self, file_path: Union[str, Path], relative_to=None, folder_metadata=No
if str(file_path).endswith(".json"):
self.dump_to_json(file_path, relative_to=relative_to, folder_metadata=folder_metadata)
elif str(file_path).endswith(".pkl") or str(file_path).endswith(".pickle"):
self.dump_to_pickle(file_path, relative_to=relative_to, folder_metadata=folder_metadata)
self.dump_to_pickle(file_path, folder_metadata=folder_metadata)
else:
raise ValueError("Dump: file must .json or .pkl")

Expand All @@ -576,7 +581,7 @@ def dump_to_json(self, file_path: Union[str, Path, None] = None, relative_to=Non
folder_metadata: str, Path, or None
Folder with files containing additional information (e.g. probe in BaseRecording) and properties.
"""
assert self.check_if_json_serializable(), "The extractor is not json serializable"
assert self.check_serializablility("json"), "The extractor is not json serializable"

# Writing paths as relative_to requires recursively expanding the dict
if relative_to:
Expand Down Expand Up @@ -616,7 +621,7 @@ def dump_to_pickle(
folder_metadata: str, Path, or None
Folder with files containing additional information (e.g. probe in BaseRecording) and properties.
"""
assert self.check_if_dumpable(), "The extractor is not dumpable"
assert self.check_if_pickle_serializable(), "The extractor is not serializable to file with pickle"

dump_dict = self.to_dict(
include_annotations=True,
Expand Down Expand Up @@ -653,8 +658,8 @@ def load(file_path: Union[str, Path], base_folder: Optional[Union[Path, str, boo
d = pickle.load(f)
else:
raise ValueError(f"Impossible to load {file_path}")
if "warning" in d and "not dumpable" in d["warning"]:
print("The extractor was not dumpable")
if "warning" in d:
print("The extractor was not serializable to file")
return None
extractor = BaseExtractor.from_dict(d, base_folder=base_folder)
return extractor
Expand Down Expand Up @@ -814,10 +819,12 @@ def save_to_folder(self, name=None, folder=None, verbose=True, **save_kwargs):

# dump provenance
provenance_file = folder / f"provenance.json"
if self.check_if_json_serializable():
if self.check_serializablility("json"):
self.dump(provenance_file)
else:
provenance_file.write_text(json.dumps({"warning": "the provenace is not dumpable!!!"}), encoding="utf8")
provenance_file.write_text(
json.dumps({"warning": "the provenace is not json serializable!!!"}), encoding="utf8"
)

self.save_metadata_to_folder(folder)

Expand Down Expand Up @@ -911,7 +918,7 @@ def save_to_zarr(

zarr_root = zarr.open(zarr_path_init, mode="w", storage_options=storage_options)

if self.check_if_dumpable():
if self.check_if_json_serializable():
zarr_root.attrs["provenance"] = check_json(self.to_dict())
else:
zarr_root.attrs["provenance"] = None
Expand Down
4 changes: 4 additions & 0 deletions src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1084,6 +1084,8 @@ def __init__(
dtype = parent_recording.dtype if parent_recording is not None else templates.dtype
BaseRecording.__init__(self, sorting.get_sampling_frequency(), channel_ids, dtype)

# Important : self._serializablility is not change here because it will depend on the sorting parents itself.

n_units = len(sorting.unit_ids)
assert len(templates) == n_units
self.spike_vector = sorting.to_spike_vector()
Expand Down Expand Up @@ -1459,5 +1461,7 @@ def generate_ground_truth_recording(
)
recording.annotate(is_filtered=True)
recording.set_probe(probe, in_place=True)
recording.set_channel_gains(1.0)
recording.set_channel_offsets(0.0)

return recording, sorting
6 changes: 3 additions & 3 deletions src/spikeinterface/core/job_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,11 +167,11 @@ def ensure_n_jobs(recording, n_jobs=1):
print(f"Python {sys.version} does not support parallel processing")
n_jobs = 1

if not recording.check_if_dumpable():
if not recording.check_if_memory_serializable():
if n_jobs != 1:
raise RuntimeError(
"Recording is not dumpable and can't be processed in parallel. "
"You can use the `recording.save()` function to make it dumpable or set 'n_jobs' to 1."
"Recording is not serializable to memory and can't be processed in parallel. "
"You can use the `rec = recording.save(folder=...)` function or set 'n_jobs' to 1."
)

return n_jobs
Expand Down
Loading

0 comments on commit 537ebbb

Please sign in to comment.