Skip to content

Commit

Permalink
reorganise rnabert
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhiyuanChen committed Apr 2, 2024
1 parent 18503fa commit 531782c
Show file tree
Hide file tree
Showing 6 changed files with 452 additions and 265 deletions.
18 changes: 16 additions & 2 deletions multimolecule/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
from .rnabert import RnaBertConfig, RnaBertModel, RnaTokenizer
from .rnabert import (
RnaBertConfig,
RnaBertForMaskedLM,
RnaBertForSequenceClassification,
RnaBertForTokenClassification,
RnaBertModel,
RnaTokenizer,
)

__all__ = ["RnaBertConfig", "RnaBertModel", "RnaTokenizer"]
__all__ = [
"RnaBertConfig",
"RnaBertModel",
"RnaBertForMaskedLM",
"RnaBertForSequenceClassification",
"RnaBertForTokenClassification",
"RnaTokenizer",
]
30 changes: 27 additions & 3 deletions multimolecule/models/rnabert/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,36 @@
from transformers import AutoConfig, AutoModel, AutoTokenizer
from transformers import (
AutoConfig,
AutoModel,
AutoModelForMaskedLM,
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
AutoModelWithLMHead,
AutoTokenizer,
)

from multimolecule.tokenizers.rna import RnaTokenizer

from .configuration_rnabert import RnaBertConfig
from .modeling_rnabert import RnaBertModel
from .modeling_rnabert import (
RnaBertForMaskedLM,
RnaBertForSequenceClassification,
RnaBertForTokenClassification,
RnaBertModel,
)

__all__ = ["RnaBertConfig", "RnaBertModel", "RnaTokenizer"]
__all__ = [
"RnaBertConfig",
"RnaBertModel",
"RnaTokenizer",
"RnaBertForMaskedLM",
"RnaBertForSequenceClassification",
"RnaBertForTokenClassification",
]

AutoConfig.register("rnabert", RnaBertConfig)
AutoModel.register(RnaBertConfig, RnaBertModel)
AutoModelForMaskedLM.register(RnaBertConfig, RnaBertForMaskedLM)
AutoModelForSequenceClassification.register(RnaBertConfig, RnaBertForSequenceClassification)
AutoModelForTokenClassification.register(RnaBertConfig, RnaBertForTokenClassification)
AutoModelWithLMHead.register(RnaBertConfig, RnaBertForTokenClassification)
AutoTokenizer.register(RnaBertConfig, RnaTokenizer)
9 changes: 7 additions & 2 deletions multimolecule/models/rnabert/configuration_rnabert.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class RnaBertConfig(PretrainedConfig):
This is the configuration class to store the configuration of a [`RnaBertModel`]. It is used to instantiate a
RnaBert model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the RnaBert
[mana438/RNABERT](https://github.com/mana438/RNABERT/blob/master/RNA_bert_config.json) architecture.
[mana438/RNABERT](https://github.com/mana438/RNABERT) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Expand Down Expand Up @@ -55,7 +55,8 @@ class RnaBertConfig(PretrainedConfig):
>>> # Initializing a model from the configuration >>> model = RnaBertModel(configuration)
>>> # Accessing the model configuration >>> configuration = model.config
```"""
```
"""

model_type = "rnabert"

Expand All @@ -77,6 +78,8 @@ def __init__(
pad_token_id=0,
position_embedding_type="absolute",
use_cache=True,
classifier_dropout=None,
proj_head_mode="nonlinear",
**kwargs,
):
super().__init__(pad_token_id=pad_token_id, **kwargs)
Expand All @@ -97,3 +100,5 @@ def __init__(
self.layer_norm_eps = layer_norm_eps
self.position_embedding_type = position_embedding_type
self.use_cache = use_cache
self.classifier_dropout = classifier_dropout
self.proj_head_mode = proj_head_mode
82 changes: 62 additions & 20 deletions multimolecule/models/rnabert/convert_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,15 @@
import torch
from torch import nn

from multimolecule.models import RnaBertConfig, RnaBertModel
from multimolecule.models import RnaBertConfig, RnaBertForMaskedLM
from multimolecule.tokenizers.rna.utils import get_special_tokens_map, get_tokenizer_config, get_vocab_list

try:
from huggingface_hub import HfApi
except:
HfApi = None


CONFIG = {
"architectures": ["RnaBertModel"],
"attention_probs_dropout_prob": 0.0,
Expand All @@ -19,7 +25,6 @@
"max_position_embeddings": 440,
"num_attention_heads": 12,
"num_hidden_layers": 6,
"vocab_size": 25,
"ss_vocab_size": 8,
"type_vocab_size": 2,
"pad_token_id": 0,
Expand All @@ -29,38 +34,75 @@
vocab_list = get_vocab_list()


def convert_checkpoint(checkpoint_path: str, output_path: Optional[str] = None):
if output_path is None:
output_path = "rnabert"
def convert_checkpoint(convert_config):
config = RnaBertConfig.from_dict(chanfig.FlatDict(CONFIG))
ckpt = torch.load(checkpoint_path, map_location=torch.device("cpu"))
bert_state_dict = ckpt
config.vocab_size = len(vocab_list)
ckpt = torch.load(convert_config.checkpoint_path, map_location=torch.device("cpu"))
original_state_dict = ckpt
state_dict = {}

model = RnaBertModel(config)
model = RnaBertForMaskedLM(config)

for key, value in bert_state_dict.items():
if key.startswith("module.cls"):
continue
key = key[12:]
for key, value in original_state_dict.items():
key = key[7:]
key = key.replace("LayerNorm", "layer_norm")
key = key.replace("gamma", "weight")
key = key.replace("beta", "bias")
state_dict[key] = value
if key.startswith("bert"):
state_dict["rna" + key] = value
continue
if key.startswith("cls"):
# import ipdb; ipdb.set_trace()
key = "lm_head." + key[4:]
# key = key[4:]
state_dict[key] = value
continue

word_embed = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
word_embed_weight = word_embed.weight.data
predictions_bias = torch.zeros(config.vocab_size)
predictions_decoder_weight = torch.zeros((config.vocab_size, config.hidden_size))
# nn.init.normal_(pos_embed.weight, std=0.02)
for original_token, new_token in zip(original_vocab_list, vocab_list):
original_index = original_vocab_list.index(original_token)
new_index = vocab_list.index(new_token)
word_embed.weight.data[new_index] = state_dict["embeddings.word_embeddings.weight"][original_index]
state_dict["embeddings.word_embeddings.weight"] = word_embed.weight.data
word_embed_weight[new_index] = state_dict["rnabert.embeddings.word_embeddings.weight"][original_index]
predictions_bias[new_index] = state_dict["lm_head.predictions.bias"][original_index]
predictions_decoder_weight[new_index] = state_dict["lm_head.predictions.decoder.weight"][original_index]
state_dict["rnabert.embeddings.word_embeddings.weight"] = word_embed_weight
state_dict["lm_head.predictions.bias"] = predictions_bias
state_dict["lm_head.predictions.decoder.weight"] = predictions_decoder_weight

model.load_state_dict(state_dict)
model.save_pretrained(output_path, safe_serialization=True)
model.save_pretrained(output_path, safe_serialization=False)
chanfig.NestedDict(get_special_tokens_map()).json(os.path.join(output_path, "special_tokens_map.json"))
chanfig.NestedDict(get_tokenizer_config()).json(os.path.join(output_path, "tokenizer_config.json"))
model.save_pretrained(convert_config.output_path, safe_serialization=True)
model.save_pretrained(convert_config.output_path, safe_serialization=False)
chanfig.NestedDict(get_special_tokens_map()).json(
os.path.join(convert_config.output_path, "special_tokens_map.json")
)
chanfig.NestedDict(get_tokenizer_config()).json(os.path.join(convert_config.output_path, "tokenizer_config.json"))

if convert_config.push_to_hub:
if HfApi is None:
raise ImportError("Please install huggingface_hub to push to the hub.")
api = HfApi()
api.create_repo(
convert_config.repo_id,
token=convert_config.auth_token,
exist_ok=True,
)
api.upload_folder(repo_id=convert_config.repo_id, path=convert_config.output_path, token=convert_config.token)


@chanfig.configclass
class ConvertConfig:
checkpoint_path: str
output_path: str = RnaBertConfig.model_type
push_to_hub: bool = False
repo_id: str = "ZhiyuanChen" + output_path
token: Optional[str] = None


if __name__ == "__main__":
convert_checkpoint(sys.argv[1], sys.argv[2] if len(sys.argv) > 2 else None)
config = ConvertConfig()
config.parse()
convert_checkpoint(config)
Loading

0 comments on commit 531782c

Please sign in to comment.