diff --git a/src/huggingface_hub/__init__.py b/src/huggingface_hub/__init__.py index f02353edff..04fe01a50e 100644 --- a/src/huggingface_hub/__init__.py +++ b/src/huggingface_hub/__init__.py @@ -30,4 +30,5 @@ ) from .file_download import cached_download, hf_hub_url from .hf_api import HfApi, HfFolder +from .hub_mixin import ModelHubMixin from .repository import Repository diff --git a/src/huggingface_hub/hub_mixin.py b/src/huggingface_hub/hub_mixin.py new file mode 100644 index 0000000000..081be72bfc --- /dev/null +++ b/src/huggingface_hub/hub_mixin.py @@ -0,0 +1,246 @@ +import json +import logging +import os +from typing import Dict, Optional + +import requests + +from .constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME +from .file_download import cached_download, hf_hub_url, is_torch_available +from .hf_api import HfApi, HfFolder +from .repository import Repository + + +if is_torch_available(): + import torch + + +logger = logging.getLogger(__name__) + + +class ModelHubMixin(object): + def __init__(self, *args, **kwargs): + """ + Mix this class with your torch-model class for ease process of saving & loading from huggingface-hub + + Example:: + + >>> from huggingface_hub import ModelHubMixin + + >>> class MyModel(nn.Module, ModelHubMixin): + ... def __init__(self, **kwargs): + ... super().__init__() + ... self.config = kwargs.pop("config", None) + ... self.layer = ... + ... def forward(self, ...) + ... return ... + + >>> model = MyModel() + >>> model.save_pretrained("mymodel", push_to_hub=False) # Saving model weights in the directory + >>> model.push_to_hub("mymodel", "model-1") # Pushing model-weights to hf-hub + + >>> # Downloading weights from hf-hub & model will be initialized from those weights + >>> model = MyModel.from_pretrained("username/mymodel@main") + """ + + def save_pretrained( + self, + save_directory: str, + config: Optional[dict] = None, + push_to_hub: bool = False, + **kwargs, + ): + """ + Saving weights in local directory. + + Parameters: + save_directory (:obj:`str`): + Specify directory in which you want to save weights. + config (:obj:`dict`, `optional`): + specify config (must be dict) incase you want to save it. + push_to_hub (:obj:`bool`, `optional`, defaults to :obj:`False`): + Set it to `True` in case you want to push your weights to huggingface_hub + model_id (:obj:`str`, `optional`, defaults to :obj:`save_directory`): + Repo name in huggingface_hub. If not specified, repo name will be same as `save_directory` + kwargs (:obj:`Dict`, `optional`): + kwargs will be passed to `push_to_hub` + """ + + os.makedirs(save_directory, exist_ok=True) + + # saving config + if isinstance(config, dict): + path = os.path.join(save_directory, CONFIG_NAME) + with open(path, "w") as f: + json.dump(config, f) + + # saving model weights + path = os.path.join(save_directory, PYTORCH_WEIGHTS_NAME) + self._save_pretrained(path) + + if push_to_hub: + return self.push_to_hub(save_directory, **kwargs) + + def _save_pretrained(self, path): + """ + Overwrite this method in case you don't want to save complete model, rather some specific layers + """ + model_to_save = self.module if hasattr(self, "module") else self + torch.save(model_to_save.state_dict(), path) + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Optional[str], + strict: bool = True, + map_location: Optional[str] = "cpu", + force_download: bool = False, + resume_download: bool = False, + proxies: Dict = None, + use_auth_token: Optional[str] = None, + cache_dir: Optional[str] = None, + local_files_only: bool = False, + **model_kwargs, + ): + r""" + Instantiate a pretrained pytorch model from a pre-trained model configuration from huggingface-hub. + The model is set in evaluation mode by default using ``model.eval()`` (Dropout modules are deactivated). To + train the model, you should first set it back in training mode with ``model.train()``. + + Parameters: + pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`, `optional`): + Can be either: + - A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co. + Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under + a user or organization name, like ``dbmdz/bert-base-german-cased``. + - You can add `revision` by appending `@` at the end of model_id simply like this: ``dbmdz/bert-base-german-cased@main`` + Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id, + since we use a git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any identifier allowed by git. + - A path to a `directory` containing model weights saved using + :func:`~transformers.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``. + - :obj:`None` if you are both providing the configuration and state dictionary (resp. with keyword + arguments ``config`` and ``state_dict``). + cache_dir (:obj:`Union[str, os.PathLike]`, `optional`): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + force_download (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (:obj:`Dict[str, str], `optional`): + A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (:obj:`str` or `bool`, `optional`): + The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token + generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`). + model_kwargs (:obj:`Dict`, `optional`):: + model_kwargs will be passed to the model during initialization + .. note:: + Passing :obj:`use_auth_token=True` is required when you want to use a private model. + """ + + model_id = pretrained_model_name_or_path + map_location = torch.device(map_location) + + revision = None + if len(model_id.split("@")) == 2: + model_id, revision = model_id.split("@") + + if model_id in os.listdir() and CONFIG_NAME in os.listdir(model_id): + config_file = os.path.join(model_id, CONFIG_NAME) + else: + try: + config_url = hf_hub_url( + model_id, filename=CONFIG_NAME, revision=revision + ) + config_file = cached_download( + config_url, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + ) + except requests.exceptions.RequestException: + logger.warning("config.json NOT FOUND in HuggingFace Hub") + config_file = None + + if model_id in os.listdir(): + print("LOADING weights from local directory") + model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME) + else: + model_url = hf_hub_url( + model_id, filename=PYTORCH_WEIGHTS_NAME, revision=revision + ) + model_file = cached_download( + model_url, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + ) + + if config_file is not None: + with open(config_file, "r", encoding="utf-8") as f: + config = json.load(f) + model_kwargs.update({"config": config}) + + model = cls(**model_kwargs) + + state_dict = torch.load(model_file, map_location=map_location) + model.load_state_dict(state_dict, strict=strict) + model.eval() + + return model + + @staticmethod + def push_to_hub( + save_directory: Optional[str], + model_id: Optional[str] = None, + repo_url: Optional[str] = None, + commit_message: Optional[str] = "add model", + organization: Optional[str] = None, + private: bool = None, + ) -> str: + """ + Parameters: + save_directory (:obj:`Union[str, os.PathLike]`): + Directory having model weights & config. + model_id (:obj:`str`, `optional`, defaults to :obj:`save_directory`): + Repo name in huggingface_hub. If not specified, repo name will be same as `save_directory` + repo_url (:obj:`str`, `optional`): + Specify this in case you want to push to existing repo in hub. + organization (:obj:`str`, `optional`): + Organization in which you want to push your model. + private (:obj:`bool`, `optional`): + private: Whether the model repo should be private (requires a paid huggingface.co account) + commit_message (:obj:`str`, `optional`, defaults to :obj:`add model`): + Message to commit while pushing + + Returns: + url to commit on remote repo. + """ + if model_id is None: + model_id = save_directory + + token = HfFolder.get_token() + if repo_url is None: + repo_url = HfApi().create_repo( + token, + model_id, + organization=organization, + private=private, + repo_type=None, + exist_ok=True, + ) + + repo = Repository(save_directory, clone_from=repo_url, use_auth_token=token) + + return repo.push_to_hub(commit_message=commit_message) diff --git a/src/huggingface_hub/repository.py b/src/huggingface_hub/repository.py index 2eb7a42228..e1607d36b5 100644 --- a/src/huggingface_hub/repository.py +++ b/src/huggingface_hub/repository.py @@ -159,14 +159,26 @@ def clone_from(self, repo_url: str, use_auth_token: Union[bool, str, None] = Non encoding="utf-8", cwd=self.local_dir, ) - subprocess.run( - ["git", "remote", "add", "origin", repo_url], + + output = subprocess.run( + "git remote -v".split(), stderr=subprocess.PIPE, stdout=subprocess.PIPE, check=True, encoding="utf-8", cwd=self.local_dir, ) + + if "origin" not in output.stdout.split(): + subprocess.run( + ["git", "remote", "add", "origin", repo_url], + stderr=subprocess.PIPE, + stdout=subprocess.PIPE, + check=True, + encoding="utf-8", + cwd=self.local_dir, + ) + subprocess.run( "git fetch".split(), stderr=subprocess.PIPE, @@ -183,15 +195,27 @@ def clone_from(self, repo_url: str, use_auth_token: Union[bool, str, None] = Non check=True, cwd=self.local_dir, ) - # TODO(check if we really want the --force flag) - subprocess.run( - "git checkout origin/main -ft".split(), + + output = subprocess.run( + "git branch".split(), stderr=subprocess.PIPE, stdout=subprocess.PIPE, - encoding="utf-8", check=True, + encoding="utf-8", cwd=self.local_dir, ) + + if "main" not in output.stdout.split(): + # TODO(check if we really want the --force flag) + subprocess.run( + "git checkout origin/main -ft".split(), + stderr=subprocess.PIPE, + stdout=subprocess.PIPE, + encoding="utf-8", + check=True, + cwd=self.local_dir, + ) + except subprocess.CalledProcessError as exc: raise EnvironmentError(exc.stderr) diff --git a/tests/test_hubmixin.py b/tests/test_hubmixin.py new file mode 100644 index 0000000000..23a3e5a249 --- /dev/null +++ b/tests/test_hubmixin.py @@ -0,0 +1,66 @@ +import unittest + +from huggingface_hub.file_download import is_torch_available +from huggingface_hub.hub_mixin import ModelHubMixin + + +if is_torch_available(): + import torch.nn as nn + + +HUGGINGFACE_ID = "vasudevgupta" +DUMMY_REPO_NAME = "dummy" + + +def require_torch(test_case): + """ + Decorator marking a test that requires PyTorch. + + These tests are skipped when PyTorch isn't installed. + + """ + if not is_torch_available(): + return unittest.skip("test requires PyTorch")(test_case) + else: + return test_case + + +@require_torch +class DummyModel(ModelHubMixin): + def __init__(self, **kwargs): + super().__init__() + self.config = kwargs.pop("config", None) + self.l1 = nn.Linear(2, 2) + + def forward(self, x): + return self.l1(x) + + +@require_torch +class DummyModelTest(unittest.TestCase): + def test_save_pretrained(self): + model = DummyModel() + model.save_pretrained(DUMMY_REPO_NAME) + model.save_pretrained( + DUMMY_REPO_NAME, config={"num": 12, "act": "gelu"}, push_to_hub=True + ) + model.save_pretrained( + DUMMY_REPO_NAME, config={"num": 24, "act": "relu"}, push_to_hub=True + ) + model.save_pretrained( + "dummy-wts", config=None, push_to_hub=True, model_id=DUMMY_REPO_NAME + ) + + def test_from_pretrained(self): + model = DummyModel() + model.save_pretrained( + DUMMY_REPO_NAME, config={"num": 7, "act": "gelu_fast"}, push_to_hub=True + ) + + model = DummyModel.from_pretrained(f"{HUGGINGFACE_ID}/{DUMMY_REPO_NAME}@main") + self.assertTrue(model.config == {"num": 7, "act": "gelu_fast"}) + + def test_push_to_hub(self): + model = DummyModel() + model.save_pretrained("dummy-wts", push_to_hub=False) + model.push_to_hub("dummy-wts", model_id=DUMMY_REPO_NAME)