Skip to content

Commit

Permalink
reorganise rnabert
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhiyuanChen committed Apr 2, 2024
1 parent 18503fa commit 186eed0
Show file tree
Hide file tree
Showing 6 changed files with 458 additions and 269 deletions.
18 changes: 16 additions & 2 deletions multimolecule/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
from .rnabert import RnaBertConfig, RnaBertModel, RnaTokenizer
from .rnabert import (
RnaBertConfig,
RnaBertForMaskedLM,
RnaBertForSequenceClassification,
RnaBertForTokenClassification,
RnaBertModel,
RnaTokenizer,
)

__all__ = ["RnaBertConfig", "RnaBertModel", "RnaTokenizer"]
__all__ = [
"RnaBertConfig",
"RnaBertModel",
"RnaBertForMaskedLM",
"RnaBertForSequenceClassification",
"RnaBertForTokenClassification",
"RnaTokenizer",
]
30 changes: 27 additions & 3 deletions multimolecule/models/rnabert/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,36 @@
from transformers import AutoConfig, AutoModel, AutoTokenizer
from transformers import (
AutoConfig,
AutoModel,
AutoModelForMaskedLM,
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
AutoModelWithLMHead,
AutoTokenizer,
)

from multimolecule.tokenizers.rna import RnaTokenizer

from .configuration_rnabert import RnaBertConfig
from .modeling_rnabert import RnaBertModel
from .modeling_rnabert import (
RnaBertForMaskedLM,
RnaBertForSequenceClassification,
RnaBertForTokenClassification,
RnaBertModel,
)

__all__ = ["RnaBertConfig", "RnaBertModel", "RnaTokenizer"]
__all__ = [
"RnaBertConfig",
"RnaBertModel",
"RnaTokenizer",
"RnaBertForMaskedLM",
"RnaBertForSequenceClassification",
"RnaBertForTokenClassification",
]

AutoConfig.register("rnabert", RnaBertConfig)
AutoModel.register(RnaBertConfig, RnaBertModel)
AutoModelForMaskedLM.register(RnaBertConfig, RnaBertForMaskedLM)
AutoModelForSequenceClassification.register(RnaBertConfig, RnaBertForSequenceClassification)
AutoModelForTokenClassification.register(RnaBertConfig, RnaBertForTokenClassification)
AutoModelWithLMHead.register(RnaBertConfig, RnaBertForTokenClassification)
AutoTokenizer.register(RnaBertConfig, RnaTokenizer)
9 changes: 7 additions & 2 deletions multimolecule/models/rnabert/configuration_rnabert.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class RnaBertConfig(PretrainedConfig):
This is the configuration class to store the configuration of a [`RnaBertModel`]. It is used to instantiate a
RnaBert model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the RnaBert
[mana438/RNABERT](https://github.com/mana438/RNABERT/blob/master/RNA_bert_config.json) architecture.
[mana438/RNABERT](https://github.com/mana438/RNABERT) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Expand Down Expand Up @@ -55,7 +55,8 @@ class RnaBertConfig(PretrainedConfig):
>>> # Initializing a model from the configuration >>> model = RnaBertModel(configuration)
>>> # Accessing the model configuration >>> configuration = model.config
```"""
```
"""

model_type = "rnabert"

Expand All @@ -77,6 +78,8 @@ def __init__(
pad_token_id=0,
position_embedding_type="absolute",
use_cache=True,
classifier_dropout=None,
proj_head_mode="nonlinear",
**kwargs,
):
super().__init__(pad_token_id=pad_token_id, **kwargs)
Expand All @@ -97,3 +100,5 @@ def __init__(
self.layer_norm_eps = layer_norm_eps
self.position_embedding_type = position_embedding_type
self.use_cache = use_cache
self.classifier_dropout = classifier_dropout
self.proj_head_mode = proj_head_mode
92 changes: 68 additions & 24 deletions multimolecule/models/rnabert/convert_checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
import os
import sys
from typing import Optional

import chanfig
import torch
from torch import nn

from multimolecule.models import RnaBertConfig, RnaBertModel
from multimolecule.models import RnaBertConfig as Config
from multimolecule.models import RnaBertForMaskedLM as Model
from multimolecule.tokenizers.rna.utils import get_special_tokens_map, get_tokenizer_config, get_vocab_list

try:
from huggingface_hub import HfApi
except:
HfApi = None


CONFIG = {
"architectures": ["RnaBertModel"],
"attention_probs_dropout_prob": 0.0,
Expand All @@ -19,7 +25,6 @@
"max_position_embeddings": 440,
"num_attention_heads": 12,
"num_hidden_layers": 6,
"vocab_size": 25,
"ss_vocab_size": 8,
"type_vocab_size": 2,
"pad_token_id": 0,
Expand All @@ -29,38 +34,77 @@
vocab_list = get_vocab_list()


def convert_checkpoint(checkpoint_path: str, output_path: Optional[str] = None):
if output_path is None:
output_path = "rnabert"
config = RnaBertConfig.from_dict(chanfig.FlatDict(CONFIG))
ckpt = torch.load(checkpoint_path, map_location=torch.device("cpu"))
bert_state_dict = ckpt
def _convert_checkpoint(original_state_dict):
state_dict = {}

model = RnaBertModel(config)

for key, value in bert_state_dict.items():
if key.startswith("module.cls"):
continue
key = key[12:]
for key, value in original_state_dict.items():
key = key[7:]
key = key.replace("LayerNorm", "layer_norm")
key = key.replace("gamma", "weight")
key = key.replace("beta", "bias")
state_dict[key] = value
if key.startswith("bert"):
state_dict["rna" + key] = value
continue
if key.startswith("cls"):
key = "lm_head." + key[4:]
state_dict[key] = value
continue

word_embed = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
word_embed_weight = word_embed.weight.data
predictions_bias = torch.zeros(config.vocab_size)
predictions_decoder_weight = torch.zeros((config.vocab_size, config.hidden_size))
# nn.init.normal_(pos_embed.weight, std=0.02)
for original_token, new_token in zip(original_vocab_list, vocab_list):
original_index = original_vocab_list.index(original_token)
new_index = vocab_list.index(new_token)
word_embed.weight.data[new_index] = state_dict["embeddings.word_embeddings.weight"][original_index]
state_dict["embeddings.word_embeddings.weight"] = word_embed.weight.data
word_embed_weight[new_index] = state_dict["rnabert.embeddings.word_embeddings.weight"][original_index]
predictions_bias[new_index] = state_dict["lm_head.predictions.bias"][original_index]
predictions_decoder_weight[new_index] = state_dict["lm_head.predictions.decoder.weight"][original_index]
state_dict["rnabert.embeddings.word_embeddings.weight"] = word_embed_weight
state_dict["lm_head.predictions.bias"] = predictions_bias
state_dict["lm_head.predictions.decoder.weight"] = predictions_decoder_weight
return state_dict


def convert_checkpoint(convert_config):
config = Config.from_dict(chanfig.FlatDict(CONFIG))
config.vocab_size = len(vocab_list)

model = Model(config)

ckpt = torch.load(convert_config.checkpoint_path, map_location=torch.device("cpu"))
state_dict = _convert_checkpoint(ckpt)

model.load_state_dict(state_dict)
model.save_pretrained(output_path, safe_serialization=True)
model.save_pretrained(output_path, safe_serialization=False)
chanfig.NestedDict(get_special_tokens_map()).json(os.path.join(output_path, "special_tokens_map.json"))
chanfig.NestedDict(get_tokenizer_config()).json(os.path.join(output_path, "tokenizer_config.json"))
model.save_pretrained(convert_config.output_path, safe_serialization=True)
model.save_pretrained(convert_config.output_path, safe_serialization=False)
chanfig.NestedDict(get_special_tokens_map()).json(
os.path.join(convert_config.output_path, "special_tokens_map.json")
)
chanfig.NestedDict(get_tokenizer_config()).json(os.path.join(convert_config.output_path, "tokenizer_config.json"))

if convert_config.push_to_hub:
if HfApi is None:
raise ImportError("Please install huggingface_hub to push to the hub.")
api = HfApi()
api.create_repo(
convert_config.repo_id,
token=convert_config.token,
exist_ok=True,
)
api.upload_folder(repo_id=convert_config.repo_id, path=convert_config.output_path, token=convert_config.token)


@chanfig.configclass
class ConvertConfig:
checkpoint_path: str
output_path: str = Config.model_type
push_to_hub: bool = False
repo_id: str = "ZhiyuanChen" + output_path
token: Optional[str] = None


if __name__ == "__main__":
convert_checkpoint(sys.argv[1], sys.argv[2] if len(sys.argv) > 2 else None)
config = ConvertConfig()
config.parse()
convert_checkpoint(config)
Loading

0 comments on commit 186eed0

Please sign in to comment.