From a07debbb4b2404743ff2bd8f99dbe76b04172656 Mon Sep 17 00:00:00 2001 From: Sean Lee Date: Sat, 1 Jun 2024 11:27:58 +0800 Subject: [PATCH] new function save_pretrained(); check repo exist when pushing to hub --- angle_emb/angle.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/angle_emb/angle.py b/angle_emb/angle.py index 2c7e3a2..6be0462 100644 --- a/angle_emb/angle.py +++ b/angle_emb/angle.py @@ -27,6 +27,7 @@ ) from transformers.tokenization_utils_base import PreTrainedTokenizerBase from transformers.utils import PaddingStrategy +from huggingface_hub import repo_exists from peft import ( get_peft_model, LoraConfig, TaskType, PeftModel, prepare_model_for_kbit_training, @@ -1636,12 +1637,27 @@ def encode(self, return output.float().detach().cpu().numpy() return output - def push_to_hub(self, hub_model_id: str, private: bool = True, **kwargs): + def push_to_hub(self, hub_model_id: str, private: bool = True, exist_ok: bool = False, **kwargs): """ push model to hub :param hub_model_id: str, hub model id. :param private: bool, whether push to private repo. Default True. + :param exist_ok: bool, whether allow overwrite. Default False. :param kwargs: other kwargs for `push_to_hub` method. """ + if not exist_ok and repo_exists(hub_model_id): + raise ValueError(f"Model {hub_model_id} already exists on the hub. Set `exist_ok=True` to overwrite.") self.tokenizer.push_to_hub(hub_model_id, private=private, **kwargs) self.backbone.push_to_hub(hub_model_id, private=private, **kwargs) + + def save_pretrained(self, output_dir: str, exist_ok: bool = True): + """ save model and tokenizer + + :param output_dir: str, output dir. + :param exist_ok: bool, whether allow overwrite. Default True. + """ + if not exist_ok and os.path.exists(output_dir): + raise ValueError(f"Output directory ({output_dir}) already exists and is not empty.") + os.makedirs(output_dir) + self.tokenizer.save_pretrained(output_dir) + self.backbone.save_pretrained(output_dir)