Skip to content

Commit

Permalink
Merge pull request #9 from SpikeInterface/check_if_templates_are_uplo…
Browse files Browse the repository at this point in the history
…aded_already

Avoid uploading the same template again
  • Loading branch information
h-mayorquin authored Jun 18, 2024
2 parents 2ef1a93 + 555262e commit 34072dd
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 28 deletions.
17 changes: 17 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
[tool.black]
line-length = 128
target-version = ['py37', 'py38', 'py39']
include = '\.pyi?$'
exclude = '''
/(
\.git
| \.hg
| \.mypy_cache
| \.tox
| \.venv
| _build
| buck-out
| build
| dist
)/
'''
Empty file added python/__init__.py
Empty file.
2 changes: 1 addition & 1 deletion python/consolidate_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@



def list_zarr_directories(bucket_name, boto_client=None):
def list_zarr_directories(bucket_name, boto_client=None) -> list[str]:
"""Lists top-level Zarr directory keys in an S3 bucket.
Parameters
Expand Down
72 changes: 45 additions & 27 deletions python/upload_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@
import numcodecs

from one.api import ONE
import os
import time
import os

from consolidate_datasets import list_zarr_directories


def find_channels_with_max_peak_to_peak_vectorized(templates):
Expand All @@ -49,11 +51,12 @@ def find_channels_with_max_peak_to_peak_vectorized(templates):
aws_access_key_id = os.environ.get("AWS_ACCESS_KEY_ID")
aws_secret_access_key = os.environ.get("AWS_SECRET_ACCESS_KEY")
bucket_name = "spikeinterface-template-database"
client_kwargs={"region_name": "us-east-2"}
client_kwargs = {"region_name": "us-east-2"}

# Parameters
minutes_by_the_end = 30 # How many minutes in the end of the recording to use for templates
upload_data = True
overwite = False
verbose = True

# Test data
Expand All @@ -75,9 +78,18 @@ def find_channels_with_max_peak_to_peak_vectorized(templates):
if do_testing_data:
dandiset_paths = [test_path]

for asset_path in dandiset_paths[:]:
# Load already processed datasets
zarr_datasets = list_zarr_directories(bucket_name=bucket_name)
if verbose:
print(f"Found {len(zarr_datasets)} datasets already processed")

dandiset_paths = np.random.choice(dandiset_paths, size=len(dandiset_paths), replace=False)
for asset_path in dandiset_paths:
if verbose:
print("-------------------")
print("----------------------------------------------------------")
print("----------------------------------------------------------")
print("----------------------------------------------------------")

print(asset_path)

recording_asset = dandiset.get_asset_by_path(path=asset_path)
Expand All @@ -89,14 +101,18 @@ def find_channels_with_max_peak_to_peak_vectorized(templates):
)
electrical_series_paths_ap = [path for path in electrical_series_paths if "Ap" in path.split("/")[-1]]
for electrical_series_path in electrical_series_paths_ap:
print(electrical_series_path)
print(f"{electrical_series_path=}")

recording = NwbRecordingExtractor(
file_path=file_path,
stream_mode="remfile",
electrical_series_path=electrical_series_path,
)


if verbose:
print("Recording")
print(recording)

session_id = recording._file["general"]["session_id"][()].decode()
eid = session_id.split("-chunking")[0] # eid : experiment id
pids, probes = one_instance.eid2pid(eid)
Expand All @@ -116,6 +132,17 @@ def find_channels_with_max_peak_to_peak_vectorized(templates):
sorting_pid = pids[0]
probe_number = "00"

if do_testing_data:
dataset_name = "test_templates.zarr"
else:
dandi_name = asset_path.split("/")[-1].split(".")[0]
dataset_name = f"{dandiset_id}_{dandi_name}_{sorting_pid}.zarr"

if dataset_name in zarr_datasets and not overwite:
if verbose:
print(f"Dataset {dataset_name} already processed, skipping")
continue

sorting = IblSortingExtractor(
pid=sorting_pid,
one=one_instance,
Expand Down Expand Up @@ -146,21 +173,20 @@ def find_channels_with_max_peak_to_peak_vectorized(templates):
end_frame_sorting = num_samples

sorting_end = sorting.frame_slice(start_frame=start_frame_sorting, end_frame=end_frame_sorting)


# NWB Streaming is not working well with parallel pre=processing so we ave
folder_path = Path.cwd() / "local_copy"
folder_path = Path.cwd() / "build" / "local_copy"
folder_path.mkdir(exist_ok=True, parents=True)

if verbose:
print("Saving Recording")
print(recording)
start_time = time.time()

recording = recording.save_to_folder(
folder=folder_path,
overwrite=True,
n_jobs=6,
n_jobs=8,
chunk_memory="1Gi",
verbose=True,
progress_bar=True,
Expand Down Expand Up @@ -201,13 +227,15 @@ def find_channels_with_max_peak_to_peak_vectorized(templates):
if verbose:
print("Computing extensions")
start_time = time.time()

analyzer.compute_several_extensions(
extensions=extensions,
n_jobs=4,
n_jobs=8,
verbose=True,
progress_bar=True,
chunk_memory="500Mi",
chunk_memory="250Mi",
)

if verbose:
end_time = time.time()
execution_time = end_time - start_time
Expand All @@ -228,31 +256,21 @@ def find_channels_with_max_peak_to_peak_vectorized(templates):
templates_object.probe.manufacturer = probe_info[0]["manufacturer"]
templates_object.probe.serial_number = probe_info[0]["serial_number"]

if do_testing_data:
dataset_name = "test_templates.zarr"
else:
path = asset_path.split("/")[-1]
dataset_name = f"{dandiset_id}_{path}_{sorting_pid}.zarr"

if verbose:
print("Saving data to Zarr")
print(f"{dataset_name=}")

if upload_data:
# Create a S3 file system object with explicit credentials
s3_kwargs = dict(
anon=False,
key=aws_access_key_id,
secret=aws_secret_access_key,
client_kwargs=client_kwargs
)
s3_kwargs = dict(anon=False, key=aws_access_key_id, secret=aws_secret_access_key, client_kwargs=client_kwargs)
s3 = s3fs.S3FileSystem(**s3_kwargs)

# Specify the S3 bucket and path
s3_path = f"{bucket_name}/{dataset_name}"
store = s3fs.S3Map(root=s3_path, s3=s3)
else:
folder_path = Path.cwd() / f"{dataset_name}"
folder_path = Path.cwd() / "build" / f"{dataset_name}"
folder_path.mkdir(exist_ok=True, parents=True)
store = zarr.DirectoryStore(str(folder_path))

# Save results to Zarr
Expand Down

0 comments on commit 34072dd

Please sign in to comment.