Skip to content

Commit

Permalink
Merge pull request #2136 from alejoe91/probeinterface-update
Browse files Browse the repository at this point in the history
Fix open ephys probe loading and unify probeinterface import syntax
  • Loading branch information
samuelgarcia authored Oct 30, 2023
2 parents 508c4de + 1f37858 commit 4bacfe3
Show file tree
Hide file tree
Showing 12 changed files with 40 additions and 31 deletions.
4 changes: 2 additions & 2 deletions src/spikeinterface/core/tests/test_channelslicerecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
import numpy as np

import probeinterface as pi
import probeinterface

from spikeinterface.core import ChannelSliceRecording, BinaryRecordingExtractor

Expand Down Expand Up @@ -58,7 +58,7 @@ def test_ChannelSliceRecording():
assert np.all(traces[:, 1] == 0)

# with probe and after save()
probe = pi.generate_linear_probe(num_elec=num_chan)
probe = probeinterface.generate_linear_probe(num_elec=num_chan)
probe.set_device_channel_indices(np.arange(num_chan))
rec_p = rec.set_probe(probe)
rec_sliced3 = ChannelSliceRecording(rec_p, channel_ids=[0, 2], renamed_channel_ids=[3, 4])
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/extractors/bids.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np

import neo
from probeinterface import read_BIDS_probe
import probeinterface

from .nwbextractors import read_nwb
from .neoextractors import read_nix
Expand Down Expand Up @@ -60,7 +60,7 @@ def read_bids(folder_path):


def _read_probe_group(folder, bids_name, recording_channel_ids):
probegroup = read_BIDS_probe(folder)
probegroup = probeinterface.read_BIDS_probe(folder)

# make maps between : channel_id and contact_id using _channels.tsv
import pandas as pd
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/extractors/cbin_ibl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from pathlib import Path

import probeinterface as pi
import probeinterface

from spikeinterface.core import BaseRecording, BaseRecordingSegment
from spikeinterface.extractors.neuropixels_utils import get_neuropixels_sample_shifts
Expand Down Expand Up @@ -89,7 +89,7 @@ def __init__(self, folder_path, load_sync_channel=False, stream_name="ap"):
self.set_channel_offsets(offsets)

if not load_sync_channel:
probe = pi.read_spikeglx(meta_file)
probe = probeinterface.read_spikeglx(meta_file)

if probe.shank_ids is not None:
self.set_probe(probe, in_place=True, group_mode="by_shank")
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/extractors/iblstreamingrecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pathlib import Path

import numpy as np
import probeinterface as pi
import probeinterface

from spikeinterface.core import BaseRecording, BaseRecordingSegment
from spikeinterface.core.core_tools import define_function_from_class
Expand Down Expand Up @@ -165,7 +165,7 @@ def __init__(

# set probe
if not load_sync_channel:
probe = pi.read_spikeglx(meta_file)
probe = probeinterface.read_spikeglx(meta_file)

if probe.shank_ids is not None:
self.set_probe(probe, in_place=True, group_mode="by_shank")
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/extractors/neoextractors/biocam.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from pathlib import Path

import probeinterface as pi
import probeinterface

from spikeinterface.core.core_tools import define_function_from_class

Expand Down Expand Up @@ -54,7 +54,7 @@ def __init__(
probe_kwargs["mea_pitch"] = mea_pitch
if electrode_width is not None:
probe_kwargs["electrode_width"] = electrode_width
probe = pi.read_3brain(file_path, **probe_kwargs)
probe = probeinterface.read_3brain(file_path, **probe_kwargs)
self.set_probe(probe, in_place=True)
self.set_property("row", self.get_property("contact_vector")["row"])
self.set_property("col", self.get_property("contact_vector")["col"])
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/extractors/neoextractors/maxwell.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
from pathlib import Path

import probeinterface as pi
import probeinterface

from spikeinterface import BaseEvent, BaseEventSegment
from spikeinterface.core.core_tools import define_function_from_class
Expand Down Expand Up @@ -68,7 +68,7 @@ def __init__(
well_name = self.stream_id
# rec_name auto set by neo
rec_name = self.neo_reader.rec_name
probe = pi.read_maxwell(file_path, well_name=well_name, rec_name=rec_name)
probe = probeinterface.read_maxwell(file_path, well_name=well_name, rec_name=rec_name)
self.set_probe(probe, in_place=True)
self.set_property("electrode", self.get_property("contact_vector")["electrode"])
self._kwargs.update(dict(file_path=str(Path(file_path).absolute()), rec_name=rec_name))
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/extractors/neoextractors/mearec.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import numpy as np

import probeinterface as pi
import probeinterface

from .neobaseextractor import NeoBaseRecordingExtractor, NeoBaseSortingExtractor

Expand Down Expand Up @@ -48,7 +48,7 @@ def __init__(self, file_path: Union[str, Path], all_annotations: bool = False):

self.extra_requirements.append("mearec")

probe = pi.read_mearec(file_path)
probe = probeinterface.read_mearec(file_path)
probe.annotations["mearec_name"] = str(probe.annotations["mearec_name"])
self.set_probe(probe, in_place=True)
self.annotate(is_filtered=True)
Expand Down
23 changes: 15 additions & 8 deletions src/spikeinterface/extractors/neoextractors/openephys.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,30 @@
"""
There are two extractors for data saved by the Open Ephys GUI
* OpenEphysLegacyRecordingExtractor: reads the original "Open Ephys" data format
* OpenEphysBinaryRecordingExtractor: reads the new default "Binary" format
See https://open-ephys.github.io/gui-docs/User-Manual/Recording-data/index.html
for more info.
"""

from pathlib import Path

import numpy as np
import warnings

import probeinterface as pi
import probeinterface

from .neobaseextractor import NeoBaseRecordingExtractor, NeoBaseSortingExtractor, NeoBaseEventExtractor

from spikeinterface.extractors.neuropixels_utils import get_neuropixels_sample_shifts


def drop_invalid_neo_arguments_for_version_0_12_0(neo_kwargs):
# Temporary function until neo version 0.13.0 is released
from packaging.version import Version
from importlib.metadata import version as lib_version

# Temporary function until neo version 0.13.0 is released
neo_version = lib_version("neo")
# The possibility of ignoring timestamps errors is not present in neo <= 0.12.0
if Version(neo_version) <= Version("0.12.0"):
Expand Down Expand Up @@ -178,7 +176,9 @@ def __init__(
settings_file = self.neo_reader.folder_structure[record_node]["experiments"][exp_id]["settings_file"]

if Path(settings_file).is_file():
probe = pi.read_openephys(settings_file=settings_file, stream_name=stream_name, raise_error=False)
probe = probeinterface.read_openephys(
settings_file=settings_file, stream_name=stream_name, raise_error=False
)
else:
probe = None

Expand All @@ -187,9 +187,16 @@ def __init__(
self.set_probe(probe, in_place=True, group_mode="by_shank")
else:
self.set_probe(probe, in_place=True)
probe_name = probe.annotations["probe_name"]

# this handles a breaking change in probeinterface after v0.2.18
# in the new version, the Neuropixels model name is stored in the "model_name" annotation,
# rather than in the "probe_name" annotation
model_name = probe.annotations.get("model_name", None)
if model_name is None:
model_name = probe.annotations["probe_name"]

# load num_channels_per_adc depending on probe type
if "2.0" in probe_name:
if "2.0" in model_name:
num_channels_per_adc = 16
num_cycles_in_adc = 16
total_channels = 384
Expand All @@ -203,7 +210,7 @@ def __init__(
sample_shifts = get_neuropixels_sample_shifts(total_channels, num_channels_per_adc, num_cycles_in_adc)
if self.get_num_channels() != total_channels:
# need slice because not all channel are saved
chans = pi.get_saved_channel_indices_from_openephys_settings(settings_file, oe_stream)
chans = probeinterface.get_saved_channel_indices_from_openephys_settings(settings_file, oe_stream)
# lets clip to 384 because this contains also the synchro channel
chans = chans[chans < total_channels]
sample_shifts = sample_shifts[chans]
Expand Down
6 changes: 3 additions & 3 deletions src/spikeinterface/extractors/neoextractors/spikeglx.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pathlib import Path

import neo
import probeinterface as pi
import probeinterface

from spikeinterface.extractors.neuropixels_utils import get_neuropixels_sample_shifts

Expand Down Expand Up @@ -60,7 +60,7 @@ def __init__(self, folder_path, load_sync_channel=False, stream_id=None, stream_
# Load probe geometry if available
if "lf" in self.stream_id:
meta_filename = meta_filename.replace(".lf", ".ap")
probe = pi.read_spikeglx(meta_filename)
probe = probeinterface.read_spikeglx(meta_filename)

if probe.shank_ids is not None:
self.set_probe(probe, in_place=True, group_mode="by_shank")
Expand All @@ -84,7 +84,7 @@ def __init__(self, folder_path, load_sync_channel=False, stream_id=None, stream_
sample_shifts = get_neuropixels_sample_shifts(total_channels, num_channels_per_adc, num_cycles_in_adc)
if self.get_num_channels() != total_channels:
# need slice because not all channel are saved
chans = pi.get_saved_channel_indices_from_spikeglx_meta(meta_filename)
chans = probeinterface.get_saved_channel_indices_from_spikeglx_meta(meta_filename)
# lets clip to 384 because this contains also the synchro channel
chans = chans[chans < total_channels]
sample_shifts = sample_shifts[chans]
Expand Down
6 changes: 3 additions & 3 deletions src/spikeinterface/extractors/shybridextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np

from probeinterface import read_prb, write_prb
import probeinterface

from spikeinterface.core import BinaryRecordingExtractor, BaseRecordingSegment, BaseSorting, BaseSortingSegment
from spikeinterface.core.core_tools import write_binary_recording, define_function_from_class
Expand Down Expand Up @@ -69,7 +69,7 @@ def __init__(self, file_path):
)

# load probe file
probegroup = read_prb(params["probe"])
probegroup = probeinterface.read_prb(params["probe"])
self.set_probegroup(probegroup, in_place=True)
self._kwargs = {"file_path": str(Path(file_path).absolute())}
self.extra_requirements.extend(["hybridizer", "pyyaml"])
Expand Down Expand Up @@ -119,7 +119,7 @@ def write_recording(recording, save_path, initial_sorting_fn, dtype="float32", *
# write probe file
probe_fn = (save_path / probe_name).absolute()
probegroup = recording.get_probegroup()
write_prb(probe_fn, probegroup, total_nb_channels=recording.get_num_channels())
probeinterface.write_prb(probe_fn, probegroup, total_nb_channels=recording.get_num_channels())

# create parameters file
parameters = dict(
Expand Down
4 changes: 3 additions & 1 deletion src/spikeinterface/extractors/tests/test_neoextractors.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import unittest
import platform
import subprocess
import os
from packaging import version

import pytest
import numpy as np

from spikeinterface.core.testing import check_recordings_equal
from spikeinterface import get_global_dataset_folder
Expand All @@ -16,6 +16,7 @@
EventCommonTestSuite,
)

ON_GITHUB = bool(os.getenv("GITHUB_ACTIONS"))
local_folder = get_global_dataset_folder() / "ephy_testing_data"


Expand Down Expand Up @@ -277,6 +278,7 @@ class CedRecordingTest(RecordingCommonTestSuite, unittest.TestCase):
]


@pytest.mark.skipif(ON_GITHUB, reason="Maxwell plugin not installed on GitHub")
class MaxwellRecordingTest(RecordingCommonTestSuite, unittest.TestCase):
ExtractorClass = MaxwellRecordingExtractor
downloads = ["maxwell"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np
from pathlib import Path

import probeinterface as pi
import probeinterface
from spikeinterface import download_dataset, generate_recording, append_recordings, concatenate_recordings
from spikeinterface.extractors import read_mearec, read_spikeglx, read_openephys
from spikeinterface.preprocessing import depth_order, zscore
Expand All @@ -29,7 +29,7 @@
def recording_and_shape():
num_cols = 2
num_rows = 64
probe = pi.generate_multi_columns_probe(num_columns=num_cols, num_contact_per_column=num_rows)
probe = probeinterface.generate_multi_columns_probe(num_columns=num_cols, num_contact_per_column=num_rows)
probe.set_device_channel_indices(np.arange(num_cols * num_rows))
recording = generate_recording(num_channels=num_cols * num_rows, durations=[10.0], sampling_frequency=30000)
recording.set_probe(probe, in_place=True)
Expand Down

0 comments on commit 4bacfe3

Please sign in to comment.