Skip to content

Commit

Permalink
Split PCA
Browse files Browse the repository at this point in the history
  • Loading branch information
yger committed Jan 13, 2025
1 parent 9ff3070 commit 08b13f7
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 59 deletions.
11 changes: 5 additions & 6 deletions src/spikeinterface/sortingcomponents/clustering/circus.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class CircusClustering:
"returns_split_count": True,
},
"radius_um": 100,
"n_svd": [5, 2],
"n_svd": 5,
"ms_before": 0.5,
"ms_after": 0.5,
"rank": 5,
Expand Down Expand Up @@ -94,7 +94,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
wfs = few_wfs[:, :, 0]
from sklearn.decomposition import TruncatedSVD

tsvd = TruncatedSVD(params["n_svd"][0])
tsvd = TruncatedSVD(params["n_svd"])
tsvd.fit(wfs)

model_folder = tmp_folder / "tsvd_model"
Expand Down Expand Up @@ -148,8 +148,8 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
sub_data = all_pc_data[mask]
sub_data = sub_data.reshape(len(sub_data), -1)

if all_pc_data.shape[1] > params["n_svd"][1]:
tsvd = PCA(params["n_svd"][1], whiten=True)
if all_pc_data.shape[1] > params["n_svd"]:
tsvd = PCA(params["n_svd"], whiten=True)
else:
tsvd = PCA(all_pc_data.shape[1], whiten=True)

Expand Down Expand Up @@ -203,8 +203,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
waveforms_sparse_mask=sparse_mask,
min_size_split=min_size,
clusterer_kwargs=d["hdbscan_kwargs"],
n_pca_features=params["n_svd"][1],
scale_n_pca_by_depth=True,
n_pca_features=[2, 4, 6, 8, 10],
),
**params["recursive_kwargs"],
**job_kwargs,
Expand Down
114 changes: 62 additions & 52 deletions src/spikeinterface/sortingcomponents/clustering/split.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,6 @@ def split(
clusterer_kwargs={"min_cluster_size": 25},
min_size_split=25,
n_pca_features=2,
scale_n_pca_by_depth=False,
minimum_overlap_ratio=0.25,
):
local_labels = np.zeros(peak_indices.size, dtype=np.int64)
Expand Down Expand Up @@ -213,74 +212,85 @@ def split(

local_labels[dont_have_channels] = -2
kept = np.flatnonzero(~dont_have_channels)
# print(recursion_level, kept.size, min_size_split)

if kept.size < min_size_split:
return False, None

aligned_wfs = aligned_wfs[kept, :, :]

if not isinstance(n_pca_features, np.ndarray):
n_pca_features = np.array([n_pca_features])

n_pca_features = n_pca_features[n_pca_features <= aligned_wfs.shape[1]]
flatten_features = aligned_wfs.reshape(aligned_wfs.shape[0], -1)

if flatten_features.shape[1] > n_pca_features:
from sklearn.decomposition import PCA
is_split = False

# from sklearn.decomposition import TruncatedSVD
for n_pca in n_pca_features:

if scale_n_pca_by_depth:
# tsvd = TruncatedSVD(n_pca_features * recursion_level)
tsvd = PCA(n_pca_features * recursion_level, whiten=True)
else:
if flatten_features.shape[1] > n_pca:
from sklearn.decomposition import PCA

# from sklearn.decomposition import TruncatedSVD
# tsvd = TruncatedSVD(n_pca_features)
tsvd = PCA(n_pca_features, whiten=True)
final_features = tsvd.fit_transform(flatten_features)
else:
final_features = flatten_features

if clusterer == "hdbscan":
from hdbscan import HDBSCAN

clust = HDBSCAN(**clusterer_kwargs)
clust.fit(final_features)
possible_labels = clust.labels_
is_split = np.setdiff1d(possible_labels, [-1]).size > 1
elif clusterer == "isocut5":
min_cluster_size = clusterer_kwargs["min_cluster_size"]
dipscore, cutpoint = isocut5(final_features[:, 0])
possible_labels = np.zeros(final_features.shape[0])
if dipscore > 1.5:
mask = final_features[:, 0] > cutpoint
if np.sum(mask) > min_cluster_size and np.sum(~mask):
possible_labels[mask] = 1
tsvd = PCA(n_pca, whiten=True)
final_features = tsvd.fit_transform(flatten_features)
else:
final_features = flatten_features

if clusterer == "hdbscan":
from hdbscan import HDBSCAN

clust = HDBSCAN(**clusterer_kwargs)
clust.fit(final_features)
possible_labels = clust.labels_
is_split = np.setdiff1d(possible_labels, [-1]).size > 1
elif clusterer == "isocut5":
min_cluster_size = clusterer_kwargs["min_cluster_size"]
dipscore, cutpoint = isocut5(final_features[:, 0])
possible_labels = np.zeros(final_features.shape[0])
if dipscore > 1.5:
mask = final_features[:, 0] > cutpoint
if np.sum(mask) > min_cluster_size and np.sum(~mask):
possible_labels[mask] = 1
is_split = np.setdiff1d(possible_labels, [-1]).size > 1
else:
is_split = False
else:
is_split = False
else:
raise ValueError(f"wrong clusterer {clusterer}")
raise ValueError(f"wrong clusterer {clusterer}")

# DEBUG = True
DEBUG = False
if DEBUG:
import matplotlib.pyplot as plt

labels_set = np.setdiff1d(possible_labels, [-1])
colors = plt.colormaps["tab10"].resampled(len(labels_set))
colors = {k: colors(i) for i, k in enumerate(labels_set)}
colors[-1] = "k"
fig, axs = plt.subplots(nrows=2)

# DEBUG = True
DEBUG = False
if DEBUG:
import matplotlib.pyplot as plt
flatten_wfs = aligned_wfs.swapaxes(1, 2).reshape(aligned_wfs.shape[0], -1)

labels_set = np.setdiff1d(possible_labels, [-1])
colors = plt.colormaps["tab10"].resampled(len(labels_set))
colors = {k: colors(i) for i, k in enumerate(labels_set)}
colors[-1] = "k"
fix, axs = plt.subplots(nrows=2)
sl = slice(None, None, 10)
for k in np.unique(possible_labels):
mask = possible_labels == k
ax = axs[0]
ax.scatter(final_features[:, 0][mask][sl], final_features[:, 1][mask][sl], s=5, color=colors[k])

flatten_wfs = aligned_wfs.swapaxes(1, 2).reshape(aligned_wfs.shape[0], -1)
ax = axs[1]
ax.plot(flatten_wfs[mask][sl].T, color=colors[k], alpha=0.5)
ax.set_xlabel("PCA features")

sl = slice(None, None, 10)
for k in np.unique(possible_labels):
mask = possible_labels == k
ax = axs[0]
ax.scatter(final_features[:, 0][mask][sl], final_features[:, 1][mask][sl], s=5, color=colors[k])
axs[0].set_title(f"{clusterer} {is_split} {peak_indices[0]} {n_pca}, recursion_level={recursion_level}")
import time

ax = axs[1]
ax.plot(flatten_wfs[mask][sl].T, color=colors[k], alpha=0.5)
plt.savefig(f"split_{recursion_level}_{time.time()}.png")
plt.close()
# plt.show()

axs[0].set_title(f"{clusterer} {is_split} {peak_indices[0]} {np.unique(possible_labels)}")
plt.show()
if is_split:
break

if not is_split:
return is_split, None
Expand All @@ -293,4 +303,4 @@ def split(
split_methods_list = [
LocalFeatureClustering,
]
split_methods_dict = {e.name: e for e in split_methods_list}
split_methods_dict = {e.name: e for e in split_methods_list}
1 change: 0 additions & 1 deletion src/spikeinterface/sortingcomponents/clustering/tdc.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,6 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
waveforms_sparse_mask=sparse_mask,
min_size_split=min_cluster_size,
n_pca_features=3,
scale_n_pca_by_depth=True,
),
recursive=True,
recursive_depth=3,
Expand Down

0 comments on commit 08b13f7

Please sign in to comment.