Skip to content

Commit

Permalink
Pass trusted to skops.load
Browse files Browse the repository at this point in the history
  • Loading branch information
chrishalcrow committed Sep 9, 2024
1 parent 45fa7bc commit 7198d2e
Showing 1 changed file with 17 additions and 22 deletions.
39 changes: 17 additions & 22 deletions src/spikeinterface/curation/model_based_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ def auto_label_units(
model_name=None,
repo_id=None,
label_conversion=None,
trusted=None,
export_to_phy=False,
):
"""
Expand All @@ -237,6 +238,8 @@ def auto_label_units(
tries to extract from `model_info.json` file. The dictionary should have the format {old_label: new_label}.
export_to_phy : bool, default: False
Whether to export the results to Phy format. Default is False.
trusted : list of str, default: None
Passed to skops.load. The object will be loaded only if there are only trusted objects and objects of types listed in trusted in the dumped file.
Returns
-------
Expand All @@ -251,7 +254,7 @@ def auto_label_units(
"""
from sklearn.pipeline import Pipeline

model, model_info = load_model(model_folder=model_folder, repo_id=repo_id, model_name=model_name)
model, model_info = load_model(model_folder=model_folder, repo_id=repo_id, model_name=model_name, trusted=trusted)

if not isinstance(model, Pipeline):
raise ValueError("The model must be an instance of sklearn.pipeline.Pipeline")
Expand All @@ -265,19 +268,10 @@ def auto_label_units(
return classified_units


def load_model(model_folder=None, repo_id=None, model_name=None):
def load_model(model_folder=None, repo_id=None, model_name=None, trusted=None):
"""
Loads a model and model_info from a folder or a huggingface repo
Parameters
----------
model_folder : str | Path, default: None
Path to the folder containing the model
repo_id : str | Path, default: None
Hugging face repo id e.g. 'username/model'
model_name: str | Path, default: None
Filename of model e.g. 'my_model.skops'. If None, uses first model found.
Returns
-------
model, model_info
Expand All @@ -296,7 +290,15 @@ def load_model(model_folder=None, repo_id=None, model_name=None):
return model, model_info


def _load_model_from_huggingface(repo_id=None, model_name=None):
def _load_model_from_huggingface(repo_id=None, model_name=None, trusted=None):
"""
Loads a model from a huggingface repo
Returns
-------
model, model_info
A model and metadata about the model
"""

from huggingface_hub import list_repo_files
from huggingface_hub import hf_hub_download
Expand All @@ -310,22 +312,15 @@ def _load_model_from_huggingface(repo_id=None, model_name=None):
full_path = hf_hub_download(repo_id=repo_id, filename=filename)
model_folder = Path(full_path).parent

model, model_info = _load_model_from_folder(model_folder=model_folder, model_name=model_name)
model, model_info = _load_model_from_folder(model_folder=model_folder, model_name=model_name, trusted=trusted)

return model, model_info


def _load_model_from_folder(model_folder=None, model_name=None):
def _load_model_from_folder(model_folder=None, model_name=None, trusted=None):
"""
Loads a model and model_info from a folder
Parameters
----------
model_folder : str | Path, default: None
Path to the folder or HuggingFace directory containing the model
model_name: str | Path, default: None
Filename of model e.g. 'my_model.skops'. If None, uses first model found in directory
Returns
-------
model, model_info
Expand All @@ -351,7 +346,7 @@ def _load_model_from_folder(model_folder=None, model_name=None):

skops_file = skops_files[0]

model = skio.load(skops_file, trusted="numpy.dtype")
model = skio.load(skops_file, trusted=trusted)

model_info_path = folder / "model_info.json"
if not model_info_path.is_file():
Expand Down

0 comments on commit 7198d2e

Please sign in to comment.