diff --git a/multimolecule/__init__.py b/multimolecule/__init__.py index 0d4bc745..650d06f2 100644 --- a/multimolecule/__init__.py +++ b/multimolecule/__init__.py @@ -1,3 +1,10 @@ -from . import models +from transformers import AutoTokenizer -__all__ = ["models"] +from . import models, tokenizers +from .models import RnaBertConfig +from .tokenizers import RnaTokenizer + +AutoTokenizer.register(RnaBertConfig, RnaTokenizer) + + +__all__ = ["models", "tokenizers"] diff --git a/multimolecule/models/__init__.py b/multimolecule/models/__init__.py index e879c6b3..9c868664 100644 --- a/multimolecule/models/__init__.py +++ b/multimolecule/models/__init__.py @@ -1,3 +1,3 @@ -from .rnabert import RnaBertConfig, RnaBertModel, RnaBertTokenizer +from .rnabert import RnaBertConfig, RnaBertModel -__all__ = ["RnaBertConfig", "RnaBertModel", "RnaBertTokenizer"] +__all__ = ["RnaBertConfig", "RnaBertModel"] diff --git a/multimolecule/models/rnabert/__init__.py b/multimolecule/models/rnabert/__init__.py index b97dff57..3a9e9ebe 100644 --- a/multimolecule/models/rnabert/__init__.py +++ b/multimolecule/models/rnabert/__init__.py @@ -1,11 +1,12 @@ from transformers import AutoConfig, AutoModel, AutoTokenizer +from multimolecule.tokenizers.rna import RnaTokenizer + from .configuration_rnabert import RnaBertConfig from .modeling_rnabert import RnaBertModel -from .tokenization_rnabert import RnaBertTokenizer -__all__ = ["RnaBertConfig", "RnaBertModel", "RnaBertTokenizer"] +__all__ = ["RnaBertConfig", "RnaBertModel"] AutoConfig.register("rnabert", RnaBertConfig) AutoModel.register(RnaBertConfig, RnaBertModel) -AutoTokenizer.register(RnaBertConfig, RnaBertTokenizer) +AutoTokenizer.register(RnaBertConfig, RnaTokenizer) diff --git a/multimolecule/models/rnabert/configuration_rnabert.py b/multimolecule/models/rnabert/configuration_rnabert.py index 6f9fd4f5..d98468e2 100644 --- a/multimolecule/models/rnabert/configuration_rnabert.py +++ b/multimolecule/models/rnabert/configuration_rnabert.py @@ -4,9 +4,6 @@ logger = logging.get_logger(__name__) -DEFAULT_VOCAB_LIST = ["", "", "A", "T", "G", "C"] - - class RnaBertConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`RnaBertModel`]. It is used to instantiate a @@ -64,34 +61,39 @@ class RnaBertConfig(PretrainedConfig): def __init__( self, - vocab_size=None, - mask_token_id=None, - pad_token_id=None, + vocab_size=25, + ss_vocab_size=8, hidden_size=None, multiple=None, num_hidden_layers=6, num_attention_heads=12, intermediate_size=40, + hidden_act="gelu", hidden_dropout_prob=0.0, attention_probs_dropout_prob=0.0, max_position_embeddings=440, initializer_range=0.02, layer_norm_eps=1e-12, - vocab_list=None, + pad_token_id=0, + position_embedding_type="absolute", + use_cache=True, **kwargs, ): - super().__init__(pad_token_id=pad_token_id, mask_token_id=mask_token_id, **kwargs) + super().__init__(pad_token_id=pad_token_id, **kwargs) self.vocab_size = vocab_size + self.ss_vocab_size = ss_vocab_size if hidden_size is None: hidden_size = num_attention_heads * multiple if multiple is not None else 120 self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.intermediate_size = intermediate_size + self.hidden_act = hidden_act self.hidden_dropout_prob = hidden_dropout_prob self.attention_probs_dropout_prob = attention_probs_dropout_prob self.max_position_embeddings = max_position_embeddings self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps - self.vocab_list = vocab_list if vocab_list is not None else DEFAULT_VOCAB_LIST + self.position_embedding_type = position_embedding_type + self.use_cache = use_cache diff --git a/multimolecule/models/rnabert/convert_checkpoint.py b/multimolecule/models/rnabert/convert_checkpoint.py index d3a85c5d..f341bbb3 100644 --- a/multimolecule/models/rnabert/convert_checkpoint.py +++ b/multimolecule/models/rnabert/convert_checkpoint.py @@ -1,11 +1,13 @@ +import os import sys from typing import Optional import chanfig import torch +from torch import nn from multimolecule.models import RnaBertConfig, RnaBertModel -from multimolecule.models.rnabert.configuration_rnabert import DEFAULT_VOCAB_LIST +from multimolecule.tokenizers.rna.config import get_special_tokens_map, get_tokenizer_config, get_vocab_list CONFIG = { "architectures": ["RnaBertModel"], @@ -13,28 +15,25 @@ "hidden_act": "gelu", "hidden_dropout_prob": 0.0, "hidden_size": 120, - "initializer_range": 0.02, "intermediate_size": 40, - "layer_norm_eps": 1e-12, - "mask_token_id": 1, "max_position_embeddings": 440, - "model_type": "rnabert", "num_attention_heads": 12, "num_hidden_layers": 6, - "position_embedding_type": "absolute", - "ss_size": 8, - "torch_dtype": "float32", + "vocab_size": 25, + "ss_vocab_size": 8, "type_vocab_size": 2, + "pad_token_id": 0, } +original_vocab_list = ["", "", "A", "U", "G", "C"] +vocab_list = get_vocab_list() + def convert_checkpoint(checkpoint_path: str, output_path: Optional[str] = None): if output_path is None: output_path = "rnabert" - config = RnaBertConfig.from_dict(chanfig.NestedDict(CONFIG)) - config.vocab_list = DEFAULT_VOCAB_LIST - config.vocab_size = len(config.vocab_list) - ckpt = torch.load(checkpoint_path) + config = RnaBertConfig.from_dict(chanfig.FlatDict(CONFIG)) + ckpt = torch.load(checkpoint_path, map_location=torch.device("cpu")) bert_state_dict = ckpt state_dict = {} @@ -48,8 +47,19 @@ def convert_checkpoint(checkpoint_path: str, output_path: Optional[str] = None): key = key.replace("beta", "bias") state_dict[key] = value + word_embed = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + # nn.init.normal_(pos_embed.weight, std=0.02) + for original_token, new_token in zip(original_vocab_list, vocab_list): + original_index = original_vocab_list.index(original_token) + new_index = vocab_list.index(new_token) + word_embed.weight.data[new_index] = state_dict["embeddings.word_embeddings.weight"][original_index] + state_dict["embeddings.word_embeddings.weight"] = word_embed.weight.data + model.load_state_dict(state_dict) - model.save_pretrained(output_path) + model.save_pretrained(output_path, safe_serialization=True) + model.save_pretrained(output_path, safe_serialization=False) + chanfig.NestedDict(get_special_tokens_map()).json(os.path.join(output_path, "special_tokens_map.json")) + chanfig.NestedDict(get_tokenizer_config()).json(os.path.join(output_path, "tokenizer_config.json")) if __name__ == "__main__": diff --git a/multimolecule/models/rnabert/modeling_rnabert.py b/multimolecule/models/rnabert/modeling_rnabert.py index 7d90e758..b9698589 100644 --- a/multimolecule/models/rnabert/modeling_rnabert.py +++ b/multimolecule/models/rnabert/modeling_rnabert.py @@ -273,7 +273,6 @@ def _init_weights(self, module): class RnaBertModel(RnaBertPreTrainedModel): - def __init__(self, config): super().__init__(config) self.embeddings = RnaBertEmbeddings(config) @@ -329,9 +328,8 @@ class RnaBertLMHead(nn.Module): def __init__(self, config): super().__init__() - self.predictions = MaskedWordPredictions(config) - config.vocab_size = config.ss_size - self.predictions_ss = MaskedWordPredictions(config) + self.predictions = MaskedWordPredictions(config, config.vocab_size) + self.predictions_ss = MaskedWordPredictions(config, config.ss_vocab_size) self.seq_relationship = nn.Linear(config.hidden_size, 2) @@ -345,13 +343,13 @@ def forward(self, sequence_output, pooled_output): class MaskedWordPredictions(nn.Module): - def __init__(self, config): + def __init__(self, config, vocab_size): super().__init__() self.transform = RnaBertPredictionHeadTransform(config) - self.decoder = nn.Linear(in_features=config.hidden_size, out_features=config.vocab_size, bias=False) - self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + self.decoder = nn.Linear(in_features=config.hidden_size, out_features=vocab_size, bias=False) + self.bias = nn.Parameter(torch.zeros(vocab_size)) def forward(self, hidden_states): hidden_states = self.transform(hidden_states) diff --git a/multimolecule/models/rnabert/special_tokens_map.json b/multimolecule/models/rnabert/special_tokens_map.json deleted file mode 100644 index 934c898f..00000000 --- a/multimolecule/models/rnabert/special_tokens_map.json +++ /dev/null @@ -1,16 +0,0 @@ -{ - "pad_token": { - "content": "", - "lstrip": false, - "normalized": false, - "rstrip": false, - "single_word": false - }, - "mask_token": { - "content": "", - "lstrip": false, - "normalized": false, - "rstrip": false, - "single_word": false - } -} diff --git a/multimolecule/models/rnabert/tokenizer_config.json b/multimolecule/models/rnabert/tokenizer_config.json deleted file mode 100644 index f9d62b74..00000000 --- a/multimolecule/models/rnabert/tokenizer_config.json +++ /dev/null @@ -1,27 +0,0 @@ -{ - "added_tokens_decoder": { - "0": { - "content": "", - "lstrip": false, - "normalized": false, - "rstrip": false, - "single_word": false, - "special": true - }, - "1": { - "content": "", - "lstrip": false, - "normalized": false, - "rstrip": false, - "single_word": false, - "special": true - } - }, - "clean_up_tokenization_spaces": true, - "pad_token": "", - "mask_token": "", - "cls_token": "", - "unk_token": "", - "model_max_length": 440, - "tokenizer_class": "RnaBertTokenizer" -} diff --git a/multimolecule/models/rnabert/vocab.txt b/multimolecule/models/rnabert/vocab.txt deleted file mode 100644 index 41be3616..00000000 --- a/multimolecule/models/rnabert/vocab.txt +++ /dev/null @@ -1,6 +0,0 @@ - - -A -T -G -C diff --git a/multimolecule/tokenizers/__init__.py b/multimolecule/tokenizers/__init__.py new file mode 100644 index 00000000..1a70ec47 --- /dev/null +++ b/multimolecule/tokenizers/__init__.py @@ -0,0 +1,3 @@ +from .rna import RnaTokenizer + +__all__ = ["RnaTokenizer"] diff --git a/multimolecule/tokenizers/rna/__init__.py b/multimolecule/tokenizers/rna/__init__.py new file mode 100644 index 00000000..531035b2 --- /dev/null +++ b/multimolecule/tokenizers/rna/__init__.py @@ -0,0 +1,3 @@ +from .tokenization_rna import RnaTokenizer + +__all__ = ["RnaTokenizer"] diff --git a/multimolecule/tokenizers/rna/config.py b/multimolecule/tokenizers/rna/config.py new file mode 100755 index 00000000..abe82a51 --- /dev/null +++ b/multimolecule/tokenizers/rna/config.py @@ -0,0 +1,91 @@ +def get_vocab_list(): + return VOCAB_LIST + + +def get_special_tokens_map(): + return SPECIAL_TOKENS_MAP + + +def get_tokenizer_config(): + config = TOKENIZER_CONFIG + config.setdefault("added_tokens_decoder", {}) + for i, v in enumerate(SPECIAL_TOKENS_MAP.values()): + config["added_tokens_decoder"][str(i)] = v + return config + + +VOCAB_LIST = [ + "", + "", + "", + "", + "", + "", + "A", + "C", + "G", + "U", + "N", + "X", + "V", + "H", + "D", + "B", + "M", + "R", + "W", + "S", + "Y", + "K", + ".", + "*", + "-", +] + +SPECIAL_TOKENS_MAP = { + "pad_token": { + "content": "", + "lstrip": False, + "normalized": False, + "rstrip": False, + "single_word": False, + "special": True, + }, + "cls_token": { + "content": "", + "lstrip": False, + "normalized": False, + "rstrip": False, + "single_word": False, + "special": True, + }, + "eos_token": { + "content": "", + "lstrip": False, + "normalized": False, + "rstrip": False, + "single_word": False, + "special": True, + }, + "unk_token": { + "content": "", + "lstrip": False, + "normalized": False, + "rstrip": False, + "single_word": False, + "special": True, + }, + "mask_token": { + "content": "", + "lstrip": False, + "normalized": False, + "rstrip": False, + "single_word": False, + "special": True, + }, +} + +TOKENIZER_CONFIG = { + "tokenizer_class": "RnaTokenizer", + "clean_up_tokenization_spaces": True, +} diff --git a/multimolecule/models/rnabert/tokenization_rnabert.py b/multimolecule/tokenizers/rna/tokenization_rna.py old mode 100644 new mode 100755 similarity index 80% rename from multimolecule/models/rnabert/tokenization_rnabert.py rename to multimolecule/tokenizers/rna/tokenization_rna.py index efcecff7..ab38cd7a --- a/multimolecule/models/rnabert/tokenization_rnabert.py +++ b/multimolecule/tokenizers/rna/tokenization_rna.py @@ -4,49 +4,41 @@ from transformers.tokenization_utils import PreTrainedTokenizer from transformers.utils import logging -logger = logging.get_logger(__name__) - -VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} - -PRETRAINED_VOCAB_FILES_MAP = { - "vocab_file": { - "ZhiyuanChen/rnabert": "https://huggingface.co/ZhiyuanChen/rnabert/resolve/main/vocab.txt", - }, -} +from .config import VOCAB_LIST -PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { - "ZhiyuanChen/rnabert": 1024, -} - - -def load_vocab_file(vocab_file): - with open(vocab_file) as f: - lines = f.read().splitlines() - return [l.strip() for l in lines] # noqa: E741 +logger = logging.get_logger(__name__) -class RnaBertTokenizer(PreTrainedTokenizer): +class RnaTokenizer(PreTrainedTokenizer): """ Constructs an RnaBert tokenizer. """ - vocab_files_names = VOCAB_FILES_NAMES - pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP - max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES model_input_names = ["input_ids", "attention_mask"] def __init__( self, - vocab_file, + cls_token="", pad_token="", + eos_token="", + sep_token="", + unk_token="", mask_token="", + convert_to_uppercase=True, + convert_T_to_U=True, **kwargs, ): - self.all_tokens = load_vocab_file(vocab_file) + self.all_tokens = VOCAB_LIST self._id_to_token = dict(enumerate(self.all_tokens)) self._token_to_id = {tok: ind for ind, tok in enumerate(self.all_tokens)} + self.convert_to_uppercase = convert_to_uppercase + self.convert_T_to_U = convert_T_to_U super().__init__( + cls_token=cls_token, pad_token=pad_token, + eos_token=eos_token, + sep_token=sep_token, + unk_token=unk_token, mask_token=mask_token, **kwargs, ) @@ -54,8 +46,8 @@ def __init__( # TODO, all the tokens are added? But they are also part of the vocab... bit strange. # none of them are special, but they all need special splitting. - self.unique_no_split_tokens = self.all_tokens - self._update_trie(self.unique_no_split_tokens) + # self.unique_no_split_tokens = self.all_tokens + # self._update_trie(self.unique_no_split_tokens) def _convert_id_to_token(self, index: int) -> str: return self._id_to_token.get(index, self.unk_token) @@ -63,8 +55,12 @@ def _convert_id_to_token(self, index: int) -> str: def _convert_token_to_id(self, token: str) -> int: return self._token_to_id.get(token, self._token_to_id.get(self.unk_token)) - def _tokenize(self, text, **kwargs): - return text.split() + def _tokenize(self, text: str, **kwargs): + if self.convert_to_uppercase: + text = text.upper() + if self.convert_T_to_U: + text = text.replace("T", "U") + return list(text) def get_vocab(self): base_vocab = self._token_to_id.copy() @@ -122,7 +118,7 @@ def get_special_tokens_mask( mask += [0] * len(token_ids_1) + [1] return mask - def save_vocabulary(self, save_directory, filename_prefix): + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None): vocab_file = os.path.join(save_directory, (filename_prefix + "-" if filename_prefix else "") + "vocab.txt") with open(vocab_file, "w") as f: f.write("\n".join(self.all_tokens)) diff --git a/multimolecule/tokenizers/rna/vocab.txt b/multimolecule/tokenizers/rna/vocab.txt new file mode 100755 index 00000000..8487a2ca --- /dev/null +++ b/multimolecule/tokenizers/rna/vocab.txt @@ -0,0 +1,25 @@ + + + + + + +A +C +G +U +N +X +V +H +D +B +M +R +W +S +Y +K +. +* +-