From 183a11baf7f8a37ad52b793e38f2f8e4a7c7090a Mon Sep 17 00:00:00 2001 From: Zhiyuan Chen Date: Thu, 2 May 2024 21:38:53 +0800 Subject: [PATCH] inspect if input_ids is NestedTensor Signed-off-by: Zhiyuan Chen --- .../models/rnabert/modeling_rnabert.py | 15 +++++++++------ multimolecule/models/rnafm/modeling_rnafm.py | 19 +++++++++++-------- .../models/rnamsm/modeling_rnamsm.py | 19 +++++++++++-------- .../models/splicebert/modeling_splicebert.py | 17 ++++++++++------- .../models/utrbert/modeling_utrbert.py | 17 ++++++++++------- multimolecule/models/utrlm/modeling_utrlm.py | 19 +++++++++++-------- 6 files changed, 62 insertions(+), 44 deletions(-) diff --git a/multimolecule/models/rnabert/modeling_rnabert.py b/multimolecule/models/rnabert/modeling_rnabert.py index a06d7c3e..39763f1b 100644 --- a/multimolecule/models/rnabert/modeling_rnabert.py +++ b/multimolecule/models/rnabert/modeling_rnabert.py @@ -5,6 +5,7 @@ from typing import Optional, Tuple import torch +from danling import NestedTensor from torch import Tensor, nn from torch.nn import functional as F from transformers import PreTrainedModel @@ -79,12 +80,14 @@ def __init__(self, config: RnaBertConfig, add_pooling_layer: bool = True): def forward( self, - input_ids: Tensor, + input_ids: Tensor | NestedTensor, attention_mask: Optional[Tensor] = None, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, ) -> Tuple[Tensor, ...] | BaseModelOutputWithPooling: + if isinstance(input_ids, NestedTensor): + input_ids, attention_mask = input_ids.tensor, input_ids.mask if attention_mask is None: attention_mask = ( input_ids.ne(self.pad_token_id) if self.pad_token_id is not None else torch.ones_like(input_ids) @@ -135,7 +138,7 @@ def __init__(self, config: RnaBertConfig): def forward( self, - input_ids: Tensor, + input_ids: Tensor | NestedTensor, attention_mask: Optional[Tensor] = None, labels: Optional[torch.Tensor] = None, output_attentions: bool = False, @@ -188,7 +191,7 @@ def __init__(self, config: RnaBertConfig): def forward( self, - input_ids: Tensor, + input_ids: Tensor | NestedTensor, attention_mask: Optional[Tensor] = None, labels: Optional[torch.Tensor] = None, labels_ss: Optional[torch.Tensor] = None, @@ -253,7 +256,7 @@ def __init__(self, config: RnaBertConfig): def forward( self, - input_ids: Tensor, + input_ids: Tensor | NestedTensor, attention_mask: Optional[Tensor] = None, labels: Optional[Tensor] = None, output_attentions: bool = False, @@ -332,7 +335,7 @@ def __init__(self, config: RnaBertConfig): def forward( self, - input_ids: Tensor, + input_ids: Tensor | NestedTensor, attention_mask: Optional[Tensor] = None, labels: Optional[Tensor] = None, output_attentions: bool = False, @@ -411,7 +414,7 @@ def __init__(self, config: RnaBertConfig): def forward( self, - input_ids: Tensor, + input_ids: Tensor | NestedTensor, attention_mask: Optional[Tensor] = None, labels: Optional[Tensor] = None, output_attentions: bool = False, diff --git a/multimolecule/models/rnafm/modeling_rnafm.py b/multimolecule/models/rnafm/modeling_rnafm.py index 02f0b12e..659ecdb4 100755 --- a/multimolecule/models/rnafm/modeling_rnafm.py +++ b/multimolecule/models/rnafm/modeling_rnafm.py @@ -5,6 +5,7 @@ import torch import torch.utils.checkpoint +from danling import NestedTensor from torch import Tensor, nn from torch.nn import functional as F from transformers.activations import ACT2FN @@ -98,7 +99,7 @@ class PreTrainedModel def forward( self, - input_ids: Optional[Tensor] = None, + input_ids: Tensor | NestedTensor, attention_mask: Optional[Tensor] = None, position_ids: Optional[Tensor] = None, head_mask: Optional[Tensor] = None, @@ -143,9 +144,11 @@ def forward( else: use_cache = False + if isinstance(input_ids, NestedTensor): + input_ids, attention_mask = input_ids.tensor, input_ids.mask if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: + if input_ids is not None: self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) input_shape = input_ids.size() elif inputs_embeds is not None: @@ -251,7 +254,7 @@ def __init__(self, config: RnaFmConfig): def forward( self, - input_ids: Optional[Tensor] = None, + input_ids: Tensor | NestedTensor, attention_mask: Optional[Tensor] = None, position_ids: Optional[Tensor] = None, head_mask: Optional[Tensor] = None, @@ -336,7 +339,7 @@ def set_output_embeddings(self, embeddings): def forward( self, - input_ids: Optional[Tensor] = None, + input_ids: Tensor | NestedTensor, attention_mask: Optional[Tensor] = None, position_ids: Optional[Tensor] = None, head_mask: Optional[Tensor] = None, @@ -417,7 +420,7 @@ def __init__(self, config: RnaFmConfig): def forward( self, - input_ids: Optional[Tensor] = None, + input_ids: Tensor | NestedTensor, attention_mask: Optional[Tensor] = None, position_ids: Optional[Tensor] = None, head_mask: Optional[Tensor] = None, @@ -501,7 +504,7 @@ def __init__(self, config: RnaFmConfig): def forward( self, - input_ids: Optional[Tensor] = None, + input_ids: Tensor | NestedTensor, attention_mask: Optional[Tensor] = None, position_ids: Optional[Tensor] = None, head_mask: Optional[Tensor] = None, @@ -583,7 +586,7 @@ def __init__(self, config: RnaFmConfig): def forward( self, - input_ids: Tensor, + input_ids: Tensor | NestedTensor, attention_mask: Optional[Tensor] = None, labels: Optional[Tensor] = None, output_attentions: bool = False, @@ -1197,7 +1200,7 @@ def forward( self, outputs: BaseModelOutputWithPastAndCrossAttentions | Tuple[Tensor, ...], attention_mask: Optional[Tensor] = None, - input_ids: Optional[Tensor] = None, + input_ids: Tensor | NestedTensor | None = None, ) -> Tuple[Tensor, Tensor]: logits = self.predictions(outputs) contact_map = self.contact(torch.stack(outputs[-1], 1), attention_mask, input_ids) diff --git a/multimolecule/models/rnamsm/modeling_rnamsm.py b/multimolecule/models/rnamsm/modeling_rnamsm.py index 7f079725..0e4e355a 100644 --- a/multimolecule/models/rnamsm/modeling_rnamsm.py +++ b/multimolecule/models/rnamsm/modeling_rnamsm.py @@ -7,6 +7,7 @@ import torch from chanfig import ConfigRegistry +from danling import NestedTensor from torch import Tensor, nn from torch.nn import functional as F from transformers import PreTrainedModel @@ -75,12 +76,14 @@ def __init__(self, config: RnaMsmConfig, add_pooling_layer: bool = True): def forward( self, - input_ids: Tensor, + input_ids: Tensor | NestedTensor, attention_mask: Optional[Tensor] = None, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, ) -> Tuple[Tensor, ...] | RnaMsmModelOutputWithPooling: + if isinstance(input_ids, NestedTensor): + input_ids, attention_mask = input_ids.tensor, input_ids.mask if attention_mask is None: attention_mask = ( input_ids.ne(self.pad_token_id) if self.pad_token_id is not None else torch.ones_like(input_ids) @@ -134,7 +137,7 @@ def __init__(self, config: RnaMsmConfig): def forward( self, - input_ids: Tensor, + input_ids: Tensor | NestedTensor, attention_mask: Optional[Tensor] = None, labels: Optional[torch.Tensor] = None, output_attentions: bool = False, @@ -188,7 +191,7 @@ def __init__(self, config: RnaMsmConfig): def forward( self, - input_ids: Tensor, + input_ids: Tensor | NestedTensor, attention_mask: Optional[Tensor] = None, labels: Optional[torch.Tensor] = None, labels_contact: Optional[torch.Tensor] = None, @@ -251,7 +254,7 @@ def __init__(self, config: RnaMsmConfig): def forward( self, - input_ids: Tensor, + input_ids: Tensor | NestedTensor, attention_mask: Optional[Tensor] = None, labels: Optional[Tensor] = None, output_attentions: bool = False, @@ -331,7 +334,7 @@ def __init__(self, config: RnaMsmConfig): def forward( self, - input_ids: Tensor, + input_ids: Tensor | NestedTensor, attention_mask: Optional[Tensor] = None, labels: Optional[Tensor] = None, output_attentions: bool = False, @@ -411,7 +414,7 @@ def __init__(self, config: RnaMsmConfig): def forward( self, - input_ids: Tensor, + input_ids: Tensor | NestedTensor, attention_mask: Optional[Tensor] = None, labels: Optional[Tensor] = None, output_attentions: bool = False, @@ -489,7 +492,7 @@ def __init__(self, config: RnaMsmConfig): self.layer_norm = nn.LayerNorm(config.hidden_size, eps=1e-12) self.dropout = nn.Dropout(config.hidden_dropout) - def forward(self, input_ids: Tensor, attention_mask: Optional[Tensor] = None) -> Tensor: + def forward(self, input_ids: Tensor | NestedTensor, attention_mask: Optional[Tensor] = None) -> Tensor: assert input_ids.ndim == 3 if attention_mask is None: attention_mask = input_ids.ne(self.pad_token_id) @@ -1259,7 +1262,7 @@ def forward( self, outputs: RnaMsmModelOutput | Tuple[Tensor, ...], attention_mask: Optional[Tensor] = None, - input_ids: Optional[Tensor] = None, + input_ids: Tensor | NestedTensor | None = None, ) -> Tuple[Tensor, Tensor]: sequence_output, row_attentions = outputs[0], torch.stack(outputs[-1], 1) prediction_scores = self.predictions(sequence_output) diff --git a/multimolecule/models/splicebert/modeling_splicebert.py b/multimolecule/models/splicebert/modeling_splicebert.py index 9e5a5cbc..876fdb9b 100644 --- a/multimolecule/models/splicebert/modeling_splicebert.py +++ b/multimolecule/models/splicebert/modeling_splicebert.py @@ -5,6 +5,7 @@ import torch import torch.utils.checkpoint +from danling import NestedTensor from torch import Tensor, nn from torch.nn import functional as F from transformers.activations import ACT2FN @@ -106,7 +107,7 @@ class PreTrainedModel def forward( self, - input_ids: Optional[Tensor] = None, + input_ids: Tensor | NestedTensor, attention_mask: Optional[Tensor] = None, position_ids: Optional[Tensor] = None, head_mask: Optional[Tensor] = None, @@ -151,9 +152,11 @@ def forward( else: use_cache = False + if isinstance(input_ids, NestedTensor): + input_ids, attention_mask = input_ids.tensor, input_ids.mask if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: + if input_ids is not None: # self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) input_shape = input_ids.size() elif inputs_embeds is not None: @@ -264,7 +267,7 @@ def set_output_embeddings(self, new_embeddings): def forward( self, - input_ids: Optional[Tensor] = None, + input_ids: Tensor | NestedTensor, attention_mask: Optional[Tensor] = None, position_ids: Optional[Tensor] = None, head_mask: Optional[Tensor] = None, @@ -359,7 +362,7 @@ def __init__(self, config: SpliceBertConfig): def forward( self, - input_ids: Optional[Tensor] = None, + input_ids: Tensor | NestedTensor, attention_mask: Optional[Tensor] = None, position_ids: Optional[Tensor] = None, head_mask: Optional[Tensor] = None, @@ -433,7 +436,7 @@ def __init__(self, config: SpliceBertConfig): def forward( self, - input_ids: Optional[Tensor] = None, + input_ids: Tensor | NestedTensor, attention_mask: Optional[Tensor] = None, position_ids: Optional[Tensor] = None, head_mask: Optional[Tensor] = None, @@ -518,7 +521,7 @@ def __init__(self, config: SpliceBertConfig): def forward( self, - input_ids: Optional[Tensor] = None, + input_ids: Tensor | NestedTensor, attention_mask: Optional[Tensor] = None, position_ids: Optional[Tensor] = None, head_mask: Optional[Tensor] = None, @@ -601,7 +604,7 @@ def __init__(self, config: SpliceBertConfig): def forward( self, - input_ids: Tensor, + input_ids: Tensor | NestedTensor, attention_mask: Optional[Tensor] = None, labels: Optional[Tensor] = None, output_attentions: bool = False, diff --git a/multimolecule/models/utrbert/modeling_utrbert.py b/multimolecule/models/utrbert/modeling_utrbert.py index 67dee76b..4b3d27fd 100644 --- a/multimolecule/models/utrbert/modeling_utrbert.py +++ b/multimolecule/models/utrbert/modeling_utrbert.py @@ -5,6 +5,7 @@ import torch import torch.utils.checkpoint +from danling import NestedTensor from torch import Tensor, nn from torch.nn import functional as F from transformers.activations import ACT2FN @@ -91,7 +92,7 @@ class PreTrainedModel def forward( self, - input_ids: Optional[Tensor] = None, + input_ids: Tensor | NestedTensor, attention_mask: Optional[Tensor] = None, position_ids: Optional[Tensor] = None, head_mask: Optional[Tensor] = None, @@ -136,9 +137,11 @@ def forward( else: use_cache = False + if isinstance(input_ids, NestedTensor): + input_ids, attention_mask = input_ids.tensor, input_ids.mask if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: + if input_ids is not None: self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) input_shape = input_ids.size() elif inputs_embeds is not None: @@ -249,7 +252,7 @@ def set_output_embeddings(self, new_embeddings): def forward( self, - input_ids: Optional[Tensor] = None, + input_ids: Tensor | NestedTensor, attention_mask: Optional[Tensor] = None, position_ids: Optional[Tensor] = None, head_mask: Optional[Tensor] = None, @@ -344,7 +347,7 @@ def __init__(self, config: UtrBertConfig): def forward( self, - input_ids: Optional[Tensor] = None, + input_ids: Tensor | NestedTensor, attention_mask: Optional[Tensor] = None, position_ids: Optional[Tensor] = None, head_mask: Optional[Tensor] = None, @@ -414,7 +417,7 @@ def __init__(self, config: UtrBertConfig): def forward( self, - input_ids: Optional[Tensor] = None, + input_ids: Tensor | NestedTensor, attention_mask: Optional[Tensor] = None, position_ids: Optional[Tensor] = None, head_mask: Optional[Tensor] = None, @@ -499,7 +502,7 @@ def __init__(self, config: UtrBertConfig): def forward( self, - input_ids: Optional[Tensor] = None, + input_ids: Tensor | NestedTensor, attention_mask: Optional[Tensor] = None, position_ids: Optional[Tensor] = None, head_mask: Optional[Tensor] = None, @@ -582,7 +585,7 @@ def __init__(self, config: UtrBertConfig): def forward( self, - input_ids: Tensor, + input_ids: Tensor | NestedTensor, attention_mask: Optional[Tensor] = None, labels: Optional[Tensor] = None, output_attentions: bool = False, diff --git a/multimolecule/models/utrlm/modeling_utrlm.py b/multimolecule/models/utrlm/modeling_utrlm.py index e445ebfc..15aba14c 100755 --- a/multimolecule/models/utrlm/modeling_utrlm.py +++ b/multimolecule/models/utrlm/modeling_utrlm.py @@ -5,6 +5,7 @@ import torch import torch.utils.checkpoint +from danling import NestedTensor from torch import Tensor, nn from torch.nn import functional as F from transformers.activations import ACT2FN @@ -98,7 +99,7 @@ class PreTrainedModel def forward( self, - input_ids: Optional[Tensor] = None, + input_ids: Tensor | NestedTensor, attention_mask: Optional[Tensor] = None, position_ids: Optional[Tensor] = None, head_mask: Optional[Tensor] = None, @@ -143,9 +144,11 @@ def forward( else: use_cache = False + if isinstance(input_ids, NestedTensor): + input_ids, attention_mask = input_ids.tensor, input_ids.mask if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: + if input_ids is not None: self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) input_shape = input_ids.size() elif inputs_embeds is not None: @@ -251,7 +254,7 @@ def __init__(self, config: UtrLmConfig): def forward( self, - input_ids: Optional[Tensor] = None, + input_ids: Tensor | NestedTensor, attention_mask: Optional[Tensor] = None, position_ids: Optional[Tensor] = None, head_mask: Optional[Tensor] = None, @@ -336,7 +339,7 @@ def set_output_embeddings(self, embeddings): def forward( self, - input_ids: Optional[Tensor] = None, + input_ids: Tensor | NestedTensor, attention_mask: Optional[Tensor] = None, position_ids: Optional[Tensor] = None, head_mask: Optional[Tensor] = None, @@ -428,7 +431,7 @@ def __init__(self, config: UtrLmConfig): def forward( self, - input_ids: Optional[Tensor] = None, + input_ids: Tensor | NestedTensor, attention_mask: Optional[Tensor] = None, position_ids: Optional[Tensor] = None, head_mask: Optional[Tensor] = None, @@ -512,7 +515,7 @@ def __init__(self, config: UtrLmConfig): def forward( self, - input_ids: Optional[Tensor] = None, + input_ids: Tensor | NestedTensor, attention_mask: Optional[Tensor] = None, position_ids: Optional[Tensor] = None, head_mask: Optional[Tensor] = None, @@ -594,7 +597,7 @@ def __init__(self, config: UtrLmConfig): def forward( self, - input_ids: Tensor, + input_ids: Tensor | NestedTensor, attention_mask: Optional[Tensor] = None, labels: Optional[Tensor] = None, output_attentions: bool = False, @@ -1218,7 +1221,7 @@ def forward( self, outputs: BaseModelOutputWithPastAndCrossAttentions | Tuple[Tensor, ...], attention_mask: Optional[Tensor] = None, - input_ids: Optional[Tensor] = None, + input_ids: Tensor | NestedTensor | None = None, ) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: logits = self.predictions(outputs) contact_map = self.contact(torch.stack(outputs[-1], 1), attention_mask, input_ids)