diff --git a/src/spikeinterface/comparison/basecomparison.py b/src/spikeinterface/comparison/basecomparison.py index 79c784491a..5af20d79b5 100644 --- a/src/spikeinterface/comparison/basecomparison.py +++ b/src/spikeinterface/comparison/basecomparison.py @@ -262,11 +262,11 @@ def get_ordered_agreement_scores(self): indexes = np.arange(scores.shape[1]) order1 = [] for r in range(scores.shape[0]): - possible = indexes[~np.in1d(indexes, order1)] + possible = indexes[~np.isin(indexes, order1)] if possible.size > 0: ind = np.argmax(scores.iloc[r, possible].values) order1.append(possible[ind]) - remain = indexes[~np.in1d(indexes, order1)] + remain = indexes[~np.isin(indexes, order1)] order1.extend(remain) scores = scores.iloc[:, order1] diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index db45e2b25b..20ee7910b4 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -538,7 +538,7 @@ def do_confusion_matrix(event_counts1, event_counts2, match_12, match_event_coun matched_units2 = match_12[match_12 != -1].values unmatched_units1 = match_12[match_12 == -1].index - unmatched_units2 = unit2_ids[~np.in1d(unit2_ids, matched_units2)] + unmatched_units2 = unit2_ids[~np.isin(unit2_ids, matched_units2)] ordered_units1 = np.hstack([matched_units1, unmatched_units1]) ordered_units2 = np.hstack([matched_units2, unmatched_units2]) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index af4970a4ad..08f187895b 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -592,7 +592,7 @@ def _channel_slice(self, channel_ids, renamed_channel_ids=None): def _remove_channels(self, remove_channel_ids): from .channelslice import ChannelSliceRecording - new_channel_ids = self.channel_ids[~np.in1d(self.channel_ids, remove_channel_ids)] + new_channel_ids = self.channel_ids[~np.isin(self.channel_ids, remove_channel_ids)] sub_recording = ChannelSliceRecording(self, new_channel_ids) return sub_recording diff --git a/src/spikeinterface/core/basesnippets.py b/src/spikeinterface/core/basesnippets.py index 737087abc1..f35bc2b266 100644 --- a/src/spikeinterface/core/basesnippets.py +++ b/src/spikeinterface/core/basesnippets.py @@ -139,7 +139,7 @@ def _channel_slice(self, channel_ids, renamed_channel_ids=None): def _remove_channels(self, remove_channel_ids): from .channelslice import ChannelSliceSnippets - new_channel_ids = self.channel_ids[~np.in1d(self.channel_ids, remove_channel_ids)] + new_channel_ids = self.channel_ids[~np.isin(self.channel_ids, remove_channel_ids)] sub_recording = ChannelSliceSnippets(self, new_channel_ids) return sub_recording diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 52f71c2399..056134a24e 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -346,7 +346,7 @@ def remove_units(self, remove_unit_ids): """ from spikeinterface import UnitsSelectionSorting - new_unit_ids = self.unit_ids[~np.in1d(self.unit_ids, remove_unit_ids)] + new_unit_ids = self.unit_ids[~np.isin(self.unit_ids, remove_unit_ids)] new_sorting = UnitsSelectionSorting(self, new_unit_ids) return new_sorting diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 401c498f03..07837bcef7 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -166,7 +166,7 @@ def generate_sorting( ) if empty_units is not None: - keep = ~np.in1d(labels, empty_units) + keep = ~np.isin(labels, empty_units) times = times[keep] labels = labels[keep] @@ -219,7 +219,7 @@ def add_synchrony_to_sorting(sorting, sync_event_ratio=0.3, seed=None): sample_index = spike["sample_index"] if sample_index not in units_used_for_spike: units_used_for_spike[sample_index] = np.array([spike["unit_index"]]) - units_not_used = unit_ids[~np.in1d(unit_ids, units_used_for_spike[sample_index])] + units_not_used = unit_ids[~np.isin(unit_ids, units_used_for_spike[sample_index])] if len(units_not_used) == 0: continue diff --git a/src/spikeinterface/core/tests/test_sparsity.py b/src/spikeinterface/core/tests/test_sparsity.py index a6b94c9b84..75182bf532 100644 --- a/src/spikeinterface/core/tests/test_sparsity.py +++ b/src/spikeinterface/core/tests/test_sparsity.py @@ -34,7 +34,7 @@ def test_ChannelSparsity(): for key, v in sparsity.unit_id_to_channel_ids.items(): assert key in unit_ids - assert np.all(np.in1d(v, channel_ids)) + assert np.all(np.isin(v, channel_ids)) for key, v in sparsity.unit_id_to_channel_indices.items(): assert key in unit_ids diff --git a/src/spikeinterface/curation/mergeunitssorting.py b/src/spikeinterface/curation/mergeunitssorting.py index 264ac3a56d..2d20a58453 100644 --- a/src/spikeinterface/curation/mergeunitssorting.py +++ b/src/spikeinterface/curation/mergeunitssorting.py @@ -59,7 +59,7 @@ def __init__(self, parent_sorting, units_to_merge, new_unit_ids=None, properties else: # we cannot automatically find new names new_unit_ids = [f"merge{i}" for i in range(num_merge)] - if np.any(np.in1d(new_unit_ids, keep_unit_ids)): + if np.any(np.isin(new_unit_ids, keep_unit_ids)): raise ValueError( "Unable to find 'new_unit_ids' because it is a string and parents " "already contain merges. Pass a list of 'new_unit_ids' as an argument." @@ -68,7 +68,7 @@ def __init__(self, parent_sorting, units_to_merge, new_unit_ids=None, properties # dtype int new_unit_ids = list(max(parents_unit_ids) + 1 + np.arange(num_merge, dtype=dtype)) else: - if np.any(np.in1d(new_unit_ids, keep_unit_ids)): + if np.any(np.isin(new_unit_ids, keep_unit_ids)): raise ValueError("'new_unit_ids' already exist in the sorting.unit_ids. Provide new ones") assert len(new_unit_ids) == num_merge, "new_unit_ids must have the same size as units_to_merge" diff --git a/src/spikeinterface/extractors/bids.py b/src/spikeinterface/extractors/bids.py index 02e7d5677d..8b70722652 100644 --- a/src/spikeinterface/extractors/bids.py +++ b/src/spikeinterface/extractors/bids.py @@ -76,7 +76,7 @@ def _read_probe_group(folder, bids_name, recording_channel_ids): contact_ids = channels["contact_id"].values.astype("U") # extracting information of requested channels - keep = np.in1d(channel_ids, recording_channel_ids) + keep = np.isin(channel_ids, recording_channel_ids) channel_ids = channel_ids[keep] contact_ids = contact_ids[keep] diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 5a0148c5c4..5a3542cdf9 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -47,9 +47,9 @@ def _set_params( def _select_extension_data(self, unit_ids): old_unit_ids = self.waveform_extractor.sorting.unit_ids - unit_inds = np.flatnonzero(np.in1d(old_unit_ids, unit_ids)) + unit_inds = np.flatnonzero(np.isin(old_unit_ids, unit_ids)) - spike_mask = np.in1d(self.spikes["unit_index"], unit_inds) + spike_mask = np.isin(self.spikes["unit_index"], unit_inds) new_amplitude_scalings = self._extension_data["amplitude_scalings"][spike_mask] return dict(amplitude_scalings=new_amplitude_scalings) diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index 62a4e2c320..b6f25cda95 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -28,13 +28,13 @@ def _select_extension_data(self, unit_ids): # load filter and save amplitude files sorting = self.waveform_extractor.sorting spikes = sorting.to_spike_vector(concatenated=False) - (keep_unit_indices,) = np.nonzero(np.in1d(sorting.unit_ids, unit_ids)) + (keep_unit_indices,) = np.nonzero(np.isin(sorting.unit_ids, unit_ids)) new_extension_data = dict() for seg_index in range(sorting.get_num_segments()): amp_data_name = f"amplitude_segment_{seg_index}" amps = self._extension_data[amp_data_name] - filtered_idxs = np.in1d(spikes[seg_index]["unit_index"], keep_unit_indices) + filtered_idxs = np.isin(spikes[seg_index]["unit_index"], keep_unit_indices) new_extension_data[amp_data_name] = amps[filtered_idxs] return new_extension_data diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index c6f498f7e8..4cbe4d665e 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -32,9 +32,9 @@ def _set_params(self, ms_before=0.5, ms_after=0.5, method="center_of_mass", meth def _select_extension_data(self, unit_ids): old_unit_ids = self.waveform_extractor.sorting.unit_ids - unit_inds = np.flatnonzero(np.in1d(old_unit_ids, unit_ids)) + unit_inds = np.flatnonzero(np.isin(old_unit_ids, unit_ids)) - spike_mask = np.in1d(self.spikes["unit_index"], unit_inds) + spike_mask = np.isin(self.spikes["unit_index"], unit_inds) new_spike_locations = self._extension_data["spike_locations"][spike_mask] return dict(spike_locations=new_spike_locations) diff --git a/src/spikeinterface/preprocessing/interpolate_bad_channels.py b/src/spikeinterface/preprocessing/interpolate_bad_channels.py index e634d55e7f..95ecd0fe52 100644 --- a/src/spikeinterface/preprocessing/interpolate_bad_channels.py +++ b/src/spikeinterface/preprocessing/interpolate_bad_channels.py @@ -49,7 +49,7 @@ def __init__(self, recording, bad_channel_ids, sigma_um=None, p=1.3, weights=Non self.bad_channel_ids = bad_channel_ids self._bad_channel_idxs = recording.ids_to_indices(self.bad_channel_ids) - self._good_channel_idxs = ~np.in1d(np.arange(recording.get_num_channels()), self._bad_channel_idxs) + self._good_channel_idxs = ~np.isin(np.arange(recording.get_num_channels()), self._bad_channel_idxs) self._bad_channel_idxs.setflags(write=False) if sigma_um is None: diff --git a/src/spikeinterface/qualitymetrics/misc_metrics.py b/src/spikeinterface/qualitymetrics/misc_metrics.py index ee28485983..4e871492f8 100644 --- a/src/spikeinterface/qualitymetrics/misc_metrics.py +++ b/src/spikeinterface/qualitymetrics/misc_metrics.py @@ -544,7 +544,7 @@ def compute_synchrony_metrics(waveform_extractor, synchrony_sizes=(2, 4, 8), **k # some segments/units might have no spikes if len(spikes_per_unit) == 0: continue - spike_complexity = complexity[np.in1d(unique_spike_index, spikes_per_unit["sample_index"])] + spike_complexity = complexity[np.isin(unique_spike_index, spikes_per_unit["sample_index"])] for synchrony_size in synchrony_sizes: synchrony_counts[synchrony_size][unit_index] += np.count_nonzero(spike_complexity >= synchrony_size) diff --git a/src/spikeinterface/qualitymetrics/pca_metrics.py b/src/spikeinterface/qualitymetrics/pca_metrics.py index 59000211d4..ed06f7d738 100644 --- a/src/spikeinterface/qualitymetrics/pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/pca_metrics.py @@ -152,8 +152,8 @@ def calculate_pc_metrics( neighbor_unit_ids = unit_ids neighbor_channel_indices = we.channel_ids_to_indices(neighbor_channel_ids) - labels = all_labels[np.in1d(all_labels, neighbor_unit_ids)] - pcs = all_pcs[np.in1d(all_labels, neighbor_unit_ids)][:, :, neighbor_channel_indices] + labels = all_labels[np.isin(all_labels, neighbor_unit_ids)] + pcs = all_pcs[np.isin(all_labels, neighbor_unit_ids)][:, :, neighbor_channel_indices] pcs_flat = pcs.reshape(pcs.shape[0], -1) func_args = ( @@ -506,7 +506,7 @@ def nearest_neighbors_isolation( other_units_ids = [ unit_id for unit_id in other_units_ids - if np.sum(np.in1d(sparsity.unit_id_to_channel_indices[unit_id], closest_chans_target_unit)) + if np.sum(np.isin(sparsity.unit_id_to_channel_indices[unit_id], closest_chans_target_unit)) >= (n_channels_target_unit * min_spatial_overlap) ] @@ -536,10 +536,10 @@ def nearest_neighbors_isolation( if waveform_extractor.is_sparse(): # in this case, waveforms are sparse so we need to do some smart indexing waveforms_target_unit_sampled = waveforms_target_unit_sampled[ - :, :, np.in1d(closest_chans_target_unit, common_channel_idxs) + :, :, np.isin(closest_chans_target_unit, common_channel_idxs) ] waveforms_other_unit_sampled = waveforms_other_unit_sampled[ - :, :, np.in1d(closest_chans_other_unit, common_channel_idxs) + :, :, np.isin(closest_chans_other_unit, common_channel_idxs) ] else: waveforms_target_unit_sampled = waveforms_target_unit_sampled[:, :, common_channel_idxs] diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py index 07c7db155c..772c99bc0a 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py @@ -502,7 +502,7 @@ def plot_errors_matching(benchmark, comp, unit_id, nb_spikes=200, metric="cosine seg_num = 0 # TODO: make compatible with multiple segments idx_1 = np.where(comp.get_labels1(unit_id)[seg_num] == label) idx_2 = benchmark.we.get_sampled_indices(unit_id)["spike_index"] - intersection = np.where(np.in1d(idx_2, idx_1))[0] + intersection = np.where(np.isin(idx_2, idx_1))[0] intersection = np.random.permutation(intersection)[:nb_spikes] if len(intersection) == 0: print(f"No {label}s found for unit {unit_id}") @@ -552,7 +552,7 @@ def plot_errors_matching_all_neurons(benchmark, comp, nb_spikes=200, metric="cos for label in ["TP", "FN"]: idx_1 = np.where(comp.get_labels1(unit_id) == label)[0] - intersection = np.where(np.in1d(idx_2, idx_1))[0] + intersection = np.where(np.isin(idx_2, idx_1))[0] intersection = np.random.permutation(intersection)[:nb_spikes] wfs_sliced = wfs[intersection, :, :] diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py index 1514a63dd4..73497a59fd 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py @@ -133,7 +133,7 @@ def run(self, peaks=None, positions=None, delta=0.2): matches = make_matching_events(times2, spikes1["sample_index"], int(delta * self.sampling_rate / 1000)) self.good_matches = matches["index1"] - garbage_matches = ~np.in1d(np.arange(len(times2)), self.good_matches) + garbage_matches = ~np.isin(np.arange(len(times2)), self.good_matches) garbage_channels = self.peaks["channel_index"][garbage_matches] garbage_peaks = times2[garbage_matches] nb_garbage = len(garbage_peaks) @@ -365,7 +365,7 @@ def plot_clusters_amplitudes(self, title=None, show_probe=False, clim=(-100, 0), idx = self.waveforms["full_gt"].get_sampled_indices(unit_id)["spike_index"] all_spikes = self.waveforms["full_gt"].sorting.get_unit_spike_train(unit_id) - mask = np.in1d(self.gt_peaks["sample_index"], all_spikes[idx]) + mask = np.isin(self.gt_peaks["sample_index"], all_spikes[idx]) colors = scalarMap.to_rgba(self.gt_peaks["amplitude"][mask]) ax.scatter(self.gt_positions["x"][mask], self.gt_positions["y"][mask], c=colors, s=1, alpha=0.5) x_mean, y_mean = (self.gt_positions["x"][mask].mean(), self.gt_positions["y"][mask].mean()) @@ -391,7 +391,7 @@ def plot_clusters_amplitudes(self, title=None, show_probe=False, clim=(-100, 0), idx = self.waveforms["gt"].get_sampled_indices(unit_id)["spike_index"] all_spikes = self.waveforms["gt"].sorting.get_unit_spike_train(unit_id) - mask = np.in1d(self.sliced_gt_peaks["sample_index"], all_spikes[idx]) + mask = np.isin(self.sliced_gt_peaks["sample_index"], all_spikes[idx]) colors = scalarMap.to_rgba(self.sliced_gt_peaks["amplitude"][mask]) ax.scatter( self.sliced_gt_positions["x"][mask], self.sliced_gt_positions["y"][mask], c=colors, s=1, alpha=0.5 @@ -420,7 +420,7 @@ def plot_clusters_amplitudes(self, title=None, show_probe=False, clim=(-100, 0), idx = self.waveforms["garbage"].get_sampled_indices(unit_id)["spike_index"] all_spikes = self.waveforms["garbage"].sorting.get_unit_spike_train(unit_id) - mask = np.in1d(self.garbage_peaks["sample_index"], all_spikes[idx]) + mask = np.isin(self.garbage_peaks["sample_index"], all_spikes[idx]) colors = scalarMap.to_rgba(self.garbage_peaks["amplitude"][mask]) ax.scatter(self.garbage_positions["x"][mask], self.garbage_positions["y"][mask], c=colors, s=1, alpha=0.5) x_mean, y_mean = (self.garbage_positions["x"][mask].mean(), self.garbage_positions["y"][mask].mean()) diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 6edf5af16b..23fdbf1979 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -30,7 +30,7 @@ def _split_waveforms( local_labels_with_noise = clustering[0] cluster_probability = clustering[2] (persistent_clusters,) = np.nonzero(cluster_probability > probability_thr) - local_labels_with_noise[~np.in1d(local_labels_with_noise, persistent_clusters)] = -1 + local_labels_with_noise[~np.isin(local_labels_with_noise, persistent_clusters)] = -1 # remove super small cluster labels, count = np.unique(local_labels_with_noise[:valid_size], return_counts=True) @@ -43,7 +43,7 @@ def _split_waveforms( to_remove = labels[(count / valid_size) < minimum_cluster_size_ratio] # ~ print('to_remove', to_remove, count / valid_size) if to_remove.size > 0: - local_labels_with_noise[np.in1d(local_labels_with_noise, to_remove)] = -1 + local_labels_with_noise[np.isin(local_labels_with_noise, to_remove)] = -1 local_labels_with_noise[valid_size:] = -2 @@ -123,7 +123,7 @@ def _split_waveforms_nested( active_labels_with_noise = clustering[0] cluster_probability = clustering[2] (persistent_clusters,) = np.nonzero(clustering[2] > probability_thr) - active_labels_with_noise[~np.in1d(active_labels_with_noise, persistent_clusters)] = -1 + active_labels_with_noise[~np.isin(active_labels_with_noise, persistent_clusters)] = -1 active_labels = active_labels_with_noise[active_ind < valid_size] active_labels_set = np.unique(active_labels) @@ -381,9 +381,9 @@ def auto_clean_clustering( continue wfs0 = wfs_arrays[label0] - wfs0 = wfs0[:, :, np.in1d(channel_inds0, used_chans)] + wfs0 = wfs0[:, :, np.isin(channel_inds0, used_chans)] wfs1 = wfs_arrays[label1] - wfs1 = wfs1[:, :, np.in1d(channel_inds1, used_chans)] + wfs1 = wfs1[:, :, np.isin(channel_inds1, used_chans)] # TODO : remove assert wfs0.shape[2] == wfs1.shape[2] diff --git a/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py b/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py index aeec14158f..08ce9f6791 100644 --- a/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py +++ b/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py @@ -198,7 +198,7 @@ def _find_clusters(cls, recording, peaks, wfs_arrays, sparsity_mask, noise, d): for chan_ind in prev_local_chan_inds: if total_count[chan_ind] == 0: continue - # ~ inds, = np.nonzero(np.in1d(peaks['channel_index'], closest_channels[chan_ind]) & (peak_labels==0)) + # ~ inds, = np.nonzero(np.isin(peaks['channel_index'], closest_channels[chan_ind]) & (peak_labels==0)) (inds,) = np.nonzero((peaks["channel_index"] == chan_ind) & (peak_labels == 0)) if inds.size <= d["min_spike_on_channel"]: chan_amps[chan_ind] = 0.0 @@ -235,12 +235,12 @@ def _find_clusters(cls, recording, peaks, wfs_arrays, sparsity_mask, noise, d): (wf_chans,) = np.nonzero(sparsity_mask[chan_ind]) # TODO: only for debug, remove later - assert np.all(np.in1d(local_chan_inds, wf_chans)) + assert np.all(np.isin(local_chan_inds, wf_chans)) # none label spikes wfs_chan = wfs_chan[inds, :, :] # only some channels - wfs_chan = wfs_chan[:, :, np.in1d(wf_chans, local_chan_inds)] + wfs_chan = wfs_chan[:, :, np.isin(wf_chans, local_chan_inds)] wfs.append(wfs_chan) # put noise to enhance clusters @@ -517,7 +517,7 @@ def _collect_sparse_waveforms(peaks, wfs_arrays, closest_channels, peak_labels, (wf_chans,) = np.nonzero(sparsity_mask[chan_ind]) # print('wf_chans', wf_chans) # TODO: only for debug, remove later - assert np.all(np.in1d(wanted_chans, wf_chans)) + assert np.all(np.isin(wanted_chans, wf_chans)) wfs_chan = wfs_arrays[chan_ind] # TODO: only for debug, remove later @@ -525,7 +525,7 @@ def _collect_sparse_waveforms(peaks, wfs_arrays, closest_channels, peak_labels, wfs_chan = wfs_chan[inds, :, :] # only some channels - wfs_chan = wfs_chan[:, :, np.in1d(wf_chans, wanted_chans)] + wfs_chan = wfs_chan[:, :, np.isin(wf_chans, wanted_chans)] wfs.append(wfs_chan) wfs = np.concatenate(wfs, axis=0) diff --git a/src/spikeinterface/widgets/unit_waveforms_density_map.py b/src/spikeinterface/widgets/unit_waveforms_density_map.py index e8a6868e92..b3391c0712 100644 --- a/src/spikeinterface/widgets/unit_waveforms_density_map.py +++ b/src/spikeinterface/widgets/unit_waveforms_density_map.py @@ -103,7 +103,7 @@ def __init__( if same_axis and not np.array_equal(chan_inds, shared_chan_inds): # add more channels if necessary wfs_ = np.zeros((wfs.shape[0], wfs.shape[1], shared_chan_inds.size), dtype=float) - mask = np.in1d(shared_chan_inds, chan_inds) + mask = np.isin(shared_chan_inds, chan_inds) wfs_[:, :, mask] = wfs wfs_[:, :, ~mask] = np.nan wfs = wfs_