Skip to content

Commit

Permalink
reorganise code
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhiyuanChen committed Apr 3, 2024
1 parent 1682649 commit 492b2e7
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 89 deletions.
28 changes: 14 additions & 14 deletions multimolecule/models/rnabert/modeling_rnabert.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,20 +433,6 @@ def forward(self, hidden_states: Tensor):
return pooled_output


class RnaBertLayerNorm(nn.Module):
def __init__(self, hidden_size: int, eps: float = 1e-12):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size)) # weightのこと
self.bias = nn.Parameter(torch.zeros(hidden_size)) # biasのこと
self.variance_epsilon = eps

def forward(self, x: Tensor):
u = x.mean(-1, keepdim=True)
s = (x - u).pow(2).mean(-1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
return self.weight * x + self.bias


class RnaBertLMHead(nn.Module):
def __init__(self, config: RnaBertConfig):
super().__init__()
Expand Down Expand Up @@ -500,3 +486,17 @@ class RnaBertMaskedLMOutput(MaskedLMOutput):
logits_ss: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None


class RnaBertLayerNorm(nn.Module):
def __init__(self, hidden_size: int, eps: float = 1e-12):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size)) # weightのこと
self.bias = nn.Parameter(torch.zeros(hidden_size)) # biasのこと
self.variance_epsilon = eps

def forward(self, x: Tensor):
u = x.mean(-1, keepdim=True)
s = (x - u).pow(2).mean(-1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
return self.weight * x + self.bias
150 changes: 75 additions & 75 deletions multimolecule/models/rnamsm/modeling_rnamsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,36 +425,6 @@ def forward(
return hidden_states, attention_probs


class RnaMsmLearnedPositionalEmbedding(nn.Embedding):
"""
This module learns positional embeddings up to a fixed maximum size.
Padding ids are ignored by either offsetting based on padding_idx
or by setting padding_idx to None and ensuring that the appropriate
position ids are passed to the forward function.
"""

def __init__(self, num_embeddings: int, *args, **kwargs):
num_embeddings += 2
super().__init__(num_embeddings, *args, **kwargs)
self.max_positions = num_embeddings

def forward(self, input_ids: Tensor, attention_mask: Optional[Tensor] = None):
"""Input is expected to be of size [bsz x seqlen]."""
if attention_mask is None:
attention_mask = input_ids.ne(self.padding_idx).int()
# This is a bug in the original implementation
positions = (torch.cumsum(attention_mask, dim=1, dtype=attention_mask.dtype) * attention_mask).long() + 1
return F.embedding(
positions,
self.weight,
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse,
)


class RowSelfAttention(nn.Module):
"""Compute self-attention over rows of a 2D input."""

Expand Down Expand Up @@ -951,6 +921,81 @@ def forward(self, hidden_states):
return hidden_states


class RnaMsmEmbeddings(nn.Module):
def __init__(self, config: RnaMsmConfig):
super().__init__()
self.max_position_embeddings = config.max_position_embeddings
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
self.position_embeddings = RnaMsmLearnedPositionalEmbedding(
self.max_position_embeddings, config.hidden_size, padding_idx=config.pad_token_id
)
if config.embed_positions_msa:
self.msa_embeddings = nn.Parameter(
0.01 * torch.randn(1, self.max_position_embeddings, 1, 1), requires_grad=True
)
else:
self.register_parameter("msa_embeddings", None) # type: ignore
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=1e-12)
self.dropout = nn.Dropout(config.hidden_dropout_prob)

def forward(self, input_ids: Tensor, attention_mask: Optional[Tensor] = None):
assert input_ids.ndim == 3
batch_size, num_alignments, seqlen = input_ids.size()
words_embeddings = self.word_embeddings(input_ids.long())
# words_embeddings = self.word_embeddings(tokens)
position_embeddings = self.position_embeddings(input_ids.view(batch_size * num_alignments, seqlen)).view(
words_embeddings.size()
)
msa_embeddings = 0
if self.msa_embeddings is not None:
if input_ids.size(1) > self.max_position_embeddings:
raise RuntimeError(
"Using model with MSA position embedding trained on maximum MSA "
f"depth of {self.max_position_embeddings}, but received {position_embeddings.size(1)} alignments."
)
msa_embeddings += self.msa_embeddings[:, :num_alignments]

embeddings = words_embeddings + position_embeddings + msa_embeddings
embeddings = self.layer_norm(embeddings)

embeddings = self.dropout(embeddings)

if attention_mask is not None:
embeddings = embeddings * (1 - attention_mask.unsqueeze(-1).type_as(embeddings))

return embeddings


class RnaMsmLearnedPositionalEmbedding(nn.Embedding):
"""
This module learns positional embeddings up to a fixed maximum size.
Padding ids are ignored by either offsetting based on padding_idx
or by setting padding_idx to None and ensuring that the appropriate
position ids are passed to the forward function.
"""

def __init__(self, num_embeddings: int, *args, **kwargs):
num_embeddings += 2
super().__init__(num_embeddings, *args, **kwargs)
self.max_positions = num_embeddings

def forward(self, input_ids: Tensor, attention_mask: Optional[Tensor] = None):
"""Input is expected to be of size [bsz x seqlen]."""
if attention_mask is None:
attention_mask = input_ids.ne(self.padding_idx).int()
# This is a bug in the original implementation
positions = (torch.cumsum(attention_mask, dim=1, dtype=attention_mask.dtype) * attention_mask).long() + 1
return F.embedding(
positions,
self.weight,
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse,
)


class RnaMsmPooler(nn.Module):
def __init__(self, config):
super().__init__()
Expand Down Expand Up @@ -1024,51 +1069,6 @@ def forward(self, input_ids, row_attentions):
return self.activation(self.regression(row_attentions).squeeze(3))


class RnaMsmEmbeddings(nn.Module):
def __init__(self, config: RnaMsmConfig):
super().__init__()
self.max_position_embeddings = config.max_position_embeddings
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
self.position_embeddings = RnaMsmLearnedPositionalEmbedding(
self.max_position_embeddings, config.hidden_size, padding_idx=config.pad_token_id
)
if config.embed_positions_msa:
self.msa_embeddings = nn.Parameter(
0.01 * torch.randn(1, self.max_position_embeddings, 1, 1), requires_grad=True
)
else:
self.register_parameter("msa_embeddings", None) # type: ignore
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=1e-12)
self.dropout = nn.Dropout(config.hidden_dropout_prob)

def forward(self, input_ids: Tensor, attention_mask: Optional[Tensor] = None):
assert input_ids.ndim == 3
batch_size, num_alignments, seqlen = input_ids.size()
words_embeddings = self.word_embeddings(input_ids.long())
# words_embeddings = self.word_embeddings(tokens)
position_embeddings = self.position_embeddings(input_ids.view(batch_size * num_alignments, seqlen)).view(
words_embeddings.size()
)
msa_embeddings = 0
if self.msa_embeddings is not None:
if input_ids.size(1) > self.max_position_embeddings:
raise RuntimeError(
"Using model with MSA position embedding trained on maximum MSA "
f"depth of {self.max_position_embeddings}, but received {position_embeddings.size(1)} alignments."
)
msa_embeddings += self.msa_embeddings[:, :num_alignments]

embeddings = words_embeddings + position_embeddings + msa_embeddings
embeddings = self.layer_norm(embeddings)

embeddings = self.dropout(embeddings)

if attention_mask is not None:
embeddings = embeddings * (1 - attention_mask.unsqueeze(-1).type_as(embeddings))

return embeddings


@dataclass
class RnaMsmMaskedLMOutput(ModelOutput):
"""
Expand Down

0 comments on commit 492b2e7

Please sign in to comment.