Skip to content

Commit

Permalink
reorganise rnabert
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhiyuanChen committed Apr 1, 2024
1 parent 7adf925 commit d4c9556
Show file tree
Hide file tree
Showing 7 changed files with 636 additions and 248 deletions.
18 changes: 16 additions & 2 deletions multimolecule/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
from .rnabert import RnaBertConfig, RnaBertModel, RnaTokenizer
from .rnabert import (
RnaBertConfig,
RnaBertForMaskedLM,
RnaBertForSequenceClassification,
RnaBertForTokenClassification,
RnaBertModel,
RnaTokenizer,
)

__all__ = ["RnaBertConfig", "RnaBertModel", "RnaTokenizer"]
__all__ = [
"RnaBertConfig",
"RnaBertModel",
"RnaBertForMaskedLM",
"RnaBertForSequenceClassification",
"RnaBertForTokenClassification",
"RnaTokenizer",
]
230 changes: 230 additions & 0 deletions multimolecule/models/modeling_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
from typing import Optional, Tuple, Union

import torch
from chanfig import ConfigRegistry
from torch import nn
from torch.nn import functional as F
from transformers.activations import ACT2FN
from transformers.modeling_outputs import MaskedLMOutput, SequenceClassifierOutput, TokenClassifierOutput


class MaskedLMHead(nn.Module):
"""Head for masked language modeling."""

def __init__(self, config):
super().__init__()
if "proj_head_mode" not in dir(config) or config.proj_head_mode is None:
config.proj_head_mode = "none"
self.transform = PredictionHeadTransform.build(config)
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
self.decoder.bias = self.bias

def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

sequence_output = outputs[0]
x = self.transform(sequence_output)
prediction_scores = self.decoder(x)

masked_lm_loss = None
if labels is not None:
masked_lm_loss = F.cross_entropy(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))

if not return_dict:
output = (prediction_scores,) + outputs[2:]
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output

return MaskedLMOutput(
loss=masked_lm_loss,
logits=prediction_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)


class SequenceClassificationHead(nn.Module):
"""Head for sequence-level classification tasks."""

num_labels: int

def __init__(self, config):
super().__init__()
if "proj_head_mode" not in dir(config) or config.proj_head_mode is None:
config.proj_head_mode = "none"
self.num_labels = config.num_labels
self.transform = PredictionHeadTransform.build(config)
classifier_dropout = (
config.classifier_dropout
if "classifier_dropout" in dir(config) and config.classifier_dropout is not None
else config.hidden_dropout_prob
)
self.dropout = nn.Dropout(classifier_dropout)
self.decoder = nn.Linear(config.hidden_size, self.num_labels, bias=False)

def forward(
self, outputs, labels: Optional[torch.Tensor] = None, return_dict: Optional[bool] = None
) -> Union[Tuple, SequenceClassifierOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
sequence_output = outputs.last_hidden_state if return_dict else outputs[0]
x = self.dropout(sequence_output)
x = self.transform(x)
logits = self.decoder(x)

loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss = (
F.mse_loss(logits.squeeze(), labels.squeeze())
if self.num_labels == 1
else F.mse_loss(logits, labels)
)
elif self.config.problem_type == "single_label_classification":
loss = F.cross_entropy(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss = F.binary_cross_entropy_with_logits(logits, labels)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output

return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)


class TokenClassificationHead(nn.Module):
"""Head for token-level classification tasks."""

num_labels: int

def __init__(self, config):
if "proj_head_mode" not in dir(config) or config.proj_head_mode is None:
config.proj_head_mode = "none"
super().__init__()
self.num_labels = config.num_labels
self.transform = PredictionHeadTransform.build(config)
classifier_dropout = (
config.classifier_dropout
if "classifier_dropout" in dir(config) and config.classifier_dropout is not None
else config.hidden_dropout_prob
)
self.dropout = nn.Dropout(classifier_dropout)
self.decoder = nn.Linear(config.hidden_size, self.num_labels, bias=False)

def forward(
self, outputs, labels: Optional[torch.Tensor] = None, return_dict: Optional[bool] = None
) -> Union[Tuple, TokenClassifierOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
token_output = outputs.pooled_output if return_dict else outputs[1]
x = self.dropout(token_output)
x = self.transform(x)
logits = self.decoder(x)

loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss = (
F.mse_loss(logits.squeeze(), labels.squeeze())
if self.num_labels == 1
else F.mse_loss(logits, labels)
)
elif self.config.problem_type == "single_label_classification":
loss = F.cross_entropy(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss = F.binary_cross_entropy_with_logits(logits, labels)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output

return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)


PredictionHeadTransform = ConfigRegistry(key="proj_head_mode")


@PredictionHeadTransform.register("nonlinear")
class NonLinearTransform(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
if isinstance(config.hidden_act, str):
self.transform_act_fn = ACT2FN[config.hidden_act]
else:
self.transform_act_fn = config.hidden_act
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states


@PredictionHeadTransform.register("linear")
class LinearTransform(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states


@PredictionHeadTransform.register("none")
class IdentityTransform(nn.Identity):
def __init__(self, config):
super().__init__()
30 changes: 27 additions & 3 deletions multimolecule/models/rnabert/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,36 @@
from transformers import AutoConfig, AutoModel, AutoTokenizer
from transformers import (
AutoConfig,
AutoModel,
AutoModelForMaskedLM,
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
AutoModelWithLMHead,
AutoTokenizer,
)

from multimolecule.tokenizers.rna import RnaTokenizer

from .configuration_rnabert import RnaBertConfig
from .modeling_rnabert import RnaBertModel
from .modeling_rnabert import (
RnaBertForMaskedLM,
RnaBertForSequenceClassification,
RnaBertForTokenClassification,
RnaBertModel,
)

__all__ = ["RnaBertConfig", "RnaBertModel", "RnaTokenizer"]
__all__ = [
"RnaBertConfig",
"RnaBertModel",
"RnaTokenizer",
"RnaBertForMaskedLM",
"RnaBertForSequenceClassification",
"RnaBertForTokenClassification",
]

AutoConfig.register("rnabert", RnaBertConfig)
AutoModel.register(RnaBertConfig, RnaBertModel)
AutoModelForMaskedLM.register(RnaBertConfig, RnaBertForMaskedLM)
AutoModelForSequenceClassification.register(RnaBertConfig, RnaBertForSequenceClassification)
AutoModelForTokenClassification.register(RnaBertConfig, RnaBertForTokenClassification)
AutoModelWithLMHead.register(RnaBertConfig, RnaBertForTokenClassification)
AutoTokenizer.register(RnaBertConfig, RnaTokenizer)
9 changes: 7 additions & 2 deletions multimolecule/models/rnabert/configuration_rnabert.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class RnaBertConfig(PretrainedConfig):
This is the configuration class to store the configuration of a [`RnaBertModel`]. It is used to instantiate a
RnaBert model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the RnaBert
[mana438/RNABERT](https://github.com/mana438/RNABERT/blob/master/RNA_bert_config.json) architecture.
[mana438/RNABERT](https://github.com/mana438/RNABERT) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Expand Down Expand Up @@ -55,7 +55,8 @@ class RnaBertConfig(PretrainedConfig):
>>> # Initializing a model from the configuration >>> model = RnaBertModel(configuration)
>>> # Accessing the model configuration >>> configuration = model.config
```"""
```
"""

model_type = "rnabert"

Expand All @@ -77,6 +78,8 @@ def __init__(
pad_token_id=0,
position_embedding_type="absolute",
use_cache=True,
classifier_dropout=None,
proj_head_mode="nonlinear",
**kwargs,
):
super().__init__(pad_token_id=pad_token_id, **kwargs)
Expand All @@ -97,3 +100,5 @@ def __init__(
self.layer_norm_eps = layer_norm_eps
self.position_embedding_type = position_embedding_type
self.use_cache = use_cache
self.classifier_dropout = classifier_dropout
self.proj_head_mode = proj_head_mode
31 changes: 22 additions & 9 deletions multimolecule/models/rnabert/convert_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
from torch import nn

from multimolecule.models import RnaBertConfig, RnaBertModel
from multimolecule.models import RnaBertConfig, RnaBertForMaskedLM
from multimolecule.tokenizers.rna.config import get_special_tokens_map, get_tokenizer_config, get_vocab_list

CONFIG = {
Expand All @@ -19,7 +19,6 @@
"max_position_embeddings": 440,
"num_attention_heads": 12,
"num_hidden_layers": 6,
"vocab_size": 25,
"ss_vocab_size": 8,
"type_vocab_size": 2,
"pad_token_id": 0,
Expand All @@ -33,27 +32,41 @@ def convert_checkpoint(checkpoint_path: str, output_path: Optional[str] = None):
if output_path is None:
output_path = "rnabert"
config = RnaBertConfig.from_dict(chanfig.FlatDict(CONFIG))
config.vocab_size = len(vocab_list)
ckpt = torch.load(checkpoint_path, map_location=torch.device("cpu"))
bert_state_dict = ckpt
state_dict = {}

model = RnaBertModel(config)
model = RnaBertForMaskedLM(config)

for key, value in bert_state_dict.items():
if key.startswith("module.cls"):
continue
key = key[12:]
key = key[7:]
key = key.replace("gamma", "weight")
key = key.replace("beta", "bias")
state_dict[key] = value
if key.startswith("bert"):
state_dict["rna" + key] = value
continue
if key.startswith("cls"):
# import ipdb; ipdb.set_trace()
key = "lm_head." + key[4:]
# key = key[4:]
state_dict[key] = value
continue

word_embed = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
word_embed_weight = word_embed.weight.data
predictions_bias = torch.zeros(config.vocab_size)
predictions_decoder_weight = torch.zeros((config.vocab_size, config.hidden_size))
# nn.init.normal_(pos_embed.weight, std=0.02)
for original_token, new_token in zip(original_vocab_list, vocab_list):
original_index = original_vocab_list.index(original_token)
new_index = vocab_list.index(new_token)
word_embed.weight.data[new_index] = state_dict["embeddings.word_embeddings.weight"][original_index]
state_dict["embeddings.word_embeddings.weight"] = word_embed.weight.data
word_embed_weight[new_index] = state_dict["rnabert.embeddings.word_embeddings.weight"][original_index]
predictions_bias[new_index] = state_dict["lm_head.predictions.bias"][original_index]
predictions_decoder_weight[new_index] = state_dict["lm_head.predictions.decoder.weight"][original_index]
state_dict["rnabert.embeddings.word_embeddings.weight"] = word_embed_weight
state_dict["lm_head.predictions.bias"] = predictions_bias
state_dict["lm_head.predictions.decoder.weight"] = predictions_decoder_weight

model.load_state_dict(state_dict)
model.save_pretrained(output_path, safe_serialization=True)
Expand Down
Loading

0 comments on commit d4c9556

Please sign in to comment.