Skip to content

Commit

Permalink
new function save_pretrained(); check repo exist when pushing to hub
Browse files Browse the repository at this point in the history
  • Loading branch information
SeanLee97 committed Jun 1, 2024
1 parent d049fd4 commit a07debb
Showing 1 changed file with 17 additions and 1 deletion.
18 changes: 17 additions & 1 deletion angle_emb/angle.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
)
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.utils import PaddingStrategy
from huggingface_hub import repo_exists
from peft import (
get_peft_model, LoraConfig, TaskType, PeftModel,
prepare_model_for_kbit_training,
Expand Down Expand Up @@ -1636,12 +1637,27 @@ def encode(self,
return output.float().detach().cpu().numpy()
return output

def push_to_hub(self, hub_model_id: str, private: bool = True, **kwargs):
def push_to_hub(self, hub_model_id: str, private: bool = True, exist_ok: bool = False, **kwargs):
""" push model to hub
:param hub_model_id: str, hub model id.
:param private: bool, whether push to private repo. Default True.
:param exist_ok: bool, whether allow overwrite. Default False.
:param kwargs: other kwargs for `push_to_hub` method.
"""
if not exist_ok and repo_exists(hub_model_id):
raise ValueError(f"Model {hub_model_id} already exists on the hub. Set `exist_ok=True` to overwrite.")
self.tokenizer.push_to_hub(hub_model_id, private=private, **kwargs)
self.backbone.push_to_hub(hub_model_id, private=private, **kwargs)

def save_pretrained(self, output_dir: str, exist_ok: bool = True):
""" save model and tokenizer
:param output_dir: str, output dir.
:param exist_ok: bool, whether allow overwrite. Default True.
"""
if not exist_ok and os.path.exists(output_dir):
raise ValueError(f"Output directory ({output_dir}) already exists and is not empty.")
os.makedirs(output_dir)
self.tokenizer.save_pretrained(output_dir)
self.backbone.save_pretrained(output_dir)

0 comments on commit a07debb

Please sign in to comment.