Skip to content

Commit

Permalink
add support of Nucleotide level tasks
Browse files Browse the repository at this point in the history
Signed-off-by: Zhiyuan Chen <[email protected]>
  • Loading branch information
ZhiyuanChen committed Apr 29, 2024
1 parent 71ea23c commit e033f8f
Show file tree
Hide file tree
Showing 14 changed files with 540 additions and 9 deletions.
12 changes: 12 additions & 0 deletions multimolecule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,36 +10,42 @@
from .models import (
RnaBertConfig,
RnaBertForMaskedLM,
RnaBertForNucleotideClassification,
RnaBertForPretraining,
RnaBertForSequenceClassification,
RnaBertForTokenClassification,
RnaBertModel,
RnaFmConfig,
RnaFmForMaskedLM,
RnaFmForNucleotideClassification,
RnaFmForPretraining,
RnaFmForSequenceClassification,
RnaFmForTokenClassification,
RnaFmModel,
RnaMsmConfig,
RnaMsmForMaskedLM,
RnaMsmForNucleotideClassification,
RnaMsmForPretraining,
RnaMsmForSequenceClassification,
RnaMsmForTokenClassification,
RnaMsmModel,
SpliceBertConfig,
SpliceBertForMaskedLM,
SpliceBertForNucleotideClassification,
SpliceBertForPretraining,
SpliceBertForSequenceClassification,
SpliceBertForTokenClassification,
SpliceBertModel,
UtrBertConfig,
UtrBertForMaskedLM,
UtrBertForNucleotideClassification,
UtrBertForPretraining,
UtrBertForSequenceClassification,
UtrBertForTokenClassification,
UtrBertModel,
UtrLmConfig,
UtrLmForMaskedLM,
UtrLmForNucleotideClassification,
UtrLmForPretraining,
UtrLmForSequenceClassification,
UtrLmForTokenClassification,
Expand All @@ -57,36 +63,42 @@
"RnaBertForPretraining",
"RnaBertForSequenceClassification",
"RnaBertForTokenClassification",
"RnaBertForNucleotideClassification",
"RnaFmConfig",
"RnaFmModel",
"RnaFmForMaskedLM",
"RnaFmForPretraining",
"RnaFmForSequenceClassification",
"RnaFmForTokenClassification",
"RnaFmForNucleotideClassification",
"RnaMsmConfig",
"RnaMsmModel",
"RnaMsmForMaskedLM",
"RnaMsmForPretraining",
"RnaMsmForSequenceClassification",
"RnaMsmForTokenClassification",
"RnaMsmForNucleotideClassification",
"SpliceBertConfig",
"SpliceBertModel",
"SpliceBertForMaskedLM",
"SpliceBertForPretraining",
"SpliceBertForSequenceClassification",
"SpliceBertForTokenClassification",
"SpliceBertForNucleotideClassification",
"UtrBertConfig",
"UtrBertModel",
"UtrBertForMaskedLM",
"UtrBertForPretraining",
"UtrBertForSequenceClassification",
"UtrBertForTokenClassification",
"UtrBertForNucleotideClassification",
"UtrLmConfig",
"UtrLmModel",
"UtrLmForMaskedLM",
"UtrLmForPretraining",
"UtrLmForSequenceClassification",
"UtrLmForTokenClassification",
"UtrLmForNucleotideClassification",
"RnaBertForCrisprOffTarget",
"RnaFmForCrisprOffTarget",
"RnaMsmForCrisprOffTarget",
Expand Down
12 changes: 12 additions & 0 deletions multimolecule/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .rnabert import (
RnaBertConfig,
RnaBertForMaskedLM,
RnaBertForNucleotideClassification,
RnaBertForPretraining,
RnaBertForSequenceClassification,
RnaBertForTokenClassification,
Expand All @@ -10,6 +11,7 @@
from .rnafm import (
RnaFmConfig,
RnaFmForMaskedLM,
RnaFmForNucleotideClassification,
RnaFmForPretraining,
RnaFmForSequenceClassification,
RnaFmForTokenClassification,
Expand All @@ -18,6 +20,7 @@
from .rnamsm import (
RnaMsmConfig,
RnaMsmForMaskedLM,
RnaMsmForNucleotideClassification,
RnaMsmForPretraining,
RnaMsmForSequenceClassification,
RnaMsmForTokenClassification,
Expand All @@ -26,6 +29,7 @@
from .splicebert import (
SpliceBertConfig,
SpliceBertForMaskedLM,
SpliceBertForNucleotideClassification,
SpliceBertForPretraining,
SpliceBertForSequenceClassification,
SpliceBertForTokenClassification,
Expand All @@ -34,6 +38,7 @@
from .utrbert import (
UtrBertConfig,
UtrBertForMaskedLM,
UtrBertForNucleotideClassification,
UtrBertForPretraining,
UtrBertForSequenceClassification,
UtrBertForTokenClassification,
Expand All @@ -42,6 +47,7 @@
from .utrlm import (
UtrLmConfig,
UtrLmForMaskedLM,
UtrLmForNucleotideClassification,
UtrLmForPretraining,
UtrLmForSequenceClassification,
UtrLmForTokenClassification,
Expand All @@ -56,34 +62,40 @@
"RnaBertForPretraining",
"RnaBertForSequenceClassification",
"RnaBertForTokenClassification",
"RnaBertForNucleotideClassification",
"RnaFmConfig",
"RnaFmModel",
"RnaFmForMaskedLM",
"RnaFmForPretraining",
"RnaFmForSequenceClassification",
"RnaFmForTokenClassification",
"RnaFmForNucleotideClassification",
"RnaMsmConfig",
"RnaMsmModel",
"RnaMsmForMaskedLM",
"RnaMsmForPretraining",
"RnaMsmForSequenceClassification",
"RnaMsmForTokenClassification",
"RnaMsmForNucleotideClassification",
"SpliceBertConfig",
"SpliceBertModel",
"SpliceBertForMaskedLM",
"SpliceBertForPretraining",
"SpliceBertForSequenceClassification",
"SpliceBertForTokenClassification",
"SpliceBertForNucleotideClassification",
"UtrBertConfig",
"UtrBertModel",
"UtrBertForMaskedLM",
"UtrBertForPretraining",
"UtrBertForSequenceClassification",
"UtrBertForTokenClassification",
"UtrBertForNucleotideClassification",
"UtrLmConfig",
"UtrLmModel",
"UtrLmForMaskedLM",
"UtrLmForPretraining",
"UtrLmForSequenceClassification",
"UtrLmForTokenClassification",
"UtrLmForNucleotideClassification",
]
5 changes: 4 additions & 1 deletion multimolecule/models/rnabert/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .configuration_rnabert import RnaBertConfig
from .modeling_rnabert import (
RnaBertForMaskedLM,
RnaBertForNucleotideClassification,
RnaBertForPretraining,
RnaBertForSequenceClassification,
RnaBertForTokenClassification,
Expand All @@ -21,14 +22,15 @@
)

__all__ = [
"RnaTokenizer",
"RnaBertConfig",
"RnaBertModel",
"RnaTokenizer",
"RnaBertPreTrainedModel",
"RnaBertForMaskedLM",
"RnaBertForPretraining",
"RnaBertForSequenceClassification",
"RnaBertForTokenClassification",
"RnaBertForNucleotideClassification",
]

AutoConfig.register("rnabert", RnaBertConfig)
Expand All @@ -37,4 +39,5 @@
AutoModelForPreTraining.register(RnaBertConfig, RnaBertForPretraining)
AutoModelForSequenceClassification.register(RnaBertConfig, RnaBertForSequenceClassification)
AutoModelForTokenClassification.register(RnaBertConfig, RnaBertForTokenClassification)
# AutoModelForNucleotideClassification.register(RnaBertConfig, RnaBertForNucleotideClassification)
AutoTokenizer.register(RnaBertConfig, RnaTokenizer)
86 changes: 85 additions & 1 deletion multimolecule/models/rnabert/modeling_rnabert.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@
TokenClassifierOutput,
)

from ..modeling_utils import MaskedLMHead, SequenceClassificationHead, TokenClassificationHead
from ..modeling_utils import (
MaskedLMHead,
NucleotideClassificationHead,
SequenceClassificationHead,
TokenClassificationHead,
)
from .configuration_rnabert import RnaBertConfig


Expand Down Expand Up @@ -383,6 +388,85 @@ def forward(
)


class RnaBertForNucleotideClassification(RnaBertPreTrainedModel):
"""
Examples:
>>> from multimolecule import RnaBertConfig, RnaBertForNucleotideClassification, RnaTokenizer
>>> config = RnaBertConfig()
>>> model = RnaBertForNucleotideClassification(config)
>>> tokenizer = RnaTokenizer.from_pretrained("multimolecule/rna")
>>> input = tokenizer("ACGUN", return_tensors="pt")
>>> output = model(**input)
"""

def __init__(self, config: RnaBertConfig):
super().__init__(config)
self.num_labels = config.head.num_labels
self.rnabert = RnaBertModel(config, add_pooling_layer=False)
self.nucleotide_head = NucleotideClassificationHead(config)
self.head_config = self.nucleotide_head.config

# Initialize weights and apply final processing
self.post_init()

def forward(
self,
input_ids: Tensor,
attention_mask: Optional[Tensor] = None,
labels: Optional[Tensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
) -> Tuple[Tensor, ...] | TokenClassifierOutput:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.rnabert(
input_ids,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
logits = self.nucleotide_head(outputs, attention_mask, input_ids)

loss = None
if labels is not None:
if self.head_config.problem_type is None:
if self.num_labels == 1:
self.head_config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.head_config.problem_type = "single_label_classification"
else:
self.head_config.problem_type = "multi_label_classification"
if self.head_config.problem_type == "regression":
loss = (
F.mse_loss(logits.squeeze(), labels.squeeze())
if self.num_labels == 1
else F.mse_loss(logits, labels)
)
elif self.head_config.problem_type == "single_label_classification":
loss = F.cross_entropy(logits.view(-1, self.num_labels), labels.view(-1))
elif self.head_config.problem_type == "multi_label_classification":
loss = F.binary_cross_entropy_with_logits(logits, labels)

if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output

return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)


class RnaBertEmbeddings(nn.Module):
def __init__(self, config: RnaBertConfig):
super().__init__()
Expand Down
4 changes: 3 additions & 1 deletion multimolecule/models/rnafm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .configuration_rnafm import RnaFmConfig
from .modeling_rnafm import (
RnaFmForMaskedLM,
RnaFmForNucleotideClassification,
RnaFmForPretraining,
RnaFmForSequenceClassification,
RnaFmForTokenClassification,
Expand All @@ -21,14 +22,15 @@
)

__all__ = [
"RnaTokenizer",
"RnaFmConfig",
"RnaFmModel",
"RnaTokenizer",
"RnaFmPreTrainedModel",
"RnaFmForMaskedLM",
"RnaFmForPretraining",
"RnaFmForSequenceClassification",
"RnaFmForTokenClassification",
"RnaFmForNucleotideClassification",
]

AutoConfig.register("rnafm", RnaFmConfig)
Expand Down
Loading

0 comments on commit e033f8f

Please sign in to comment.