Skip to content

Commit

Permalink
inspect if input_ids is NestedTensor
Browse files Browse the repository at this point in the history
Signed-off-by: Zhiyuan Chen <[email protected]>
  • Loading branch information
ZhiyuanChen committed May 2, 2024
1 parent cc25cbe commit 183a11b
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 44 deletions.
15 changes: 9 additions & 6 deletions multimolecule/models/rnabert/modeling_rnabert.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Optional, Tuple

import torch
from danling import NestedTensor
from torch import Tensor, nn
from torch.nn import functional as F
from transformers import PreTrainedModel
Expand Down Expand Up @@ -79,12 +80,14 @@ def __init__(self, config: RnaBertConfig, add_pooling_layer: bool = True):

def forward(
self,
input_ids: Tensor,
input_ids: Tensor | NestedTensor,
attention_mask: Optional[Tensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
) -> Tuple[Tensor, ...] | BaseModelOutputWithPooling:
if isinstance(input_ids, NestedTensor):
input_ids, attention_mask = input_ids.tensor, input_ids.mask
if attention_mask is None:
attention_mask = (
input_ids.ne(self.pad_token_id) if self.pad_token_id is not None else torch.ones_like(input_ids)
Expand Down Expand Up @@ -135,7 +138,7 @@ def __init__(self, config: RnaBertConfig):

def forward(
self,
input_ids: Tensor,
input_ids: Tensor | NestedTensor,
attention_mask: Optional[Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: bool = False,
Expand Down Expand Up @@ -188,7 +191,7 @@ def __init__(self, config: RnaBertConfig):

def forward(
self,
input_ids: Tensor,
input_ids: Tensor | NestedTensor,
attention_mask: Optional[Tensor] = None,
labels: Optional[torch.Tensor] = None,
labels_ss: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -253,7 +256,7 @@ def __init__(self, config: RnaBertConfig):

def forward(
self,
input_ids: Tensor,
input_ids: Tensor | NestedTensor,
attention_mask: Optional[Tensor] = None,
labels: Optional[Tensor] = None,
output_attentions: bool = False,
Expand Down Expand Up @@ -332,7 +335,7 @@ def __init__(self, config: RnaBertConfig):

def forward(
self,
input_ids: Tensor,
input_ids: Tensor | NestedTensor,
attention_mask: Optional[Tensor] = None,
labels: Optional[Tensor] = None,
output_attentions: bool = False,
Expand Down Expand Up @@ -411,7 +414,7 @@ def __init__(self, config: RnaBertConfig):

def forward(
self,
input_ids: Tensor,
input_ids: Tensor | NestedTensor,
attention_mask: Optional[Tensor] = None,
labels: Optional[Tensor] = None,
output_attentions: bool = False,
Expand Down
19 changes: 11 additions & 8 deletions multimolecule/models/rnafm/modeling_rnafm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import torch
import torch.utils.checkpoint
from danling import NestedTensor
from torch import Tensor, nn
from torch.nn import functional as F
from transformers.activations import ACT2FN
Expand Down Expand Up @@ -98,7 +99,7 @@ class PreTrainedModel

def forward(
self,
input_ids: Optional[Tensor] = None,
input_ids: Tensor | NestedTensor,
attention_mask: Optional[Tensor] = None,
position_ids: Optional[Tensor] = None,
head_mask: Optional[Tensor] = None,
Expand Down Expand Up @@ -143,9 +144,11 @@ def forward(
else:
use_cache = False

if isinstance(input_ids, NestedTensor):
input_ids, attention_mask = input_ids.tensor, input_ids.mask
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
if input_ids is not None:
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
input_shape = input_ids.size()
elif inputs_embeds is not None:
Expand Down Expand Up @@ -251,7 +254,7 @@ def __init__(self, config: RnaFmConfig):

def forward(
self,
input_ids: Optional[Tensor] = None,
input_ids: Tensor | NestedTensor,
attention_mask: Optional[Tensor] = None,
position_ids: Optional[Tensor] = None,
head_mask: Optional[Tensor] = None,
Expand Down Expand Up @@ -336,7 +339,7 @@ def set_output_embeddings(self, embeddings):

def forward(
self,
input_ids: Optional[Tensor] = None,
input_ids: Tensor | NestedTensor,
attention_mask: Optional[Tensor] = None,
position_ids: Optional[Tensor] = None,
head_mask: Optional[Tensor] = None,
Expand Down Expand Up @@ -417,7 +420,7 @@ def __init__(self, config: RnaFmConfig):

def forward(
self,
input_ids: Optional[Tensor] = None,
input_ids: Tensor | NestedTensor,
attention_mask: Optional[Tensor] = None,
position_ids: Optional[Tensor] = None,
head_mask: Optional[Tensor] = None,
Expand Down Expand Up @@ -501,7 +504,7 @@ def __init__(self, config: RnaFmConfig):

def forward(
self,
input_ids: Optional[Tensor] = None,
input_ids: Tensor | NestedTensor,
attention_mask: Optional[Tensor] = None,
position_ids: Optional[Tensor] = None,
head_mask: Optional[Tensor] = None,
Expand Down Expand Up @@ -583,7 +586,7 @@ def __init__(self, config: RnaFmConfig):

def forward(
self,
input_ids: Tensor,
input_ids: Tensor | NestedTensor,
attention_mask: Optional[Tensor] = None,
labels: Optional[Tensor] = None,
output_attentions: bool = False,
Expand Down Expand Up @@ -1197,7 +1200,7 @@ def forward(
self,
outputs: BaseModelOutputWithPastAndCrossAttentions | Tuple[Tensor, ...],
attention_mask: Optional[Tensor] = None,
input_ids: Optional[Tensor] = None,
input_ids: Tensor | NestedTensor | None = None,
) -> Tuple[Tensor, Tensor]:
logits = self.predictions(outputs)
contact_map = self.contact(torch.stack(outputs[-1], 1), attention_mask, input_ids)
Expand Down
19 changes: 11 additions & 8 deletions multimolecule/models/rnamsm/modeling_rnamsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import torch
from chanfig import ConfigRegistry
from danling import NestedTensor
from torch import Tensor, nn
from torch.nn import functional as F
from transformers import PreTrainedModel
Expand Down Expand Up @@ -75,12 +76,14 @@ def __init__(self, config: RnaMsmConfig, add_pooling_layer: bool = True):

def forward(
self,
input_ids: Tensor,
input_ids: Tensor | NestedTensor,
attention_mask: Optional[Tensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
) -> Tuple[Tensor, ...] | RnaMsmModelOutputWithPooling:
if isinstance(input_ids, NestedTensor):
input_ids, attention_mask = input_ids.tensor, input_ids.mask
if attention_mask is None:
attention_mask = (
input_ids.ne(self.pad_token_id) if self.pad_token_id is not None else torch.ones_like(input_ids)
Expand Down Expand Up @@ -134,7 +137,7 @@ def __init__(self, config: RnaMsmConfig):

def forward(
self,
input_ids: Tensor,
input_ids: Tensor | NestedTensor,
attention_mask: Optional[Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: bool = False,
Expand Down Expand Up @@ -188,7 +191,7 @@ def __init__(self, config: RnaMsmConfig):

def forward(
self,
input_ids: Tensor,
input_ids: Tensor | NestedTensor,
attention_mask: Optional[Tensor] = None,
labels: Optional[torch.Tensor] = None,
labels_contact: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -251,7 +254,7 @@ def __init__(self, config: RnaMsmConfig):

def forward(
self,
input_ids: Tensor,
input_ids: Tensor | NestedTensor,
attention_mask: Optional[Tensor] = None,
labels: Optional[Tensor] = None,
output_attentions: bool = False,
Expand Down Expand Up @@ -331,7 +334,7 @@ def __init__(self, config: RnaMsmConfig):

def forward(
self,
input_ids: Tensor,
input_ids: Tensor | NestedTensor,
attention_mask: Optional[Tensor] = None,
labels: Optional[Tensor] = None,
output_attentions: bool = False,
Expand Down Expand Up @@ -411,7 +414,7 @@ def __init__(self, config: RnaMsmConfig):

def forward(
self,
input_ids: Tensor,
input_ids: Tensor | NestedTensor,
attention_mask: Optional[Tensor] = None,
labels: Optional[Tensor] = None,
output_attentions: bool = False,
Expand Down Expand Up @@ -489,7 +492,7 @@ def __init__(self, config: RnaMsmConfig):
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=1e-12)
self.dropout = nn.Dropout(config.hidden_dropout)

def forward(self, input_ids: Tensor, attention_mask: Optional[Tensor] = None) -> Tensor:
def forward(self, input_ids: Tensor | NestedTensor, attention_mask: Optional[Tensor] = None) -> Tensor:
assert input_ids.ndim == 3
if attention_mask is None:
attention_mask = input_ids.ne(self.pad_token_id)
Expand Down Expand Up @@ -1259,7 +1262,7 @@ def forward(
self,
outputs: RnaMsmModelOutput | Tuple[Tensor, ...],
attention_mask: Optional[Tensor] = None,
input_ids: Optional[Tensor] = None,
input_ids: Tensor | NestedTensor | None = None,
) -> Tuple[Tensor, Tensor]:
sequence_output, row_attentions = outputs[0], torch.stack(outputs[-1], 1)
prediction_scores = self.predictions(sequence_output)
Expand Down
17 changes: 10 additions & 7 deletions multimolecule/models/splicebert/modeling_splicebert.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import torch
import torch.utils.checkpoint
from danling import NestedTensor
from torch import Tensor, nn
from torch.nn import functional as F
from transformers.activations import ACT2FN
Expand Down Expand Up @@ -106,7 +107,7 @@ class PreTrainedModel

def forward(
self,
input_ids: Optional[Tensor] = None,
input_ids: Tensor | NestedTensor,
attention_mask: Optional[Tensor] = None,
position_ids: Optional[Tensor] = None,
head_mask: Optional[Tensor] = None,
Expand Down Expand Up @@ -151,9 +152,11 @@ def forward(
else:
use_cache = False

if isinstance(input_ids, NestedTensor):
input_ids, attention_mask = input_ids.tensor, input_ids.mask
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
if input_ids is not None:
# self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
input_shape = input_ids.size()
elif inputs_embeds is not None:
Expand Down Expand Up @@ -264,7 +267,7 @@ def set_output_embeddings(self, new_embeddings):

def forward(
self,
input_ids: Optional[Tensor] = None,
input_ids: Tensor | NestedTensor,
attention_mask: Optional[Tensor] = None,
position_ids: Optional[Tensor] = None,
head_mask: Optional[Tensor] = None,
Expand Down Expand Up @@ -359,7 +362,7 @@ def __init__(self, config: SpliceBertConfig):

def forward(
self,
input_ids: Optional[Tensor] = None,
input_ids: Tensor | NestedTensor,
attention_mask: Optional[Tensor] = None,
position_ids: Optional[Tensor] = None,
head_mask: Optional[Tensor] = None,
Expand Down Expand Up @@ -433,7 +436,7 @@ def __init__(self, config: SpliceBertConfig):

def forward(
self,
input_ids: Optional[Tensor] = None,
input_ids: Tensor | NestedTensor,
attention_mask: Optional[Tensor] = None,
position_ids: Optional[Tensor] = None,
head_mask: Optional[Tensor] = None,
Expand Down Expand Up @@ -518,7 +521,7 @@ def __init__(self, config: SpliceBertConfig):

def forward(
self,
input_ids: Optional[Tensor] = None,
input_ids: Tensor | NestedTensor,
attention_mask: Optional[Tensor] = None,
position_ids: Optional[Tensor] = None,
head_mask: Optional[Tensor] = None,
Expand Down Expand Up @@ -601,7 +604,7 @@ def __init__(self, config: SpliceBertConfig):

def forward(
self,
input_ids: Tensor,
input_ids: Tensor | NestedTensor,
attention_mask: Optional[Tensor] = None,
labels: Optional[Tensor] = None,
output_attentions: bool = False,
Expand Down
17 changes: 10 additions & 7 deletions multimolecule/models/utrbert/modeling_utrbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import torch
import torch.utils.checkpoint
from danling import NestedTensor
from torch import Tensor, nn
from torch.nn import functional as F
from transformers.activations import ACT2FN
Expand Down Expand Up @@ -91,7 +92,7 @@ class PreTrainedModel

def forward(
self,
input_ids: Optional[Tensor] = None,
input_ids: Tensor | NestedTensor,
attention_mask: Optional[Tensor] = None,
position_ids: Optional[Tensor] = None,
head_mask: Optional[Tensor] = None,
Expand Down Expand Up @@ -136,9 +137,11 @@ def forward(
else:
use_cache = False

if isinstance(input_ids, NestedTensor):
input_ids, attention_mask = input_ids.tensor, input_ids.mask
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
if input_ids is not None:
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
input_shape = input_ids.size()
elif inputs_embeds is not None:
Expand Down Expand Up @@ -249,7 +252,7 @@ def set_output_embeddings(self, new_embeddings):

def forward(
self,
input_ids: Optional[Tensor] = None,
input_ids: Tensor | NestedTensor,
attention_mask: Optional[Tensor] = None,
position_ids: Optional[Tensor] = None,
head_mask: Optional[Tensor] = None,
Expand Down Expand Up @@ -344,7 +347,7 @@ def __init__(self, config: UtrBertConfig):

def forward(
self,
input_ids: Optional[Tensor] = None,
input_ids: Tensor | NestedTensor,
attention_mask: Optional[Tensor] = None,
position_ids: Optional[Tensor] = None,
head_mask: Optional[Tensor] = None,
Expand Down Expand Up @@ -414,7 +417,7 @@ def __init__(self, config: UtrBertConfig):

def forward(
self,
input_ids: Optional[Tensor] = None,
input_ids: Tensor | NestedTensor,
attention_mask: Optional[Tensor] = None,
position_ids: Optional[Tensor] = None,
head_mask: Optional[Tensor] = None,
Expand Down Expand Up @@ -499,7 +502,7 @@ def __init__(self, config: UtrBertConfig):

def forward(
self,
input_ids: Optional[Tensor] = None,
input_ids: Tensor | NestedTensor,
attention_mask: Optional[Tensor] = None,
position_ids: Optional[Tensor] = None,
head_mask: Optional[Tensor] = None,
Expand Down Expand Up @@ -582,7 +585,7 @@ def __init__(self, config: UtrBertConfig):

def forward(
self,
input_ids: Tensor,
input_ids: Tensor | NestedTensor,
attention_mask: Optional[Tensor] = None,
labels: Optional[Tensor] = None,
output_attentions: bool = False,
Expand Down
Loading

0 comments on commit 183a11b

Please sign in to comment.