diff --git a/src/spikeinterface/generation/drift_tools.py b/src/spikeinterface/generation/drift_tools.py index aa59de8f60..9d28340352 100644 --- a/src/spikeinterface/generation/drift_tools.py +++ b/src/spikeinterface/generation/drift_tools.py @@ -40,9 +40,13 @@ def interpolate_templates(templates_array, source_locations, dest_locations, int source_locations = np.asarray(source_locations) dest_locations = np.asarray(dest_locations) if dest_locations.ndim == 2: - new_shape = templates_array.shape + new_shape = (*templates_array.shape[:2], len(dest_locations)) elif dest_locations.ndim == 3: - new_shape = (dest_locations.shape[0],) + templates_array.shape + new_shape = ( + dest_locations.shape[0], + *templates_array.shape[:2], + dest_locations.shape[1], + ) else: raise ValueError(f"Incorrect dimensions for dest_locations: {dest_locations.ndim}. Dimensions can be 2 or 3. ")