Skip to content

Commit

Permalink
Merge branch 'main' into expose_attempts_in_plexon2
Browse files Browse the repository at this point in the history
  • Loading branch information
h-mayorquin authored Sep 16, 2024
2 parents bd06b6d + 55c7de1 commit a4cc12e
Show file tree
Hide file tree
Showing 29 changed files with 1,077 additions and 226 deletions.
10 changes: 9 additions & 1 deletion doc/development/development.rst
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,17 @@ so that the user knows what the options are.
Miscelleaneous Stylistic Conventions
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

#. Avoid using abreviations in variable names (e.g., use :code:`recording` instead of :code:`rec`). It is especially important to avoid single letter variables.
#. Avoid using abbreviations in variable names (e.g. use :code:`recording` instead of :code:`rec`). It is especially important to avoid single letter variables.
#. Use index as singular and indices for plural following the NumPy convention. Avoid idx or indexes. Plus, id and ids are reserved for identifiers (i.e. channel_ids)
#. We use file_path and folder_path (instead of file_name and folder_name) for clarity.
#. For creating headers to divide sections of code we use the following convention (see issue `#3019 <https://github.com/SpikeInterface/spikeinterface/issues/3019>`_):


.. code:: python
#########################################
# A header
#########################################
How to build the documentation
Expand Down
3 changes: 0 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ preprocessing = [
full = [
"h5py",
"pandas",
"xarray",
"scipy",
"scikit-learn",
"networkx",
Expand Down Expand Up @@ -148,7 +147,6 @@ test = [
"pytest-dependency",
"pytest-cov",

"xarray",
"huggingface_hub",

# preprocessing
Expand Down Expand Up @@ -193,7 +191,6 @@ docs = [
"pandas", # in the modules gallery comparison tutorial
"hdbscan>=0.8.33", # For sorters spykingcircus2 + tridesclous
"numba", # For many postprocessing functions
"xarray", # For use of SortingAnalyzer zarr format
"networkx",
# Download data
"pooch>=1.8.2",
Expand Down
19 changes: 17 additions & 2 deletions src/spikeinterface/core/baserecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,7 +768,13 @@ def get_binary_description(self):
raise NotImplementedError

def binary_compatible_with(
self, dtype=None, time_axis=None, file_paths_lenght=None, file_offset=None, file_suffix=None
self,
dtype=None,
time_axis=None,
file_paths_length=None,
file_offset=None,
file_suffix=None,
file_paths_lenght=None,
):
"""
Check is the recording is binary compatible with some constrain on
Expand All @@ -779,6 +785,15 @@ def binary_compatible_with(
* file_offset
* file_suffix
"""

# spelling typo need to fix
if file_paths_lenght is not None:
warnings.warn(
"`file_paths_lenght` is deprecated and will be removed in 0.103.0 please use `file_paths_length`"
)
if file_paths_length is None:
file_paths_length = file_paths_lenght

if not self.is_binary_compatible():
return False

Expand All @@ -790,7 +805,7 @@ def binary_compatible_with(
if time_axis is not None and time_axis != d["time_axis"]:
return False

if file_paths_lenght is not None and file_paths_lenght != len(d["file_paths"]):
if file_paths_length is not None and file_paths_length != len(d["file_paths"]):
return False

if file_offset is not None and file_offset != d["file_offset"]:
Expand Down
27 changes: 15 additions & 12 deletions src/spikeinterface/core/baserecordingsnippets.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,11 @@ def _set_probes(self, probe_or_probegroup, group_mode="by_probe", in_place=False
else:
raise ValueError("must give Probe or ProbeGroup or list of Probe")

# check that the probe do not overlap
num_probes = len(probegroup.probes)
if num_probes > 1:
check_probe_do_not_overlap(probegroup.probes)

# handle not connected channels
assert all(
probe.device_channel_indices is not None for probe in probegroup.probes
Expand Down Expand Up @@ -234,7 +239,7 @@ def set_probes(self, probe_or_probegroup, group_mode="by_probe", in_place=False)

warning_msg = (
"`set_probes` is now a private function and the public function will be "
"removed in 0.103.0. Please use `set_probe` or `set_probegroups` instead"
"removed in 0.103.0. Please use `set_probe` or `set_probegroup` instead"
)

warn(warning_msg, category=DeprecationWarning, stacklevel=2)
Expand Down Expand Up @@ -348,17 +353,15 @@ def get_channel_locations(self, channel_ids=None, axes: str = "xy"):
if channel_ids is None:
channel_ids = self.get_channel_ids()
channel_indices = self.ids_to_indices(channel_ids)
if self.get_property("contact_vector") is not None:
if len(self.get_probes()) == 1:
probe = self.get_probe()
positions = probe.contact_positions[channel_indices]
else:
all_probes = self.get_probes()
# check that multiple probes are non-overlapping
check_probe_do_not_overlap(all_probes)
all_positions = np.vstack([probe.contact_positions for probe in all_probes])
positions = all_positions[channel_indices]
return select_axes(positions, axes)
contact_vector = self.get_property("contact_vector")
if contact_vector is not None:
# here we bypass the probe reconstruction so this works both for probe and probegroup
ndim = len(axes)
all_positions = np.zeros((contact_vector.size, ndim), dtype="float64")
for i, dim in enumerate(axes):
all_positions[:, i] = contact_vector[dim]
positions = all_positions[channel_indices]
return positions
else:
locations = self.get_property("location")
if locations is None:
Expand Down
12 changes: 7 additions & 5 deletions src/spikeinterface/core/recording_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,12 @@ def write_binary_recording(
data_size_bytes = dtype_size_bytes * num_frames * num_channels
file_size_bytes = data_size_bytes + byte_offset

file = open(file_path, "wb+")
file.truncate(file_size_bytes)
file.close()
# Create an empty file with file_size_bytes
with open(file_path, "wb+") as file:
# The previous implementation `file.truncate(file_size_bytes)` was slow on Windows (#3408)
file.seek(file_size_bytes - 1)
file.write(b"\0")

assert Path(file_path).is_file()

# use executor (loop or workers)
Expand Down Expand Up @@ -888,11 +891,10 @@ def check_probe_do_not_overlap(probes):

for j in range(i + 1, len(probes)):
probe_j = probes[j]

if np.any(
np.array(
[
x_bounds_i[0] < cp[0] < x_bounds_i[1] and y_bounds_i[0] < cp[1] < y_bounds_i[1]
x_bounds_i[0] <= cp[0] <= x_bounds_i[1] and y_bounds_i[0] <= cp[1] <= y_bounds_i[1]
for cp in probe_j.contact_positions
]
)
Expand Down
72 changes: 52 additions & 20 deletions src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import shutil
import warnings
import importlib
from packaging.version import parse
from time import perf_counter

import numpy as np
Expand Down Expand Up @@ -579,6 +580,20 @@ def load_from_zarr(cls, folder, recording=None, storage_options=None):

zarr_root = zarr.open_consolidated(str(folder), mode="r", storage_options=storage_options)

si_info = zarr_root.attrs["spikeinterface_info"]
if parse(si_info["version"]) < parse("0.101.1"):
# v0.101.0 did not have a consolidate metadata step after computing extensions.
# Here we try to consolidate the metadata and throw a warning if it fails.
try:
zarr_root_a = zarr.open(str(folder), mode="a", storage_options=storage_options)
zarr.consolidate_metadata(zarr_root_a.store)
except Exception as e:
warnings.warn(
"The zarr store was not properly consolidated prior to v0.101.1. "
"This may lead to unexpected behavior in loading extensions. "
"Please consider re-generating the SortingAnalyzer object."
)

# load internal sorting in memory
sorting = NumpySorting.from_sorting(
ZarrSortingExtractor(folder, zarr_group="sorting", storage_options=storage_options),
Expand Down Expand Up @@ -1111,9 +1126,16 @@ def get_probe(self):
def get_channel_locations(self) -> np.ndarray:
# important note : contrary to recording
# this give all channel locations, so no kwargs like channel_ids and axes
all_probes = self.get_probegroup().probes
all_positions = np.vstack([probe.contact_positions for probe in all_probes])
return all_positions
probegroup = self.get_probegroup()
probe_as_numpy_array = probegroup.to_numpy(complete=True)
# we need to sort by device_channel_indices to ensure the order of locations is correct
probe_as_numpy_array = probe_as_numpy_array[np.argsort(probe_as_numpy_array["device_channel_indices"])]
ndim = probegroup.ndim
locations = np.zeros((probe_as_numpy_array.size, ndim), dtype="float64")
# here we only loop through xy because only 2d locations are supported
for i, dim in enumerate(["x", "y"][:ndim]):
locations[:, i] = probe_as_numpy_array[dim]
return locations

def channel_ids_to_indices(self, channel_ids) -> np.ndarray:
all_channel_ids = list(self.rec_attributes["channel_ids"])
Expand Down Expand Up @@ -1963,12 +1985,14 @@ def load_data(self):
if "dict" in ext_data_.attrs:
ext_data = ext_data_[0]
elif "dataframe" in ext_data_.attrs:
import xarray
import pandas as pd

ext_data = xarray.open_zarr(
ext_data_.store, group=f"{extension_group.name}/{ext_data_name}"
).to_pandas()
ext_data.index.rename("", inplace=True)
index = ext_data_["index"]
ext_data = pd.DataFrame(index=index)
for col in ext_data_.keys():
if col != "index":
ext_data.loc[:, col] = ext_data_[col][:]
ext_data = ext_data.convert_dtypes()
elif "object" in ext_data_.attrs:
ext_data = ext_data_[0]
else:
Expand Down Expand Up @@ -2024,12 +2048,21 @@ def run(self, save=True, **kwargs):
if save and not self.sorting_analyzer.is_read_only():
self._save_run_info()
self._save_data(**kwargs)
if self.format == "zarr":
import zarr

zarr.consolidate_metadata(self.sorting_analyzer._get_zarr_root().store)

def save(self, **kwargs):
self._save_params()
self._save_importing_provenance()
self._save_data(**kwargs)
self._save_run_info()
self._save_data(**kwargs)

if self.format == "zarr":
import zarr

zarr.consolidate_metadata(self.sorting_analyzer._get_zarr_root().store)

def _save_data(self, **kwargs):
if self.format == "memory":
Expand Down Expand Up @@ -2089,12 +2122,12 @@ def _save_data(self, **kwargs):
elif isinstance(ext_data, np.ndarray):
extension_group.create_dataset(name=ext_data_name, data=ext_data, compressor=compressor)
elif HAS_PANDAS and isinstance(ext_data, pd.DataFrame):
ext_data.to_xarray().to_zarr(
store=extension_group.store,
group=f"{extension_group.name}/{ext_data_name}",
mode="a",
)
extension_group[ext_data_name].attrs["dataframe"] = True
df_group = extension_group.create_group(ext_data_name)
# first we save the index
df_group.create_dataset(name="index", data=ext_data.index.to_numpy())
for col in ext_data.columns:
df_group.create_dataset(name=col, data=ext_data[col].to_numpy())
df_group.attrs["dataframe"] = True
else:
# any object
try:
Expand All @@ -2104,8 +2137,6 @@ def _save_data(self, **kwargs):
except:
raise Exception(f"Could not save {ext_data_name} as extension data")
extension_group[ext_data_name].attrs["object"] = True
# we need to re-consolidate
zarr.consolidate_metadata(self.sorting_analyzer._get_zarr_root().store)

def _reset_extension_folder(self):
"""
Expand Down Expand Up @@ -2233,9 +2264,10 @@ def get_pipeline_nodes(self):
return self._get_pipeline_nodes()

def get_data(self, *args, **kwargs):
assert self.run_info[
"run_completed"
], f"You must run the extension {self.extension_name} before retrieving data"
if self.run_info is not None:
assert self.run_info[
"run_completed"
], f"You must run the extension {self.extension_name} before retrieving data"
assert len(self.data) > 0, "Extension has been run but no data found."
return self._get_data(*args, **kwargs)

Expand Down
Loading

0 comments on commit a4cc12e

Please sign in to comment.