Skip to content

Commit

Permalink
fix Optional type hints
Browse files Browse the repository at this point in the history
Signed-off-by: Zhiyuan Chen <[email protected]>
  • Loading branch information
ZhiyuanChen committed May 2, 2024
1 parent 183a11b commit e086816
Show file tree
Hide file tree
Showing 16 changed files with 451 additions and 458 deletions.
74 changes: 37 additions & 37 deletions multimolecule/downstream/crispr_off_target.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Optional, Tuple
from typing import Tuple

import torch
from torch import Tensor
Expand Down Expand Up @@ -45,10 +45,10 @@ def __init__(self, config: RnaBertConfig):
def forward(
self,
input_ids: Tensor,
target_input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[Tensor] = None,
target_attention_mask: Optional[torch.Tensor] = None,
labels: Optional[Tensor] = None,
target_input_ids: Tensor | None = None,
attention_mask: Tensor | None = None,
target_attention_mask: Tensor | None = None,
labels: Tensor | None = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
Expand Down Expand Up @@ -142,10 +142,10 @@ def __init__(self, config: RnaFmConfig):
def forward(
self,
input_ids: Tensor,
target_input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[Tensor] = None,
target_attention_mask: Optional[torch.Tensor] = None,
labels: Optional[Tensor] = None,
target_input_ids: Tensor | None = None,
attention_mask: Tensor | None = None,
target_attention_mask: Tensor | None = None,
labels: Tensor | None = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
Expand Down Expand Up @@ -239,10 +239,10 @@ def __init__(self, config: RnaMsmConfig):
def forward(
self,
input_ids: Tensor,
target_input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[Tensor] = None,
target_attention_mask: Optional[torch.Tensor] = None,
labels: Optional[Tensor] = None,
target_input_ids: Tensor | None = None,
attention_mask: Tensor | None = None,
target_attention_mask: Tensor | None = None,
labels: Tensor | None = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
Expand Down Expand Up @@ -338,10 +338,10 @@ def __init__(self, config: SpliceBertConfig):
def forward(
self,
input_ids: Tensor,
target_input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[Tensor] = None,
target_attention_mask: Optional[torch.Tensor] = None,
labels: Optional[Tensor] = None,
target_input_ids: Tensor | None = None,
attention_mask: Tensor | None = None,
target_attention_mask: Tensor | None = None,
labels: Tensor | None = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
Expand Down Expand Up @@ -435,10 +435,10 @@ def __init__(self, config: UtrBertConfig):
def forward(
self,
input_ids: Tensor,
target_input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[Tensor] = None,
target_attention_mask: Optional[torch.Tensor] = None,
labels: Optional[Tensor] = None,
target_input_ids: Tensor | None = None,
attention_mask: Tensor | None = None,
target_attention_mask: Tensor | None = None,
labels: Tensor | None = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
Expand Down Expand Up @@ -532,10 +532,10 @@ def __init__(self, config: UtrLmConfig):
def forward(
self,
input_ids: Tensor,
target_input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[Tensor] = None,
target_attention_mask: Optional[torch.Tensor] = None,
labels: Optional[Tensor] = None,
target_input_ids: Tensor | None = None,
attention_mask: Tensor | None = None,
target_attention_mask: Tensor | None = None,
labels: Tensor | None = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
Expand Down Expand Up @@ -603,21 +603,21 @@ def forward(

@dataclass
class CrisprOffTargetOutput(ModelOutput):
loss: Optional[torch.FloatTensor] = None
loss: torch.FloatTensor | None = None
logits: torch.FloatTensor = None
sgrna_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
sgrna_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
target_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
target_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
sgrna_hidden_states: Tuple[torch.FloatTensor, ...] | None = None
sgrna_attentions: Tuple[torch.FloatTensor, ...] | None = None
target_hidden_states: Tuple[torch.FloatTensor, ...] | None = None
target_attentions: Tuple[torch.FloatTensor, ...] | None = None


@dataclass
class RnaMsmForCrisprOffTargetOutput(ModelOutput):
loss: Optional[torch.FloatTensor] = None
loss: torch.FloatTensor | None = None
logits: torch.FloatTensor = None
sgrna_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
sgrna_col_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
sgrna_row_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
target_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
target_col_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
target_row_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
sgrna_hidden_states: Tuple[torch.FloatTensor, ...] | None = None
sgrna_col_attentions: Tuple[torch.FloatTensor, ...] | None = None
sgrna_row_attentions: Tuple[torch.FloatTensor, ...] | None = None
target_hidden_states: Tuple[torch.FloatTensor, ...] | None = None
target_col_attentions: Tuple[torch.FloatTensor, ...] | None = None
target_row_attentions: Tuple[torch.FloatTensor, ...] | None = None
19 changes: 9 additions & 10 deletions multimolecule/models/configuration_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

from dataclasses import asdict, dataclass, is_dataclass
from typing import Optional

from transformers.configuration_utils import PretrainedConfig as _PretrainedConfig

Expand Down Expand Up @@ -70,15 +69,15 @@ class HeadConfig:
`"single_label_classification"` or `"multi_label_classification"`.
"""

hidden_size: Optional[int] = None
hidden_size: int | None = None
dropout: float = 0.0
transform: Optional[str] = None
transform_act: Optional[str] = "gelu"
transform: str | None = None
transform_act: str | None = "gelu"
bias: bool = True
act: Optional[str] = None
act: str | None = None
layer_norm_eps: float = 1e-12
num_labels: int = 1
problem_type: Optional[str] = None
problem_type: str | None = None


@dataclass
Expand Down Expand Up @@ -108,10 +107,10 @@ class MaskedLMHeadConfig:
The epsilon used by the layer normalization layers.
"""

hidden_size: Optional[int] = None
hidden_size: int | None = None
dropout: float = 0.0
transform: Optional[str] = "nonlinear"
transform_act: Optional[str] = "gelu"
transform: str | None = "nonlinear"
transform_act: str | None = "gelu"
bias: bool = True
act: Optional[str] = None
act: str | None = None
layer_norm_eps: float = 1e-12
22 changes: 11 additions & 11 deletions multimolecule/models/modeling_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from functools import partial
from typing import Optional, Tuple
from typing import Tuple

import torch
from chanfig import ConfigRegistry
Expand Down Expand Up @@ -42,7 +42,7 @@ def __init__(self, config: PretrainedConfig):
self.activation = ACT2FN[self.config.act] if self.config.act is not None else None

def forward(
self, attentions: Tensor, attention_mask: Optional[Tensor] = None, input_ids: Optional[Tensor] = None
self, attentions: Tensor, attention_mask: Tensor | None = None, input_ids: Tensor | None = None
) -> Tensor:
if attention_mask is None:
if input_ids is None:
Expand Down Expand Up @@ -96,7 +96,7 @@ def forward(
class MaskedLMHead(nn.Module):
"""Head for masked language modeling."""

def __init__(self, config: PretrainedConfig, weight: Optional[Tensor] = None):
def __init__(self, config: PretrainedConfig, weight: Tensor | None = None):
super().__init__()
self.config = config.lm_head if hasattr(config, "lm_head") else config.head
if self.config.hidden_size is None:
Expand Down Expand Up @@ -185,8 +185,8 @@ def __init__(self, config: PretrainedConfig):
def forward( # pylint: disable=arguments-renamed
self,
outputs: ModelOutput | Tuple[Tensor, ...],
attention_mask: Optional[Tensor] = None,
input_ids: Optional[Tensor] = None,
attention_mask: Tensor | None = None,
input_ids: Tensor | None = None,
) -> Tensor:
if attention_mask is None:
if input_ids is None:
Expand Down Expand Up @@ -214,8 +214,8 @@ def __init__(self, config: PretrainedConfig):
def forward( # pylint: disable=arguments-renamed
self,
outputs: ModelOutput | Tuple[Tensor, ...],
attention_mask: Optional[Tensor] = None,
input_ids: Optional[Tensor] = None,
attention_mask: Tensor | None = None,
input_ids: Tensor | None = None,
) -> Tensor:
if attention_mask is None:
if input_ids is None:
Expand Down Expand Up @@ -267,8 +267,8 @@ def __init__(self, config: PretrainedConfig):
def forward( # pylint: disable=arguments-renamed
self,
outputs: ModelOutput | Tuple[Tensor, ...],
attention_mask: Optional[Tensor] = None,
input_ids: Optional[Tensor] = None,
attention_mask: Tensor | None = None,
input_ids: Tensor | None = None,
) -> Tensor:
if attention_mask is None:
if input_ids is None:
Expand Down Expand Up @@ -348,8 +348,8 @@ def unfold_kmer_embeddings(
embeddings: Tensor,
attention_mask: Tensor,
nmers: int,
bos_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
bos_token_id: int | None = None,
eos_token_id: int | None = None,
) -> Tensor:
r"""
Unfold k-mer embeddings to token embeddings.
Expand Down
5 changes: 2 additions & 3 deletions multimolecule/models/rnabert/convert_checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
from typing import Optional

import chanfig
import torch
Expand Down Expand Up @@ -98,8 +97,8 @@ class ConvertConfig:
output_path: str = Config.model_type
push_to_hub: bool = False
delete_existing: bool = False
repo_id: Optional[str] = None
token: Optional[str] = None
repo_id: str | None = None
token: str | None = None

def post(self):
if self.repo_id is None:
Expand Down
34 changes: 17 additions & 17 deletions multimolecule/models/rnabert/modeling_rnabert.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import math
from dataclasses import dataclass
from typing import Optional, Tuple
from typing import Tuple

import torch
from danling import NestedTensor
Expand Down Expand Up @@ -81,7 +81,7 @@ def __init__(self, config: RnaBertConfig, add_pooling_layer: bool = True):
def forward(
self,
input_ids: Tensor | NestedTensor,
attention_mask: Optional[Tensor] = None,
attention_mask: Tensor | None = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
Expand Down Expand Up @@ -139,8 +139,8 @@ def __init__(self, config: RnaBertConfig):
def forward(
self,
input_ids: Tensor | NestedTensor,
attention_mask: Optional[Tensor] = None,
labels: Optional[torch.Tensor] = None,
attention_mask: Tensor | None = None,
labels: Tensor | None = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
Expand Down Expand Up @@ -192,10 +192,10 @@ def __init__(self, config: RnaBertConfig):
def forward(
self,
input_ids: Tensor | NestedTensor,
attention_mask: Optional[Tensor] = None,
labels: Optional[torch.Tensor] = None,
labels_ss: Optional[torch.Tensor] = None,
next_sentence_label: Optional[torch.Tensor] = None,
attention_mask: Tensor | None = None,
labels: Tensor | None = None,
labels_ss: Tensor | None = None,
next_sentence_label: Tensor | None = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
Expand Down Expand Up @@ -257,8 +257,8 @@ def __init__(self, config: RnaBertConfig):
def forward(
self,
input_ids: Tensor | NestedTensor,
attention_mask: Optional[Tensor] = None,
labels: Optional[Tensor] = None,
attention_mask: Tensor | None = None,
labels: Tensor | None = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
Expand Down Expand Up @@ -336,8 +336,8 @@ def __init__(self, config: RnaBertConfig):
def forward(
self,
input_ids: Tensor | NestedTensor,
attention_mask: Optional[Tensor] = None,
labels: Optional[Tensor] = None,
attention_mask: Tensor | None = None,
labels: Tensor | None = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
Expand Down Expand Up @@ -415,8 +415,8 @@ def __init__(self, config: RnaBertConfig):
def forward(
self,
input_ids: Tensor | NestedTensor,
attention_mask: Optional[Tensor] = None,
labels: Optional[Tensor] = None,
attention_mask: Tensor | None = None,
labels: Tensor | None = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
Expand Down Expand Up @@ -698,11 +698,11 @@ def forward(self, outputs: ModelOutput | Tuple[Tensor, ...]) -> Tuple[Tensor, Te

@dataclass
class RnaBertForPretrainingOutput(ModelOutput):
loss: Optional[torch.FloatTensor] = None
loss: torch.FloatTensor | None = None
logits: torch.FloatTensor = None # type: ignore[assignment]
logits_ss: torch.FloatTensor = None # type: ignore[assignment]
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
hidden_states: Tuple[torch.FloatTensor, ...] | None = None
attentions: Tuple[torch.FloatTensor, ...] | None = None


class RnaBertLayerNorm(nn.Module):
Expand Down
5 changes: 2 additions & 3 deletions multimolecule/models/rnafm/convert_checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
from typing import Optional

import chanfig
import torch
Expand Down Expand Up @@ -135,8 +134,8 @@ class ConvertConfig:
output_path: str = Config.model_type
push_to_hub: bool = False
delete_existing: bool = False
repo_id: Optional[str] = None
token: Optional[str] = None
repo_id: str | None = None
token: str | None = None

def post(self):
if self.repo_id is None:
Expand Down
Loading

0 comments on commit e086816

Please sign in to comment.