From e7407356fa0cfc1b3126b449f518738b2b117ffb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20WYNGAARD?= Date: Fri, 12 Jan 2024 11:20:27 +0100 Subject: [PATCH] Clean-up in `merge_sortings` --- .../params_cerebellar_cortex[beta].json | 4 +- src/lussac/modules/merge_sortings.py | 12 ++++-- src/lussac/utils/misc.py | 2 +- tests/modules/test_merge_sortings.py | 43 +++++++++++++++++-- 4 files changed, 50 insertions(+), 11 deletions(-) diff --git a/params_examples/params_cerebellar_cortex[beta].json b/params_examples/params_cerebellar_cortex[beta].json index ccb5ef2..f49d018 100644 --- a/params_examples/params_cerebellar_cortex[beta].json +++ b/params_examples/params_cerebellar_cortex[beta].json @@ -228,7 +228,6 @@ }, "correlogram_validation": { "max_time": 400, - "max_difference": 0.35, "gaussian_std": 10.0, "gaussian_truncate": 5.0 } @@ -240,8 +239,7 @@ "min_similarity": 0.4 }, "correlogram_validation": { - "max_time": 75.0, - "max_difference": 0.14 + "max_time": 75.0 } } }, diff --git a/src/lussac/modules/merge_sortings.py b/src/lussac/modules/merge_sortings.py index 2d0d23e..a791b7f 100644 --- a/src/lussac/modules/merge_sortings.py +++ b/src/lussac/modules/merge_sortings.py @@ -37,12 +37,10 @@ def default_params(self) -> dict[str, Any]: }, 'correlogram_validation': { 'max_time': 70.0, - 'max_difference': 0.25, 'gaussian_std': 0.6, 'gaussian_truncate': 5.0 }, 'waveform_validation': { - 'max_difference': 0.20, 'wvf_extraction': { 'ms_before': 1.0, 'ms_after': 2.0, @@ -439,7 +437,7 @@ def clean_edges(self, graph: nx.Graph, cross_shifts: dict[str, dict[str, np.ndar cross_cont, p_value = utils.estimate_cross_contamination(spike_train1, spike_train2, params['refractory_period'], limit=0.06) # TODO: make 'limit' a parameter. - if p_value > 5e-3 or data['temp_diff'] > 0.10 or data['corr_diff'] > 0.12: # Make these parameters. + if p_value > 5e-3 or data['temp_diff'] < 0.10 or data['corr_diff'] < 0.12: # TODO: Make these parameters. continue # From this point on, the edge is treated as problematic. @@ -459,7 +457,13 @@ def clean_edges(self, graph: nx.Graph, cross_shifts: dict[str, dict[str, np.ndar @staticmethod def separate_communities(graph: nx.Graph) -> None: """ - TODO + Looks for all subgraphs (connected component) and uses the Louvain algorithm to check if + multiple communities are found. If so, the edges between communities are removed. + Additionally, small communities are removed. + Warning: it's recommended to run this function if there are at least 4 analyses. + + @param graph: nx.Graph + The graph containing all the units and connected by their similarity. """ for nodes in list(nx.connected_components(graph)): diff --git a/src/lussac/utils/misc.py b/src/lussac/utils/misc.py index 6e5e036..af8f2fd 100644 --- a/src/lussac/utils/misc.py +++ b/src/lussac/utils/misc.py @@ -133,7 +133,7 @@ def binom_sf(x: int, n: float, p: float) -> float: """ n_array = np.arange(math.floor(n-2), math.ceil(n+3), 1) - n_array = n_array[n_array >= 1] + n_array = n_array[n_array >= 0] res = [scipy.stats.binom.sf(x, n_, p) for n_ in n_array] f = scipy.interpolate.interp1d(n_array, res, kind="quadratic") diff --git a/tests/modules/test_merge_sortings.py b/tests/modules/test_merge_sortings.py index 2a6d077..0083965 100644 --- a/tests/modules/test_merge_sortings.py +++ b/tests/modules/test_merge_sortings.py @@ -162,19 +162,56 @@ def test_compute_difference(merge_sortings_module: MergeSortings) -> None: assert graph[('ks2_low_thresh', 64)][('ms3_best', 80)]['temp_diff'] > 0.65 +def test_clean_edges(data: LussacData) -> None: + data = data.clone() + data.sortings = { + 'ks2_best': data.sortings['ks2_best'].select_units([13, 22, 41]), + 'ms3_best': data.sortings['ms3_best'].select_units([14, 71]) + } + multi_sortings_data = MultiSortingsData(data, data.sortings) + merge_sortings_module = MergeSortings("merge_sortings_edges", multi_sortings_data, "all") + + # Making a graph with approximate parameters for testing. + graph = nx.Graph() + graph.add_node(('ks2_best', 41), contamination=0.002, SNR=6.26, sd_ratio=1.05) # Beautiful SSpk. + graph.add_node(('ks2_best', 13), contamination=0.001, SNR=3.94, sd_ratio=1.11) # Beautiful mossy fiber. + graph.add_node(('ks2_best', 22), contamination=0.316, SNR=4.27, sd_ratio=1.35) # Noisy unit. + graph.add_node(('ms3_best', 71), contamination=0.000, SNR=6.35, sd_ratio=0.89) # Same SSpk (but spikes missing). + graph.add_node(('ms3_best', 14), contamination=0.001, SNR=4.35, sd_ratio=1.16) # Same mossy fiber. + + graph.add_edge(('ks2_best', 41), ('ms3_best', 71), similarity=0.998, corr_diff=0.008, temp_diff=0.051) # Linking SSpk together. + graph.add_edge(('ks2_best', 13), ('ms3_best', 14), similarity=0.964, corr_diff=0.081, temp_diff=0.074) # Linking MF together. + graph.add_edge(('ks2_best', 41), ('ms3_best', 14), similarity=0.052, corr_diff=0.723, temp_diff=0.947) # Erroneous link: edge should be removed but not the nodes. + graph.add_edge(('ks2_best', 22), ('ms3_best', 71), similarity=0.030, corr_diff=0.733, temp_diff=0.969) # node1 is bad --> should get removed + + # Running "clean_edges" + cross_shifts = merge_sortings_module.compute_cross_shifts(30) + merge_sortings_module.clean_edges(graph, cross_shifts, merge_sortings_module.update_params({})) + + # Making sure everything is as expected. + assert graph.number_of_nodes() == 4 + assert graph.number_of_edges() == 2 + assert ('ks2_best', 41) in graph + assert ('ms3_best', 14) in graph + assert ('ks2_best', 22) not in graph + assert ('ms3_best', 71) in graph + assert graph.has_edge(('ks2_best', 41), ('ms3_best', 71)) + assert not graph.has_edge(('ks2_best', 41), ('ms3_best', 14)) + + def test_separate_communities() -> None: - graph = nx.from_edgelist([(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4), (4, 5), (5, 6), (5, 7), (6, 7), (1, 8), (8, 9)]) + graph = nx.from_edgelist([(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4), (4, 5), (5, 6), (5, 7), (6, 7), (1, 8), (8, 9), (10, 11)]) MergeSortings.separate_communities(graph) # Only nodes '8' and '9' need to be removed print(graph.nodes) - assert graph.number_of_nodes() == 8 + assert graph.number_of_nodes() == 10 assert 1 in graph assert 8 not in graph assert 9 not in graph # Only edges (1, 8), (8, 9) and (3, 4) need to be removed - assert graph.number_of_edges() == 13 + assert graph.number_of_edges() == 14 assert not graph.has_edge(4, 5)