From 12fd197859a3bb91099e9f5fb73fc5f74f923847 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 19 Sep 2023 12:56:55 +0200 Subject: [PATCH 01/18] Use sparsity mask and handle right border correctly --- .../postprocessing/amplitude_scalings.py | 54 +++++++++---------- 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 5a0148c5c4..4dab68fdf8 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -90,10 +90,7 @@ def _run(self, **job_kwargs): if self._params["max_dense_channels"] is not None: assert recording.get_num_channels() <= self._params["max_dense_channels"], "" sparsity = ChannelSparsity.create_dense(we) - sparsity_inds = sparsity.unit_id_to_channel_indices - - # easier to use in chunk function as spikes use unit_index instead o id - unit_inds_to_channel_indices = {unit_ind: sparsity_inds[unit_id] for unit_ind, unit_id in enumerate(unit_ids)} + sparsity_mask = sparsity.mask all_templates = we.get_all_templates() # precompute segment slice @@ -113,7 +110,7 @@ def _run(self, **job_kwargs): self.spikes, all_templates, segment_slices, - unit_inds_to_channel_indices, + sparsity_mask, nbefore, nafter, cut_out_before, @@ -262,7 +259,7 @@ def _init_worker_amplitude_scalings( spikes, all_templates, segment_slices, - unit_inds_to_channel_indices, + sparsity_mask, nbefore, nafter, cut_out_before, @@ -282,7 +279,7 @@ def _init_worker_amplitude_scalings( worker_ctx["cut_out_before"] = cut_out_before worker_ctx["cut_out_after"] = cut_out_after worker_ctx["return_scaled"] = return_scaled - worker_ctx["unit_inds_to_channel_indices"] = unit_inds_to_channel_indices + worker_ctx["sparsity_mask"] = sparsity_mask worker_ctx["handle_collisions"] = handle_collisions worker_ctx["delta_collision_samples"] = delta_collision_samples @@ -306,7 +303,7 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) recording = worker_ctx["recording"] all_templates = worker_ctx["all_templates"] segment_slices = worker_ctx["segment_slices"] - unit_inds_to_channel_indices = worker_ctx["unit_inds_to_channel_indices"] + sparsity_mask = worker_ctx["sparsity_mask"] nbefore = worker_ctx["nbefore"] cut_out_before = worker_ctx["cut_out_before"] cut_out_after = worker_ctx["cut_out_after"] @@ -339,7 +336,7 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) i1_margin = np.searchsorted(spikes_in_segment["sample_index"], end_frame + right) local_spikes_w_margin = spikes_in_segment[i0_margin:i1_margin] collisions_local = find_collisions( - local_spikes, local_spikes_w_margin, delta_collision_samples, unit_inds_to_channel_indices + local_spikes, local_spikes_w_margin, delta_collision_samples, sparsity_mask ) else: collisions_local = {} @@ -354,7 +351,7 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) continue unit_index = spike["unit_index"] sample_index = spike["sample_index"] - sparse_indices = unit_inds_to_channel_indices[unit_index] + sparse_indices = sparsity_mask[unit_index] template = all_templates[unit_index][:, sparse_indices] template = template[nbefore - cut_out_before : nbefore + cut_out_after] sample_centered = sample_index - start_frame @@ -393,7 +390,7 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) right, nbefore, all_templates, - unit_inds_to_channel_indices, + sparsity_mask, cut_out_before, cut_out_after, ) @@ -410,14 +407,14 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) ### Collision handling ### -def _are_unit_indices_overlapping(unit_inds_to_channel_indices, i, j): +def _are_unit_indices_overlapping(sparsity_mask, i, j): """ Returns True if the unit indices i and j are overlapping, False otherwise Parameters ---------- - unit_inds_to_channel_indices: dict - A dictionary mapping unit indices to channel indices + sparsity_mask: boolean mask + The sparsity mask i: int The first unit index j: int @@ -428,13 +425,13 @@ def _are_unit_indices_overlapping(unit_inds_to_channel_indices, i, j): bool True if the unit indices i and j are overlapping, False otherwise """ - if len(np.intersect1d(unit_inds_to_channel_indices[i], unit_inds_to_channel_indices[j])) > 0: + if np.sum(np.logical_and(sparsity_mask[i], sparsity_mask[j])) > 0: return True else: return False -def find_collisions(spikes, spikes_w_margin, delta_collision_samples, unit_inds_to_channel_indices): +def find_collisions(spikes, spikes_w_margin, delta_collision_samples, sparsity_mask): """ Finds the collisions between spikes. @@ -446,8 +443,8 @@ def find_collisions(spikes, spikes_w_margin, delta_collision_samples, unit_inds_ An array of spikes within the added margin delta_collision_samples: int The maximum number of samples between two spikes to consider them as overlapping - unit_inds_to_channel_indices: dict - A dictionary mapping unit indices to channel indices + sparsity_mask: boolean mask + The sparsity mask Returns ------- @@ -480,7 +477,7 @@ def find_collisions(spikes, spikes_w_margin, delta_collision_samples, unit_inds_ # find the overlapping spikes in space as well for possible_overlapping_spike_index in possible_overlapping_spike_indices: if _are_unit_indices_overlapping( - unit_inds_to_channel_indices, + sparsity_mask, spike["unit_index"], spikes_w_margin[possible_overlapping_spike_index]["unit_index"], ): @@ -501,7 +498,7 @@ def fit_collision( right, nbefore, all_templates, - unit_inds_to_channel_indices, + sparsity_mask, cut_out_before, cut_out_after, ): @@ -528,8 +525,8 @@ def fit_collision( The number of samples before the spike to consider for the fit. all_templates: np.ndarray A numpy array of shape (n_units, n_samples, n_channels) containing the templates. - unit_inds_to_channel_indices: dict - A dictionary mapping unit indices to channel indices. + sparsity_mask: boolean mask + The sparsity mask cut_out_before: int The number of samples to cut out before the spike. cut_out_after: int @@ -547,14 +544,15 @@ def fit_collision( sample_last_centered = np.max(collision["sample_index"]) - (start_frame - left) # construct sparsity as union between units' sparsity - sparse_indices = np.array([], dtype="int") + sparse_indices = np.zeros(sparsity_mask.shape[1], dtype="int") for spike in collision: - sparse_indices_i = unit_inds_to_channel_indices[spike["unit_index"]] - sparse_indices = np.union1d(sparse_indices, sparse_indices_i) + sparse_indices_i = sparsity_mask[spike["unit_index"]] + sparse_indices = np.logical_or(sparse_indices, sparse_indices_i) local_waveform_start = max(0, sample_first_centered - cut_out_before) local_waveform_end = min(traces_with_margin.shape[0], sample_last_centered + cut_out_after) local_waveform = traces_with_margin[local_waveform_start:local_waveform_end, sparse_indices] + num_samples_local_waveform = local_waveform.shape[0] y = local_waveform.T.flatten() X = np.zeros((len(y), len(collision))) @@ -567,8 +565,10 @@ def fit_collision( # deal with borders if sample_centered - cut_out_before < 0: full_template[: sample_centered + cut_out_after] = template_cut[cut_out_before - sample_centered :] - elif sample_centered + cut_out_after > end_frame + right: - full_template[sample_centered - cut_out_before :] = template_cut[: -cut_out_after - (end_frame + right)] + elif sample_centered + cut_out_after > num_samples_local_waveform: + full_template[sample_centered - cut_out_before :] = template_cut[ + : -(cut_out_after + sample_centered - num_samples_local_waveform) + ] else: full_template[sample_centered - cut_out_before : sample_centered + cut_out_after] = template_cut X[:, i] = full_template.T.flatten() From c8579b573236a6e454e27c329e9a03482be606f7 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 20 Sep 2023 21:25:42 +0200 Subject: [PATCH 02/18] minor chnages on drift benchmark for figures --- .../benchmark/benchmark_motion_estimation.py | 33 +++++++++++-------- .../benchmark_motion_interpolation.py | 14 +++++--- 2 files changed, 29 insertions(+), 18 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py index dd35670abd..a47b97fb6d 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py @@ -487,7 +487,7 @@ def plot_errors_several_benchmarks(benchmarks, axes=None, show_legend=True, colo mean_error = np.sqrt(np.mean((errors) ** 2, axis=1)) depth_error = np.sqrt(np.mean((errors) ** 2, axis=0)) - axes[0].plot(benchmark.temporal_bins, mean_error, label=benchmark.title, color=c) + axes[0].plot(benchmark.temporal_bins, mean_error, lw=1, label=benchmark.title, color=c) parts = axes[1].violinplot(mean_error, [count], showmeans=True) if c is not None: for pc in parts["bodies"]: @@ -584,23 +584,30 @@ def plot_motions_several_benchmarks(benchmarks): _simpleaxis(ax) -def plot_speed_several_benchmarks(benchmarks, ax=None, colors=None): +def plot_speed_several_benchmarks(benchmarks, detailed=True, ax=None, colors=None): if ax is None: fig, ax = plt.subplots(figsize=(5, 5)) for count, benchmark in enumerate(benchmarks): color = colors[count] if colors is not None else None - bottom = 0 - i = 0 - patterns = ["/", "\\", "|", "*"] - for key, value in benchmark.run_times.items(): - if count == 0: - label = key.replace("_", " ") - else: - label = None - ax.bar([count], [value], label=label, bottom=bottom, color=color, edgecolor="black", hatch=patterns[i]) - bottom += value - i += 1 + + if detailed: + bottom = 0 + i = 0 + patterns = ["/", "\\", "|", "*"] + for key, value in benchmark.run_times.items(): + if count == 0: + label = key.replace("_", " ") + else: + label = None + ax.bar([count], [value], label=label, bottom=bottom, color=color, edgecolor="black", hatch=patterns[i]) + bottom += value + i += 1 + else: + total_run_time = np.sum([value for key, value in benchmark.run_times.items()]) + ax.bar([count], [total_run_time], color=color, edgecolor="black") + + # ax.legend() ax.set_ylabel("speed (s)") diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py index 13a64e8168..8e5afb2e8e 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py @@ -9,7 +9,7 @@ from spikeinterface.extractors import read_mearec from spikeinterface.preprocessing import bandpass_filter, zscore, common_reference, scale, highpass_filter, whiten -from spikeinterface.sorters import run_sorter +from spikeinterface.sorters import run_sorter, read_sorter_folder from spikeinterface.widgets import plot_unit_waveforms, plot_gt_performances from spikeinterface.comparison import GroundTruthComparison @@ -184,7 +184,7 @@ def extract_waveforms(self): we.run_extract_waveforms(seed=22051977, **self.job_kwargs) self.waveforms[key] = we - def run_sorters(self): + def run_sorters(self, skip_already_done=True): for case in self.sorter_cases: label = case["label"] print("run sorter", label) @@ -192,9 +192,13 @@ def run_sorters(self): sorter_params = case["sorter_params"] recording = self.recordings[case["recording"]] output_folder = self.folder / f"tmp_sortings_{label}" - sorting = run_sorter( - sorter_name, recording, output_folder, **sorter_params, delete_output_folder=self.delete_output_folder - ) + if output_folder.exists() and skip_already_done: + print('already done') + sorting = read_sorter_folder(output_folder) + else: + sorting = run_sorter( + sorter_name, recording, output_folder, **sorter_params, delete_output_folder=self.delete_output_folder + ) self.sortings[label] = sorting def compute_distances_to_static(self, force=False): From e964731b33401db1757ce813d2078c00a36dcf34 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 21 Sep 2023 16:36:31 +0200 Subject: [PATCH 03/18] Start refactor ipywidgets plot_traces --- src/spikeinterface/widgets/traces.py | 29 +- .../widgets/utils_ipywidgets.py | 251 ++++++++++++++++-- 2 files changed, 254 insertions(+), 26 deletions(-) diff --git a/src/spikeinterface/widgets/traces.py b/src/spikeinterface/widgets/traces.py index 7bb2126744..c6e36387f8 100644 --- a/src/spikeinterface/widgets/traces.py +++ b/src/spikeinterface/widgets/traces.py @@ -276,11 +276,16 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt import ipywidgets.widgets as widgets from IPython.display import display + import ipywidgets.widgets as W from .utils_ipywidgets import ( check_ipywidget_backend, make_timeseries_controller, make_channel_controller, make_scale_controller, + + TimeSlider, + ScaleWidget, + ) check_ipywidget_backend() @@ -308,6 +313,8 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): t_start = 0.0 t_stop = rec0.get_num_samples(segment_index=0) / rec0.get_sampling_frequency() + + ts_widget, ts_controller = make_timeseries_controller( t_start, t_stop, @@ -319,6 +326,22 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): width_cm, ) + # some widgets + self.time_slider = TimeSlider( + durations=[rec0.get_duration(s) for s in range(rec0.get_num_segments())], + sampling_frequency=rec0.sampling_frequency, + ) + self.layer_selector = W.Dropdown(description="layer", options=data_plot["layer_keys"], + layout=W.Layout(width="5cm"),) + self.mode_selector = W.Dropdown(options=["line", "map"], description="mode", value=data_plot["mode"], + layout=W.Layout(width="5cm"),) + self.scaler = ScaleWidget() + left_sidebar = W.VBox( + children=[self.layer_selector, self.mode_selector, self.scaler], + layout=W.Layout(width="5cm"), + ) + + ch_widget, ch_controller = make_channel_controller(rec0, width_cm=ratios[2] * width_cm, height_cm=height_cm) scale_widget, scale_controller = make_scale_controller(width_cm=ratios[0] * width_cm, height_cm=height_cm) @@ -346,8 +369,10 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.widget = widgets.AppLayout( center=self.figure.canvas, - footer=ts_widget, - left_sidebar=scale_widget, + # footer=ts_widget, + footer=self.time_slider, + # left_sidebar=scale_widget, + left_sidebar = left_sidebar, right_sidebar=ch_widget, pane_heights=[0, 6, 1], pane_widths=ratios, diff --git a/src/spikeinterface/widgets/utils_ipywidgets.py b/src/spikeinterface/widgets/utils_ipywidgets.py index a7c571d1f0..674a2d2cc7 100644 --- a/src/spikeinterface/widgets/utils_ipywidgets.py +++ b/src/spikeinterface/widgets/utils_ipywidgets.py @@ -1,4 +1,6 @@ -import ipywidgets.widgets as widgets +import ipywidgets.widgets as W +import traitlets + import numpy as np @@ -10,20 +12,20 @@ def check_ipywidget_backend(): def make_timeseries_controller(t_start, t_stop, layer_keys, num_segments, time_range, mode, all_layers, width_cm): - time_slider = widgets.FloatSlider( + time_slider = W.FloatSlider( orientation="horizontal", description="time:", value=time_range[0], min=t_start, max=t_stop, continuous_update=False, - layout=widgets.Layout(width=f"{width_cm}cm"), + layout=W.Layout(width=f"{width_cm}cm"), ) - layer_selector = widgets.Dropdown(description="layer", options=layer_keys) - segment_selector = widgets.Dropdown(description="segment", options=list(range(num_segments))) - window_sizer = widgets.BoundedFloatText(value=np.diff(time_range)[0], step=0.1, min=0.005, description="win (s)") - mode_selector = widgets.Dropdown(options=["line", "map"], description="mode", value=mode) - all_layers = widgets.Checkbox(description="plot all layers", value=all_layers) + layer_selector = W.Dropdown(description="layer", options=layer_keys) + segment_selector = W.Dropdown(description="segment", options=list(range(num_segments))) + window_sizer = W.BoundedFloatText(value=np.diff(time_range)[0], step=0.1, min=0.005, description="win (s)") + mode_selector = W.Dropdown(options=["line", "map"], description="mode", value=mode) + all_layers = W.Checkbox(description="plot all layers", value=all_layers) controller = { "layer_key": layer_selector, @@ -33,32 +35,32 @@ def make_timeseries_controller(t_start, t_stop, layer_keys, num_segments, time_r "mode": mode_selector, "all_layers": all_layers, } - widget = widgets.VBox( - [time_slider, widgets.HBox([all_layers, layer_selector, segment_selector, window_sizer, mode_selector])] + widget = W.VBox( + [time_slider, W.HBox([all_layers, layer_selector, segment_selector, window_sizer, mode_selector])] ) return widget, controller def make_unit_controller(unit_ids, all_unit_ids, width_cm, height_cm): - unit_label = widgets.Label(value="units:") + unit_label = W.Label(value="units:") - unit_selector = widgets.SelectMultiple( + unit_selector = W.SelectMultiple( options=all_unit_ids, value=list(unit_ids), disabled=False, - layout=widgets.Layout(width=f"{width_cm}cm", height=f"{height_cm}cm"), + layout=W.Layout(width=f"{width_cm}cm", height=f"{height_cm}cm"), ) controller = {"unit_ids": unit_selector} - widget = widgets.VBox([unit_label, unit_selector]) + widget = W.VBox([unit_label, unit_selector]) return widget, controller def make_channel_controller(recording, width_cm, height_cm): - channel_label = widgets.Label("channel indices:", layout=widgets.Layout(justify_content="center")) - channel_selector = widgets.IntRangeSlider( + channel_label = W.Label("channel indices:", layout=W.Layout(justify_content="center")) + channel_selector = W.IntRangeSlider( value=[0, recording.get_num_channels()], min=0, max=recording.get_num_channels(), @@ -68,37 +70,238 @@ def make_channel_controller(recording, width_cm, height_cm): orientation="vertical", readout=True, readout_format="d", - layout=widgets.Layout(width=f"{0.8 * width_cm}cm", height=f"{height_cm}cm"), + layout=W.Layout(width=f"{0.8 * width_cm}cm", height=f"{height_cm}cm"), ) controller = {"channel_inds": channel_selector} - widget = widgets.VBox([channel_label, channel_selector]) + widget = W.VBox([channel_label, channel_selector]) return widget, controller def make_scale_controller(width_cm, height_cm): - scale_label = widgets.Label("Scale", layout=widgets.Layout(justify_content="center")) + scale_label = W.Label("Scale", layout=W.Layout(justify_content="center")) - plus_selector = widgets.Button( + plus_selector = W.Button( description="", disabled=False, button_style="", # 'success', 'info', 'warning', 'danger' or '' tooltip="Increase scale", icon="arrow-up", - layout=widgets.Layout(width=f"{0.8 * width_cm}cm", height=f"{0.4 * height_cm}cm"), + layout=W.Layout(width=f"{0.8 * width_cm}cm", height=f"{0.4 * height_cm}cm"), ) - minus_selector = widgets.Button( + minus_selector = W.Button( description="", disabled=False, button_style="", # 'success', 'info', 'warning', 'danger' or '' tooltip="Decrease scale", icon="arrow-down", - layout=widgets.Layout(width=f"{0.8 * width_cm}cm", height=f"{0.4 * height_cm}cm"), + layout=W.Layout(width=f"{0.8 * width_cm}cm", height=f"{0.4 * height_cm}cm"), ) controller = {"plus": plus_selector, "minus": minus_selector} - widget = widgets.VBox([scale_label, plus_selector, minus_selector]) + widget = W.VBox([scale_label, plus_selector, minus_selector]) return widget, controller + + + +class TimeSlider(W.HBox): + + position = traitlets.Tuple(traitlets.Int(), traitlets.Int(), traitlets.Int()) + + def __init__(self, durations, sampling_frequency, time_range=(0, 1.), **kwargs): + + + self.num_segments = len(durations) + self.frame_limits = [int(sampling_frequency * d) for d in durations] + self.sampling_frequency = sampling_frequency + start_frame = int(time_range[0] * sampling_frequency) + end_frame = int(time_range[1] * sampling_frequency) + + self.frame_range = (start_frame, end_frame) + + self.segment_index = 0 + self.position = (start_frame, end_frame, self.segment_index) + + + layout = W.Layout(align_items="center", width="1.5cm", height="100%") + but_left = W.Button(description='', disabled=False, button_style='', icon='arrow-left', layout=layout) + but_right = W.Button(description='', disabled=False, button_style='', icon='arrow-right', layout=layout) + + but_left.on_click(self.move_left) + but_right.on_click(self.move_right) + + self.move_size = W.Dropdown(options=['10 ms', '100 ms', '1 s', '10 s', '1 m', '30 m', '1 h',], # '6 h', '24 h' + value='1 s', + description='', + layout = W.Layout(width="2cm") + ) + + # DatetimePicker is only for ipywidget v8 (which is not working in vscode 2023-03) + self.time_label = W.Text(value=f'{time_range[0]}',description='', + disabled=False, layout=W.Layout(width='5.5cm')) + self.time_label.observe(self.time_label_changed, names='value', type="change") + + + self.slider = W.IntSlider( + orientation='horizontal', + # description='time:', + value=start_frame, + min=0, + max=self.frame_limits[self.segment_index], + readout=False, + continuous_update=False, + layout=W.Layout(width=f'70%') + ) + + self.slider.observe(self.slider_moved, names='value', type="change") + + delta_s = np.diff(self.frame_range) / sampling_frequency + + self.window_sizer = W.BoundedFloatText(value=delta_s, step=1, + min=0.01, max=30., + description='win (s)', + layout=W.Layout(width='auto') + # layout=W.Layout(width=f'10%') + ) + self.window_sizer.observe(self.win_size_changed, names='value', type="change") + + self.segment_selector = W.Dropdown(description="segment", options=list(range(self.num_segments))) + self.segment_selector.observe(self.segment_changed, names='value', type="change") + + super(W.HBox, self).__init__(children=[self.segment_selector, but_left, self.move_size, but_right, + self.slider, self.time_label, self.window_sizer], + layout=W.Layout(align_items="center", width="100%", height="100%"), + **kwargs) + + self.observe(self.position_changed, names=['position'], type="change") + + def position_changed(self, change=None): + + self.unobserve(self.position_changed, names=['position'], type="change") + + start, stop, seg_index = self.position + if seg_index < 0 or seg_index >= self.num_segments: + self.position = change['old'] + return + if start < 0 or stop < 0: + self.position = change['old'] + return + if start >= self.frame_limits[seg_index] or start > self.frame_limits[seg_index]: + self.position = change['old'] + return + + self.segment_selector.value = seg_index + self.update_time(new_frame=start, update_slider=True, update_label=True) + delta_s = (stop - start) / self.sampling_frequency + self.window_sizer.value = delta_s + + self.observe(self.position_changed, names=['position'], type="change") + + def update_time(self, new_frame=None, new_time=None, update_slider=False, update_label=False): + if new_frame is None and new_time is None: + start_frame = self.slider.value + elif new_frame is None: + start_frame = int(new_time * self.sampling_frequency) + else: + start_frame = new_frame + delta_s = self.window_sizer.value + end_frame = start_frame + int(delta_s * self.sampling_frequency) + + # clip + start_frame = max(0, start_frame) + end_frame = min(self.frame_limits[self.segment_index], end_frame) + + + start_time = start_frame / self.sampling_frequency + + if update_label: + self.time_label.unobserve(self.time_label_changed, names='value', type="change") + self.time_label.value = f'{start_time}' + self.time_label.observe(self.time_label_changed, names='value', type="change") + + if update_slider: + self.slider.unobserve(self.slider_moved, names='value', type="change") + self.slider.value = start_frame + self.slider.observe(self.slider_moved, names='value', type="change") + + self.frame_range = (start_frame, end_frame) + + def time_label_changed(self, change=None): + try: + new_time = float(self.time_label.value) + except: + new_time = None + if new_time is not None: + self.update_time(new_time=new_time, update_slider=True) + + + def win_size_changed(self, change=None): + self.update_time() + + def slider_moved(self, change=None): + new_frame = self.slider.value + self.update_time(new_frame=new_frame, update_label=True) + + def move(self, sign): + value, units = self.move_size.value.split(' ') + value = int(value) + delta_s = (sign * np.timedelta64(value, units)) / np.timedelta64(1, 's') + delta_sample = int(delta_s * self.sampling_frequency) + + new_frame = self.frame_range[0] + delta_sample + self.slider.value = new_frame + + def move_left(self, change=None): + self.move(-1) + + def move_right(self, change=None): + self.move(+1) + + def segment_changed(self, change=None): + self.segment_index = self.segment_selector.value + + self.slider.unobserve(self.slider_moved, names='value', type="change") + # self.slider.value = 0 + self.slider.max = self.frame_limits[self.segment_index] + self.slider.observe(self.slider_moved, names='value', type="change") + + self.update_time(new_frame=0, update_slider=True, update_label=True) + + + +class ScaleWidget(W.VBox): + def __init__(self, **kwargs): + scale_label = W.Label("Scale", + layout=W.Layout(layout=W.Layout(width='95%'), + justify_content="center")) + + self.plus_selector = W.Button( + description="", + disabled=False, + button_style="", # 'success', 'info', 'warning', 'danger' or '' + tooltip="Increase scale", + icon="arrow-up", + # layout=W.Layout(width=f"{0.8 * width_cm}cm", height=f"{0.4 * height_cm}cm"), + layout=W.Layout(width='95%'), + ) + + self.minus_selector = W.Button( + description="", + disabled=False, + button_style="", # 'success', 'info', 'warning', 'danger' or '' + tooltip="Decrease scale", + icon="arrow-down", + # layout=W.Layout(width=f"{0.8 * width_cm}cm", height=f"{0.4 * height_cm}cm"), + layout=W.Layout(width='95%'), + ) + + # controller = {"plus": plus_selector, "minus": minus_selector} + # widget = W.VBox([scale_label, plus_selector, minus_selector]) + + + super(W.VBox, self).__init__(children=[scale_label, self.plus_selector, self.minus_selector], + # layout=W.Layout(align_items="center", width="100%", height="100%"), + **kwargs) From 389737efe1330f1f75afb73caedb41bb6bf84b4d Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 21 Sep 2023 20:58:38 +0200 Subject: [PATCH 04/18] wip refactor plot traces ipywidget --- src/spikeinterface/widgets/traces.py | 126 ++++++++++++++---- .../widgets/utils_ipywidgets.py | 62 ++++++--- 2 files changed, 145 insertions(+), 43 deletions(-) diff --git a/src/spikeinterface/widgets/traces.py b/src/spikeinterface/widgets/traces.py index c6e36387f8..efd32ffb24 100644 --- a/src/spikeinterface/widgets/traces.py +++ b/src/spikeinterface/widgets/traces.py @@ -279,9 +279,9 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): import ipywidgets.widgets as W from .utils_ipywidgets import ( check_ipywidget_backend, - make_timeseries_controller, + # make_timeseries_controller, make_channel_controller, - make_scale_controller, + # make_scale_controller, TimeSlider, ScaleWidget, @@ -315,21 +315,22 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): - ts_widget, ts_controller = make_timeseries_controller( - t_start, - t_stop, - data_plot["layer_keys"], - rec0.get_num_segments(), - data_plot["time_range"], - data_plot["mode"], - False, - width_cm, - ) + # ts_widget, ts_controller = make_timeseries_controller( + # t_start, + # t_stop, + # data_plot["layer_keys"], + # rec0.get_num_segments(), + # data_plot["time_range"], + # data_plot["mode"], + # False, + # width_cm, + # ) # some widgets self.time_slider = TimeSlider( durations=[rec0.get_duration(s) for s in range(rec0.get_num_segments())], sampling_frequency=rec0.sampling_frequency, + # layout=W.Layout(height="2cm"), ) self.layer_selector = W.Dropdown(description="layer", options=data_plot["layer_keys"], layout=W.Layout(width="5cm"),) @@ -338,22 +339,22 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.scaler = ScaleWidget() left_sidebar = W.VBox( children=[self.layer_selector, self.mode_selector, self.scaler], - layout=W.Layout(width="5cm"), + layout=W.Layout(width="3.5cm"), ) ch_widget, ch_controller = make_channel_controller(rec0, width_cm=ratios[2] * width_cm, height_cm=height_cm) - scale_widget, scale_controller = make_scale_controller(width_cm=ratios[0] * width_cm, height_cm=height_cm) + # scale_widget, scale_controller = make_scale_controller(width_cm=ratios[0] * width_cm, height_cm=height_cm) - self.controller = ts_controller - self.controller.update(ch_controller) - self.controller.update(scale_controller) + # self.controller = ts_controller + # self.controller.update(ch_controller) + # self.controller.update(scale_controller) self.recordings = data_plot["recordings"] self.return_scaled = data_plot["return_scaled"] self.list_traces = None - self.actual_segment_index = self.controller["segment_index"].value + # self.actual_segment_index = self.controller["segment_index"].value self.rec0 = self.recordings[self.data_plot["layer_keys"][0]] self.t_stops = [ @@ -361,11 +362,11 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): for seg_index in range(self.rec0.get_num_segments()) ] - for w in self.controller.values(): - if isinstance(w, widgets.Button): - w.on_click(self._update_ipywidget) - else: - w.observe(self._update_ipywidget) + # for w in self.controller.values(): + # if isinstance(w, widgets.Button): + # w.on_click(self._update_ipywidget) + # else: + # w.observe(self._update_ipywidget) self.widget = widgets.AppLayout( center=self.figure.canvas, @@ -379,12 +380,89 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): ) # a first update - self._update_ipywidget(None) + # self._update_ipywidget(None) + + self._retrieve_traces() + self._update_plot() + + # only layer selector and time change generate a new traces retrieve + self.time_slider.observe(self._retrieve_traces, names='value', type="change") + self.layer_selector.observe(self._retrieve_traces, names='value', type="change") + # other widgets only refresh + self.scaler.observe(self._update_plot, names='value', type="change") + self.mode_selector.observe(self._update_plot, names='value', type="change") + if backend_kwargs["display"]: # self.check_backend() display(self.widget) + + + def _retrieve_traces(self, change=None): + # done when: + # * time or window is changes + # * layer is changed + + # TODO connect with channel selector + channel_ids = self.rec0.channel_ids + + # all_channel_ids = self.recordings[list(self.recordings.keys())[0]].channel_ids + # if self.data_plot["order"] is not None: + # all_channel_ids = all_channel_ids[self.data_plot["order"]] + # channel_ids = all_channel_ids[channel_indices] + if self.data_plot["order_channel_by_depth"]: + order, _ = order_channels_by_depth(self.rec0, channel_ids) + else: + order = None + + start_frame, end_frame, segment_index = self.time_slider.value + time_range = np.array([start_frame, end_frame]) / self.rec0.sampling_frequency + + times, list_traces, frame_range, channel_ids = _get_trace_list( + self.recordings, channel_ids, time_range, segment_index, order, self.return_scaled + ) + self.list_traces = list_traces + + self._update_plot() + + def _update_plot(self, change=None): + # done when: + # * time or window is changed (after _retrive_traces) + # * layer is changed (after _retrive_traces) + #  * scale is change + # * mode is change + + data_plot = self.next_data_plot + + # matplotlib next_data_plot dict update at each call + data_plot["mode"] = self.mode_selector.value + # data_plot["frame_range"] = frame_range + # data_plot["time_range"] = time_range + data_plot["with_colorbar"] = False + # data_plot["recordings"] = recordings + # data_plot["layer_keys"] = layer_keys + # data_plot["list_traces"] = list_traces_plot + # data_plot["times"] = times + # data_plot["clims"] = clims + # data_plot["channel_ids"] = channel_ids + + list_traces = [traces * self.scaler.value for traces in self.list_traces] + data_plot["list_traces"] = list_traces + + backend_kwargs = {} + backend_kwargs["ax"] = self.ax + + self.ax.clear() + self.plot_matplotlib(data_plot, **backend_kwargs) + + fig = self.ax.figure + fig.canvas.draw() + fig.canvas.flush_events() + + + + def _update_ipywidget(self, change): import ipywidgets.widgets as widgets diff --git a/src/spikeinterface/widgets/utils_ipywidgets.py b/src/spikeinterface/widgets/utils_ipywidgets.py index 674a2d2cc7..ad0ead7bc0 100644 --- a/src/spikeinterface/widgets/utils_ipywidgets.py +++ b/src/spikeinterface/widgets/utils_ipywidgets.py @@ -109,7 +109,7 @@ def make_scale_controller(width_cm, height_cm): class TimeSlider(W.HBox): - position = traitlets.Tuple(traitlets.Int(), traitlets.Int(), traitlets.Int()) + value = traitlets.Tuple(traitlets.Int(), traitlets.Int(), traitlets.Int()) def __init__(self, durations, sampling_frequency, time_range=(0, 1.), **kwargs): @@ -123,10 +123,10 @@ def __init__(self, durations, sampling_frequency, time_range=(0, 1.), **kwargs): self.frame_range = (start_frame, end_frame) self.segment_index = 0 - self.position = (start_frame, end_frame, self.segment_index) + self.value = (start_frame, end_frame, self.segment_index) - layout = W.Layout(align_items="center", width="1.5cm", height="100%") + layout = W.Layout(align_items="center", width="2cm", hight="1.5cm") but_left = W.Button(description='', disabled=False, button_style='', icon='arrow-left', layout=layout) but_right = W.Button(description='', disabled=False, button_style='', icon='arrow-right', layout=layout) @@ -176,21 +176,21 @@ def __init__(self, durations, sampling_frequency, time_range=(0, 1.), **kwargs): layout=W.Layout(align_items="center", width="100%", height="100%"), **kwargs) - self.observe(self.position_changed, names=['position'], type="change") + self.observe(self.value_changed, names=['value'], type="change") - def position_changed(self, change=None): + def value_changed(self, change=None): - self.unobserve(self.position_changed, names=['position'], type="change") + self.unobserve(self.value_changed, names=['value'], type="change") - start, stop, seg_index = self.position + start, stop, seg_index = self.value if seg_index < 0 or seg_index >= self.num_segments: - self.position = change['old'] + self.value = change['old'] return if start < 0 or stop < 0: - self.position = change['old'] + self.value = change['old'] return if start >= self.frame_limits[seg_index] or start > self.frame_limits[seg_index]: - self.position = change['old'] + self.value = change['old'] return self.segment_selector.value = seg_index @@ -198,7 +198,7 @@ def position_changed(self, change=None): delta_s = (stop - start) / self.sampling_frequency self.window_sizer.value = delta_s - self.observe(self.position_changed, names=['position'], type="change") + self.observe(self.value_changed, names=['value'], type="change") def update_time(self, new_frame=None, new_time=None, update_slider=False, update_label=False): if new_frame is None and new_time is None: @@ -228,6 +228,7 @@ def update_time(self, new_frame=None, new_time=None, update_slider=False, update self.slider.observe(self.slider_moved, names='value', type="change") self.frame_range = (start_frame, end_frame) + self.value = (start_frame, end_frame, self.segment_index) def time_label_changed(self, change=None): try: @@ -273,8 +274,14 @@ def segment_changed(self, change=None): class ScaleWidget(W.VBox): - def __init__(self, **kwargs): - scale_label = W.Label("Scale", + value = traitlets.Float() + + def __init__(self, value=1., factor=1.2, **kwargs): + + assert factor > 1. + self.factor = factor + + self.scale_label = W.Label("Scale", layout=W.Layout(layout=W.Layout(width='95%'), justify_content="center")) @@ -285,7 +292,7 @@ def __init__(self, **kwargs): tooltip="Increase scale", icon="arrow-up", # layout=W.Layout(width=f"{0.8 * width_cm}cm", height=f"{0.4 * height_cm}cm"), - layout=W.Layout(width='95%'), + layout=W.Layout(width='60%', align_self='center'), ) self.minus_selector = W.Button( @@ -295,13 +302,30 @@ def __init__(self, **kwargs): tooltip="Decrease scale", icon="arrow-down", # layout=W.Layout(width=f"{0.8 * width_cm}cm", height=f"{0.4 * height_cm}cm"), - layout=W.Layout(width='95%'), + layout=W.Layout(width='60%', align_self='center'), ) - # controller = {"plus": plus_selector, "minus": minus_selector} - # widget = W.VBox([scale_label, plus_selector, minus_selector]) + self.plus_selector.on_click(self.plus_clicked) + self.minus_selector.on_click(self.minus_clicked) - - super(W.VBox, self).__init__(children=[scale_label, self.plus_selector, self.minus_selector], + self.value = 1. + super(W.VBox, self).__init__(children=[self.plus_selector, self.scale_label, self.minus_selector], # layout=W.Layout(align_items="center", width="100%", height="100%"), **kwargs) + + self.update_label() + self.observe(self.value_changed, names=['value'], type="change") + + def update_label(self): + self.scale_label.value = f"Scale: {self.value:0.2f}" + + + def plus_clicked(self, change=None): + self.value = self.value * self.factor + + def minus_clicked(self, change=None): + self.value = self.value / self.factor + + + def value_changed(self, change=None): + self.update_label() From e5995f2aa6445fd878e1c0881f11299f8ae22a2d Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 21 Sep 2023 22:59:53 +0200 Subject: [PATCH 05/18] ipywidget backend refactor wip --- src/spikeinterface/widgets/traces.py | 298 +++++------------- .../widgets/utils_ipywidgets.py | 175 ++++++---- 2 files changed, 190 insertions(+), 283 deletions(-) diff --git a/src/spikeinterface/widgets/traces.py b/src/spikeinterface/widgets/traces.py index efd32ffb24..d107c5cb23 100644 --- a/src/spikeinterface/widgets/traces.py +++ b/src/spikeinterface/widgets/traces.py @@ -280,23 +280,23 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): from .utils_ipywidgets import ( check_ipywidget_backend, # make_timeseries_controller, - make_channel_controller, + # make_channel_controller, # make_scale_controller, - TimeSlider, + ChannelSelector, ScaleWidget, - ) check_ipywidget_backend() self.next_data_plot = data_plot.copy() - self.next_data_plot["add_legend"] = False + - recordings = data_plot["recordings"] + self.recordings = data_plot["recordings"] # first layer - rec0 = recordings[data_plot["layer_keys"][0]] + # rec0 = recordings[data_plot["layer_keys"][0]] + rec0 = self.rec0 = self.recordings[self.data_plot["layer_keys"][0]] cm = 1 / 2.54 @@ -310,107 +310,92 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.figure, self.ax = plt.subplots(figsize=(0.9 * ratios[1] * width_cm * cm, height_cm * cm)) plt.show() - t_start = 0.0 - t_stop = rec0.get_num_samples(segment_index=0) / rec0.get_sampling_frequency() - - - - # ts_widget, ts_controller = make_timeseries_controller( - # t_start, - # t_stop, - # data_plot["layer_keys"], - # rec0.get_num_segments(), - # data_plot["time_range"], - # data_plot["mode"], - # False, - # width_cm, - # ) - # some widgets self.time_slider = TimeSlider( durations=[rec0.get_duration(s) for s in range(rec0.get_num_segments())], sampling_frequency=rec0.sampling_frequency, # layout=W.Layout(height="2cm"), ) - self.layer_selector = W.Dropdown(description="layer", options=data_plot["layer_keys"], - layout=W.Layout(width="5cm"),) - self.mode_selector = W.Dropdown(options=["line", "map"], description="mode", value=data_plot["mode"], - layout=W.Layout(width="5cm"),) + + start_frame = int(data_plot["time_range"][0] * rec0.sampling_frequency) + end_frame = int(data_plot["time_range"][1] * rec0.sampling_frequency) + + self.time_slider.value = start_frame, end_frame, data_plot["segment_index"] + + _layer_keys = data_plot["layer_keys"] + if len(_layer_keys) > 1: + _layer_keys = ['ALL'] + _layer_keys + self.layer_selector = W.Dropdown(options=_layer_keys, + layout=W.Layout(width="95%"), + ) + self.mode_selector = W.Dropdown(options=["line", "map"], value=data_plot["mode"], + # layout=W.Layout(width="5cm"), + layout=W.Layout(width="95%"), + ) self.scaler = ScaleWidget() + self.channel_selector = ChannelSelector(self.rec0.channel_ids) + left_sidebar = W.VBox( - children=[self.layer_selector, self.mode_selector, self.scaler], + children=[ + W.Label(value="layer"), + self.layer_selector, + W.Label(value="mode"), + self.mode_selector, + self.scaler, + # self.channel_selector, + ], layout=W.Layout(width="3.5cm"), + align_items='center', ) - - ch_widget, ch_controller = make_channel_controller(rec0, width_cm=ratios[2] * width_cm, height_cm=height_cm) - - # scale_widget, scale_controller = make_scale_controller(width_cm=ratios[0] * width_cm, height_cm=height_cm) - - # self.controller = ts_controller - # self.controller.update(ch_controller) - # self.controller.update(scale_controller) - - self.recordings = data_plot["recordings"] self.return_scaled = data_plot["return_scaled"] - self.list_traces = None - # self.actual_segment_index = self.controller["segment_index"].value - - self.rec0 = self.recordings[self.data_plot["layer_keys"][0]] - self.t_stops = [ - self.rec0.get_num_samples(segment_index=seg_index) / self.rec0.get_sampling_frequency() - for seg_index in range(self.rec0.get_num_segments()) - ] - - # for w in self.controller.values(): - # if isinstance(w, widgets.Button): - # w.on_click(self._update_ipywidget) - # else: - # w.observe(self._update_ipywidget) self.widget = widgets.AppLayout( center=self.figure.canvas, - # footer=ts_widget, footer=self.time_slider, - # left_sidebar=scale_widget, left_sidebar = left_sidebar, - right_sidebar=ch_widget, + right_sidebar=self.channel_selector, pane_heights=[0, 6, 1], pane_widths=ratios, ) # a first update - # self._update_ipywidget(None) - self._retrieve_traces() self._update_plot() - # only layer selector and time change generate a new traces retrieve + # callbacks: + # some widgets generate a full retrieve + refresh self.time_slider.observe(self._retrieve_traces, names='value', type="change") self.layer_selector.observe(self._retrieve_traces, names='value', type="change") + self.channel_selector.observe(self._retrieve_traces, names='value', type="change") # other widgets only refresh self.scaler.observe(self._update_plot, names='value', type="change") - self.mode_selector.observe(self._update_plot, names='value', type="change") + # map is a special case because needs to check layer also + self.mode_selector.observe(self._mode_changed, names='value', type="change") - if backend_kwargs["display"]: # self.check_backend() display(self.widget) - + def _get_layers(self): + layer = self.layer_selector.value + if layer == 'ALL': + layer_keys = self.data_plot["layer_keys"] + else: + layer_keys = [layer] + if self.mode_selector.value == "map": + layer_keys = layer_keys[:1] + return layer_keys + + def _mode_changed(self, change=None): + if self.mode_selector.value == "map" and self.layer_selector.value == "ALL": + self.layer_selector.value = self.data_plot["layer_keys"][0] + else: + self._update_plot() def _retrieve_traces(self, change=None): - # done when: - # * time or window is changes - # * layer is changed + channel_ids = np.array(self.channel_selector.value) - # TODO connect with channel selector - channel_ids = self.rec0.channel_ids - - # all_channel_ids = self.recordings[list(self.recordings.keys())[0]].channel_ids - # if self.data_plot["order"] is not None: - # all_channel_ids = all_channel_ids[self.data_plot["order"]] - # channel_ids = all_channel_ids[channel_indices] if self.data_plot["order_channel_by_depth"]: order, _ = order_channels_by_depth(self.rec0, channel_ids) else: @@ -419,176 +404,61 @@ def _retrieve_traces(self, change=None): start_frame, end_frame, segment_index = self.time_slider.value time_range = np.array([start_frame, end_frame]) / self.rec0.sampling_frequency + self._selected_recordings = {k: self.recordings[k] for k in self._get_layers()} times, list_traces, frame_range, channel_ids = _get_trace_list( - self.recordings, channel_ids, time_range, segment_index, order, self.return_scaled + self._selected_recordings, channel_ids, time_range, segment_index, order, self.return_scaled ) - self.list_traces = list_traces + + self._channel_ids = channel_ids + self._list_traces = list_traces + self._times = times + self._time_range = time_range + self._frame_range = (start_frame, end_frame) + self._segment_index = segment_index self._update_plot() def _update_plot(self, change=None): - # done when: - # * time or window is changed (after _retrive_traces) - # * layer is changed (after _retrive_traces) - #  * scale is change - # * mode is change - data_plot = self.next_data_plot # matplotlib next_data_plot dict update at each call - data_plot["mode"] = self.mode_selector.value - # data_plot["frame_range"] = frame_range - # data_plot["time_range"] = time_range - data_plot["with_colorbar"] = False - # data_plot["recordings"] = recordings - # data_plot["layer_keys"] = layer_keys - # data_plot["list_traces"] = list_traces_plot - # data_plot["times"] = times - # data_plot["clims"] = clims - # data_plot["channel_ids"] = channel_ids - - list_traces = [traces * self.scaler.value for traces in self.list_traces] - data_plot["list_traces"] = list_traces - - backend_kwargs = {} - backend_kwargs["ax"] = self.ax - - self.ax.clear() - self.plot_matplotlib(data_plot, **backend_kwargs) - - fig = self.ax.figure - fig.canvas.draw() - fig.canvas.flush_events() - - - - - def _update_ipywidget(self, change): - import ipywidgets.widgets as widgets - - # if changing the layer_key, no need to retrieve and process traces - retrieve_traces = True - scale_up = False - scale_down = False - if change is not None: - for cname, c in self.controller.items(): - if isinstance(change, dict): - if change["owner"] is c and cname == "layer_key": - retrieve_traces = False - elif isinstance(change, widgets.Button): - if change is c and cname == "plus": - scale_up = True - if change is c and cname == "minus": - scale_down = True - - t_start = self.controller["t_start"].value - window = self.controller["window"].value - layer_key = self.controller["layer_key"].value - segment_index = self.controller["segment_index"].value - mode = self.controller["mode"].value - chan_start, chan_stop = self.controller["channel_inds"].value - - if mode == "line": - self.controller["all_layers"].layout.visibility = "visible" - all_layers = self.controller["all_layers"].value - elif mode == "map": - self.controller["all_layers"].layout.visibility = "hidden" - all_layers = False - - if all_layers: - self.controller["layer_key"].layout.visibility = "hidden" - else: - self.controller["layer_key"].layout.visibility = "visible" - - if chan_start == chan_stop: - chan_stop += 1 - channel_indices = np.arange(chan_start, chan_stop) - - t_stop = self.t_stops[segment_index] - if self.actual_segment_index != segment_index: - # change time_slider limits - self.controller["t_start"].max = t_stop - self.actual_segment_index = segment_index - - # protect limits - if t_start >= t_stop - window: - t_start = t_stop - window - - time_range = np.array([t_start, t_start + window]) - data_plot = self.next_data_plot + mode = self.mode_selector.value + layer_keys = self._get_layers() - if retrieve_traces: - all_channel_ids = self.recordings[list(self.recordings.keys())[0]].channel_ids - if self.data_plot["order"] is not None: - all_channel_ids = all_channel_ids[self.data_plot["order"]] - channel_ids = all_channel_ids[channel_indices] - if self.data_plot["order_channel_by_depth"]: - order, _ = order_channels_by_depth(self.rec0, channel_ids) - else: - order = None - times, list_traces, frame_range, channel_ids = _get_trace_list( - self.recordings, channel_ids, time_range, segment_index, order, self.return_scaled - ) - self.list_traces = list_traces - else: - times = data_plot["times"] - list_traces = data_plot["list_traces"] - frame_range = data_plot["frame_range"] - channel_ids = data_plot["channel_ids"] - - if all_layers: - layer_keys = self.data_plot["layer_keys"] - recordings = self.recordings - list_traces_plot = self.list_traces - else: - layer_keys = [layer_key] - recordings = {layer_key: self.recordings[layer_key]} - list_traces_plot = [self.list_traces[list(self.recordings.keys()).index(layer_key)]] - - if scale_up: - if mode == "line": - data_plot["vspacing"] *= 0.8 - elif mode == "map": - data_plot["clims"] = { - layer: (1.2 * val[0], 1.2 * val[1]) for layer, val in self.data_plot["clims"].items() - } - if scale_down: - if mode == "line": - data_plot["vspacing"] *= 1.2 - elif mode == "map": - data_plot["clims"] = { - layer: (0.8 * val[0], 0.8 * val[1]) for layer, val in self.data_plot["clims"].items() - } - - self.next_data_plot["vspacing"] = data_plot["vspacing"] - self.next_data_plot["clims"] = data_plot["clims"] + data_plot["mode"] = mode + data_plot["frame_range"] = self._frame_range + data_plot["time_range"] = self._time_range + data_plot["with_colorbar"] = False + data_plot["recordings"] = self._selected_recordings + data_plot["add_legend"] = False if mode == "line": clims = None elif mode == "map": - clims = {layer_key: self.data_plot["clims"][layer_key]} + clims = {k: self.data_plot["clims"][k] for k in layer_keys} - # matplotlib next_data_plot dict update at each call - data_plot["mode"] = mode - data_plot["frame_range"] = frame_range - data_plot["time_range"] = time_range - data_plot["with_colorbar"] = False - data_plot["recordings"] = recordings - data_plot["layer_keys"] = layer_keys - data_plot["list_traces"] = list_traces_plot - data_plot["times"] = times data_plot["clims"] = clims - data_plot["channel_ids"] = channel_ids + data_plot["channel_ids"] = self._channel_ids + + data_plot["layer_keys"] = layer_keys + data_plot["colors"] = {k:self.data_plot["colors"][k] for k in layer_keys} + + list_traces = [traces * self.scaler.value for traces in self._list_traces] + data_plot["list_traces"] = list_traces + data_plot["times"] = self._times backend_kwargs = {} backend_kwargs["ax"] = self.ax + self.ax.clear() self.plot_matplotlib(data_plot, **backend_kwargs) + self.ax.set_title("") fig = self.ax.figure fig.canvas.draw() fig.canvas.flush_events() + def plot_sortingview(self, data_plot, **backend_kwargs): import sortingview.views as vv from .utils_sortingview import handle_display_and_url diff --git a/src/spikeinterface/widgets/utils_ipywidgets.py b/src/spikeinterface/widgets/utils_ipywidgets.py index ad0ead7bc0..ab2b51a7bb 100644 --- a/src/spikeinterface/widgets/utils_ipywidgets.py +++ b/src/spikeinterface/widgets/utils_ipywidgets.py @@ -11,35 +11,35 @@ def check_ipywidget_backend(): assert "ipympl" in mpl_backend, "To use the 'ipywidgets' backend, you have to set %matplotlib widget" -def make_timeseries_controller(t_start, t_stop, layer_keys, num_segments, time_range, mode, all_layers, width_cm): - time_slider = W.FloatSlider( - orientation="horizontal", - description="time:", - value=time_range[0], - min=t_start, - max=t_stop, - continuous_update=False, - layout=W.Layout(width=f"{width_cm}cm"), - ) - layer_selector = W.Dropdown(description="layer", options=layer_keys) - segment_selector = W.Dropdown(description="segment", options=list(range(num_segments))) - window_sizer = W.BoundedFloatText(value=np.diff(time_range)[0], step=0.1, min=0.005, description="win (s)") - mode_selector = W.Dropdown(options=["line", "map"], description="mode", value=mode) - all_layers = W.Checkbox(description="plot all layers", value=all_layers) - - controller = { - "layer_key": layer_selector, - "segment_index": segment_selector, - "window": window_sizer, - "t_start": time_slider, - "mode": mode_selector, - "all_layers": all_layers, - } - widget = W.VBox( - [time_slider, W.HBox([all_layers, layer_selector, segment_selector, window_sizer, mode_selector])] - ) - - return widget, controller +# def make_timeseries_controller(t_start, t_stop, layer_keys, num_segments, time_range, mode, all_layers, width_cm): +# time_slider = W.FloatSlider( +# orientation="horizontal", +# description="time:", +# value=time_range[0], +# min=t_start, +# max=t_stop, +# continuous_update=False, +# layout=W.Layout(width=f"{width_cm}cm"), +# ) +# layer_selector = W.Dropdown(description="layer", options=layer_keys) +# segment_selector = W.Dropdown(description="segment", options=list(range(num_segments))) +# window_sizer = W.BoundedFloatText(value=np.diff(time_range)[0], step=0.1, min=0.005, description="win (s)") +# mode_selector = W.Dropdown(options=["line", "map"], description="mode", value=mode) +# all_layers = W.Checkbox(description="plot all layers", value=all_layers) + +# controller = { +# "layer_key": layer_selector, +# "segment_index": segment_selector, +# "window": window_sizer, +# "t_start": time_slider, +# "mode": mode_selector, +# "all_layers": all_layers, +# } +# widget = W.VBox( +# [time_slider, W.HBox([all_layers, layer_selector, segment_selector, window_sizer, mode_selector])] +# ) + +# return widget, controller def make_unit_controller(unit_ids, all_unit_ids, width_cm, height_cm): @@ -58,52 +58,52 @@ def make_unit_controller(unit_ids, all_unit_ids, width_cm, height_cm): return widget, controller -def make_channel_controller(recording, width_cm, height_cm): - channel_label = W.Label("channel indices:", layout=W.Layout(justify_content="center")) - channel_selector = W.IntRangeSlider( - value=[0, recording.get_num_channels()], - min=0, - max=recording.get_num_channels(), - step=1, - disabled=False, - continuous_update=False, - orientation="vertical", - readout=True, - readout_format="d", - layout=W.Layout(width=f"{0.8 * width_cm}cm", height=f"{height_cm}cm"), - ) +# def make_channel_controller(recording, width_cm, height_cm): +# channel_label = W.Label("channel indices:", layout=W.Layout(justify_content="center")) +# channel_selector = W.IntRangeSlider( +# value=[0, recording.get_num_channels()], +# min=0, +# max=recording.get_num_channels(), +# step=1, +# disabled=False, +# continuous_update=False, +# orientation="vertical", +# readout=True, +# readout_format="d", +# layout=W.Layout(width=f"{0.8 * width_cm}cm", height=f"{height_cm}cm"), +# ) - controller = {"channel_inds": channel_selector} - widget = W.VBox([channel_label, channel_selector]) +# controller = {"channel_inds": channel_selector} +# widget = W.VBox([channel_label, channel_selector]) - return widget, controller +# return widget, controller -def make_scale_controller(width_cm, height_cm): - scale_label = W.Label("Scale", layout=W.Layout(justify_content="center")) +# def make_scale_controller(width_cm, height_cm): +# scale_label = W.Label("Scale", layout=W.Layout(justify_content="center")) - plus_selector = W.Button( - description="", - disabled=False, - button_style="", # 'success', 'info', 'warning', 'danger' or '' - tooltip="Increase scale", - icon="arrow-up", - layout=W.Layout(width=f"{0.8 * width_cm}cm", height=f"{0.4 * height_cm}cm"), - ) +# plus_selector = W.Button( +# description="", +# disabled=False, +# button_style="", # 'success', 'info', 'warning', 'danger' or '' +# tooltip="Increase scale", +# icon="arrow-up", +# layout=W.Layout(width=f"{0.8 * width_cm}cm", height=f"{0.4 * height_cm}cm"), +# ) - minus_selector = W.Button( - description="", - disabled=False, - button_style="", # 'success', 'info', 'warning', 'danger' or '' - tooltip="Decrease scale", - icon="arrow-down", - layout=W.Layout(width=f"{0.8 * width_cm}cm", height=f"{0.4 * height_cm}cm"), - ) +# minus_selector = W.Button( +# description="", +# disabled=False, +# button_style="", # 'success', 'info', 'warning', 'danger' or '' +# tooltip="Decrease scale", +# icon="arrow-down", +# layout=W.Layout(width=f"{0.8 * width_cm}cm", height=f"{0.4 * height_cm}cm"), +# ) - controller = {"plus": plus_selector, "minus": minus_selector} - widget = W.VBox([scale_label, plus_selector, minus_selector]) +# controller = {"plus": plus_selector, "minus": minus_selector} +# widget = W.VBox([scale_label, plus_selector, minus_selector]) - return widget, controller +# return widget, controller @@ -126,7 +126,7 @@ def __init__(self, durations, sampling_frequency, time_range=(0, 1.), **kwargs): self.value = (start_frame, end_frame, self.segment_index) - layout = W.Layout(align_items="center", width="2cm", hight="1.5cm") + layout = W.Layout(align_items="center", width="2.5cm", height="1.cm") but_left = W.Button(description='', disabled=False, button_style='', icon='arrow-left', layout=layout) but_right = W.Button(description='', disabled=False, button_style='', icon='arrow-right', layout=layout) @@ -141,7 +141,7 @@ def __init__(self, durations, sampling_frequency, time_range=(0, 1.), **kwargs): # DatetimePicker is only for ipywidget v8 (which is not working in vscode 2023-03) self.time_label = W.Text(value=f'{time_range[0]}',description='', - disabled=False, layout=W.Layout(width='5.5cm')) + disabled=False, layout=W.Layout(width='2.5cm')) self.time_label.observe(self.time_label_changed, names='value', type="change") @@ -271,6 +271,43 @@ def segment_changed(self, change=None): self.update_time(new_frame=0, update_slider=True, update_label=True) +class ChannelSelector(W.VBox): + value = traitlets.List() + + def __init__(self, channel_ids, **kwargs): + self.channel_ids = list(channel_ids) + self.value = self.channel_ids + + channel_label = W.Label("Channels", layout=W.Layout(justify_content="center")) + n = len(channel_ids) + self.slider = W.IntRangeSlider( + value=[0, n], + min=0, + max=n, + step=1, + disabled=False, + continuous_update=False, + orientation="vertical", + readout=True, + readout_format="d", + # layout=W.Layout(width=f"{0.8 * width_cm}cm", height=f"{height_cm}cm"), + layout=W.Layout(height="100%"), + ) + + + + super(W.VBox, self).__init__(children=[channel_label, self.slider], + layout=W.Layout(align_items="center"), + # layout=W.Layout(align_items="center", width="100%", height="100%"), + **kwargs) + self.slider.observe(self.on_slider_changed, names=['value'], type="change") + # self.update_label() + # self.observe(self.value_changed, names=['value'], type="change") + + def on_slider_changed(self, change=None): + i0, i1 = self.slider.value + self.value = self.channel_ids[i0:i1] + class ScaleWidget(W.VBox): From 7b92c2153d4fad412823100fd77079e3cf286138 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 22 Sep 2023 08:06:37 +0200 Subject: [PATCH 06/18] improve channel selector --- .../widgets/utils_ipywidgets.py | 38 +++++++++++++++++-- 1 file changed, 35 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/widgets/utils_ipywidgets.py b/src/spikeinterface/widgets/utils_ipywidgets.py index ab2b51a7bb..705dd09f23 100644 --- a/src/spikeinterface/widgets/utils_ipywidgets.py +++ b/src/spikeinterface/widgets/utils_ipywidgets.py @@ -294,20 +294,52 @@ def __init__(self, channel_ids, **kwargs): layout=W.Layout(height="100%"), ) + # first channel are bottom: need reverse + self.selector = W.SelectMultiple( + options=self.channel_ids[::-1], + value=self.channel_ids[::-1], + disabled=False, + # layout=W.Layout(width=f"{width_cm}cm", height=f"{height_cm}cm"), + layout=W.Layout(height="100%", width="2cm"), + ) + hbox = W.HBox(children=[self.slider, self.selector]) - - super(W.VBox, self).__init__(children=[channel_label, self.slider], + super(W.VBox, self).__init__(children=[channel_label, hbox], layout=W.Layout(align_items="center"), # layout=W.Layout(align_items="center", width="100%", height="100%"), **kwargs) self.slider.observe(self.on_slider_changed, names=['value'], type="change") - # self.update_label() + self.selector.observe(self.on_selector_changed, names=['value'], type="change") + + # TODO external value change # self.observe(self.value_changed, names=['value'], type="change") def on_slider_changed(self, change=None): i0, i1 = self.slider.value + + self.selector.unobserve(self.on_selector_changed, names=['value'], type="change") + self.selector.value = self.channel_ids[i0:i1][::-1] + self.selector.observe(self.on_selector_changed, names=['value'], type="change") + self.value = self.channel_ids[i0:i1] + def on_selector_changed(self, change=None): + channel_ids = self.selector.value + channel_ids = channel_ids[::-1] + + if len(channel_ids) > 0: + self.slider.unobserve(self.on_slider_changed, names=['value'], type="change") + i0 = self.channel_ids.index(channel_ids[0]) + i1 = self.channel_ids.index(channel_ids[-1]) + 1 + self.slider.value = (i0, i1) + self.slider.observe(self.on_slider_changed, names=['value'], type="change") + + self.value = channel_ids + + + + + class ScaleWidget(W.VBox): From c46a7cba4b1e937d40050d0061017256ab5dade3 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 22 Sep 2023 10:31:05 +0200 Subject: [PATCH 07/18] Allow to restrict sparsity --- .../postprocessing/amplitude_scalings.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 4dab68fdf8..3eac333781 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -68,7 +68,6 @@ def _run(self, **job_kwargs): delta_collision_samples = int(delta_collision_ms / 1000 * we.sampling_frequency) return_scaled = we._params["return_scaled"] - unit_ids = we.unit_ids if ms_before is not None: assert ( @@ -82,9 +81,16 @@ def _run(self, **job_kwargs): cut_out_before = int(ms_before / 1000 * we.sampling_frequency) if ms_before is not None else nbefore cut_out_after = int(ms_after / 1000 * we.sampling_frequency) if ms_after is not None else nafter - if we.is_sparse(): + if we.is_sparse() and self._params["sparsity"] is None: sparsity = we.sparsity - elif self._params["sparsity"] is not None: + elif we.is_sparse() and self._params["sparsity"] is not None: + sparsity = self._params["sparsity"] + # assert provided sparsity is sparser than the one in the waveform extractor + waveform_sparsity = we.sparsity + assert np.all( + np.sum(waveform_sparsity.mask, 1) - np.sum(sparsity.mask, 1) > 0 + ), "The provided sparsity needs to be sparser than the one in the waveform extractor!" + elif not we.is_sparse() and self._params["sparsity"] is not None: sparsity = self._params["sparsity"] else: if self._params["max_dense_channels"] is not None: @@ -362,7 +368,7 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) template = template[cut_out_before - sample_index :] elif sample_index + cut_out_after > end_frame + right: local_waveform = traces_with_margin[cut_out_start:, sparse_indices] - template = template[: -(sample_index + cut_out_after - end_frame)] + template = template[: -(sample_index + cut_out_after - end_frame - right)] else: local_waveform = traces_with_margin[cut_out_start:cut_out_end, sparse_indices] assert template.shape == local_waveform.shape From 2e305586d5b39bb8bfa89280057579a97726e93a Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 22 Sep 2023 11:09:05 +0200 Subject: [PATCH 08/18] ipywidgets backend start UnitCOntroller --- src/spikeinterface/widgets/amplitudes.py | 69 ++++++++++--------- .../widgets/utils_ipywidgets.py | 39 +++++++++-- 2 files changed, 71 insertions(+), 37 deletions(-) diff --git a/src/spikeinterface/widgets/amplitudes.py b/src/spikeinterface/widgets/amplitudes.py index 7ef6e0ff61..b60de98cb0 100644 --- a/src/spikeinterface/widgets/amplitudes.py +++ b/src/spikeinterface/widgets/amplitudes.py @@ -171,9 +171,10 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): def plot_ipywidgets(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt - import ipywidgets.widgets as widgets + # import ipywidgets.widgets as widgets + import ipywidgets.widgets as W from IPython.display import display - from .utils_ipywidgets import check_ipywidget_backend, make_unit_controller + from .utils_ipywidgets import check_ipywidget_backend, make_unit_controller, UnitSelector check_ipywidget_backend() @@ -188,60 +189,62 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): ratios = [0.15, 0.85] with plt.ioff(): - output = widgets.Output() + output = W.Output() with output: self.figure = plt.figure(figsize=((ratios[1] * width_cm) * cm, height_cm * cm)) plt.show() - data_plot["unit_ids"] = data_plot["unit_ids"][:1] - unit_widget, unit_controller = make_unit_controller( - data_plot["unit_ids"], we.unit_ids, ratios[0] * width_cm, height_cm - ) + self.unit_selector = UnitSelector(we.unit_ids) + self.unit_selector.value = list(we.unit_ids)[:1] - plot_histograms = widgets.Checkbox( + self.checkbox_histograms = W.Checkbox( value=data_plot["plot_histograms"], - description="plot histograms", - disabled=False, + description="hist", + # disabled=False, ) - footer = plot_histograms - - self.controller = {"plot_histograms": plot_histograms} - self.controller.update(unit_controller) - - for w in self.controller.values(): - w.observe(self._update_ipywidget) + left_sidebar = W.VBox( + children=[ + self.unit_selector, + self.checkbox_histograms, + ], + layout = W.Layout(align_items="center", width="4cm", height="100%"), + ) - self.widget = widgets.AppLayout( + self.widget = W.AppLayout( center=self.figure.canvas, - left_sidebar=unit_widget, + left_sidebar=left_sidebar, pane_widths=ratios + [0], - footer=footer, ) # a first update - self._update_ipywidget(None) + self._full_update_plot() + + self.unit_selector.observe(self._update_plot, names='value', type="change") + self.checkbox_histograms.observe(self._full_update_plot, names='value', type="change") if backend_kwargs["display"]: display(self.widget) - def _update_ipywidget(self, change): + def _full_update_plot(self, change=None): self.figure.clear() + data_plot = self.next_data_plot + data_plot["unit_ids"] = self.unit_selector.value + data_plot["plot_histograms"] = self.checkbox_histograms.value + + backend_kwargs = dict(figure=self.figure, axes=None, ax=None) + self.plot_matplotlib(data_plot, **backend_kwargs) + self._update_plot() - unit_ids = self.controller["unit_ids"].value - plot_histograms = self.controller["plot_histograms"].value + def _update_plot(self, change=None): + for ax in self.axes.flatten(): + ax.clear() - # matplotlib next_data_plot dict update at each call data_plot = self.next_data_plot - data_plot["unit_ids"] = unit_ids - data_plot["plot_histograms"] = plot_histograms - - backend_kwargs = {} - # backend_kwargs["figure"] = self.fig - backend_kwargs["figure"] = self.figure - backend_kwargs["axes"] = None - backend_kwargs["ax"] = None + data_plot["unit_ids"] = self.unit_selector.value + data_plot["plot_histograms"] = self.checkbox_histograms.value + backend_kwargs = dict(figure=None, axes=self.axes, ax=None) self.plot_matplotlib(data_plot, **backend_kwargs) self.figure.canvas.draw() diff --git a/src/spikeinterface/widgets/utils_ipywidgets.py b/src/spikeinterface/widgets/utils_ipywidgets.py index 705dd09f23..d2c41f234a 100644 --- a/src/spikeinterface/widgets/utils_ipywidgets.py +++ b/src/spikeinterface/widgets/utils_ipywidgets.py @@ -338,10 +338,6 @@ def on_selector_changed(self, change=None): - - - - class ScaleWidget(W.VBox): value = traitlets.Float() @@ -398,3 +394,38 @@ def minus_clicked(self, change=None): def value_changed(self, change=None): self.update_label() + + +class UnitSelector(W.VBox): + value = traitlets.List() + + def __init__(self, unit_ids, **kwargs): + self.unit_ids = list(unit_ids) + self.value = self.unit_ids + + label = W.Label("Units", layout=W.Layout(justify_content="center")) + + self.selector = W.SelectMultiple( + options=self.unit_ids, + value=self.unit_ids, + disabled=False, + layout=W.Layout(height="100%", width="2cm"), + ) + + super(W.VBox, self).__init__(children=[label, self.selector], + layout=W.Layout(align_items="center"), + **kwargs) + + self.selector.observe(self.on_selector_changed, names=['value'], type="change") + + self.observe(self.value_changed, names=['value'], type="change") + + def on_selector_changed(self, change=None): + unit_ids = self.selector.value + self.value = unit_ids + + def value_changed(self, change=None): + self.selector.unobserve(self.on_selector_changed, names=['value'], type="change") + self.selector.value = change['new'] + self.selector.observe(self.on_selector_changed, names=['value'], type="change") + From 4e31329d9aed376ecc41c4238a2f4836f94054ea Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 22 Sep 2023 11:37:18 +0200 Subject: [PATCH 09/18] Add spikes on border when generating sorting, PCA sparse return fixes --- src/spikeinterface/core/generate.py | 28 +++++++++++++++++ .../core/tests/test_generate.py | 30 +++++++++++++++++-- .../postprocessing/amplitude_scalings.py | 12 ++++---- .../postprocessing/principal_component.py | 15 ++++++++-- .../tests/common_extension_tests.py | 26 ++++++++++++++-- .../tests/test_principal_component.py | 12 ++++---- 6 files changed, 104 insertions(+), 19 deletions(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 401c498f03..741dd20000 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -123,6 +123,9 @@ def generate_sorting( firing_rates=3.0, empty_units=None, refractory_period_ms=3.0, # in ms + add_spikes_on_borders=False, + num_spikes_per_border=3, + border_size_samples=20, seed=None, ): """ @@ -142,6 +145,12 @@ def generate_sorting( List of units that will have no spikes. (used for testing mainly). refractory_period_ms : float, default: 3.0 The refractory period in ms + add_spikes_on_borders : bool, default: False + If True, spikes will be added close to the borders of the segments. + num_spikes_per_border : int, default: 3 + The number of spikes to add close to the borders of the segments. + border_size_samples : int, default: 20 + The size of the border in samples to add border spikes. seed : int, default: None The random seed @@ -151,11 +160,13 @@ def generate_sorting( The sorting object """ seed = _ensure_seed(seed) + rng = np.random.default_rng(seed) num_segments = len(durations) unit_ids = np.arange(num_units) spikes = [] for segment_index in range(num_segments): + num_samples = int(sampling_frequency * durations[segment_index]) times, labels = synthesize_random_firings( num_units=num_units, sampling_frequency=sampling_frequency, @@ -175,7 +186,23 @@ def generate_sorting( spikes_in_seg["unit_index"] = labels spikes_in_seg["segment_index"] = segment_index spikes.append(spikes_in_seg) + + if add_spikes_on_borders: + spikes_on_borders = np.zeros(2 * num_spikes_per_border, dtype=minimum_spike_dtype) + spikes_on_borders["segment_index"] = segment_index + spikes_on_borders["unit_index"] = rng.choice(num_units, size=2 * num_spikes_per_border, replace=True) + # at start + spikes_on_borders["sample_index"][:num_spikes_per_border] = rng.integers( + 0, border_size_samples, num_spikes_per_border + ) + # at end + spikes_on_borders["sample_index"][num_spikes_per_border:] = rng.integers( + num_samples - border_size_samples, num_samples, num_spikes_per_border + ) + spikes.append(spikes_on_borders) + spikes = np.concatenate(spikes) + spikes = spikes[np.lexsort((spikes["sample_index"], spikes["segment_index"]))] sorting = NumpySorting(spikes, sampling_frequency, unit_ids) @@ -596,6 +623,7 @@ def __init__( dtype = np.dtype(dtype).name # Cast to string for serialization if dtype not in ("float32", "float64"): raise ValueError(f"'dtype' must be 'float32' or 'float64' but is {dtype}") + assert strategy in ("tile_pregenerated", "on_the_fly"), "'strategy' must be 'tile_pregenerated' or 'on_the_fly'" BaseRecording.__init__(self, sampling_frequency=sampling_frequency, channel_ids=channel_ids, dtype=dtype) diff --git a/src/spikeinterface/core/tests/test_generate.py b/src/spikeinterface/core/tests/test_generate.py index 9ba5de42d6..3844e421ac 100644 --- a/src/spikeinterface/core/tests/test_generate.py +++ b/src/spikeinterface/core/tests/test_generate.py @@ -26,15 +26,38 @@ def test_generate_recording(): - # TODO even this is extenssivly tested in all other function + # TODO even this is extensively tested in all other functions pass def test_generate_sorting(): - # TODO even this is extenssivly tested in all other function + # TODO even this is extensively tested in all other functions pass +def test_generate_sorting_with_spikes_on_borders(): + num_spikes_on_borders = 10 + border_size_samples = 10 + segment_duration = 10 + for nseg in [1, 2, 3]: + sorting = generate_sorting( + durations=[segment_duration] * nseg, + sampling_frequency=30000, + num_units=10, + add_spikes_on_borders=True, + num_spikes_per_border=num_spikes_on_borders, + border_size_samples=border_size_samples, + ) + spikes = sorting.to_spike_vector(concatenated=False) + # at least num_border spikes at borders for all segments + for i, spikes_in_segment in enumerate(spikes): + num_samples = int(segment_duration * 30000) + assert np.sum(spikes_in_segment["sample_index"] < border_size_samples) >= num_spikes_on_borders + assert ( + np.sum(spikes_in_segment["sample_index"] >= num_samples - border_size_samples) >= num_spikes_on_borders + ) + + def measure_memory_allocation(measure_in_process: bool = True) -> float: """ A local utility to measure memory allocation at a specific point in time. @@ -399,7 +422,7 @@ def test_generate_ground_truth_recording(): if __name__ == "__main__": strategy = "tile_pregenerated" # strategy = "on_the_fly" - test_noise_generator_memory() + # test_noise_generator_memory() # test_noise_generator_under_giga() # test_noise_generator_correct_shape(strategy) # test_noise_generator_consistency_across_calls(strategy, 0, 5) @@ -410,3 +433,4 @@ def test_generate_ground_truth_recording(): # test_generate_templates() # test_inject_templates() # test_generate_ground_truth_recording() + test_generate_sorting_with_spikes_on_borders() diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 3eac333781..c86337a30d 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -16,6 +16,7 @@ class AmplitudeScalingsCalculator(BaseWaveformExtractorExtension): """ extension_name = "amplitude_scalings" + handle_sparsity = True def __init__(self, waveform_extractor): BaseWaveformExtractorExtension.__init__(self, waveform_extractor) @@ -357,7 +358,7 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) continue unit_index = spike["unit_index"] sample_index = spike["sample_index"] - sparse_indices = sparsity_mask[unit_index] + (sparse_indices,) = np.nonzero(sparsity_mask[unit_index]) template = all_templates[unit_index][:, sparse_indices] template = template[nbefore - cut_out_before : nbefore + cut_out_after] sample_centered = sample_index - start_frame @@ -368,7 +369,7 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) template = template[cut_out_before - sample_index :] elif sample_index + cut_out_after > end_frame + right: local_waveform = traces_with_margin[cut_out_start:, sparse_indices] - template = template[: -(sample_index + cut_out_after - end_frame - right)] + template = template[: -(sample_index + cut_out_after - (end_frame + right))] else: local_waveform = traces_with_margin[cut_out_start:cut_out_end, sparse_indices] assert template.shape == local_waveform.shape @@ -550,10 +551,11 @@ def fit_collision( sample_last_centered = np.max(collision["sample_index"]) - (start_frame - left) # construct sparsity as union between units' sparsity - sparse_indices = np.zeros(sparsity_mask.shape[1], dtype="int") + common_sparse_mask = np.zeros(sparsity_mask.shape[1], dtype="int") for spike in collision: - sparse_indices_i = sparsity_mask[spike["unit_index"]] - sparse_indices = np.logical_or(sparse_indices, sparse_indices_i) + mask_i = sparsity_mask[spike["unit_index"]] + common_sparse_mask = np.logical_or(common_sparse_mask, mask_i) + (sparse_indices,) = np.nonzero(common_sparse_mask) local_waveform_start = max(0, sample_first_centered - cut_out_before) local_waveform_end = min(traces_with_margin.shape[0], sample_last_centered + cut_out_after) diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index 233625e09e..1214b84ac4 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -84,9 +84,16 @@ def get_projections(self, unit_id): Returns ------- proj: np.array - The PCA projections (num_waveforms, num_components, num_channels) + The PCA projections (num_waveforms, num_components, num_channels). + In case sparsity is used, only the projections on sparse channels are returned. """ - return self._extension_data[f"pca_{unit_id}"] + projections = self._extension_data[f"pca_{unit_id}"] + mode = self._params["mode"] + if mode in ("by_channel_local", "by_channel_global"): + sparsity = self.get_sparsity() + if sparsity is not None: + projections = projections[:, :, sparsity.unit_id_to_channel_indices[unit_id]] + return projections def get_pca_model(self): """ @@ -211,6 +218,10 @@ def project_new(self, new_waveforms, unit_id=None): wfs_flat = new_waveforms.reshape(new_waveforms.shape[0], -1) projections = pca_model.transform(wfs_flat) + # take care of sparsity (not in case of concatenated) + if mode in ("by_channel_local", "by_channel_global"): + if sparsity is not None: + projections = projections[:, :, sparsity.unit_id_to_channel_indices[unit_id]] return projections def get_sparsity(self): diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index b9c72f9b99..8657d1dced 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -5,7 +5,7 @@ from pathlib import Path from spikeinterface import extract_waveforms, load_extractor, compute_sparsity -from spikeinterface.extractors import toy_example +from spikeinterface.core.generate import generate_ground_truth_recording if hasattr(pytest, "global_test_folder"): cache_folder = pytest.global_test_folder / "postprocessing" @@ -26,7 +26,18 @@ def setUp(self): self.cache_folder = cache_folder # 1-segment - recording, sorting = toy_example(num_segments=1, num_units=10, num_channels=12) + recording, sorting = generate_ground_truth_recording( + durations=[10], + sampling_frequency=30000, + num_channels=12, + num_units=10, + dtype="float32", + seed=91, + generate_sorting_kwargs=dict(add_spikes_on_borders=True), + noise_kwargs=dict(noise_level=10.0, strategy="tile_pregenerated"), + ) + + # add gains and offsets and save gain = 0.1 recording.set_channel_gains(gain) recording.set_channel_offsets(0) @@ -53,7 +64,16 @@ def setUp(self): self.sparsity1 = compute_sparsity(we1, method="radius", radius_um=50) # 2-segments - recording, sorting = toy_example(num_segments=2, num_units=10) + recording, sorting = generate_ground_truth_recording( + durations=[10, 5], + sampling_frequency=30000, + num_channels=12, + num_units=10, + dtype="float32", + seed=91, + generate_sorting_kwargs=dict(add_spikes_on_borders=True), + noise_kwargs=dict(noise_level=10.0, strategy="tile_pregenerated"), + ) recording.set_channel_gains(gain) recording.set_channel_offsets(0) if (cache_folder / "toy_rec_2seg").is_dir(): diff --git a/src/spikeinterface/postprocessing/tests/test_principal_component.py b/src/spikeinterface/postprocessing/tests/test_principal_component.py index 5d64525b52..04ce42b70e 100644 --- a/src/spikeinterface/postprocessing/tests/test_principal_component.py +++ b/src/spikeinterface/postprocessing/tests/test_principal_component.py @@ -87,13 +87,13 @@ def test_sparse(self): pc.run() for i, unit_id in enumerate(unit_ids): proj = pc.get_projections(unit_id) - assert proj.shape[1:] == (5, 4) + assert proj.shape[1:] == (5, len(sparsity.unit_id_to_channel_ids[unit_id])) # test project_new unit_id = 3 new_wfs = we.get_waveforms(unit_id) new_proj = pc.project_new(new_wfs, unit_id=unit_id) - assert new_proj.shape == (new_wfs.shape[0], 5, 4) + assert new_proj.shape == (new_wfs.shape[0], 5, len(sparsity.unit_id_to_channel_ids[unit_id])) if DEBUG: import matplotlib.pyplot as plt @@ -197,8 +197,8 @@ def test_project_new(self): if __name__ == "__main__": test = PrincipalComponentsExtensionTest() test.setUp() - test.test_extension() - test.test_shapes() - test.test_compute_for_all_spikes() + # test.test_extension() + # test.test_shapes() + # test.test_compute_for_all_spikes() test.test_sparse() - test.test_project_new() + # test.test_project_new() From 73ceaacefecc4426d994ebca4ca006d667dada42 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 22 Sep 2023 12:06:15 +0200 Subject: [PATCH 10/18] Extend PCA to be able to return sparse projections and fix tests --- .../postprocessing/principal_component.py | 16 ++++++++++------ .../tests/test_principal_component.py | 12 ++++++++---- .../tests/test_quality_metric_calculator.py | 7 ++++--- 3 files changed, 22 insertions(+), 13 deletions(-) diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index 5d62216c20..8383dcbb43 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -72,7 +72,7 @@ def _select_extension_data(self, unit_ids): new_extension_data[k] = v return new_extension_data - def get_projections(self, unit_id): + def get_projections(self, unit_id, sparse=False): """ Returns the computed projections for the sampled waveforms of a unit id. @@ -80,16 +80,18 @@ def get_projections(self, unit_id): ---------- unit_id : int or str The unit id to return PCA projections for + sparse: bool, default False + If True, and sparsity is not None, only projections on sparse channels are returned. Returns ------- - proj: np.array + projections: np.array The PCA projections (num_waveforms, num_components, num_channels). In case sparsity is used, only the projections on sparse channels are returned. """ projections = self._extension_data[f"pca_{unit_id}"] mode = self._params["mode"] - if mode in ("by_channel_local", "by_channel_global"): + if mode in ("by_channel_local", "by_channel_global") and sparse: sparsity = self.get_sparsity() if sparsity is not None: projections = projections[:, :, sparsity.unit_id_to_channel_indices[unit_id]] @@ -141,7 +143,7 @@ def get_all_projections(self, channel_ids=None, unit_ids=None, outputs="id"): all_labels = [] #  can be unit_id or unit_index all_projections = [] for unit_index, unit_id in enumerate(unit_ids): - proj = self.get_projections(unit_id) + proj = self.get_projections(unit_id, sparse=False) if channel_ids is not None: chan_inds = self.waveform_extractor.channel_ids_to_indices(channel_ids) proj = proj[:, :, chan_inds] @@ -158,7 +160,7 @@ def get_all_projections(self, channel_ids=None, unit_ids=None, outputs="id"): return all_labels, all_projections - def project_new(self, new_waveforms, unit_id=None): + def project_new(self, new_waveforms, unit_id=None, sparse=False): """ Projects new waveforms or traces snippets on the PC components. @@ -168,6 +170,8 @@ def project_new(self, new_waveforms, unit_id=None): Array with new waveforms to project with shape (num_waveforms, num_samples, num_channels) unit_id: int or str In case PCA is sparse and mode is by_channel_local, the unit_id of 'new_waveforms' + sparse: bool, default: False + If True, and sparsity is not None, only projections on sparse channels are returned. Returns ------- @@ -219,7 +223,7 @@ def project_new(self, new_waveforms, unit_id=None): projections = pca_model.transform(wfs_flat) # take care of sparsity (not in case of concatenated) - if mode in ("by_channel_local", "by_channel_global"): + if mode in ("by_channel_local", "by_channel_global") and sparse: if sparsity is not None: projections = projections[:, :, sparsity.unit_id_to_channel_indices[unit_id]] return projections diff --git a/src/spikeinterface/postprocessing/tests/test_principal_component.py b/src/spikeinterface/postprocessing/tests/test_principal_component.py index 04ce42b70e..49591d9b89 100644 --- a/src/spikeinterface/postprocessing/tests/test_principal_component.py +++ b/src/spikeinterface/postprocessing/tests/test_principal_component.py @@ -86,14 +86,18 @@ def test_sparse(self): pc.set_params(n_components=5, mode=mode, sparsity=sparsity) pc.run() for i, unit_id in enumerate(unit_ids): - proj = pc.get_projections(unit_id) - assert proj.shape[1:] == (5, len(sparsity.unit_id_to_channel_ids[unit_id])) + proj_sparse = pc.get_projections(unit_id, sparse=True) + assert proj_sparse.shape[1:] == (5, len(sparsity.unit_id_to_channel_ids[unit_id])) + proj_dense = pc.get_projections(unit_id, sparse=False) + assert proj_dense.shape[1:] == (5, num_channels) # test project_new unit_id = 3 new_wfs = we.get_waveforms(unit_id) - new_proj = pc.project_new(new_wfs, unit_id=unit_id) - assert new_proj.shape == (new_wfs.shape[0], 5, len(sparsity.unit_id_to_channel_ids[unit_id])) + new_proj_sparse = pc.project_new(new_wfs, unit_id=unit_id, sparse=True) + assert new_proj_sparse.shape == (new_wfs.shape[0], 5, len(sparsity.unit_id_to_channel_ids[unit_id])) + new_proj_dense = pc.project_new(new_wfs, unit_id=unit_id, sparse=False) + assert new_proj_dense.shape == (new_wfs.shape[0], 5, num_channels) if DEBUG: import matplotlib.pyplot as plt diff --git a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py index 4fa65993d1..977beca210 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py @@ -261,7 +261,8 @@ def test_nn_metrics(self): we_sparse, metric_names=metric_names, sparsity=None, seed=0, n_jobs=2 ) for metric_name in metrics.columns: - assert np.allclose(metrics[metric_name], metrics_par[metric_name]) + # NaNs are skipped + assert np.allclose(metrics[metric_name].dropna(), metrics_par[metric_name].dropna()) def test_recordingless(self): we = self.we_long @@ -305,7 +306,7 @@ def test_empty_units(self): test.setUp() # test.test_drift_metrics() # test.test_extension() - # test.test_nn_metrics() + test.test_nn_metrics() # test.test_peak_sign() # test.test_empty_units() - test.test_recordingless() + # test.test_recordingless() From b9b6c15b42a64d877ea9fad9fca84424e2c97edf Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 22 Sep 2023 12:12:21 +0200 Subject: [PATCH 11/18] Add test to check correct order of spikes with borders --- src/spikeinterface/core/tests/test_generate.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/tests/test_generate.py b/src/spikeinterface/core/tests/test_generate.py index 3844e421ac..9a9c61766f 100644 --- a/src/spikeinterface/core/tests/test_generate.py +++ b/src/spikeinterface/core/tests/test_generate.py @@ -48,9 +48,15 @@ def test_generate_sorting_with_spikes_on_borders(): num_spikes_per_border=num_spikes_on_borders, border_size_samples=border_size_samples, ) + # check that segments are correctly sorted + all_spikes = sorting.to_spike_vector() + np.testing.assert_array_equal(all_spikes["segment_index"], np.sort(all_spikes["segment_index"])) + spikes = sorting.to_spike_vector(concatenated=False) # at least num_border spikes at borders for all segments - for i, spikes_in_segment in enumerate(spikes): + for spikes_in_segment in spikes: + # check that sample indices are correctly sorted within segments + np.testing.assert_array_equal(spikes_in_segment["sample_index"], np.sort(spikes_in_segment["sample_index"])) num_samples = int(segment_duration * 30000) assert np.sum(spikes_in_segment["sample_index"] < border_size_samples) >= num_spikes_on_borders assert ( From 4e79b5811d41e6343391a3a6b26fab97f657368b Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 22 Sep 2023 13:32:51 +0200 Subject: [PATCH 12/18] propagate UnitSelector to others ipywidgets --- src/spikeinterface/widgets/amplitudes.py | 12 ++- src/spikeinterface/widgets/base.py | 3 +- src/spikeinterface/widgets/metrics.py | 21 ++-- src/spikeinterface/widgets/spike_locations.py | 34 +++---- .../widgets/spikes_on_traces.py | 87 ++++++++++------- src/spikeinterface/widgets/unit_locations.py | 29 +++--- src/spikeinterface/widgets/unit_waveforms.py | 50 +++++----- .../widgets/utils_ipywidgets.py | 96 ------------------- 8 files changed, 121 insertions(+), 211 deletions(-) diff --git a/src/spikeinterface/widgets/amplitudes.py b/src/spikeinterface/widgets/amplitudes.py index b60de98cb0..5aa090b1b4 100644 --- a/src/spikeinterface/widgets/amplitudes.py +++ b/src/spikeinterface/widgets/amplitudes.py @@ -147,13 +147,16 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): else: bins = dp.bins ax_hist = self.axes.flatten()[1] - ax_hist.hist(amps, bins=bins, orientation="horizontal", color=dp.unit_colors[unit_id], alpha=0.8) + # this is super slow, using plot and np.histogram is really much faster (and nicer!) + # ax_hist.hist(amps, bins=bins, orientation="horizontal", color=dp.unit_colors[unit_id], alpha=0.8) + count, bins = np.histogram(amps, bins=bins) + ax_hist.plot(count, bins[:-1], color=dp.unit_colors[unit_id], alpha=0.8) if dp.plot_histograms: ax_hist = self.axes.flatten()[1] ax_hist.set_ylim(scatter_ax.get_ylim()) ax_hist.axis("off") - self.figure.tight_layout() + # self.figure.tight_layout() if dp.plot_legend: if hasattr(self, "legend") and self.legend is not None: @@ -174,7 +177,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): # import ipywidgets.widgets as widgets import ipywidgets.widgets as W from IPython.display import display - from .utils_ipywidgets import check_ipywidget_backend, make_unit_controller, UnitSelector + from .utils_ipywidgets import check_ipywidget_backend, UnitSelector check_ipywidget_backend() @@ -200,7 +203,6 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.checkbox_histograms = W.Checkbox( value=data_plot["plot_histograms"], description="hist", - # disabled=False, ) left_sidebar = W.VBox( @@ -231,6 +233,7 @@ def _full_update_plot(self, change=None): data_plot = self.next_data_plot data_plot["unit_ids"] = self.unit_selector.value data_plot["plot_histograms"] = self.checkbox_histograms.value + data_plot["plot_legend"] = False backend_kwargs = dict(figure=self.figure, axes=None, ax=None) self.plot_matplotlib(data_plot, **backend_kwargs) @@ -243,6 +246,7 @@ def _update_plot(self, change=None): data_plot = self.next_data_plot data_plot["unit_ids"] = self.unit_selector.value data_plot["plot_histograms"] = self.checkbox_histograms.value + data_plot["plot_legend"] = False backend_kwargs = dict(figure=None, axes=self.axes, ax=None) self.plot_matplotlib(data_plot, **backend_kwargs) diff --git a/src/spikeinterface/widgets/base.py b/src/spikeinterface/widgets/base.py index 4ed83fcca9..1ff691320a 100644 --- a/src/spikeinterface/widgets/base.py +++ b/src/spikeinterface/widgets/base.py @@ -38,6 +38,7 @@ def set_default_plotter_backend(backend): "width_cm": "Width of the figure in cm (default 10)", "height_cm": "Height of the figure in cm (default 6)", "display": "If True, widgets are immediately displayed", + # "controllers": "" }, "ephyviewer": {}, } @@ -45,7 +46,7 @@ def set_default_plotter_backend(backend): default_backend_kwargs = { "matplotlib": {"figure": None, "ax": None, "axes": None, "ncols": 5, "figsize": None, "figtitle": None}, "sortingview": {"generate_url": True, "display": True, "figlabel": None, "height": None}, - "ipywidgets": {"width_cm": 25, "height_cm": 10, "display": True}, + "ipywidgets": {"width_cm": 25, "height_cm": 10, "display": True, "controllers": None}, "ephyviewer": {}, } diff --git a/src/spikeinterface/widgets/metrics.py b/src/spikeinterface/widgets/metrics.py index 9dc51f522e..604da35e65 100644 --- a/src/spikeinterface/widgets/metrics.py +++ b/src/spikeinterface/widgets/metrics.py @@ -128,7 +128,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt import ipywidgets.widgets as widgets from IPython.display import display - from .utils_ipywidgets import check_ipywidget_backend, make_unit_controller + from .utils_ipywidgets import check_ipywidget_backend, UnitSelector check_ipywidget_backend() @@ -147,34 +147,29 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): with output: self.figure = plt.figure(figsize=((ratios[1] * width_cm) * cm, height_cm * cm)) plt.show() - if data_plot["unit_ids"] is None: - data_plot["unit_ids"] = [] - unit_widget, unit_controller = make_unit_controller( - data_plot["unit_ids"], list(data_plot["unit_colors"].keys()), ratios[0] * width_cm, height_cm - ) - - self.controller = unit_controller + self.unit_selector = UnitSelector(data_plot["sorting"].unit_ids) + self.unit_selector.value = [ ] - for w in self.controller.values(): - w.observe(self._update_ipywidget) self.widget = widgets.AppLayout( center=self.figure.canvas, - left_sidebar=unit_widget, + left_sidebar=self.unit_selector, pane_widths=ratios + [0], ) # a first update self._update_ipywidget(None) + self.unit_selector.observe(self._update_ipywidget, names='value', type="change") + if backend_kwargs["display"]: display(self.widget) def _update_ipywidget(self, change): from matplotlib.lines import Line2D - unit_ids = self.controller["unit_ids"].value + unit_ids = self.unit_selector.value unit_colors = self.data_plot["unit_colors"] # matplotlib next_data_plot dict update at each call @@ -198,6 +193,7 @@ def _update_ipywidget(self, change): self.plot_matplotlib(self.data_plot, **backend_kwargs) if len(unit_ids) > 0: + # TODO later make option to control legend or not for l in self.figure.legends: l.remove() handles = [ @@ -212,6 +208,7 @@ def _update_ipywidget(self, change): self.figure.canvas.draw() self.figure.canvas.flush_events() + def plot_sortingview(self, data_plot, **backend_kwargs): import sortingview.views as vv from .utils_sortingview import generate_unit_table_view, make_serializable, handle_display_and_url diff --git a/src/spikeinterface/widgets/spike_locations.py b/src/spikeinterface/widgets/spike_locations.py index 9771b2c0e9..926051b8f9 100644 --- a/src/spikeinterface/widgets/spike_locations.py +++ b/src/spikeinterface/widgets/spike_locations.py @@ -191,7 +191,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt import ipywidgets.widgets as widgets from IPython.display import display - from .utils_ipywidgets import check_ipywidget_backend, make_unit_controller + from .utils_ipywidgets import check_ipywidget_backend, UnitSelector check_ipywidget_backend() @@ -210,48 +210,36 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): fig, self.ax = plt.subplots(figsize=((ratios[1] * width_cm) * cm, height_cm * cm)) plt.show() - data_plot["unit_ids"] = data_plot["unit_ids"][:1] - - unit_widget, unit_controller = make_unit_controller( - data_plot["unit_ids"], - list(data_plot["unit_colors"].keys()), - ratios[0] * width_cm, - height_cm, - ) - - self.controller = unit_controller - - for w in self.controller.values(): - w.observe(self._update_ipywidget) + self.unit_selector = UnitSelector(data_plot["unit_ids"]) + self.unit_selector.value = list(data_plot["unit_ids"])[:1] self.widget = widgets.AppLayout( center=fig.canvas, - left_sidebar=unit_widget, + left_sidebar=self.unit_selector, pane_widths=ratios + [0], ) # a first update - self._update_ipywidget(None) + self._update_ipywidget() + + self.unit_selector.observe(self._update_ipywidget, names='value', type="change") if backend_kwargs["display"]: display(self.widget) - def _update_ipywidget(self, change): + def _update_ipywidget(self, change=None): self.ax.clear() - unit_ids = self.controller["unit_ids"].value - # matplotlib next_data_plot dict update at each call data_plot = self.next_data_plot - data_plot["unit_ids"] = unit_ids + data_plot["unit_ids"] = self.unit_selector.value data_plot["plot_all_units"] = True + # TODO add an option checkbox for legend data_plot["plot_legend"] = True data_plot["hide_axis"] = True - backend_kwargs = {} - backend_kwargs["ax"] = self.ax + backend_kwargs = dict(ax=self.ax) - # self.mpl_plotter.do_plot(data_plot, **backend_kwargs) self.plot_matplotlib(data_plot, **backend_kwargs) fig = self.ax.get_figure() fig.canvas.draw() diff --git a/src/spikeinterface/widgets/spikes_on_traces.py b/src/spikeinterface/widgets/spikes_on_traces.py index ae036d1ba1..2f748cc0fc 100644 --- a/src/spikeinterface/widgets/spikes_on_traces.py +++ b/src/spikeinterface/widgets/spikes_on_traces.py @@ -149,20 +149,20 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): sorting = we.sorting # first plot time series - ts_widget = TracesWidget(recording, **dp.options, backend="matplotlib", **backend_kwargs) - self.ax = ts_widget.ax - self.axes = ts_widget.axes - self.figure = ts_widget.figure + traces_widget = TracesWidget(recording, **dp.options, backend="matplotlib", **backend_kwargs) + self.ax = traces_widget.ax + self.axes = traces_widget.axes + self.figure = traces_widget.figure ax = self.ax - frame_range = ts_widget.data_plot["frame_range"] - segment_index = ts_widget.data_plot["segment_index"] - min_y = np.min(ts_widget.data_plot["channel_locations"][:, 1]) - max_y = np.max(ts_widget.data_plot["channel_locations"][:, 1]) + frame_range = traces_widget.data_plot["frame_range"] + segment_index = traces_widget.data_plot["segment_index"] + min_y = np.min(traces_widget.data_plot["channel_locations"][:, 1]) + max_y = np.max(traces_widget.data_plot["channel_locations"][:, 1]) - n = len(ts_widget.data_plot["channel_ids"]) - order = ts_widget.data_plot["order"] + n = len(traces_widget.data_plot["channel_ids"]) + order = traces_widget.data_plot["order"] if order is None: order = np.arange(n) @@ -210,13 +210,13 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): # construct waveforms label_set = False if len(spike_frames_to_plot) > 0: - vspacing = ts_widget.data_plot["vspacing"] - traces = ts_widget.data_plot["list_traces"][0] + vspacing = traces_widget.data_plot["vspacing"] + traces = traces_widget.data_plot["list_traces"][0] waveform_idxs = spike_frames_to_plot[:, None] + np.arange(-we.nbefore, we.nafter) - frame_range[0] - waveform_idxs = np.clip(waveform_idxs, 0, len(ts_widget.data_plot["times"]) - 1) + waveform_idxs = np.clip(waveform_idxs, 0, len(traces_widget.data_plot["times"]) - 1) - times = ts_widget.data_plot["times"][waveform_idxs] + times = traces_widget.data_plot["times"][waveform_idxs] # discontinuity times[:, -1] = np.nan @@ -224,7 +224,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): waveforms = traces[waveform_idxs] # [:, :, order] waveforms_r = waveforms.reshape((waveforms.shape[0] * waveforms.shape[1], waveforms.shape[2])) - for i, chan_id in enumerate(ts_widget.data_plot["channel_ids"]): + for i, chan_id in enumerate(traces_widget.data_plot["channel_ids"]): offset = vspacing * i if chan_id in chan_ids: l = ax.plot(times_r, offset + waveforms_r[:, i], color=dp.unit_colors[unit]) @@ -232,13 +232,13 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): handles.append(l[0]) labels.append(unit) label_set = True - ax.legend(handles, labels) + # ax.legend(handles, labels) def plot_ipywidgets(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt import ipywidgets.widgets as widgets from IPython.display import display - from .utils_ipywidgets import check_ipywidget_backend, make_unit_controller + from .utils_ipywidgets import check_ipywidget_backend, UnitSelector check_ipywidget_backend() @@ -256,37 +256,58 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): width_cm = backend_kwargs["width_cm"] # plot timeseries - ts_widget = TracesWidget(we.recording, **dp.options, backend="ipywidgets", **backend_kwargs_ts) - self.ax = ts_widget.ax - self.axes = ts_widget.axes - self.figure = ts_widget.figure + self._traces_widget = TracesWidget(we.recording, **dp.options, backend="ipywidgets", **backend_kwargs_ts) + self.ax = self._traces_widget.ax + self.axes = self._traces_widget.axes + self.figure = self._traces_widget.figure - unit_widget, unit_controller = make_unit_controller( - data_plot["unit_ids"], we.unit_ids, ratios[0] * width_cm, height_cm - ) + self.sampling_frequency = self._traces_widget.rec0.sampling_frequency - self.controller = dict() - self.controller.update(ts_widget.controller) - self.controller.update(unit_controller) + self.time_slider = self._traces_widget.time_slider - for w in self.controller.values(): - w.observe(self._update_ipywidget) + self.unit_selector = UnitSelector(data_plot["unit_ids"]) + self.unit_selector.value = list(data_plot["unit_ids"])[:1] - self.widget = widgets.AppLayout(center=ts_widget.widget, left_sidebar=unit_widget, pane_widths=ratios + [0]) + self.widget = widgets.AppLayout(center=self._traces_widget.widget, + left_sidebar=self.unit_selector, + pane_widths=ratios + [0]) # a first update - self._update_ipywidget(None) + self._update_ipywidget() + + # remove callback from traces_widget + self.unit_selector.observe(self._update_ipywidget, names='value', type="change") + self._traces_widget.time_slider.observe(self._update_ipywidget, names='value', type="change") + self._traces_widget.channel_selector.observe(self._update_ipywidget, names='value', type="change") + self._traces_widget.scaler.observe(self._update_ipywidget, names='value', type="change") + if backend_kwargs["display"]: display(self.widget) - def _update_ipywidget(self, change): + def _update_ipywidget(self, change=None): self.ax.clear() - unit_ids = self.controller["unit_ids"].value + # TODO later: this is still a bit buggy because it make double refresh one from _traces_widget and one internal + + unit_ids = self.unit_selector.value + start_frame, end_frame, segment_index = self._traces_widget.time_slider.value + channel_ids = self._traces_widget.channel_selector.value + mode = self._traces_widget.mode_selector.value data_plot = self.next_data_plot data_plot["unit_ids"] = unit_ids + data_plot["options"].update( + dict( + channel_ids=channel_ids, + segment_index=segment_index, + # frame_range=(start_frame, end_frame), + time_range=np.array([start_frame, end_frame]) / self.sampling_frequency, + mode=mode, + with_colorbar=False, + ) + ) + backend_kwargs = {} backend_kwargs["ax"] = self.ax diff --git a/src/spikeinterface/widgets/unit_locations.py b/src/spikeinterface/widgets/unit_locations.py index 42267e711f..8526a95d60 100644 --- a/src/spikeinterface/widgets/unit_locations.py +++ b/src/spikeinterface/widgets/unit_locations.py @@ -167,7 +167,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt import ipywidgets.widgets as widgets from IPython.display import display - from .utils_ipywidgets import check_ipywidget_backend, make_unit_controller + from .utils_ipywidgets import check_ipywidget_backend, UnitSelector check_ipywidget_backend() @@ -186,42 +186,35 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): fig, self.ax = plt.subplots(figsize=((ratios[1] * width_cm) * cm, height_cm * cm)) plt.show() - data_plot["unit_ids"] = data_plot["unit_ids"][:1] - unit_widget, unit_controller = make_unit_controller( - data_plot["unit_ids"], list(data_plot["unit_colors"].keys()), ratios[0] * width_cm, height_cm - ) - - self.controller = unit_controller - - for w in self.controller.values(): - w.observe(self._update_ipywidget) + self.unit_selector = UnitSelector(data_plot["unit_ids"]) + self.unit_selector.value = list(data_plot["unit_ids"])[:1] self.widget = widgets.AppLayout( center=fig.canvas, - left_sidebar=unit_widget, + left_sidebar=self.unit_selector, pane_widths=ratios + [0], ) # a first update - self._update_ipywidget(None) + self._update_ipywidget() + + self.unit_selector.observe(self._update_ipywidget, names='value', type="change") if backend_kwargs["display"]: display(self.widget) - def _update_ipywidget(self, change): + def _update_ipywidget(self, change=None): self.ax.clear() - unit_ids = self.controller["unit_ids"].value - # matplotlib next_data_plot dict update at each call data_plot = self.next_data_plot - data_plot["unit_ids"] = unit_ids + data_plot["unit_ids"] = self.unit_selector.value data_plot["plot_all_units"] = True + # TODO later add an option checkbox for legend data_plot["plot_legend"] = True data_plot["hide_axis"] = True - backend_kwargs = {} - backend_kwargs["ax"] = self.ax + backend_kwargs = dict(ax=self.ax) self.plot_matplotlib(data_plot, **backend_kwargs) fig = self.ax.get_figure() diff --git a/src/spikeinterface/widgets/unit_waveforms.py b/src/spikeinterface/widgets/unit_waveforms.py index e64765b44b..f01c842b66 100644 --- a/src/spikeinterface/widgets/unit_waveforms.py +++ b/src/spikeinterface/widgets/unit_waveforms.py @@ -250,7 +250,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt import ipywidgets.widgets as widgets from IPython.display import display - from .utils_ipywidgets import check_ipywidget_backend, make_unit_controller + from .utils_ipywidgets import check_ipywidget_backend, UnitSelector check_ipywidget_backend() @@ -274,44 +274,33 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.fig_probe, self.ax_probe = plt.subplots(figsize=((ratios[2] * width_cm) * cm, height_cm * cm)) plt.show() - data_plot["unit_ids"] = data_plot["unit_ids"][:1] - unit_widget, unit_controller = make_unit_controller( - data_plot["unit_ids"], we.unit_ids, ratios[0] * width_cm, height_cm - ) + self.unit_selector = UnitSelector(data_plot["unit_ids"]) + self.unit_selector.value = list(data_plot["unit_ids"])[:1] + - same_axis_button = widgets.Checkbox( + self.same_axis_button = widgets.Checkbox( value=False, description="same axis", disabled=False, ) - plot_templates_button = widgets.Checkbox( + self.plot_templates_button = widgets.Checkbox( value=True, description="plot templates", disabled=False, ) - hide_axis_button = widgets.Checkbox( + self.hide_axis_button = widgets.Checkbox( value=True, description="hide axis", disabled=False, ) - footer = widgets.HBox([same_axis_button, plot_templates_button, hide_axis_button]) - - self.controller = { - "same_axis": same_axis_button, - "plot_templates": plot_templates_button, - "hide_axis": hide_axis_button, - } - self.controller.update(unit_controller) - - for w in self.controller.values(): - w.observe(self._update_ipywidget) + footer = widgets.HBox([self.same_axis_button, self.plot_templates_button, self.hide_axis_button]) self.widget = widgets.AppLayout( center=self.fig_wf.canvas, - left_sidebar=unit_widget, + left_sidebar=self.unit_selector, right_sidebar=self.fig_probe.canvas, pane_widths=ratios, footer=footer, @@ -320,6 +309,11 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): # a first update self._update_ipywidget(None) + self.unit_selector.observe(self._update_ipywidget, names='value', type="change") + for w in self.same_axis_button, self.plot_templates_button, self.hide_axis_button: + w.observe(self._update_ipywidget, names='value', type="change") + + if backend_kwargs["display"]: display(self.widget) @@ -327,10 +321,15 @@ def _update_ipywidget(self, change): self.fig_wf.clear() self.ax_probe.clear() - unit_ids = self.controller["unit_ids"].value - same_axis = self.controller["same_axis"].value - plot_templates = self.controller["plot_templates"].value - hide_axis = self.controller["hide_axis"].value + # unit_ids = self.controller["unit_ids"].value + unit_ids = self.unit_selector.value + # same_axis = self.controller["same_axis"].value + # plot_templates = self.controller["plot_templates"].value + # hide_axis = self.controller["hide_axis"].value + + same_axis = self.same_axis_button.value + plot_templates = self.plot_templates_button.value + hide_axis = self.hide_axis_button.value # matplotlib next_data_plot dict update at each call data_plot = self.next_data_plot @@ -341,6 +340,8 @@ def _update_ipywidget(self, change): data_plot["plot_templates"] = plot_templates if data_plot["plot_waveforms"]: data_plot["wfs_by_ids"] = {unit_id: self.we.get_waveforms(unit_id) for unit_id in unit_ids} + + # TODO option for plot_legend backend_kwargs = {} @@ -369,6 +370,7 @@ def _update_ipywidget(self, change): self.ax_probe.axis("off") self.ax_probe.axis("equal") + # TODO this could be done with probeinterface plotting plotting tools!! for unit in unit_ids: channel_inds = data_plot["sparsity"].unit_id_to_channel_indices[unit] self.ax_probe.plot( diff --git a/src/spikeinterface/widgets/utils_ipywidgets.py b/src/spikeinterface/widgets/utils_ipywidgets.py index d2c41f234a..57550c0910 100644 --- a/src/spikeinterface/widgets/utils_ipywidgets.py +++ b/src/spikeinterface/widgets/utils_ipywidgets.py @@ -11,102 +11,6 @@ def check_ipywidget_backend(): assert "ipympl" in mpl_backend, "To use the 'ipywidgets' backend, you have to set %matplotlib widget" -# def make_timeseries_controller(t_start, t_stop, layer_keys, num_segments, time_range, mode, all_layers, width_cm): -# time_slider = W.FloatSlider( -# orientation="horizontal", -# description="time:", -# value=time_range[0], -# min=t_start, -# max=t_stop, -# continuous_update=False, -# layout=W.Layout(width=f"{width_cm}cm"), -# ) -# layer_selector = W.Dropdown(description="layer", options=layer_keys) -# segment_selector = W.Dropdown(description="segment", options=list(range(num_segments))) -# window_sizer = W.BoundedFloatText(value=np.diff(time_range)[0], step=0.1, min=0.005, description="win (s)") -# mode_selector = W.Dropdown(options=["line", "map"], description="mode", value=mode) -# all_layers = W.Checkbox(description="plot all layers", value=all_layers) - -# controller = { -# "layer_key": layer_selector, -# "segment_index": segment_selector, -# "window": window_sizer, -# "t_start": time_slider, -# "mode": mode_selector, -# "all_layers": all_layers, -# } -# widget = W.VBox( -# [time_slider, W.HBox([all_layers, layer_selector, segment_selector, window_sizer, mode_selector])] -# ) - -# return widget, controller - - -def make_unit_controller(unit_ids, all_unit_ids, width_cm, height_cm): - unit_label = W.Label(value="units:") - - unit_selector = W.SelectMultiple( - options=all_unit_ids, - value=list(unit_ids), - disabled=False, - layout=W.Layout(width=f"{width_cm}cm", height=f"{height_cm}cm"), - ) - - controller = {"unit_ids": unit_selector} - widget = W.VBox([unit_label, unit_selector]) - - return widget, controller - - -# def make_channel_controller(recording, width_cm, height_cm): -# channel_label = W.Label("channel indices:", layout=W.Layout(justify_content="center")) -# channel_selector = W.IntRangeSlider( -# value=[0, recording.get_num_channels()], -# min=0, -# max=recording.get_num_channels(), -# step=1, -# disabled=False, -# continuous_update=False, -# orientation="vertical", -# readout=True, -# readout_format="d", -# layout=W.Layout(width=f"{0.8 * width_cm}cm", height=f"{height_cm}cm"), -# ) - -# controller = {"channel_inds": channel_selector} -# widget = W.VBox([channel_label, channel_selector]) - -# return widget, controller - - -# def make_scale_controller(width_cm, height_cm): -# scale_label = W.Label("Scale", layout=W.Layout(justify_content="center")) - -# plus_selector = W.Button( -# description="", -# disabled=False, -# button_style="", # 'success', 'info', 'warning', 'danger' or '' -# tooltip="Increase scale", -# icon="arrow-up", -# layout=W.Layout(width=f"{0.8 * width_cm}cm", height=f"{0.4 * height_cm}cm"), -# ) - -# minus_selector = W.Button( -# description="", -# disabled=False, -# button_style="", # 'success', 'info', 'warning', 'danger' or '' -# tooltip="Decrease scale", -# icon="arrow-down", -# layout=W.Layout(width=f"{0.8 * width_cm}cm", height=f"{0.4 * height_cm}cm"), -# ) - -# controller = {"plus": plus_selector, "minus": minus_selector} -# widget = W.VBox([scale_label, plus_selector, minus_selector]) - -# return widget, controller - - - class TimeSlider(W.HBox): value = traitlets.Tuple(traitlets.Int(), traitlets.Int(), traitlets.Int()) From f315594b0b88bed01f01232688d62c4c2e4bc0fe Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 22 Sep 2023 15:49:47 +0200 Subject: [PATCH 13/18] protect TimeSlider on the upper limit to avoid border effect on window size --- src/spikeinterface/widgets/utils_ipywidgets.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/widgets/utils_ipywidgets.py b/src/spikeinterface/widgets/utils_ipywidgets.py index 57550c0910..ee6133a990 100644 --- a/src/spikeinterface/widgets/utils_ipywidgets.py +++ b/src/spikeinterface/widgets/utils_ipywidgets.py @@ -54,7 +54,7 @@ def __init__(self, durations, sampling_frequency, time_range=(0, 1.), **kwargs): # description='time:', value=start_frame, min=0, - max=self.frame_limits[self.segment_index], + max=self.frame_limits[self.segment_index] - 1, readout=False, continuous_update=False, layout=W.Layout(width=f'70%') @@ -112,10 +112,13 @@ def update_time(self, new_frame=None, new_time=None, update_slider=False, update else: start_frame = new_frame delta_s = self.window_sizer.value - end_frame = start_frame + int(delta_s * self.sampling_frequency) - + delta = int(delta_s * self.sampling_frequency) + # clip + start_frame = min(self.frame_limits[self.segment_index] - delta, start_frame) start_frame = max(0, start_frame) + end_frame = start_frame + delta + end_frame = min(self.frame_limits[self.segment_index], end_frame) @@ -170,7 +173,7 @@ def segment_changed(self, change=None): self.slider.unobserve(self.slider_moved, names='value', type="change") # self.slider.value = 0 - self.slider.max = self.frame_limits[self.segment_index] + self.slider.max = self.frame_limits[self.segment_index] - 1 self.slider.observe(self.slider_moved, names='value', type="change") self.update_time(new_frame=0, update_slider=True, update_label=True) From 2c015f78e9311e106e9d2fda4e4026a61ca68c5b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 27 Sep 2023 09:28:28 +0000 Subject: [PATCH 14/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/widgets/amplitudes.py | 7 +- src/spikeinterface/widgets/base.py | 2 +- src/spikeinterface/widgets/metrics.py | 6 +- src/spikeinterface/widgets/spike_locations.py | 2 +- .../widgets/spikes_on_traces.py | 20 +- src/spikeinterface/widgets/traces.py | 51 ++-- src/spikeinterface/widgets/unit_locations.py | 2 +- src/spikeinterface/widgets/unit_waveforms.py | 8 +- .../widgets/utils_ipywidgets.py | 222 +++++++++--------- 9 files changed, 163 insertions(+), 157 deletions(-) diff --git a/src/spikeinterface/widgets/amplitudes.py b/src/spikeinterface/widgets/amplitudes.py index 5aa090b1b4..6b6496a577 100644 --- a/src/spikeinterface/widgets/amplitudes.py +++ b/src/spikeinterface/widgets/amplitudes.py @@ -174,6 +174,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): def plot_ipywidgets(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt + # import ipywidgets.widgets as widgets import ipywidgets.widgets as W from IPython.display import display @@ -210,7 +211,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.unit_selector, self.checkbox_histograms, ], - layout = W.Layout(align_items="center", width="4cm", height="100%"), + layout=W.Layout(align_items="center", width="4cm", height="100%"), ) self.widget = W.AppLayout( @@ -222,8 +223,8 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): # a first update self._full_update_plot() - self.unit_selector.observe(self._update_plot, names='value', type="change") - self.checkbox_histograms.observe(self._full_update_plot, names='value', type="change") + self.unit_selector.observe(self._update_plot, names="value", type="change") + self.checkbox_histograms.observe(self._full_update_plot, names="value", type="change") if backend_kwargs["display"]: display(self.widget) diff --git a/src/spikeinterface/widgets/base.py b/src/spikeinterface/widgets/base.py index 1ff691320a..9fc7b73707 100644 --- a/src/spikeinterface/widgets/base.py +++ b/src/spikeinterface/widgets/base.py @@ -38,7 +38,7 @@ def set_default_plotter_backend(backend): "width_cm": "Width of the figure in cm (default 10)", "height_cm": "Height of the figure in cm (default 6)", "display": "If True, widgets are immediately displayed", - # "controllers": "" + # "controllers": "" }, "ephyviewer": {}, } diff --git a/src/spikeinterface/widgets/metrics.py b/src/spikeinterface/widgets/metrics.py index 604da35e65..c7b701c8b0 100644 --- a/src/spikeinterface/widgets/metrics.py +++ b/src/spikeinterface/widgets/metrics.py @@ -149,8 +149,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): plt.show() self.unit_selector = UnitSelector(data_plot["sorting"].unit_ids) - self.unit_selector.value = [ ] - + self.unit_selector.value = [] self.widget = widgets.AppLayout( center=self.figure.canvas, @@ -161,7 +160,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): # a first update self._update_ipywidget(None) - self.unit_selector.observe(self._update_ipywidget, names='value', type="change") + self.unit_selector.observe(self._update_ipywidget, names="value", type="change") if backend_kwargs["display"]: display(self.widget) @@ -208,7 +207,6 @@ def _update_ipywidget(self, change): self.figure.canvas.draw() self.figure.canvas.flush_events() - def plot_sortingview(self, data_plot, **backend_kwargs): import sortingview.views as vv from .utils_sortingview import generate_unit_table_view, make_serializable, handle_display_and_url diff --git a/src/spikeinterface/widgets/spike_locations.py b/src/spikeinterface/widgets/spike_locations.py index 926051b8f9..fda2356105 100644 --- a/src/spikeinterface/widgets/spike_locations.py +++ b/src/spikeinterface/widgets/spike_locations.py @@ -222,7 +222,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): # a first update self._update_ipywidget() - self.unit_selector.observe(self._update_ipywidget, names='value', type="change") + self.unit_selector.observe(self._update_ipywidget, names="value", type="change") if backend_kwargs["display"]: display(self.widget) diff --git a/src/spikeinterface/widgets/spikes_on_traces.py b/src/spikeinterface/widgets/spikes_on_traces.py index 2f748cc0fc..c2bed8fe41 100644 --- a/src/spikeinterface/widgets/spikes_on_traces.py +++ b/src/spikeinterface/widgets/spikes_on_traces.py @@ -232,7 +232,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): handles.append(l[0]) labels.append(unit) label_set = True - # ax.legend(handles, labels) + # ax.legend(handles, labels) def plot_ipywidgets(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt @@ -268,19 +268,18 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.unit_selector = UnitSelector(data_plot["unit_ids"]) self.unit_selector.value = list(data_plot["unit_ids"])[:1] - self.widget = widgets.AppLayout(center=self._traces_widget.widget, - left_sidebar=self.unit_selector, - pane_widths=ratios + [0]) + self.widget = widgets.AppLayout( + center=self._traces_widget.widget, left_sidebar=self.unit_selector, pane_widths=ratios + [0] + ) # a first update self._update_ipywidget() # remove callback from traces_widget - self.unit_selector.observe(self._update_ipywidget, names='value', type="change") - self._traces_widget.time_slider.observe(self._update_ipywidget, names='value', type="change") - self._traces_widget.channel_selector.observe(self._update_ipywidget, names='value', type="change") - self._traces_widget.scaler.observe(self._update_ipywidget, names='value', type="change") - + self.unit_selector.observe(self._update_ipywidget, names="value", type="change") + self._traces_widget.time_slider.observe(self._update_ipywidget, names="value", type="change") + self._traces_widget.channel_selector.observe(self._update_ipywidget, names="value", type="change") + self._traces_widget.scaler.observe(self._update_ipywidget, names="value", type="change") if backend_kwargs["display"]: display(self.widget) @@ -305,10 +304,9 @@ def _update_ipywidget(self, change=None): time_range=np.array([start_frame, end_frame]) / self.sampling_frequency, mode=mode, with_colorbar=False, - ) + ) ) - backend_kwargs = {} backend_kwargs["ax"] = self.ax diff --git a/src/spikeinterface/widgets/traces.py b/src/spikeinterface/widgets/traces.py index d107c5cb23..9b6716e8f3 100644 --- a/src/spikeinterface/widgets/traces.py +++ b/src/spikeinterface/widgets/traces.py @@ -290,7 +290,6 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): check_ipywidget_backend() self.next_data_plot = data_plot.copy() - self.recordings = data_plot["recordings"] @@ -314,7 +313,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.time_slider = TimeSlider( durations=[rec0.get_duration(s) for s in range(rec0.get_num_segments())], sampling_frequency=rec0.sampling_frequency, - # layout=W.Layout(height="2cm"), + # layout=W.Layout(height="2cm"), ) start_frame = int(data_plot["time_range"][0] * rec0.sampling_frequency) @@ -324,14 +323,17 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): _layer_keys = data_plot["layer_keys"] if len(_layer_keys) > 1: - _layer_keys = ['ALL'] + _layer_keys - self.layer_selector = W.Dropdown(options=_layer_keys, - layout=W.Layout(width="95%"), - ) - self.mode_selector = W.Dropdown(options=["line", "map"], value=data_plot["mode"], - # layout=W.Layout(width="5cm"), - layout=W.Layout(width="95%"), - ) + _layer_keys = ["ALL"] + _layer_keys + self.layer_selector = W.Dropdown( + options=_layer_keys, + layout=W.Layout(width="95%"), + ) + self.mode_selector = W.Dropdown( + options=["line", "map"], + value=data_plot["mode"], + # layout=W.Layout(width="5cm"), + layout=W.Layout(width="95%"), + ) self.scaler = ScaleWidget() self.channel_selector = ChannelSelector(self.rec0.channel_ids) @@ -343,9 +345,9 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.mode_selector, self.scaler, # self.channel_selector, - ], + ], layout=W.Layout(width="3.5cm"), - align_items='center', + align_items="center", ) self.return_scaled = data_plot["return_scaled"] @@ -353,7 +355,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.widget = widgets.AppLayout( center=self.figure.canvas, footer=self.time_slider, - left_sidebar = left_sidebar, + left_sidebar=left_sidebar, right_sidebar=self.channel_selector, pane_heights=[0, 6, 1], pane_widths=ratios, @@ -365,28 +367,28 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): # callbacks: # some widgets generate a full retrieve + refresh - self.time_slider.observe(self._retrieve_traces, names='value', type="change") - self.layer_selector.observe(self._retrieve_traces, names='value', type="change") - self.channel_selector.observe(self._retrieve_traces, names='value', type="change") + self.time_slider.observe(self._retrieve_traces, names="value", type="change") + self.layer_selector.observe(self._retrieve_traces, names="value", type="change") + self.channel_selector.observe(self._retrieve_traces, names="value", type="change") # other widgets only refresh - self.scaler.observe(self._update_plot, names='value', type="change") + self.scaler.observe(self._update_plot, names="value", type="change") # map is a special case because needs to check layer also - self.mode_selector.observe(self._mode_changed, names='value', type="change") - + self.mode_selector.observe(self._mode_changed, names="value", type="change") + if backend_kwargs["display"]: # self.check_backend() display(self.widget) def _get_layers(self): layer = self.layer_selector.value - if layer == 'ALL': + if layer == "ALL": layer_keys = self.data_plot["layer_keys"] else: layer_keys = [layer] if self.mode_selector.value == "map": layer_keys = layer_keys[:1] return layer_keys - + def _mode_changed(self, change=None): if self.mode_selector.value == "map" and self.layer_selector.value == "ALL": self.layer_selector.value = self.data_plot["layer_keys"][0] @@ -400,7 +402,7 @@ def _retrieve_traces(self, change=None): order, _ = order_channels_by_depth(self.rec0, channel_ids) else: order = None - + start_frame, end_frame, segment_index = self.time_slider.value time_range = np.array([start_frame, end_frame]) / self.rec0.sampling_frequency @@ -439,9 +441,9 @@ def _update_plot(self, change=None): data_plot["clims"] = clims data_plot["channel_ids"] = self._channel_ids - + data_plot["layer_keys"] = layer_keys - data_plot["colors"] = {k:self.data_plot["colors"][k] for k in layer_keys} + data_plot["colors"] = {k: self.data_plot["colors"][k] for k in layer_keys} list_traces = [traces * self.scaler.value for traces in self._list_traces] data_plot["list_traces"] = list_traces @@ -458,7 +460,6 @@ def _update_plot(self, change=None): fig.canvas.draw() fig.canvas.flush_events() - def plot_sortingview(self, data_plot, **backend_kwargs): import sortingview.views as vv from .utils_sortingview import handle_display_and_url diff --git a/src/spikeinterface/widgets/unit_locations.py b/src/spikeinterface/widgets/unit_locations.py index 8526a95d60..b41ee3508b 100644 --- a/src/spikeinterface/widgets/unit_locations.py +++ b/src/spikeinterface/widgets/unit_locations.py @@ -198,7 +198,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): # a first update self._update_ipywidget() - self.unit_selector.observe(self._update_ipywidget, names='value', type="change") + self.unit_selector.observe(self._update_ipywidget, names="value", type="change") if backend_kwargs["display"]: display(self.widget) diff --git a/src/spikeinterface/widgets/unit_waveforms.py b/src/spikeinterface/widgets/unit_waveforms.py index f01c842b66..8ffc931bf2 100644 --- a/src/spikeinterface/widgets/unit_waveforms.py +++ b/src/spikeinterface/widgets/unit_waveforms.py @@ -277,7 +277,6 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.unit_selector = UnitSelector(data_plot["unit_ids"]) self.unit_selector.value = list(data_plot["unit_ids"])[:1] - self.same_axis_button = widgets.Checkbox( value=False, description="same axis", @@ -309,10 +308,9 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): # a first update self._update_ipywidget(None) - self.unit_selector.observe(self._update_ipywidget, names='value', type="change") + self.unit_selector.observe(self._update_ipywidget, names="value", type="change") for w in self.same_axis_button, self.plot_templates_button, self.hide_axis_button: - w.observe(self._update_ipywidget, names='value', type="change") - + w.observe(self._update_ipywidget, names="value", type="change") if backend_kwargs["display"]: display(self.widget) @@ -340,7 +338,7 @@ def _update_ipywidget(self, change): data_plot["plot_templates"] = plot_templates if data_plot["plot_waveforms"]: data_plot["wfs_by_ids"] = {unit_id: self.we.get_waveforms(unit_id) for unit_id in unit_ids} - + # TODO option for plot_legend backend_kwargs = {} diff --git a/src/spikeinterface/widgets/utils_ipywidgets.py b/src/spikeinterface/widgets/utils_ipywidgets.py index ee6133a990..6e872eca55 100644 --- a/src/spikeinterface/widgets/utils_ipywidgets.py +++ b/src/spikeinterface/widgets/utils_ipywidgets.py @@ -12,12 +12,9 @@ def check_ipywidget_backend(): class TimeSlider(W.HBox): - value = traitlets.Tuple(traitlets.Int(), traitlets.Int(), traitlets.Int()) - - def __init__(self, durations, sampling_frequency, time_range=(0, 1.), **kwargs): - - + + def __init__(self, durations, sampling_frequency, time_range=(0, 1.0), **kwargs): self.num_segments = len(durations) self.frame_limits = [int(sampling_frequency * d) for d in durations] self.sampling_frequency = sampling_frequency @@ -28,81 +25,100 @@ def __init__(self, durations, sampling_frequency, time_range=(0, 1.), **kwargs): self.segment_index = 0 self.value = (start_frame, end_frame, self.segment_index) - - + layout = W.Layout(align_items="center", width="2.5cm", height="1.cm") - but_left = W.Button(description='', disabled=False, button_style='', icon='arrow-left', layout=layout) - but_right = W.Button(description='', disabled=False, button_style='', icon='arrow-right', layout=layout) - + but_left = W.Button(description="", disabled=False, button_style="", icon="arrow-left", layout=layout) + but_right = W.Button(description="", disabled=False, button_style="", icon="arrow-right", layout=layout) + but_left.on_click(self.move_left) but_right.on_click(self.move_right) - self.move_size = W.Dropdown(options=['10 ms', '100 ms', '1 s', '10 s', '1 m', '30 m', '1 h',], # '6 h', '24 h' - value='1 s', - description='', - layout = W.Layout(width="2cm") - ) + self.move_size = W.Dropdown( + options=[ + "10 ms", + "100 ms", + "1 s", + "10 s", + "1 m", + "30 m", + "1 h", + ], # '6 h', '24 h' + value="1 s", + description="", + layout=W.Layout(width="2cm"), + ) # DatetimePicker is only for ipywidget v8 (which is not working in vscode 2023-03) - self.time_label = W.Text(value=f'{time_range[0]}',description='', - disabled=False, layout=W.Layout(width='2.5cm')) - self.time_label.observe(self.time_label_changed, names='value', type="change") - + self.time_label = W.Text( + value=f"{time_range[0]}", description="", disabled=False, layout=W.Layout(width="2.5cm") + ) + self.time_label.observe(self.time_label_changed, names="value", type="change") self.slider = W.IntSlider( - orientation='horizontal', - # description='time:', + orientation="horizontal", + # description='time:', value=start_frame, min=0, max=self.frame_limits[self.segment_index] - 1, readout=False, continuous_update=False, - layout=W.Layout(width=f'70%') + layout=W.Layout(width=f"70%"), ) - - self.slider.observe(self.slider_moved, names='value', type="change") - + + self.slider.observe(self.slider_moved, names="value", type="change") + delta_s = np.diff(self.frame_range) / sampling_frequency - - self.window_sizer = W.BoundedFloatText(value=delta_s, step=1, - min=0.01, max=30., - description='win (s)', - layout=W.Layout(width='auto') - # layout=W.Layout(width=f'10%') - ) - self.window_sizer.observe(self.win_size_changed, names='value', type="change") + + self.window_sizer = W.BoundedFloatText( + value=delta_s, + step=1, + min=0.01, + max=30.0, + description="win (s)", + layout=W.Layout(width="auto") + # layout=W.Layout(width=f'10%') + ) + self.window_sizer.observe(self.win_size_changed, names="value", type="change") self.segment_selector = W.Dropdown(description="segment", options=list(range(self.num_segments))) - self.segment_selector.observe(self.segment_changed, names='value', type="change") + self.segment_selector.observe(self.segment_changed, names="value", type="change") + + super(W.HBox, self).__init__( + children=[ + self.segment_selector, + but_left, + self.move_size, + but_right, + self.slider, + self.time_label, + self.window_sizer, + ], + layout=W.Layout(align_items="center", width="100%", height="100%"), + **kwargs, + ) - super(W.HBox, self).__init__(children=[self.segment_selector, but_left, self.move_size, but_right, - self.slider, self.time_label, self.window_sizer], - layout=W.Layout(align_items="center", width="100%", height="100%"), - **kwargs) - - self.observe(self.value_changed, names=['value'], type="change") + self.observe(self.value_changed, names=["value"], type="change") def value_changed(self, change=None): - - self.unobserve(self.value_changed, names=['value'], type="change") + self.unobserve(self.value_changed, names=["value"], type="change") start, stop, seg_index = self.value if seg_index < 0 or seg_index >= self.num_segments: - self.value = change['old'] + self.value = change["old"] return if start < 0 or stop < 0: - self.value = change['old'] + self.value = change["old"] return if start >= self.frame_limits[seg_index] or start > self.frame_limits[seg_index]: - self.value = change['old'] + self.value = change["old"] return - + self.segment_selector.value = seg_index self.update_time(new_frame=start, update_slider=True, update_label=True) delta_s = (stop - start) / self.sampling_frequency self.window_sizer.value = delta_s - self.observe(self.value_changed, names=['value'], type="change") + self.observe(self.value_changed, names=["value"], type="change") def update_time(self, new_frame=None, new_time=None, update_slider=False, update_label=False): if new_frame is None and new_time is None: @@ -118,25 +134,24 @@ def update_time(self, new_frame=None, new_time=None, update_slider=False, update start_frame = min(self.frame_limits[self.segment_index] - delta, start_frame) start_frame = max(0, start_frame) end_frame = start_frame + delta - + end_frame = min(self.frame_limits[self.segment_index], end_frame) - start_time = start_frame / self.sampling_frequency if update_label: - self.time_label.unobserve(self.time_label_changed, names='value', type="change") - self.time_label.value = f'{start_time}' - self.time_label.observe(self.time_label_changed, names='value', type="change") + self.time_label.unobserve(self.time_label_changed, names="value", type="change") + self.time_label.value = f"{start_time}" + self.time_label.observe(self.time_label_changed, names="value", type="change") if update_slider: - self.slider.unobserve(self.slider_moved, names='value', type="change") + self.slider.unobserve(self.slider_moved, names="value", type="change") self.slider.value = start_frame - self.slider.observe(self.slider_moved, names='value', type="change") - + self.slider.observe(self.slider_moved, names="value", type="change") + self.frame_range = (start_frame, end_frame) self.value = (start_frame, end_frame, self.segment_index) - + def time_label_changed(self, change=None): try: new_time = float(self.time_label.value) @@ -145,39 +160,39 @@ def time_label_changed(self, change=None): if new_time is not None: self.update_time(new_time=new_time, update_slider=True) - def win_size_changed(self, change=None): self.update_time() - + def slider_moved(self, change=None): new_frame = self.slider.value self.update_time(new_frame=new_frame, update_label=True) - + def move(self, sign): - value, units = self.move_size.value.split(' ') + value, units = self.move_size.value.split(" ") value = int(value) - delta_s = (sign * np.timedelta64(value, units)) / np.timedelta64(1, 's') + delta_s = (sign * np.timedelta64(value, units)) / np.timedelta64(1, "s") delta_sample = int(delta_s * self.sampling_frequency) new_frame = self.frame_range[0] + delta_sample self.slider.value = new_frame - + def move_left(self, change=None): self.move(-1) def move_right(self, change=None): self.move(+1) - + def segment_changed(self, change=None): self.segment_index = self.segment_selector.value - self.slider.unobserve(self.slider_moved, names='value', type="change") + self.slider.unobserve(self.slider_moved, names="value", type="change") # self.slider.value = 0 self.slider.max = self.frame_limits[self.segment_index] - 1 - self.slider.observe(self.slider_moved, names='value', type="change") + self.slider.observe(self.slider_moved, names="value", type="change") self.update_time(new_frame=0, update_slider=True, update_label=True) + class ChannelSelector(W.VBox): value = traitlets.List() @@ -211,22 +226,24 @@ def __init__(self, channel_ids, **kwargs): ) hbox = W.HBox(children=[self.slider, self.selector]) - super(W.VBox, self).__init__(children=[channel_label, hbox], - layout=W.Layout(align_items="center"), - # layout=W.Layout(align_items="center", width="100%", height="100%"), - **kwargs) - self.slider.observe(self.on_slider_changed, names=['value'], type="change") - self.selector.observe(self.on_selector_changed, names=['value'], type="change") + super(W.VBox, self).__init__( + children=[channel_label, hbox], + layout=W.Layout(align_items="center"), + # layout=W.Layout(align_items="center", width="100%", height="100%"), + **kwargs, + ) + self.slider.observe(self.on_slider_changed, names=["value"], type="change") + self.selector.observe(self.on_selector_changed, names=["value"], type="change") # TODO external value change # self.observe(self.value_changed, names=['value'], type="change") - + def on_slider_changed(self, change=None): i0, i1 = self.slider.value - - self.selector.unobserve(self.on_selector_changed, names=['value'], type="change") + + self.selector.unobserve(self.on_selector_changed, names=["value"], type="change") self.selector.value = self.channel_ids[i0:i1][::-1] - self.selector.observe(self.on_selector_changed, names=['value'], type="change") + self.selector.observe(self.on_selector_changed, names=["value"], type="change") self.value = self.channel_ids[i0:i1] @@ -235,27 +252,23 @@ def on_selector_changed(self, change=None): channel_ids = channel_ids[::-1] if len(channel_ids) > 0: - self.slider.unobserve(self.on_slider_changed, names=['value'], type="change") + self.slider.unobserve(self.on_slider_changed, names=["value"], type="change") i0 = self.channel_ids.index(channel_ids[0]) i1 = self.channel_ids.index(channel_ids[-1]) + 1 self.slider.value = (i0, i1) - self.slider.observe(self.on_slider_changed, names=['value'], type="change") + self.slider.observe(self.on_slider_changed, names=["value"], type="change") self.value = channel_ids - class ScaleWidget(W.VBox): value = traitlets.Float() - def __init__(self, value=1., factor=1.2, **kwargs): - - assert factor > 1. + def __init__(self, value=1.0, factor=1.2, **kwargs): + assert factor > 1.0 self.factor = factor - self.scale_label = W.Label("Scale", - layout=W.Layout(layout=W.Layout(width='95%'), - justify_content="center")) + self.scale_label = W.Label("Scale", layout=W.Layout(layout=W.Layout(width="95%"), justify_content="center")) self.plus_selector = W.Button( description="", @@ -264,7 +277,7 @@ def __init__(self, value=1., factor=1.2, **kwargs): tooltip="Increase scale", icon="arrow-up", # layout=W.Layout(width=f"{0.8 * width_cm}cm", height=f"{0.4 * height_cm}cm"), - layout=W.Layout(width='60%', align_self='center'), + layout=W.Layout(width="60%", align_self="center"), ) self.minus_selector = W.Button( @@ -274,31 +287,31 @@ def __init__(self, value=1., factor=1.2, **kwargs): tooltip="Decrease scale", icon="arrow-down", # layout=W.Layout(width=f"{0.8 * width_cm}cm", height=f"{0.4 * height_cm}cm"), - layout=W.Layout(width='60%', align_self='center'), + layout=W.Layout(width="60%", align_self="center"), ) self.plus_selector.on_click(self.plus_clicked) self.minus_selector.on_click(self.minus_clicked) - self.value = 1. - super(W.VBox, self).__init__(children=[self.plus_selector, self.scale_label, self.minus_selector], - # layout=W.Layout(align_items="center", width="100%", height="100%"), - **kwargs) + self.value = 1.0 + super(W.VBox, self).__init__( + children=[self.plus_selector, self.scale_label, self.minus_selector], + # layout=W.Layout(align_items="center", width="100%", height="100%"), + **kwargs, + ) self.update_label() - self.observe(self.value_changed, names=['value'], type="change") - + self.observe(self.value_changed, names=["value"], type="change") + def update_label(self): self.scale_label.value = f"Scale: {self.value:0.2f}" - def plus_clicked(self, change=None): self.value = self.value * self.factor def minus_clicked(self, change=None): self.value = self.value / self.factor - def value_changed(self, change=None): self.update_label() @@ -319,20 +332,17 @@ def __init__(self, unit_ids, **kwargs): layout=W.Layout(height="100%", width="2cm"), ) - super(W.VBox, self).__init__(children=[label, self.selector], - layout=W.Layout(align_items="center"), - **kwargs) - - self.selector.observe(self.on_selector_changed, names=['value'], type="change") + super(W.VBox, self).__init__(children=[label, self.selector], layout=W.Layout(align_items="center"), **kwargs) + + self.selector.observe(self.on_selector_changed, names=["value"], type="change") + + self.observe(self.value_changed, names=["value"], type="change") - self.observe(self.value_changed, names=['value'], type="change") - def on_selector_changed(self, change=None): unit_ids = self.selector.value self.value = unit_ids - - def value_changed(self, change=None): - self.selector.unobserve(self.on_selector_changed, names=['value'], type="change") - self.selector.value = change['new'] - self.selector.observe(self.on_selector_changed, names=['value'], type="change") + def value_changed(self, change=None): + self.selector.unobserve(self.on_selector_changed, names=["value"], type="change") + self.selector.value = change["new"] + self.selector.observe(self.on_selector_changed, names=["value"], type="change") From 8e4b43a4f67a92a1497eda5d53f2be2e04f7779f Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Wed, 27 Sep 2023 11:37:12 +0200 Subject: [PATCH 15/18] Update src/spikeinterface/postprocessing/amplitude_scalings.py --- src/spikeinterface/postprocessing/amplitude_scalings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 8823fd6257..7e6c95a875 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -431,7 +431,7 @@ def _are_unit_indices_overlapping(sparsity_mask, i, j): bool True if the unit indices i and j are overlapping, False otherwise """ - if np.sum(np.logical_and(sparsity_mask[i], sparsity_mask[j])) > 0: + if np.any(sparsity_mask[i] & sparsity_mask[j]): return True else: return False From 9dde3760dd62803ea54d5c1f42d560fd907380a0 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 27 Sep 2023 21:31:11 +0200 Subject: [PATCH 16/18] title --- .../benchmark/benchmark_motion_estimation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py index a47b97fb6d..c505676c05 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py @@ -500,8 +500,8 @@ def plot_errors_several_benchmarks(benchmarks, axes=None, show_legend=True, colo axes[2].plot(benchmark.spatial_bins, depth_error, label=benchmark.title, color=c) ax0 = ax = axes[0] - ax.set_xlabel("time [s]") - ax.set_ylabel("error [um]") + ax.set_xlabel("Time [s]") + ax.set_ylabel("Error [μm]") if show_legend: ax.legend() _simpleaxis(ax) @@ -514,7 +514,7 @@ def plot_errors_several_benchmarks(benchmarks, axes=None, show_legend=True, colo ax2 = axes[2] ax2.set_yticks([]) - ax2.set_xlabel("depth [um]") + ax2.set_xlabel("Depth [μm]") # ax.set_ylabel('error') channel_positions = benchmark.recording.get_channel_locations() probe_y_min, probe_y_max = channel_positions[:, 1].min(), channel_positions[:, 1].max() From e0bcb28fc019e7ecde6df3ecdeb504e3c719fccc Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Fri, 29 Sep 2023 07:22:13 +0200 Subject: [PATCH 17/18] move import in --- src/spikeinterface/extractors/cbin_ibl.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/src/spikeinterface/extractors/cbin_ibl.py b/src/spikeinterface/extractors/cbin_ibl.py index 3dde998ca1..bd56208ebe 100644 --- a/src/spikeinterface/extractors/cbin_ibl.py +++ b/src/spikeinterface/extractors/cbin_ibl.py @@ -6,13 +6,6 @@ from spikeinterface.extractors.neuropixels_utils import get_neuropixels_sample_shifts from spikeinterface.core.core_tools import define_function_from_class -try: - import mtscomp - - HAVE_MTSCOMP = True -except: - HAVE_MTSCOMP = False - class CompressedBinaryIblExtractor(BaseRecording): """Load IBL data as an extractor object. @@ -42,7 +35,6 @@ class CompressedBinaryIblExtractor(BaseRecording): """ extractor_name = "CompressedBinaryIbl" - installed = HAVE_MTSCOMP mode = "folder" installation_mesg = "To use the CompressedBinaryIblExtractor, install mtscomp: \n\n pip install mtscomp\n\n" name = "cbin_ibl" @@ -51,7 +43,10 @@ def __init__(self, folder_path, load_sync_channel=False, stream_name="ap"): # this work only for future neo from neo.rawio.spikeglxrawio import read_meta_file, extract_stream_info - assert HAVE_MTSCOMP + try: + import mtscomp + except: + raise ImportError(self.installation_mesg) folder_path = Path(folder_path) # check bands From c8be1a0def93d4a639370a146c5d3244234049c0 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 29 Sep 2023 14:17:37 +0200 Subject: [PATCH 18/18] Fix firing range when bin size is to small (#2054) * Fix firing range when bin size is to small * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- src/spikeinterface/qualitymetrics/misc_metrics.py | 9 +++++++++ .../qualitymetrics/tests/test_metrics_functions.py | 8 ++++++-- .../benchmark/benchmark_motion_estimation.py | 6 ++---- .../benchmark/benchmark_motion_interpolation.py | 8 ++++++-- 4 files changed, 23 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/qualitymetrics/misc_metrics.py b/src/spikeinterface/qualitymetrics/misc_metrics.py index e9726a16da..d3f875959e 100644 --- a/src/spikeinterface/qualitymetrics/misc_metrics.py +++ b/src/spikeinterface/qualitymetrics/misc_metrics.py @@ -602,6 +602,15 @@ def compute_firing_ranges(waveform_extractor, bin_size_s=5, percentiles=(5, 95), if unit_ids is None: unit_ids = sorting.unit_ids + if all( + [ + waveform_extractor.get_num_samples(segment_index) < bin_size_samples + for segment_index in range(waveform_extractor.get_num_segments()) + ] + ): + warnings.warn(f"Bin size of {bin_size_s}s is larger than each segment duration. Firing ranges are set to NaN.") + return {unit_id: np.nan for unit_id in unit_ids} + # for each segment, we compute the firing rate histogram and we concatenate them firing_rate_histograms = {unit_id: np.array([], dtype=float) for unit_id in sorting.unit_ids} for segment_index in range(waveform_extractor.get_num_segments()): diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index 2d63a06b17..8a32c4cee8 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -220,6 +220,10 @@ def test_calculate_firing_range(waveform_extractor_simple): firing_ranges = compute_firing_ranges(we) print(firing_ranges) + with pytest.warns(UserWarning) as w: + firing_ranges_nan = compute_firing_ranges(we, bin_size_s=we.get_total_duration() + 1) + assert np.all([np.isnan(f) for f in firing_ranges_nan.values()]) + def test_calculate_amplitude_cutoff(waveform_extractor_simple): we = waveform_extractor_simple @@ -378,7 +382,7 @@ def test_calculate_drift_metrics(waveform_extractor_simple): if __name__ == "__main__": sim_data = _simulated_data() we = _waveform_extractor_simple() - we_violations = _waveform_extractor_violations(sim_data) + # we_violations = _waveform_extractor_violations(sim_data) # test_calculate_amplitude_cutoff(we) # test_calculate_presence_ratio(we) # test_calculate_amplitude_median(we) @@ -387,4 +391,4 @@ def test_calculate_drift_metrics(waveform_extractor_simple): # test_calculate_drift_metrics(we) # test_synchrony_metrics(we) test_calculate_firing_range(we) - test_calculate_amplitude_cv_metrics(we) + # test_calculate_amplitude_cv_metrics(we) diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py index c505676c05..abf40b2da6 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py @@ -584,13 +584,13 @@ def plot_motions_several_benchmarks(benchmarks): _simpleaxis(ax) -def plot_speed_several_benchmarks(benchmarks, detailed=True, ax=None, colors=None): +def plot_speed_several_benchmarks(benchmarks, detailed=True, ax=None, colors=None): if ax is None: fig, ax = plt.subplots(figsize=(5, 5)) for count, benchmark in enumerate(benchmarks): color = colors[count] if colors is not None else None - + if detailed: bottom = 0 i = 0 @@ -606,8 +606,6 @@ def plot_speed_several_benchmarks(benchmarks, detailed=True, ax=None, colors=No else: total_run_time = np.sum([value for key, value in benchmark.run_times.items()]) ax.bar([count], [total_run_time], color=color, edgecolor="black") - - # ax.legend() ax.set_ylabel("speed (s)") diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py index 8e5afb2e8e..b28b29f17c 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py @@ -193,11 +193,15 @@ def run_sorters(self, skip_already_done=True): recording = self.recordings[case["recording"]] output_folder = self.folder / f"tmp_sortings_{label}" if output_folder.exists() and skip_already_done: - print('already done') + print("already done") sorting = read_sorter_folder(output_folder) else: sorting = run_sorter( - sorter_name, recording, output_folder, **sorter_params, delete_output_folder=self.delete_output_folder + sorter_name, + recording, + output_folder, + **sorter_params, + delete_output_folder=self.delete_output_folder, ) self.sortings[label] = sorting