From 45fa7bcf65c39f1745a0bfc1e8c8dbcc076e1135 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Mon, 9 Sep 2024 13:21:03 +0100 Subject: [PATCH] change model_folder_path to model_folder --- .../curation/model_based_curation.py | 34 +++++++++---------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/src/spikeinterface/curation/model_based_curation.py b/src/spikeinterface/curation/model_based_curation.py index 5dbe3bf203..1e19a3a68a 100644 --- a/src/spikeinterface/curation/model_based_curation.py +++ b/src/spikeinterface/curation/model_based_curation.py @@ -211,7 +211,7 @@ def _export_to_phy(self, classified_units): def auto_label_units( sorting_analyzer: SortingAnalyzer, - model_folder_path=None, + model_folder=None, model_name=None, repo_id=None, label_conversion=None, @@ -226,7 +226,7 @@ def auto_label_units( ---------- sorting_analyzer : SortingAnalyzer The sorting analyzer object containing the spike sorting results. - model_folder_path : str or Path, defualt: None + model_folder : str or Path, defualt: None The path to the folder containing the model repo_id : str | Path, default: None Hugging face repo id which contains the model e.g. 'username/model' @@ -251,7 +251,7 @@ def auto_label_units( """ from sklearn.pipeline import Pipeline - model, model_info = load_model(model_folder_path=model_folder_path, repo_id=repo_id, model_name=model_name) + model, model_info = load_model(model_folder=model_folder, repo_id=repo_id, model_name=model_name) if not isinstance(model, Pipeline): raise ValueError("The model must be an instance of sklearn.pipeline.Pipeline") @@ -265,13 +265,13 @@ def auto_label_units( return classified_units -def load_model(model_folder_path=None, repo_id=None, model_name=None): +def load_model(model_folder=None, repo_id=None, model_name=None): """ Loads a model and model_info from a folder or a huggingface repo Parameters ---------- - model_folder_path : str | Path, default: None + model_folder : str | Path, default: None Path to the folder containing the model repo_id : str | Path, default: None Hugging face repo id e.g. 'username/model' @@ -284,12 +284,12 @@ def load_model(model_folder_path=None, repo_id=None, model_name=None): A model and metadata about the model """ - if model_folder_path is None and repo_id is None: - raise ValueError("Please provide a 'model_folder_path' or a 'repo_id'.") - elif model_folder_path is not None and repo_id is not None: - raise ValueError("Please only provide one of 'model_folder_path' or 'repo_id'.") - elif model_folder_path is not None: - model, model_info = _load_model_from_folder(model_folder_path=model_folder_path, model_name=model_name) + if model_folder is None and repo_id is None: + raise ValueError("Please provide a 'model_folder' or a 'repo_id'.") + elif model_folder is not None and repo_id is not None: + raise ValueError("Please only provide one of 'model_folder' or 'repo_id'.") + elif model_folder is not None: + model, model_info = _load_model_from_folder(model_folder=model_folder, model_name=model_name) else: model, model_info = _load_model_from_huggingface(repo_id=repo_id, model_name=model_name) @@ -308,20 +308,20 @@ def _load_model_from_huggingface(repo_id=None, model_name=None): for filename in repo_filenames: if Path(filename).suffix in [".skops", ".json"]: full_path = hf_hub_download(repo_id=repo_id, filename=filename) - model_folder_path = Path(full_path).parent + model_folder = Path(full_path).parent - model, model_info = _load_model_from_folder(model_folder_path=model_folder_path, model_name=model_name) + model, model_info = _load_model_from_folder(model_folder=model_folder, model_name=model_name) return model, model_info -def _load_model_from_folder(model_folder_path=None, model_name=None): +def _load_model_from_folder(model_folder=None, model_name=None): """ Loads a model and model_info from a folder Parameters ---------- - model_folder_path : str | Path, default: None + model_folder : str | Path, default: None Path to the folder or HuggingFace directory containing the model model_name: str | Path, default: None Filename of model e.g. 'my_model.skops'. If None, uses first model found in directory @@ -334,11 +334,11 @@ def _load_model_from_folder(model_folder_path=None, model_name=None): import skops.io as skio - folder = Path(model_folder_path) + folder = Path(model_folder) assert folder.is_dir(), f"The folder {folder}, does not exist." if model_name is not None: - skops_file = Path(model_folder_path) / Path(model_name) + skops_file = Path(model_folder) / Path(model_name) assert skops_file.is_file(), f"Model file {skops_file} not found." else: # look for any .skops files