Skip to content
This repository has been archived by the owner on Jan 8, 2025. It is now read-only.

Commit

Permalink
Fix a bug in retrieve model (ml4ai#331)
Browse files Browse the repository at this point in the history
If the model path is None or doesn't exist when calling the retrieve_model
function, download the default model from the server.
  • Loading branch information
ualiangzhang authored Jul 14, 2023
1 parent 9f4c8e1 commit abef48f
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions skema/img2mml/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ def retrieve_model(model_path=None) -> str:
cwd = Path(__file__).parents[0]
MODEL_BASE_ADDRESS = "https://artifacts.askem.lum.ai/skema/img2mml/models"
MODEL_NAME = "cnn_xfmer_arxiv_im2mml_with_fonts_boldface_best.pt"

if model_path is None:
# If the model path is none or doesn't exist, the default model will be downloaded from server.
if model_path is None or not os.path.exists(model_path):
model_path = cwd / "trained_models" / MODEL_NAME

# Check if the model file already exists
Expand Down

0 comments on commit abef48f

Please sign in to comment.