From 4c5b26e5f71c4b6b7adee11c154dccea4652e118 Mon Sep 17 00:00:00 2001 From: Zhiyuan Chen Date: Tue, 26 Mar 2024 12:30:48 +0800 Subject: [PATCH] DNM Signed-off-by: Zhiyuan Chen --- multimolecule/__init__.py | 0 multimolecule/models/__init__.py | 3 + multimolecule/models/rnabert/__init__.py | 5 + multimolecule/models/rnabert/config.json | 25 ++ .../models/rnabert/configuration_rnabert.py | 106 +++++ .../models/rnabert/convert_checkpoint.py | 34 ++ .../models/rnabert/modeling_rnabert.py | 420 ++++++++++++++++++ .../models/rnabert/special_tokens_map.json | 16 + .../models/rnabert/tokenization_rnabert.py | 133 ++++++ .../models/rnabert/tokenizer_config.json | 27 ++ multimolecule/models/rnabert/vocab.txt | 6 + 11 files changed, 775 insertions(+) create mode 100644 multimolecule/__init__.py create mode 100644 multimolecule/models/__init__.py create mode 100644 multimolecule/models/rnabert/__init__.py create mode 100644 multimolecule/models/rnabert/config.json create mode 100644 multimolecule/models/rnabert/configuration_rnabert.py create mode 100644 multimolecule/models/rnabert/convert_checkpoint.py create mode 100644 multimolecule/models/rnabert/modeling_rnabert.py create mode 100644 multimolecule/models/rnabert/special_tokens_map.json create mode 100644 multimolecule/models/rnabert/tokenization_rnabert.py create mode 100644 multimolecule/models/rnabert/tokenizer_config.json create mode 100644 multimolecule/models/rnabert/vocab.txt diff --git a/multimolecule/__init__.py b/multimolecule/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/multimolecule/models/__init__.py b/multimolecule/models/__init__.py new file mode 100644 index 00000000..e879c6b3 --- /dev/null +++ b/multimolecule/models/__init__.py @@ -0,0 +1,3 @@ +from .rnabert import RnaBertConfig, RnaBertModel, RnaBertTokenizer + +__all__ = ["RnaBertConfig", "RnaBertModel", "RnaBertTokenizer"] diff --git a/multimolecule/models/rnabert/__init__.py b/multimolecule/models/rnabert/__init__.py new file mode 100644 index 00000000..ac898ed6 --- /dev/null +++ b/multimolecule/models/rnabert/__init__.py @@ -0,0 +1,5 @@ +from .configuration_rnabert import RnaBertConfig +from .modeling_rnabert import RnaBertModel +from .tokenization_rnabert import RnaBertTokenizer + +__all__ = ["RnaBertConfig", "RnaBertModel", "RnaBertTokenizer"] diff --git a/multimolecule/models/rnabert/config.json b/multimolecule/models/rnabert/config.json new file mode 100644 index 00000000..01a47640 --- /dev/null +++ b/multimolecule/models/rnabert/config.json @@ -0,0 +1,25 @@ +{ + "architectures": ["RnaBertModel"], + "attention_probs_dropout_prob": 0.0, + "emb_layer_norm_before": null, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.0, + "hidden_size": 120, + "initializer_range": 0.02, + "intermediate_size": 40, + "layer_norm_eps": 1e-12, + "mask_token_id": null, + "max_position_embeddings": 440, + "model_type": "rnabert", + "num_attention_heads": 12, + "num_hidden_layers": 6, + "position_embedding_type": "absolute", + "ss_size": 8, + "token_dropout": false, + "torch_dtype": "float32", + "transformers_version": "4.39.1", + "type_vocab_size": 2, + "use_cache": true, + "vocab_list": ["", "", "A", "T", "G", "C"], + "vocab_size": 6 +} diff --git a/multimolecule/models/rnabert/configuration_rnabert.py b/multimolecule/models/rnabert/configuration_rnabert.py new file mode 100644 index 00000000..d7bbaaac --- /dev/null +++ b/multimolecule/models/rnabert/configuration_rnabert.py @@ -0,0 +1,106 @@ +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class RnaBertConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`RnaBertModel`]. It is used to instantiate a + RnaBert model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the RnaBert + [mana438/RNABERT](https://github.com/mana438/RNABERT/blob/master/RNA_bert_config.json) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*): + Vocabulary size of the RnaBert model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`RnaBertModel`]. + mask_token_id (`int`, *optional*): + The index of the mask token in the vocabulary. This must be included in the config because of the + "mask-dropout" scaling trick, which will scale the inputs depending on the number of masked tokens. + pad_token_id (`int`, *optional*): + The index of the padding token in the vocabulary. This must be included in the config because certain parts + of the RnaBert code use this instead of the attention mask. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (`int`, *optional*, defaults to 1026): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + emb_layer_norm_before (`bool`, *optional*): + Whether to apply layer normalization after embeddings but before the main stem of the network. + token_dropout (`bool`, defaults to `False`): + When this is enabled, masked tokens are treated as if they had been dropped out by input dropout. + + Examples: + + ```python + >>> from transformers import RnaBertModel, RnaBertConfig + + >>> # Initializing a RnaBert style configuration >>> configuration = RnaBertConfig() + + >>> # Initializing a model from the configuration >>> model = RnaBertModel(configuration) + + >>> # Accessing the model configuration >>> configuration = model.config + ```""" + + model_type = "rnabert" + + def __init__( + self, + vocab_size=None, + mask_token_id=None, + pad_token_id=None, + hidden_size=None, + multiple=None, + num_hidden_layers=6, + num_attention_heads=12, + intermediate_size=40, + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + max_position_embeddings=440, + initializer_range=0.02, + layer_norm_eps=1e-12, + emb_layer_norm_before=None, + token_dropout=False, + vocab_list=None, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, mask_token_id=mask_token_id, **kwargs) + + self.vocab_size = vocab_size + if hidden_size is None: + hidden_size = num_attention_heads * multiple if multiple is not None else 120 + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.emb_layer_norm_before = emb_layer_norm_before + self.token_dropout = token_dropout + self.vocab_list = vocab_list + + +def get_default_vocab_list(): + return ["", "", "A", "T", "G", "C"] diff --git a/multimolecule/models/rnabert/convert_checkpoint.py b/multimolecule/models/rnabert/convert_checkpoint.py new file mode 100644 index 00000000..99dbc0ff --- /dev/null +++ b/multimolecule/models/rnabert/convert_checkpoint.py @@ -0,0 +1,34 @@ +import sys +from typing import Optional + +import chanfig +import torch +from . import RnaBertConfig, RnaBertModel +from .configuration_rnabert import get_default_vocab_list + + +def convert_checkpoint(checkpoint_path: str, output_path: Optional[str] = None): + if output_path is None: + output_path = "rnabert" + config = RnaBertConfig.from_dict(chanfig.load("config.json")) + config.vocab_list = get_default_vocab_list() + ckpt = torch.load(checkpoint_path) + bert_state_dict = ckpt + state_dict = {} + + model = RnaBertModel(config) + + for key, value in bert_state_dict.items(): + if key.startswith("module.cls"): + continue + key = key[12:] + key = key.replace("gamma", "weight") + key = key.replace("beta", "bias") + state_dict[key] = value + + model.load_state_dict(state_dict) + model.save_pretrained(output_path) + + +if __name__ == "__main__": + convert_checkpoint(sys.argv[1], sys.argv[2] if len(sys.argv) > 2 else None) diff --git a/multimolecule/models/rnabert/modeling_rnabert.py b/multimolecule/models/rnabert/modeling_rnabert.py new file mode 100644 index 00000000..10926819 --- /dev/null +++ b/multimolecule/models/rnabert/modeling_rnabert.py @@ -0,0 +1,420 @@ +import math + +import torch +from torch import nn +from transformers import PreTrainedModel + +from .configuration_rnabert import RnaBertConfig + + +class RnaBertLayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-12): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) # weightのこと + self.bias = nn.Parameter(torch.zeros(hidden_size)) # biasのこと + self.variance_epsilon = eps + + def forward(self, x): + u = x.mean(-1, keepdim=True) + s = (x - u).pow(2).mean(-1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.variance_epsilon) + return self.weight * x + self.bias + + +class RnaBertEmbeddings(nn.Module): + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + self.LayerNorm = RnaBertLayerNorm(config.hidden_size, eps=1e-12) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, input_ids, token_type_ids=None): + words_embeddings = self.word_embeddings(input_ids) + + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + seq_length = input_ids.size(1) + position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + position_embeddings = self.position_embeddings(position_ids) + + embeddings = words_embeddings + position_embeddings + token_type_embeddings + + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + + return embeddings + + +class RnaBertLayer(nn.Module): + def __init__(self, config): + super().__init__() + + self.attention = RnaBertAttention(config) + + self.intermediate = RnaBertIntermediate(config) + + self.output = RnaBertOutput(config) + + def forward(self, hidden_states, attention_mask, attention_show_flg=False): + if attention_show_flg: + attention_output, attention_probs = self.attention(hidden_states, attention_mask, attention_show_flg) + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output, attention_probs + else: + attention_output = self.attention(hidden_states, attention_mask, attention_show_flg) + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output # [batch_size, seq_length, hidden_size] + + +class RnaBertAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.selfattn = RnaBertSelfAttention(config) + self.output = RnaBertSelfOutput(config) + + def forward(self, input_tensor, attention_mask, attention_show_flg=False): + if attention_show_flg: + self_output, attention_probs = self.selfattn(input_tensor, attention_mask, attention_show_flg) + attention_output = self.output(self_output, input_tensor) + return attention_output, attention_probs + else: + self_output = self.selfattn(input_tensor, attention_mask, attention_show_flg) + attention_output = self.output(self_output, input_tensor) + return attention_output + + +class RnaBertSelfAttention(nn.Module): + def __init__(self, config): + super().__init__() + + self.num_attention_heads = config.num_attention_heads + # num_attention_heads': 12 + + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states, attention_mask, attention_show_flg=False): + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + + attention_scores = attention_scores + attention_mask + + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + attention_probs = self.dropout(attention_probs) + + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + if attention_show_flg: + return context_layer, attention_probs + else: + return context_layer + + +class RnaBertSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = RnaBertLayerNorm(config.hidden_size, eps=1e-12) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +def gelu(x): + return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) + + +class RnaBertIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + + self.intermediate_act_fn = gelu + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class RnaBertOutput(nn.Module): + def __init__(self, config): + super().__init__() + + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + + self.LayerNorm = RnaBertLayerNorm(config.hidden_size, eps=1e-12) + + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class RnaBertEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.layer = nn.ModuleList([RnaBertLayer(config) for _ in range(config.num_hidden_layers)]) + # self.layer = nn.ModuleList([nn.Linear(config.hidden_size, config.hidden_size) + # for _ in range(config.num_hidden_layers)]) + + def forward( + self, + hidden_states, + attention_mask, + output_all_encoded_layers=True, + attention_show_flg=False, + ): + all_encoder_layers = [] + for layer in self.layer: + if attention_show_flg: + hidden_states, attention_probs = layer(hidden_states, attention_mask, attention_show_flg) + else: + hidden_states = layer(hidden_states, attention_mask, attention_show_flg) + if output_all_encoded_layers: + all_encoder_layers.append(hidden_states) + + if not output_all_encoded_layers: + all_encoder_layers.append(hidden_states) + + if attention_show_flg: + return all_encoder_layers, attention_probs + else: + return all_encoder_layers + + +class RnaBertPooler(nn.Module): + def __init__(self, config): + super().__init__() + + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + first_token_tensor = hidden_states[:, 0] + + pooled_output = self.dense(first_token_tensor) + + pooled_output = self.activation(pooled_output) + + return pooled_output + + +class RnaBertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = RnaBertConfig + base_model_prefix = "rnabert" + supports_gradient_checkpointing = True + _no_split_modules = ["RnaBertLayer", "RnaBertFoldTriangularSelfAttentionBlock", "RnaBertEmbeddings"] + + # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +class RnaBertModel(RnaBertPreTrainedModel): + + def __init__(self, config): + super().__init__(config) + self.embeddings = RnaBertEmbeddings(config) + self.encoder = RnaBertEncoder(config) + self.pooler = RnaBertPooler(config) + + def forward( + self, + input_ids, + token_type_ids=None, + attention_mask=None, + output_all_encoded_layers=True, + attention_show_flg=False, + ): + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + + extended_attention_mask = extended_attention_mask.to(dtype=torch.float32) + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + + embedding_output = self.embeddings(input_ids, token_type_ids) + + if attention_show_flg: + encoded_layers, attention_probs = self.encoder( + embedding_output, + extended_attention_mask, + output_all_encoded_layers, + attention_show_flg, + ) + else: + encoded_layers = self.encoder( + embedding_output, + extended_attention_mask, + output_all_encoded_layers, + attention_show_flg, + ) + + pooled_output = self.pooler(encoded_layers[-1]) + + if not output_all_encoded_layers: + encoded_layers = encoded_layers[-1] + + if attention_show_flg: + return encoded_layers, pooled_output, attention_probs + else: + return encoded_layers, pooled_output + + +class RnaBertPreTrainingHeads(nn.Module): + def __init__(self, config): + super().__init__() + + self.predictions = MaskedWordPredictions(config) + config.vocab_size = config.ss_size + self.predictions_ss = MaskedWordPredictions(config) + + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, sequence_output, pooled_output): + prediction_scores = self.predictions(sequence_output) + prediction_scores_ss = self.predictions_ss(sequence_output) + + seq_relationship_score = self.seq_relationship(pooled_output) + + return prediction_scores, prediction_scores_ss, seq_relationship_score + + +class MaskedWordPredictions(nn.Module): + def __init__(self, config): + super().__init__() + + self.transform = RnaBertPredictionHeadTransform(config) + + self.decoder = nn.Linear(in_features=config.hidden_size, out_features=config.vocab_size, bias=False) + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + self.bias + + return hidden_states + + +class RnaBertPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + + self.transform_act_fn = gelu + + self.LayerNorm = RnaBertLayerNorm(config.hidden_size, eps=1e-12) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + # hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class SeqRelationship(nn.Module): + def __init__(self, config, out_features): + super().__init__() + + self.seq_relationship = nn.Linear(config.hidden_size, out_features) + + def forward(self, pooled_output): + return self.seq_relationship(pooled_output) + + +class RnaBertForMaskedLM(nn.Module): + def __init__(self, config): + super().__init__() + self.bert = RnaBertModel(config) + self.cls = RnaBertPreTrainingHeads(config) + + def forward( + self, + input_ids, + token_type_ids=None, + attention_mask=None, + attention_show_flg=False, + ): + if attention_show_flg: + encoded_layers, pooled_output, attention_probs = self.bert( + input_ids, + token_type_ids, + attention_mask, + output_all_encoded_layers=False, + attention_show_flg=True, + ) + else: + encoded_layers, pooled_output = self.bert( + input_ids, + token_type_ids, + attention_mask, + output_all_encoded_layers=False, + attention_show_flg=False, + ) + + prediction_scores, prediction_scores_ss, seq_relationship_score = self.cls(encoded_layers, pooled_output) + return prediction_scores, prediction_scores_ss, encoded_layers diff --git a/multimolecule/models/rnabert/special_tokens_map.json b/multimolecule/models/rnabert/special_tokens_map.json new file mode 100644 index 00000000..934c898f --- /dev/null +++ b/multimolecule/models/rnabert/special_tokens_map.json @@ -0,0 +1,16 @@ +{ + "pad_token": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + }, + "mask_token": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + } +} diff --git a/multimolecule/models/rnabert/tokenization_rnabert.py b/multimolecule/models/rnabert/tokenization_rnabert.py new file mode 100644 index 00000000..efcecff7 --- /dev/null +++ b/multimolecule/models/rnabert/tokenization_rnabert.py @@ -0,0 +1,133 @@ +import os +from typing import List, Optional + +from transformers.tokenization_utils import PreTrainedTokenizer +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} + +PRETRAINED_VOCAB_FILES_MAP = { + "vocab_file": { + "ZhiyuanChen/rnabert": "https://huggingface.co/ZhiyuanChen/rnabert/resolve/main/vocab.txt", + }, +} + +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { + "ZhiyuanChen/rnabert": 1024, +} + + +def load_vocab_file(vocab_file): + with open(vocab_file) as f: + lines = f.read().splitlines() + return [l.strip() for l in lines] # noqa: E741 + + +class RnaBertTokenizer(PreTrainedTokenizer): + """ + Constructs an RnaBert tokenizer. + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + pad_token="", + mask_token="", + **kwargs, + ): + self.all_tokens = load_vocab_file(vocab_file) + self._id_to_token = dict(enumerate(self.all_tokens)) + self._token_to_id = {tok: ind for ind, tok in enumerate(self.all_tokens)} + super().__init__( + pad_token=pad_token, + mask_token=mask_token, + **kwargs, + ) + + # TODO, all the tokens are added? But they are also part of the vocab... bit strange. + # none of them are special, but they all need special splitting. + + self.unique_no_split_tokens = self.all_tokens + self._update_trie(self.unique_no_split_tokens) + + def _convert_id_to_token(self, index: int) -> str: + return self._id_to_token.get(index, self.unk_token) + + def _convert_token_to_id(self, token: str) -> int: + return self._token_to_id.get(token, self._token_to_id.get(self.unk_token)) + + def _tokenize(self, text, **kwargs): + return text.split() + + def get_vocab(self): + base_vocab = self._token_to_id.copy() + base_vocab.update(self.added_tokens_encoder) + return base_vocab + + def token_to_id(self, token: str) -> int: + return self._token_to_id.get(token, self._token_to_id.get(self.unk_token)) + + def id_to_token(self, index: int) -> str: + return self._id_to_token.get(index, self.unk_token) + + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + cls = [self.cls_token_id] + sep = [self.eos_token_id] # No sep token in RnaBert vocabulary + if token_ids_1 is None: + if self.eos_token_id is None: + return cls + token_ids_0 + else: + return cls + token_ids_0 + sep + elif self.eos_token_id is None: + raise ValueError("Cannot tokenize multiple sequences when EOS token is not set!") + return cls + token_ids_0 + sep + token_ids_1 + sep # Multiple inputs always have an EOS token + + def get_special_tokens_mask( + self, token_ids_0: List, token_ids_1: Optional[List] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods. + + Args: + token_ids_0 (`List[int]`): + List of ids of the first sequence. + token_ids_1 (`List[int]`, *optional*): + List of ids of the second sequence. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + if token_ids_1 is not None: + raise ValueError( + "You should not supply a second sequence if the provided sequence of " + "ids is already formatted with special tokens for the model." + ) + + return [1 if token in self.all_special_ids else 0 for token in token_ids_0] + mask = [1] + ([0] * len(token_ids_0)) + [1] + if token_ids_1 is not None: + mask += [0] * len(token_ids_1) + [1] + return mask + + def save_vocabulary(self, save_directory, filename_prefix): + vocab_file = os.path.join(save_directory, (filename_prefix + "-" if filename_prefix else "") + "vocab.txt") + with open(vocab_file, "w") as f: + f.write("\n".join(self.all_tokens)) + return (vocab_file,) + + @property + def vocab_size(self) -> int: + return len(self.all_tokens) diff --git a/multimolecule/models/rnabert/tokenizer_config.json b/multimolecule/models/rnabert/tokenizer_config.json new file mode 100644 index 00000000..f9d62b74 --- /dev/null +++ b/multimolecule/models/rnabert/tokenizer_config.json @@ -0,0 +1,27 @@ +{ + "added_tokens_decoder": { + "0": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "1": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + } + }, + "clean_up_tokenization_spaces": true, + "pad_token": "", + "mask_token": "", + "cls_token": "", + "unk_token": "", + "model_max_length": 440, + "tokenizer_class": "RnaBertTokenizer" +} diff --git a/multimolecule/models/rnabert/vocab.txt b/multimolecule/models/rnabert/vocab.txt new file mode 100644 index 00000000..74f773f5 --- /dev/null +++ b/multimolecule/models/rnabert/vocab.txt @@ -0,0 +1,6 @@ + + +A +U +G +C \ No newline at end of file