-
Notifications
You must be signed in to change notification settings - Fork 559
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding mixin class for ease saving, uploading, downloading (as discus…
…sed in issue #9). (#11) * work initiated * start upload_to_hub * add changes * final-push * i feel this is better. * updated for Repositary class * small updates * fix mutiple calling * small fix * make style * add everything * minor fix * minor fix * done evrything * small fix * [doc] remove mention of TF support * Fix typings (i think) * We do NOT want to have a hard requirement on torch * Fix flake8 * Fix CI Co-authored-by: Julien Chaumond <[email protected]>
- Loading branch information
1 parent
be902d8
commit 9e1d3d5
Showing
4 changed files
with
343 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,246 @@ | ||
import json | ||
import logging | ||
import os | ||
from typing import Dict, Optional | ||
|
||
import requests | ||
|
||
from .constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME | ||
from .file_download import cached_download, hf_hub_url, is_torch_available | ||
from .hf_api import HfApi, HfFolder | ||
from .repository import Repository | ||
|
||
|
||
if is_torch_available(): | ||
import torch | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class ModelHubMixin(object): | ||
def __init__(self, *args, **kwargs): | ||
""" | ||
Mix this class with your torch-model class for ease process of saving & loading from huggingface-hub | ||
Example:: | ||
>>> from huggingface_hub import ModelHubMixin | ||
>>> class MyModel(nn.Module, ModelHubMixin): | ||
... def __init__(self, **kwargs): | ||
... super().__init__() | ||
... self.config = kwargs.pop("config", None) | ||
... self.layer = ... | ||
... def forward(self, ...) | ||
... return ... | ||
>>> model = MyModel() | ||
>>> model.save_pretrained("mymodel", push_to_hub=False) # Saving model weights in the directory | ||
>>> model.push_to_hub("mymodel", "model-1") # Pushing model-weights to hf-hub | ||
>>> # Downloading weights from hf-hub & model will be initialized from those weights | ||
>>> model = MyModel.from_pretrained("username/mymodel@main") | ||
""" | ||
|
||
def save_pretrained( | ||
self, | ||
save_directory: str, | ||
config: Optional[dict] = None, | ||
push_to_hub: bool = False, | ||
**kwargs, | ||
): | ||
""" | ||
Saving weights in local directory. | ||
Parameters: | ||
save_directory (:obj:`str`): | ||
Specify directory in which you want to save weights. | ||
config (:obj:`dict`, `optional`): | ||
specify config (must be dict) incase you want to save it. | ||
push_to_hub (:obj:`bool`, `optional`, defaults to :obj:`False`): | ||
Set it to `True` in case you want to push your weights to huggingface_hub | ||
model_id (:obj:`str`, `optional`, defaults to :obj:`save_directory`): | ||
Repo name in huggingface_hub. If not specified, repo name will be same as `save_directory` | ||
kwargs (:obj:`Dict`, `optional`): | ||
kwargs will be passed to `push_to_hub` | ||
""" | ||
|
||
os.makedirs(save_directory, exist_ok=True) | ||
|
||
# saving config | ||
if isinstance(config, dict): | ||
path = os.path.join(save_directory, CONFIG_NAME) | ||
with open(path, "w") as f: | ||
json.dump(config, f) | ||
|
||
# saving model weights | ||
path = os.path.join(save_directory, PYTORCH_WEIGHTS_NAME) | ||
self._save_pretrained(path) | ||
|
||
if push_to_hub: | ||
return self.push_to_hub(save_directory, **kwargs) | ||
|
||
def _save_pretrained(self, path): | ||
""" | ||
Overwrite this method in case you don't want to save complete model, rather some specific layers | ||
""" | ||
model_to_save = self.module if hasattr(self, "module") else self | ||
torch.save(model_to_save.state_dict(), path) | ||
|
||
@classmethod | ||
def from_pretrained( | ||
cls, | ||
pretrained_model_name_or_path: Optional[str], | ||
strict: bool = True, | ||
map_location: Optional[str] = "cpu", | ||
force_download: bool = False, | ||
resume_download: bool = False, | ||
proxies: Dict = None, | ||
use_auth_token: Optional[str] = None, | ||
cache_dir: Optional[str] = None, | ||
local_files_only: bool = False, | ||
**model_kwargs, | ||
): | ||
r""" | ||
Instantiate a pretrained pytorch model from a pre-trained model configuration from huggingface-hub. | ||
The model is set in evaluation mode by default using ``model.eval()`` (Dropout modules are deactivated). To | ||
train the model, you should first set it back in training mode with ``model.train()``. | ||
Parameters: | ||
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`, `optional`): | ||
Can be either: | ||
- A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co. | ||
Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under | ||
a user or organization name, like ``dbmdz/bert-base-german-cased``. | ||
- You can add `revision` by appending `@` at the end of model_id simply like this: ``dbmdz/bert-base-german-cased@main`` | ||
Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id, | ||
since we use a git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any identifier allowed by git. | ||
- A path to a `directory` containing model weights saved using | ||
:func:`~transformers.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``. | ||
- :obj:`None` if you are both providing the configuration and state dictionary (resp. with keyword | ||
arguments ``config`` and ``state_dict``). | ||
cache_dir (:obj:`Union[str, os.PathLike]`, `optional`): | ||
Path to a directory in which a downloaded pretrained model configuration should be cached if the | ||
standard cache should not be used. | ||
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`): | ||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the | ||
cached versions if they exist. | ||
resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`): | ||
Whether or not to delete incompletely received files. Will attempt to resume the download if such a | ||
file exists. | ||
proxies (:obj:`Dict[str, str], `optional`): | ||
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128', | ||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. | ||
local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`): | ||
Whether or not to only look at local files (i.e., do not try to download the model). | ||
use_auth_token (:obj:`str` or `bool`, `optional`): | ||
The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token | ||
generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`). | ||
model_kwargs (:obj:`Dict`, `optional`):: | ||
model_kwargs will be passed to the model during initialization | ||
.. note:: | ||
Passing :obj:`use_auth_token=True` is required when you want to use a private model. | ||
""" | ||
|
||
model_id = pretrained_model_name_or_path | ||
map_location = torch.device(map_location) | ||
|
||
revision = None | ||
if len(model_id.split("@")) == 2: | ||
model_id, revision = model_id.split("@") | ||
|
||
if model_id in os.listdir() and CONFIG_NAME in os.listdir(model_id): | ||
config_file = os.path.join(model_id, CONFIG_NAME) | ||
else: | ||
try: | ||
config_url = hf_hub_url( | ||
model_id, filename=CONFIG_NAME, revision=revision | ||
) | ||
config_file = cached_download( | ||
config_url, | ||
cache_dir=cache_dir, | ||
force_download=force_download, | ||
proxies=proxies, | ||
resume_download=resume_download, | ||
local_files_only=local_files_only, | ||
use_auth_token=use_auth_token, | ||
) | ||
except requests.exceptions.RequestException: | ||
logger.warning("config.json NOT FOUND in HuggingFace Hub") | ||
config_file = None | ||
|
||
if model_id in os.listdir(): | ||
print("LOADING weights from local directory") | ||
model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME) | ||
else: | ||
model_url = hf_hub_url( | ||
model_id, filename=PYTORCH_WEIGHTS_NAME, revision=revision | ||
) | ||
model_file = cached_download( | ||
model_url, | ||
cache_dir=cache_dir, | ||
force_download=force_download, | ||
proxies=proxies, | ||
resume_download=resume_download, | ||
local_files_only=local_files_only, | ||
use_auth_token=use_auth_token, | ||
) | ||
|
||
if config_file is not None: | ||
with open(config_file, "r", encoding="utf-8") as f: | ||
config = json.load(f) | ||
model_kwargs.update({"config": config}) | ||
|
||
model = cls(**model_kwargs) | ||
|
||
state_dict = torch.load(model_file, map_location=map_location) | ||
model.load_state_dict(state_dict, strict=strict) | ||
model.eval() | ||
|
||
return model | ||
|
||
@staticmethod | ||
def push_to_hub( | ||
save_directory: Optional[str], | ||
model_id: Optional[str] = None, | ||
repo_url: Optional[str] = None, | ||
commit_message: Optional[str] = "add model", | ||
organization: Optional[str] = None, | ||
private: bool = None, | ||
) -> str: | ||
""" | ||
Parameters: | ||
save_directory (:obj:`Union[str, os.PathLike]`): | ||
Directory having model weights & config. | ||
model_id (:obj:`str`, `optional`, defaults to :obj:`save_directory`): | ||
Repo name in huggingface_hub. If not specified, repo name will be same as `save_directory` | ||
repo_url (:obj:`str`, `optional`): | ||
Specify this in case you want to push to existing repo in hub. | ||
organization (:obj:`str`, `optional`): | ||
Organization in which you want to push your model. | ||
private (:obj:`bool`, `optional`): | ||
private: Whether the model repo should be private (requires a paid huggingface.co account) | ||
commit_message (:obj:`str`, `optional`, defaults to :obj:`add model`): | ||
Message to commit while pushing | ||
Returns: | ||
url to commit on remote repo. | ||
""" | ||
if model_id is None: | ||
model_id = save_directory | ||
|
||
token = HfFolder.get_token() | ||
if repo_url is None: | ||
repo_url = HfApi().create_repo( | ||
token, | ||
model_id, | ||
organization=organization, | ||
private=private, | ||
repo_type=None, | ||
exist_ok=True, | ||
) | ||
|
||
repo = Repository(save_directory, clone_from=repo_url, use_auth_token=token) | ||
|
||
return repo.push_to_hub(commit_message=commit_message) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import unittest | ||
|
||
from huggingface_hub.file_download import is_torch_available | ||
from huggingface_hub.hub_mixin import ModelHubMixin | ||
|
||
|
||
if is_torch_available(): | ||
import torch.nn as nn | ||
|
||
|
||
HUGGINGFACE_ID = "vasudevgupta" | ||
DUMMY_REPO_NAME = "dummy" | ||
|
||
|
||
def require_torch(test_case): | ||
""" | ||
Decorator marking a test that requires PyTorch. | ||
These tests are skipped when PyTorch isn't installed. | ||
""" | ||
if not is_torch_available(): | ||
return unittest.skip("test requires PyTorch")(test_case) | ||
else: | ||
return test_case | ||
|
||
|
||
@require_torch | ||
class DummyModel(ModelHubMixin): | ||
def __init__(self, **kwargs): | ||
super().__init__() | ||
self.config = kwargs.pop("config", None) | ||
self.l1 = nn.Linear(2, 2) | ||
|
||
def forward(self, x): | ||
return self.l1(x) | ||
|
||
|
||
@require_torch | ||
class DummyModelTest(unittest.TestCase): | ||
def test_save_pretrained(self): | ||
model = DummyModel() | ||
model.save_pretrained(DUMMY_REPO_NAME) | ||
model.save_pretrained( | ||
DUMMY_REPO_NAME, config={"num": 12, "act": "gelu"}, push_to_hub=True | ||
) | ||
model.save_pretrained( | ||
DUMMY_REPO_NAME, config={"num": 24, "act": "relu"}, push_to_hub=True | ||
) | ||
model.save_pretrained( | ||
"dummy-wts", config=None, push_to_hub=True, model_id=DUMMY_REPO_NAME | ||
) | ||
|
||
def test_from_pretrained(self): | ||
model = DummyModel() | ||
model.save_pretrained( | ||
DUMMY_REPO_NAME, config={"num": 7, "act": "gelu_fast"}, push_to_hub=True | ||
) | ||
|
||
model = DummyModel.from_pretrained(f"{HUGGINGFACE_ID}/{DUMMY_REPO_NAME}@main") | ||
self.assertTrue(model.config == {"num": 7, "act": "gelu_fast"}) | ||
|
||
def test_push_to_hub(self): | ||
model = DummyModel() | ||
model.save_pretrained("dummy-wts", push_to_hub=False) | ||
model.push_to_hub("dummy-wts", model_id=DUMMY_REPO_NAME) |