-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
2bea376
commit a02aa49
Showing
8 changed files
with
676 additions
and
379 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,17 @@ | ||
from .rnabert import RnaBertConfig, RnaBertModel, RnaTokenizer | ||
from .rnabert import ( | ||
RnaBertConfig, | ||
RnaBertForMaskedLM, | ||
RnaBertForSequenceClassification, | ||
RnaBertForTokenClassification, | ||
RnaBertModel, | ||
RnaTokenizer, | ||
) | ||
|
||
__all__ = ["RnaBertConfig", "RnaBertModel", "RnaTokenizer"] | ||
__all__ = [ | ||
"RnaBertConfig", | ||
"RnaBertModel", | ||
"RnaBertForMaskedLM", | ||
"RnaBertForSequenceClassification", | ||
"RnaBertForTokenClassification", | ||
"RnaTokenizer", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
from __future__ import annotations | ||
|
||
from dataclasses import asdict, dataclass, is_dataclass | ||
from typing import Optional | ||
|
||
from transformers.configuration_utils import PretrainedConfig as _PretrainedConfig | ||
|
||
|
||
class PretrainedConfig(_PretrainedConfig): | ||
head: HeadConfig | ||
|
||
def to_dict(self): | ||
""" | ||
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. | ||
Returns: | ||
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, | ||
""" | ||
output = super().to_dict() | ||
for k, v in output.items(): | ||
if hasattr(v, "to_dict"): | ||
output[k] = v.to_dict() | ||
if is_dataclass(v): | ||
output[k] = asdict(v) | ||
return output | ||
|
||
|
||
@dataclass | ||
class HeadConfig: | ||
r""" | ||
This is the configuration class to store the configuration of a prediction head. It is used to instantiate a | ||
prediction head according to the specified arguments, defining the head architecture. | ||
Configuration objects inherit from [`HeadConfig`] and can be used to control the model outputs. Read the | ||
documentation from [`HeadConfig`] for more information. | ||
Args: | ||
hidden_size (`int`, *optional*, defaults to 768): | ||
Dimensionality of the encoder layers and the pooler layer. | ||
dropout (`float`, *optional*, defaults to 0.0): | ||
The dropout ratio for the hidden states. | ||
transform (`str`, *optional*, defaults to None): | ||
The transform operation applied to hidden states. | ||
transform_act (`str`, *optional*, defaults to "tanh"): | ||
The activation function of transform applied to hidden states. | ||
bias (`bool`, *optional*, defaults to True): | ||
Whether to apply bias to the final prediction layer. | ||
act (`str`, *optional*, defaults to None): | ||
The activation function of the final prediction output. | ||
layer_norm_eps (`float`, *optional*, defaults to 1e-12): | ||
The epsilon used by the layer normalization layers. | ||
""" | ||
|
||
hidden_size: Optional[int] = None | ||
dropout: float = 0.0 | ||
transform: Optional[str] = None | ||
transform_act: Optional[str] = "tanh" | ||
bias: bool = True | ||
act: Optional[str] = None | ||
layer_norm_eps: float = 1e-12 | ||
|
||
|
||
@dataclass | ||
class MaskedLMHeadConfig: | ||
r""" | ||
This is the configuration class to store the configuration of a prediction head. It is used to instantiate a | ||
prediction head according to the specified arguments, defining the head architecture. | ||
Configuration objects inherit from [`HeadConfig`] and can be used to control the model outputs. Read the | ||
documentation from [`HeadConfig`] for more information. | ||
Args: | ||
hidden_size (`int`, *optional*, defaults to 768): | ||
Dimensionality of the encoder layers and the pooler layer. | ||
dropout (`float`, *optional*, defaults to 0.0): | ||
The dropout ratio for the hidden states. | ||
transform (`str`, *optional*, defaults to "nonlinear"): | ||
The transform operation applied to hidden states. | ||
transform_act (`str`, *optional*, defaults to "gelu"): | ||
The activation function of transform applied to hidden states. | ||
bias (`bool`, *optional*, defaults to True): | ||
Whether to apply bias to the final prediction layer. | ||
act (`str`, *optional*, defaults to None): | ||
The activation function of the final prediction output. | ||
layer_norm_eps (`float`, *optional*, defaults to 1e-12): | ||
The epsilon used by the layer normalization layers. | ||
""" | ||
|
||
hidden_size: Optional[int] = None | ||
dropout: float = 0.0 | ||
transform: Optional[str] = "nonlinear" | ||
transform_act: Optional[str] = "gelu" | ||
bias: bool = True | ||
act: Optional[str] = None | ||
layer_norm_eps: float = 1e-12 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,36 @@ | ||
from transformers import AutoConfig, AutoModel, AutoTokenizer | ||
from transformers import ( | ||
AutoConfig, | ||
AutoModel, | ||
AutoModelForMaskedLM, | ||
AutoModelForSequenceClassification, | ||
AutoModelForTokenClassification, | ||
AutoModelWithLMHead, | ||
AutoTokenizer, | ||
) | ||
|
||
from multimolecule.tokenizers.rna import RnaTokenizer | ||
|
||
from .configuration_rnabert import RnaBertConfig | ||
from .modeling_rnabert import RnaBertModel | ||
from .modeling_rnabert import ( | ||
RnaBertForMaskedLM, | ||
RnaBertForSequenceClassification, | ||
RnaBertForTokenClassification, | ||
RnaBertModel, | ||
) | ||
|
||
__all__ = ["RnaBertConfig", "RnaBertModel", "RnaTokenizer"] | ||
__all__ = [ | ||
"RnaBertConfig", | ||
"RnaBertModel", | ||
"RnaTokenizer", | ||
"RnaBertForMaskedLM", | ||
"RnaBertForSequenceClassification", | ||
"RnaBertForTokenClassification", | ||
] | ||
|
||
AutoConfig.register("rnabert", RnaBertConfig) | ||
AutoModel.register(RnaBertConfig, RnaBertModel) | ||
AutoModelForMaskedLM.register(RnaBertConfig, RnaBertForMaskedLM) | ||
AutoModelForSequenceClassification.register(RnaBertConfig, RnaBertForSequenceClassification) | ||
AutoModelForTokenClassification.register(RnaBertConfig, RnaBertForTokenClassification) | ||
AutoModelWithLMHead.register(RnaBertConfig, RnaBertForTokenClassification) | ||
AutoTokenizer.register(RnaBertConfig, RnaTokenizer) |
Oops, something went wrong.