Skip to content

Commit

Permalink
Refactor to take advantage of unified quality and template metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
chrishalcrow committed Dec 4, 2024
1 parent 955ca72 commit 233d598
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 97 deletions.
78 changes: 31 additions & 47 deletions src/spikeinterface/curation/model_based_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,53 +150,37 @@ def _check_params_for_classification(self, enforce_metric_params=False, model_in
Path to model_info.json provenance file
"""

quality_metrics_extension = self.sorting_analyzer.get_extension("quality_metrics")
template_metrics_extension = self.sorting_analyzer.get_extension("template_metrics")

if quality_metrics_extension is not None:

model_quality_metrics_params = model_info["metric_params"]["quality_metric_params"]
quality_metrics_params = quality_metrics_extension.params["metric_params"]

inconsistent_metrics = []
for metric in model_quality_metrics_params["metric_names"]:
if metric not in model_quality_metrics_params["metric_params"]:
inconsistent_metrics += metric
else:
if quality_metrics_params[metric] != model_quality_metrics_params["metric_params"][metric]:
warning_message = f"Quality metric params for {metric} do not match those used to train the model. Parameters can be found in the 'model_info.json' file."
if enforce_metric_params is True:
raise Exception(warning_message)
else:
warnings.warn(warning_message)

if len(inconsistent_metrics) > 0:
warning_message = (
f"Parameters used to compute metrics {inconsistent_metrics}, used to train this model, are unknown."
)
if enforce_metric_params is True:
raise Exception(warning_message)
else:
warnings.warn(warning_message)

if template_metrics_extension is not None:

model_template_metrics_params = model_info["metric_params"]["template_metric_params"]["metric_params"]
template_metrics_params = template_metrics_extension.params["metric_params"]

if template_metrics_params == {}:
warning_message = "Parameters used to compute template metrics, used to train this model, are unknown."
if enforce_metric_params is True:
raise Exception(warning_message)
else:
warnings.warn(warning_message)

if template_metrics_params != model_template_metrics_params:
warning_message = "Template metrics params do not match those used to train model. Parameters can be found in the 'model_info.json' file."
if enforce_metric_params is True:
raise Exception(warning_message)
else:
warnings.warn(warning_message)
extension_names = ["quality_metrics", "template_metrics"]

metric_extensions = [self.sorting_analyzer.get_extension(extension_name) for extension_name in extension_names]

for metric_extension, extension_name in zip(metric_extensions, extension_names):

# remove the 's' at the end of the extension name
extension_name = extension_name[:-1]
if metric_extension is not None:

model_metric_params = model_info["metric_params"][extension_name + "_params"]
metric_params = metric_extension.params["metric_params"]

inconsistent_metrics = []
for metric in model_metric_params["metric_names"]:
if metric not in model_metric_params["metric_params"]:
inconsistent_metrics += metric
else:
if metric_params[metric] != model_metric_params["metric_params"][metric]:
warning_message = f"{extension_name} params for {metric} do not match those used to train the model. Parameters can be found in the 'model_info.json' file."
if enforce_metric_params is True:
raise Exception(warning_message)
else:
warnings.warn(warning_message)

if len(inconsistent_metrics) > 0:
warning_message = f"Parameters used to compute metrics {inconsistent_metrics}, used to train this model, are unknown."
if enforce_metric_params is True:
raise Exception(warning_message)
else:
warnings.warn(warning_message)

def _export_to_phy(self, classified_units):
"""Export the classified units to Phy as cluster_prediction.tsv file"""
Expand Down
90 changes: 40 additions & 50 deletions src/spikeinterface/curation/train_manual_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,28 +237,32 @@ def load_and_preprocess_analyzers(self, analyzers, enforce_metric_params):
conflicting_metrics = self._check_metrics_parameters(analyzers, enforce_metric_params)

self.metrics_params = {}
if analyzers[0].has_extension("quality_metrics") is True:
self.metrics_params["quality_metric_params"] = analyzers[0].extensions["quality_metrics"].params

# Only save metric params which are 1) consistent and 2) exist in metric_names
qm_names = self.metrics_params["quality_metric_params"]["metric_names"]
consistent_metrics = list(set(qm_names).difference(set(conflicting_metrics)))
consistent_metric_params = {
metric: analyzers[0].extensions["quality_metrics"].params["metric_params"][metric]
for metric in consistent_metrics
}
self.metrics_params["quality_metric_params"]["metric_params"] = consistent_metric_params

if analyzers[0].has_extension("template_metrics") is True:
self.metrics_params["template_metric_params"] = deepcopy(analyzers[0].extensions["template_metrics"].params)
if "template_metrics" in conflicting_metrics:
self.metrics_params["template_metric_params"] = {}

extension_names = ["quality_metrics", "template_metrics"]
metric_extensions = [analyzers[0].get_extension(extension_name) for extension_name in extension_names]

for metric_extension, extension_name in zip(metric_extensions, extension_names):

# remove the 's' at the end of the extension name
extension_name = extension_name[:-1]
if metric_extension is not None:
self.metrics_params[extension_name + "_params"] = metric_extension.params

# Only save metric params which are 1) consistent and 2) exist in metric_names
metric_names = metric_extension.params["metric_names"]
consistent_metrics = list(set(metric_names).difference(set(conflicting_metrics)))
consistent_metric_params = {
metric: metric_extension.params["metric_params"][metric] for metric in consistent_metrics
}
self.metrics_params[extension_name + "_params"]["metric_params"] = consistent_metric_params

self.process_test_data_for_classification()

def _check_metrics_parameters(self, analyzers, enforce_metric_params):
"""Checks that the metrics of each analyzer have been calcualted using the same parameters"""

extension_names = ["quality_metrics", "template_metrics"]

conflicting_metrics = []
for analyzer_index_1, analyzer_1 in enumerate(analyzers):
for analyzer_index_2, analyzer_2 in enumerate(analyzers):
Expand All @@ -267,29 +271,20 @@ def _check_metrics_parameters(self, analyzers, enforce_metric_params):
continue
else:

qm_params_1 = {}
qm_params_2 = {}
tm_params_1 = {}
tm_params_2 = {}
metric_params_1 = {}
metric_params_2 = {}

if analyzer_1.has_extension("quality_metrics") is True:
qm_params_1 = analyzer_1.extensions["quality_metrics"].params["metric_params"]
if analyzer_2.has_extension("quality_metrics") is True:
qm_params_2 = analyzer_2.extensions["quality_metrics"].params["metric_params"]
if analyzer_1.has_extension("template_metrics") is True:
tm_params_1 = analyzer_1.extensions["template_metrics"].params["metric_params"]
if analyzer_2.has_extension("template_metrics") is True:
tm_params_2 = analyzer_2.extensions["template_metrics"].params["metric_params"]
for extension_name in extension_names:
if (extension_1 := analyzer_1.get_extension(extension_name)) is not None:
metric_params_1.update(extension_1.params["metric_params"])
if (extension_2 := analyzer_2.get_extension(extension_name)) is not None:
metric_params_2.update(extension_2.params["metric_params"])

conflicting_metrics_between_1_2 = []
# check quality metrics params
for metric, params_1 in qm_params_1.items():
if params_1 != qm_params_2.get(metric):
for metric, params_1 in metric_params_1.items():
if params_1 != metric_params_2.get(metric):
conflicting_metrics_between_1_2.append(metric)
# check template metric params
for metric, params_1 in tm_params_1.items():
if params_1 != tm_params_2.get(metric):
conflicting_metrics_between_1_2.append("template_metrics")

conflicting_metrics += conflicting_metrics_between_1_2

Expand Down Expand Up @@ -737,28 +732,23 @@ def _get_computed_metrics(sorting_analyzer):

def try_to_get_metrics_from_analyzer(sorting_analyzer):

quality_metrics = None
template_metrics = None
extension_names = ["quality_metrics", "template_metrics"]
metric_extensions = [sorting_analyzer.get_extension(extension_name) for extension_name in extension_names]

# Try to get metrics if available
try:
quality_metrics = sorting_analyzer.get_extension("quality_metrics").get_data()
except:
pass

try:
template_metrics = sorting_analyzer.get_extension("template_metrics").get_data()
except:
pass

# Check if at least one of the metrics is available
if quality_metrics is None and template_metrics is None:
if any(metric_extensions) is False:
raise ValueError(
"At least one of quality metrics or template metrics must be computed before classification.",
"Compute both using `sorting_analyzer.compute('quality_metrics', 'template_metrics')",
)

return quality_metrics, template_metrics
metric_extensions_data = []
for metric_extension in metric_extensions:
try:
metric_extensions_data.append(metric_extension.get_data())
except:
metric_extensions_data.append(None)

return metric_extensions_data


def set_default_search_kwargs(search_kwargs):
Expand Down

0 comments on commit 233d598

Please sign in to comment.