From a9dee3349c5d3b694a8e5713a7baed337356014a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20WYNGAARD?= Date: Fri, 15 Dec 2023 14:23:17 +0100 Subject: [PATCH] Testing new `merge_sortings` Testing a new way to remove bad units (i.e. merged, mua) in the `merge_sortings` module. --- src/lussac/modules/merge_sortings.py | 110 +++++++++++++++------------ tests/core/test_modulefactory.py | 4 +- 2 files changed, 63 insertions(+), 51 deletions(-) diff --git a/src/lussac/modules/merge_sortings.py b/src/lussac/modules/merge_sortings.py index b94b157..494340c 100644 --- a/src/lussac/modules/merge_sortings.py +++ b/src/lussac/modules/merge_sortings.py @@ -53,7 +53,7 @@ def default_params(self) -> dict[str, Any]: 'num_channels': 5 }, 'merge_check': { - 'cross_cont_limit': 0.22 + 'cross_cont_limit': 0.10 } } @@ -83,7 +83,7 @@ def run(self, params: dict[str, Any]) -> dict[str, si.BaseSorting]: if params['merge_check']: self.remove_merged_units(graph, cross_shifts, params['refractory_period'], params['merge_check']) - self.clean_graph(graph, params) + self.clean_graph(graph, cross_shifts, params) self._save_graph(graph, "final_graph") merged_sorting = self.merge_sortings(graph, params) @@ -176,12 +176,12 @@ def _compute_graph(self, similarity_matrices: dict[str, dict[str, np.ndarray]], wvf_extractor = si.extract_waveforms(recording_f, sorting, folder=self.tmp_folder / f"wvfs_{name}", ms_before=1.5, ms_after=1.5, max_spikes_per_unit=150, sparse=False) spost.compute_spike_amplitudes(wvf_extractor, peak_sign="both", return_scaled=recording_f.has_scaled()) contamination, _ = sqm.compute_refrac_period_violations(wvf_extractor, refractory_period_ms=refractory_period, censored_period_ms=censored_period) - # sd_ratio = sqm.compute_sd_ratio(wvf_extractor) + sd_ratio = sqm.compute_sd_ratio(wvf_extractor) for unit_id in sorting.unit_ids: attr = {key: sorting.get_unit_property(unit_id, key) for key in sorting.get_property_keys() if key.startswith('gt_')} attr['contamination'] = contamination[unit_id] - # attr['sd_ratio'] = sd_ratio[unit_id] + attr['sd_ratio'] = sd_ratio[unit_id] graph.add_node((name, unit_id), **attr) del wvf_extractor @@ -255,43 +255,33 @@ def remove_merged_units(self, graph: nx.Graph, cross_shifts: dict[str, dict[str, if sorting1_name != sorting2_name: continue + spike_train = self.sortings[sorting_name].get_unit_spike_train(unit_id) spike_train1 = self.sortings[sorting1_name].get_unit_spike_train(unit_id1) + cross_shifts[sorting_name][sorting1_name][unit_ind, unit_ind1] spike_train2 = self.sortings[sorting2_name].get_unit_spike_train(unit_id2) + cross_shifts[sorting_name][sorting2_name][unit_ind, unit_ind2] - C1 = graph.nodes[(sorting1_name, unit_id1)]['contamination'] - C2 = graph.nodes[(sorting2_name, unit_id2)]['contamination'] + C = graph.nodes[node]['contamination'] + C1 = graph.nodes[node1]['contamination'] + C2 = graph.nodes[node2]['contamination'] + sd = graph.nodes[node]['sd_ratio'] + sd1 = graph.nodes[node1]['sd_ratio'] + sd2 = graph.nodes[node2]['sd_ratio'] if C2 < C1: spike_train1, spike_train2 = spike_train2, spike_train1 cross_cont, p_value = utils.estimate_cross_contamination(spike_train1, spike_train2, refractory_period, limit=params['cross_cont_limit']) logs.write(f"\nUnit {node} is connected to {node1} and {node2}:\n") - logs.write(f"\tcross-cont = {cross_cont:.2%} (p_value={p_value:.3f})\n") - logs.write(f"\tC1 = {C1:.1%} ; C2 = {C2:.1%}\n") - if p_value > 5e-3: # No problem, node1 and node2 are probably just a split. + logs.write(f"\tcross-cont = {cross_cont:.2%} (p_value={p_value:.2e})\n") + logs.write(f"\tC = {C:.1%} ; C1 = {C1:.1%} ; C2 = {C2:.1%}\n") + logs.write(f"\tsd = {sd:.3f} ; sd1 = {sd1:.3f} ; sd2 = {sd2:.3f}\n") + if p_value > 5e-3: # No problem detected, node1 and node2 are probably just a split. continue - spike_train = self.sortings[sorting_name].get_unit_spike_train(unit_id) - C = graph.nodes[(sorting_name, unit_id)]['contamination'] - cross_cont1, p_value1 = utils.estimate_cross_contamination(spike_train1, spike_train, refractory_period, limit=0.1) - cross_cont2, p_value2 = utils.estimate_cross_contamination(spike_train2, spike_train, refractory_period, limit=0.1) - p_value1, p_value2 = 1 - p_value1, 1 - p_value2 # Reverse the p-values because we want to know the probability <= and not >=. - - logs.write(f"\tcheck1 = {cross_cont1:.2%} (p_value={p_value1:.3f})\n") - logs.write(f"\tcheck2 = {cross_cont2:.2%} (p_value={p_value2:.3f})\n") - - if p_value1 < 1e-3: # node2 is the problematic unit. - if node2 not in nodes_to_remove and C2 > C + 0.02: - nodes_to_remove.append(node2) - logs.write(f"\t-> Unit {node2} is considered a problematic unit.\n") - continue - elif p_value2 < 1e-3: # node1 is the problematic unit. - if node1 not in nodes_to_remove and C1 > C + 0.02: - nodes_to_remove.append(node1) - logs.write(f"\t-> Unit {node1} is considered a problematic unit.\n") - continue - elif min(C1, C2) < 0.1 and C > min(C1, C2) + 0.02: # node is probably a merged unit. - if node not in nodes_to_remove: - nodes_to_remove.append(node) - logs.write(f"\t-> Unit {node} is considered a merged unit.\n") + cases_removed = np.array([ # If at least one is true, 'node' is removed + p_value < 1e-8 and C > max(C1, C2), + p_value < 5e-3 and sd > max(sd1, sd2) and C > max(C1, C2) + ], dtype=bool) + if np.any(cases_removed): + nodes_to_remove.append(node) + break # Don't need to check the other ones. if len(nodes_to_remove) > 0: logs.write("\nRemoved units:\n") @@ -398,39 +388,61 @@ def compute_waveform_difference(self, graph: nx.Graph, cross_shifts: dict[str, d template_diff = np.sum(np.abs(template1 - template2)) / np.sum(np.abs(template1) + np.abs(template2)) graph[node1][node2]['temp_diff'] = template_diff - def clean_graph(self, graph: nx.Graph, params: dict) -> None: # pragma: no cover (not implemented yet) + def clean_graph(self, graph: nx.Graph, cross_shifts: dict[str, dict[str, np.ndarray]], params: dict) -> None: """ TODO - @param graph: - @param params: - @return: + @param graph: nx.Graph + The graph containing all the units and connected by their similarity. + @param cross_shifts: dict[str, dict[str, np.ndarray]] + The cross-shifts between units. + @param params: dict + The parameters of the merge_sorting module. """ + # TODO: Make these parameters. + sd_threshold = 0.20 + C_threshold = 0.08 nodes_to_remove = [] + edges_to_remove = [] for node1, node2, data in list(graph.edges(data=True)): - if node1 in nodes_to_remove or node2 in nodes_to_remove: - continue + sorting_name1, unit_id1 = node1 + sorting_name2, unit_id2 = node2 - if 'corr_diff' not in data and 'temp_diff' not in data: - continue + unit_ind1 = self.sortings[sorting_name1].id_to_index(unit_id1) + unit_ind2 = self.sortings[sorting_name2].id_to_index(unit_id2) - if ('corr_diff' in data and data['corr_diff'] > 0.25) or ('temp_diff' in data and data['temp_diff'] > 0.20): - data['problem'] = True - sorting1_name, unit_id1 = node1 - sorting2_name, unit_id2 = node2 + C1 = graph.nodes[node1]['contamination'] + C2 = graph.nodes[node2]['contamination'] - C1 = utils.estimate_contamination(self.sortings[sorting1_name].get_unit_spike_train(unit_id1), refractory_period=params['refractory_period']) - C2 = utils.estimate_contamination(self.sortings[sorting2_name].get_unit_spike_train(unit_id2), refractory_period=params['refractory_period']) + sd1 = graph.nodes[node1]['sd_ratio'] + sd2 = graph.nodes[node2]['sd_ratio'] + + spike_train1 = self.sortings[sorting_name1].get_unit_spike_train(unit_id1).astype(np.int64) + spike_train2 = self.sortings[sorting_name2].get_unit_spike_train(unit_id2).astype(np.int64) + cross_shifts[sorting_name1][sorting_name2][unit_ind1, unit_ind2] + if C2 < C1: + spike_train1, spike_train2 = spike_train2, spike_train1 + + cross_cont, p_value = utils.estimate_cross_contamination(spike_train1, spike_train2, params['refractory_period'], limit=0.06) # TODO: make 'limit' a parameter. + + if p_value > 5e-3 or data['temp_diff'] > 0.10 or data['corr_diff'] > 0.12: # Make these parameters. + continue - """if C1 > C2 + 0.04: + # From this point on, the edge is treated as problematic. + if sd1 - sd2 > (C2 - C1) * sd_threshold/C_threshold + sd_threshold and sd1 > 1.05: # node 1 is problematic + if node1 not in nodes_to_remove: nodes_to_remove.append(node1) - elif C2 > C1 + 0.04: - nodes_to_remove.append(node2)""" + elif sd2 - sd1 > (C1 - C2) * sd_threshold/C_threshold + sd_threshold and sd2 > 1.05: # node 2 is problematic + if node2 not in nodes_to_remove: + nodes_to_remove.append(node2) + else: # Couldn't decide which one is problematic, Remove the connection between them. + edges_to_remove.append((node1, node2)) for node in nodes_to_remove: graph.remove_node(node) + for edge in edges_to_remove: + graph.remove_edge(*edge) def merge_sortings(self, graph: nx.Graph, params: dict) -> si.NpzSortingExtractor: """ diff --git a/tests/core/test_modulefactory.py b/tests/core/test_modulefactory.py index 1f1f467..d08fe23 100644 --- a/tests/core/test_modulefactory.py +++ b/tests/core/test_modulefactory.py @@ -18,10 +18,10 @@ def test_get_module_member() -> None: with pytest.raises(ModuleNotFoundError): ModuleFactory._get_module_member("not_a_module") - with pytest.raises(Exception): + with pytest.raises(Exception): # Fails because it contains no module. ModuleFactory._get_module_member("lussac.core.lussac_data") - with pytest.raises(Exception): + with pytest.raises(Exception): # Fails because it contains two modules. ModuleFactory._get_module_member("tests.core.test_modulefactory") assert ModuleFactory._get_module_member("lussac.modules.export_to_phy") == ExportToPhy