Skip to content

Commit

Permalink
Testing new merge_sortings
Browse files Browse the repository at this point in the history
Testing a new way to remove bad units (i.e. merged, mua) in the
`merge_sortings` module.
  • Loading branch information
DradeAW committed Dec 15, 2023
1 parent cffd7cc commit a9dee33
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 51 deletions.
110 changes: 61 additions & 49 deletions src/lussac/modules/merge_sortings.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def default_params(self) -> dict[str, Any]:
'num_channels': 5
},
'merge_check': {
'cross_cont_limit': 0.22
'cross_cont_limit': 0.10
}
}

Expand Down Expand Up @@ -83,7 +83,7 @@ def run(self, params: dict[str, Any]) -> dict[str, si.BaseSorting]:

if params['merge_check']:
self.remove_merged_units(graph, cross_shifts, params['refractory_period'], params['merge_check'])
self.clean_graph(graph, params)
self.clean_graph(graph, cross_shifts, params)
self._save_graph(graph, "final_graph")

merged_sorting = self.merge_sortings(graph, params)
Expand Down Expand Up @@ -176,12 +176,12 @@ def _compute_graph(self, similarity_matrices: dict[str, dict[str, np.ndarray]],
wvf_extractor = si.extract_waveforms(recording_f, sorting, folder=self.tmp_folder / f"wvfs_{name}", ms_before=1.5, ms_after=1.5, max_spikes_per_unit=150, sparse=False)
spost.compute_spike_amplitudes(wvf_extractor, peak_sign="both", return_scaled=recording_f.has_scaled())
contamination, _ = sqm.compute_refrac_period_violations(wvf_extractor, refractory_period_ms=refractory_period, censored_period_ms=censored_period)
# sd_ratio = sqm.compute_sd_ratio(wvf_extractor)
sd_ratio = sqm.compute_sd_ratio(wvf_extractor)

for unit_id in sorting.unit_ids:
attr = {key: sorting.get_unit_property(unit_id, key) for key in sorting.get_property_keys() if key.startswith('gt_')}
attr['contamination'] = contamination[unit_id]
# attr['sd_ratio'] = sd_ratio[unit_id]
attr['sd_ratio'] = sd_ratio[unit_id]
graph.add_node((name, unit_id), **attr)

del wvf_extractor
Expand Down Expand Up @@ -255,43 +255,33 @@ def remove_merged_units(self, graph: nx.Graph, cross_shifts: dict[str, dict[str,
if sorting1_name != sorting2_name:
continue

spike_train = self.sortings[sorting_name].get_unit_spike_train(unit_id)
spike_train1 = self.sortings[sorting1_name].get_unit_spike_train(unit_id1) + cross_shifts[sorting_name][sorting1_name][unit_ind, unit_ind1]
spike_train2 = self.sortings[sorting2_name].get_unit_spike_train(unit_id2) + cross_shifts[sorting_name][sorting2_name][unit_ind, unit_ind2]
C1 = graph.nodes[(sorting1_name, unit_id1)]['contamination']
C2 = graph.nodes[(sorting2_name, unit_id2)]['contamination']
C = graph.nodes[node]['contamination']
C1 = graph.nodes[node1]['contamination']
C2 = graph.nodes[node2]['contamination']
sd = graph.nodes[node]['sd_ratio']
sd1 = graph.nodes[node1]['sd_ratio']
sd2 = graph.nodes[node2]['sd_ratio']
if C2 < C1:
spike_train1, spike_train2 = spike_train2, spike_train1
cross_cont, p_value = utils.estimate_cross_contamination(spike_train1, spike_train2, refractory_period, limit=params['cross_cont_limit'])

logs.write(f"\nUnit {node} is connected to {node1} and {node2}:\n")
logs.write(f"\tcross-cont = {cross_cont:.2%} (p_value={p_value:.3f})\n")
logs.write(f"\tC1 = {C1:.1%} ; C2 = {C2:.1%}\n")
if p_value > 5e-3: # No problem, node1 and node2 are probably just a split.
logs.write(f"\tcross-cont = {cross_cont:.2%} (p_value={p_value:.2e})\n")
logs.write(f"\tC = {C:.1%} ; C1 = {C1:.1%} ; C2 = {C2:.1%}\n")
logs.write(f"\tsd = {sd:.3f} ; sd1 = {sd1:.3f} ; sd2 = {sd2:.3f}\n")
if p_value > 5e-3: # No problem detected, node1 and node2 are probably just a split.
continue

spike_train = self.sortings[sorting_name].get_unit_spike_train(unit_id)
C = graph.nodes[(sorting_name, unit_id)]['contamination']
cross_cont1, p_value1 = utils.estimate_cross_contamination(spike_train1, spike_train, refractory_period, limit=0.1)
cross_cont2, p_value2 = utils.estimate_cross_contamination(spike_train2, spike_train, refractory_period, limit=0.1)
p_value1, p_value2 = 1 - p_value1, 1 - p_value2 # Reverse the p-values because we want to know the probability <= and not >=.

logs.write(f"\tcheck1 = {cross_cont1:.2%} (p_value={p_value1:.3f})\n")
logs.write(f"\tcheck2 = {cross_cont2:.2%} (p_value={p_value2:.3f})\n")

if p_value1 < 1e-3: # node2 is the problematic unit.
if node2 not in nodes_to_remove and C2 > C + 0.02:
nodes_to_remove.append(node2)
logs.write(f"\t-> Unit {node2} is considered a problematic unit.\n")
continue
elif p_value2 < 1e-3: # node1 is the problematic unit.
if node1 not in nodes_to_remove and C1 > C + 0.02:
nodes_to_remove.append(node1)
logs.write(f"\t-> Unit {node1} is considered a problematic unit.\n")
continue
elif min(C1, C2) < 0.1 and C > min(C1, C2) + 0.02: # node is probably a merged unit.
if node not in nodes_to_remove:
nodes_to_remove.append(node)
logs.write(f"\t-> Unit {node} is considered a merged unit.\n")
cases_removed = np.array([ # If at least one is true, 'node' is removed
p_value < 1e-8 and C > max(C1, C2),
p_value < 5e-3 and sd > max(sd1, sd2) and C > max(C1, C2)
], dtype=bool)
if np.any(cases_removed):
nodes_to_remove.append(node)
break # Don't need to check the other ones.

if len(nodes_to_remove) > 0:
logs.write("\nRemoved units:\n")
Expand Down Expand Up @@ -398,39 +388,61 @@ def compute_waveform_difference(self, graph: nx.Graph, cross_shifts: dict[str, d
template_diff = np.sum(np.abs(template1 - template2)) / np.sum(np.abs(template1) + np.abs(template2))
graph[node1][node2]['temp_diff'] = template_diff

def clean_graph(self, graph: nx.Graph, params: dict) -> None: # pragma: no cover (not implemented yet)
def clean_graph(self, graph: nx.Graph, cross_shifts: dict[str, dict[str, np.ndarray]], params: dict) -> None:
"""
TODO
@param graph:
@param params:
@return:
@param graph: nx.Graph
The graph containing all the units and connected by their similarity.
@param cross_shifts: dict[str, dict[str, np.ndarray]]
The cross-shifts between units.
@param params: dict
The parameters of the merge_sorting module.
"""
# TODO: Make these parameters.
sd_threshold = 0.20
C_threshold = 0.08

nodes_to_remove = []
edges_to_remove = []

for node1, node2, data in list(graph.edges(data=True)):
if node1 in nodes_to_remove or node2 in nodes_to_remove:
continue
sorting_name1, unit_id1 = node1
sorting_name2, unit_id2 = node2

if 'corr_diff' not in data and 'temp_diff' not in data:
continue
unit_ind1 = self.sortings[sorting_name1].id_to_index(unit_id1)
unit_ind2 = self.sortings[sorting_name2].id_to_index(unit_id2)

if ('corr_diff' in data and data['corr_diff'] > 0.25) or ('temp_diff' in data and data['temp_diff'] > 0.20):
data['problem'] = True
sorting1_name, unit_id1 = node1
sorting2_name, unit_id2 = node2
C1 = graph.nodes[node1]['contamination']
C2 = graph.nodes[node2]['contamination']

C1 = utils.estimate_contamination(self.sortings[sorting1_name].get_unit_spike_train(unit_id1), refractory_period=params['refractory_period'])
C2 = utils.estimate_contamination(self.sortings[sorting2_name].get_unit_spike_train(unit_id2), refractory_period=params['refractory_period'])
sd1 = graph.nodes[node1]['sd_ratio']
sd2 = graph.nodes[node2]['sd_ratio']

spike_train1 = self.sortings[sorting_name1].get_unit_spike_train(unit_id1).astype(np.int64)
spike_train2 = self.sortings[sorting_name2].get_unit_spike_train(unit_id2).astype(np.int64) + cross_shifts[sorting_name1][sorting_name2][unit_ind1, unit_ind2]
if C2 < C1:
spike_train1, spike_train2 = spike_train2, spike_train1

cross_cont, p_value = utils.estimate_cross_contamination(spike_train1, spike_train2, params['refractory_period'], limit=0.06) # TODO: make 'limit' a parameter.

if p_value > 5e-3 or data['temp_diff'] > 0.10 or data['corr_diff'] > 0.12: # Make these parameters.
continue

"""if C1 > C2 + 0.04:
# From this point on, the edge is treated as problematic.
if sd1 - sd2 > (C2 - C1) * sd_threshold/C_threshold + sd_threshold and sd1 > 1.05: # node 1 is problematic
if node1 not in nodes_to_remove:
nodes_to_remove.append(node1)
elif C2 > C1 + 0.04:
nodes_to_remove.append(node2)"""
elif sd2 - sd1 > (C1 - C2) * sd_threshold/C_threshold + sd_threshold and sd2 > 1.05: # node 2 is problematic
if node2 not in nodes_to_remove:
nodes_to_remove.append(node2)
else: # Couldn't decide which one is problematic, Remove the connection between them.
edges_to_remove.append((node1, node2))

for node in nodes_to_remove:
graph.remove_node(node)
for edge in edges_to_remove:
graph.remove_edge(*edge)

def merge_sortings(self, graph: nx.Graph, params: dict) -> si.NpzSortingExtractor:
"""
Expand Down
4 changes: 2 additions & 2 deletions tests/core/test_modulefactory.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ def test_get_module_member() -> None:
with pytest.raises(ModuleNotFoundError):
ModuleFactory._get_module_member("not_a_module")

with pytest.raises(Exception):
with pytest.raises(Exception): # Fails because it contains no module.
ModuleFactory._get_module_member("lussac.core.lussac_data")

with pytest.raises(Exception):
with pytest.raises(Exception): # Fails because it contains two modules.
ModuleFactory._get_module_member("tests.core.test_modulefactory")

assert ModuleFactory._get_module_member("lussac.modules.export_to_phy") == ExportToPhy
Expand Down

0 comments on commit a9dee33

Please sign in to comment.