diff --git a/skema/img2mml/api.py b/skema/img2mml/api.py index 07b74ede5a6..501fe1fabe6 100644 --- a/skema/img2mml/api.py +++ b/skema/img2mml/api.py @@ -25,8 +25,8 @@ def retrieve_model(model_path=None) -> str: cwd = Path(__file__).parents[0] MODEL_BASE_ADDRESS = "https://artifacts.askem.lum.ai/skema/img2mml/models" MODEL_NAME = "cnn_xfmer_arxiv_im2mml_with_fonts_boldface_best.pt" - - if model_path is None: + # If the model path is none or doesn't exist, the default model will be downloaded from server. + if model_path is None or not os.path.exists(model_path): model_path = cwd / "trained_models" / MODEL_NAME # Check if the model file already exists