From a097ee2126002384065291fa6e1667047f92a5b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=A8=E7=92=90?= Date: Sat, 27 Jul 2024 10:27:31 +0800 Subject: [PATCH] refine: Allow specifying the model config when loading a local model. --- paddleclas.py | 29 ++++++++++++++++++----------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/paddleclas.py b/paddleclas.py index d7db179053..d0aa082d33 100644 --- a/paddleclas.py +++ b/paddleclas.py @@ -621,7 +621,24 @@ def _check_input_model(self, model_name, inference_model_dir): all_pulc_model_names = PULC_MODELS all_shitu_model_names = SHITU_MODELS - if model_name: + if inference_model_dir: + model_file_path = os.path.join(inference_model_dir, + "inference.pdmodel") + params_file_path = os.path.join(inference_model_dir, + "inference.pdiparams") + if not os.path.isfile(model_file_path) or not os.path.isfile( + params_file_path): + err = f"There is no model file or params file in this directory: {inference_model_dir}" + raise InputModelError(err) + model_type = "custom" + if model_name in all_imn_model_names: + model_type = "imn" + elif model_name in all_pulc_model_names: + model_type = "pulc" + elif model_name in all_shitu_model_names: + model_type = "shitu" + return model_type, inference_model_dir + elif model_name: if model_name in all_imn_model_names: inference_model_dir = check_model_file("imn", model_name) return "imn", inference_model_dir @@ -646,16 +663,6 @@ def _check_input_model(self, model_name, inference_model_dir): similar_pulc_names) err = f"{model_name} is not provided by PaddleClas. \nMaybe you want the : [{similar_names_str}]. \nIf you want to use your own model, please specify inference_model_dir!" raise InputModelError(err) - elif inference_model_dir: - model_file_path = os.path.join(inference_model_dir, - "inference.pdmodel") - params_file_path = os.path.join(inference_model_dir, - "inference.pdiparams") - if not os.path.isfile(model_file_path) or not os.path.isfile( - params_file_path): - err = f"There is no model file or params file in this directory: {inference_model_dir}" - raise InputModelError(err) - return "custom", inference_model_dir else: err = "Please specify the model name supported by PaddleClas or directory contained model files(inference.pdmodel, inference.pdiparams)." raise InputModelError(err)