diff --git a/multimolecule/models/rnabert/convert_checkpoint.py b/multimolecule/models/rnabert/convert_checkpoint.py index 1783a5ce..2d3bee5e 100644 --- a/multimolecule/models/rnabert/convert_checkpoint.py +++ b/multimolecule/models/rnabert/convert_checkpoint.py @@ -76,6 +76,12 @@ def convert_checkpoint(convert_config): if HfApi is None: raise ImportError("Please install huggingface_hub to push to the hub.") api = HfApi() + if convert_config.delete_existing: + api.delete_repo( + convert_config.repo_id, + token=convert_config.token, + missing_ok=True, + ) api.create_repo( convert_config.repo_id, token=convert_config.token, @@ -91,9 +97,14 @@ class ConvertConfig: checkpoint_path: str output_path: str = Config.model_type push_to_hub: bool = False - repo_id: str = f"multimolecule/{output_path}" + delete_existing: bool = False + repo_id: Optional[str] = None token: Optional[str] = None + def post(self): + if self.repo_id is None: + self.repo_id = f"multimolecule/{self.output_path}" + if __name__ == "__main__": config = ConvertConfig() diff --git a/multimolecule/models/rnafm/convert_checkpoint.py b/multimolecule/models/rnafm/convert_checkpoint.py index 25499c57..083c11f6 100644 --- a/multimolecule/models/rnafm/convert_checkpoint.py +++ b/multimolecule/models/rnafm/convert_checkpoint.py @@ -113,6 +113,12 @@ def convert_checkpoint(convert_config): if HfApi is None: raise ImportError("Please install huggingface_hub to push to the hub.") api = HfApi() + if convert_config.delete_existing: + api.delete_repo( + convert_config.repo_id, + token=convert_config.token, + missing_ok=True, + ) api.create_repo( convert_config.repo_id, token=convert_config.token, @@ -128,9 +134,14 @@ class ConvertConfig: checkpoint_path: str output_path: str = Config.model_type push_to_hub: bool = False - repo_id: str = f"multimolecule/{output_path}" + delete_existing: bool = False + repo_id: Optional[str] = None token: Optional[str] = None + def post(self): + if self.repo_id is None: + self.repo_id = f"multimolecule/{self.output_path}" + if __name__ == "__main__": config = ConvertConfig() diff --git a/multimolecule/models/rnamsm/convert_checkpoint.py b/multimolecule/models/rnamsm/convert_checkpoint.py index 0d181763..b328f4a1 100644 --- a/multimolecule/models/rnamsm/convert_checkpoint.py +++ b/multimolecule/models/rnamsm/convert_checkpoint.py @@ -76,6 +76,12 @@ def convert_checkpoint(convert_config): if HfApi is None: raise ImportError("Please install huggingface_hub to push to the hub.") api = HfApi() + if convert_config.delete_existing: + api.delete_repo( + convert_config.repo_id, + token=convert_config.token, + missing_ok=True, + ) api.create_repo( convert_config.repo_id, token=convert_config.token, @@ -91,9 +97,14 @@ class ConvertConfig: checkpoint_path: str output_path: str = Config.model_type push_to_hub: bool = False - repo_id: str = f"multimolecule/{output_path}" + delete_existing: bool = False + repo_id: Optional[str] = None token: Optional[str] = None + def post(self): + if self.repo_id is None: + self.repo_id = f"multimolecule/{self.output_path}" + if __name__ == "__main__": config = ConvertConfig() diff --git a/multimolecule/models/splicebert/convert_checkpoint.py b/multimolecule/models/splicebert/convert_checkpoint.py index ae0d8e3b..88b1d4bd 100644 --- a/multimolecule/models/splicebert/convert_checkpoint.py +++ b/multimolecule/models/splicebert/convert_checkpoint.py @@ -89,6 +89,12 @@ def convert_checkpoint(convert_config): if HfApi is None: raise ImportError("Please install huggingface_hub to push to the hub.") api = HfApi() + if convert_config.delete_existing: + api.delete_repo( + convert_config.repo_id, + token=convert_config.token, + missing_ok=True, + ) api.create_repo( convert_config.repo_id, token=convert_config.token, @@ -104,7 +110,8 @@ class ConvertConfig: checkpoint_path: str output_path: Optional[str] = None push_to_hub: bool = False - repo_id: Optional[str] = output_path + delete_existing: bool = False + repo_id: Optional[str] = None token: Optional[str] = None def post(self): diff --git a/multimolecule/models/utrbert/convert_checkpoint.py b/multimolecule/models/utrbert/convert_checkpoint.py index 0e5d451e..7a0ad68b 100644 --- a/multimolecule/models/utrbert/convert_checkpoint.py +++ b/multimolecule/models/utrbert/convert_checkpoint.py @@ -92,6 +92,12 @@ def convert_checkpoint(convert_config): if HfApi is None: raise ImportError("Please install huggingface_hub to push to the hub.") api = HfApi() + if convert_config.delete_existing: + api.delete_repo( + convert_config.repo_id, + token=convert_config.token, + missing_ok=True, + ) api.create_repo( convert_config.repo_id, token=convert_config.token, @@ -107,7 +113,8 @@ class ConvertConfig: checkpoint_path: str output_path: Optional[str] = None push_to_hub: bool = False - repo_id: Optional[str] = output_path + delete_existing: bool = False + repo_id: Optional[str] = None token: Optional[str] = None def post(self): diff --git a/multimolecule/models/utrlm/convert_checkpoint.py b/multimolecule/models/utrlm/convert_checkpoint.py index f46ee645..1c2ceb93 100644 --- a/multimolecule/models/utrlm/convert_checkpoint.py +++ b/multimolecule/models/utrlm/convert_checkpoint.py @@ -95,6 +95,12 @@ def convert_checkpoint(convert_config): if HfApi is None: raise ImportError("Please install huggingface_hub to push to the hub.") api = HfApi() + if convert_config.delete_existing: + api.delete_repo( + convert_config.repo_id, + token=convert_config.token, + missing_ok=True, + ) api.create_repo( convert_config.repo_id, token=convert_config.token, @@ -110,9 +116,14 @@ class ConvertConfig: checkpoint_path: str output_path: str = Config.model_type push_to_hub: bool = False - repo_id: str = f"multimolecule/{output_path}" + delete_existing: bool = False + repo_id: Optional[str] = None token: Optional[str] = None + def post(self): + if self.repo_id is None: + self.repo_id = f"multimolecule/{self.output_path}" + if __name__ == "__main__": config = ConvertConfig()