Skip to content

Commit

Permalink
improve RnaTokenizer
Browse files Browse the repository at this point in the history
add type hints for RnaTokenizer
rename tokenizers.rna.config -> tokenizers.rna.utils
point <bos> token to <cls> token
  • Loading branch information
ZhiyuanChen committed Apr 2, 2024
1 parent 9935bcd commit 18503fa
Show file tree
Hide file tree
Showing 4 changed files with 245 additions and 13 deletions.
230 changes: 230 additions & 0 deletions multimolecule/models/modeling_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
from typing import Optional, Tuple, Union

import torch
from chanfig import ConfigRegistry
from torch import nn
from torch.nn import functional as F
from transformers.activations import ACT2FN
from transformers.modeling_outputs import MaskedLMOutput, SequenceClassifierOutput, TokenClassifierOutput


class MaskedLMHead(nn.Module):
"""Head for masked language modeling."""

def __init__(self, config):
super().__init__()
if "proj_head_mode" not in dir(config) or config.proj_head_mode is None:
config.proj_head_mode = "none"
self.transform = PredictionHeadTransform.build(config)
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
self.decoder.bias = self.bias

def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

sequence_output = outputs[0]
x = self.transform(sequence_output)
prediction_scores = self.decoder(x)

masked_lm_loss = None
if labels is not None:
masked_lm_loss = F.cross_entropy(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))

if not return_dict:
output = (prediction_scores,) + outputs[2:]
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output

return MaskedLMOutput(
loss=masked_lm_loss,
logits=prediction_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)


class SequenceClassificationHead(nn.Module):
"""Head for sequence-level classification tasks."""

num_labels: int

def __init__(self, config):
super().__init__()
if "proj_head_mode" not in dir(config) or config.proj_head_mode is None:
config.proj_head_mode = "none"
self.num_labels = config.num_labels
self.transform = PredictionHeadTransform.build(config)
classifier_dropout = (
config.classifier_dropout
if "classifier_dropout" in dir(config) and config.classifier_dropout is not None
else config.hidden_dropout_prob
)
self.dropout = nn.Dropout(classifier_dropout)
self.decoder = nn.Linear(config.hidden_size, self.num_labels, bias=False)

def forward(
self, outputs, labels: Optional[torch.Tensor] = None, return_dict: Optional[bool] = None
) -> Union[Tuple, SequenceClassifierOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
sequence_output = outputs.last_hidden_state if return_dict else outputs[0]
x = self.dropout(sequence_output)
x = self.transform(x)
logits = self.decoder(x)

loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss = (
F.mse_loss(logits.squeeze(), labels.squeeze())
if self.num_labels == 1
else F.mse_loss(logits, labels)
)
elif self.config.problem_type == "single_label_classification":
loss = F.cross_entropy(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss = F.binary_cross_entropy_with_logits(logits, labels)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output

return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)


class TokenClassificationHead(nn.Module):
"""Head for token-level classification tasks."""

num_labels: int

def __init__(self, config):
if "proj_head_mode" not in dir(config) or config.proj_head_mode is None:
config.proj_head_mode = "none"
super().__init__()
self.num_labels = config.num_labels
self.transform = PredictionHeadTransform.build(config)
classifier_dropout = (
config.classifier_dropout
if "classifier_dropout" in dir(config) and config.classifier_dropout is not None
else config.hidden_dropout_prob
)
self.dropout = nn.Dropout(classifier_dropout)
self.decoder = nn.Linear(config.hidden_size, self.num_labels, bias=False)

def forward(
self, outputs, labels: Optional[torch.Tensor] = None, return_dict: Optional[bool] = None
) -> Union[Tuple, TokenClassifierOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
token_output = outputs.pooled_output if return_dict else outputs[1]
x = self.dropout(token_output)
x = self.transform(x)
logits = self.decoder(x)

loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss = (
F.mse_loss(logits.squeeze(), labels.squeeze())
if self.num_labels == 1
else F.mse_loss(logits, labels)
)
elif self.config.problem_type == "single_label_classification":
loss = F.cross_entropy(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss = F.binary_cross_entropy_with_logits(logits, labels)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output

return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)


PredictionHeadTransform = ConfigRegistry(key="proj_head_mode")


@PredictionHeadTransform.register("nonlinear")
class NonLinearTransform(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
if isinstance(config.hidden_act, str):
self.transform_act_fn = ACT2FN[config.hidden_act]
else:
self.transform_act_fn = config.hidden_act
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states


@PredictionHeadTransform.register("linear")
class LinearTransform(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states


@PredictionHeadTransform.register("none")
class IdentityTransform(nn.Identity):
def __init__(self, config):
super().__init__()
2 changes: 1 addition & 1 deletion multimolecule/models/rnabert/convert_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch import nn

from multimolecule.models import RnaBertConfig, RnaBertModel
from multimolecule.tokenizers.rna.config import get_special_tokens_map, get_tokenizer_config, get_vocab_list
from multimolecule.tokenizers.rna.utils import get_special_tokens_map, get_tokenizer_config, get_vocab_list

CONFIG = {
"architectures": ["RnaBertModel"],
Expand Down
26 changes: 14 additions & 12 deletions multimolecule/tokenizers/rna/tokenization_rna.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from transformers.tokenization_utils import PreTrainedTokenizer
from transformers.utils import logging

from .config import VOCAB_LIST
from .utils import get_vocab_list

logger = logging.get_logger(__name__)

Expand All @@ -18,22 +18,24 @@ class RnaTokenizer(PreTrainedTokenizer):

def __init__(
self,
cls_token="<cls>",
pad_token="<pad>",
eos_token="<eos>",
sep_token="<eos>",
unk_token="<unk>",
mask_token="<mask>",
convert_to_uppercase=True,
convert_T_to_U=True,
bos_token: str = "<cls>",
cls_token: str = "<cls>",
pad_token: str = "<pad>",
eos_token: str = "<eos>",
sep_token: str = "<eos>",
unk_token: str = "<unk>",
mask_token: str = "<mask>",
convert_to_uppercase: bool = True,
convert_T_to_U: bool = True,
**kwargs,
):
self.all_tokens = VOCAB_LIST
self.all_tokens = get_vocab_list()
self._id_to_token = dict(enumerate(self.all_tokens))
self._token_to_id = {tok: ind for ind, tok in enumerate(self.all_tokens)}
self.convert_to_uppercase = convert_to_uppercase
self.convert_T_to_U = convert_T_to_U
super().__init__(
bos_token=bos_token,
cls_token=cls_token,
pad_token=pad_token,
eos_token=eos_token,
Expand All @@ -53,7 +55,7 @@ def _convert_id_to_token(self, index: int) -> str:
return self._id_to_token.get(index, self.unk_token)

def _convert_token_to_id(self, token: str) -> int:
return self._token_to_id.get(token, self._token_to_id.get(self.unk_token))
return self._token_to_id.get(token, self._token_to_id.get(self.unk_token)) # type: ignore[arg-type]

def _tokenize(self, text: str, **kwargs):
if self.convert_to_uppercase:
Expand All @@ -68,7 +70,7 @@ def get_vocab(self):
return base_vocab

def token_to_id(self, token: str) -> int:
return self._token_to_id.get(token, self._token_to_id.get(self.unk_token))
return self._token_to_id.get(token, self._token_to_id.get(self.unk_token)) # type: ignore[arg-type]

def id_to_token(self, index: int) -> str:
return self._id_to_token.get(index, self.unk_token)
Expand Down
File renamed without changes.

0 comments on commit 18503fa

Please sign in to comment.