diff --git a/yolov5/helpers.py b/yolov5/helpers.py index fbf2798..9c0c442 100644 --- a/yolov5/helpers.py +++ b/yolov5/helpers.py @@ -9,7 +9,7 @@ def load_model( - model_path, device=None, autoshape=True, verbose=False, hf_token: str = None + model_path, device=None, autoshape=True, verbose=False, hf_token: str = None, revision: str = None ): """ Creates a specified YOLOv5 model @@ -21,6 +21,7 @@ def load_model( autoshape (bool): make model ready for inference verbose (bool): if False, yolov5 logs will be silent hf_token (str): huggingface read token for private models + revision (str): huggingface model revision Returns: pytorch model @@ -36,7 +37,7 @@ def load_model( try: model = DetectMultiBackend( - model_path, device=device, fuse=autoshape, hf_token=hf_token + model_path, device=device, fuse=autoshape, hf_token=hf_token, revision=revision ) # detection model if autoshape: if model.pt and isinstance(model.model, ClassificationModel): diff --git a/yolov5/models/common.py b/yolov5/models/common.py index 8ec382a..e997d29 100644 --- a/yolov5/models/common.py +++ b/yolov5/models/common.py @@ -315,7 +315,7 @@ def forward(self, x): class DetectMultiBackend(nn.Module): # YOLOv5 MultiBackend class for python inference on various backends - def __init__(self, weights='yolov5s.pt', device=torch.device('cpu'), dnn=False, data=None, fp16=False, fuse=True, hf_token=None): + def __init__(self, weights='yolov5s.pt', device=torch.device('cpu'), dnn=False, data=None, fp16=False, fuse=True, hf_token=None, revision=None): # Usage: # PyTorch: weights = *.pt # TorchScript: *.torchscript @@ -335,7 +335,7 @@ def __init__(self, weights='yolov5s.pt', device=torch.device('cpu'), dnn=False, w = str(weights[0] if isinstance(weights, list) else weights) # try to dowload from hf hub - result = attempt_download_from_hub(w, hf_token=hf_token) + result = attempt_download_from_hub(w, hf_token=hf_token, revision=revision) if result is not None: w = result diff --git a/yolov5/utils/downloads.py b/yolov5/utils/downloads.py index e3270ef..e07acb8 100644 --- a/yolov5/utils/downloads.py +++ b/yolov5/utils/downloads.py @@ -145,7 +145,7 @@ def get_model_filename_from_hfhub(repo_id, hf_token=None): return None -def attempt_download_from_hub(repo_id, hf_token=None): +def attempt_download_from_hub(repo_id, hf_token=None, revision=None): from huggingface_hub import hf_hub_download, list_repo_files from huggingface_hub.utils._errors import RepositoryNotFoundError from huggingface_hub.utils._validators import HFValidationError @@ -161,6 +161,7 @@ def attempt_download_from_hub(repo_id, hf_token=None): filename=config_file, repo_type='model', token=hf_token, + revision=revision, ) # download model file @@ -170,6 +171,7 @@ def attempt_download_from_hub(repo_id, hf_token=None): filename=model_file, repo_type='model', token=hf_token, + revision=revision, ) return file except (RepositoryNotFoundError, HFValidationError):