Skip to content

Commit

Permalink
Clean-up in merge_sortings
Browse files Browse the repository at this point in the history
  • Loading branch information
DradeAW committed Jan 12, 2024
1 parent 2d6f208 commit e740735
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 11 deletions.
4 changes: 1 addition & 3 deletions params_examples/params_cerebellar_cortex[beta].json
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,6 @@
},
"correlogram_validation": {
"max_time": 400,
"max_difference": 0.35,
"gaussian_std": 10.0,
"gaussian_truncate": 5.0
}
Expand All @@ -240,8 +239,7 @@
"min_similarity": 0.4
},
"correlogram_validation": {
"max_time": 75.0,
"max_difference": 0.14
"max_time": 75.0
}
}
},
Expand Down
12 changes: 8 additions & 4 deletions src/lussac/modules/merge_sortings.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,10 @@ def default_params(self) -> dict[str, Any]:
},
'correlogram_validation': {
'max_time': 70.0,
'max_difference': 0.25,
'gaussian_std': 0.6,
'gaussian_truncate': 5.0
},
'waveform_validation': {
'max_difference': 0.20,
'wvf_extraction': {
'ms_before': 1.0,
'ms_after': 2.0,
Expand Down Expand Up @@ -439,7 +437,7 @@ def clean_edges(self, graph: nx.Graph, cross_shifts: dict[str, dict[str, np.ndar

cross_cont, p_value = utils.estimate_cross_contamination(spike_train1, spike_train2, params['refractory_period'], limit=0.06) # TODO: make 'limit' a parameter.

if p_value > 5e-3 or data['temp_diff'] > 0.10 or data['corr_diff'] > 0.12: # Make these parameters.
if p_value > 5e-3 or data['temp_diff'] < 0.10 or data['corr_diff'] < 0.12: # TODO: Make these parameters.
continue

# From this point on, the edge is treated as problematic.
Expand All @@ -459,7 +457,13 @@ def clean_edges(self, graph: nx.Graph, cross_shifts: dict[str, dict[str, np.ndar
@staticmethod
def separate_communities(graph: nx.Graph) -> None:
"""
TODO
Looks for all subgraphs (connected component) and uses the Louvain algorithm to check if
multiple communities are found. If so, the edges between communities are removed.
Additionally, small communities are removed.
Warning: it's recommended to run this function if there are at least 4 analyses.
@param graph: nx.Graph
The graph containing all the units and connected by their similarity.
"""

for nodes in list(nx.connected_components(graph)):
Expand Down
2 changes: 1 addition & 1 deletion src/lussac/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def binom_sf(x: int, n: float, p: float) -> float:
"""

n_array = np.arange(math.floor(n-2), math.ceil(n+3), 1)
n_array = n_array[n_array >= 1]
n_array = n_array[n_array >= 0]

res = [scipy.stats.binom.sf(x, n_, p) for n_ in n_array]
f = scipy.interpolate.interp1d(n_array, res, kind="quadratic")
Expand Down
43 changes: 40 additions & 3 deletions tests/modules/test_merge_sortings.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,19 +162,56 @@ def test_compute_difference(merge_sortings_module: MergeSortings) -> None:
assert graph[('ks2_low_thresh', 64)][('ms3_best', 80)]['temp_diff'] > 0.65


def test_clean_edges(data: LussacData) -> None:
data = data.clone()
data.sortings = {
'ks2_best': data.sortings['ks2_best'].select_units([13, 22, 41]),
'ms3_best': data.sortings['ms3_best'].select_units([14, 71])
}
multi_sortings_data = MultiSortingsData(data, data.sortings)
merge_sortings_module = MergeSortings("merge_sortings_edges", multi_sortings_data, "all")

# Making a graph with approximate parameters for testing.
graph = nx.Graph()
graph.add_node(('ks2_best', 41), contamination=0.002, SNR=6.26, sd_ratio=1.05) # Beautiful SSpk.
graph.add_node(('ks2_best', 13), contamination=0.001, SNR=3.94, sd_ratio=1.11) # Beautiful mossy fiber.
graph.add_node(('ks2_best', 22), contamination=0.316, SNR=4.27, sd_ratio=1.35) # Noisy unit.
graph.add_node(('ms3_best', 71), contamination=0.000, SNR=6.35, sd_ratio=0.89) # Same SSpk (but spikes missing).
graph.add_node(('ms3_best', 14), contamination=0.001, SNR=4.35, sd_ratio=1.16) # Same mossy fiber.

graph.add_edge(('ks2_best', 41), ('ms3_best', 71), similarity=0.998, corr_diff=0.008, temp_diff=0.051) # Linking SSpk together.
graph.add_edge(('ks2_best', 13), ('ms3_best', 14), similarity=0.964, corr_diff=0.081, temp_diff=0.074) # Linking MF together.
graph.add_edge(('ks2_best', 41), ('ms3_best', 14), similarity=0.052, corr_diff=0.723, temp_diff=0.947) # Erroneous link: edge should be removed but not the nodes.
graph.add_edge(('ks2_best', 22), ('ms3_best', 71), similarity=0.030, corr_diff=0.733, temp_diff=0.969) # node1 is bad --> should get removed

# Running "clean_edges"
cross_shifts = merge_sortings_module.compute_cross_shifts(30)
merge_sortings_module.clean_edges(graph, cross_shifts, merge_sortings_module.update_params({}))

# Making sure everything is as expected.
assert graph.number_of_nodes() == 4
assert graph.number_of_edges() == 2
assert ('ks2_best', 41) in graph
assert ('ms3_best', 14) in graph
assert ('ks2_best', 22) not in graph
assert ('ms3_best', 71) in graph
assert graph.has_edge(('ks2_best', 41), ('ms3_best', 71))
assert not graph.has_edge(('ks2_best', 41), ('ms3_best', 14))


def test_separate_communities() -> None:
graph = nx.from_edgelist([(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4), (4, 5), (5, 6), (5, 7), (6, 7), (1, 8), (8, 9)])
graph = nx.from_edgelist([(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4), (4, 5), (5, 6), (5, 7), (6, 7), (1, 8), (8, 9), (10, 11)])
MergeSortings.separate_communities(graph)

# Only nodes '8' and '9' need to be removed
print(graph.nodes)
assert graph.number_of_nodes() == 8
assert graph.number_of_nodes() == 10
assert 1 in graph
assert 8 not in graph
assert 9 not in graph

# Only edges (1, 8), (8, 9) and (3, 4) need to be removed
assert graph.number_of_edges() == 13
assert graph.number_of_edges() == 14
assert not graph.has_edge(4, 5)


Expand Down

0 comments on commit e740735

Please sign in to comment.