diff --git a/multimolecule/downstream/crispr_off_target.py b/multimolecule/downstream/crispr_off_target.py index cc51bd3f..c527df86 100644 --- a/multimolecule/downstream/crispr_off_target.py +++ b/multimolecule/downstream/crispr_off_target.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Tuple import torch from torch import Tensor @@ -45,10 +45,10 @@ def __init__(self, config: RnaBertConfig): def forward( self, input_ids: Tensor, - target_input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[Tensor] = None, - target_attention_mask: Optional[torch.Tensor] = None, - labels: Optional[Tensor] = None, + target_input_ids: Tensor | None = None, + attention_mask: Tensor | None = None, + target_attention_mask: Tensor | None = None, + labels: Tensor | None = None, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, @@ -142,10 +142,10 @@ def __init__(self, config: RnaFmConfig): def forward( self, input_ids: Tensor, - target_input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[Tensor] = None, - target_attention_mask: Optional[torch.Tensor] = None, - labels: Optional[Tensor] = None, + target_input_ids: Tensor | None = None, + attention_mask: Tensor | None = None, + target_attention_mask: Tensor | None = None, + labels: Tensor | None = None, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, @@ -239,10 +239,10 @@ def __init__(self, config: RnaMsmConfig): def forward( self, input_ids: Tensor, - target_input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[Tensor] = None, - target_attention_mask: Optional[torch.Tensor] = None, - labels: Optional[Tensor] = None, + target_input_ids: Tensor | None = None, + attention_mask: Tensor | None = None, + target_attention_mask: Tensor | None = None, + labels: Tensor | None = None, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, @@ -338,10 +338,10 @@ def __init__(self, config: SpliceBertConfig): def forward( self, input_ids: Tensor, - target_input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[Tensor] = None, - target_attention_mask: Optional[torch.Tensor] = None, - labels: Optional[Tensor] = None, + target_input_ids: Tensor | None = None, + attention_mask: Tensor | None = None, + target_attention_mask: Tensor | None = None, + labels: Tensor | None = None, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, @@ -435,10 +435,10 @@ def __init__(self, config: UtrBertConfig): def forward( self, input_ids: Tensor, - target_input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[Tensor] = None, - target_attention_mask: Optional[torch.Tensor] = None, - labels: Optional[Tensor] = None, + target_input_ids: Tensor | None = None, + attention_mask: Tensor | None = None, + target_attention_mask: Tensor | None = None, + labels: Tensor | None = None, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, @@ -532,10 +532,10 @@ def __init__(self, config: UtrLmConfig): def forward( self, input_ids: Tensor, - target_input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[Tensor] = None, - target_attention_mask: Optional[torch.Tensor] = None, - labels: Optional[Tensor] = None, + target_input_ids: Tensor | None = None, + attention_mask: Tensor | None = None, + target_attention_mask: Tensor | None = None, + labels: Tensor | None = None, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, @@ -603,21 +603,21 @@ def forward( @dataclass class CrisprOffTargetOutput(ModelOutput): - loss: Optional[torch.FloatTensor] = None + loss: torch.FloatTensor | None = None logits: torch.FloatTensor = None - sgrna_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - sgrna_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - target_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - target_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + sgrna_hidden_states: Tuple[torch.FloatTensor, ...] | None = None + sgrna_attentions: Tuple[torch.FloatTensor, ...] | None = None + target_hidden_states: Tuple[torch.FloatTensor, ...] | None = None + target_attentions: Tuple[torch.FloatTensor, ...] | None = None @dataclass class RnaMsmForCrisprOffTargetOutput(ModelOutput): - loss: Optional[torch.FloatTensor] = None + loss: torch.FloatTensor | None = None logits: torch.FloatTensor = None - sgrna_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - sgrna_col_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - sgrna_row_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - target_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - target_col_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - target_row_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + sgrna_hidden_states: Tuple[torch.FloatTensor, ...] | None = None + sgrna_col_attentions: Tuple[torch.FloatTensor, ...] | None = None + sgrna_row_attentions: Tuple[torch.FloatTensor, ...] | None = None + target_hidden_states: Tuple[torch.FloatTensor, ...] | None = None + target_col_attentions: Tuple[torch.FloatTensor, ...] | None = None + target_row_attentions: Tuple[torch.FloatTensor, ...] | None = None diff --git a/multimolecule/models/configuration_utils.py b/multimolecule/models/configuration_utils.py index c6e33793..75f8f8bf 100644 --- a/multimolecule/models/configuration_utils.py +++ b/multimolecule/models/configuration_utils.py @@ -1,7 +1,6 @@ from __future__ import annotations from dataclasses import asdict, dataclass, is_dataclass -from typing import Optional from transformers.configuration_utils import PretrainedConfig as _PretrainedConfig @@ -70,15 +69,15 @@ class HeadConfig: `"single_label_classification"` or `"multi_label_classification"`. """ - hidden_size: Optional[int] = None + hidden_size: int | None = None dropout: float = 0.0 - transform: Optional[str] = None - transform_act: Optional[str] = "gelu" + transform: str | None = None + transform_act: str | None = "gelu" bias: bool = True - act: Optional[str] = None + act: str | None = None layer_norm_eps: float = 1e-12 num_labels: int = 1 - problem_type: Optional[str] = None + problem_type: str | None = None @dataclass @@ -108,10 +107,10 @@ class MaskedLMHeadConfig: The epsilon used by the layer normalization layers. """ - hidden_size: Optional[int] = None + hidden_size: int | None = None dropout: float = 0.0 - transform: Optional[str] = "nonlinear" - transform_act: Optional[str] = "gelu" + transform: str | None = "nonlinear" + transform_act: str | None = "gelu" bias: bool = True - act: Optional[str] = None + act: str | None = None layer_norm_eps: float = 1e-12 diff --git a/multimolecule/models/modeling_utils.py b/multimolecule/models/modeling_utils.py index 5ba06bb9..5c529020 100644 --- a/multimolecule/models/modeling_utils.py +++ b/multimolecule/models/modeling_utils.py @@ -1,7 +1,7 @@ from __future__ import annotations from functools import partial -from typing import Optional, Tuple +from typing import Tuple import torch from chanfig import ConfigRegistry @@ -42,7 +42,7 @@ def __init__(self, config: PretrainedConfig): self.activation = ACT2FN[self.config.act] if self.config.act is not None else None def forward( - self, attentions: Tensor, attention_mask: Optional[Tensor] = None, input_ids: Optional[Tensor] = None + self, attentions: Tensor, attention_mask: Tensor | None = None, input_ids: Tensor | None = None ) -> Tensor: if attention_mask is None: if input_ids is None: @@ -96,7 +96,7 @@ def forward( class MaskedLMHead(nn.Module): """Head for masked language modeling.""" - def __init__(self, config: PretrainedConfig, weight: Optional[Tensor] = None): + def __init__(self, config: PretrainedConfig, weight: Tensor | None = None): super().__init__() self.config = config.lm_head if hasattr(config, "lm_head") else config.head if self.config.hidden_size is None: @@ -185,8 +185,8 @@ def __init__(self, config: PretrainedConfig): def forward( # pylint: disable=arguments-renamed self, outputs: ModelOutput | Tuple[Tensor, ...], - attention_mask: Optional[Tensor] = None, - input_ids: Optional[Tensor] = None, + attention_mask: Tensor | None = None, + input_ids: Tensor | None = None, ) -> Tensor: if attention_mask is None: if input_ids is None: @@ -214,8 +214,8 @@ def __init__(self, config: PretrainedConfig): def forward( # pylint: disable=arguments-renamed self, outputs: ModelOutput | Tuple[Tensor, ...], - attention_mask: Optional[Tensor] = None, - input_ids: Optional[Tensor] = None, + attention_mask: Tensor | None = None, + input_ids: Tensor | None = None, ) -> Tensor: if attention_mask is None: if input_ids is None: @@ -267,8 +267,8 @@ def __init__(self, config: PretrainedConfig): def forward( # pylint: disable=arguments-renamed self, outputs: ModelOutput | Tuple[Tensor, ...], - attention_mask: Optional[Tensor] = None, - input_ids: Optional[Tensor] = None, + attention_mask: Tensor | None = None, + input_ids: Tensor | None = None, ) -> Tensor: if attention_mask is None: if input_ids is None: @@ -348,8 +348,8 @@ def unfold_kmer_embeddings( embeddings: Tensor, attention_mask: Tensor, nmers: int, - bos_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, + bos_token_id: int | None = None, + eos_token_id: int | None = None, ) -> Tensor: r""" Unfold k-mer embeddings to token embeddings. diff --git a/multimolecule/models/rnabert/convert_checkpoint.py b/multimolecule/models/rnabert/convert_checkpoint.py index e4a3be3d..21759d05 100644 --- a/multimolecule/models/rnabert/convert_checkpoint.py +++ b/multimolecule/models/rnabert/convert_checkpoint.py @@ -1,5 +1,4 @@ import os -from typing import Optional import chanfig import torch @@ -98,8 +97,8 @@ class ConvertConfig: output_path: str = Config.model_type push_to_hub: bool = False delete_existing: bool = False - repo_id: Optional[str] = None - token: Optional[str] = None + repo_id: str | None = None + token: str | None = None def post(self): if self.repo_id is None: diff --git a/multimolecule/models/rnabert/modeling_rnabert.py b/multimolecule/models/rnabert/modeling_rnabert.py index 39763f1b..deb211e8 100644 --- a/multimolecule/models/rnabert/modeling_rnabert.py +++ b/multimolecule/models/rnabert/modeling_rnabert.py @@ -2,7 +2,7 @@ import math from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Tuple import torch from danling import NestedTensor @@ -81,7 +81,7 @@ def __init__(self, config: RnaBertConfig, add_pooling_layer: bool = True): def forward( self, input_ids: Tensor | NestedTensor, - attention_mask: Optional[Tensor] = None, + attention_mask: Tensor | None = None, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, @@ -139,8 +139,8 @@ def __init__(self, config: RnaBertConfig): def forward( self, input_ids: Tensor | NestedTensor, - attention_mask: Optional[Tensor] = None, - labels: Optional[torch.Tensor] = None, + attention_mask: Tensor | None = None, + labels: Tensor | None = None, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, @@ -192,10 +192,10 @@ def __init__(self, config: RnaBertConfig): def forward( self, input_ids: Tensor | NestedTensor, - attention_mask: Optional[Tensor] = None, - labels: Optional[torch.Tensor] = None, - labels_ss: Optional[torch.Tensor] = None, - next_sentence_label: Optional[torch.Tensor] = None, + attention_mask: Tensor | None = None, + labels: Tensor | None = None, + labels_ss: Tensor | None = None, + next_sentence_label: Tensor | None = None, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, @@ -257,8 +257,8 @@ def __init__(self, config: RnaBertConfig): def forward( self, input_ids: Tensor | NestedTensor, - attention_mask: Optional[Tensor] = None, - labels: Optional[Tensor] = None, + attention_mask: Tensor | None = None, + labels: Tensor | None = None, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, @@ -336,8 +336,8 @@ def __init__(self, config: RnaBertConfig): def forward( self, input_ids: Tensor | NestedTensor, - attention_mask: Optional[Tensor] = None, - labels: Optional[Tensor] = None, + attention_mask: Tensor | None = None, + labels: Tensor | None = None, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, @@ -415,8 +415,8 @@ def __init__(self, config: RnaBertConfig): def forward( self, input_ids: Tensor | NestedTensor, - attention_mask: Optional[Tensor] = None, - labels: Optional[Tensor] = None, + attention_mask: Tensor | None = None, + labels: Tensor | None = None, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, @@ -698,11 +698,11 @@ def forward(self, outputs: ModelOutput | Tuple[Tensor, ...]) -> Tuple[Tensor, Te @dataclass class RnaBertForPretrainingOutput(ModelOutput): - loss: Optional[torch.FloatTensor] = None + loss: torch.FloatTensor | None = None logits: torch.FloatTensor = None # type: ignore[assignment] logits_ss: torch.FloatTensor = None # type: ignore[assignment] - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + hidden_states: Tuple[torch.FloatTensor, ...] | None = None + attentions: Tuple[torch.FloatTensor, ...] | None = None class RnaBertLayerNorm(nn.Module): diff --git a/multimolecule/models/rnafm/convert_checkpoint.py b/multimolecule/models/rnafm/convert_checkpoint.py index c0daaf6d..fda5f6fb 100644 --- a/multimolecule/models/rnafm/convert_checkpoint.py +++ b/multimolecule/models/rnafm/convert_checkpoint.py @@ -1,5 +1,4 @@ import os -from typing import Optional import chanfig import torch @@ -135,8 +134,8 @@ class ConvertConfig: output_path: str = Config.model_type push_to_hub: bool = False delete_existing: bool = False - repo_id: Optional[str] = None - token: Optional[str] = None + repo_id: str | None = None + token: str | None = None def post(self): if self.repo_id is None: diff --git a/multimolecule/models/rnafm/modeling_rnafm.py b/multimolecule/models/rnafm/modeling_rnafm.py index 659ecdb4..09219f6e 100755 --- a/multimolecule/models/rnafm/modeling_rnafm.py +++ b/multimolecule/models/rnafm/modeling_rnafm.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from typing import List, Tuple, Union import torch import torch.utils.checkpoint @@ -100,17 +100,17 @@ class PreTrainedModel def forward( self, input_ids: Tensor | NestedTensor, - attention_mask: Optional[Tensor] = None, - position_ids: Optional[Tensor] = None, - head_mask: Optional[Tensor] = None, - inputs_embeds: Optional[Tensor] = None, - encoder_hidden_states: Optional[Tensor] = None, - encoder_attention_mask: Optional[Tensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, + attention_mask: Tensor | None = None, + position_ids: Tensor | None = None, + head_mask: Tensor | None = None, + inputs_embeds: Tensor | None = None, + encoder_hidden_states: Tensor | None = None, + encoder_attention_mask: Tensor | None = None, + past_key_values: List[torch.FloatTensor] | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, ) -> Union[Tuple[Tensor, ...], BaseModelOutputWithPoolingAndCrossAttentions]: r""" encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): @@ -255,16 +255,16 @@ def __init__(self, config: RnaFmConfig): def forward( self, input_ids: Tensor | NestedTensor, - attention_mask: Optional[Tensor] = None, - position_ids: Optional[Tensor] = None, - head_mask: Optional[Tensor] = None, - inputs_embeds: Optional[Tensor] = None, - encoder_hidden_states: Optional[Tensor] = None, - encoder_attention_mask: Optional[Tensor] = None, - labels: Optional[Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, + attention_mask: Tensor | None = None, + position_ids: Tensor | None = None, + head_mask: Tensor | None = None, + inputs_embeds: Tensor | None = None, + encoder_hidden_states: Tensor | None = None, + encoder_attention_mask: Tensor | None = None, + labels: Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, ) -> Union[Tuple[Tensor, ...], MaskedLMOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -340,17 +340,17 @@ def set_output_embeddings(self, embeddings): def forward( self, input_ids: Tensor | NestedTensor, - attention_mask: Optional[Tensor] = None, - position_ids: Optional[Tensor] = None, - head_mask: Optional[Tensor] = None, - inputs_embeds: Optional[Tensor] = None, - encoder_hidden_states: Optional[Tensor] = None, - encoder_attention_mask: Optional[Tensor] = None, - labels: Optional[Tensor] = None, - labels_contact: Optional[Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, + attention_mask: Tensor | None = None, + position_ids: Tensor | None = None, + head_mask: Tensor | None = None, + inputs_embeds: Tensor | None = None, + encoder_hidden_states: Tensor | None = None, + encoder_attention_mask: Tensor | None = None, + labels: Tensor | None = None, + labels_contact: Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, ) -> Union[Tuple[Tensor, ...], RnaFmForPretrainingOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -421,14 +421,14 @@ def __init__(self, config: RnaFmConfig): def forward( self, input_ids: Tensor | NestedTensor, - attention_mask: Optional[Tensor] = None, - position_ids: Optional[Tensor] = None, - head_mask: Optional[Tensor] = None, - inputs_embeds: Optional[Tensor] = None, - labels: Optional[Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, + attention_mask: Tensor | None = None, + position_ids: Tensor | None = None, + head_mask: Tensor | None = None, + inputs_embeds: Tensor | None = None, + labels: Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, ) -> Union[Tuple[Tensor, ...], SequenceClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -505,14 +505,14 @@ def __init__(self, config: RnaFmConfig): def forward( self, input_ids: Tensor | NestedTensor, - attention_mask: Optional[Tensor] = None, - position_ids: Optional[Tensor] = None, - head_mask: Optional[Tensor] = None, - inputs_embeds: Optional[Tensor] = None, - labels: Optional[Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, + attention_mask: Tensor | None = None, + position_ids: Tensor | None = None, + head_mask: Tensor | None = None, + inputs_embeds: Tensor | None = None, + labels: Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, ) -> Union[Tuple[Tensor, ...], TokenClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -587,8 +587,8 @@ def __init__(self, config: RnaFmConfig): def forward( self, input_ids: Tensor | NestedTensor, - attention_mask: Optional[Tensor] = None, - labels: Optional[Tensor] = None, + attention_mask: Tensor | None = None, + labels: Tensor | None = None, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, @@ -778,12 +778,12 @@ def __init__(self, config: RnaFmConfig): def forward( self, hidden_states: Tensor, - attention_mask: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor, ...], ...]] = None, - use_cache: Optional[bool] = None, + attention_mask: torch.FloatTensor | None = None, + head_mask: torch.FloatTensor | None = None, + encoder_hidden_states: torch.FloatTensor | None = None, + encoder_attention_mask: torch.FloatTensor | None = None, + past_key_values: Tuple[Tuple[torch.FloatTensor, ...], ...] | None = None, + use_cache: bool | None = None, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, @@ -882,11 +882,11 @@ def __init__(self, config: RnaFmConfig): def forward( self, hidden_states: Tensor, - attention_mask: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor, ...], ...]] = None, + attention_mask: torch.FloatTensor | None = None, + head_mask: torch.FloatTensor | None = None, + encoder_hidden_states: torch.FloatTensor | None = None, + encoder_attention_mask: torch.FloatTensor | None = None, + past_key_value: Tuple[Tuple[torch.FloatTensor, ...], ...] | None = None, output_attentions: bool = False, ) -> Tuple[Tensor, ...]: # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 @@ -979,11 +979,11 @@ def prune_heads(self, heads): def forward( self, hidden_states: Tensor, - attention_mask: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor, ...], ...]] = None, + attention_mask: torch.FloatTensor | None = None, + head_mask: torch.FloatTensor | None = None, + encoder_hidden_states: torch.FloatTensor | None = None, + encoder_attention_mask: torch.FloatTensor | None = None, + past_key_value: Tuple[Tuple[torch.FloatTensor, ...], ...] | None = None, output_attentions: bool = False, ) -> Tuple[Tensor, ...]: hidden_states_ln = self.layer_norm(hidden_states) @@ -1002,7 +1002,7 @@ def forward( class RnaFmSelfAttention(nn.Module): - def __init__(self, config: RnaFmConfig, position_embedding_type: Optional[str] = None): + def __init__(self, config: RnaFmConfig, position_embedding_type: str | None = None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -1037,11 +1037,11 @@ def transpose_for_scores(self, x: Tensor) -> Tensor: def forward( self, hidden_states: Tensor, - attention_mask: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor, ...], ...]] = None, + attention_mask: torch.FloatTensor | None = None, + head_mask: torch.FloatTensor | None = None, + encoder_hidden_states: torch.FloatTensor | None = None, + encoder_attention_mask: torch.FloatTensor | None = None, + past_key_value: Tuple[Tuple[torch.FloatTensor, ...], ...] | None = None, output_attentions: bool = False, ) -> Tuple[Tensor, ...]: mixed_query_layer = self.query(hidden_states) @@ -1199,7 +1199,7 @@ def __init__(self, config: RnaFmConfig): def forward( self, outputs: BaseModelOutputWithPastAndCrossAttentions | Tuple[Tensor, ...], - attention_mask: Optional[Tensor] = None, + attention_mask: Tensor | None = None, input_ids: Tensor | NestedTensor | None = None, ) -> Tuple[Tensor, Tensor]: logits = self.predictions(outputs) @@ -1209,11 +1209,11 @@ def forward( @dataclass class RnaFmForPretrainingOutput(ModelOutput): - loss: Optional[torch.FloatTensor] = None + loss: torch.FloatTensor | None = None logits: torch.FloatTensor = None - contact_map: Optional[torch.FloatTensor] = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + contact_map: torch.FloatTensor | None = None + hidden_states: Tuple[torch.FloatTensor, ...] | None = None + attentions: Tuple[torch.FloatTensor, ...] | None = None def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0): diff --git a/multimolecule/models/rnamsm/convert_checkpoint.py b/multimolecule/models/rnamsm/convert_checkpoint.py index 64662ce5..49c83ebb 100644 --- a/multimolecule/models/rnamsm/convert_checkpoint.py +++ b/multimolecule/models/rnamsm/convert_checkpoint.py @@ -1,5 +1,4 @@ import os -from typing import Optional import chanfig import torch @@ -98,8 +97,8 @@ class ConvertConfig: output_path: str = Config.model_type push_to_hub: bool = False delete_existing: bool = False - repo_id: Optional[str] = None - token: Optional[str] = None + repo_id: str | None = None + token: str | None = None def post(self): if self.repo_id is None: diff --git a/multimolecule/models/rnamsm/modeling_rnamsm.py b/multimolecule/models/rnamsm/modeling_rnamsm.py index 0e4e355a..3197d8fe 100644 --- a/multimolecule/models/rnamsm/modeling_rnamsm.py +++ b/multimolecule/models/rnamsm/modeling_rnamsm.py @@ -3,7 +3,7 @@ import math from dataclasses import dataclass from functools import partial -from typing import Optional, Tuple +from typing import Tuple import torch from chanfig import ConfigRegistry @@ -77,7 +77,7 @@ def __init__(self, config: RnaMsmConfig, add_pooling_layer: bool = True): def forward( self, input_ids: Tensor | NestedTensor, - attention_mask: Optional[Tensor] = None, + attention_mask: Tensor | None = None, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, @@ -138,8 +138,8 @@ def __init__(self, config: RnaMsmConfig): def forward( self, input_ids: Tensor | NestedTensor, - attention_mask: Optional[Tensor] = None, - labels: Optional[torch.Tensor] = None, + attention_mask: Tensor | None = None, + labels: Tensor | None = None, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, @@ -192,9 +192,9 @@ def __init__(self, config: RnaMsmConfig): def forward( self, input_ids: Tensor | NestedTensor, - attention_mask: Optional[Tensor] = None, - labels: Optional[torch.Tensor] = None, - labels_contact: Optional[torch.Tensor] = None, + attention_mask: Tensor | None = None, + labels: Tensor | None = None, + labels_contact: Tensor | None = None, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, @@ -255,8 +255,8 @@ def __init__(self, config: RnaMsmConfig): def forward( self, input_ids: Tensor | NestedTensor, - attention_mask: Optional[Tensor] = None, - labels: Optional[Tensor] = None, + attention_mask: Tensor | None = None, + labels: Tensor | None = None, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, @@ -335,8 +335,8 @@ def __init__(self, config: RnaMsmConfig): def forward( self, input_ids: Tensor | NestedTensor, - attention_mask: Optional[Tensor] = None, - labels: Optional[Tensor] = None, + attention_mask: Tensor | None = None, + labels: Tensor | None = None, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, @@ -415,8 +415,8 @@ def __init__(self, config: RnaMsmConfig): def forward( self, input_ids: Tensor | NestedTensor, - attention_mask: Optional[Tensor] = None, - labels: Optional[Tensor] = None, + attention_mask: Tensor | None = None, + labels: Tensor | None = None, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, @@ -492,7 +492,7 @@ def __init__(self, config: RnaMsmConfig): self.layer_norm = nn.LayerNorm(config.hidden_size, eps=1e-12) self.dropout = nn.Dropout(config.hidden_dropout) - def forward(self, input_ids: Tensor | NestedTensor, attention_mask: Optional[Tensor] = None) -> Tensor: + def forward(self, input_ids: Tensor | NestedTensor, attention_mask: Tensor | None = None) -> Tensor: assert input_ids.ndim == 3 if attention_mask is None: attention_mask = input_ids.ne(self.pad_token_id) @@ -558,7 +558,7 @@ def __init__(self, config: RnaMsmConfig): def forward( self, hidden_states: Tensor, - key_padding_mask: Optional[torch.FloatTensor] = None, + key_padding_mask: torch.FloatTensor | None = None, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, @@ -628,8 +628,8 @@ def __init__(self, config) -> None: def forward( self, hidden_states: Tensor, - self_attention_mask: Optional[Tensor] = None, - self_attention_padding_mask: Optional[Tensor] = None, + self_attention_mask: Tensor | None = None, + self_attention_padding_mask: Tensor | None = None, output_attentions: bool = False, ) -> Tuple[Tensor, ...]: """ @@ -1029,9 +1029,9 @@ def attention_fn( query, key, value, - key_padding_mask: Optional[Tensor] = None, + key_padding_mask: Tensor | None = None, output_attentions: bool = False, - attention_mask: Optional[Tensor] = None, + attention_mask: Tensor | None = None, ): return self._attention_fn( query, @@ -1063,8 +1063,8 @@ def reset_parameters(self): def forward( self, hidden_states: Tensor, - key_padding_mask: Optional[Tensor] = None, - attention_mask: Optional[Tensor] = None, + key_padding_mask: Tensor | None = None, + attention_mask: Tensor | None = None, output_attentions: bool = False, ) -> Tuple[Tensor, ...]: """Input shape: Time x Batch x Channel @@ -1166,8 +1166,8 @@ def attention_fn(self, query, key, value): def forward( self, hidden_states: Tensor, - key_padding_mask: Optional[Tensor] = None, - attention_mask: Optional[Tensor] = None, + key_padding_mask: Tensor | None = None, + attention_mask: Tensor | None = None, output_attentions: bool = False, ) -> Tuple[Tensor, ...]: from einops import rearrange @@ -1253,7 +1253,7 @@ def forward(self, hidden_states: Tensor) -> Tensor: class RnaMsmPreTrainingHeads(nn.Module): - def __init__(self, config: RnaMsmConfig, weight: Optional[Tensor] = None): + def __init__(self, config: RnaMsmConfig, weight: Tensor | None = None): super().__init__() self.predictions = MaskedLMHead(config, weight=weight) self.contact = ContactPredictionHead(config) @@ -1261,7 +1261,7 @@ def __init__(self, config: RnaMsmConfig, weight: Optional[Tensor] = None): def forward( self, outputs: RnaMsmModelOutput | Tuple[Tensor, ...], - attention_mask: Optional[Tensor] = None, + attention_mask: Tensor | None = None, input_ids: Tensor | NestedTensor | None = None, ) -> Tuple[Tensor, Tensor]: sequence_output, row_attentions = outputs[0], torch.stack(outputs[-1], 1) @@ -1272,53 +1272,53 @@ def forward( @dataclass class RnaMsmForPretrainingOutput(ModelOutput): - loss: Optional[torch.FloatTensor] = None + loss: torch.FloatTensor | None = None logits: torch.FloatTensor = None - contact_map: Optional[torch.FloatTensor] = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - col_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - row_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + contact_map: torch.FloatTensor | None = None + hidden_states: Tuple[torch.FloatTensor, ...] | None = None + col_attentions: Tuple[torch.FloatTensor, ...] | None = None + row_attentions: Tuple[torch.FloatTensor, ...] | None = None @dataclass class RnaMsmForMaskedLMOutput(ModelOutput): - loss: Optional[torch.FloatTensor] = None + loss: torch.FloatTensor | None = None logits: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - col_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - row_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + hidden_states: Tuple[torch.FloatTensor, ...] | None = None + col_attentions: Tuple[torch.FloatTensor, ...] | None = None + row_attentions: Tuple[torch.FloatTensor, ...] | None = None @dataclass class RnaMsmForSequenceClassifierOutput(ModelOutput): - loss: Optional[torch.FloatTensor] = None + loss: torch.FloatTensor | None = None logits: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - col_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - row_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + hidden_states: Tuple[torch.FloatTensor, ...] | None = None + col_attentions: Tuple[torch.FloatTensor, ...] | None = None + row_attentions: Tuple[torch.FloatTensor, ...] | None = None @dataclass class RnaMsmForTokenClassifierOutput(ModelOutput): - loss: Optional[torch.FloatTensor] = None + loss: torch.FloatTensor | None = None logits: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - col_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - row_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + hidden_states: Tuple[torch.FloatTensor, ...] | None = None + col_attentions: Tuple[torch.FloatTensor, ...] | None = None + row_attentions: Tuple[torch.FloatTensor, ...] | None = None @dataclass class RnaMsmModelOutputWithPooling(ModelOutput): last_hidden_state: torch.FloatTensor = None pooler_output: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - col_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - row_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + hidden_states: Tuple[torch.FloatTensor, ...] | None = None + col_attentions: Tuple[torch.FloatTensor, ...] | None = None + row_attentions: Tuple[torch.FloatTensor, ...] | None = None @dataclass class RnaMsmModelOutput(ModelOutput): last_hidden_state: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - col_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - row_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + hidden_states: Tuple[torch.FloatTensor, ...] | None = None + col_attentions: Tuple[torch.FloatTensor, ...] | None = None + row_attentions: Tuple[torch.FloatTensor, ...] | None = None diff --git a/multimolecule/models/splicebert/convert_checkpoint.py b/multimolecule/models/splicebert/convert_checkpoint.py index 0f2f4c23..9271da68 100644 --- a/multimolecule/models/splicebert/convert_checkpoint.py +++ b/multimolecule/models/splicebert/convert_checkpoint.py @@ -1,5 +1,4 @@ import os -from typing import Optional import chanfig import torch @@ -108,11 +107,11 @@ def convert_checkpoint(convert_config): @chanfig.configclass class ConvertConfig: checkpoint_path: str - output_path: Optional[str] = None + output_path: str | None = None push_to_hub: bool = False delete_existing: bool = False - repo_id: Optional[str] = None - token: Optional[str] = None + repo_id: str | None = None + token: str | None = None def post(self): if self.output_path is None: diff --git a/multimolecule/models/splicebert/modeling_splicebert.py b/multimolecule/models/splicebert/modeling_splicebert.py index 876fdb9b..fde7731f 100644 --- a/multimolecule/models/splicebert/modeling_splicebert.py +++ b/multimolecule/models/splicebert/modeling_splicebert.py @@ -1,7 +1,7 @@ from __future__ import annotations import math -from typing import List, Optional, Tuple +from typing import List, Tuple import torch import torch.utils.checkpoint @@ -108,17 +108,17 @@ class PreTrainedModel def forward( self, input_ids: Tensor | NestedTensor, - attention_mask: Optional[Tensor] = None, - position_ids: Optional[Tensor] = None, - head_mask: Optional[Tensor] = None, - inputs_embeds: Optional[Tensor] = None, - encoder_hidden_states: Optional[Tensor] = None, - encoder_attention_mask: Optional[Tensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, + attention_mask: Tensor | None = None, + position_ids: Tensor | None = None, + head_mask: Tensor | None = None, + inputs_embeds: Tensor | None = None, + encoder_hidden_states: Tensor | None = None, + encoder_attention_mask: Tensor | None = None, + past_key_values: List[torch.FloatTensor] | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, ) -> Tuple[Tensor, ...] | BaseModelOutputWithPoolingAndCrossAttentions: r""" encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): @@ -268,16 +268,16 @@ def set_output_embeddings(self, new_embeddings): def forward( self, input_ids: Tensor | NestedTensor, - attention_mask: Optional[Tensor] = None, - position_ids: Optional[Tensor] = None, - head_mask: Optional[Tensor] = None, - inputs_embeds: Optional[Tensor] = None, - encoder_hidden_states: Optional[Tensor] = None, - encoder_attention_mask: Optional[Tensor] = None, - labels: Optional[Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, + attention_mask: Tensor | None = None, + position_ids: Tensor | None = None, + head_mask: Tensor | None = None, + inputs_embeds: Tensor | None = None, + encoder_hidden_states: Tensor | None = None, + encoder_attention_mask: Tensor | None = None, + labels: Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, ) -> Tuple[Tensor, ...] | MaskedLMOutput: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -363,16 +363,16 @@ def __init__(self, config: SpliceBertConfig): def forward( self, input_ids: Tensor | NestedTensor, - attention_mask: Optional[Tensor] = None, - position_ids: Optional[Tensor] = None, - head_mask: Optional[Tensor] = None, - inputs_embeds: Optional[Tensor] = None, - encoder_hidden_states: Optional[Tensor] = None, - encoder_attention_mask: Optional[Tensor] = None, - labels: Optional[Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, + attention_mask: Tensor | None = None, + position_ids: Tensor | None = None, + head_mask: Tensor | None = None, + inputs_embeds: Tensor | None = None, + encoder_hidden_states: Tensor | None = None, + encoder_attention_mask: Tensor | None = None, + labels: Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, ) -> Tuple[Tensor, ...] | MaskedLMOutput: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -437,14 +437,14 @@ def __init__(self, config: SpliceBertConfig): def forward( self, input_ids: Tensor | NestedTensor, - attention_mask: Optional[Tensor] = None, - position_ids: Optional[Tensor] = None, - head_mask: Optional[Tensor] = None, - inputs_embeds: Optional[Tensor] = None, - labels: Optional[Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, + attention_mask: Tensor | None = None, + position_ids: Tensor | None = None, + head_mask: Tensor | None = None, + inputs_embeds: Tensor | None = None, + labels: Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, ) -> Tuple[Tensor, ...] | SequenceClassifierOutput: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -522,14 +522,14 @@ def __init__(self, config: SpliceBertConfig): def forward( self, input_ids: Tensor | NestedTensor, - attention_mask: Optional[Tensor] = None, - position_ids: Optional[Tensor] = None, - head_mask: Optional[Tensor] = None, - inputs_embeds: Optional[Tensor] = None, - labels: Optional[Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, + attention_mask: Tensor | None = None, + position_ids: Tensor | None = None, + head_mask: Tensor | None = None, + inputs_embeds: Tensor | None = None, + labels: Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, ) -> Tuple[Tensor, ...] | TokenClassifierOutput: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -605,8 +605,8 @@ def __init__(self, config: SpliceBertConfig): def forward( self, input_ids: Tensor | NestedTensor, - attention_mask: Optional[Tensor] = None, - labels: Optional[Tensor] = None, + attention_mask: Tensor | None = None, + labels: Tensor | None = None, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, @@ -679,9 +679,9 @@ def __init__(self, config: SpliceBertConfig): def forward( self, - input_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, + input_ids: torch.LongTensor | None = None, + position_ids: torch.LongTensor | None = None, + inputs_embeds: torch.FloatTensor | None = None, past_key_values_length: int = 0, ) -> Tensor: if input_ids is not None: @@ -720,12 +720,12 @@ def __init__(self, config: SpliceBertConfig): def forward( self, hidden_states: Tensor, - attention_mask: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor, ...], ...]] = None, - use_cache: Optional[bool] = None, + attention_mask: torch.FloatTensor | None = None, + head_mask: torch.FloatTensor | None = None, + encoder_hidden_states: torch.FloatTensor | None = None, + encoder_attention_mask: torch.FloatTensor | None = None, + past_key_values: Tuple[Tuple[torch.FloatTensor, ...], ...] | None = None, + use_cache: bool | None = None, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, @@ -825,11 +825,11 @@ def __init__(self, config: SpliceBertConfig): def forward( self, hidden_states: Tensor, - attention_mask: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor, ...], ...]] = None, + attention_mask: torch.FloatTensor | None = None, + head_mask: torch.FloatTensor | None = None, + encoder_hidden_states: torch.FloatTensor | None = None, + encoder_attention_mask: torch.FloatTensor | None = None, + past_key_value: Tuple[Tuple[torch.FloatTensor, ...], ...] | None = None, output_attentions: bool = False, ) -> Tuple[Tensor, ...]: # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 @@ -894,7 +894,7 @@ def feed_forward_chunk(self, attention_output): class SpliceBertAttention(nn.Module): - def __init__(self, config: SpliceBertConfig, position_embedding_type: Optional[str] = None): + def __init__(self, config: SpliceBertConfig, position_embedding_type: str | None = None): super().__init__() self.self = SpliceBertSelfAttention(config, position_embedding_type=position_embedding_type) self.output = SpliceBertSelfOutput(config) @@ -921,11 +921,11 @@ def prune_heads(self, heads): def forward( self, hidden_states: Tensor, - attention_mask: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor, ...], ...]] = None, + attention_mask: torch.FloatTensor | None = None, + head_mask: torch.FloatTensor | None = None, + encoder_hidden_states: torch.FloatTensor | None = None, + encoder_attention_mask: torch.FloatTensor | None = None, + past_key_value: Tuple[Tuple[torch.FloatTensor, ...], ...] | None = None, output_attentions: bool = False, ) -> Tuple[Tensor, ...]: self_outputs = self.self( @@ -943,7 +943,7 @@ def forward( class SpliceBertSelfAttention(nn.Module): - def __init__(self, config: SpliceBertConfig, position_embedding_type: Optional[str] = None): + def __init__(self, config: SpliceBertConfig, position_embedding_type: str | None = None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -976,11 +976,11 @@ def transpose_for_scores(self, x: Tensor) -> Tensor: def forward( self, hidden_states: Tensor, - attention_mask: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor, ...], ...]] = None, + attention_mask: torch.FloatTensor | None = None, + head_mask: torch.FloatTensor | None = None, + encoder_hidden_states: torch.FloatTensor | None = None, + encoder_attention_mask: torch.FloatTensor | None = None, + past_key_value: Tuple[Tuple[torch.FloatTensor, ...], ...] | None = None, output_attentions: bool = False, ) -> Tuple[Tensor, ...]: mixed_query_layer = self.query(hidden_states) diff --git a/multimolecule/models/utrbert/convert_checkpoint.py b/multimolecule/models/utrbert/convert_checkpoint.py index e6e07e93..2ce99963 100644 --- a/multimolecule/models/utrbert/convert_checkpoint.py +++ b/multimolecule/models/utrbert/convert_checkpoint.py @@ -1,5 +1,4 @@ import os -from typing import Optional import chanfig import torch @@ -111,11 +110,11 @@ def convert_checkpoint(convert_config): @chanfig.configclass class ConvertConfig: checkpoint_path: str - output_path: Optional[str] = None + output_path: str | None = None push_to_hub: bool = False delete_existing: bool = False - repo_id: Optional[str] = None - token: Optional[str] = None + repo_id: str | None = None + token: str | None = None def post(self): if self.output_path is None: diff --git a/multimolecule/models/utrbert/modeling_utrbert.py b/multimolecule/models/utrbert/modeling_utrbert.py index 4b3d27fd..426e14d3 100644 --- a/multimolecule/models/utrbert/modeling_utrbert.py +++ b/multimolecule/models/utrbert/modeling_utrbert.py @@ -1,7 +1,7 @@ from __future__ import annotations import math -from typing import List, Optional, Tuple +from typing import List, Tuple import torch import torch.utils.checkpoint @@ -93,17 +93,17 @@ class PreTrainedModel def forward( self, input_ids: Tensor | NestedTensor, - attention_mask: Optional[Tensor] = None, - position_ids: Optional[Tensor] = None, - head_mask: Optional[Tensor] = None, - inputs_embeds: Optional[Tensor] = None, - encoder_hidden_states: Optional[Tensor] = None, - encoder_attention_mask: Optional[Tensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, + attention_mask: Tensor | None = None, + position_ids: Tensor | None = None, + head_mask: Tensor | None = None, + inputs_embeds: Tensor | None = None, + encoder_hidden_states: Tensor | None = None, + encoder_attention_mask: Tensor | None = None, + past_key_values: List[torch.FloatTensor] | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, ) -> Tuple[Tensor, ...] | BaseModelOutputWithPoolingAndCrossAttentions: r""" encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): @@ -253,16 +253,16 @@ def set_output_embeddings(self, new_embeddings): def forward( self, input_ids: Tensor | NestedTensor, - attention_mask: Optional[Tensor] = None, - position_ids: Optional[Tensor] = None, - head_mask: Optional[Tensor] = None, - inputs_embeds: Optional[Tensor] = None, - encoder_hidden_states: Optional[Tensor] = None, - encoder_attention_mask: Optional[Tensor] = None, - labels: Optional[Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, + attention_mask: Tensor | None = None, + position_ids: Tensor | None = None, + head_mask: Tensor | None = None, + inputs_embeds: Tensor | None = None, + encoder_hidden_states: Tensor | None = None, + encoder_attention_mask: Tensor | None = None, + labels: Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, ) -> Tuple[Tensor, ...] | MaskedLMOutput: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -348,14 +348,14 @@ def __init__(self, config: UtrBertConfig): def forward( self, input_ids: Tensor | NestedTensor, - attention_mask: Optional[Tensor] = None, - position_ids: Optional[Tensor] = None, - head_mask: Optional[Tensor] = None, - inputs_embeds: Optional[Tensor] = None, - labels: Optional[Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, + attention_mask: Tensor | None = None, + position_ids: Tensor | None = None, + head_mask: Tensor | None = None, + inputs_embeds: Tensor | None = None, + labels: Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, ) -> Tuple[Tensor, ...] | MaskedLMOutput: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -418,14 +418,14 @@ def __init__(self, config: UtrBertConfig): def forward( self, input_ids: Tensor | NestedTensor, - attention_mask: Optional[Tensor] = None, - position_ids: Optional[Tensor] = None, - head_mask: Optional[Tensor] = None, - inputs_embeds: Optional[Tensor] = None, - labels: Optional[Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, + attention_mask: Tensor | None = None, + position_ids: Tensor | None = None, + head_mask: Tensor | None = None, + inputs_embeds: Tensor | None = None, + labels: Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, ) -> Tuple[Tensor, ...] | SequenceClassifierOutput: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -503,14 +503,14 @@ def __init__(self, config: UtrBertConfig): def forward( self, input_ids: Tensor | NestedTensor, - attention_mask: Optional[Tensor] = None, - position_ids: Optional[Tensor] = None, - head_mask: Optional[Tensor] = None, - inputs_embeds: Optional[Tensor] = None, - labels: Optional[Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, + attention_mask: Tensor | None = None, + position_ids: Tensor | None = None, + head_mask: Tensor | None = None, + inputs_embeds: Tensor | None = None, + labels: Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, ) -> Tuple[Tensor, ...] | TokenClassifierOutput: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -586,8 +586,8 @@ def __init__(self, config: UtrBertConfig): def forward( self, input_ids: Tensor | NestedTensor, - attention_mask: Optional[Tensor] = None, - labels: Optional[Tensor] = None, + attention_mask: Tensor | None = None, + labels: Tensor | None = None, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, @@ -660,9 +660,9 @@ def __init__(self, config: UtrBertConfig): def forward( self, - input_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, + input_ids: torch.LongTensor | None = None, + position_ids: torch.LongTensor | None = None, + inputs_embeds: torch.FloatTensor | None = None, past_key_values_length: int = 0, ) -> Tensor: if input_ids is not None: @@ -701,12 +701,12 @@ def __init__(self, config: UtrBertConfig): def forward( self, hidden_states: Tensor, - attention_mask: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor, ...], ...]] = None, - use_cache: Optional[bool] = None, + attention_mask: torch.FloatTensor | None = None, + head_mask: torch.FloatTensor | None = None, + encoder_hidden_states: torch.FloatTensor | None = None, + encoder_attention_mask: torch.FloatTensor | None = None, + past_key_values: Tuple[Tuple[torch.FloatTensor, ...], ...] | None = None, + use_cache: bool | None = None, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, @@ -801,11 +801,11 @@ def __init__(self, config: UtrBertConfig): def forward( self, hidden_states: Tensor, - attention_mask: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor, ...], ...]] = None, + attention_mask: torch.FloatTensor | None = None, + head_mask: torch.FloatTensor | None = None, + encoder_hidden_states: torch.FloatTensor | None = None, + encoder_attention_mask: torch.FloatTensor | None = None, + past_key_value: Tuple[Tuple[torch.FloatTensor, ...], ...] | None = None, output_attentions: bool = False, ) -> Tuple[Tensor, ...]: # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 @@ -870,7 +870,7 @@ def feed_forward_chunk(self, attention_output): class UtrBertAttention(nn.Module): - def __init__(self, config: UtrBertConfig, position_embedding_type: Optional[str] = None): + def __init__(self, config: UtrBertConfig, position_embedding_type: str | None = None): super().__init__() self.self = UtrBertSelfAttention(config, position_embedding_type=position_embedding_type) self.output = UtrBertSelfOutput(config) @@ -897,11 +897,11 @@ def prune_heads(self, heads): def forward( self, hidden_states: Tensor, - attention_mask: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor, ...], ...]] = None, + attention_mask: torch.FloatTensor | None = None, + head_mask: torch.FloatTensor | None = None, + encoder_hidden_states: torch.FloatTensor | None = None, + encoder_attention_mask: torch.FloatTensor | None = None, + past_key_value: Tuple[Tuple[torch.FloatTensor, ...], ...] | None = None, output_attentions: bool = False, ) -> Tuple[Tensor, ...]: self_outputs = self.self( @@ -919,7 +919,7 @@ def forward( class UtrBertSelfAttention(nn.Module): - def __init__(self, config: UtrBertConfig, position_embedding_type: Optional[str] = None): + def __init__(self, config: UtrBertConfig, position_embedding_type: str | None = None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -951,11 +951,11 @@ def transpose_for_scores(self, x: Tensor) -> Tensor: def forward( self, hidden_states: Tensor, - attention_mask: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor, ...], ...]] = None, + attention_mask: torch.FloatTensor | None = None, + head_mask: torch.FloatTensor | None = None, + encoder_hidden_states: torch.FloatTensor | None = None, + encoder_attention_mask: torch.FloatTensor | None = None, + past_key_value: Tuple[Tuple[torch.FloatTensor, ...], ...] | None = None, output_attentions: bool = False, ) -> Tuple[Tensor, ...]: mixed_query_layer = self.query(hidden_states) diff --git a/multimolecule/models/utrlm/convert_checkpoint.py b/multimolecule/models/utrlm/convert_checkpoint.py index 3a7e25e9..3b9dd31e 100644 --- a/multimolecule/models/utrlm/convert_checkpoint.py +++ b/multimolecule/models/utrlm/convert_checkpoint.py @@ -1,5 +1,4 @@ import os -from typing import Optional import chanfig import torch @@ -117,8 +116,8 @@ class ConvertConfig: output_path: str = Config.model_type push_to_hub: bool = False delete_existing: bool = False - repo_id: Optional[str] = None - token: Optional[str] = None + repo_id: str | None = None + token: str | None = None def post(self): if self.repo_id is None: diff --git a/multimolecule/models/utrlm/modeling_utrlm.py b/multimolecule/models/utrlm/modeling_utrlm.py index 15aba14c..64f9ee03 100755 --- a/multimolecule/models/utrlm/modeling_utrlm.py +++ b/multimolecule/models/utrlm/modeling_utrlm.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from typing import List, Tuple, Union import torch import torch.utils.checkpoint @@ -100,17 +100,17 @@ class PreTrainedModel def forward( self, input_ids: Tensor | NestedTensor, - attention_mask: Optional[Tensor] = None, - position_ids: Optional[Tensor] = None, - head_mask: Optional[Tensor] = None, - inputs_embeds: Optional[Tensor] = None, - encoder_hidden_states: Optional[Tensor] = None, - encoder_attention_mask: Optional[Tensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, + attention_mask: Tensor | None = None, + position_ids: Tensor | None = None, + head_mask: Tensor | None = None, + inputs_embeds: Tensor | None = None, + encoder_hidden_states: Tensor | None = None, + encoder_attention_mask: Tensor | None = None, + past_key_values: List[torch.FloatTensor] | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, ) -> Union[Tuple[Tensor, ...], BaseModelOutputWithPoolingAndCrossAttentions]: r""" encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): @@ -255,16 +255,16 @@ def __init__(self, config: UtrLmConfig): def forward( self, input_ids: Tensor | NestedTensor, - attention_mask: Optional[Tensor] = None, - position_ids: Optional[Tensor] = None, - head_mask: Optional[Tensor] = None, - inputs_embeds: Optional[Tensor] = None, - encoder_hidden_states: Optional[Tensor] = None, - encoder_attention_mask: Optional[Tensor] = None, - labels: Optional[Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, + attention_mask: Tensor | None = None, + position_ids: Tensor | None = None, + head_mask: Tensor | None = None, + inputs_embeds: Tensor | None = None, + encoder_hidden_states: Tensor | None = None, + encoder_attention_mask: Tensor | None = None, + labels: Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, ) -> Union[Tuple[Tensor, ...], MaskedLMOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -340,19 +340,19 @@ def set_output_embeddings(self, embeddings): def forward( self, input_ids: Tensor | NestedTensor, - attention_mask: Optional[Tensor] = None, - position_ids: Optional[Tensor] = None, - head_mask: Optional[Tensor] = None, - inputs_embeds: Optional[Tensor] = None, - encoder_hidden_states: Optional[Tensor] = None, - encoder_attention_mask: Optional[Tensor] = None, - labels: Optional[Tensor] = None, - labels_contact: Optional[Tensor] = None, - labels_structure: Optional[Tensor] = None, - labels_supervised: Optional[Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, + attention_mask: Tensor | None = None, + position_ids: Tensor | None = None, + head_mask: Tensor | None = None, + inputs_embeds: Tensor | None = None, + encoder_hidden_states: Tensor | None = None, + encoder_attention_mask: Tensor | None = None, + labels: Tensor | None = None, + labels_contact: Tensor | None = None, + labels_structure: Tensor | None = None, + labels_supervised: Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, ) -> Union[Tuple[Tensor, ...], UtrLmForPretrainingOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -432,14 +432,14 @@ def __init__(self, config: UtrLmConfig): def forward( self, input_ids: Tensor | NestedTensor, - attention_mask: Optional[Tensor] = None, - position_ids: Optional[Tensor] = None, - head_mask: Optional[Tensor] = None, - inputs_embeds: Optional[Tensor] = None, - labels: Optional[Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, + attention_mask: Tensor | None = None, + position_ids: Tensor | None = None, + head_mask: Tensor | None = None, + inputs_embeds: Tensor | None = None, + labels: Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, ) -> Union[Tuple[Tensor, ...], SequenceClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -516,14 +516,14 @@ def __init__(self, config: UtrLmConfig): def forward( self, input_ids: Tensor | NestedTensor, - attention_mask: Optional[Tensor] = None, - position_ids: Optional[Tensor] = None, - head_mask: Optional[Tensor] = None, - inputs_embeds: Optional[Tensor] = None, - labels: Optional[Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, + attention_mask: Tensor | None = None, + position_ids: Tensor | None = None, + head_mask: Tensor | None = None, + inputs_embeds: Tensor | None = None, + labels: Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, ) -> Union[Tuple[Tensor, ...], TokenClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -598,8 +598,8 @@ def __init__(self, config: UtrLmConfig): def forward( self, input_ids: Tensor | NestedTensor, - attention_mask: Optional[Tensor] = None, - labels: Optional[Tensor] = None, + attention_mask: Tensor | None = None, + labels: Tensor | None = None, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, @@ -789,12 +789,12 @@ def __init__(self, config: UtrLmConfig): def forward( self, hidden_states: Tensor, - attention_mask: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor, ...], ...]] = None, - use_cache: Optional[bool] = None, + attention_mask: torch.FloatTensor | None = None, + head_mask: torch.FloatTensor | None = None, + encoder_hidden_states: torch.FloatTensor | None = None, + encoder_attention_mask: torch.FloatTensor | None = None, + past_key_values: Tuple[Tuple[torch.FloatTensor, ...], ...] | None = None, + use_cache: bool | None = None, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, @@ -893,11 +893,11 @@ def __init__(self, config: UtrLmConfig): def forward( self, hidden_states: Tensor, - attention_mask: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor, ...], ...]] = None, + attention_mask: torch.FloatTensor | None = None, + head_mask: torch.FloatTensor | None = None, + encoder_hidden_states: torch.FloatTensor | None = None, + encoder_attention_mask: torch.FloatTensor | None = None, + past_key_value: Tuple[Tuple[torch.FloatTensor, ...], ...] | None = None, output_attentions: bool = False, ) -> Tuple[Tensor, ...]: # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 @@ -990,11 +990,11 @@ def prune_heads(self, heads): def forward( self, hidden_states: Tensor, - attention_mask: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor, ...], ...]] = None, + attention_mask: torch.FloatTensor | None = None, + head_mask: torch.FloatTensor | None = None, + encoder_hidden_states: torch.FloatTensor | None = None, + encoder_attention_mask: torch.FloatTensor | None = None, + past_key_value: Tuple[Tuple[torch.FloatTensor, ...], ...] | None = None, output_attentions: bool = False, ) -> Tuple[Tensor, ...]: hidden_states_ln = self.layer_norm(hidden_states) @@ -1013,7 +1013,7 @@ def forward( class UtrLmSelfAttention(nn.Module): - def __init__(self, config: UtrLmConfig, position_embedding_type: Optional[str] = None): + def __init__(self, config: UtrLmConfig, position_embedding_type: str | None = None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -1048,11 +1048,11 @@ def transpose_for_scores(self, x: Tensor) -> Tensor: def forward( self, hidden_states: Tensor, - attention_mask: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor, ...], ...]] = None, + attention_mask: torch.FloatTensor | None = None, + head_mask: torch.FloatTensor | None = None, + encoder_hidden_states: torch.FloatTensor | None = None, + encoder_attention_mask: torch.FloatTensor | None = None, + past_key_value: Tuple[Tuple[torch.FloatTensor, ...], ...] | None = None, output_attentions: bool = False, ) -> Tuple[Tensor, ...]: mixed_query_layer = self.query(hidden_states) @@ -1220,9 +1220,9 @@ def __init__(self, config: UtrLmConfig): def forward( self, outputs: BaseModelOutputWithPastAndCrossAttentions | Tuple[Tensor, ...], - attention_mask: Optional[Tensor] = None, + attention_mask: Tensor | None = None, input_ids: Tensor | NestedTensor | None = None, - ) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: + ) -> Tuple[Tensor, Tensor, Tensor | None, Tensor | None]: logits = self.predictions(outputs) contact_map = self.contact(torch.stack(outputs[-1], 1), attention_mask, input_ids) structure = self.structure(outputs) if self.structure else None @@ -1232,13 +1232,13 @@ def forward( @dataclass class UtrLmForPretrainingOutput(ModelOutput): - loss: Optional[torch.FloatTensor] = None + loss: torch.FloatTensor | None = None logits: torch.FloatTensor = None - contact_map: Optional[torch.FloatTensor] = None - structure: Optional[torch.FloatTensor] = None - supervised: Optional[torch.FloatTensor] = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + contact_map: torch.FloatTensor | None = None + structure: torch.FloatTensor | None = None + supervised: torch.FloatTensor | None = None + hidden_states: Tuple[torch.FloatTensor, ...] | None = None + attentions: Tuple[torch.FloatTensor, ...] | None = None def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0): diff --git a/multimolecule/tokenisers/rna/tokenization_rna.py b/multimolecule/tokenisers/rna/tokenization_rna.py index 5eeafacd..f97fd7e5 100755 --- a/multimolecule/tokenisers/rna/tokenization_rna.py +++ b/multimolecule/tokenisers/rna/tokenization_rna.py @@ -1,7 +1,7 @@ from __future__ import annotations import os -from typing import List, Optional +from typing import List from transformers.tokenization_utils import PreTrainedTokenizer from transformers.utils import logging @@ -84,7 +84,7 @@ def id_to_token(self, index: int) -> str: return self._id_to_token.get(index, self.unk_token) def build_inputs_with_special_tokens( - self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + self, token_ids_0: List[int], token_ids_1: List[int] | None = None ) -> List[int]: cls = [self.cls_token_id] sep = [self.eos_token_id] # No sep token in RnaBert vocabulary @@ -98,7 +98,7 @@ def build_inputs_with_special_tokens( return cls + token_ids_0 + sep + token_ids_1 + sep # Multiple inputs always have an EOS token def get_special_tokens_mask( - self, token_ids_0: List, token_ids_1: Optional[List] = None, already_has_special_tokens: bool = False + self, token_ids_0: List, token_ids_1: List | None = None, already_has_special_tokens: bool = False ) -> List[int]: """ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding @@ -128,7 +128,7 @@ def get_special_tokens_mask( mask += [0] * len(token_ids_1) + [1] return mask - def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None): + def save_vocabulary(self, save_directory: str, filename_prefix: str | None = None): vocab_file = os.path.join(save_directory, (filename_prefix + "-" if filename_prefix else "") + "vocab.txt") with open(vocab_file, "w") as f: f.write("\n".join(self.all_tokens))