From 09ca4d0e8b45989375e8ad95bab82643c782f3b8 Mon Sep 17 00:00:00 2001 From: Zhiyuan Chen Date: Wed, 27 Mar 2024 23:45:24 +0800 Subject: [PATCH] reorganise rnabert --- multimolecule/models/__init__.py | 18 +- multimolecule/models/modeling_utils.py | 246 ++++++++++ multimolecule/models/rnabert/__init__.py | 30 +- .../models/rnabert/configuration_rnabert.py | 9 +- .../models/rnabert/convert_checkpoint.py | 31 +- .../models/rnabert/modeling_rnabert.py | 446 +++++++++++------- pyproject.toml | 4 + 7 files changed, 590 insertions(+), 194 deletions(-) create mode 100644 multimolecule/models/modeling_utils.py diff --git a/multimolecule/models/__init__.py b/multimolecule/models/__init__.py index a922cfc5..1f07452e 100644 --- a/multimolecule/models/__init__.py +++ b/multimolecule/models/__init__.py @@ -1,3 +1,17 @@ -from .rnabert import RnaBertConfig, RnaBertModel, RnaTokenizer +from .rnabert import ( + RnaBertConfig, + RnaBertForMaskedLM, + RnaBertForSequenceClassification, + RnaBertForTokenClassification, + RnaBertModel, + RnaTokenizer, +) -__all__ = ["RnaBertConfig", "RnaBertModel", "RnaTokenizer"] +__all__ = [ + "RnaBertConfig", + "RnaBertModel", + "RnaBertForMaskedLM", + "RnaBertForSequenceClassification", + "RnaBertForTokenClassification", + "RnaTokenizer", +] diff --git a/multimolecule/models/modeling_utils.py b/multimolecule/models/modeling_utils.py new file mode 100644 index 00000000..80669928 --- /dev/null +++ b/multimolecule/models/modeling_utils.py @@ -0,0 +1,246 @@ +from math import sqrt +from typing import Optional, Tuple, Union + +import torch +from chanfig import ConfigRegistry +from torch import nn +from torch.nn import functional as F +from transformers.activations import ACT2FN +from transformers.modeling_outputs import MaskedLMOutput, SequenceClassifierOutput, TokenClassifierOutput + + +class MaskedLMHead(nn.Module): + """Head for masked language modeling.""" + + def __init__(self, config): + super().__init__() + if "proj_head_mode" not in dir(config) or config.proj_head_mode is None: + config.proj_head_mode = "none" + self.transform = PredictionHeadTransform.build(config) + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + self.decoder.bias = self.bias + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]: + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + x = self.transform(sequence_output) + prediction_scores = self.decoder(x) + + masked_lm_loss = None + if labels is not None: + masked_lm_loss = F.cross_entropy(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class SequenceClassificationHead(nn.Module): + """Head for sequence-level classification tasks.""" + + num_labels: int + + def __init__(self, config): + super().__init__() + if "proj_head_mode" not in dir(config) or config.proj_head_mode is None: + config.proj_head_mode = "none" + self.num_labels = config.num_labels + self.transform = PredictionHeadTransform.build(config) + classifier_dropout = ( + config.classifier_dropout + if "classifier_dropout" in dir(config) and config.classifier_dropout is not None + else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) + self.decoder = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + def forward( + self, outputs, labels: Optional[torch.Tensor] = None, return_dict: Optional[bool] = None + ) -> Union[Tuple, SequenceClassifierOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + sequence_output = outputs.last_hidden_state if return_dict else outputs[0] + x = self.dropout(sequence_output) + x = self.transform(x) + logits = self.decoder(x) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + if self.config.problem_type == "regression": + loss = ( + F.mse_loss(logits.squeeze(), labels.squeeze()) + if self.num_labels == 1 + else F.mse_loss(logits, labels) + ) + elif self.config.problem_type == "single_label_classification": + loss = F.cross_entropy(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss = F.binary_cross_entropy_with_logits(logits, labels) + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class TokenClassificationHead(nn.Module): + """Head for token-level classification tasks.""" + + num_labels: int + + def __init__(self, config): + if "proj_head_mode" not in dir(config) or config.proj_head_mode is None: + config.proj_head_mode = "none" + super().__init__() + self.num_labels = config.num_labels + self.transform = PredictionHeadTransform.build(config) + classifier_dropout = ( + config.classifier_dropout + if "classifier_dropout" in dir(config) and config.classifier_dropout is not None + else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) + self.decoder = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + def forward( + self, outputs, labels: Optional[torch.Tensor] = None, return_dict: Optional[bool] = None + ) -> Union[Tuple, TokenClassifierOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + token_output = outputs.pooled_output if return_dict else outputs[1] + x = self.dropout(token_output) + x = self.transform(x) + logits = self.decoder(x) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + if self.config.problem_type == "regression": + loss = ( + F.mse_loss(logits.squeeze(), labels.squeeze()) + if self.num_labels == 1 + else F.mse_loss(logits, labels) + ) + elif self.config.problem_type == "single_label_classification": + loss = F.cross_entropy(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss = F.binary_cross_entropy_with_logits(logits, labels) + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +PredictionHeadTransform = ConfigRegistry(key="proj_head_mode") + + +@PredictionHeadTransform.register("nonlinear") +class NonLinearTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +@PredictionHeadTransform.register("linear") +class LinearTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size, bias=False) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +@PredictionHeadTransform.register("none") +class IdentityTransform(nn.Identity): + def __init__(self, config): + super().__init__() + + +sqrt_2 = sqrt(2.0) + + +def gelu(x): + """Implementation of the gelu activation function. + + For information: OpenAI GPT's gelu is slightly different + (and gives slightly different results): + 0.5 * x * ( + 1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))) + ) + """ + return x * 0.5 * (1.0 + torch.erf(x / sqrt_2)) diff --git a/multimolecule/models/rnabert/__init__.py b/multimolecule/models/rnabert/__init__.py index c388aeff..36c2b893 100644 --- a/multimolecule/models/rnabert/__init__.py +++ b/multimolecule/models/rnabert/__init__.py @@ -1,12 +1,36 @@ -from transformers import AutoConfig, AutoModel, AutoTokenizer +from transformers import ( + AutoConfig, + AutoModel, + AutoModelForMaskedLM, + AutoModelForSequenceClassification, + AutoModelForTokenClassification, + AutoModelWithLMHead, + AutoTokenizer, +) from multimolecule.tokenizers.rna import RnaTokenizer from .configuration_rnabert import RnaBertConfig -from .modeling_rnabert import RnaBertModel +from .modeling_rnabert import ( + RnaBertForMaskedLM, + RnaBertForSequenceClassification, + RnaBertForTokenClassification, + RnaBertModel, +) -__all__ = ["RnaBertConfig", "RnaBertModel", "RnaTokenizer"] +__all__ = [ + "RnaBertConfig", + "RnaBertModel", + "RnaTokenizer", + "RnaBertForMaskedLM", + "RnaBertForSequenceClassification", + "RnaBertForTokenClassification", +] AutoConfig.register("rnabert", RnaBertConfig) AutoModel.register(RnaBertConfig, RnaBertModel) +AutoModelForMaskedLM.register(RnaBertConfig, RnaBertForMaskedLM) +AutoModelForSequenceClassification.register(RnaBertConfig, RnaBertForSequenceClassification) +AutoModelForTokenClassification.register(RnaBertConfig, RnaBertForTokenClassification) +AutoModelWithLMHead.register(RnaBertConfig, RnaBertForTokenClassification) AutoTokenizer.register(RnaBertConfig, RnaTokenizer) diff --git a/multimolecule/models/rnabert/configuration_rnabert.py b/multimolecule/models/rnabert/configuration_rnabert.py index d98468e2..c2af1179 100644 --- a/multimolecule/models/rnabert/configuration_rnabert.py +++ b/multimolecule/models/rnabert/configuration_rnabert.py @@ -9,7 +9,7 @@ class RnaBertConfig(PretrainedConfig): This is the configuration class to store the configuration of a [`RnaBertModel`]. It is used to instantiate a RnaBert model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the RnaBert - [mana438/RNABERT](https://github.com/mana438/RNABERT/blob/master/RNA_bert_config.json) architecture. + [mana438/RNABERT](https://github.com/mana438/RNABERT) architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. @@ -55,7 +55,8 @@ class RnaBertConfig(PretrainedConfig): >>> # Initializing a model from the configuration >>> model = RnaBertModel(configuration) >>> # Accessing the model configuration >>> configuration = model.config - ```""" + ``` + """ model_type = "rnabert" @@ -77,6 +78,8 @@ def __init__( pad_token_id=0, position_embedding_type="absolute", use_cache=True, + classifier_dropout=None, + proj_head_mode="nonlinear", **kwargs, ): super().__init__(pad_token_id=pad_token_id, **kwargs) @@ -97,3 +100,5 @@ def __init__( self.layer_norm_eps = layer_norm_eps self.position_embedding_type = position_embedding_type self.use_cache = use_cache + self.classifier_dropout = classifier_dropout + self.proj_head_mode = proj_head_mode diff --git a/multimolecule/models/rnabert/convert_checkpoint.py b/multimolecule/models/rnabert/convert_checkpoint.py index f341bbb3..9c281064 100644 --- a/multimolecule/models/rnabert/convert_checkpoint.py +++ b/multimolecule/models/rnabert/convert_checkpoint.py @@ -6,7 +6,7 @@ import torch from torch import nn -from multimolecule.models import RnaBertConfig, RnaBertModel +from multimolecule.models import RnaBertConfig, RnaBertForMaskedLM from multimolecule.tokenizers.rna.config import get_special_tokens_map, get_tokenizer_config, get_vocab_list CONFIG = { @@ -19,7 +19,6 @@ "max_position_embeddings": 440, "num_attention_heads": 12, "num_hidden_layers": 6, - "vocab_size": 25, "ss_vocab_size": 8, "type_vocab_size": 2, "pad_token_id": 0, @@ -33,27 +32,41 @@ def convert_checkpoint(checkpoint_path: str, output_path: Optional[str] = None): if output_path is None: output_path = "rnabert" config = RnaBertConfig.from_dict(chanfig.FlatDict(CONFIG)) + config.vocab_size = len(vocab_list) ckpt = torch.load(checkpoint_path, map_location=torch.device("cpu")) bert_state_dict = ckpt state_dict = {} - model = RnaBertModel(config) + model = RnaBertForMaskedLM(config) for key, value in bert_state_dict.items(): - if key.startswith("module.cls"): - continue - key = key[12:] + key = key[7:] key = key.replace("gamma", "weight") key = key.replace("beta", "bias") - state_dict[key] = value + if key.startswith("bert"): + state_dict["rna" + key] = value + continue + if key.startswith("cls"): + # import ipdb; ipdb.set_trace() + key = "lm_head." + key[4:] + # key = key[4:] + state_dict[key] = value + continue word_embed = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + word_embed_weight = word_embed.weight.data + predictions_bias = torch.zeros(config.vocab_size) + predictions_decoder_weight = torch.zeros((config.vocab_size, config.hidden_size)) # nn.init.normal_(pos_embed.weight, std=0.02) for original_token, new_token in zip(original_vocab_list, vocab_list): original_index = original_vocab_list.index(original_token) new_index = vocab_list.index(new_token) - word_embed.weight.data[new_index] = state_dict["embeddings.word_embeddings.weight"][original_index] - state_dict["embeddings.word_embeddings.weight"] = word_embed.weight.data + word_embed_weight[new_index] = state_dict["rnabert.embeddings.word_embeddings.weight"][original_index] + predictions_bias[new_index] = state_dict["lm_head.predictions.bias"][original_index] + predictions_decoder_weight[new_index] = state_dict["lm_head.predictions.decoder.weight"][original_index] + state_dict["rnabert.embeddings.word_embeddings.weight"] = word_embed_weight + state_dict["lm_head.predictions.bias"] = predictions_bias + state_dict["lm_head.predictions.decoder.weight"] = predictions_decoder_weight model.load_state_dict(state_dict) model.save_pretrained(output_path, safe_serialization=True) diff --git a/multimolecule/models/rnabert/modeling_rnabert.py b/multimolecule/models/rnabert/modeling_rnabert.py index b9698589..c5a6bab7 100644 --- a/multimolecule/models/rnabert/modeling_rnabert.py +++ b/multimolecule/models/rnabert/modeling_rnabert.py @@ -1,21 +1,242 @@ import math +from dataclasses import dataclass +from typing import Optional, Tuple, Union import torch -from torch import nn +from torch import Tensor, nn from transformers import PreTrainedModel -from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling - +from transformers.modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPooling, + MaskedLMOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) + +from ..modeling_utils import SequenceClassificationHead, TokenClassificationHead, gelu from .configuration_rnabert import RnaBertConfig +class RnaBertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = RnaBertConfig + base_model_prefix = "rnabert" + supports_gradient_checkpointing = True + _no_split_modules = ["RnaBertLayer", "RnaBertEmbeddings"] + + # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights + def _init_weights(self, module: nn.Module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +class RnaBertModel(RnaBertPreTrainedModel): + def __init__(self, config, add_pooling_layer=True): + super().__init__(config) + self.config = config + + self.embeddings = RnaBertEmbeddings(config) + self.encoder = RnaBertEncoder(config) + self.pooler = RnaBertPooler(config) if add_pooling_layer else None + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids: Tensor, + token_type_ids: Optional[Tensor] = None, + attention_mask: Optional[Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = False, + ): + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + + extended_attention_mask = extended_attention_mask.to(dtype=torch.float32) + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + + embedding_output = self.embeddings( + input_ids=input_ids, + token_type_ids=token_type_ids, + # attention_mask=attention_mask, + ) + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class RnaBertForMaskedLM(RnaBertPreTrainedModel): + def __init__(self, config: RnaBertConfig): + super().__init__(config) + self.rnabert = RnaBertModel(config) + self.lm_head = RnaBertLMHead(config) + + def forward( + self, + input_ids: Tensor, + token_type_ids: Optional[Tensor] = None, + attention_mask: Optional[Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = False, + ): + outputs = self.rnabert( + input_ids, + token_type_ids, + attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + prediction_scores, prediction_scores_ss, seq_relationship_score = self.lm_head( + outputs.last_hidden_state, outputs.pooler_output + ) + + if not return_dict: + return (prediction_scores, prediction_scores_ss) + outputs[2:] + + return RnaBertMaskedLMOutput( + logits=prediction_scores, + logits_ss=prediction_scores_ss, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class RnaBertForSequenceClassification(RnaBertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + + self.rnabert = RnaBertModel(config, add_pooling_layer=False) + self.classifier = SequenceClassificationHead(config) + + self.init_weights() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.rnabert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + return self.classifier(outputs, labels) + + +class RnaBertForTokenClassification(RnaBertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + + self.rnabert = RnaBertModel(config, add_pooling_layer=False) + self.classifier = TokenClassificationHead(config) + + self.init_weights() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.rnabert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + return self.classifier(outputs, labels) + + class RnaBertLayerNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-12): + def __init__(self, hidden_size: int, eps: float = 1e-12): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) # weightのこと self.bias = nn.Parameter(torch.zeros(hidden_size)) # biasのこと self.variance_epsilon = eps - def forward(self, x): + def forward(self, x: Tensor): u = x.mean(-1, keepdim=True) s = (x - u).pow(2).mean(-1, keepdim=True) x = (x - u) / torch.sqrt(s + self.variance_epsilon) @@ -23,15 +244,15 @@ def forward(self, x): class RnaBertEmbeddings(nn.Module): - def __init__(self, config): + def __init__(self, config: RnaBertConfig): super().__init__() - self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0) + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) self.LayerNorm = RnaBertLayerNorm(config.hidden_size, eps=1e-12) self.dropout = nn.Dropout(config.hidden_dropout_prob) - def forward(self, input_ids, token_type_ids=None): + def forward(self, input_ids: Tensor, token_type_ids: Optional[Tensor] = None): words_embeddings = self.word_embeddings(input_ids) if token_type_ids is None: @@ -52,13 +273,13 @@ def forward(self, input_ids, token_type_ids=None): class RnaBertLayer(nn.Module): - def __init__(self, config): + def __init__(self, config: RnaBertConfig): super().__init__() self.attention = RnaBertAttention(config) self.intermediate = RnaBertIntermediate(config) self.output = RnaBertOutput(config) - def forward(self, hidden_states, attention_mask, output_attentions=False): + def forward(self, hidden_states: Tensor, attention_mask: Tensor, output_attentions: bool = False): self_attention_outputs = self.attention(hidden_states, attention_mask, output_attentions=output_attentions) attention_output, outputs = self_attention_outputs[0], self_attention_outputs[1:] intermediate_output = self.intermediate(attention_output) @@ -68,12 +289,12 @@ def forward(self, hidden_states, attention_mask, output_attentions=False): class RnaBertAttention(nn.Module): - def __init__(self, config): + def __init__(self, config: RnaBertConfig): super().__init__() self.selfattn = RnaBertSelfAttention(config) self.output = RnaBertSelfOutput(config) - def forward(self, hidden_states, attention_mask, output_attentions=False): + def forward(self, hidden_states: Tensor, attention_mask: Tensor, output_attentions: bool = False): self_outputs = self.selfattn(hidden_states, attention_mask, output_attentions=output_attentions) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them @@ -81,22 +302,17 @@ def forward(self, hidden_states, attention_mask, output_attentions=False): class RnaBertSelfAttention(nn.Module): - def __init__(self, config): + def __init__(self, config: RnaBertConfig): super().__init__() - self.num_attention_heads = config.num_attention_heads - # num_attention_heads': 12 - self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size - self.query = nn.Linear(config.hidden_size, self.all_head_size) self.key = nn.Linear(config.hidden_size, self.all_head_size) self.value = nn.Linear(config.hidden_size, self.all_head_size) - self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - def transpose_for_scores(self, x): + def transpose_for_scores(self, x: Tensor): new_x_shape = x.size()[:-1] + ( self.num_attention_heads, self.attention_head_size, @@ -104,7 +320,7 @@ def transpose_for_scores(self, x): x = x.view(*new_x_shape) return x.permute(0, 2, 1, 3) - def forward(self, hidden_states, attention_mask, output_attentions=False): + def forward(self, hidden_states: Tensor, attention_mask: Tensor, output_attentions: bool = False): mixed_query_layer = self.query(hidden_states) mixed_key_layer = self.key(hidden_states) mixed_value_layer = self.value(hidden_states) @@ -114,11 +330,8 @@ def forward(self, hidden_states, attention_mask, output_attentions=False): value_layer = self.transpose_for_scores(mixed_value_layer) attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) attention_scores = attention_scores / math.sqrt(self.attention_head_size) - attention_scores = attention_scores + attention_mask - - attention_probs = nn.Softmax(dim=-1)(attention_scores) - + attention_probs = attention_scores.softmax(-1) attention_probs = self.dropout(attention_probs) context_layer = torch.matmul(attention_probs, value_layer) @@ -131,48 +344,39 @@ def forward(self, hidden_states, attention_mask, output_attentions=False): class RnaBertSelfOutput(nn.Module): - def __init__(self, config): + def __init__(self, config: RnaBertConfig): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.LayerNorm = RnaBertLayerNorm(config.hidden_size, eps=1e-12) self.dropout = nn.Dropout(config.hidden_dropout_prob) - def forward(self, hidden_states, input_tensor): + def forward(self, hidden_states: Tensor, input_tensor: Tensor): hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.LayerNorm(hidden_states + input_tensor) return hidden_states -def gelu(x): - return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) - - class RnaBertIntermediate(nn.Module): - def __init__(self, config): + def __init__(self, config: RnaBertConfig): super().__init__() - self.dense = nn.Linear(config.hidden_size, config.intermediate_size) - self.intermediate_act_fn = gelu - def forward(self, hidden_states): + def forward(self, hidden_states: Tensor): hidden_states = self.dense(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states) return hidden_states class RnaBertOutput(nn.Module): - def __init__(self, config): + def __init__(self, config: RnaBertConfig): super().__init__() - self.dense = nn.Linear(config.intermediate_size, config.hidden_size) - self.LayerNorm = RnaBertLayerNorm(config.hidden_size, eps=1e-12) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - def forward(self, hidden_states, input_tensor): + def forward(self, hidden_states: Tensor, input_tensor: Tensor): hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.LayerNorm(hidden_states + input_tensor) @@ -180,34 +384,35 @@ def forward(self, hidden_states, input_tensor): class RnaBertEncoder(nn.Module): - def __init__(self, config): + def __init__(self, config: RnaBertConfig): super().__init__() + self.config = config self.layer = nn.ModuleList([RnaBertLayer(config) for _ in range(config.num_hidden_layers)]) # self.layer = nn.ModuleList([nn.Linear(config.hidden_size, config.hidden_size) # for _ in range(config.num_hidden_layers)]) def forward( self, - hidden_states, - attention_mask, - output_attentions=False, - output_hidden_states=False, - return_dict=False, + hidden_states: Tensor, + attention_mask: Tensor, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = False, ): all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None for layer in self.layer: if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) + all_hidden_states = all_hidden_states + (hidden_states,) # type: ignore[operator] layer_outputs = layer(hidden_states, attention_mask, output_attentions) hidden_states = layer_outputs[0] if output_attentions: - all_self_attentions = all_self_attentions + (layer_outputs[1],) + all_self_attentions = all_self_attentions + (layer_outputs[1],) # type: ignore[operator] if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) + all_hidden_states = all_hidden_states + (hidden_states,) # type: ignore[operator] if not return_dict: return tuple( @@ -227,179 +432,64 @@ def forward( class RnaBertPooler(nn.Module): - def __init__(self, config): + def __init__(self, config: RnaBertConfig): super().__init__() - self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.activation = nn.Tanh() - def forward(self, hidden_states): + def forward(self, hidden_states: Tensor): first_token_tensor = hidden_states[:, 0] - pooled_output = self.dense(first_token_tensor) - pooled_output = self.activation(pooled_output) - return pooled_output -class RnaBertPreTrainedModel(PreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = RnaBertConfig - base_model_prefix = "rnabert" - supports_gradient_checkpointing = True - _no_split_modules = ["RnaBertLayer", "RnaBertFoldTriangularSelfAttentionBlock", "RnaBertEmbeddings"] - - # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights - def _init_weights(self, module): - """Initialize the weights""" - if isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - - -class RnaBertModel(RnaBertPreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.embeddings = RnaBertEmbeddings(config) - self.encoder = RnaBertEncoder(config) - self.pooler = RnaBertPooler(config) - - def forward( - self, - input_ids, - token_type_ids=None, - attention_mask=None, - output_attentions=False, - output_hidden_states=False, - return_dict=True, - ): - if attention_mask is None: - attention_mask = torch.ones_like(input_ids) - if token_type_ids is None: - token_type_ids = torch.zeros_like(input_ids) - - extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) - - extended_attention_mask = extended_attention_mask.to(dtype=torch.float32) - extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 - - embedding_output = self.embeddings( - input_ids=input_ids, - token_type_ids=token_type_ids, - # attention_mask=attention_mask, - ) - encoder_outputs = self.encoder( - embedding_output, - attention_mask=extended_attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - sequence_output = encoder_outputs[0] - pooled_output = self.pooler(sequence_output) if self.pooler is not None else None - - if not return_dict: - return (sequence_output, pooled_output) + encoder_outputs[1:] - - return BaseModelOutputWithPooling( - last_hidden_state=sequence_output, - pooler_output=pooled_output, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - ) - - class RnaBertLMHead(nn.Module): - def __init__(self, config): + def __init__(self, config: RnaBertConfig): super().__init__() - self.predictions = MaskedWordPredictions(config, config.vocab_size) self.predictions_ss = MaskedWordPredictions(config, config.ss_vocab_size) - self.seq_relationship = nn.Linear(config.hidden_size, 2) - def forward(self, sequence_output, pooled_output): + def forward(self, sequence_output: Tensor, pooled_output: Tensor): prediction_scores = self.predictions(sequence_output) prediction_scores_ss = self.predictions_ss(sequence_output) - seq_relationship_score = self.seq_relationship(pooled_output) - return prediction_scores, prediction_scores_ss, seq_relationship_score class MaskedWordPredictions(nn.Module): def __init__(self, config, vocab_size): super().__init__() - self.transform = RnaBertPredictionHeadTransform(config) - self.decoder = nn.Linear(in_features=config.hidden_size, out_features=vocab_size, bias=False) self.bias = nn.Parameter(torch.zeros(vocab_size)) + # self.decoder.bias = self.bias - def forward(self, hidden_states): + def forward(self, hidden_states: Tensor): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) + self.bias - return hidden_states class RnaBertPredictionHeadTransform(nn.Module): - def __init__(self, config): + def __init__(self, config: RnaBertConfig): super().__init__() - self.dense = nn.Linear(config.hidden_size, config.hidden_size) - self.transform_act_fn = gelu - self.LayerNorm = RnaBertLayerNorm(config.hidden_size, eps=1e-12) - def forward(self, hidden_states): + def forward(self, hidden_states: Tensor): hidden_states = self.dense(hidden_states) # hidden_states = self.transform_act_fn(hidden_states) hidden_states = self.LayerNorm(hidden_states) return hidden_states -class RnaBertForMaskedLM(nn.Module): - def __init__(self, config): - super().__init__() - self.bert = RnaBertModel(config) - self.lm_head = RnaBertLMHead(config) - - def forward( - self, - input_ids, - token_type_ids=None, - attention_mask=None, - output_attentions=False, - output_hidden_states=False, - return_dict=False, - ): - outputs = self.bert( - input_ids, - token_type_ids, - attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - prediction_scores, prediction_scores_ss, seq_relationship_score = self.lm_head( - outputs.last_hidden_state, outputs.pooler_output - ) - return prediction_scores, prediction_scores_ss, outputs +@dataclass +class RnaBertMaskedLMOutput(MaskedLMOutput): + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + logits_ss: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None diff --git a/pyproject.toml b/pyproject.toml index 536d2e09..5edeb3ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,10 @@ classifiers = [ dynamic = [ "version", ] +dependencies = [ + "chanfig", + "transformers", +] [project.urls] documentation = "https://multimolecule.danling.org" homepage = "https://multimolecule.danling.org"