Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update to SI new API #2

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/mountainsort4_example1.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#!/usr/bin/env python3

import mountainsort4 as ms4
import spikeextractors as se
import spikeinterface.extractors as se

def main():
recording, sorting_true = se.example_datasets.toy_example()
recording, sorting_true = se.toy_example(num_segments=1)
sorting = ms4.mountainsort4(
recording=recording,
detect_sign=-1,
Expand Down
2 changes: 1 addition & 1 deletion jinjaroot.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
projectName: mountainsort4
projectVersion: 1.0.0
projectVersion: 1.1.0
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bump up to version 2.0.0!

projectAuthor: Jeremy Magland
projectAuthorEmail: [email protected]
projectDescription: Spike sorting using MountainSort4 algorithm
Expand Down
3 changes: 2 additions & 1 deletion mountainsort4/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .mountainsort4 import mountainsort4
from .mountainsort4 import mountainsort4
from .version import __version__
24 changes: 10 additions & 14 deletions mountainsort4/mountainsort4.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
import numpy as np
import math
import multiprocessing
import spikeextractors as se
import spikeinterface as si


def mountainsort4(*, recording: se.RecordingExtractor, detect_sign: int, clip_size: int=50, adjacency_radius: float=-1, detect_threshold: float=3, detect_interval: int=10,
num_workers: Union[None, int]=None, verbose: bool=True) -> se.SortingExtractor:
def mountainsort4(*, recording: si.BaseRecording, detect_sign: int, clip_size: int=50, adjacency_radius: float=-1, detect_threshold: float=3, detect_interval: int=10,
num_workers: Union[None, int]=None, verbose: bool=True) -> si.BaseSorting:
if num_workers is None:
num_workers = math.floor((multiprocessing.cpu_count()+1)/2)

Expand Down Expand Up @@ -45,18 +45,14 @@ def mountainsort4(*, recording: se.RecordingExtractor, detect_sign: int, clip_si
print('Cleaning tmpdir::::: '+tmpdir)
shutil.rmtree(tmpdir)
times, labels, channels = MS4.eventTimesLabelsChannels()
output = se.NumpySortingExtractor()
output.set_times_labels(times=times, labels=labels)
output = si.NumpySorting.from_times_labels(times_list=times, labels_list=labels,
sampling_frequency=recording.get_sampling_frequency())
return output


def _get_geom_from_recording(recording: se.RecordingExtractor):
channel_ids = cast(np.ndarray, recording.get_channel_ids())
M = len(channel_ids)
location0 = recording.get_channel_property(channel_ids[0], 'location')
nd = len(location0)
geom = np.zeros((M, nd))
for i in range(M):
location_i = recording.get_channel_property(channel_ids[i], 'location')
geom[i, :] = location_i
def _get_geom_from_recording(recording: si.BaseRecording):
if 'location' in recording.get_property_keys():
geom = recording.get_channel_locations()
else:
raise AttributeError("mountainsort4 needs locations to be added to the recording object")
return geom
8 changes: 4 additions & 4 deletions mountainsort4/ms4alg.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,7 @@ def numChannels(self):
return len(self._recording.get_channel_ids())

def numTimepoints(self):
return self._recording.get_num_frames()
return self._recording.get_num_samples()

def getChunk(self, *, t1, t2, channels):
channel_ids = self._recording.get_channel_ids()
Expand All @@ -642,14 +642,14 @@ def getChunk(self, *, t1, t2, channels):
t2=t2a, channels=channels)
return ret
else:
return self._recording.get_traces(start_frame=t1, end_frame=t2, channel_ids=channels2)
return self._recording.get_traces(start_frame=t1, end_frame=t2, channel_ids=channels2).T


def prepare_timeseries_hdf5_from_recording(recording, timeseries_hdf5_fname, *, chunk_size, padding):
chunk_size_with_padding = chunk_size+2*padding
with h5py.File(timeseries_hdf5_fname, "w") as f:
M = len(recording.get_channel_ids()) # Number of channels
N = recording.get_num_frames() # Number of timepoints
N = recording.get_num_samples() # Number of timepoints
num_chunks = math.ceil(N/chunk_size)
f.create_dataset('chunk_size', data=[chunk_size])
f.create_dataset('num_chunks', data=[num_chunks])
Expand All @@ -672,7 +672,7 @@ def prepare_timeseries_hdf5_from_recording(recording, timeseries_hdf5_fname, *,
aa = padding-(t1-s1)
# Read the padded chunk
padded_chunk[:, aa:aa+s2 -
s1] = recording.get_traces(start_frame=s1, end_frame=s2)
s1] = recording.get_traces(start_frame=s1, end_frame=s2).T

for m in range(M):
f.create_dataset('part-{}-{}'.format(m, j),
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@
'numpy',
'h5py',
'sklearn',
'spikeextractors>=0.9.5'
'spikeinterface>=0.90'
]
)