Skip to content

Commit

Permalink
Merge pull request #11 from SpikeInterface/add_template_check_consist…
Browse files Browse the repository at this point in the history
…ency

Add templates and consistency checks
  • Loading branch information
alejoe91 authored Aug 2, 2024
2 parents 7fd3a5f + a20dc49 commit 010fd76
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 25 deletions.
60 changes: 37 additions & 23 deletions python/consolidate_datasets.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
from pathlib import Path
from argparse import ArgumentParser

import boto3
import pandas as pd
import zarr
import numpy as np
from argparse import ArgumentParser
from tqdm.auto import tqdm

from spikeinterface.core import Templates

parser = ArgumentParser(description="Consolidate datasets from spikeinterface template database")

parser.add_argument("--dry-run", action="store_true", help="Dry run (no upload)")
parser.add_argument("--bucket", type=str, help="S3 bucket name", default="spikeinterface-template-database")

parser.add_argument("--verbose", action="store_true", help="Print additional information during processing")


def list_zarr_directories(bucket_name, boto_client=None) -> list[str]:
Expand All @@ -30,17 +32,18 @@ def list_zarr_directories(bucket_name, boto_client=None) -> list[str]:
found in the bucket.
"""

boto_client = boto_client or boto3.client('s3')
zarr_directories = set()
boto_client = boto_client or boto3.client("s3")
zarr_directories = set()

paginator = boto_client.get_paginator("list_objects_v2")
for page in paginator.paginate(Bucket=bucket_name, Delimiter="/"):
for prefix in page.get("CommonPrefixes", []):
key = prefix["Prefix"]
if key.endswith(".zarr/"):
zarr_directories.add(key.rstrip("/"))

paginator = boto_client.get_paginator('list_objects_v2')
for page in paginator.paginate(Bucket=bucket_name, Delimiter='/'):
for prefix in page.get('CommonPrefixes', []):
key = prefix['Prefix']
if key.endswith('.zarr/'):
zarr_directories.add(key.rstrip('/'))
return list(zarr_directories)

return list(zarr_directories)

def consolidate_datasets(dry_run: bool = False, verbose: bool = False):
"""Consolidates data from Zarr datasets within an S3 bucket.
Expand All @@ -62,8 +65,8 @@ def consolidate_datasets(dry_run: bool = False, verbose: bool = False):
FileNotFoundError
If no Zarr datasets are found in the specified bucket.
"""
bucket="spikeinterface-template-database"

bucket = "spikeinterface-template-database"
boto_client = boto3.client("s3")

# Get list of Zarr directories, excluding test datasets
Expand All @@ -78,9 +81,10 @@ def consolidate_datasets(dry_run: bool = False, verbose: bool = False):

# Initialize list to collect DataFrames for each dataset
all_dataframes = []

for dataset in zarr_datasets:
print(f"Processing dataset {dataset}")
desc = "Processing Zarr datasets"
for dataset in tqdm(zarr_datasets, desc=desc, unit=" datasets processed", disable=not verbose):
if verbose:
print(f"Processing dataset: {dataset}")
zarr_path = f"s3://{bucket}/{dataset}"
zarr_group = zarr.open_consolidated(zarr_path, storage_options=dict(anon=True))
templates = Templates.from_zarr_group(zarr_group)
Expand Down Expand Up @@ -122,19 +126,29 @@ def consolidate_datasets(dry_run: bool = False, verbose: bool = False):
# Concatenate all DataFrames into a single DataFrame
templates_df = pd.concat(all_dataframes, ignore_index=True)

templates_df.to_csv("templates.csv", index=False)
templates_file_name = "templates.csv"
local_template_folder = Path("./build/")
local_template_info_file_path = local_template_folder / templates_file_name
templates_df.to_csv(local_template_info_file_path, index=False)

# Upload to S3
if not dry_run:
boto_client.upload_file("templates.csv", bucket, "templates.csv")
if dry_run:
print("Dry run: skipping upload to S3")
else:
boto_client.upload_file(
Filename=local_template_info_file_path,
Bucket=bucket,
Key=templates_file_name,
)

if verbose:
print("Dry run, not uploading")
print(templates_df)

return templates_df


if __name__ == "__main__":
params = parser.parse_args()
DRY_RUN = params.dry_run
templates_df = consolidate_datasets(dry_run=DRY_RUN)
dry_run = params.dry_run
verbose = params.verbose
templates_df = consolidate_datasets(dry_run=dry_run, verbose=verbose)
56 changes: 56 additions & 0 deletions python/delete_templates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import boto3
from consolidate_datasets import list_zarr_directories


def delete_template_from_s3(bucket_name: str, template_key: str, boto_client: boto3.client = None) -> None:
"""Deletes a Zarr template (and its contents) from S3."""

boto_client = boto_client or boto3.client("s3")

# Delete all objects within the template directory (including nested directories)
boto_client.delete_objects(
Bucket=bucket_name,
Delete={
"Objects": [
{"Key": obj["Key"]}
for obj in boto_client.list_objects_v2(Bucket=bucket_name, Prefix=template_key).get("Contents", [])
]
},
)
print(f"Deleted template: {template_key}")


def delete_templates_from_s3(
bucket_name: str,
template_keys: list[str],
boto_client: boto3.client = None,
) -> None:
"""Deletes multiple Zarr templates from S3."""
boto_client = boto_client or boto3.client("s3")
for key in template_keys:
delete_template_from_s3(bucket_name, key, boto_client=boto_client)


if __name__ == "__main__":
bucket = "spikeinterface-template-database"
boto_client = boto3.client("s3")
verbose = True

templates_to_erase_from_bucket = [
"000409_sub-KS084_ses-1b715600-0cbc-442c-bd00-5b0ac2865de1_behavior+ecephys+image_bbe6ebc1-d32f-42dd-a89c-211226737deb.zarr",
"000409_sub-KS086_ses-e45481fa-be22-4365-972c-e7404ed8ab5a_behavior+ecephys+image_f2a098e7-a67e-4125-92d8-36fc6b606c45.zarr",
"000409_sub-KS091_ses-196a2adf-ff83-49b2-823a-33f990049c2e_behavior+ecephys+image_0259543e-1ca3-48e7-95c9-53f9e4c9bfcc.zarr",
"000409_sub-KS091_ses-78b4fff5-c5ec-44d9-b5f9-d59493063f00_behavior+ecephys+image_19c5b0d5-a255-47ff-9f8d-639e634a7b61.zarr",
"000409_sub-KS094_ses-6b0b5d24-bcda-4053-a59c-beaa1fe03b8f_behavior+ecephys+image_3282a590-8688-44fc-9811-cdf8b80d9a80.zarr",
"000409_sub-KS094_ses-752456f3-9f47-4fbf-bd44-9d131c0f41aa_behavior+ecephys+image_100433fa-2c59-4432-8295-aa27657fe3fb.zarr",
"000409_sub-KS094_ses-c8d46ee6-eb68-4535-8756-7c9aa32f10e4_behavior+ecephys+image_49a86b2e-3db4-42f2-8da8-7ebb7e482c70.zarr",
"000409_sub-KS096_ses-1b9e349e-93f2-41cc-a4b5-b212d7ddc8df_behavior+ecephys+image_1c4e2ebd-59da-4527-9700-b4b2dad6dfb2.zarr",
"000409_sub-KS096_ses-8928f98a-b411-497e-aa4b-aa752434686d_behavior+ecephys+image_d7ec0892-0a6c-4f4f-9d8f-72083692af5c.zarr",
"000409_sub-KS096_ses-a2701b93-d8e1-47e9-a819-f1063046f3e7_behavior+ecephys+image_f336f6a4-f693-4b88-b12c-c5cf0785b061.zarr",
"000409_sub-KS096_ses-f819d499-8bf7-4da0-a431-15377a8319d5_behavior+ecephys+image_4ea45238-55b1-4d54-ba92-efa47feb9f57.zarr",
]
existing_templates = list_zarr_directories(bucket, boto_client=boto_client)
templates_to_erase_from_bucket = [template for template in templates_to_erase_from_bucket if template in existing_templates]
if verbose:
print(f"Erasing {len(templates_to_erase_from_bucket)} templates from bucket: {bucket}")
delete_templates_from_s3(bucket, templates_to_erase_from_bucket, boto_client=boto_client)
20 changes: 18 additions & 2 deletions python/upload_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,18 @@ def find_channels_with_max_peak_to_peak_vectorized(templates):
"method": "all",
}

# Correct for round mismatches in the number of temporal samples in conversion from seconds to samples
target_ms_before = 3.0
target_ms_after = 5.0
expected_fs = 30_000
target_nbefore = int(target_ms_before / 1000 * expected_fs)
target_nafter = int(target_ms_after / 1000 * expected_fs)
ms_before_corrected = target_nbefore / recording.sampling_frequency * 1000
ms_after_corrected = target_nafter / recording.sampling_frequency * 1000

template_extension_parameters = {
"ms_before": 3.0,
"ms_after": 5.0,
"ms_before": ms_before_corrected,
"ms_after": ms_after_corrected,
"operators": ["average"],
}

Expand Down Expand Up @@ -247,6 +256,13 @@ def find_channels_with_max_peak_to_peak_vectorized(templates):
templates_extension = analyzer.get_extension("templates")
templates_extension_data = templates_extension.get_data(outputs="Templates")

# Do a check for the expected shape of the templates
number_of_units = sorting.get_num_units()
number_of_temporal_samples = target_nbefore + target_nafter
number_of_channels = pre_processed_recording.get_number_of_channels()
expected_shape = (number_of_units, number_of_temporal_samples, number_of_channels)
assert templates_extension_data.templates_array.shape == expected_shape

templates_extension = analyzer.get_extension("templates")
templates_object = templates_extension.get_data(outputs="Templates")
unit_ids = templates_object.unit_ids
Expand Down

0 comments on commit 010fd76

Please sign in to comment.