diff --git a/src/spikeinterface/curation/model_based_curation.py b/src/spikeinterface/curation/model_based_curation.py index 1e19a3a68a..0bc45e98ef 100644 --- a/src/spikeinterface/curation/model_based_curation.py +++ b/src/spikeinterface/curation/model_based_curation.py @@ -215,6 +215,7 @@ def auto_label_units( model_name=None, repo_id=None, label_conversion=None, + trusted=None, export_to_phy=False, ): """ @@ -237,6 +238,8 @@ def auto_label_units( tries to extract from `model_info.json` file. The dictionary should have the format {old_label: new_label}. export_to_phy : bool, default: False Whether to export the results to Phy format. Default is False. + trusted : list of str, default: None + Passed to skops.load. The object will be loaded only if there are only trusted objects and objects of types listed in trusted in the dumped file. Returns ------- @@ -251,7 +254,7 @@ def auto_label_units( """ from sklearn.pipeline import Pipeline - model, model_info = load_model(model_folder=model_folder, repo_id=repo_id, model_name=model_name) + model, model_info = load_model(model_folder=model_folder, repo_id=repo_id, model_name=model_name, trusted=trusted) if not isinstance(model, Pipeline): raise ValueError("The model must be an instance of sklearn.pipeline.Pipeline") @@ -265,19 +268,10 @@ def auto_label_units( return classified_units -def load_model(model_folder=None, repo_id=None, model_name=None): +def load_model(model_folder=None, repo_id=None, model_name=None, trusted=None): """ Loads a model and model_info from a folder or a huggingface repo - Parameters - ---------- - model_folder : str | Path, default: None - Path to the folder containing the model - repo_id : str | Path, default: None - Hugging face repo id e.g. 'username/model' - model_name: str | Path, default: None - Filename of model e.g. 'my_model.skops'. If None, uses first model found. - Returns ------- model, model_info @@ -296,7 +290,15 @@ def load_model(model_folder=None, repo_id=None, model_name=None): return model, model_info -def _load_model_from_huggingface(repo_id=None, model_name=None): +def _load_model_from_huggingface(repo_id=None, model_name=None, trusted=None): + """ + Loads a model from a huggingface repo + + Returns + ------- + model, model_info + A model and metadata about the model + """ from huggingface_hub import list_repo_files from huggingface_hub import hf_hub_download @@ -310,22 +312,15 @@ def _load_model_from_huggingface(repo_id=None, model_name=None): full_path = hf_hub_download(repo_id=repo_id, filename=filename) model_folder = Path(full_path).parent - model, model_info = _load_model_from_folder(model_folder=model_folder, model_name=model_name) + model, model_info = _load_model_from_folder(model_folder=model_folder, model_name=model_name, trusted=trusted) return model, model_info -def _load_model_from_folder(model_folder=None, model_name=None): +def _load_model_from_folder(model_folder=None, model_name=None, trusted=None): """ Loads a model and model_info from a folder - Parameters - ---------- - model_folder : str | Path, default: None - Path to the folder or HuggingFace directory containing the model - model_name: str | Path, default: None - Filename of model e.g. 'my_model.skops'. If None, uses first model found in directory - Returns ------- model, model_info @@ -351,7 +346,7 @@ def _load_model_from_folder(model_folder=None, model_name=None): skops_file = skops_files[0] - model = skio.load(skops_file, trusted="numpy.dtype") + model = skio.load(skops_file, trusted=trusted) model_info_path = folder / "model_info.json" if not model_info_path.is_file():