From 0b0248f164f7ed08450a30cde4dac776c7770f1a Mon Sep 17 00:00:00 2001 From: Zhiyuan Chen Date: Mon, 22 Apr 2024 23:57:39 +0800 Subject: [PATCH] support delete_existing in convert_checkpoint Signed-off-by: Zhiyuan Chen --- multimolecule/models/rnabert/convert_checkpoint.py | 13 ++++++++++++- multimolecule/models/rnafm/convert_checkpoint.py | 13 ++++++++++++- multimolecule/models/rnamsm/convert_checkpoint.py | 13 ++++++++++++- .../models/splicebert/convert_checkpoint.py | 9 ++++++++- multimolecule/models/utrbert/convert_checkpoint.py | 9 ++++++++- multimolecule/models/utrlm/convert_checkpoint.py | 13 ++++++++++++- 6 files changed, 64 insertions(+), 6 deletions(-) diff --git a/multimolecule/models/rnabert/convert_checkpoint.py b/multimolecule/models/rnabert/convert_checkpoint.py index 1783a5ce..2d3bee5e 100644 --- a/multimolecule/models/rnabert/convert_checkpoint.py +++ b/multimolecule/models/rnabert/convert_checkpoint.py @@ -76,6 +76,12 @@ def convert_checkpoint(convert_config): if HfApi is None: raise ImportError("Please install huggingface_hub to push to the hub.") api = HfApi() + if convert_config.delete_existing: + api.delete_repo( + convert_config.repo_id, + token=convert_config.token, + missing_ok=True, + ) api.create_repo( convert_config.repo_id, token=convert_config.token, @@ -91,9 +97,14 @@ class ConvertConfig: checkpoint_path: str output_path: str = Config.model_type push_to_hub: bool = False - repo_id: str = f"multimolecule/{output_path}" + delete_existing: bool = False + repo_id: Optional[str] = None token: Optional[str] = None + def post(self): + if self.repo_id is None: + self.repo_id = f"multimolecule/{self.output_path}" + if __name__ == "__main__": config = ConvertConfig() diff --git a/multimolecule/models/rnafm/convert_checkpoint.py b/multimolecule/models/rnafm/convert_checkpoint.py index 25499c57..083c11f6 100644 --- a/multimolecule/models/rnafm/convert_checkpoint.py +++ b/multimolecule/models/rnafm/convert_checkpoint.py @@ -113,6 +113,12 @@ def convert_checkpoint(convert_config): if HfApi is None: raise ImportError("Please install huggingface_hub to push to the hub.") api = HfApi() + if convert_config.delete_existing: + api.delete_repo( + convert_config.repo_id, + token=convert_config.token, + missing_ok=True, + ) api.create_repo( convert_config.repo_id, token=convert_config.token, @@ -128,9 +134,14 @@ class ConvertConfig: checkpoint_path: str output_path: str = Config.model_type push_to_hub: bool = False - repo_id: str = f"multimolecule/{output_path}" + delete_existing: bool = False + repo_id: Optional[str] = None token: Optional[str] = None + def post(self): + if self.repo_id is None: + self.repo_id = f"multimolecule/{self.output_path}" + if __name__ == "__main__": config = ConvertConfig() diff --git a/multimolecule/models/rnamsm/convert_checkpoint.py b/multimolecule/models/rnamsm/convert_checkpoint.py index 0d181763..b328f4a1 100644 --- a/multimolecule/models/rnamsm/convert_checkpoint.py +++ b/multimolecule/models/rnamsm/convert_checkpoint.py @@ -76,6 +76,12 @@ def convert_checkpoint(convert_config): if HfApi is None: raise ImportError("Please install huggingface_hub to push to the hub.") api = HfApi() + if convert_config.delete_existing: + api.delete_repo( + convert_config.repo_id, + token=convert_config.token, + missing_ok=True, + ) api.create_repo( convert_config.repo_id, token=convert_config.token, @@ -91,9 +97,14 @@ class ConvertConfig: checkpoint_path: str output_path: str = Config.model_type push_to_hub: bool = False - repo_id: str = f"multimolecule/{output_path}" + delete_existing: bool = False + repo_id: Optional[str] = None token: Optional[str] = None + def post(self): + if self.repo_id is None: + self.repo_id = f"multimolecule/{self.output_path}" + if __name__ == "__main__": config = ConvertConfig() diff --git a/multimolecule/models/splicebert/convert_checkpoint.py b/multimolecule/models/splicebert/convert_checkpoint.py index ae0d8e3b..88b1d4bd 100644 --- a/multimolecule/models/splicebert/convert_checkpoint.py +++ b/multimolecule/models/splicebert/convert_checkpoint.py @@ -89,6 +89,12 @@ def convert_checkpoint(convert_config): if HfApi is None: raise ImportError("Please install huggingface_hub to push to the hub.") api = HfApi() + if convert_config.delete_existing: + api.delete_repo( + convert_config.repo_id, + token=convert_config.token, + missing_ok=True, + ) api.create_repo( convert_config.repo_id, token=convert_config.token, @@ -104,7 +110,8 @@ class ConvertConfig: checkpoint_path: str output_path: Optional[str] = None push_to_hub: bool = False - repo_id: Optional[str] = output_path + delete_existing: bool = False + repo_id: Optional[str] = None token: Optional[str] = None def post(self): diff --git a/multimolecule/models/utrbert/convert_checkpoint.py b/multimolecule/models/utrbert/convert_checkpoint.py index 0e5d451e..7a0ad68b 100644 --- a/multimolecule/models/utrbert/convert_checkpoint.py +++ b/multimolecule/models/utrbert/convert_checkpoint.py @@ -92,6 +92,12 @@ def convert_checkpoint(convert_config): if HfApi is None: raise ImportError("Please install huggingface_hub to push to the hub.") api = HfApi() + if convert_config.delete_existing: + api.delete_repo( + convert_config.repo_id, + token=convert_config.token, + missing_ok=True, + ) api.create_repo( convert_config.repo_id, token=convert_config.token, @@ -107,7 +113,8 @@ class ConvertConfig: checkpoint_path: str output_path: Optional[str] = None push_to_hub: bool = False - repo_id: Optional[str] = output_path + delete_existing: bool = False + repo_id: Optional[str] = None token: Optional[str] = None def post(self): diff --git a/multimolecule/models/utrlm/convert_checkpoint.py b/multimolecule/models/utrlm/convert_checkpoint.py index f46ee645..1c2ceb93 100644 --- a/multimolecule/models/utrlm/convert_checkpoint.py +++ b/multimolecule/models/utrlm/convert_checkpoint.py @@ -95,6 +95,12 @@ def convert_checkpoint(convert_config): if HfApi is None: raise ImportError("Please install huggingface_hub to push to the hub.") api = HfApi() + if convert_config.delete_existing: + api.delete_repo( + convert_config.repo_id, + token=convert_config.token, + missing_ok=True, + ) api.create_repo( convert_config.repo_id, token=convert_config.token, @@ -110,9 +116,14 @@ class ConvertConfig: checkpoint_path: str output_path: str = Config.model_type push_to_hub: bool = False - repo_id: str = f"multimolecule/{output_path}" + delete_existing: bool = False + repo_id: Optional[str] = None token: Optional[str] = None + def post(self): + if self.repo_id is None: + self.repo_id = f"multimolecule/{self.output_path}" + if __name__ == "__main__": config = ConvertConfig()