Skip to content

Commit

Permalink
reorganise rnabert
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhiyuanChen committed Apr 16, 2024
1 parent 2bea376 commit a02aa49
Show file tree
Hide file tree
Showing 8 changed files with 676 additions and 379 deletions.
18 changes: 16 additions & 2 deletions multimolecule/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
from .rnabert import RnaBertConfig, RnaBertModel, RnaTokenizer
from .rnabert import (
RnaBertConfig,
RnaBertForMaskedLM,
RnaBertForSequenceClassification,
RnaBertForTokenClassification,
RnaBertModel,
RnaTokenizer,
)

__all__ = ["RnaBertConfig", "RnaBertModel", "RnaTokenizer"]
__all__ = [
"RnaBertConfig",
"RnaBertModel",
"RnaBertForMaskedLM",
"RnaBertForSequenceClassification",
"RnaBertForTokenClassification",
"RnaTokenizer",
]
97 changes: 97 additions & 0 deletions multimolecule/models/configuration_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from __future__ import annotations

from dataclasses import asdict, dataclass, is_dataclass
from typing import Optional

from transformers.configuration_utils import PretrainedConfig as _PretrainedConfig


class PretrainedConfig(_PretrainedConfig):
head: HeadConfig

def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = super().to_dict()
for k, v in output.items():
if hasattr(v, "to_dict"):
output[k] = v.to_dict()
if is_dataclass(v):
output[k] = asdict(v)
return output


@dataclass
class HeadConfig:
r"""
This is the configuration class to store the configuration of a prediction head. It is used to instantiate a
prediction head according to the specified arguments, defining the head architecture.
Configuration objects inherit from [`HeadConfig`] and can be used to control the model outputs. Read the
documentation from [`HeadConfig`] for more information.
Args:
hidden_size (`int`, *optional*, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the hidden states.
transform (`str`, *optional*, defaults to None):
The transform operation applied to hidden states.
transform_act (`str`, *optional*, defaults to "tanh"):
The activation function of transform applied to hidden states.
bias (`bool`, *optional*, defaults to True):
Whether to apply bias to the final prediction layer.
act (`str`, *optional*, defaults to None):
The activation function of the final prediction output.
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
The epsilon used by the layer normalization layers.
"""

hidden_size: Optional[int] = None
dropout: float = 0.0
transform: Optional[str] = None
transform_act: Optional[str] = "tanh"
bias: bool = True
act: Optional[str] = None
layer_norm_eps: float = 1e-12


@dataclass
class MaskedLMHeadConfig:
r"""
This is the configuration class to store the configuration of a prediction head. It is used to instantiate a
prediction head according to the specified arguments, defining the head architecture.
Configuration objects inherit from [`HeadConfig`] and can be used to control the model outputs. Read the
documentation from [`HeadConfig`] for more information.
Args:
hidden_size (`int`, *optional*, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the hidden states.
transform (`str`, *optional*, defaults to "nonlinear"):
The transform operation applied to hidden states.
transform_act (`str`, *optional*, defaults to "gelu"):
The activation function of transform applied to hidden states.
bias (`bool`, *optional*, defaults to True):
Whether to apply bias to the final prediction layer.
act (`str`, *optional*, defaults to None):
The activation function of the final prediction output.
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
The epsilon used by the layer normalization layers.
"""

hidden_size: Optional[int] = None
dropout: float = 0.0
transform: Optional[str] = "nonlinear"
transform_act: Optional[str] = "gelu"
bias: bool = True
act: Optional[str] = None
layer_norm_eps: float = 1e-12
105 changes: 39 additions & 66 deletions multimolecule/models/modeling_utils.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,31 @@
from typing import Optional

import torch
from chanfig import Registry
from chanfig import ConfigRegistry
from torch import Tensor, nn
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_outputs import ModelOutput

from .configuration_utils import HeadConfig, PretrainedConfig


class ContactPredictionHead(nn.Module):
"""
Head for contact-map-level tasks.
Performs symmetrization, and average product correct.
"""

def __init__(
self,
config: PretrainedConfig,
in_features: int,
*,
transform: str = "none",
dropout: float = 0.0,
activation: Optional[str] = None,
bias: bool = False,
):
def __init__(self, config: PretrainedConfig, in_features: int):
super().__init__()
self.config = config.head
self.in_features = in_features
self.bos_token_id = config.bos_token_id
self.eos_token_id = config.eos_token_id
self.num_labels = config.num_labels
self.dropout = nn.Dropout(dropout)
self.transform = PredictionHeadTransform.build(transform, config)
self.decoder = nn.Linear(in_features, config.num_labels, bias=bias)
self.activation = ACT2FN[activation] if activation is not None else None
self.dropout = nn.Dropout(self.config.dropout)
self.transform = PredictionHeadTransform.build(self.config)
self.decoder = nn.Linear(in_features, config.num_labels, bias=self.config.bias)
self.activation = ACT2FN[self.config.act] if self.config.act is not None else None

def forward(self, attentions: Tensor, input_ids: Tensor) -> Tensor:
# remove cls token attentions
Expand Down Expand Up @@ -61,26 +54,20 @@ def forward(self, attentions: Tensor, input_ids: Tensor) -> Tensor:
class MaskedLMHead(nn.Module):
"""Head for masked language modeling."""

def __init__(
self,
config: PretrainedConfig,
weight: Optional[Tensor] = None,
*,
transform: str = "nonlinear",
dropout: float = 0.0,
activation: Optional[str] = None,
bias: bool = False,
):
def __init__(self, config: PretrainedConfig, weight: Optional[Tensor] = None):
super().__init__()
self.config = config.lm_head or config.head
self.num_labels = config.vocab_size
self.dropout = nn.Dropout(dropout)
self.transform = PredictionHeadTransform.build(transform, config)
self.decoder = nn.Linear(config.hidden_size, self.num_labels, bias=bias)
self.bias = nn.Parameter(torch.zeros(self.num_labels))
self.dropout = nn.Dropout(self.config.dropout)
self.transform = PredictionHeadTransform.build(self.config)

self.decoder = nn.Linear(config.hidden_size, self.num_labels, bias=False)
if weight is not None:
self.decoder.weight = weight
self.decoder.bias = self.bias
self.activation = ACT2FN[activation] if activation is not None else None
if self.config.bias:
self.bias = nn.Parameter(torch.zeros(self.num_labels))
self.decoder.bias = self.bias
self.activation = ACT2FN[self.config.act] if self.config.act is not None else None

def forward(self, outputs: ModelOutput) -> Tensor:
sequence_output = outputs[0]
Expand All @@ -97,21 +84,14 @@ class SequenceClassificationHead(nn.Module):

num_labels: int

def __init__(
self,
config: PretrainedConfig,
*,
transform: str = "none",
dropout: float = 0.0,
activation: Optional[str] = None,
bias: bool = False,
):
def __init__(self, config: PretrainedConfig):
super().__init__()
self.config = config.head
self.num_labels = config.num_labels
self.dropout = nn.Dropout(dropout)
self.transform = PredictionHeadTransform.build(transform, config)
self.decoder = nn.Linear(config.hidden_size, self.num_labels, bias=bias)
self.activation = ACT2FN[activation] if activation is not None else None
self.dropout = nn.Dropout(self.config.dropout)
self.transform = PredictionHeadTransform.build(self.config)
self.decoder = nn.Linear(config.hidden_size, self.num_labels, bias=self.config.bias)
self.activation = ACT2FN[self.config.act] if self.config.act is not None else None

def forward(self, outputs: ModelOutput) -> Tensor:
sequence_output = outputs[0]
Expand All @@ -128,21 +108,14 @@ class TokenClassificationHead(nn.Module):

num_labels: int

def __init__(
self,
config: PretrainedConfig,
*,
transform: str = "none",
dropout: float = 0.0,
activation: Optional[str] = None,
bias: bool = False,
):
def __init__(self, config: PretrainedConfig):
super().__init__()
self.config = config.head
self.num_labels = config.num_labels
self.dropout = nn.Dropout(dropout)
self.transform = PredictionHeadTransform.build(transform, config)
self.decoder = nn.Linear(config.hidden_size, self.num_labels, bias=bias)
self.activation = ACT2FN[activation] if activation is not None else None
self.dropout = nn.Dropout(self.config.dropout)
self.transform = PredictionHeadTransform.build(self.config)
self.decoder = nn.Linear(config.hidden_size, self.num_labels, bias=self.config.bias)
self.activation = ACT2FN[self.config.act] if self.config.act is not None else None

def forward(self, outputs: ModelOutput) -> Tensor:
token_output = outputs[1]
Expand All @@ -154,18 +127,18 @@ def forward(self, outputs: ModelOutput) -> Tensor:
return output


PredictionHeadTransform = Registry()
PredictionHeadTransform = ConfigRegistry(key="transform")


@PredictionHeadTransform.register("nonlinear")
class NonLinearTransform(nn.Module):
def __init__(self, config: PretrainedConfig):
def __init__(self, config: HeadConfig):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
if isinstance(config.hidden_act, str):
self.transform_act_fn = ACT2FN[config.hidden_act]
if isinstance(config.transform_act, str):
self.transform_act_fn = ACT2FN[config.transform_act]
else:
self.transform_act_fn = config.hidden_act
self.transform_act_fn = config.transform_act
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

def forward(self, hidden_states: Tensor) -> Tensor:
Expand All @@ -177,7 +150,7 @@ def forward(self, hidden_states: Tensor) -> Tensor:

@PredictionHeadTransform.register("linear")
class LinearTransform(nn.Module):
def __init__(self, config: PretrainedConfig):
def __init__(self, config: HeadConfig):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
Expand All @@ -188,9 +161,9 @@ def forward(self, hidden_states: Tensor) -> Tensor:
return hidden_states


@PredictionHeadTransform.register("none")
@PredictionHeadTransform.register(None)
class IdentityTransform(nn.Identity):
def __init__(self, config: PretrainedConfig): # pylint: disable=unused-argument
def __init__(self, config: HeadConfig): # pylint: disable=unused-argument
super().__init__()


Expand Down
30 changes: 27 additions & 3 deletions multimolecule/models/rnabert/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,36 @@
from transformers import AutoConfig, AutoModel, AutoTokenizer
from transformers import (
AutoConfig,
AutoModel,
AutoModelForMaskedLM,
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
AutoModelWithLMHead,
AutoTokenizer,
)

from multimolecule.tokenizers.rna import RnaTokenizer

from .configuration_rnabert import RnaBertConfig
from .modeling_rnabert import RnaBertModel
from .modeling_rnabert import (
RnaBertForMaskedLM,
RnaBertForSequenceClassification,
RnaBertForTokenClassification,
RnaBertModel,
)

__all__ = ["RnaBertConfig", "RnaBertModel", "RnaTokenizer"]
__all__ = [
"RnaBertConfig",
"RnaBertModel",
"RnaTokenizer",
"RnaBertForMaskedLM",
"RnaBertForSequenceClassification",
"RnaBertForTokenClassification",
]

AutoConfig.register("rnabert", RnaBertConfig)
AutoModel.register(RnaBertConfig, RnaBertModel)
AutoModelForMaskedLM.register(RnaBertConfig, RnaBertForMaskedLM)
AutoModelForSequenceClassification.register(RnaBertConfig, RnaBertForSequenceClassification)
AutoModelForTokenClassification.register(RnaBertConfig, RnaBertForTokenClassification)
AutoModelWithLMHead.register(RnaBertConfig, RnaBertForTokenClassification)
AutoTokenizer.register(RnaBertConfig, RnaTokenizer)
Loading

0 comments on commit a02aa49

Please sign in to comment.