Skip to content

Commit

Permalink
Merge branch 'main' into add_colon_spaces
Browse files Browse the repository at this point in the history
  • Loading branch information
chrishalcrow authored Jun 3, 2024
2 parents a2fa1fd + dd8bac5 commit 2823979
Show file tree
Hide file tree
Showing 55 changed files with 414 additions and 461 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/core-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: '3.10'
python-version: '3.11'
- name: Install dependencies
run: |
git config --global user.email "[email protected]"
Expand All @@ -31,7 +31,7 @@ jobs:
pip install -e .[test_core]
- name: Test core with pytest
run: |
pytest -vv -ra --durations=0 --durations-min=0.001 src/spikeinterface/core | tee report.txt; test ${PIPESTATUS[0]} -eq 0 || exit 1
pytest -m "core" -vv -ra --durations=0 --durations-min=0.001 | tee report.txt; test $? -eq 0 || exit 1
shell: bash # Necessary for pipeline to work on windows
- name: Build test summary
run: |
Expand Down
20 changes: 8 additions & 12 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,31 +25,27 @@ def pytest_sessionstart(session):
for mark_name in mark_names:
(pytest.global_test_folder / mark_name).mkdir()


def pytest_collection_modifyitems(config, items):
"""
This function marks (in the pytest sense) the tests according to their name and file_path location
Marking them in turn allows the tests to be run by using the pytest -m marker_name option.
"""


# python 3.4/3.5 compat: rootdir = pathlib.Path(str(config.rootdir))
rootdir = Path(config.rootdir)

modules_location = rootdir / "src" / "spikeinterface"
for item in items:
rel_path = Path(item.fspath).relative_to(rootdir)
if "sorters" in str(rel_path):
if "/internal/" in str(rel_path):
rel_path = Path(item.fspath).relative_to(modules_location)
module = rel_path.parts[0]
if module == "sorters":
if "internal" in rel_path.parts:
item.add_marker("sorters_internal")
elif "/external/" in str(rel_path):
elif "external" in rel_path.parts:
item.add_marker("sorters_external")
else:
item.add_marker("sorters")
else:
for mark_name in mark_names:
if f"/{mark_name}/" in str(rel_path):
mark = getattr(pytest.mark, mark_name)
item.add_marker(mark)
item.add_marker(module)



def pytest_sessionfinish(session, exitstatus):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import numpy as np
from numpy.testing import assert_array_equal

import pandas as pd

from spikeinterface.extractors import NumpySorting
from spikeinterface.comparison import compare_sorter_to_ground_truth
Expand Down Expand Up @@ -57,6 +56,8 @@ def test_compare_sorter_to_ground_truth():
"pooled_with_average",
]
for method in methods:
import pandas as pd

perf_df = sc.get_performance(method=method, output="pandas")
assert isinstance(perf_df, (pd.Series, pd.DataFrame))
perf_dict = sc.get_performance(method=method, output="dict")
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/curation/tests/sv-sorting-curation.json
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"labelsByUnit":{"#2":["mua"],"#3":["mua"],"#4":["mua"],"#5":["accept"],"#6":["accept"],"#7":["accept"],"#8":["artifact"],"#9":["artifact"]},"mergeGroups":[["#8","#9"]]}
{"labelsByUnit":{"2":["mua"],"3":["mua"],"4":["mua"],"5":["accept"],"6":["accept"],"7":["accept"],"8":["artifact"],"9":["artifact"]},"mergeGroups":[[8,9]]}
27 changes: 11 additions & 16 deletions src/spikeinterface/curation/tests/test_sortingview_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import numpy as np

import spikeinterface as si
from spikeinterface.core import generate_sorting
import spikeinterface.extractors as se
from spikeinterface.extractors import read_mearec
from spikeinterface import set_global_tmp_folder
from spikeinterface.postprocessing import (
compute_correlograms,
Expand Down Expand Up @@ -34,8 +34,6 @@
# def generate_sortingview_curation_dataset():
# import spikeinterface.widgets as sw

# local_path = si.download_dataset(remote_path="mearec/mearec_test_10s.h5")
# recording, sorting = read_mearec(local_path)

# sorting_analyzer = si.create_sorting_analyzer(sorting, recording, format="memory")
# sorting_analyzer.compute("random_spikes")
Expand All @@ -50,23 +48,22 @@
# w = sw.plot_sorting_summary(sorting_analyzer, curation=True, backend="sortingview")

# # curation_link:
# # https://figurl.org/f?v=gs://figurl/spikesortingview-10&d=sha1://bd53f6b707f8121cadc901562a89b67aec81cc81&label=SpikeInterface%20-%20Sorting%20Summary
# # https://figurl.org/f?v=npm://@fi-sci/figurl-sortingview@12/dist&d=sha1://058ab901610aa9d29df565595a3cc2a81a1b08e5


@pytest.mark.skipif(ON_GITHUB and not KACHERY_CLOUD_SET, reason="Kachery cloud secrets not available")
def test_gh_curation():
"""
Test curation using GitHub URI.
"""
local_path = si.download_dataset(remote_path="mearec/mearec_test_10s.h5")
_, sorting = read_mearec(local_path)
sorting = generate_sorting(num_units=10)
# curated link:
# https://figurl.org/f?v=gs://figurl/spikesortingview-10&d=sha1://bd53f6b707f8121cadc901562a89b67aec81cc81&label=SpikeInterface%20-%20Sorting%20Summary&s={%22sortingCuration%22:%22gh://alejoe91/spikeinterface/fix-codecov/spikeinterface/curation/tests/sv-sorting-curation.json%22}
# https://figurl.org/f?v=npm://@fi-sci/figurl-sortingview@12/dist&d=sha1://058ab901610aa9d29df565595a3cc2a81a1b08e5
gh_uri = "gh://SpikeInterface/spikeinterface/main/src/spikeinterface/curation/tests/sv-sorting-curation.json"
sorting_curated_gh = apply_sortingview_curation(sorting, uri_or_json=gh_uri, verbose=True)

assert len(sorting_curated_gh.unit_ids) == 9
assert "#8-#9" in sorting_curated_gh.unit_ids
assert 1, 2 in sorting_curated_gh.unit_ids
assert "accept" in sorting_curated_gh.get_property_keys()
assert "mua" in sorting_curated_gh.get_property_keys()
assert "artifact" in sorting_curated_gh.get_property_keys()
Expand All @@ -86,18 +83,17 @@ def test_sha1_curation():
"""
Test curation using SHA1 URI.
"""
local_path = si.download_dataset(remote_path="mearec/mearec_test_10s.h5")
_, sorting = read_mearec(local_path)
sorting = generate_sorting(num_units=10)

# from SHA1
# curated link:
# https://figurl.org/f?v=gs://figurl/spikesortingview-10&d=sha1://bd53f6b707f8121cadc901562a89b67aec81cc81&label=SpikeInterface%20-%20Sorting%20Summary&s={%22sortingCuration%22:%22sha1://1182ba19671fcc7d3f8e0501b0f8c07fb9736c22%22}
sha1_uri = "sha1://1182ba19671fcc7d3f8e0501b0f8c07fb9736c22"
# https://figurl.org/f?v=npm://@fi-sci/figurl-sortingview@12/dist&d=sha1://058ab901610aa9d29df565595a3cc2a81a1b08e5
sha1_uri = "sha1://449a428e8824eef9ad9bcc3241e45a2cee02d381"
sorting_curated_sha1 = apply_sortingview_curation(sorting, uri_or_json=sha1_uri, verbose=True)
# print(f"From SHA: {sorting_curated_sha1}")

assert len(sorting_curated_sha1.unit_ids) == 9
assert "#8-#9" in sorting_curated_sha1.unit_ids
assert 1, 2 in sorting_curated_sha1.unit_ids
assert "accept" in sorting_curated_sha1.get_property_keys()
assert "mua" in sorting_curated_sha1.get_property_keys()
assert "artifact" in sorting_curated_sha1.get_property_keys()
Expand All @@ -116,16 +112,15 @@ def test_json_curation():
"""
Test curation using a JSON file.
"""
local_path = si.download_dataset(remote_path="mearec/mearec_test_10s.h5")
_, sorting = read_mearec(local_path)
sorting = generate_sorting(num_units=10)

# from curation.json
json_file = parent_folder / "sv-sorting-curation.json"
# print(f"Sorting: {sorting.get_unit_ids()}")
sorting_curated_json = apply_sortingview_curation(sorting, uri_or_json=json_file, verbose=True)

assert len(sorting_curated_json.unit_ids) == 9
assert "#8-#9" in sorting_curated_json.unit_ids
assert 1, 2 in sorting_curated_json.unit_ids
assert "accept" in sorting_curated_json.get_property_keys()
assert "mua" in sorting_curated_json.get_property_keys()
assert "artifact" in sorting_curated_json.get_property_keys()
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/extractors/tests/test_iblextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import numpy as np
from numpy.testing import assert_array_equal
import pytest
import requests

from spikeinterface.extractors import read_ibl_recording, read_ibl_sorting, IblRecordingExtractor

Expand All @@ -16,6 +15,7 @@
class TestDefaultIblRecordingExtractorApBand(TestCase):
@classmethod
def setUpClass(cls):
import requests
from one.api import ONE

cls.eid = EID
Expand Down
37 changes: 25 additions & 12 deletions src/spikeinterface/extractors/tests/test_nwbextractors.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,10 @@
import unittest
import unittest
from pathlib import Path
from tempfile import mkdtemp
from datetime import datetime


import pytest
import numpy as np
from pynwb import NWBHDF5IO
from hdmf_zarr import NWBZarrIO
from pynwb.ecephys import ElectricalSeries, LFP, FilteredEphys
from pynwb.testing.mock.file import mock_NWBFile
from pynwb.testing.mock.device import mock_Device
from pynwb.testing.mock.ecephys import mock_ElectricalSeries, mock_ElectrodeGroup, mock_electrodes

from spikeinterface.extractors import NwbRecordingExtractor, NwbSortingExtractor

from spikeinterface.extractors.tests.common_tests import RecordingCommonTestSuite, SortingCommonTestSuite
Expand All @@ -30,10 +23,12 @@ class NwbSortingTest(SortingCommonTestSuite, unittest.TestCase):
entities = []


from pynwb.testing.mock.ecephys import mock_ElectrodeGroup


def nwbfile_with_ecephys_content():
from pynwb.ecephys import ElectricalSeries, LFP, FilteredEphys
from pynwb.testing.mock.file import mock_NWBFile
from pynwb.testing.mock.device import mock_Device
from pynwb.testing.mock.ecephys import mock_ElectricalSeries, mock_ElectrodeGroup

to_micro_volts = 1e6

nwbfile = mock_NWBFile()
Expand Down Expand Up @@ -160,6 +155,9 @@ def nwbfile_with_ecephys_content():


def _generate_nwbfile(backend, file_path):
from pynwb import NWBHDF5IO
from hdmf_zarr import NWBZarrIO

nwbfile = nwbfile_with_ecephys_content()
if backend == "hdf5":
io_class = NWBHDF5IO
Expand Down Expand Up @@ -367,6 +365,9 @@ def test_failure_with_wrong_electrical_series_path(generate_nwbfile, use_pynwb):

@pytest.mark.parametrize("use_pynwb", [True, False])
def test_sorting_extraction_of_ragged_arrays(tmp_path, use_pynwb):
from pynwb import NWBHDF5IO
from pynwb.testing.mock.file import mock_NWBFile

nwbfile = mock_NWBFile()

# Add the spikes
Expand Down Expand Up @@ -433,6 +434,10 @@ def test_sorting_extraction_of_ragged_arrays(tmp_path, use_pynwb):

@pytest.mark.parametrize("use_pynwb", [True, False])
def test_sorting_extraction_start_time(tmp_path, use_pynwb):

from pynwb import NWBHDF5IO
from pynwb.testing.mock.file import mock_NWBFile

nwbfile = mock_NWBFile()

# Add the spikes
Expand Down Expand Up @@ -477,6 +482,12 @@ def test_sorting_extraction_start_time(tmp_path, use_pynwb):

@pytest.mark.parametrize("use_pynwb", [True, False])
def test_sorting_extraction_start_time_from_series(tmp_path, use_pynwb):
from pynwb import NWBHDF5IO
from pynwb.testing.mock.file import mock_NWBFile
from pynwb.ecephys import ElectricalSeries, LFP, FilteredEphys

from pynwb.testing.mock.ecephys import mock_electrodes

nwbfile = mock_NWBFile()
electrical_series_name = "ElectricalSeries"
t_start = 10.0
Expand Down Expand Up @@ -530,6 +541,8 @@ def test_sorting_extraction_start_time_from_series(tmp_path, use_pynwb):
@pytest.mark.parametrize("use_pynwb", [True, False])
def test_multiple_unit_tables(tmp_path, use_pynwb):
from pynwb.misc import Units
from pynwb import NWBHDF5IO
from pynwb.testing.mock.file import mock_NWBFile

nwbfile = mock_NWBFile()

Expand Down
Original file line number Diff line number Diff line change
@@ -1,54 +1,15 @@
from pathlib import Path
import pickle
from tabnanny import check

import pytest
import numpy as np
import h5py

from spikeinterface import load_extractor
from spikeinterface.core.testing import check_recordings_equal
from spikeinterface.core.testing import check_recordings_equal, check_sortings_equal
from spikeinterface.extractors import NwbRecordingExtractor, NwbSortingExtractor


@pytest.mark.streaming_extractors
@pytest.mark.skipif("ros3" not in h5py.registered_drivers(), reason="ROS3 driver not installed")
def test_recording_s3_nwb_ros3(tmp_path):
file_path = (
"https://dandi-api-staging-dandisets.s3.amazonaws.com/blobs/5f4/b7a/5f4b7a1f-7b95-4ad8-9579-4df6025371cc"
)
rec = NwbRecordingExtractor(file_path, stream_mode="ros3")

start_frame = 0
end_frame = 300
num_frames = end_frame - start_frame

num_seg = rec.get_num_segments()
num_chans = rec.get_num_channels()
dtype = rec.get_dtype()

for segment_index in range(num_seg):
num_samples = rec.get_num_samples(segment_index=segment_index)

full_traces = rec.get_traces(segment_index=segment_index, start_frame=start_frame, end_frame=end_frame)
assert full_traces.shape == (num_frames, num_chans)
assert full_traces.dtype == dtype

if rec.has_scaleable_traces():
trace_scaled = rec.get_traces(segment_index=segment_index, return_scaled=True, end_frame=2)
assert trace_scaled.dtype == "float32"

tmp_file = tmp_path / "test_ros3_recording.pkl"
with open(tmp_file, "wb") as f:
pickle.dump(rec, f)

with open(tmp_file, "rb") as f:
reloaded_recording = pickle.load(f)

check_recordings_equal(rec, reloaded_recording)


@pytest.mark.streaming_extractors
@pytest.mark.parametrize("cache", [True, False]) # Test with and without cache
def test_recording_s3_nwb_fsspec(tmp_path, cache):
Expand Down Expand Up @@ -154,37 +115,6 @@ def test_recording_s3_nwb_remfile_file_like(tmp_path):
check_recordings_equal(rec, rec2)


@pytest.mark.streaming_extractors
@pytest.mark.skipif("ros3" not in h5py.registered_drivers(), reason="ROS3 driver not installed")
def test_sorting_s3_nwb_ros3(tmp_path):
file_path = "https://dandiarchive.s3.amazonaws.com/blobs/84b/aa4/84baa446-cf19-43e8-bdeb-fc804852279b"
# we provide the 'sampling_frequency' because the NWB file does not the electrical series
sort = NwbSortingExtractor(file_path, sampling_frequency=30000, stream_mode="ros3", t_start=0)

start_frame = 0
end_frame = 300
num_frames = end_frame - start_frame

num_seg = sort.get_num_segments()
num_units = len(sort.unit_ids)

for segment_index in range(num_seg):
for unit in sort.unit_ids:
spike_train = sort.get_unit_spike_train(unit_id=unit, segment_index=segment_index)
assert len(spike_train) > 0
assert spike_train.dtype == "int64"
assert np.all(spike_train >= 0)

tmp_file = tmp_path / "test_ros3_sorting.pkl"
with open(tmp_file, "wb") as f:
pickle.dump(sort, f)

with open(tmp_file, "rb") as f:
reloaded_sorting = pickle.load(f)

check_sortings_equal(reloaded_sorting, sort)


@pytest.mark.streaming_extractors
@pytest.mark.parametrize("cache", [True, False]) # Test with and without cache
def test_sorting_s3_nwb_fsspec(tmp_path, cache):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@

import pytest
import numpy as np
import pandas as pd
import shutil
import platform
from pathlib import Path

from spikeinterface.core import generate_ground_truth_recording
Expand Down
1 change: 0 additions & 1 deletion src/spikeinterface/preprocessing/motion.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import numpy as np
import json
import copy

from spikeinterface.core import get_noise_levels, fix_job_kwargs
from spikeinterface.core.job_tools import _shared_job_kwargs_doc
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import pytest
import numpy as np
import scipy.stats

from spikeinterface import NumpyRecording, get_random_data_chunks
from probeinterface import generate_linear_probe
Expand Down Expand Up @@ -167,6 +166,8 @@ def test_detect_bad_channels_ibl(num_channels):
channel_flags_ibl[:, i] = channel_flags

# Take the mode of the chunk estimates as final result. Convert to binary good / bad channel output.
import scipy.stats

bad_channel_labels_ibl, _ = scipy.stats.mode(channel_flags_ibl, axis=1, keepdims=False)

# Compare
Expand Down
Loading

0 comments on commit 2823979

Please sign in to comment.