Skip to content

Commit

Permalink
Merge pull request #14 from SpikeInterface/np-ultra
Browse files Browse the repository at this point in the history
Add script to upload NP-Ultra dataset
  • Loading branch information
alejoe91 authored Sep 28, 2024
2 parents 1b4080a + 0b2a456 commit 40f0d7b
Show file tree
Hide file tree
Showing 6 changed files with 363 additions and 22 deletions.
51 changes: 45 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,49 @@
# hybrid_template_library
Library of templates to create hybrid data for spike sorting benchmarks
# SpikeInterface Hybrid Template Library

[click here for access to the library](https://spikeinterface.github.io/hybrid_template_library/)

This repo contains a set of tools to construct and interact with a library of hybrid templates for spike sorting benchmarks.

## Testing locally
The library is made of several datasets stored a zarr file and can be accessed through the spikeinterface library.
The library is also accessible through a web-app that allows users to browse the templates and download them for
use in their spike sorting benchmarks.


## Template sources

The following datasets are available in the library:

- [IBL](https://dandiarchive.org/dandiset/000409?search=IBL&pos=3): Neuropixels 1.0 templates from the IBL Brain Wide Map dataset
- [Steinmetz and Ye. 2022](https://doi.org/10.6084/m9.figshare.19493588.v2): Neuropixels Ultra templates from Steinmetz and Ye. 2022

The templates have been processed and stored with the `python` scripts in the `python/scripts` folder and are stored in `zarr`
format in the `s3://spikeinterface-template-library` bucket hosted on `AWS S3` by [CatalystNeuro](https://www.catalystneuro.com/).


## Accessing the data through `SpikeInterface`

The library can be accessed through the `spikeinterface` library using the `generation` module.
The following code shows how to access the library to fetch a dataframe with the available templates
and download the templates corresponing to a specific user query:

```python
import spikeinterface.generation as sgen

templates_info = sgen.fetch_templates_database_info()

# select templates with amplitude between 200 and 250uV
templates_info_selected = templates_info.query('amplitude_uv > 200 and amplitude_uv < 250')
templates_selected = sgen.sgen.query_templates_from_database(templates_info_selected)
```

For a more comprehensive example on how to construct hybrid recordings from the template library and run spike sorting
benchmarks, please refer to the SpikeInterface tutorial on [Hybrid recordings](https://spikeinterface.readthedocs.io/en/latest/how_to/benchmark_with_hybrid_recordings.html).

## Live Web-App

The template library can be browsed through a web-app (source code included in this repo). The web-app is hosted on github pages and can be accessed through the following link: [https://spikeinterface.github.io/hybrid_template_library/](https://spikeinterface.github.io/hybrid_template_library/)


### Testing locally

How to run a python server for testing zarr access

Expand Down Expand Up @@ -38,10 +77,10 @@ python -c "from http.server import HTTPServer, SimpleHTTPRequestHandler; import
```


Then you run the npm script to start the server and open the browser
Then you run the `npm` script to start the server and open the browser

```bash
export TEST_URL="http://localhost:8000/zarr_store.zarr"
export TEST_URL="http://localhost:8000/test_zarr.zarr"
npm run start
```

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ classifiers = [

dependencies = [
"spikeinterface >= 0.101",
"MEArec",
"tqdm",
"pynwb>=2.8",
"remfile==0.1",
Expand Down
8 changes: 6 additions & 2 deletions python/consolidate_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,12 @@ def consolidate_datasets(dry_run: bool = False, verbose: bool = False):

depth_best_channel = channel_depths[best_channel_indices]
peak_to_peak_best_channel = zarr_group["peak_to_peak"][template_indices, best_channel_indices]
noise_best_channel = zarr_group["channel_noise_levels"][best_channel_indices]
signal_to_noise_ratio_best_channel = peak_to_peak_best_channel / noise_best_channel
if "channel_noise_levels" not in zarr_group:
noise_best_channel = np.nan * np.zeros(num_units)
signal_to_noise_ratio_best_channel = np.nan * np.zeros(num_units)
else:
noise_best_channel = zarr_group["channel_noise_levels"][best_channel_indices]
signal_to_noise_ratio_best_channel = peak_to_peak_best_channel / noise_best_channel

new_entry = pd.DataFrame(
{
Expand Down
130 changes: 123 additions & 7 deletions python/delete_templates.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import boto3
import numpy as np
import zarr

from consolidate_datasets import list_zarr_directories


Expand Down Expand Up @@ -31,11 +34,118 @@ def delete_templates_from_s3(
delete_template_from_s3(bucket_name, key, boto_client=boto_client)


if __name__ == "__main__":
def delete_templates_too_few_spikes(min_spikes=50, dry_run=False, verbose=True):
"""
This function will delete templates associated to spike trains with too few spikes.
The initial database was in fact created without a minimum number of spikes per unit,
so some units have very few spikes and possibly a noisy template.
"""
import spikeinterface.generation as sgen

templates_info = sgen.fetch_templates_database_info()
templates_to_remove = templates_info.query(f"spikes_per_unit < {min_spikes}")

if len(templates_to_remove) > 0:
if verbose:
print(
f"Removing {len(templates_to_remove)}/{len(templates_info)} templates with less than {min_spikes} spikes"
)
datasets = np.unique(templates_to_remove["dataset"])

for d_i, dataset in enumerate(datasets):
if verbose:
print(f"\tCleaning dataset {d_i + 1}/{len(datasets)}")
templates_in_dataset = templates_to_remove.query(f"dataset == '{dataset}'")
template_indices_to_remove = templates_in_dataset.template_index.values
s3_path = templates_in_dataset.dataset_path.values[0]

# to filter from the zarr dataset:
datasets_to_filter = [
"templates_array",
"best_channel_index",
"spikes_per_unit",
"brain_area",
"peak_to_peak",
"unit_ids",
]

# open zarr in append mode
if dry_run:
mode = "r"
else:
mode = "r+"
zarr_root = zarr.open(s3_path, mode=mode)
all_unit_indices = np.arange(len(zarr_root["unit_ids"]))
n_original_units = len(all_unit_indices)
unit_indices_to_keep = np.delete(all_unit_indices, template_indices_to_remove)
n_units_to_keep = len(unit_indices_to_keep)
if verbose:
print(f"\tRemoving {n_original_units - n_units_to_keep} templates from {n_original_units}")
for dset in datasets_to_filter:
dataset_original = zarr_root[dset]
if len(dataset_original) == n_units_to_keep:
print(f"\t\tDataset: {dset} - shape: {dataset_original.shape} - already updated")
continue
dataset_filtered = dataset_original[unit_indices_to_keep]
if not dry_run:
if verbose:
print(f"\t\tUpdating: {dset} - shape: {dataset_filtered.shape}")
if dataset_filtered.dtype.kind == "O":
dataset_filtered = dataset_filtered.astype(str)
zarr_root[dset] = dataset_filtered
else:
if verbose:
print(f"\t\tDry run: {dset} - shape: {dataset_filtered.shape}")
if not dry_run:
zarr.consolidate_metadata(zarr_root.store)


def restore_noise_levels_ibl(datasets, one=None, dry_run=False, verbose=True):
"""
This function will restore noise levels for IBL datasets.
"""
import spikeinterface as si
import spikeinterface.extractors as se
import spikeinterface.generation as sgen

for dataset in datasets:
if verbose:
print(f"Processing dataset: {dataset}")
pid = dataset.split("_")[-1][:-5]
dataset_path = f"s3://spikeinterface-template-database/{dataset}"
recording = se.read_ibl_recording(pid=pid, load_sync_channel=False, stream_type="ap", one=one)

default_params = si.get_default_analyzer_extension_params("noise_levels")
if verbose:
print(f"\tComputing noise levels")
noise_levels = si.get_noise_levels(recording, return_scaled=True, **default_params)

if dry_run:
mode = "r"
else:
mode = "r+"
zarr_root = zarr.open(dataset_path, mode=mode)
if not dry_run:
if verbose:
print(f"\tRestoring noise levels")
zarr_root["channel_noise_levels"] = noise_levels
zarr.consolidate_metadata(zarr_root.store)
else:
if verbose:
print(f"\tCurrent shape: {zarr_root['channel_noise_levels'].shape}")
print(f"\tDry run: would restore noise levels: {noise_levels.shape}")


def delete_templates_with_num_samples(dry_run=False):
"""
This function will delete templates with number of samples,
which were not corrected for in the initial database.
"""
bucket = "spikeinterface-template-database"
boto_client = boto3.client("s3")
verbose = True
verbose = True

templates_to_erase_from_bucket = [
"000409_sub-KS084_ses-1b715600-0cbc-442c-bd00-5b0ac2865de1_behavior+ecephys+image_bbe6ebc1-d32f-42dd-a89c-211226737deb.zarr",
"000409_sub-KS086_ses-e45481fa-be22-4365-972c-e7404ed8ab5a_behavior+ecephys+image_f2a098e7-a67e-4125-92d8-36fc6b606c45.zarr",
Expand All @@ -50,7 +160,13 @@ def delete_templates_from_s3(
"000409_sub-KS096_ses-f819d499-8bf7-4da0-a431-15377a8319d5_behavior+ecephys+image_4ea45238-55b1-4d54-ba92-efa47feb9f57.zarr",
]
existing_templates = list_zarr_directories(bucket, boto_client=boto_client)
templates_to_erase_from_bucket = [template for template in templates_to_erase_from_bucket if template in existing_templates]
if verbose:
print(f"Erasing {len(templates_to_erase_from_bucket)} templates from bucket: {bucket}")
delete_templates_from_s3(bucket, templates_to_erase_from_bucket, boto_client=boto_client)
templates_to_erase_from_bucket = [
template for template in templates_to_erase_from_bucket if template in existing_templates
]
if dry_run:
if verbose:
print(f"Would erase {len(templates_to_erase_from_bucket)} templates from bucket: {bucket}")
else:
if verbose:
print(f"Erasing {len(templates_to_erase_from_bucket)} templates from bucket: {bucket}")
delete_templates_from_s3(bucket, templates_to_erase_from_bucket, boto_client=boto_client)
37 changes: 30 additions & 7 deletions python/upload_templates.py → python/upload_ibl_templates.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,26 @@
"""
This script constructs and uploads the templates from the International Brain Laboratory (IBL) datasets
available from DANDI (https://dandiarchive.org/dandiset/000409?search=IBL&pos=3).
Templates are extracted by combining the raw data from the NWB files on DANDI with the spike trains form
the Alyx ONE database. Only the units that passed the IBL quality control are used.
To minimize the amount of drift in the templates, only the last 30 minutes of the recording are used.
The raw recordings are pre-processed with a high-pass filter and a common median reference prior to
template extraction. Units with less than 50 spikes are excluded from the template database.
Once the templates are constructed they are saved to a Zarr file which is then uploaded to
"spikeinterface-template-database" bucket (hosted by CatalystNeuro).
"""

from pathlib import Path

import numpy as np
import s3fs
import zarr
import time
import os
import numcodecs

from dandi.dandiapi import DandiAPIClient

from spikeinterface.extractors import (
Expand All @@ -16,13 +36,7 @@
highpass_filter,
)

import s3fs
import zarr
import numcodecs

from one.api import ONE
import time
import os

from consolidate_datasets import list_zarr_directories

Expand Down Expand Up @@ -55,6 +69,7 @@ def find_channels_with_max_peak_to_peak_vectorized(templates):

# Parameters
minutes_by_the_end = 30 # How many minutes in the end of the recording to use for templates
min_spikes_per_unit = 50
upload_data = True
overwite = False
verbose = True
Expand Down Expand Up @@ -174,6 +189,10 @@ def find_channels_with_max_peak_to_peak_vectorized(templates):

sorting_end = sorting.frame_slice(start_frame=start_frame_sorting, end_frame=end_frame_sorting)

spikes_per_unit = sorting_end.count_num_spikes_per_unit(outputs="array")
unit_indices_to_keep = np.where(spikes_per_unit >= min_spikes_per_unit)[0]
sorting_end = sorting_end.select_units(sorting_end.unit_ids[unit_indices_to_keep])

# NWB Streaming is not working well with parallel pre=processing so we ave
folder_path = Path.cwd() / "build" / "local_copy"
folder_path.mkdir(exist_ok=True, parents=True)
Expand Down Expand Up @@ -263,6 +282,8 @@ def find_channels_with_max_peak_to_peak_vectorized(templates):
expected_shape = (number_of_units, number_of_temporal_samples, number_of_channels)
assert templates_extension_data.templates_array.shape == expected_shape

# TODO: skip templates with 0 amplitude!
# TODO: check for weird shapes
templates_extension = analyzer.get_extension("templates")
templates_object = templates_extension.get_data(outputs="Templates")
unit_ids = templates_object.unit_ids
Expand All @@ -278,7 +299,9 @@ def find_channels_with_max_peak_to_peak_vectorized(templates):

if upload_data:
# Create a S3 file system object with explicit credentials
s3_kwargs = dict(anon=False, key=aws_access_key_id, secret=aws_secret_access_key, client_kwargs=client_kwargs)
s3_kwargs = dict(
anon=False, key=aws_access_key_id, secret=aws_secret_access_key, client_kwargs=client_kwargs
)
s3 = s3fs.S3FileSystem(**s3_kwargs)

# Specify the S3 bucket and path
Expand Down
Loading

0 comments on commit 40f0d7b

Please sign in to comment.