Skip to content

Commit

Permalink
Add support to huggingface hub download with revision version (#255)
Browse files Browse the repository at this point in the history
  • Loading branch information
muhammadariffaizin authored Aug 11, 2024
1 parent e5a577b commit 7663793
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 5 deletions.
5 changes: 3 additions & 2 deletions yolov5/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


def load_model(
model_path, device=None, autoshape=True, verbose=False, hf_token: str = None
model_path, device=None, autoshape=True, verbose=False, hf_token: str = None, revision: str = None
):
"""
Creates a specified YOLOv5 model
Expand All @@ -21,6 +21,7 @@ def load_model(
autoshape (bool): make model ready for inference
verbose (bool): if False, yolov5 logs will be silent
hf_token (str): huggingface read token for private models
revision (str): huggingface model revision
Returns:
pytorch model
Expand All @@ -36,7 +37,7 @@ def load_model(

try:
model = DetectMultiBackend(
model_path, device=device, fuse=autoshape, hf_token=hf_token
model_path, device=device, fuse=autoshape, hf_token=hf_token, revision=revision
) # detection model
if autoshape:
if model.pt and isinstance(model.model, ClassificationModel):
Expand Down
4 changes: 2 additions & 2 deletions yolov5/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def forward(self, x):

class DetectMultiBackend(nn.Module):
# YOLOv5 MultiBackend class for python inference on various backends
def __init__(self, weights='yolov5s.pt', device=torch.device('cpu'), dnn=False, data=None, fp16=False, fuse=True, hf_token=None):
def __init__(self, weights='yolov5s.pt', device=torch.device('cpu'), dnn=False, data=None, fp16=False, fuse=True, hf_token=None, revision=None):
# Usage:
# PyTorch: weights = *.pt
# TorchScript: *.torchscript
Expand All @@ -335,7 +335,7 @@ def __init__(self, weights='yolov5s.pt', device=torch.device('cpu'), dnn=False,
w = str(weights[0] if isinstance(weights, list) else weights)

# try to dowload from hf hub
result = attempt_download_from_hub(w, hf_token=hf_token)
result = attempt_download_from_hub(w, hf_token=hf_token, revision=revision)
if result is not None:
w = result

Expand Down
4 changes: 3 additions & 1 deletion yolov5/utils/downloads.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def get_model_filename_from_hfhub(repo_id, hf_token=None):
return None


def attempt_download_from_hub(repo_id, hf_token=None):
def attempt_download_from_hub(repo_id, hf_token=None, revision=None):
from huggingface_hub import hf_hub_download, list_repo_files
from huggingface_hub.utils._errors import RepositoryNotFoundError
from huggingface_hub.utils._validators import HFValidationError
Expand All @@ -161,6 +161,7 @@ def attempt_download_from_hub(repo_id, hf_token=None):
filename=config_file,
repo_type='model',
token=hf_token,
revision=revision,
)

# download model file
Expand All @@ -170,6 +171,7 @@ def attempt_download_from_hub(repo_id, hf_token=None):
filename=model_file,
repo_type='model',
token=hf_token,
revision=revision,
)
return file
except (RepositoryNotFoundError, HFValidationError):
Expand Down

1 comment on commit 7663793

@regaliaf
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from huggingface_hub.utils._errors import RepositoryNotFoundError
This module has been removed from huggingface_hub since 0.25.0

Please sign in to comment.