diff --git a/src/spikeanalysis/spike_data.py b/src/spikeanalysis/spike_data.py index 1d8ab90..a999d69 100644 --- a/src/spikeanalysis/spike_data.py +++ b/src/spikeanalysis/spike_data.py @@ -779,6 +779,8 @@ def _isolation_distance(self, pc_feat: np.array, labels: np.array, this_id: int) md_other = np.sort(cdist(pc_other_clusters, mean_this_cluster, "mahalanobis", VI=cov_matrix)) # md_self = cdist(pc_this_cluster,mean_this_cluster,"mahalanobis", VI=cov_matrix) + if len(md_other.shape) > 1: + md_other = np.squeeze(md_other) isolation_dist = (md_other[n_spikes - 1]) ** 2 else: diff --git a/test/test_spike_data.py b/test/test_spike_data.py index 48af794..f6196fd 100644 --- a/test/test_spike_data.py +++ b/test/test_spike_data.py @@ -327,7 +327,6 @@ def test_generate_qcmetrics(spikes, tmp_path): spikes.pc_feat = pc_feat spikes.CACHING = True spikes.generate_qcmetrics() - assert isinstance(spikes.isolation_distances, np.ndarray) assert len(spikes.isolation_distances) == 2 assert isinstance(spikes.silhouette_scores, np.ndarray)