Skip to content

Commit

Permalink
support delete_existing in convert_checkpoint
Browse files Browse the repository at this point in the history
Signed-off-by: Zhiyuan Chen <[email protected]>
  • Loading branch information
ZhiyuanChen committed Apr 22, 2024
1 parent 8200437 commit 1d6f318
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 6 deletions.
13 changes: 12 additions & 1 deletion multimolecule/models/rnabert/convert_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,12 @@ def convert_checkpoint(convert_config):
if HfApi is None:
raise ImportError("Please install huggingface_hub to push to the hub.")
api = HfApi()
if convert_config.delete_existing:
api.delete_repo(
convert_config.repo_id,
token=convert_config.token,
missing_ok=True,
)
api.create_repo(
convert_config.repo_id,
token=convert_config.token,
Expand All @@ -91,9 +97,14 @@ class ConvertConfig:
checkpoint_path: str
output_path: str = Config.model_type
push_to_hub: bool = False
repo_id: str = f"multimolecule/{output_path}"
delete_existing: bool = False
repo_id: Optional[str] = None
token: Optional[str] = None

def post(self):
if self.repo_id is None:
self.repo_id = f"multimolecule/{self.output_path}"


if __name__ == "__main__":
config = ConvertConfig()
Expand Down
13 changes: 12 additions & 1 deletion multimolecule/models/rnafm/convert_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,12 @@ def convert_checkpoint(convert_config):
if HfApi is None:
raise ImportError("Please install huggingface_hub to push to the hub.")
api = HfApi()
if convert_config.delete_existing:
api.delete_repo(
convert_config.repo_id,
token=convert_config.token,
missing_ok=True,
)
api.create_repo(
convert_config.repo_id,
token=convert_config.token,
Expand All @@ -128,9 +134,14 @@ class ConvertConfig:
checkpoint_path: str
output_path: str = Config.model_type
push_to_hub: bool = False
repo_id: str = f"multimolecule/{output_path}"
delete_existing: bool = False
repo_id: Optional[str] = None
token: Optional[str] = None

def post(self):
if self.repo_id is None:
self.repo_id = f"multimolecule/{self.output_path}"


if __name__ == "__main__":
config = ConvertConfig()
Expand Down
13 changes: 12 additions & 1 deletion multimolecule/models/rnamsm/convert_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,12 @@ def convert_checkpoint(convert_config):
if HfApi is None:
raise ImportError("Please install huggingface_hub to push to the hub.")
api = HfApi()
if convert_config.delete_existing:
api.delete_repo(
convert_config.repo_id,
token=convert_config.token,
missing_ok=True,
)
api.create_repo(
convert_config.repo_id,
token=convert_config.token,
Expand All @@ -91,9 +97,14 @@ class ConvertConfig:
checkpoint_path: str
output_path: str = Config.model_type
push_to_hub: bool = False
repo_id: str = f"multimolecule/{output_path}"
delete_existing: bool = False
repo_id: Optional[str] = None
token: Optional[str] = None

def post(self):
if self.repo_id is None:
self.repo_id = f"multimolecule/{self.output_path}"


if __name__ == "__main__":
config = ConvertConfig()
Expand Down
9 changes: 8 additions & 1 deletion multimolecule/models/splicebert/convert_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,12 @@ def convert_checkpoint(convert_config):
if HfApi is None:
raise ImportError("Please install huggingface_hub to push to the hub.")
api = HfApi()
if convert_config.delete_existing:
api.delete_repo(
convert_config.repo_id,
token=convert_config.token,
missing_ok=True,
)
api.create_repo(
convert_config.repo_id,
token=convert_config.token,
Expand All @@ -104,7 +110,8 @@ class ConvertConfig:
checkpoint_path: str
output_path: Optional[str] = None
push_to_hub: bool = False
repo_id: Optional[str] = output_path
delete_existing: bool = False
repo_id: Optional[str] = None
token: Optional[str] = None

def post(self):
Expand Down
9 changes: 8 additions & 1 deletion multimolecule/models/utrbert/convert_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,12 @@ def convert_checkpoint(convert_config):
if HfApi is None:
raise ImportError("Please install huggingface_hub to push to the hub.")
api = HfApi()
if convert_config.delete_existing:
api.delete_repo(
convert_config.repo_id,
token=convert_config.token,
missing_ok=True,
)
api.create_repo(
convert_config.repo_id,
token=convert_config.token,
Expand All @@ -107,7 +113,8 @@ class ConvertConfig:
checkpoint_path: str
output_path: Optional[str] = None
push_to_hub: bool = False
repo_id: Optional[str] = output_path
delete_existing: bool = False
repo_id: Optional[str] = None
token: Optional[str] = None

def post(self):
Expand Down
13 changes: 12 additions & 1 deletion multimolecule/models/utrlm/convert_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,12 @@ def convert_checkpoint(convert_config):
if HfApi is None:
raise ImportError("Please install huggingface_hub to push to the hub.")
api = HfApi()
if convert_config.delete_existing:
api.delete_repo(
convert_config.repo_id,
token=convert_config.token,
missing_ok=True,
)
api.create_repo(
convert_config.repo_id,
token=convert_config.token,
Expand All @@ -110,9 +116,14 @@ class ConvertConfig:
checkpoint_path: str
output_path: str = Config.model_type
push_to_hub: bool = False
repo_id: str = f"multimolecule/{output_path}"
delete_existing: bool = False
repo_id: Optional[str] = None
token: Optional[str] = None

def post(self):
if self.repo_id is None:
self.repo_id = f"multimolecule/{self.output_path}"


if __name__ == "__main__":
config = ConvertConfig()
Expand Down

0 comments on commit 1d6f318

Please sign in to comment.