diff --git a/.github/scripts/README.MD b/.github/scripts/README.MD new file mode 100644 index 0000000000..1d3a622aae --- /dev/null +++ b/.github/scripts/README.MD @@ -0,0 +1,2 @@ +This folder contains test scripts for running in the CI, that are not run as part of the usual +CI because they are too long / heavy. These are run on cron-jobs once per week. diff --git a/.github/scripts/check_kilosort4_releases.py b/.github/scripts/check_kilosort4_releases.py new file mode 100644 index 0000000000..3d04d6948a --- /dev/null +++ b/.github/scripts/check_kilosort4_releases.py @@ -0,0 +1,20 @@ +import os +import re +from pathlib import Path +import requests +import json + + +def get_pypi_versions(package_name): + url = f"https://pypi.org/pypi/{package_name}/json" + response = requests.get(url) + response.raise_for_status() + data = response.json() + return list(sorted(data["releases"].keys())) + + +if __name__ == "__main__": + package_name = "kilosort" + versions = get_pypi_versions(package_name) + with open(Path(os.path.realpath(__file__)).parent / "kilosort4-latest-version.json", "w") as f: + json.dump(versions, f) diff --git a/src/spikeinterface/temp_test_file_dir/test_kilosort4_new.py b/.github/scripts/test_kilosort4_ci.py similarity index 83% rename from src/spikeinterface/temp_test_file_dir/test_kilosort4_new.py rename to .github/scripts/test_kilosort4_ci.py index e4d48a1344..4684038bd0 100644 --- a/src/spikeinterface/temp_test_file_dir/test_kilosort4_new.py +++ b/.github/scripts/test_kilosort4_ci.py @@ -12,6 +12,14 @@ from kilosort.parameters import DEFAULT_SETTINGS from packaging.version import parse from importlib.metadata import version +from inspect import signature +from kilosort.run_kilosort import (set_files, initialize_ops, + compute_preprocessing, + compute_drift_correction, detect_spikes, + cluster_spikes, save_sorting, + get_run_parameters, ) +from kilosort.io import load_probe, RecordingExtractorAsArray, BinaryFiltered +from kilosort.parameters import DEFAULT_SETTINGS # TODO: save_preprocesed_copy is misspelled in KS4. # TODO: duplicate_spike_bins to duplicate_spike_ms @@ -190,6 +198,102 @@ def test_default_settings_all_represented(self): if param_key not in ["n_chan_bin", "fs", "tmin", "tmax"]: assert param_key in tested_keys, f"param: {param_key} in DEFAULT SETTINGS but not tested." + def test_set_files_arguments(self): + self._check_arguments( + set_files, + ["settings", "filename", "probe", "probe_name", "data_dir", "results_dir"] + ) + + def test_initialize_ops_arguments(self): + + expected_arguments = ["settings", "probe", "data_dtype", "do_CAR", "invert_sign", "device"] + + if parse(version("kilosort")) >= parse("4.0.12"): + expected_arguments.append("save_preprocesed_copy") + + self._check_arguments( + initialize_ops, + expected_arguments, + ) + + def test_compute_preprocessing_arguments(self): + self._check_arguments( + compute_preprocessing, + ["ops", "device", "tic0", "file_object"] + ) + + def test_compute_drift_location_arguments(self): + self._check_arguments( + compute_drift_correction, + ["ops", "device", "tic0", "progress_bar", "file_object"] + ) + + def test_detect_spikes_arguments(self): + self._check_arguments( + detect_spikes, + ["ops", "device", "bfile", "tic0", "progress_bar"] + ) + + + def test_cluster_spikes_arguments(self): + self._check_arguments( + cluster_spikes, + ["st", "tF", "ops", "device", "bfile", "tic0", "progress_bar"] + ) + + def test_save_sorting_arguments(self): + + expected_arguments = ["ops", "results_dir", "st", "clu", "tF", "Wall", "imin", "tic0", "save_extra_vars"] + + if parse(version("kilosort")) > parse("4.0.11"): + expected_arguments.append("save_preprocessed_copy") + + self._check_arguments( + save_sorting, + expected_arguments + ) + + def test_get_run_parameters(self): + self._check_arguments( + get_run_parameters, + ["ops"] + ) + + def test_load_probe_parameters(self): + self._check_arguments( + load_probe, + ["probe_path"] + ) + + def test_recording_extractor_as_array_arguments(self): + self._check_arguments( + RecordingExtractorAsArray, + ["recording_extractor"] + ) + + def test_binary_filtered_arguments(self): + + expected_arguments = [ + "filename", "n_chan_bin", "fs", "NT", "nt", "nt0min", + "chan_map", "hp_filter", "whiten_mat", "dshift", + "device", "do_CAR", "artifact_threshold", "invert_sign", + "dtype", "tmin", "tmax", "file_object" + ] + + if parse(version("kilosort")) >= parse("4.0.11"): + expected_arguments.pop(-1) + expected_arguments.extend(["shift", "scale", "file_object"]) + + self._check_arguments( + BinaryFiltered, + expected_arguments + ) + + def _check_arguments(self, object_, expected_arguments): + sig = signature(object_) + obj_arguments = list(sig.parameters.keys()) + assert expected_arguments == obj_arguments + @pytest.mark.parametrize("parameter", PARAMS_TO_TEST) def test_kilosort4(self, recording_and_paths, default_results, tmp_path, parameter): """ """ @@ -381,7 +485,7 @@ def fake_fftshift(X, dim): # Helpers ###### def _check_test_parameters_are_actually_changing_the_output(self, results, default_results, param_key): """ """ - if param_key not in ["artifact_threshold", "ccg_threshold", "cluster_downsampling"]: + if param_key not in ["artifact_threshold", "ccg_threshold", "cluster_downsampling", "cluster_pcs"]: num_clus = np.unique(results["si"]["clus"].iloc[:, 0]).size num_clus_default = np.unique(default_results["ks"]["clus"].iloc[:, 0]).size diff --git a/.github/workflows/test_kilosort4.yml b/.github/workflows/test_kilosort4.yml index 8e57f79786..c216be20d0 100644 --- a/.github/workflows/test_kilosort4.yml +++ b/.github/workflows/test_kilosort4.yml @@ -9,38 +9,56 @@ on: branches: - main -# env: -# KACHERY_CLOUD_CLIENT_ID: ${{ secrets.KACHERY_CLOUD_CLIENT_ID }} -# KACHERY_CLOUD_PRIVATE_KEY: ${{ secrets.KACHERY_CLOUD_PRIVATE_KEY }} +jobs: + versions: + runs-on: ubuntu-latest + outputs: + matrix: ${{ steps.set-matrix.outputs.matrix }} + steps: + - name: Checkout repository + uses: actions/checkout@v2 -# concurrency: # Cancel previous workflows on the same pull request -# group: ${{ github.workflow }}-${{ github.ref }} -# cancel-in-progress: true + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: 3.12 -jobs: - run: - name: ${{ matrix.os }} Python ${{ matrix.python-version }} + - name: Install dependencies + run: | + pip install requests + + - name: Fetch package versions from PyPI + run: | + python .github/scripts/check_kilosort4_releases.py + shell: bash + + - name: Set matrix data + id: set-matrix + run: | + echo "matrix=$(jq -c . < .github/scripts/kilosort4-latest-version.json)" >> $GITHUB_OUTPUT + + test: + needs: versions + name: ${{ matrix.ks_version }} runs-on: ${{ matrix.os }} strategy: fail-fast: false matrix: - python-version: ["3.12"] # TODO: "3.9", # Lower and higher versions we support - os: [ubuntu-latest] # TODO: macos-13, windows-latest, - ks_version: ["4.0.12"] # TODO: add / build from pypi based on Christians PR + python-version: ["3.12"] + os: [ubuntu-latest] + ks_version: ${{ fromJson(needs.versions.outputs.matrix) }} steps: - - uses: actions/checkout@v4 + - name: Checkout repository + uses: actions/checkout@v2 + - name: Setup Python ${{ matrix.python-version }} uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - - name: Install packages - # TODO: maybe dont need full? + - name: Install SpikeInterface run: | pip install -e .[test] - # git config --global user.email "CI@example.com" - # git config --global user.name "CI Almighty" - # pip install tabulate shell: bash - name: Install Kilosort @@ -49,13 +67,6 @@ jobs: shell: bash - name: Run new kilosort4 tests - # run: chmod +x .github/test_kilosort4.sh - # TODO: figure out the paths to be able to run this by calling the file directly run: | - pytest -k test_kilosort4_new --durations=0 + pytest .github/scripts/test_kilosort4_ci.py shell: bash - -# TODO: pip install -e .[full,dev] is failing # -#The conflict is caused by: -# spikeinterface[docs] 0.101.0rc0 depends on datalad==0.16.2; extra == "docs" -# spikeinterface[test] 0.101.0rc0 depends on datalad>=1.0.2; extra == "test" diff --git a/conftest.py b/conftest.py index c4bac6628a..8c06830d25 100644 --- a/conftest.py +++ b/conftest.py @@ -19,6 +19,7 @@ def create_cache_folder(tmp_path_factory): cache_folder = tmp_path_factory.mktemp("cache_folder") return cache_folder + def pytest_collection_modifyitems(config, items): """ This function marks (in the pytest sense) the tests according to their name and file_path location @@ -28,7 +29,11 @@ def pytest_collection_modifyitems(config, items): rootdir = Path(config.rootdir) modules_location = rootdir / "src" / "spikeinterface" for item in items: - rel_path = Path(item.fspath).relative_to(modules_location) + try: # TODO: make a note on this, check with Herberto its okay. + rel_path = Path(item.fspath).relative_to(modules_location) + except: + continue + module = rel_path.parts[0] if module == "sorters": if "internal" in rel_path.parts: