From 5513302ada6c7053fe457563fc824b59c1fd6797 Mon Sep 17 00:00:00 2001 From: Zhiyuan Chen Date: Tue, 2 Apr 2024 22:31:51 +0800 Subject: [PATCH] add support of RnaMsm Signed-off-by: Zhiyuan Chen --- multimolecule/models/__init__.py | 14 +- multimolecule/models/rnamsm/__init__.py | 36 + .../models/rnamsm/configuration_rnamsm.py | 106 ++ .../models/rnamsm/convert_checkpoint.py | 106 ++ .../models/rnamsm/modeling_rnamsm.py | 1230 +++++++++++++++++ tox.ini | 1 + 6 files changed, 1492 insertions(+), 1 deletion(-) create mode 100644 multimolecule/models/rnamsm/__init__.py create mode 100644 multimolecule/models/rnamsm/configuration_rnamsm.py create mode 100644 multimolecule/models/rnamsm/convert_checkpoint.py create mode 100644 multimolecule/models/rnamsm/modeling_rnamsm.py diff --git a/multimolecule/models/__init__.py b/multimolecule/models/__init__.py index 1f07452e..48c90187 100644 --- a/multimolecule/models/__init__.py +++ b/multimolecule/models/__init__.py @@ -1,10 +1,17 @@ +from ..tokenizers.rna import RnaTokenizer from .rnabert import ( RnaBertConfig, RnaBertForMaskedLM, RnaBertForSequenceClassification, RnaBertForTokenClassification, RnaBertModel, - RnaTokenizer, +) +from .rnamsm import ( + RnaMsmConfig, + RnaMsmForMaskedLM, + RnaMsmForSequenceClassification, + RnaMsmForTokenClassification, + RnaMsmModel, ) __all__ = [ @@ -13,5 +20,10 @@ "RnaBertForMaskedLM", "RnaBertForSequenceClassification", "RnaBertForTokenClassification", + "RnaMsmConfig", + "RnaMsmModel", + "RnaMsmForMaskedLM", + "RnaMsmForSequenceClassification", + "RnaMsmForTokenClassification", "RnaTokenizer", ] diff --git a/multimolecule/models/rnamsm/__init__.py b/multimolecule/models/rnamsm/__init__.py new file mode 100644 index 00000000..6405ac1d --- /dev/null +++ b/multimolecule/models/rnamsm/__init__.py @@ -0,0 +1,36 @@ +from transformers import ( + AutoConfig, + AutoModel, + AutoModelForMaskedLM, + AutoModelForSequenceClassification, + AutoModelForTokenClassification, + AutoModelWithLMHead, + AutoTokenizer, +) + +from multimolecule.tokenizers.rna import RnaTokenizer + +from .configuration_rnamsm import RnaMsmConfig +from .modeling_rnamsm import ( + RnaMsmForMaskedLM, + RnaMsmForSequenceClassification, + RnaMsmForTokenClassification, + RnaMsmModel, +) + +__all__ = [ + "RnaMsmConfig", + "RnaMsmModel", + "RnaTokenizer", + "RnaMsmForMaskedLM", + "RnaMsmForSequenceClassification", + "RnaMsmForTokenClassification", +] + +AutoConfig.register("rnamsm", RnaMsmConfig) +AutoModel.register(RnaMsmConfig, RnaMsmModel) +AutoModelForMaskedLM.register(RnaMsmConfig, RnaMsmForMaskedLM) +AutoModelForSequenceClassification.register(RnaMsmConfig, RnaMsmForSequenceClassification) +AutoModelForTokenClassification.register(RnaMsmConfig, RnaMsmForTokenClassification) +AutoModelWithLMHead.register(RnaMsmConfig, RnaMsmForTokenClassification) +AutoTokenizer.register(RnaMsmConfig, RnaTokenizer) diff --git a/multimolecule/models/rnamsm/configuration_rnamsm.py b/multimolecule/models/rnamsm/configuration_rnamsm.py new file mode 100644 index 00000000..5351949b --- /dev/null +++ b/multimolecule/models/rnamsm/configuration_rnamsm.py @@ -0,0 +1,106 @@ +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class RnaMsmConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`RnaMsmModel`]. It is used to instantiate a + RnaMsm model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the RnaMsm + [yikunpku/RNA-MSM](https://github.com/yikunpku/RNA-MSM) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*): + Vocabulary size of the RnaMsm model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`RnaMsmModel`]. + mask_token_id (`int`, *optional*): + The index of the mask token in the vocabulary. This must be included in the config because of the + "mask-dropout" scaling trick, which will scale the inputs depending on the number of masked tokens. + pad_token_id (`int`, *optional*): + The index of the padding token in the vocabulary. This must be included in the config because certain parts + of the RnaMsm code use this instead of the attention mask. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (`int`, *optional*, defaults to 1026): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + + Examples: + + ```python + >>> from multimolecule import RnaMsmModel, RnaMsmConfig + + >>> # Initializing a RnaMsm style configuration >>> configuration = RnaMsmConfig() + + >>> # Initializing a model from the configuration >>> model = RnaMsmModel(configuration) + + >>> # Accessing the model configuration >>> configuration = model.config + ``` + """ + + model_type = "rnamsm" + + def __init__( + self, + vocab_size=25, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=1024, + initializer_range=0.02, + layer_norm_eps=1e-12, + pad_token_id=0, + bos_token_id=1, + position_embedding_type="absolute", + use_cache=True, + classifier_dropout=None, + max_tokens_per_msa=2**14, + attention_type="standard", + embed_positions_msa=True, + attention_bias=True, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.position_embedding_type = position_embedding_type + self.use_cache = use_cache + self.classifier_dropout = classifier_dropout + self.max_tokens_per_msa = max_tokens_per_msa + self.attention_type = attention_type + self.embed_positions_msa = embed_positions_msa + self.attention_bias = attention_bias diff --git a/multimolecule/models/rnamsm/convert_checkpoint.py b/multimolecule/models/rnamsm/convert_checkpoint.py new file mode 100644 index 00000000..1a32a50a --- /dev/null +++ b/multimolecule/models/rnamsm/convert_checkpoint.py @@ -0,0 +1,106 @@ +import os +from typing import Optional + +import chanfig +import torch +from torch import nn + +from multimolecule.models import RnaMsmConfig as Config +from multimolecule.models import RnaMsmForMaskedLM as Model +from multimolecule.tokenizers.rna.utils import get_special_tokens_map, get_tokenizer_config, get_vocab_list + +try: + from huggingface_hub import HfApi +except ImportError: + HfApi = None + + +torch.manual_seed(1013) + +CONFIG = { + "architectures": ["RnaMsmModel"], + "attention_probs_dropout_prob": 0.1, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "hidden_size": 768, + "intermediate_size": 3072, + "max_position_embeddings": 1024, + "num_attention_heads": 12, + "num_hidden_layers": 10, + "vocab_size": 25, + "pad_token_id": 0, + "embed_positions_msa": True, +} + +original_vocab_list = ["", "", "", "", "A", "G", "C", "U", "X", "N", "-", ""] +vocab_list = get_vocab_list() + + +def _convert_checkpoint(config, original_state_dict): + state_dict = {} + for key, value in original_state_dict.items(): + key = key.replace("layers", "rnamsm.encoder.layer") + key = key.replace("msa_position_embedding", "rnamsm.embeddings.msa_embeddings") + key = key.replace("embed_tokens", "rnamsm.embeddings.word_embeddings") + key = key.replace("embed_positions", "rnamsm.embeddings.position_embeddings") + key = key.replace("emb_layer_norm_before", "rnamsm.embeddings.layer_norm") + key = key.replace("emb_layer_norm_after", "rnamsm.encoder.layer_norm") + state_dict[key] = value + + word_embed = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + predictions_bias = torch.zeros(config.vocab_size) + # nn.init.normal_(pos_embed.weight, std=0.02) + for original_index, original_token in enumerate(original_vocab_list): + new_index = vocab_list.index(original_token) + word_embed.weight.data[new_index] = state_dict["rnamsm.embeddings.word_embeddings.weight"][original_index] + predictions_bias[new_index] = state_dict["lm_head.bias"][original_index] + state_dict["rnamsm.embeddings.word_embeddings.weight"] = word_embed.weight.data + state_dict["lm_head.weight"] = word_embed.weight.data + state_dict["lm_head.bias"] = predictions_bias + return state_dict + + +def convert_checkpoint(convert_config): + config = Config.from_dict(chanfig.FlatDict(CONFIG)) + config.vocab_size = len(vocab_list) + + model = Model(config) + + ckpt = torch.load(convert_config.checkpoint_path, map_location=torch.device("cpu")) + state_dict = _convert_checkpoint(config, ckpt) + + model.load_state_dict(state_dict) + model.save_pretrained(convert_config.output_path, safe_serialization=True) + model.save_pretrained(convert_config.output_path, safe_serialization=False) + chanfig.NestedDict(get_special_tokens_map()).json( + os.path.join(convert_config.output_path, "special_tokens_map.json") + ) + chanfig.NestedDict(get_tokenizer_config()).json(os.path.join(convert_config.output_path, "tokenizer_config.json")) + + if convert_config.push_to_hub: + if HfApi is None: + raise ImportError("Please install huggingface_hub to push to the hub.") + api = HfApi() + api.create_repo( + convert_config.repo_id, + token=convert_config.token, + exist_ok=True, + ) + api.upload_folder( + repo_id=convert_config.repo_id, folder_path=convert_config.output_path, token=convert_config.token + ) + + +@chanfig.configclass +class ConvertConfig: + checkpoint_path: str + output_path: str = Config.model_type + push_to_hub: bool = False + repo_id: str = "ZhiyuanChen/" + output_path + token: Optional[str] = None + + +if __name__ == "__main__": + config = ConvertConfig() + config.parse() # type: ignore[attr-defined] + convert_checkpoint(config) diff --git a/multimolecule/models/rnamsm/modeling_rnamsm.py b/multimolecule/models/rnamsm/modeling_rnamsm.py new file mode 100644 index 00000000..d028aba5 --- /dev/null +++ b/multimolecule/models/rnamsm/modeling_rnamsm.py @@ -0,0 +1,1230 @@ +from __future__ import annotations + +import math +from dataclasses import dataclass +from functools import partial, wraps +from typing import Callable, Optional, Tuple, Union + +import numpy as np +import torch +from chanfig import ConfigRegistry +from torch import Tensor, nn +from torch.nn import functional as F +from transformers import PreTrainedModel +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ModelOutput, SequenceClassifierOutput, TokenClassifierOutput + +from ..modeling_utils import SequenceClassificationHead, TokenClassificationHead +from .configuration_rnamsm import RnaMsmConfig + + +class RnaMsmPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = RnaMsmConfig + base_model_prefix = "rnamsm" + supports_gradient_checkpointing = True + _no_split_modules = ["RnaMsmLayer", "RnaMsmAxialLayer", "RnaMsmPkmLayer", "RnaMsmEmbeddings"] + + # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights + def _init_weights(self, module: nn.Module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm) and module.elementwise_affine: + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +class RnaMsmModel(RnaMsmPreTrainedModel): + def __init__(self, config, add_pooling_layer=True): + super().__init__(config) + self.config = config + self.pad_token_id = config.pad_token_id + self.embeddings = RnaMsmEmbeddings(config) + self.encoder = RnaMsmEncoder(config) + self.pooler = RnaMsmPooler(config) if add_pooling_layer else None + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids: Tensor, + attention_mask: Optional[Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = False, + ) -> Union[Tuple[torch.Tensor], RnaMsmModelOutputWithPooling]: + if attention_mask is None: + attention_mask = input_ids.eq(self.pad_token_id) # B, R, C + if not attention_mask.any(): + attention_mask = None + + embedding_output = self.embeddings( + input_ids=input_ids, + attention_mask=attention_mask, + ) + encoder_outputs = self.encoder( + embedding_output, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return RnaMsmModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + col_attentions=encoder_outputs.col_attentions, + row_attentions=encoder_outputs.row_attentions, + ) + + +class RnaMsmForMaskedLM(RnaMsmPreTrainedModel): + def __init__(self, config: RnaMsmConfig): + super().__init__(config) + self.rnamsm = RnaMsmModel(config) + self.lm_head = RnaMsmLMHead(config, weight=self.rnamsm.embeddings.word_embeddings.weight) + self.contact_head = RnaMsmContactHead(config) + + def forward( + self, + input_ids: Tensor, + attention_mask: Optional[Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = False, + ): + outputs = self.rnamsm( + input_ids, + attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + prediction_scores = self.lm_head(sequence_output) + + if not return_dict: + return (prediction_scores,) + outputs[2:] + + return RnaMsmMaskedLMOutput( + logits=prediction_scores, + hidden_states=outputs.hidden_states, + row_attentions=outputs.row_attentions, + col_attentions=outputs.col_attentions, + ) + + +class RnaMsmForSequenceClassification(RnaMsmPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + + self.rnamsm = RnaMsmModel(config, add_pooling_layer=False) + self.classifier = SequenceClassificationHead(config) + + self.init_weights() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.rnamsm( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + return self.classifier(outputs, labels) + + +class RnaMsmForTokenClassification(RnaMsmPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + + self.rnamsm = RnaMsmModel(config, add_pooling_layer=False) + self.classifier = TokenClassificationHead(config) + + self.init_weights() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.rnamsm( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + return self.classifier(outputs, labels) + + +class RnaMsmEncoder(nn.Module): + def __init__(self, config: RnaMsmConfig): + super().__init__() + self.config = config + self.layer = nn.ModuleList([RnaMsmAxialLayer(config) for _ in range(config.num_hidden_layers)]) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + output_contacts: bool = False, + return_dict: bool = True, + ) -> Union[Tuple[torch.Tensor], RnaMsmModelOutput]: + if output_contacts: + output_attentions = True + all_hidden_states = () if output_hidden_states else None + all_col_attentions = () if output_attentions else None + all_row_attentions = () if output_attentions else None + + # repr_layers = set(repr_layers) + # hidden_representations = {} + # if 0 in repr_layers: + # hidden_representations[0] = hidden_states + + # B x R x C x D -> R x C x B x D + hidden_states = hidden_states.permute(1, 2, 0, 3) + + for layer_module in self.layer: + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) # type: ignore + layer_outputs = layer_module( + hidden_states, + self_attention_padding_mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states = layer_outputs[0] + if output_attentions: + # H x C x B x R x R -> B x H x C x R x R + all_col_attentions = all_col_attentions + (layer_outputs[1].permute(2, 0, 1, 3, 4),) # type: ignore + # H x B x C x C -> B x H x C x C + all_row_attentions = all_row_attentions + (layer_outputs[2].permute(1, 0, 2, 3),) # type: ignore + + hidden_states = self.layer_norm(hidden_states) + hidden_states = hidden_states.permute(2, 0, 1, 3) # R x C x B x D -> B x R x C x D + + # last hidden representation should have layer norm applied + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) # type: ignore + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + all_hidden_states, + all_col_attentions, + all_row_attentions, + ] + if v is not None + ) + return RnaMsmModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + col_attentions=all_col_attentions, + row_attentions=all_row_attentions, + ) + + +class RnaMsmAxialLayer(nn.Module): + """Implements an Axial MSA Transformer block.""" + + def __init__(self, config) -> None: + super().__init__() + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + row_self_attention = RowSelfAttention(config) + column_self_attention = ColumnSelfAttention(config) + feed_forward_layer = FeedForwardNetwork(config) + + self.row_self_attention = NormalizedResidualBlock(config, row_self_attention) + self.column_self_attention = NormalizedResidualBlock(config, column_self_attention) + self.feed_forward_layer = NormalizedResidualBlock(config, feed_forward_layer) + + def forward( + self, + hidden_states: Tensor, + self_attention_mask: Optional[Tensor] = None, + self_attention_padding_mask: Optional[Tensor] = None, + output_attentions: bool = False, + ): + """ + LayerNorm is applied either before or after the self-attention/ffn + modules similar to the original Transformer implementation. + """ + row_attention_outputs = self.row_self_attention( + hidden_states, + self_attention_mask=self_attention_mask, + self_attention_padding_mask=self_attention_padding_mask, + output_attentions=output_attentions, + ) + row_attention_output, row_outputs = row_attention_outputs[0], row_attention_outputs[1:] + col_attention_outputs = self.column_self_attention( + row_attention_output, + self_attention_mask=self_attention_mask, + self_attention_padding_mask=self_attention_padding_mask, + output_attentions=output_attentions, + ) + col_attention_output, col_outputs = col_attention_outputs[0], col_attention_outputs[1:] + context_layer = self.feed_forward_layer(col_attention_output) + + outputs = (context_layer,) + col_outputs + row_outputs + return outputs + + +class RnaMsmLayer(nn.Module): + """Transformer layer block.""" + + def __init__(self, config: RnaMsmConfig): + super().__init__() + self.dropout = nn.Dropout(config.dropout) + self.self_attention = attention_registry.build(config) + self.self_attention_layer_norm = nn.LayerNorm(config.hidden_size) + self.ffn = FeedForwardNetwork(config) + self.final_layer_norm = nn.LayerNorm(config.hidden_size) + + def forward( + self, + hidden_states, + self_attention_mask=None, + self_attention_padding_mask=None, + output_attentions: bool = False, + ): + residual = hidden_states + hidden_states = self.self_attention_layer_norm(hidden_states) + hidden_states, attention_probs = self.self_attention( + hidden_states, + key_padding_mask=self_attention_padding_mask, + output_attentions=output_attentions, + attention_mask=self_attention_mask, + ) + hidden_states = self.dropout(hidden_states) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.ffn(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states, attention_probs + + +class RnaMsmPkmLayer(nn.Module): + """Transformer layer block.""" + + def __init__(self, config: RnaMsmConfig): + from product_key_memory import PKM + + super().__init__() + self.self_attention = attention_registry.build(config) + self.self_attention_layer_norm = nn.LayerNorm(config.hidden_size) + + self.pkm = PKM( + config.hidden_size, + config.pkm_attention_heads, + config.num_product_keys, + config.pkm_topk, + config.pkm_dim_head, + ) + + self.final_layer_norm = nn.LayerNorm(self.hidden_size) + + def forward( + self, + hidden_states, + self_attention_mask=None, + self_attention_padding_mask=None, + output_attentions: bool = False, + ): + residual = hidden_states + hidden_states = self.self_attention_layer_norm(hidden_states) + hidden_states, attention_probs = self.self_attention( + query=hidden_states, + key=hidden_states, + value=hidden_states, + key_padding_mask=self_attention_padding_mask, + output_attentions=output_attentions, + attention_mask=self_attention_mask, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.pkm(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states, attention_probs + + +class RnaMsmLearnedPositionalEmbedding(nn.Embedding): + """ + This module learns positional embeddings up to a fixed maximum size. + Padding ids are ignored by either offsetting based on padding_idx + or by setting padding_idx to None and ensuring that the appropriate + position ids are passed to the forward function. + """ + + def __init__(self, num_embeddings: int, *args, **kwargs): + num_embeddings += 2 + super().__init__(num_embeddings, *args, **kwargs) + self.max_positions = num_embeddings + + def forward(self, input_ids: Tensor, attention_mask: Optional[Tensor] = None): + """Input is expected to be of size [bsz x seqlen].""" + if attention_mask is None: + attention_mask = input_ids.ne(self.padding_idx).int() + # This is a bug in the original implementation + positions = (torch.cumsum(attention_mask, dim=1, dtype=attention_mask.dtype) * attention_mask).long() + 1 + return F.embedding( + positions, + self.weight, + self.padding_idx, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.sparse, + ) + + +class RowSelfAttention(nn.Module): + """Compute self-attention over rows of a 2D input.""" + + def __init__(self, config: RnaMsmConfig): + super().__init__() + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.scaling = self.attention_head_size**-0.5 + self.max_tokens_per_msa = config.max_tokens_per_msa + self.attention_shape = "hnij" + + self.q_proj = nn.Linear(config.hidden_size, config.hidden_size) + self.k_proj = nn.Linear(config.hidden_size, config.hidden_size) + self.v_proj = nn.Linear(config.hidden_size, config.hidden_size) + + self.out_proj = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def align_scaling(self, q): + num_rows = q.size(0) + return self.scaling / math.sqrt(num_rows) + + def compute_attention_weights( + self, + hidden_states, + scaling: float, + self_attention_mask=None, + self_attention_padding_mask=None, + ): + num_rows, num_cols, batch_size, _ = hidden_states.size() + q = self.q_proj(hidden_states).view( + num_rows, num_cols, batch_size, self.num_attention_heads, self.attention_head_size + ) + k = self.k_proj(hidden_states).view( + num_rows, num_cols, batch_size, self.num_attention_heads, self.attention_head_size + ) + q *= scaling + if self_attention_padding_mask is not None: + # Zero out any padded aligned positions - this is important since + # we take a sum across the alignment axis. + q *= 1 - self_attention_padding_mask.permute(1, 2, 0).unsqueeze(3).unsqueeze(4).to(q) + # q *= 1 - self_attention_padding_mask.permute(3, 4, 0, 1, 2).to(q) + + attention_scores = torch.einsum(f"rinhd,rjnhd->{self.attention_shape}", q, k) + + if self_attention_mask is not None: + raise NotImplementedError + # Mask Size: [B x R x C], Weights Size: [H x B x C x C] + + if self_attention_padding_mask is not None: + attention_scores = attention_scores.masked_fill( + self_attention_padding_mask[:, 0].unsqueeze(0).unsqueeze(2), -10000 + ) + + return attention_scores + + def compute_attention_update( + self, + hidden_states, + attention_probs, + ): + num_rows, num_cols, batch_size, hidden_size = hidden_states.size() + v = self.v_proj(hidden_states).view( + num_rows, num_cols, batch_size, self.num_attention_heads, self.attention_head_size + ) + context_layer = torch.einsum(f"{self.attention_shape},rjnhd->rinhd", attention_probs, v) + context_layer = context_layer.reshape(num_rows, num_cols, batch_size, hidden_size) + output = self.out_proj(context_layer) + return output + + def forward( + self, + hidden_states, + self_attention_mask=None, + self_attention_padding_mask=None, + output_attentions: bool = False, + ): + num_rows, num_cols, _, _ = hidden_states.size() + if (num_rows * num_cols > self.max_tokens_per_msa) and not torch.is_grad_enabled(): + return self._batched_forward(hidden_states, self_attention_mask, self_attention_padding_mask) + scaling = self.align_scaling(hidden_states) + attention_scores = self.compute_attention_weights( + hidden_states, scaling, self_attention_mask, self_attention_padding_mask + ) + attention_probs = attention_scores.softmax(-1) + attention_probs = self.dropout(attention_probs) + context_layer = self.compute_attention_update(hidden_states, attention_probs) + outputs = (context_layer, attention_probs) if output_attentions else (context_layer, None) + return outputs + + def _batched_forward( + self, + hidden_states, + self_attention_mask=None, + self_attention_padding_mask=None, + output_attentions: bool = False, + ): + num_rows, num_cols, _, _ = hidden_states.size() + max_rows = max(1, self.max_tokens_per_msa // num_cols) + scaling = self.align_scaling(hidden_states) + attention_scores = 0 + for start in range(0, num_rows, max_rows): + attention_scores += self.compute_attention_weights( + hidden_states[start : start + max_rows], + scaling, + self_attention_mask=self_attention_mask, + self_attention_padding_mask=( + self_attention_padding_mask[:, start : start + max_rows] + if self_attention_padding_mask is not None + else None + ), + ) + attention_probs = attention_scores.softmax(-1) # type: ignore[attr-defined] + attention_probs = self.dropout(attention_probs) + context_layer = torch.cat( + [ + self.compute_attention_update(hidden_states[start : start + max_rows], attention_probs) + for start in range(0, num_rows, max_rows) + ], + 0, + ) + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + return outputs + + +class ColumnSelfAttention(nn.Module): + """Compute self-attention over columns of a 2D input.""" + + def __init__(self, config: RnaMsmConfig): + super().__init__() + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.scaling = self.attention_head_size**-0.5 + self.max_tokens_per_msa = config.max_tokens_per_msa + + self.q_proj = nn.Linear(config.hidden_size, config.hidden_size) + self.k_proj = nn.Linear(config.hidden_size, config.hidden_size) + self.v_proj = nn.Linear(config.hidden_size, config.hidden_size) + + self.out_proj = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def compute_attention_update( + self, + hidden_states, + self_attention_mask=None, + self_attention_padding_mask=None, + output_attentions: bool = False, + ): + num_rows, num_cols, batch_size, hidden_size = hidden_states.size() + if num_rows == 1: + # if there is only 1 position, this is equivalent and doesn't break with + # padding + attention_probs = torch.ones( + self.num_attention_heads, + num_cols, + batch_size, + num_rows, + num_rows, + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + context_layer = self.out_proj(self.v_proj(hidden_states)) + else: + q = self.q_proj(hidden_states).view( + num_rows, num_cols, batch_size, self.num_attention_heads, self.attention_head_size + ) + k = self.k_proj(hidden_states).view( + num_rows, num_cols, batch_size, self.num_attention_heads, self.attention_head_size + ) + v = self.v_proj(hidden_states).view( + num_rows, num_cols, batch_size, self.num_attention_heads, self.attention_head_size + ) + q *= self.scaling + + attention_scores = torch.einsum("icnhd,jcnhd->hcnij", q, k) + + if self_attention_mask is not None: + raise NotImplementedError + if self_attention_padding_mask is not None: + attention_scores = attention_scores.masked_fill( + self_attention_padding_mask.permute(2, 0, 1).unsqueeze(0).unsqueeze(3), + -10000, + ) + + attention_probs = attention_scores.softmax(-1) + attention_probs = self.dropout(attention_probs) + context_layer = torch.einsum("hcnij,jcnhd->icnhd", attention_probs, v) + context_layer = context_layer.reshape(num_rows, num_cols, batch_size, hidden_size) + context_layer = self.out_proj(context_layer) + outputs = (context_layer, attention_probs) if output_attentions else (context_layer, None) + return outputs + + def forward( + self, + hidden_states, + self_attention_mask=None, + self_attention_padding_mask=None, + output_attentions: bool = False, + ): + num_rows, num_cols, _, _ = hidden_states.size() + # if False and num_rows * num_cols > 2 ** 14 and not torch.is_grad_enabled(): + if (num_rows * num_cols) > self.max_tokens_per_msa and not torch.is_grad_enabled(): + return self._batched_forward( + hidden_states, + self_attention_mask, + self_attention_padding_mask, + ) + else: + return self.compute_attention_update( + hidden_states, self_attention_mask, self_attention_padding_mask, output_attentions + ) + + def _batched_forward( + self, + hidden_states, + self_attention_mask=None, + self_attention_padding_mask=None, + output_attentions: bool = False, + ): + num_rows, num_cols, _, _ = hidden_states.size() + max_cols = max(1, self.max_tokens_per_msa // num_rows) + contexts, attentions = [], [] + for start in range(0, num_cols, max_cols): + output, attention = self( + hidden_states[:, start : start + max_cols], + self_attention_mask=self_attention_mask, + self_attention_padding_mask=( + self_attention_padding_mask[:, :, start : start + max_cols] + if self_attention_padding_mask is not None + else None + ), + ) + contexts.append(output) + attentions.append(attention) + context_layer = torch.cat(contexts, 1) + attention_probs = torch.cat(attentions, 1) + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + return outputs + + +attention_registry = ConfigRegistry(key="attention_type") + + +@attention_registry.register("standard") +class MultiheadAttention(nn.Module): + """Multi-headed attention. + + See "Attention Is All You Need" for more details. + """ + + def __init__(self, config: RnaMsmConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.scaling = self.attention_head_size**-0.5 + self.dropout_prob = config.attention_probs_dropout_prob + + self.q_proj = nn.Linear(config.hidden_size, config.hidden_size, config.attention_bias) + self.k_proj = nn.Linear(config.hidden_size, config.hidden_size, config.attention_bias) + self.v_proj = nn.Linear(config.hidden_size, config.hidden_size, config.attention_bias) + + self.out_proj = nn.Linear(config.hidden_size, config.hidden_size, config.attention_bias) + + self.reset_parameters() + self.enable_torch_version = hasattr(F, "multi_head_attention_forward") + if self.enable_torch_version: + self._attention_fn = partial( + F.multi_head_attention_forward, # type: ignore + embed_dim_to_check=self.hidden_size, + num_heads=self.num_attention_heads, + in_proj_weight=torch.empty([0]), + bias_k=None, + bias_v=None, + add_zero_attention=False, + dropout_p=self.dropout_prob, + use_separate_proj_weight=True, + ) + + def attention_fn( + self, + query, + key, + value, + key_padding_mask: Optional[Tensor] = None, + output_attentions: bool = False, + attention_mask: Optional[Tensor] = None, + ): + return self._attention_fn( + query, + key, + value, + in_proj_bias=torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), + attention_mask=attention_mask, + key_padding_mask=key_padding_mask, + training=self.training, + need_weights=output_attentions, + out_proj_weight=self.out_proj.weight, + out_proj_bias=self.out_proj.bias, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + ) + + def reset_parameters(self): + # Empirically observed the convergence to be much better with + # the scaled initialization + nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) + + nn.init.xavier_uniform_(self.out_proj.weight) + if self.out_proj.bias is not None: + nn.init.constant_(self.out_proj.bias, 0.0) + + def forward( + self, + hidden_states: Tensor, + key_padding_mask: Optional[Tensor] = None, + attention_mask: Optional[Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[Tensor, Optional[Tensor]]: + """Input shape: Time x Batch x Channel + + Args: + key_padding_mask (ByteTensor, optional): mask to exclude + keys that are pads, of shape `(batch, src_len)`, where + padding elements are indicated by 1s. + attention_mask (ByteTensor, optional): typically used to + implement causal attention, where the mask prevents the + attention from looking forward in time (default: None). + before_softmax (bool, optional): return the raw attention + pretrained and values before the attention softmax. + """ + + tgt_len, bsz, hidden_size = hidden_states.size() + assert hidden_size == self.hidden_size + + # A workaround for quantization to work. Otherwise JIT compilation + # treats bias in linear module as method. + # and not output_attentions + if self.enable_torch_version and not torch.jit.is_scripting(): + return self.attention_fn( + query=hidden_states, + key=hidden_states, + value=hidden_states, + key_padding_mask=key_padding_mask, + output_attentions=output_attentions, + attention_mask=attention_mask, + ) + + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + q *= self.scaling + + q = q.reshape(tgt_len, bsz * self.num_attention_heads, self.attention_head_size).transpose(0, 1) + k = k.reshape(-1, bsz * self.num_attention_heads, self.attention_head_size).transpose(0, 1) + v = v.reshape(-1, bsz * self.num_attention_heads, self.attention_head_size).transpose(0, 1) + + src_len = k.size(1) + + # This is part of a workaround to get around fork/join parallelism + # not supporting Optional types. + if key_padding_mask is not None and key_padding_mask.dim() == 0: + key_padding_mask = None + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == src_len + + attention_scores = torch.bmm(q, k.transpose(1, 2)) + + assert list(attention_scores.size()) == [bsz * self.num_attention_heads, tgt_len, src_len] + + if attention_mask is not None: + attention_mask = attention_mask.unsqueeze(0) + if self.onnx_trace: + attention_mask = attention_mask.repeat(attention_scores.size(0), 1, 1) + attention_scores += attention_mask + + if key_padding_mask is not None: + # don't attend to padding symbols + attention_scores = attention_scores.view(bsz, self.num_attention_heads, tgt_len, src_len) + attention_scores = attention_scores.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf") + ) + attention_scores = attention_scores.view(bsz * self.num_attention_heads, tgt_len, src_len) + + # attention_probs = F.softmax(attention_scores, dim=-1, dtype=torch.float32) + # attention_probs = attention_probs.type_as(attention_scores) + attention_probs = attention_scores.softmax(-1) + attention_probs = F.dropout( + attention_probs.type_as(attention_scores), + p=self.dropout_prob, + training=self.training, + ) + + context_layer = torch.bmm(attention_probs, v) + assert list(context_layer.size()) == [bsz * self.num_attention_heads, tgt_len, self.attention_head_size] + context_layer = context_layer.transpose(0, 1).reshape(tgt_len, bsz, hidden_size) + context_layer = self.out_proj(context_layer) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + return outputs + + +@attention_registry.register("performer") +class PerformerAttention(MultiheadAttention): + def __init__(self, config: RnaMsmConfig): + from performer_pytorch import FastAttention + + super().__init__(config) + self._attention_fn = FastAttention(dim_heads=self.attention_head_size, nb_features=config.num_features) + + def attention_fn(self, query, key, value): + return self._attention_fn(query, key, value) + + def forward( + self, + hidden_states: Tensor, + key_padding_mask: Optional[Tensor] = None, + attention_mask: Optional[Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[Tensor, Optional[Tensor]]: + from einops import rearrange + + q = self.q_proj(hidden_states) # [T x B x D] + k = self.k_proj(hidden_states) # [...] + v = self.v_proj(hidden_states) # [...] + + q, k, v = (rearrange(t, "t b (h d) -> b h t d", h=self.num_attention_heads) for t in (q, k, v)) + + if key_padding_mask is not None: + mask = key_padding_mask[:, None, :, None] + v.masked_fill_(mask, 0) + if attention_mask is not None: + raise NotImplementedError + + attention_probs = self.attention_fn(q, k, v) + context_layer = rearrange(attention_probs, "b h t d -> t b (h d)") + context_layer = self.out_proj(context_layer) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + return outputs + + +class NormalizedResidualBlock(nn.Module): + def __init__( + self, + config: RnaMsmConfig, + layer: nn.Module, + ): + super().__init__() + self.layer = layer + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.layer_norm = nn.LayerNorm(config.hidden_size) + + def forward(self, hidden_states, *args, **kwargs): + residual = hidden_states + hidden_states = self.layer_norm(hidden_states) + outputs = self.layer(hidden_states, *args, **kwargs) + if isinstance(outputs, tuple): + hidden_states, *out = outputs + else: + hidden_states = outputs + out = None + + hidden_states = self.dropout(hidden_states) + hidden_states = residual + hidden_states + + if out is not None: + return (hidden_states,) + tuple(out) + else: + return hidden_states + + +class FeedForwardNetwork(nn.Module): + def __init__(self, config: RnaMsmConfig): + super().__init__() + self.activation = ACT2FN[config.hidden_act] if isinstance(config.hidden_act, str) else config.hidden_act + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states): + hidden_states = self.activation(self.fc1(hidden_states)) + hidden_states = self.dropout(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class RnaMsmPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class RnaMsmLMHead(nn.Module): + """Head for masked language modeling.""" + + def __init__(self, config, weight): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.layer_norm = nn.LayerNorm(config.hidden_size) + self.weight = weight + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.layer_norm(hidden_states) + # project back to size of vocabulary with bias + hidden_states = F.linear(hidden_states, self.weight) + self.bias + return hidden_states + + +class RnaMsmContactHead(nn.Module): + """Performs symmetrization, apc, and computes a logistic regression on the output + features + """ + + def __init__(self, config): + super().__init__() + self.in_features = config.num_attention_heads * config.num_hidden_layers + self.bos_token_id = config.bos_token_id + self.eos_token_id = config.eos_token_id + self.regression = nn.Linear(self.in_features, 1) + self.activation = nn.Sigmoid() + + def forward(self, input_ids, row_attentions): + # remove cls token attentions + if self.bos_token_id: + row_attentions = row_attentions[..., 1:, 1:] + # remove eos token attentions + if self.eos_token_id: + eos_mask = input_ids.ne(self.eos_token_id).to(row_attentions) + eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2) + row_attentions = row_attentions * eos_mask[:, None, None, :, :] + row_attentions = row_attentions[..., :-1, :-1] + batch_size, layers, heads, seqlen, _ = row_attentions.size() + row_attentions = row_attentions.view(batch_size, layers * heads, seqlen, seqlen) + + # features: B x C x T x T + row_attentions = row_attentions.to( + next(self.parameters()) + ) # row_attentions always float32, may need to convert to float16 + row_attentions = apc(symmetrize(row_attentions)) + row_attentions = row_attentions.permute(0, 2, 3, 1) + return self.activation(self.regression(row_attentions).squeeze(3)) + + +class RnaMsmEmbeddings(nn.Module): + def __init__(self, config: RnaMsmConfig): + super().__init__() + self.max_position_embeddings = config.max_position_embeddings + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = RnaMsmLearnedPositionalEmbedding( + self.max_position_embeddings, config.hidden_size, padding_idx=config.pad_token_id + ) + if config.embed_positions_msa: + self.msa_embeddings = nn.Parameter( + 0.01 * torch.randn(1, self.max_position_embeddings, 1, 1), requires_grad=True + ) + else: + self.register_parameter("msa_embeddings", None) # type: ignore + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=1e-12) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, input_ids: Tensor, attention_mask: Optional[Tensor] = None): + assert input_ids.ndim == 3 + batch_size, num_alignments, seqlen = input_ids.size() + words_embeddings = self.word_embeddings(input_ids.long()) + # words_embeddings = self.word_embeddings(tokens) + position_embeddings = self.position_embeddings(input_ids.view(batch_size * num_alignments, seqlen)).view( + words_embeddings.size() + ) + msa_embeddings = 0 + if self.msa_embeddings is not None: + if input_ids.size(1) > self.max_position_embeddings: + raise RuntimeError( + "Using model with MSA position embedding trained on maximum MSA " + f"depth of {self.max_position_embeddings}, but received {position_embeddings.size(1)} alignments." + ) + msa_embeddings += self.msa_embeddings[:, :num_alignments] + + embeddings = words_embeddings + position_embeddings + msa_embeddings + embeddings = self.layer_norm(embeddings) + + embeddings = self.dropout(embeddings) + + if attention_mask is not None: + embeddings = embeddings * (1 - attention_mask.unsqueeze(-1).type_as(embeddings)) + + return embeddings + + +@dataclass +class RnaMsmMaskedLMOutput(ModelOutput): + """ + Base class for masked language models outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Masked language modeling (MLM) loss. + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or + when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or + when`config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + col_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + row_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class RnaMsmModelOutputWithPooling(ModelOutput): + """ + Base class for axial model's outputs, with potential hidden states and attentions. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or + when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + col_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or + when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + row_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or + when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + last_hidden_state: torch.FloatTensor = None + pooler_output: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + col_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + row_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class RnaMsmModelOutput(ModelOutput): + """ + Base class for axial model's outputs, with potential hidden states and attentions. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or + when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + col_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or + when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + row_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or + when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + col_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + row_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +def symmetrize(x): + "Make layer symmetric in final two dimensions, used for contact prediction." + return x + x.transpose(-1, -2) + + +def coerce_numpy(func: Callable) -> Callable: + @wraps(func) + def make_torch_args(*args, **kwargs): + is_numpy = False + update_args = [] + for arg in args: + if isinstance(arg, np.ndarray): + arg = torch.from_numpy(arg) + is_numpy = True + update_args.append(arg) + update_kwargs = {} + for kw, arg in kwargs.items(): + if isinstance(args, np.ndarray): + arg = torch.from_numpy(arg) + is_numpy = True + update_kwargs[kw] = arg + + output = func(*update_args, **update_kwargs) + + if is_numpy: + output = recursive_make_numpy(output) + + return output + + return make_torch_args + + +def recursive_make_numpy(item): + if isinstance(item, torch.Tensor): + return item.detach().cpu().numpy() + elif isinstance(item, (tuple, list)): + return type(item)(recursive_make_numpy(el) for el in item) + elif isinstance(item, dict): + return {kw: recursive_make_numpy(arg) for kw, arg in item.items()} + else: + return item + + +@coerce_numpy +def apc(x): + "Perform average product correct, used for contact prediction." + a1 = x.sum(-1, keepdims=True) + a2 = x.sum(-2, keepdims=True) + a12 = x.sum((-1, -2), keepdims=True) + + avg = a1 * a2 + avg.div_(a12) # in-place to reduce memory + normalized = x - avg + return normalized diff --git a/tox.ini b/tox.ini index fb02b3fb..1e2e6fb6 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,6 @@ [flake8] max-line-length = 120 +ignore = E203 [pycodestyle] count = True