diff --git a/src/spikeinterface/exporters/tests/test_export_to_phy.py b/src/spikeinterface/exporters/tests/test_export_to_phy.py index 47294b3cf7..a54dbf7290 100644 --- a/src/spikeinterface/exporters/tests/test_export_to_phy.py +++ b/src/spikeinterface/exporters/tests/test_export_to_phy.py @@ -91,6 +91,69 @@ def test_export_to_phy_by_property(sorting_analyzer_with_group_for_export, creat assert template_inds.shape == (sorting_analyzer.unit_ids.size, 4) +def test_export_to_phy_metrics(sorting_analyzer_sparse_for_export, create_cache_folder): + cache_folder = create_cache_folder + + sorting_analyzer = sorting_analyzer_sparse_for_export + + # quality metrics are computed already + qm = sorting_analyzer.get_extension("quality_metrics").get_data() + output_folder = cache_folder / "phy_output_qm" + export_to_phy( + sorting_analyzer, + output_folder, + compute_pc_features=False, + compute_amplitudes=False, + n_jobs=1, + chunk_size=10000, + progress_bar=True, + add_quality_metrics=True, + ) + for col_name in qm.columns: + assert (output_folder / f"cluster_{col_name}.tsv").is_file() + + # quality metrics are computed already + tm_ext = sorting_analyzer.compute("template_metrics") + tm = tm_ext.get_data() + output_folder = cache_folder / "phy_output_tm_not_qm" + export_to_phy( + sorting_analyzer, + output_folder, + compute_pc_features=False, + compute_amplitudes=False, + n_jobs=1, + chunk_size=10000, + progress_bar=True, + add_quality_metrics=False, + add_template_metrics=True, + ) + for col_name in tm.columns: + assert (output_folder / f"cluster_{col_name}.tsv").is_file() + for col_name in qm.columns: + assert not (output_folder / f"cluster_{col_name}.tsv").is_file() + + # custom metrics + sorting_analyzer.sorting.set_property("custom_metric", np.random.rand(sorting_analyzer.unit_ids.size)) + output_folder = cache_folder / "phy_output_custom" + export_to_phy( + sorting_analyzer, + output_folder, + compute_pc_features=False, + compute_amplitudes=False, + n_jobs=1, + chunk_size=10000, + progress_bar=True, + add_quality_metrics=False, + add_template_metrics=False, + additional_properties=["custom_metric"], + ) + assert (output_folder / "cluster_custom_metric.tsv").is_file() + for col_name in tm.columns: + assert not (output_folder / f"cluster_{col_name}.tsv").is_file() + for col_name in qm.columns: + assert not (output_folder / f"cluster_{col_name}.tsv").is_file() + + if __name__ == "__main__": sorting_analyzer_sparse = make_sorting_analyzer(sparse=True) sorting_analyzer_group = make_sorting_analyzer(sparse=False, with_group=True) diff --git a/src/spikeinterface/exporters/to_phy.py b/src/spikeinterface/exporters/to_phy.py index d7be6c1ba3..7b3c7daab0 100644 --- a/src/spikeinterface/exporters/to_phy.py +++ b/src/spikeinterface/exporters/to_phy.py @@ -25,8 +25,10 @@ def export_to_phy( sparsity: Optional[ChannelSparsity] = None, copy_binary: bool = True, remove_if_exists: bool = False, - peak_sign: Literal["both", "neg", "pos"] = "neg", template_mode: str = "average", + add_quality_metrics: bool = True, + add_template_metrics: bool = True, + additional_properties: list | None = None, dtype: Optional[npt.DTypeLike] = None, verbose: bool = True, use_relative_path: bool = False, @@ -51,10 +53,14 @@ def export_to_phy( If True, the recording is copied and saved in the phy "output_folder" remove_if_exists : bool, default: False If True and "output_folder" exists, it is removed and overwritten - peak_sign : "neg" | "pos" | "both", default: "neg" - Used by compute_spike_amplitudes template_mode : str, default: "average" Parameter "mode" to be given to SortingAnalyzer.get_template() + add_quality_metrics : bool, default: True + If True, quality metrics (if computed) are saved as Phy tsv and will appear in the ClusterView. + add_template_metrics : bool, default: True + If True, template metrics (if computed) are saved as Phy tsv and will appear in the ClusterView. + additional_properties : list | None, default: None + List of additional properties to be saved as Phy tsv and will appear in the ClusterView. dtype : dtype or None, default: None Dtype to save binary data verbose : bool, default: True @@ -244,7 +250,7 @@ def export_to_phy( channel_group = pd.DataFrame({"cluster_id": [i for i in range(len(unit_ids))], "channel_group": unit_groups}) channel_group.to_csv(output_folder / "cluster_channel_group.tsv", sep="\t", index=False) - if sorting_analyzer.has_extension("quality_metrics"): + if sorting_analyzer.has_extension("quality_metrics") and add_quality_metrics: qm_data = sorting_analyzer.get_extension("quality_metrics").get_data() for column_name in qm_data.columns: # already computed by phy @@ -253,6 +259,19 @@ def export_to_phy( {"cluster_id": [i for i in range(len(unit_ids))], column_name: qm_data[column_name].values} ) metric.to_csv(output_folder / f"cluster_{column_name}.tsv", sep="\t", index=False) + if sorting_analyzer.has_extension("template_metrics") and add_template_metrics: + tm_data = sorting_analyzer.get_extension("template_metrics").get_data() + for column_name in tm_data.columns: + metric = pd.DataFrame( + {"cluster_id": [i for i in range(len(unit_ids))], column_name: tm_data[column_name].values} + ) + metric.to_csv(output_folder / f"cluster_{column_name}.tsv", sep="\t", index=False) + if additional_properties is not None: + for prop_name in additional_properties: + prop_data = sorting.get_property(prop_name) + if prop_data is not None: + prop = pd.DataFrame({"cluster_id": [i for i in range(len(unit_ids))], prop_name: prop_data}) + prop.to_csv(output_folder / f"cluster_{prop_name}.tsv", sep="\t", index=False) if verbose: print("Run:\nphy template-gui ", str(output_folder / "params.py"))