Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix open ephys probe loading and unify probeinterface import syntax #2136

Merged
merged 5 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/spikeinterface/core/tests/test_channelslicerecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
import numpy as np

import probeinterface as pi
import probeinterface

from spikeinterface.core import ChannelSliceRecording, BinaryRecordingExtractor

Expand Down Expand Up @@ -58,7 +58,7 @@ def test_ChannelSliceRecording():
assert np.all(traces[:, 1] == 0)

# with probe and after save()
probe = pi.generate_linear_probe(num_elec=num_chan)
probe = probeinterface.generate_linear_probe(num_elec=num_chan)
probe.set_device_channel_indices(np.arange(num_chan))
rec_p = rec.set_probe(probe)
rec_sliced3 = ChannelSliceRecording(rec_p, channel_ids=[0, 2], renamed_channel_ids=[3, 4])
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/extractors/bids.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np

import neo
from probeinterface import read_BIDS_probe
import probeinterface

from .nwbextractors import read_nwb
from .neoextractors import read_nix
Expand Down Expand Up @@ -60,7 +60,7 @@ def read_bids(folder_path):


def _read_probe_group(folder, bids_name, recording_channel_ids):
probegroup = read_BIDS_probe(folder)
probegroup = probeinterface.read_BIDS_probe(folder)

# make maps between : channel_id and contact_id using _channels.tsv
import pandas as pd
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/extractors/cbin_ibl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from pathlib import Path

import probeinterface as pi
import probeinterface

from spikeinterface.core import BaseRecording, BaseRecordingSegment
from spikeinterface.extractors.neuropixels_utils import get_neuropixels_sample_shifts
Expand Down Expand Up @@ -89,7 +89,7 @@ def __init__(self, folder_path, load_sync_channel=False, stream_name="ap"):
self.set_channel_offsets(offsets)

if not load_sync_channel:
probe = pi.read_spikeglx(meta_file)
probe = probeinterface.read_spikeglx(meta_file)

if probe.shank_ids is not None:
self.set_probe(probe, in_place=True, group_mode="by_shank")
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/extractors/iblstreamingrecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pathlib import Path

import numpy as np
import probeinterface as pi
import probeinterface

from spikeinterface.core import BaseRecording, BaseRecordingSegment
from spikeinterface.core.core_tools import define_function_from_class
Expand Down Expand Up @@ -165,7 +165,7 @@ def __init__(

# set probe
if not load_sync_channel:
probe = pi.read_spikeglx(meta_file)
probe = probeinterface.read_spikeglx(meta_file)

if probe.shank_ids is not None:
self.set_probe(probe, in_place=True, group_mode="by_shank")
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/extractors/neoextractors/biocam.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from pathlib import Path

import probeinterface as pi
import probeinterface

from spikeinterface.core.core_tools import define_function_from_class

Expand Down Expand Up @@ -54,7 +54,7 @@ def __init__(
probe_kwargs["mea_pitch"] = mea_pitch
if electrode_width is not None:
probe_kwargs["electrode_width"] = electrode_width
probe = pi.read_3brain(file_path, **probe_kwargs)
probe = probeinterface.read_3brain(file_path, **probe_kwargs)
self.set_probe(probe, in_place=True)
self.set_property("row", self.get_property("contact_vector")["row"])
self.set_property("col", self.get_property("contact_vector")["col"])
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/extractors/neoextractors/maxwell.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
from pathlib import Path

import probeinterface as pi
import probeinterface

from spikeinterface import BaseEvent, BaseEventSegment
from spikeinterface.core.core_tools import define_function_from_class
Expand Down Expand Up @@ -68,7 +68,7 @@ def __init__(
well_name = self.stream_id
# rec_name auto set by neo
rec_name = self.neo_reader.rec_name
probe = pi.read_maxwell(file_path, well_name=well_name, rec_name=rec_name)
probe = probeinterface.read_maxwell(file_path, well_name=well_name, rec_name=rec_name)
self.set_probe(probe, in_place=True)
self.set_property("electrode", self.get_property("contact_vector")["electrode"])
self._kwargs.update(dict(file_path=str(Path(file_path).absolute()), rec_name=rec_name))
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/extractors/neoextractors/mearec.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import numpy as np

import probeinterface as pi
import probeinterface

from .neobaseextractor import NeoBaseRecordingExtractor, NeoBaseSortingExtractor

Expand Down Expand Up @@ -48,7 +48,7 @@ def __init__(self, file_path: Union[str, Path], all_annotations: bool = False):

self.extra_requirements.append("mearec")

probe = pi.read_mearec(file_path)
probe = probeinterface.read_mearec(file_path)
probe.annotations["mearec_name"] = str(probe.annotations["mearec_name"])
self.set_probe(probe, in_place=True)
self.annotate(is_filtered=True)
Expand Down
23 changes: 15 additions & 8 deletions src/spikeinterface/extractors/neoextractors/openephys.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,30 @@
"""

There are two extractors for data saved by the Open Ephys GUI

* OpenEphysLegacyRecordingExtractor: reads the original "Open Ephys" data format
* OpenEphysBinaryRecordingExtractor: reads the new default "Binary" format

See https://open-ephys.github.io/gui-docs/User-Manual/Recording-data/index.html
for more info.

"""

from pathlib import Path

import numpy as np
import warnings

import probeinterface as pi
import probeinterface

from .neobaseextractor import NeoBaseRecordingExtractor, NeoBaseSortingExtractor, NeoBaseEventExtractor

from spikeinterface.extractors.neuropixels_utils import get_neuropixels_sample_shifts


def drop_invalid_neo_arguments_for_version_0_12_0(neo_kwargs):
# Temporary function until neo version 0.13.0 is released
from packaging.version import Version
from importlib.metadata import version as lib_version

# Temporary function until neo version 0.13.0 is released
neo_version = lib_version("neo")
# The possibility of ignoring timestamps errors is not present in neo <= 0.12.0
if Version(neo_version) <= Version("0.12.0"):
Expand Down Expand Up @@ -178,7 +176,9 @@ def __init__(
settings_file = self.neo_reader.folder_structure[record_node]["experiments"][exp_id]["settings_file"]

if Path(settings_file).is_file():
probe = pi.read_openephys(settings_file=settings_file, stream_name=stream_name, raise_error=False)
probe = probeinterface.read_openephys(
settings_file=settings_file, stream_name=stream_name, raise_error=False
)
else:
probe = None

Expand All @@ -187,9 +187,16 @@ def __init__(
self.set_probe(probe, in_place=True, group_mode="by_shank")
else:
self.set_probe(probe, in_place=True)
probe_name = probe.annotations["probe_name"]

# this handles a breaking change in probeinterface after v0.2.18
# in the new version, the Neuropixels model name is stored in the "model_name" annotation,
# rather than in the "probe_name" annotation
model_name = probe.annotations.get("model_name", None)
if model_name is None:
model_name = probe.annotations["probe_name"]

# load num_channels_per_adc depending on probe type
if "2.0" in probe_name:
if "2.0" in model_name:
num_channels_per_adc = 16
num_cycles_in_adc = 16
total_channels = 384
Expand All @@ -203,7 +210,7 @@ def __init__(
sample_shifts = get_neuropixels_sample_shifts(total_channels, num_channels_per_adc, num_cycles_in_adc)
if self.get_num_channels() != total_channels:
# need slice because not all channel are saved
chans = pi.get_saved_channel_indices_from_openephys_settings(settings_file, oe_stream)
chans = probeinterface.get_saved_channel_indices_from_openephys_settings(settings_file, oe_stream)
# lets clip to 384 because this contains also the synchro channel
chans = chans[chans < total_channels]
sample_shifts = sample_shifts[chans]
Expand Down
6 changes: 3 additions & 3 deletions src/spikeinterface/extractors/neoextractors/spikeglx.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pathlib import Path

import neo
import probeinterface as pi
import probeinterface

from spikeinterface.extractors.neuropixels_utils import get_neuropixels_sample_shifts

Expand Down Expand Up @@ -60,7 +60,7 @@ def __init__(self, folder_path, load_sync_channel=False, stream_id=None, stream_
# Load probe geometry if available
if "lf" in self.stream_id:
meta_filename = meta_filename.replace(".lf", ".ap")
probe = pi.read_spikeglx(meta_filename)
probe = probeinterface.read_spikeglx(meta_filename)

if probe.shank_ids is not None:
self.set_probe(probe, in_place=True, group_mode="by_shank")
Expand All @@ -84,7 +84,7 @@ def __init__(self, folder_path, load_sync_channel=False, stream_id=None, stream_
sample_shifts = get_neuropixels_sample_shifts(total_channels, num_channels_per_adc, num_cycles_in_adc)
if self.get_num_channels() != total_channels:
# need slice because not all channel are saved
chans = pi.get_saved_channel_indices_from_spikeglx_meta(meta_filename)
chans = probeinterface.get_saved_channel_indices_from_spikeglx_meta(meta_filename)
# lets clip to 384 because this contains also the synchro channel
chans = chans[chans < total_channels]
sample_shifts = sample_shifts[chans]
Expand Down
6 changes: 3 additions & 3 deletions src/spikeinterface/extractors/shybridextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np

from probeinterface import read_prb, write_prb
import probeinterface

from spikeinterface.core import BinaryRecordingExtractor, BaseRecordingSegment, BaseSorting, BaseSortingSegment
from spikeinterface.core.core_tools import write_binary_recording, define_function_from_class
Expand Down Expand Up @@ -69,7 +69,7 @@ def __init__(self, file_path):
)

# load probe file
probegroup = read_prb(params["probe"])
probegroup = probeinterface.read_prb(params["probe"])
self.set_probegroup(probegroup, in_place=True)
self._kwargs = {"file_path": str(Path(file_path).absolute())}
self.extra_requirements.extend(["hybridizer", "pyyaml"])
Expand Down Expand Up @@ -119,7 +119,7 @@ def write_recording(recording, save_path, initial_sorting_fn, dtype="float32", *
# write probe file
probe_fn = (save_path / probe_name).absolute()
probegroup = recording.get_probegroup()
write_prb(probe_fn, probegroup, total_nb_channels=recording.get_num_channels())
probeinterface.write_prb(probe_fn, probegroup, total_nb_channels=recording.get_num_channels())

# create parameters file
parameters = dict(
Expand Down
4 changes: 3 additions & 1 deletion src/spikeinterface/extractors/tests/test_neoextractors.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import unittest
import platform
import subprocess
import os
from packaging import version

import pytest
import numpy as np

from spikeinterface.core.testing import check_recordings_equal
from spikeinterface import get_global_dataset_folder
Expand All @@ -16,6 +16,7 @@
EventCommonTestSuite,
)

ON_GITHUB = bool(os.getenv("GITHUB_ACTIONS"))
local_folder = get_global_dataset_folder() / "ephy_testing_data"


Expand Down Expand Up @@ -277,6 +278,7 @@ class CedRecordingTest(RecordingCommonTestSuite, unittest.TestCase):
]


@pytest.mark.skipif(ON_GITHUB, reason="Maxwell plugin not installed on GitHub")
class MaxwellRecordingTest(RecordingCommonTestSuite, unittest.TestCase):
ExtractorClass = MaxwellRecordingExtractor
downloads = ["maxwell"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np
from pathlib import Path

import probeinterface as pi
import probeinterface
from spikeinterface import download_dataset, generate_recording, append_recordings, concatenate_recordings
from spikeinterface.extractors import read_mearec, read_spikeglx, read_openephys
from spikeinterface.preprocessing import depth_order, zscore
Expand All @@ -29,7 +29,7 @@
def recording_and_shape():
num_cols = 2
num_rows = 64
probe = pi.generate_multi_columns_probe(num_columns=num_cols, num_contact_per_column=num_rows)
probe = probeinterface.generate_multi_columns_probe(num_columns=num_cols, num_contact_per_column=num_rows)
probe.set_device_channel_indices(np.arange(num_cols * num_rows))
recording = generate_recording(num_channels=num_cols * num_rows, durations=[10.0], sampling_frequency=30000)
recording.set_probe(probe, in_place=True)
Expand Down