Skip to content

Commit

Permalink
change model_folder_path to model_folder
Browse files Browse the repository at this point in the history
  • Loading branch information
chrishalcrow committed Sep 9, 2024
1 parent 28918a5 commit 45fa7bc
Showing 1 changed file with 17 additions and 17 deletions.
34 changes: 17 additions & 17 deletions src/spikeinterface/curation/model_based_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def _export_to_phy(self, classified_units):

def auto_label_units(
sorting_analyzer: SortingAnalyzer,
model_folder_path=None,
model_folder=None,
model_name=None,
repo_id=None,
label_conversion=None,
Expand All @@ -226,7 +226,7 @@ def auto_label_units(
----------
sorting_analyzer : SortingAnalyzer
The sorting analyzer object containing the spike sorting results.
model_folder_path : str or Path, defualt: None
model_folder : str or Path, defualt: None
The path to the folder containing the model
repo_id : str | Path, default: None
Hugging face repo id which contains the model e.g. 'username/model'
Expand All @@ -251,7 +251,7 @@ def auto_label_units(
"""
from sklearn.pipeline import Pipeline

model, model_info = load_model(model_folder_path=model_folder_path, repo_id=repo_id, model_name=model_name)
model, model_info = load_model(model_folder=model_folder, repo_id=repo_id, model_name=model_name)

if not isinstance(model, Pipeline):
raise ValueError("The model must be an instance of sklearn.pipeline.Pipeline")
Expand All @@ -265,13 +265,13 @@ def auto_label_units(
return classified_units


def load_model(model_folder_path=None, repo_id=None, model_name=None):
def load_model(model_folder=None, repo_id=None, model_name=None):
"""
Loads a model and model_info from a folder or a huggingface repo
Parameters
----------
model_folder_path : str | Path, default: None
model_folder : str | Path, default: None
Path to the folder containing the model
repo_id : str | Path, default: None
Hugging face repo id e.g. 'username/model'
Expand All @@ -284,12 +284,12 @@ def load_model(model_folder_path=None, repo_id=None, model_name=None):
A model and metadata about the model
"""

if model_folder_path is None and repo_id is None:
raise ValueError("Please provide a 'model_folder_path' or a 'repo_id'.")
elif model_folder_path is not None and repo_id is not None:
raise ValueError("Please only provide one of 'model_folder_path' or 'repo_id'.")
elif model_folder_path is not None:
model, model_info = _load_model_from_folder(model_folder_path=model_folder_path, model_name=model_name)
if model_folder is None and repo_id is None:
raise ValueError("Please provide a 'model_folder' or a 'repo_id'.")
elif model_folder is not None and repo_id is not None:
raise ValueError("Please only provide one of 'model_folder' or 'repo_id'.")
elif model_folder is not None:
model, model_info = _load_model_from_folder(model_folder=model_folder, model_name=model_name)
else:
model, model_info = _load_model_from_huggingface(repo_id=repo_id, model_name=model_name)

Expand All @@ -308,20 +308,20 @@ def _load_model_from_huggingface(repo_id=None, model_name=None):
for filename in repo_filenames:
if Path(filename).suffix in [".skops", ".json"]:
full_path = hf_hub_download(repo_id=repo_id, filename=filename)
model_folder_path = Path(full_path).parent
model_folder = Path(full_path).parent

model, model_info = _load_model_from_folder(model_folder_path=model_folder_path, model_name=model_name)
model, model_info = _load_model_from_folder(model_folder=model_folder, model_name=model_name)

return model, model_info


def _load_model_from_folder(model_folder_path=None, model_name=None):
def _load_model_from_folder(model_folder=None, model_name=None):
"""
Loads a model and model_info from a folder
Parameters
----------
model_folder_path : str | Path, default: None
model_folder : str | Path, default: None
Path to the folder or HuggingFace directory containing the model
model_name: str | Path, default: None
Filename of model e.g. 'my_model.skops'. If None, uses first model found in directory
Expand All @@ -334,11 +334,11 @@ def _load_model_from_folder(model_folder_path=None, model_name=None):

import skops.io as skio

folder = Path(model_folder_path)
folder = Path(model_folder)
assert folder.is_dir(), f"The folder {folder}, does not exist."

if model_name is not None:
skops_file = Path(model_folder_path) / Path(model_name)
skops_file = Path(model_folder) / Path(model_name)
assert skops_file.is_file(), f"Model file {skops_file} not found."
else:
# look for any .skops files
Expand Down

0 comments on commit 45fa7bc

Please sign in to comment.