diff --git a/python/consolidate_datasets.py b/python/consolidate_datasets.py index 0710efb..54a30fa 100644 --- a/python/consolidate_datasets.py +++ b/python/consolidate_datasets.py @@ -1,16 +1,18 @@ +from pathlib import Path +from argparse import ArgumentParser + import boto3 import pandas as pd import zarr import numpy as np -from argparse import ArgumentParser +from tqdm.auto import tqdm from spikeinterface.core import Templates parser = ArgumentParser(description="Consolidate datasets from spikeinterface template database") parser.add_argument("--dry-run", action="store_true", help="Dry run (no upload)") -parser.add_argument("--bucket", type=str, help="S3 bucket name", default="spikeinterface-template-database") - +parser.add_argument("--verbose", action="store_true", help="Print additional information during processing") def list_zarr_directories(bucket_name, boto_client=None) -> list[str]: @@ -30,17 +32,18 @@ def list_zarr_directories(bucket_name, boto_client=None) -> list[str]: found in the bucket. """ - boto_client = boto_client or boto3.client('s3') - zarr_directories = set() + boto_client = boto_client or boto3.client("s3") + zarr_directories = set() + + paginator = boto_client.get_paginator("list_objects_v2") + for page in paginator.paginate(Bucket=bucket_name, Delimiter="/"): + for prefix in page.get("CommonPrefixes", []): + key = prefix["Prefix"] + if key.endswith(".zarr/"): + zarr_directories.add(key.rstrip("/")) - paginator = boto_client.get_paginator('list_objects_v2') - for page in paginator.paginate(Bucket=bucket_name, Delimiter='/'): - for prefix in page.get('CommonPrefixes', []): - key = prefix['Prefix'] - if key.endswith('.zarr/'): - zarr_directories.add(key.rstrip('/')) + return list(zarr_directories) - return list(zarr_directories) def consolidate_datasets(dry_run: bool = False, verbose: bool = False): """Consolidates data from Zarr datasets within an S3 bucket. @@ -62,8 +65,8 @@ def consolidate_datasets(dry_run: bool = False, verbose: bool = False): FileNotFoundError If no Zarr datasets are found in the specified bucket. """ - - bucket="spikeinterface-template-database" + + bucket = "spikeinterface-template-database" boto_client = boto3.client("s3") # Get list of Zarr directories, excluding test datasets @@ -78,9 +81,10 @@ def consolidate_datasets(dry_run: bool = False, verbose: bool = False): # Initialize list to collect DataFrames for each dataset all_dataframes = [] - - for dataset in zarr_datasets: - print(f"Processing dataset {dataset}") + desc = "Processing Zarr datasets" + for dataset in tqdm(zarr_datasets, desc=desc, unit=" datasets processed", disable=not verbose): + if verbose: + print(f"Processing dataset: {dataset}") zarr_path = f"s3://{bucket}/{dataset}" zarr_group = zarr.open_consolidated(zarr_path, storage_options=dict(anon=True)) templates = Templates.from_zarr_group(zarr_group) @@ -122,19 +126,29 @@ def consolidate_datasets(dry_run: bool = False, verbose: bool = False): # Concatenate all DataFrames into a single DataFrame templates_df = pd.concat(all_dataframes, ignore_index=True) - templates_df.to_csv("templates.csv", index=False) + templates_file_name = "templates.csv" + local_template_folder = Path("./build/") + local_template_info_file_path = local_template_folder / templates_file_name + templates_df.to_csv(local_template_info_file_path, index=False) # Upload to S3 - if not dry_run: - boto_client.upload_file("templates.csv", bucket, "templates.csv") + if dry_run: + print("Dry run: skipping upload to S3") + else: + boto_client.upload_file( + Filename=local_template_info_file_path, + Bucket=bucket, + Key=templates_file_name, + ) if verbose: - print("Dry run, not uploading") print(templates_df) return templates_df + if __name__ == "__main__": params = parser.parse_args() - DRY_RUN = params.dry_run - templates_df = consolidate_datasets(dry_run=DRY_RUN) + dry_run = params.dry_run + verbose = params.verbose + templates_df = consolidate_datasets(dry_run=dry_run, verbose=verbose) diff --git a/python/delete_templates.py b/python/delete_templates.py new file mode 100644 index 0000000..78c6a07 --- /dev/null +++ b/python/delete_templates.py @@ -0,0 +1,56 @@ +import boto3 +from consolidate_datasets import list_zarr_directories + + +def delete_template_from_s3(bucket_name: str, template_key: str, boto_client: boto3.client = None) -> None: + """Deletes a Zarr template (and its contents) from S3.""" + + boto_client = boto_client or boto3.client("s3") + + # Delete all objects within the template directory (including nested directories) + boto_client.delete_objects( + Bucket=bucket_name, + Delete={ + "Objects": [ + {"Key": obj["Key"]} + for obj in boto_client.list_objects_v2(Bucket=bucket_name, Prefix=template_key).get("Contents", []) + ] + }, + ) + print(f"Deleted template: {template_key}") + + +def delete_templates_from_s3( + bucket_name: str, + template_keys: list[str], + boto_client: boto3.client = None, +) -> None: + """Deletes multiple Zarr templates from S3.""" + boto_client = boto_client or boto3.client("s3") + for key in template_keys: + delete_template_from_s3(bucket_name, key, boto_client=boto_client) + + +if __name__ == "__main__": + bucket = "spikeinterface-template-database" + boto_client = boto3.client("s3") + verbose = True + + templates_to_erase_from_bucket = [ + "000409_sub-KS084_ses-1b715600-0cbc-442c-bd00-5b0ac2865de1_behavior+ecephys+image_bbe6ebc1-d32f-42dd-a89c-211226737deb.zarr", + "000409_sub-KS086_ses-e45481fa-be22-4365-972c-e7404ed8ab5a_behavior+ecephys+image_f2a098e7-a67e-4125-92d8-36fc6b606c45.zarr", + "000409_sub-KS091_ses-196a2adf-ff83-49b2-823a-33f990049c2e_behavior+ecephys+image_0259543e-1ca3-48e7-95c9-53f9e4c9bfcc.zarr", + "000409_sub-KS091_ses-78b4fff5-c5ec-44d9-b5f9-d59493063f00_behavior+ecephys+image_19c5b0d5-a255-47ff-9f8d-639e634a7b61.zarr", + "000409_sub-KS094_ses-6b0b5d24-bcda-4053-a59c-beaa1fe03b8f_behavior+ecephys+image_3282a590-8688-44fc-9811-cdf8b80d9a80.zarr", + "000409_sub-KS094_ses-752456f3-9f47-4fbf-bd44-9d131c0f41aa_behavior+ecephys+image_100433fa-2c59-4432-8295-aa27657fe3fb.zarr", + "000409_sub-KS094_ses-c8d46ee6-eb68-4535-8756-7c9aa32f10e4_behavior+ecephys+image_49a86b2e-3db4-42f2-8da8-7ebb7e482c70.zarr", + "000409_sub-KS096_ses-1b9e349e-93f2-41cc-a4b5-b212d7ddc8df_behavior+ecephys+image_1c4e2ebd-59da-4527-9700-b4b2dad6dfb2.zarr", + "000409_sub-KS096_ses-8928f98a-b411-497e-aa4b-aa752434686d_behavior+ecephys+image_d7ec0892-0a6c-4f4f-9d8f-72083692af5c.zarr", + "000409_sub-KS096_ses-a2701b93-d8e1-47e9-a819-f1063046f3e7_behavior+ecephys+image_f336f6a4-f693-4b88-b12c-c5cf0785b061.zarr", + "000409_sub-KS096_ses-f819d499-8bf7-4da0-a431-15377a8319d5_behavior+ecephys+image_4ea45238-55b1-4d54-ba92-efa47feb9f57.zarr", + ] + existing_templates = list_zarr_directories(bucket, boto_client=boto_client) + templates_to_erase_from_bucket = [template for template in templates_to_erase_from_bucket if template in existing_templates] + if verbose: + print(f"Erasing {len(templates_to_erase_from_bucket)} templates from bucket: {bucket}") + delete_templates_from_s3(bucket, templates_to_erase_from_bucket, boto_client=boto_client) diff --git a/python/upload_templates.py b/python/upload_templates.py index 82529be..66e2adf 100644 --- a/python/upload_templates.py +++ b/python/upload_templates.py @@ -207,9 +207,18 @@ def find_channels_with_max_peak_to_peak_vectorized(templates): "method": "all", } + # Correct for round mismatches in the number of temporal samples in conversion from seconds to samples + target_ms_before = 3.0 + target_ms_after = 5.0 + expected_fs = 30_000 + target_nbefore = int(target_ms_before / 1000 * expected_fs) + target_nafter = int(target_ms_after / 1000 * expected_fs) + ms_before_corrected = target_nbefore / recording.sampling_frequency * 1000 + ms_after_corrected = target_nafter / recording.sampling_frequency * 1000 + template_extension_parameters = { - "ms_before": 3.0, - "ms_after": 5.0, + "ms_before": ms_before_corrected, + "ms_after": ms_after_corrected, "operators": ["average"], } @@ -247,6 +256,13 @@ def find_channels_with_max_peak_to_peak_vectorized(templates): templates_extension = analyzer.get_extension("templates") templates_extension_data = templates_extension.get_data(outputs="Templates") + # Do a check for the expected shape of the templates + number_of_units = sorting.get_num_units() + number_of_temporal_samples = target_nbefore + target_nafter + number_of_channels = pre_processed_recording.get_number_of_channels() + expected_shape = (number_of_units, number_of_temporal_samples, number_of_channels) + assert templates_extension_data.templates_array.shape == expected_shape + templates_extension = analyzer.get_extension("templates") templates_object = templates_extension.get_data(outputs="Templates") unit_ids = templates_object.unit_ids