Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refine: Allow specifying the model config when loading a local model. #3200

Open
wants to merge 1 commit into
base: release/2.5.2
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 18 additions & 11 deletions paddleclas.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,24 @@ def _check_input_model(self, model_name, inference_model_dir):
all_pulc_model_names = PULC_MODELS
all_shitu_model_names = SHITU_MODELS

if model_name:
if inference_model_dir:
model_file_path = os.path.join(inference_model_dir,
"inference.pdmodel")
params_file_path = os.path.join(inference_model_dir,
"inference.pdiparams")
if not os.path.isfile(model_file_path) or not os.path.isfile(
params_file_path):
err = f"There is no model file or params file in this directory: {inference_model_dir}"
raise InputModelError(err)
model_type = "custom"
if model_name in all_imn_model_names:
model_type = "imn"
elif model_name in all_pulc_model_names:
model_type = "pulc"
elif model_name in all_shitu_model_names:
model_type = "shitu"
return model_type, inference_model_dir
elif model_name:
if model_name in all_imn_model_names:
inference_model_dir = check_model_file("imn", model_name)
return "imn", inference_model_dir
Expand All @@ -646,16 +663,6 @@ def _check_input_model(self, model_name, inference_model_dir):
similar_pulc_names)
err = f"{model_name} is not provided by PaddleClas. \nMaybe you want the : [{similar_names_str}]. \nIf you want to use your own model, please specify inference_model_dir!"
raise InputModelError(err)
elif inference_model_dir:
model_file_path = os.path.join(inference_model_dir,
"inference.pdmodel")
params_file_path = os.path.join(inference_model_dir,
"inference.pdiparams")
if not os.path.isfile(model_file_path) or not os.path.isfile(
params_file_path):
err = f"There is no model file or params file in this directory: {inference_model_dir}"
raise InputModelError(err)
return "custom", inference_model_dir
else:
err = "Please specify the model name supported by PaddleClas or directory contained model files(inference.pdmodel, inference.pdiparams)."
raise InputModelError(err)
Expand Down