diff --git a/multimolecule/models/rnabert/configuration_rnabert.py b/multimolecule/models/rnabert/configuration_rnabert.py index d98468e2..3a81ea15 100644 --- a/multimolecule/models/rnabert/configuration_rnabert.py +++ b/multimolecule/models/rnabert/configuration_rnabert.py @@ -9,7 +9,7 @@ class RnaBertConfig(PretrainedConfig): This is the configuration class to store the configuration of a [`RnaBertModel`]. It is used to instantiate a RnaBert model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the RnaBert - [mana438/RNABERT](https://github.com/mana438/RNABERT/blob/master/RNA_bert_config.json) architecture. + [mana438/RNABERT](https://github.com/mana438/RNABERT) architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. @@ -55,7 +55,8 @@ class RnaBertConfig(PretrainedConfig): >>> # Initializing a model from the configuration >>> model = RnaBertModel(configuration) >>> # Accessing the model configuration >>> configuration = model.config - ```""" + ``` + """ model_type = "rnabert" diff --git a/multimolecule/models/rnabert/modeling_rnabert.py b/multimolecule/models/rnabert/modeling_rnabert.py index b9698589..2ae90100 100644 --- a/multimolecule/models/rnabert/modeling_rnabert.py +++ b/multimolecule/models/rnabert/modeling_rnabert.py @@ -1,21 +1,133 @@ import math +from typing import Optional import torch -from torch import nn +from torch import Tensor, nn from transformers import PreTrainedModel from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from .configuration_rnabert import RnaBertConfig +class RnaBertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = RnaBertConfig + base_model_prefix = "rnabert" + supports_gradient_checkpointing = True + _no_split_modules = ["RnaBertLayer", "RnaBertFoldTriangularSelfAttentionBlock", "RnaBertEmbeddings"] + + # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights + def _init_weights(self, module: nn.Module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +class RnaBertModel(RnaBertPreTrainedModel): + def __init__(self, config: RnaBertConfig): + super().__init__(config) + self.embeddings = RnaBertEmbeddings(config) + self.encoder = RnaBertEncoder(config) + self.pooler = RnaBertPooler(config) + + def forward( + self, + input_ids: Tensor, + token_type_ids: Optional[Tensor] = None, + attention_mask: Optional[Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = False, + ): + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + + extended_attention_mask = extended_attention_mask.to(dtype=torch.float32) + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + + embedding_output = self.embeddings( + input_ids=input_ids, + token_type_ids=token_type_ids, + # attention_mask=attention_mask, + ) + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class RnaBertForMaskedLM(nn.Module): + def __init__(self, config: RnaBertConfig): + super().__init__() + self.rnabert = RnaBertModel(config) + self.lm_head = RnaBertLMHead(config) + + def forward( + self, + input_ids: Tensor, + token_type_ids: Optional[Tensor] = None, + attention_mask: Optional[Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = False, + ): + outputs = self.rnabert( + input_ids, + token_type_ids, + attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + prediction_scores, prediction_scores_ss, seq_relationship_score = self.lm_head( + outputs.last_hidden_state, outputs.pooler_output + ) + return prediction_scores, prediction_scores_ss, outputs + + class RnaBertLayerNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-12): + def __init__(self, hidden_size: int, eps: float = 1e-12): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) # weightのこと self.bias = nn.Parameter(torch.zeros(hidden_size)) # biasのこと self.variance_epsilon = eps - def forward(self, x): + def forward(self, x: Tensor): u = x.mean(-1, keepdim=True) s = (x - u).pow(2).mean(-1, keepdim=True) x = (x - u) / torch.sqrt(s + self.variance_epsilon) @@ -23,15 +135,15 @@ def forward(self, x): class RnaBertEmbeddings(nn.Module): - def __init__(self, config): + def __init__(self, config: RnaBertConfig): super().__init__() - self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0) + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) self.LayerNorm = RnaBertLayerNorm(config.hidden_size, eps=1e-12) self.dropout = nn.Dropout(config.hidden_dropout_prob) - def forward(self, input_ids, token_type_ids=None): + def forward(self, input_ids: Tensor, token_type_ids: Optional[Tensor] = None): words_embeddings = self.word_embeddings(input_ids) if token_type_ids is None: @@ -52,13 +164,13 @@ def forward(self, input_ids, token_type_ids=None): class RnaBertLayer(nn.Module): - def __init__(self, config): + def __init__(self, config: RnaBertConfig): super().__init__() self.attention = RnaBertAttention(config) self.intermediate = RnaBertIntermediate(config) self.output = RnaBertOutput(config) - def forward(self, hidden_states, attention_mask, output_attentions=False): + def forward(self, hidden_states: Tensor, attention_mask: Tensor, output_attentions: bool = False): self_attention_outputs = self.attention(hidden_states, attention_mask, output_attentions=output_attentions) attention_output, outputs = self_attention_outputs[0], self_attention_outputs[1:] intermediate_output = self.intermediate(attention_output) @@ -68,12 +180,12 @@ def forward(self, hidden_states, attention_mask, output_attentions=False): class RnaBertAttention(nn.Module): - def __init__(self, config): + def __init__(self, config: RnaBertConfig): super().__init__() self.selfattn = RnaBertSelfAttention(config) self.output = RnaBertSelfOutput(config) - def forward(self, hidden_states, attention_mask, output_attentions=False): + def forward(self, hidden_states: Tensor, attention_mask: Tensor, output_attentions: bool = False): self_outputs = self.selfattn(hidden_states, attention_mask, output_attentions=output_attentions) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them @@ -81,7 +193,7 @@ def forward(self, hidden_states, attention_mask, output_attentions=False): class RnaBertSelfAttention(nn.Module): - def __init__(self, config): + def __init__(self, config: RnaBertConfig): super().__init__() self.num_attention_heads = config.num_attention_heads @@ -96,7 +208,7 @@ def __init__(self, config): self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - def transpose_for_scores(self, x): + def transpose_for_scores(self, x: Tensor): new_x_shape = x.size()[:-1] + ( self.num_attention_heads, self.attention_head_size, @@ -104,7 +216,7 @@ def transpose_for_scores(self, x): x = x.view(*new_x_shape) return x.permute(0, 2, 1, 3) - def forward(self, hidden_states, attention_mask, output_attentions=False): + def forward(self, hidden_states: Tensor, attention_mask: Tensor, output_attentions: bool = False): mixed_query_layer = self.query(hidden_states) mixed_key_layer = self.key(hidden_states) mixed_value_layer = self.value(hidden_states) @@ -131,39 +243,35 @@ def forward(self, hidden_states, attention_mask, output_attentions=False): class RnaBertSelfOutput(nn.Module): - def __init__(self, config): + def __init__(self, config: RnaBertConfig): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.LayerNorm = RnaBertLayerNorm(config.hidden_size, eps=1e-12) self.dropout = nn.Dropout(config.hidden_dropout_prob) - def forward(self, hidden_states, input_tensor): + def forward(self, hidden_states: Tensor, input_tensor: Tensor): hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.LayerNorm(hidden_states + input_tensor) return hidden_states -def gelu(x): - return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) - - class RnaBertIntermediate(nn.Module): - def __init__(self, config): + def __init__(self, config: RnaBertConfig): super().__init__() self.dense = nn.Linear(config.hidden_size, config.intermediate_size) self.intermediate_act_fn = gelu - def forward(self, hidden_states): + def forward(self, hidden_states: Tensor): hidden_states = self.dense(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states) return hidden_states class RnaBertOutput(nn.Module): - def __init__(self, config): + def __init__(self, config: RnaBertConfig): super().__init__() self.dense = nn.Linear(config.intermediate_size, config.hidden_size) @@ -172,7 +280,7 @@ def __init__(self, config): self.dropout = nn.Dropout(config.hidden_dropout_prob) - def forward(self, hidden_states, input_tensor): + def forward(self, hidden_states: Tensor, input_tensor: Tensor): hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.LayerNorm(hidden_states + input_tensor) @@ -180,7 +288,7 @@ def forward(self, hidden_states, input_tensor): class RnaBertEncoder(nn.Module): - def __init__(self, config): + def __init__(self, config: RnaBertConfig): super().__init__() self.layer = nn.ModuleList([RnaBertLayer(config) for _ in range(config.num_hidden_layers)]) # self.layer = nn.ModuleList([nn.Linear(config.hidden_size, config.hidden_size) @@ -188,26 +296,26 @@ def __init__(self, config): def forward( self, - hidden_states, - attention_mask, - output_attentions=False, - output_hidden_states=False, - return_dict=False, + hidden_states: Tensor, + attention_mask: Tensor, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = False, ): all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None for layer in self.layer: if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) + all_hidden_states = all_hidden_states + (hidden_states,) # type: ignore[operator] layer_outputs = layer(hidden_states, attention_mask, output_attentions) hidden_states = layer_outputs[0] if output_attentions: - all_self_attentions = all_self_attentions + (layer_outputs[1],) + all_self_attentions = all_self_attentions + (layer_outputs[1],) # type: ignore[operator] if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) + all_hidden_states = all_hidden_states + (hidden_states,) # type: ignore[operator] if not return_dict: return tuple( @@ -227,13 +335,13 @@ def forward( class RnaBertPooler(nn.Module): - def __init__(self, config): + def __init__(self, config: RnaBertConfig): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.activation = nn.Tanh() - def forward(self, hidden_states): + def forward(self, hidden_states: Tensor): first_token_tensor = hidden_states[:, 0] pooled_output = self.dense(first_token_tensor) @@ -243,89 +351,8 @@ def forward(self, hidden_states): return pooled_output -class RnaBertPreTrainedModel(PreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = RnaBertConfig - base_model_prefix = "rnabert" - supports_gradient_checkpointing = True - _no_split_modules = ["RnaBertLayer", "RnaBertFoldTriangularSelfAttentionBlock", "RnaBertEmbeddings"] - - # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights - def _init_weights(self, module): - """Initialize the weights""" - if isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - - -class RnaBertModel(RnaBertPreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.embeddings = RnaBertEmbeddings(config) - self.encoder = RnaBertEncoder(config) - self.pooler = RnaBertPooler(config) - - def forward( - self, - input_ids, - token_type_ids=None, - attention_mask=None, - output_attentions=False, - output_hidden_states=False, - return_dict=True, - ): - if attention_mask is None: - attention_mask = torch.ones_like(input_ids) - if token_type_ids is None: - token_type_ids = torch.zeros_like(input_ids) - - extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) - - extended_attention_mask = extended_attention_mask.to(dtype=torch.float32) - extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 - - embedding_output = self.embeddings( - input_ids=input_ids, - token_type_ids=token_type_ids, - # attention_mask=attention_mask, - ) - encoder_outputs = self.encoder( - embedding_output, - attention_mask=extended_attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - sequence_output = encoder_outputs[0] - pooled_output = self.pooler(sequence_output) if self.pooler is not None else None - - if not return_dict: - return (sequence_output, pooled_output) + encoder_outputs[1:] - - return BaseModelOutputWithPooling( - last_hidden_state=sequence_output, - pooler_output=pooled_output, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - ) - - class RnaBertLMHead(nn.Module): - def __init__(self, config): + def __init__(self, config: RnaBertConfig): super().__init__() self.predictions = MaskedWordPredictions(config, config.vocab_size) @@ -333,7 +360,7 @@ def __init__(self, config): self.seq_relationship = nn.Linear(config.hidden_size, 2) - def forward(self, sequence_output, pooled_output): + def forward(self, sequence_output: Tensor, pooled_output: Tensor): prediction_scores = self.predictions(sequence_output) prediction_scores_ss = self.predictions_ss(sequence_output) @@ -351,7 +378,7 @@ def __init__(self, config, vocab_size): self.decoder = nn.Linear(in_features=config.hidden_size, out_features=vocab_size, bias=False) self.bias = nn.Parameter(torch.zeros(vocab_size)) - def forward(self, hidden_states): + def forward(self, hidden_states: Tensor): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) + self.bias @@ -359,7 +386,7 @@ def forward(self, hidden_states): class RnaBertPredictionHeadTransform(nn.Module): - def __init__(self, config): + def __init__(self, config: RnaBertConfig): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) @@ -368,38 +395,12 @@ def __init__(self, config): self.LayerNorm = RnaBertLayerNorm(config.hidden_size, eps=1e-12) - def forward(self, hidden_states): + def forward(self, hidden_states: Tensor): hidden_states = self.dense(hidden_states) # hidden_states = self.transform_act_fn(hidden_states) hidden_states = self.LayerNorm(hidden_states) return hidden_states -class RnaBertForMaskedLM(nn.Module): - def __init__(self, config): - super().__init__() - self.bert = RnaBertModel(config) - self.lm_head = RnaBertLMHead(config) - - def forward( - self, - input_ids, - token_type_ids=None, - attention_mask=None, - output_attentions=False, - output_hidden_states=False, - return_dict=False, - ): - outputs = self.bert( - input_ids, - token_type_ids, - attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - prediction_scores, prediction_scores_ss, seq_relationship_score = self.lm_head( - outputs.last_hidden_state, outputs.pooler_output - ) - return prediction_scores, prediction_scores_ss, outputs +def gelu(x): + return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))