Skip to content

Commit

Permalink
Add CI to test all kilosort4 versions.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeZiminski committed Jun 26, 2024
1 parent 9570c92 commit a8489a5
Show file tree
Hide file tree
Showing 5 changed files with 170 additions and 28 deletions.
2 changes: 2 additions & 0 deletions .github/scripts/README.MD
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
This folder contains test scripts for running in the CI, that are not run as part of the usual
CI because they are too long / heavy. These are run on cron-jobs once per week.
20 changes: 20 additions & 0 deletions .github/scripts/check_kilosort4_releases.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import os
import re
from pathlib import Path
import requests
import json


def get_pypi_versions(package_name):
url = f"https://pypi.org/pypi/{package_name}/json"
response = requests.get(url)
response.raise_for_status()
data = response.json()
return list(sorted(data["releases"].keys()))


if __name__ == "__main__":
package_name = "kilosort"
versions = get_pypi_versions(package_name)
with open(Path(os.path.realpath(__file__)).parent / "kilosort4-latest-version.json", "w") as f:
json.dump(versions, f)
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,14 @@
from kilosort.parameters import DEFAULT_SETTINGS
from packaging.version import parse
from importlib.metadata import version
from inspect import signature
from kilosort.run_kilosort import (set_files, initialize_ops,
compute_preprocessing,
compute_drift_correction, detect_spikes,
cluster_spikes, save_sorting,
get_run_parameters, )
from kilosort.io import load_probe, RecordingExtractorAsArray, BinaryFiltered
from kilosort.parameters import DEFAULT_SETTINGS

# TODO: save_preprocesed_copy is misspelled in KS4.
# TODO: duplicate_spike_bins to duplicate_spike_ms
Expand Down Expand Up @@ -190,6 +198,102 @@ def test_default_settings_all_represented(self):
if param_key not in ["n_chan_bin", "fs", "tmin", "tmax"]:
assert param_key in tested_keys, f"param: {param_key} in DEFAULT SETTINGS but not tested."

def test_set_files_arguments(self):
self._check_arguments(
set_files,
["settings", "filename", "probe", "probe_name", "data_dir", "results_dir"]
)

def test_initialize_ops_arguments(self):

expected_arguments = ["settings", "probe", "data_dtype", "do_CAR", "invert_sign", "device"]

if parse(version("kilosort")) >= parse("4.0.12"):
expected_arguments.append("save_preprocesed_copy")

self._check_arguments(
initialize_ops,
expected_arguments,
)

def test_compute_preprocessing_arguments(self):
self._check_arguments(
compute_preprocessing,
["ops", "device", "tic0", "file_object"]
)

def test_compute_drift_location_arguments(self):
self._check_arguments(
compute_drift_correction,
["ops", "device", "tic0", "progress_bar", "file_object"]
)

def test_detect_spikes_arguments(self):
self._check_arguments(
detect_spikes,
["ops", "device", "bfile", "tic0", "progress_bar"]
)


def test_cluster_spikes_arguments(self):
self._check_arguments(
cluster_spikes,
["st", "tF", "ops", "device", "bfile", "tic0", "progress_bar"]
)

def test_save_sorting_arguments(self):

expected_arguments = ["ops", "results_dir", "st", "clu", "tF", "Wall", "imin", "tic0", "save_extra_vars"]

if parse(version("kilosort")) > parse("4.0.11"):
expected_arguments.append("save_preprocessed_copy")

self._check_arguments(
save_sorting,
expected_arguments
)

def test_get_run_parameters(self):
self._check_arguments(
get_run_parameters,
["ops"]
)

def test_load_probe_parameters(self):
self._check_arguments(
load_probe,
["probe_path"]
)

def test_recording_extractor_as_array_arguments(self):
self._check_arguments(
RecordingExtractorAsArray,
["recording_extractor"]
)

def test_binary_filtered_arguments(self):

expected_arguments = [
"filename", "n_chan_bin", "fs", "NT", "nt", "nt0min",
"chan_map", "hp_filter", "whiten_mat", "dshift",
"device", "do_CAR", "artifact_threshold", "invert_sign",
"dtype", "tmin", "tmax", "file_object"
]

if parse(version("kilosort")) >= parse("4.0.11"):
expected_arguments.pop(-1)
expected_arguments.extend(["shift", "scale", "file_object"])

self._check_arguments(
BinaryFiltered,
expected_arguments
)

def _check_arguments(self, object_, expected_arguments):
sig = signature(object_)
obj_arguments = list(sig.parameters.keys())
assert expected_arguments == obj_arguments

@pytest.mark.parametrize("parameter", PARAMS_TO_TEST)
def test_kilosort4(self, recording_and_paths, default_results, tmp_path, parameter):
""" """
Expand Down Expand Up @@ -381,7 +485,7 @@ def fake_fftshift(X, dim):
# Helpers ######
def _check_test_parameters_are_actually_changing_the_output(self, results, default_results, param_key):
""" """
if param_key not in ["artifact_threshold", "ccg_threshold", "cluster_downsampling"]:
if param_key not in ["artifact_threshold", "ccg_threshold", "cluster_downsampling", "cluster_pcs"]:
num_clus = np.unique(results["si"]["clus"].iloc[:, 0]).size
num_clus_default = np.unique(default_results["ks"]["clus"].iloc[:, 0]).size

Expand Down
63 changes: 37 additions & 26 deletions .github/workflows/test_kilosort4.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,38 +9,56 @@ on:
branches:
- main

# env:
# KACHERY_CLOUD_CLIENT_ID: ${{ secrets.KACHERY_CLOUD_CLIENT_ID }}
# KACHERY_CLOUD_PRIVATE_KEY: ${{ secrets.KACHERY_CLOUD_PRIVATE_KEY }}
jobs:
versions:
runs-on: ubuntu-latest
outputs:
matrix: ${{ steps.set-matrix.outputs.matrix }}
steps:
- name: Checkout repository
uses: actions/checkout@v2

# concurrency: # Cancel previous workflows on the same pull request
# group: ${{ github.workflow }}-${{ github.ref }}
# cancel-in-progress: true
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: 3.12

jobs:
run:
name: ${{ matrix.os }} Python ${{ matrix.python-version }}
- name: Install dependencies
run: |
pip install requests
- name: Fetch package versions from PyPI
run: |
python .github/scripts/check_kilosort4_releases.py
shell: bash

- name: Set matrix data
id: set-matrix
run: |
echo "matrix=$(jq -c . < .github/scripts/kilosort4-latest-version.json)" >> $GITHUB_OUTPUT
test:
needs: versions
name: ${{ matrix.ks_version }}
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
python-version: ["3.12"] # TODO: "3.9", # Lower and higher versions we support
os: [ubuntu-latest] # TODO: macos-13, windows-latest,
ks_version: ["4.0.12"] # TODO: add / build from pypi based on Christians PR
python-version: ["3.12"]
os: [ubuntu-latest]
ks_version: ${{ fromJson(needs.versions.outputs.matrix) }}
steps:
- uses: actions/checkout@v4
- name: Checkout repository
uses: actions/checkout@v2

- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

- name: Install packages
# TODO: maybe dont need full?
- name: Install SpikeInterface
run: |
pip install -e .[test]
# git config --global user.email "[email protected]"
# git config --global user.name "CI Almighty"
# pip install tabulate
shell: bash

- name: Install Kilosort
Expand All @@ -49,13 +67,6 @@ jobs:
shell: bash

- name: Run new kilosort4 tests
# run: chmod +x .github/test_kilosort4.sh
# TODO: figure out the paths to be able to run this by calling the file directly
run: |
pytest -k test_kilosort4_new --durations=0
pytest .github/scripts/test_kilosort4_ci.py
shell: bash

# TODO: pip install -e .[full,dev] is failing #
#The conflict is caused by:
# spikeinterface[docs] 0.101.0rc0 depends on datalad==0.16.2; extra == "docs"
# spikeinterface[test] 0.101.0rc0 depends on datalad>=1.0.2; extra == "test"
7 changes: 6 additions & 1 deletion conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def create_cache_folder(tmp_path_factory):
cache_folder = tmp_path_factory.mktemp("cache_folder")
return cache_folder


def pytest_collection_modifyitems(config, items):
"""
This function marks (in the pytest sense) the tests according to their name and file_path location
Expand All @@ -28,7 +29,11 @@ def pytest_collection_modifyitems(config, items):
rootdir = Path(config.rootdir)
modules_location = rootdir / "src" / "spikeinterface"
for item in items:
rel_path = Path(item.fspath).relative_to(modules_location)
try: # TODO: make a note on this, check with Herberto its okay.
rel_path = Path(item.fspath).relative_to(modules_location)
except:
continue

module = rel_path.parts[0]
if module == "sorters":
if "internal" in rel_path.parts:
Expand Down

0 comments on commit a8489a5

Please sign in to comment.