Skip to content

Commit

Permalink
add RnaFm
Browse files Browse the repository at this point in the history
Signed-off-by: Zhiyuan Chen <[email protected]>
  • Loading branch information
ZhiyuanChen committed Apr 16, 2024
1 parent 6ab3514 commit 663833d
Show file tree
Hide file tree
Showing 5 changed files with 1,352 additions and 0 deletions.
12 changes: 12 additions & 0 deletions multimolecule/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@
RnaBertForTokenClassification,
RnaBertModel,
)
from .rnafm import (
RnaFmConfig,
RnaFmForMaskedLM,
RnaFmForSequenceClassification,
RnaFmForTokenClassification,
RnaFmModel,
)
from .rnamsm import (
RnaMsmConfig,
RnaMsmForMaskedLM,
Expand All @@ -27,6 +34,11 @@
"RnaBertForMaskedLM",
"RnaBertForSequenceClassification",
"RnaBertForTokenClassification",
"RnaFmConfig",
"RnaFmForMaskedLM",
"RnaFmForSequenceClassification",
"RnaFmForTokenClassification",
"RnaFmModel",
"RnaMsmConfig",
"RnaMsmModel",
"RnaMsmForMaskedLM",
Expand Down
31 changes: 31 additions & 0 deletions multimolecule/models/rnafm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from transformers import (
AutoConfig,
AutoModel,
AutoModelForMaskedLM,
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
AutoModelWithLMHead,
AutoTokenizer,
)

from multimolecule.tokenizers.rna import RnaTokenizer

from .configuration_rnafm import RnaFmConfig
from .modeling_rnafm import RnaFmForMaskedLM, RnaFmForSequenceClassification, RnaFmForTokenClassification, RnaFmModel

__all__ = [
"RnaFmConfig",
"RnaFmModel",
"RnaTokenizer",
"RnaFmForMaskedLM",
"RnaFmForSequenceClassification",
"RnaFmForTokenClassification",
]

AutoConfig.register("rnafm", RnaFmConfig)
AutoModel.register(RnaFmConfig, RnaFmModel)
AutoModelForMaskedLM.register(RnaFmConfig, RnaFmForMaskedLM)
AutoModelForSequenceClassification.register(RnaFmConfig, RnaFmForSequenceClassification)
AutoModelForTokenClassification.register(RnaFmConfig, RnaFmForTokenClassification)
AutoModelWithLMHead.register(RnaFmConfig, RnaFmForTokenClassification)
AutoTokenizer.register(RnaFmConfig, RnaTokenizer)
130 changes: 130 additions & 0 deletions multimolecule/models/rnafm/configuration_rnafm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
from transformers.utils import logging

from ..configuration_utils import HeadConfig, MaskedLMHeadConfig, PretrainedConfig

logger = logging.get_logger(__name__)


class RnaFmConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`RnaFmModel`]. It is used to instantiate a RNA-FM
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the RNA-FM
[ml4bio/RNA-FM](https://github.com/ml4bio/RNA-FM) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*):
Vocabulary size of the RNA-FM model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`RnaFmModel`].
hidden_size (`int`, *optional*, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
num_hidden_layers (`int`, *optional*, defaults to 12):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
intermediate_size (`int`, *optional*, defaults to 3072):
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
hidden_dropout (`float`, *optional*, defaults to 0.1):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_dropout (`float`, *optional*, defaults to 0.1):
The dropout ratio for the attention probabilities.
max_position_embeddings (`int`, *optional*, defaults to 1026):
The maximum sequence length that this model might ever be used with. Typically set this to something large
just in case (e.g., 512 or 1024 or 2048).
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
The epsilon used by the layer normalization layers.
pad_token_id (`int`, *optional*, defaults to 0):
The index of the padding token in the vocabulary. This must be included in the config because certain parts
of the RnaBert code use this instead of the attention mask.
bos_token_id (`int`, *optional*, defaults to 1):
The index of the bos token in the vocabulary. This must be included in the config because of the
contact and other prediction heads removes the bos and padding token when predicting outputs.
mask_token_id (`int`, *optional*, defaults to 4):
The index of the mask token in the vocabulary. This must be included in the config because of the
"mask-dropout" scaling trick, which will scale the inputs depending on the number of masked tokens.
position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query", "rotary"`.
For positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
[Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155).
For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658).
is_decoder (`bool`, *optional*, defaults to `False`):
Whether the model is used as a decoder or not. If `False`, the model is used as an encoder.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
emb_layer_norm_before (`bool`, *optional*):
Whether to apply layer normalization after embeddings but before the main stem of the network.
token_dropout (`bool`, defaults to `False`):
When this is enabled, masked tokens are treated as if they had been dropped out by input dropout.
Examples:
```python
>>> from multimolecule import RnaFmModel, RnaFmConfig
>>> # Initializing a RNA-FM style configuration >>> configuration = RnaFmConfig()
>>> # Initializing a model from the configuration >>> model = RnaFmModel(configuration)
>>> # Accessing the model configuration >>> configuration = model.config
```
"""

model_type = "rnafm"

def __init__(
self,
vocab_size=25,
hidden_size=640,
num_hidden_layers=12,
num_attention_heads=20,
intermediate_size=5120,
hidden_act="gelu",
hidden_dropout=0.1,
attention_dropout=0.1,
max_position_embeddings=1026,
initializer_range=0.02,
layer_norm_eps=1e-12,
pad_token_id=0,
bos_token_id=1,
mask_token_id=4,
position_embedding_type="absolute",
use_cache=True,
emb_layer_norm_before=True,
token_dropout=True,
head=None,
lm_head=None,
**kwargs,
):
if head is None:
head = {}
if lm_head is None:
lm_head = {}
head.setdefault("hidden_size", hidden_size)
lm_head.setdefault("hidden_size", hidden_size)
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, mask_token_id=mask_token_id, **kwargs)

self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout = hidden_dropout
self.attention_dropout = attention_dropout
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.position_embedding_type = position_embedding_type
self.use_cache = use_cache
self.emb_layer_norm_before = emb_layer_norm_before
self.token_dropout = token_dropout
self.head = HeadConfig(**head)
self.lm_head = MaskedLMHeadConfig(**lm_head)
152 changes: 152 additions & 0 deletions multimolecule/models/rnafm/convert_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import os
from typing import Optional

import chanfig
import torch
from torch import nn

from multimolecule.models import RnaFmConfig as Config
from multimolecule.models import RnaFmForMaskedLM as Model
from multimolecule.tokenizers.rna.utils import get_special_tokens_map, get_tokenizer_config, get_vocab_list

try:
from huggingface_hub import HfApi
except ImportError:
HfApi = None


torch.manual_seed(1013)

CONFIG = {
"architectures": ["RnaFmModel"],
"attention_dropout": 0.1,
"hidden_act": "gelu",
"hidden_dropout": 0.1,
"hidden_size": 640,
"intermediate_size": 5120,
"max_position_embeddings": 1026,
"num_attention_heads": 20,
"num_hidden_layers": 12,
"max_tokens_per_msa": 2**14,
"num_labels": 1,
}

original_vocab_list = [
"<cls>",
"<pad>",
"<eos>",
"<unk>",
"A",
"C",
"G",
"U",
"R",
"Y",
"K",
"M",
"S",
"W",
"B",
"D",
"H",
"V",
"N",
"-",
"<null>",
"<null>",
"<null>",
"<null>",
"<mask>",
]
vocab_list = get_vocab_list()


def _convert_checkpoint(config, original_state_dict):
state_dict = {}
for key, value in original_state_dict.items():
key = "rnafm" + key[7:]
key = key.replace("LayerNorm", "layer_norm")
key = key.replace("gamma", "weight")
key = key.replace("beta", "bias")
key = key.replace("rnafm.encoder.emb_layer_norm_before", "rnafm.embeddings.layer_norm")
key = key.replace("rnafm.encoder.embed_tokens", "rnafm.embeddings.word_embeddings")
key = key.replace("rnafm.encoder.embed_positions", "rnafm.embeddings.position_embeddings")
key = key.replace("layers", "layer")
key = key.replace("self_attn", "attention.self")
key = key.replace("q_proj", "query")
key = key.replace("k_proj", "key")
key = key.replace("v_proj", "value")
key = key.replace("self.out_proj", "output.dense")
key = key.replace("fc1", "intermediate.dense")
key = key.replace("fc2", "output.dense")
key = key.replace("rnafm.encoder.lm_head", "lm_head")
key = key.replace("lm_head.dense", "lm_head.transform.dense")
key = key.replace("lm_head.layer_norm", "lm_head.transform.layer_norm")
key = key.replace("lm_head.weight", "lm_head.decoder.weight")
key = key.replace("rnafm.encoder.contact_head", "rnafm.contact_head")
key = key.replace("self_layer_norm", "layer_norm")
key = key.replace("final_layer_norm", "layer_norm")
key = key.replace("regression", "decoder")
state_dict[key] = value

word_embed = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
word_embed_weight = word_embed.weight.data
predictions_decoder_weight = torch.zeros((config.vocab_size, config.hidden_size))
predictions_bias = torch.zeros(config.vocab_size)
# nn.init.normal_(pos_embed.weight, std=0.02)
for original_index, original_token in enumerate(original_vocab_list):
new_index = vocab_list.index(original_token)
word_embed_weight[new_index] = state_dict["rnafm.embeddings.word_embeddings.weight"][original_index]
predictions_decoder_weight[new_index] = state_dict["lm_head.decoder.weight"][original_index]
predictions_bias[new_index] = state_dict["lm_head.bias"][original_index]
state_dict["rnafm.embeddings.word_embeddings.weight"] = word_embed_weight
state_dict["lm_head.decoder.weight"] = predictions_decoder_weight
state_dict["lm_head.bias"] = predictions_bias
state_dict["lm_head.decoder.bias"] = state_dict["lm_head.bias"]
return state_dict


def convert_checkpoint(convert_config):
config = Config.from_dict(chanfig.FlatDict(CONFIG))
config.vocab_size = len(vocab_list)

model = Model(config)

ckpt = torch.load(convert_config.checkpoint_path, map_location=torch.device("cpu"))
state_dict = _convert_checkpoint(config, ckpt)

model.load_state_dict(state_dict)
model.save_pretrained(convert_config.output_path, safe_serialization=True)
model.save_pretrained(convert_config.output_path, safe_serialization=False)
chanfig.NestedDict(get_special_tokens_map()).json(
os.path.join(convert_config.output_path, "special_tokens_map.json")
)
chanfig.NestedDict(get_tokenizer_config()).json(os.path.join(convert_config.output_path, "tokenizer_config.json"))

if convert_config.push_to_hub:
if HfApi is None:
raise ImportError("Please install huggingface_hub to push to the hub.")
api = HfApi()
api.create_repo(
convert_config.repo_id,
token=convert_config.token,
exist_ok=True,
)
api.upload_folder(
repo_id=convert_config.repo_id, folder_path=convert_config.output_path, token=convert_config.token
)


@chanfig.configclass
class ConvertConfig:
checkpoint_path: str
output_path: str = Config.model_type
push_to_hub: bool = False
repo_id: str = f"multimolecule/{output_path}"
token: Optional[str] = None


if __name__ == "__main__":
config = ConvertConfig()
config.parse() # type: ignore[attr-defined]
convert_checkpoint(config)
Loading

0 comments on commit 663833d

Please sign in to comment.