diff --git a/README.md b/README.md index e51a05c..1a9b1db 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,49 @@ -# hybrid_template_library -Library of templates to create hybrid data for spike sorting benchmarks +# SpikeInterface Hybrid Template Library -[click here for access to the library](https://spikeinterface.github.io/hybrid_template_library/) +This repo contains a set of tools to construct and interact with a library of hybrid templates for spike sorting benchmarks. -## Testing locally +The library is made of several datasets stored a zarr file and can be accessed through the spikeinterface library. +The library is also accessible through a web-app that allows users to browse the templates and download them for +use in their spike sorting benchmarks. + + +## Template sources + +The following datasets are available in the library: + +- [IBL](https://dandiarchive.org/dandiset/000409?search=IBL&pos=3): Neuropixels 1.0 templates from the IBL Brain Wide Map dataset +- [Steinmetz and Ye. 2022](https://doi.org/10.6084/m9.figshare.19493588.v2): Neuropixels Ultra templates from Steinmetz and Ye. 2022 + +The templates have been processed and stored with the `python` scripts in the `python/scripts` folder and are stored in `zarr` +format in the `s3://spikeinterface-template-library` bucket hosted on `AWS S3` by [CatalystNeuro](https://www.catalystneuro.com/). + + +## Accessing the data through `SpikeInterface` + +The library can be accessed through the `spikeinterface` library using the `generation` module. +The following code shows how to access the library to fetch a dataframe with the available templates +and download the templates corresponing to a specific user query: + +```python +import spikeinterface.generation as sgen + +templates_info = sgen.fetch_templates_database_info() + +# select templates with amplitude between 200 and 250uV +templates_info_selected = templates_info.query('amplitude_uv > 200 and amplitude_uv < 250') +templates_selected = sgen.sgen.query_templates_from_database(templates_info_selected) +``` + +For a more comprehensive example on how to construct hybrid recordings from the template library and run spike sorting +benchmarks, please refer to the SpikeInterface tutorial on [Hybrid recordings](https://spikeinterface.readthedocs.io/en/latest/how_to/benchmark_with_hybrid_recordings.html). + +## Live Web-App + +The template library can be browsed through a web-app (source code included in this repo). The web-app is hosted on github pages and can be accessed through the following link: [https://spikeinterface.github.io/hybrid_template_library/](https://spikeinterface.github.io/hybrid_template_library/) + + +### Testing locally How to run a python server for testing zarr access @@ -38,10 +77,10 @@ python -c "from http.server import HTTPServer, SimpleHTTPRequestHandler; import ``` -Then you run the npm script to start the server and open the browser +Then you run the `npm` script to start the server and open the browser ```bash -export TEST_URL="http://localhost:8000/zarr_store.zarr" +export TEST_URL="http://localhost:8000/test_zarr.zarr" npm run start ``` diff --git a/pyproject.toml b/pyproject.toml index f127e31..4194441 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ classifiers = [ dependencies = [ "spikeinterface >= 0.101", + "MEArec", "tqdm", "pynwb>=2.8", "remfile==0.1", diff --git a/python/consolidate_datasets.py b/python/consolidate_datasets.py index 54a30fa..098bfb1 100644 --- a/python/consolidate_datasets.py +++ b/python/consolidate_datasets.py @@ -101,8 +101,12 @@ def consolidate_datasets(dry_run: bool = False, verbose: bool = False): depth_best_channel = channel_depths[best_channel_indices] peak_to_peak_best_channel = zarr_group["peak_to_peak"][template_indices, best_channel_indices] - noise_best_channel = zarr_group["channel_noise_levels"][best_channel_indices] - signal_to_noise_ratio_best_channel = peak_to_peak_best_channel / noise_best_channel + if "channel_noise_levels" not in zarr_group: + noise_best_channel = np.nan * np.zeros(num_units) + signal_to_noise_ratio_best_channel = np.nan * np.zeros(num_units) + else: + noise_best_channel = zarr_group["channel_noise_levels"][best_channel_indices] + signal_to_noise_ratio_best_channel = peak_to_peak_best_channel / noise_best_channel new_entry = pd.DataFrame( { diff --git a/python/delete_templates.py b/python/delete_templates.py index 78c6a07..f6890dd 100644 --- a/python/delete_templates.py +++ b/python/delete_templates.py @@ -1,4 +1,7 @@ import boto3 +import numpy as np +import zarr + from consolidate_datasets import list_zarr_directories @@ -31,11 +34,118 @@ def delete_templates_from_s3( delete_template_from_s3(bucket_name, key, boto_client=boto_client) -if __name__ == "__main__": +def delete_templates_too_few_spikes(min_spikes=50, dry_run=False, verbose=True): + """ + This function will delete templates associated to spike trains with too few spikes. + + The initial database was in fact created without a minimum number of spikes per unit, + so some units have very few spikes and possibly a noisy template. + """ + import spikeinterface.generation as sgen + + templates_info = sgen.fetch_templates_database_info() + templates_to_remove = templates_info.query(f"spikes_per_unit < {min_spikes}") + + if len(templates_to_remove) > 0: + if verbose: + print( + f"Removing {len(templates_to_remove)}/{len(templates_info)} templates with less than {min_spikes} spikes" + ) + datasets = np.unique(templates_to_remove["dataset"]) + + for d_i, dataset in enumerate(datasets): + if verbose: + print(f"\tCleaning dataset {d_i + 1}/{len(datasets)}") + templates_in_dataset = templates_to_remove.query(f"dataset == '{dataset}'") + template_indices_to_remove = templates_in_dataset.template_index.values + s3_path = templates_in_dataset.dataset_path.values[0] + + # to filter from the zarr dataset: + datasets_to_filter = [ + "templates_array", + "best_channel_index", + "spikes_per_unit", + "brain_area", + "peak_to_peak", + "unit_ids", + ] + + # open zarr in append mode + if dry_run: + mode = "r" + else: + mode = "r+" + zarr_root = zarr.open(s3_path, mode=mode) + all_unit_indices = np.arange(len(zarr_root["unit_ids"])) + n_original_units = len(all_unit_indices) + unit_indices_to_keep = np.delete(all_unit_indices, template_indices_to_remove) + n_units_to_keep = len(unit_indices_to_keep) + if verbose: + print(f"\tRemoving {n_original_units - n_units_to_keep} templates from {n_original_units}") + for dset in datasets_to_filter: + dataset_original = zarr_root[dset] + if len(dataset_original) == n_units_to_keep: + print(f"\t\tDataset: {dset} - shape: {dataset_original.shape} - already updated") + continue + dataset_filtered = dataset_original[unit_indices_to_keep] + if not dry_run: + if verbose: + print(f"\t\tUpdating: {dset} - shape: {dataset_filtered.shape}") + if dataset_filtered.dtype.kind == "O": + dataset_filtered = dataset_filtered.astype(str) + zarr_root[dset] = dataset_filtered + else: + if verbose: + print(f"\t\tDry run: {dset} - shape: {dataset_filtered.shape}") + if not dry_run: + zarr.consolidate_metadata(zarr_root.store) + + +def restore_noise_levels_ibl(datasets, one=None, dry_run=False, verbose=True): + """ + This function will restore noise levels for IBL datasets. + """ + import spikeinterface as si + import spikeinterface.extractors as se + import spikeinterface.generation as sgen + + for dataset in datasets: + if verbose: + print(f"Processing dataset: {dataset}") + pid = dataset.split("_")[-1][:-5] + dataset_path = f"s3://spikeinterface-template-database/{dataset}" + recording = se.read_ibl_recording(pid=pid, load_sync_channel=False, stream_type="ap", one=one) + + default_params = si.get_default_analyzer_extension_params("noise_levels") + if verbose: + print(f"\tComputing noise levels") + noise_levels = si.get_noise_levels(recording, return_scaled=True, **default_params) + + if dry_run: + mode = "r" + else: + mode = "r+" + zarr_root = zarr.open(dataset_path, mode=mode) + if not dry_run: + if verbose: + print(f"\tRestoring noise levels") + zarr_root["channel_noise_levels"] = noise_levels + zarr.consolidate_metadata(zarr_root.store) + else: + if verbose: + print(f"\tCurrent shape: {zarr_root['channel_noise_levels'].shape}") + print(f"\tDry run: would restore noise levels: {noise_levels.shape}") + + +def delete_templates_with_num_samples(dry_run=False): + """ + This function will delete templates with number of samples, + which were not corrected for in the initial database. + """ bucket = "spikeinterface-template-database" boto_client = boto3.client("s3") - verbose = True - + verbose = True + templates_to_erase_from_bucket = [ "000409_sub-KS084_ses-1b715600-0cbc-442c-bd00-5b0ac2865de1_behavior+ecephys+image_bbe6ebc1-d32f-42dd-a89c-211226737deb.zarr", "000409_sub-KS086_ses-e45481fa-be22-4365-972c-e7404ed8ab5a_behavior+ecephys+image_f2a098e7-a67e-4125-92d8-36fc6b606c45.zarr", @@ -50,7 +160,13 @@ def delete_templates_from_s3( "000409_sub-KS096_ses-f819d499-8bf7-4da0-a431-15377a8319d5_behavior+ecephys+image_4ea45238-55b1-4d54-ba92-efa47feb9f57.zarr", ] existing_templates = list_zarr_directories(bucket, boto_client=boto_client) - templates_to_erase_from_bucket = [template for template in templates_to_erase_from_bucket if template in existing_templates] - if verbose: - print(f"Erasing {len(templates_to_erase_from_bucket)} templates from bucket: {bucket}") - delete_templates_from_s3(bucket, templates_to_erase_from_bucket, boto_client=boto_client) + templates_to_erase_from_bucket = [ + template for template in templates_to_erase_from_bucket if template in existing_templates + ] + if dry_run: + if verbose: + print(f"Would erase {len(templates_to_erase_from_bucket)} templates from bucket: {bucket}") + else: + if verbose: + print(f"Erasing {len(templates_to_erase_from_bucket)} templates from bucket: {bucket}") + delete_templates_from_s3(bucket, templates_to_erase_from_bucket, boto_client=boto_client) diff --git a/python/upload_templates.py b/python/upload_ibl_templates.py similarity index 89% rename from python/upload_templates.py rename to python/upload_ibl_templates.py index 5f4f5d7..a3477a4 100644 --- a/python/upload_templates.py +++ b/python/upload_ibl_templates.py @@ -1,6 +1,26 @@ +""" +This script constructs and uploads the templates from the International Brain Laboratory (IBL) datasets +available from DANDI (https://dandiarchive.org/dandiset/000409?search=IBL&pos=3). + +Templates are extracted by combining the raw data from the NWB files on DANDI with the spike trains form +the Alyx ONE database. Only the units that passed the IBL quality control are used. +To minimize the amount of drift in the templates, only the last 30 minutes of the recording are used. +The raw recordings are pre-processed with a high-pass filter and a common median reference prior to +template extraction. Units with less than 50 spikes are excluded from the template database. + +Once the templates are constructed they are saved to a Zarr file which is then uploaded to +"spikeinterface-template-database" bucket (hosted by CatalystNeuro). +""" + from pathlib import Path import numpy as np +import s3fs +import zarr +import time +import os +import numcodecs + from dandi.dandiapi import DandiAPIClient from spikeinterface.extractors import ( @@ -16,13 +36,7 @@ highpass_filter, ) -import s3fs -import zarr -import numcodecs - from one.api import ONE -import time -import os from consolidate_datasets import list_zarr_directories @@ -55,6 +69,7 @@ def find_channels_with_max_peak_to_peak_vectorized(templates): # Parameters minutes_by_the_end = 30 # How many minutes in the end of the recording to use for templates +min_spikes_per_unit = 50 upload_data = True overwite = False verbose = True @@ -174,6 +189,10 @@ def find_channels_with_max_peak_to_peak_vectorized(templates): sorting_end = sorting.frame_slice(start_frame=start_frame_sorting, end_frame=end_frame_sorting) + spikes_per_unit = sorting_end.count_num_spikes_per_unit(outputs="array") + unit_indices_to_keep = np.where(spikes_per_unit >= min_spikes_per_unit)[0] + sorting_end = sorting_end.select_units(sorting_end.unit_ids[unit_indices_to_keep]) + # NWB Streaming is not working well with parallel pre=processing so we ave folder_path = Path.cwd() / "build" / "local_copy" folder_path.mkdir(exist_ok=True, parents=True) @@ -263,6 +282,8 @@ def find_channels_with_max_peak_to_peak_vectorized(templates): expected_shape = (number_of_units, number_of_temporal_samples, number_of_channels) assert templates_extension_data.templates_array.shape == expected_shape + # TODO: skip templates with 0 amplitude! + # TODO: check for weird shapes templates_extension = analyzer.get_extension("templates") templates_object = templates_extension.get_data(outputs="Templates") unit_ids = templates_object.unit_ids @@ -278,7 +299,9 @@ def find_channels_with_max_peak_to_peak_vectorized(templates): if upload_data: # Create a S3 file system object with explicit credentials - s3_kwargs = dict(anon=False, key=aws_access_key_id, secret=aws_secret_access_key, client_kwargs=client_kwargs) + s3_kwargs = dict( + anon=False, key=aws_access_key_id, secret=aws_secret_access_key, client_kwargs=client_kwargs + ) s3 = s3fs.S3FileSystem(**s3_kwargs) # Specify the S3 bucket and path diff --git a/python/upload_npultra_templates.py b/python/upload_npultra_templates.py new file mode 100644 index 0000000..e642e09 --- /dev/null +++ b/python/upload_npultra_templates.py @@ -0,0 +1,158 @@ +""" +This script constructs and uploads the templates from the Neuropixels Ultra dataset +form Steinmetz and Ye, 2022. The dataset is hosted on Figshare at https://doi.org/10.6084/m9.figshare.19493588.v2 + +Since the templates in the dataset have rather short cut outs, which might negatively interfere with +hybrid spike injections, the templates are padded and smoothed using the `MEArec` package so that they +end up having 240 samples (90 before, 150 after the peak). + +Once the templates are constructed they are saved to a Zarr file which is then uploaded to +"spikeinterface-template-database" bucket (hosted by CatalystNeuro). +""" + +from pathlib import Path + +import numpy as np +import s3fs +import zarr +import pandas as pd +import os +import numcodecs +import probeinterface as pi +import spikeinterface as si + +from MEArec.tools import pad_templates, sigmoid + + +def smooth_edges(templates, pad_samples, smooth_percent=0.5, smooth_strength=1): + # smooth edges + sigmoid_samples = int(smooth_percent * pad_samples[0]) // 2 * 2 + sigmoid_x = np.arange(-sigmoid_samples // 2, sigmoid_samples // 2) + b = smooth_strength + sig = sigmoid(sigmoid_x, b) + 0.5 + window = np.ones(templates.shape[-1]) + window[:sigmoid_samples] = sig + window[-sigmoid_samples:] = sig[::-1] + + templates *= window + return templates + + +# parameters +min_spikes_per_unit = 50 +target_nbefore = 90 +target_nafter = 150 +upload_data = False + +npultra_templates_path = Path("/home/alessio/Documents/Data/Templates/NPUltraWaveforms/") +dataset_name = "steinmetz_ye_np_ultra_2022_figshare19493588v2.zarr" + +# AWS credentials +aws_access_key_id = os.environ.get("AWS_ACCESS_KEY_ID") +aws_secret_access_key = os.environ.get("AWS_SECRET_ACCESS_KEY") +bucket_name = "spikeinterface-template-database" +client_kwargs = {"region_name": "us-east-2"} + +# Load the templates and the required metadata +xpos = np.load(npultra_templates_path / "channels.xcoords.npy") +ypos = np.load(npultra_templates_path / "channels.ycoords.npy") + +channel_locations = np.squeeze([xpos, ypos]).T + +templates_array = np.load(npultra_templates_path / "clusters.waveforms.npy") +spike_clusters = np.load(npultra_templates_path / "spikes.clusters.npy") + +brain_area = pd.read_csv(npultra_templates_path / "clusters.acronym.tsv", sep="\t") +brain_area_acronym = brain_area["acronym"].values + +# Instantiate Probe +probe = pi.Probe(ndim=2) +probe.set_contacts(positions=channel_locations, shapes="square", shape_params={"width": 5}) +probe.model_name = "Neuropixels Ultra" +probe.manufacturer = "IMEC" + +# Unit ids and properties +unit_ids, spikes_per_unit = np.unique(spike_clusters, return_counts=True) +unit_ids_enough_spikes = spikes_per_unit >= min_spikes_per_unit +unit_ids = unit_ids[unit_ids_enough_spikes] +spikes_per_unit = spikes_per_unit[unit_ids_enough_spikes] + +# Sort the units by unit_id +sort_unit_indices = np.argsort(unit_ids) +unit_ids = unit_ids[sort_unit_indices].astype(int) +spikes_per_unit = spikes_per_unit[sort_unit_indices] +brain_area_acronym = brain_area_acronym[sort_unit_indices] + +# Process the templates to make them smooth +nbefore = 40 +sampling_frequency = 30000 +num_samples = templates_array.shape[1] +nafter = num_samples - nbefore + +pad_samples = [target_nbefore - nbefore, target_nafter - nafter] + +# MEArec needs swap axes +print("Padding and smoothing templates") +templates_array_swap = templates_array.swapaxes(1, 2) +tmp_templates_file = "templates_padded.raw" +templates_padded_swap = pad_templates( + templates_array_swap, + pad_samples, + drifting=False, + dtype="float", + verbose=False, + n_jobs=-1, + tmp_file=tmp_templates_file, + parallel=True, +) +templates_padded = templates_padded_swap.swapaxes(1, 2) +Path(tmp_templates_file).unlink() + +# smooth edges +templates_smoothed_swap = smooth_edges(templates_padded_swap, pad_samples) +templates_smoothed = templates_smoothed_swap.swapaxes(1, 2) + +# Create Templates object +print("Creating Templates object") +templates_ultra = si.Templates( + templates_array=templates_smoothed, + sampling_frequency=sampling_frequency, + nbefore=nbefore, + unit_ids=unit_ids, + probe=probe, + is_scaled=True, +) + +best_channel_index = si.get_template_extremum_channel(templates_ultra, mode="peak_to_peak", outputs="index") +best_channel_index = list(best_channel_index.values()) + +if upload_data: + # Create a S3 file system object with explicit credentials + s3_kwargs = dict(anon=False, key=aws_access_key_id, secret=aws_secret_access_key, client_kwargs=client_kwargs) + s3 = s3fs.S3FileSystem(**s3_kwargs) + + # Specify the S3 bucket and path + s3_path = f"{bucket_name}/{dataset_name}" + store = s3fs.S3Map(root=s3_path, s3=s3) +else: + folder_path = Path.cwd() / "build" / f"{dataset_name}" + folder_path.mkdir(exist_ok=True, parents=True) + store = zarr.DirectoryStore(str(folder_path)) + +# Save results to Zarr +zarr_group = zarr.group(store=store, overwrite=True) +zarr_group.create_dataset(name="brain_area", data=brain_area, object_codec=numcodecs.VLenUTF8()) +zarr_group.create_dataset(name="spikes_per_unit", data=spikes_per_unit, chunks=None, dtype="uint32") +zarr_group.create_dataset( + name="best_channel_index", + data=best_channel_index, + chunks=None, + dtype="uint32", +) +peak_to_peak = np.ptp(templates_array, axis=1) +zarr_group.create_dataset(name="peak_to_peak", data=peak_to_peak) + +# Now you can create a Zarr array using this store +templates_ultra.add_templates_to_zarr_group(zarr_group=zarr_group) +zarr_group_s3 = zarr_group +zarr.consolidate_metadata(zarr_group_s3.store)