diff --git a/examples/mountainsort4_example1.py b/examples/mountainsort4_example1.py index 66a41b4..3635abc 100644 --- a/examples/mountainsort4_example1.py +++ b/examples/mountainsort4_example1.py @@ -1,10 +1,10 @@ #!/usr/bin/env python3 import mountainsort4 as ms4 -import spikeextractors as se +import spikeinterface.extractors as se def main(): - recording, sorting_true = se.example_datasets.toy_example() + recording, sorting_true = se.toy_example(num_segments=1) sorting = ms4.mountainsort4( recording=recording, detect_sign=-1, diff --git a/jinjaroot.yaml b/jinjaroot.yaml index 4af38b0..62c984d 100644 --- a/jinjaroot.yaml +++ b/jinjaroot.yaml @@ -1,5 +1,5 @@ projectName: mountainsort4 -projectVersion: 1.0.0 +projectVersion: 1.1.0 projectAuthor: Jeremy Magland projectAuthorEmail: jmagland@flatironinstitute.org projectDescription: Spike sorting using MountainSort4 algorithm diff --git a/mountainsort4/__init__.py b/mountainsort4/__init__.py index 7dd55a9..2e5a520 100644 --- a/mountainsort4/__init__.py +++ b/mountainsort4/__init__.py @@ -1 +1,2 @@ -from .mountainsort4 import mountainsort4 \ No newline at end of file +from .mountainsort4 import mountainsort4 +from .version import __version__ diff --git a/mountainsort4/mountainsort4.py b/mountainsort4/mountainsort4.py index e9a2dc3..7af3cd7 100644 --- a/mountainsort4/mountainsort4.py +++ b/mountainsort4/mountainsort4.py @@ -6,11 +6,11 @@ import numpy as np import math import multiprocessing -import spikeextractors as se +import spikeinterface as si -def mountainsort4(*, recording: se.RecordingExtractor, detect_sign: int, clip_size: int=50, adjacency_radius: float=-1, detect_threshold: float=3, detect_interval: int=10, - num_workers: Union[None, int]=None, verbose: bool=True) -> se.SortingExtractor: +def mountainsort4(*, recording: si.BaseRecording, detect_sign: int, clip_size: int=50, adjacency_radius: float=-1, detect_threshold: float=3, detect_interval: int=10, + num_workers: Union[None, int]=None, verbose: bool=True) -> si.BaseSorting: if num_workers is None: num_workers = math.floor((multiprocessing.cpu_count()+1)/2) @@ -45,18 +45,14 @@ def mountainsort4(*, recording: se.RecordingExtractor, detect_sign: int, clip_si print('Cleaning tmpdir::::: '+tmpdir) shutil.rmtree(tmpdir) times, labels, channels = MS4.eventTimesLabelsChannels() - output = se.NumpySortingExtractor() - output.set_times_labels(times=times, labels=labels) + output = si.NumpySorting.from_times_labels(times_list=times, labels_list=labels, + sampling_frequency=recording.get_sampling_frequency()) return output -def _get_geom_from_recording(recording: se.RecordingExtractor): - channel_ids = cast(np.ndarray, recording.get_channel_ids()) - M = len(channel_ids) - location0 = recording.get_channel_property(channel_ids[0], 'location') - nd = len(location0) - geom = np.zeros((M, nd)) - for i in range(M): - location_i = recording.get_channel_property(channel_ids[i], 'location') - geom[i, :] = location_i +def _get_geom_from_recording(recording: si.BaseRecording): + if 'location' in recording.get_property_keys(): + geom = recording.get_channel_locations() + else: + raise AttributeError("mountainsort4 needs locations to be added to the recording object") return geom diff --git a/mountainsort4/ms4alg.py b/mountainsort4/ms4alg.py index 28cd00f..1a1925e 100644 --- a/mountainsort4/ms4alg.py +++ b/mountainsort4/ms4alg.py @@ -627,7 +627,7 @@ def numChannels(self): return len(self._recording.get_channel_ids()) def numTimepoints(self): - return self._recording.get_num_frames() + return self._recording.get_num_samples() def getChunk(self, *, t1, t2, channels): channel_ids = self._recording.get_channel_ids() @@ -642,14 +642,14 @@ def getChunk(self, *, t1, t2, channels): t2=t2a, channels=channels) return ret else: - return self._recording.get_traces(start_frame=t1, end_frame=t2, channel_ids=channels2) + return self._recording.get_traces(start_frame=t1, end_frame=t2, channel_ids=channels2).T def prepare_timeseries_hdf5_from_recording(recording, timeseries_hdf5_fname, *, chunk_size, padding): chunk_size_with_padding = chunk_size+2*padding with h5py.File(timeseries_hdf5_fname, "w") as f: M = len(recording.get_channel_ids()) # Number of channels - N = recording.get_num_frames() # Number of timepoints + N = recording.get_num_samples() # Number of timepoints num_chunks = math.ceil(N/chunk_size) f.create_dataset('chunk_size', data=[chunk_size]) f.create_dataset('num_chunks', data=[num_chunks]) @@ -672,7 +672,7 @@ def prepare_timeseries_hdf5_from_recording(recording, timeseries_hdf5_fname, *, aa = padding-(t1-s1) # Read the padded chunk padded_chunk[:, aa:aa+s2 - - s1] = recording.get_traces(start_frame=s1, end_frame=s2) + s1] = recording.get_traces(start_frame=s1, end_frame=s2).T for m in range(M): f.create_dataset('part-{}-{}'.format(m, j), diff --git a/setup.py b/setup.py index f17b76e..7aa282a 100644 --- a/setup.py +++ b/setup.py @@ -10,6 +10,6 @@ 'numpy', 'h5py', 'sklearn', - 'spikeextractors>=0.9.5' + 'spikeinterface>=0.90' ] ) \ No newline at end of file