diff --git a/multimolecule/models/rnabert/modeling_rnabert.py b/multimolecule/models/rnabert/modeling_rnabert.py index 6ae17d59..b35ddf46 100644 --- a/multimolecule/models/rnabert/modeling_rnabert.py +++ b/multimolecule/models/rnabert/modeling_rnabert.py @@ -433,20 +433,6 @@ def forward(self, hidden_states: Tensor): return pooled_output -class RnaBertLayerNorm(nn.Module): - def __init__(self, hidden_size: int, eps: float = 1e-12): - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) # weightのこと - self.bias = nn.Parameter(torch.zeros(hidden_size)) # biasのこと - self.variance_epsilon = eps - - def forward(self, x: Tensor): - u = x.mean(-1, keepdim=True) - s = (x - u).pow(2).mean(-1, keepdim=True) - x = (x - u) / torch.sqrt(s + self.variance_epsilon) - return self.weight * x + self.bias - - class RnaBertLMHead(nn.Module): def __init__(self, config: RnaBertConfig): super().__init__() @@ -500,3 +486,17 @@ class RnaBertMaskedLMOutput(MaskedLMOutput): logits_ss: torch.FloatTensor = None hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +class RnaBertLayerNorm(nn.Module): + def __init__(self, hidden_size: int, eps: float = 1e-12): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) # weightのこと + self.bias = nn.Parameter(torch.zeros(hidden_size)) # biasのこと + self.variance_epsilon = eps + + def forward(self, x: Tensor): + u = x.mean(-1, keepdim=True) + s = (x - u).pow(2).mean(-1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.variance_epsilon) + return self.weight * x + self.bias diff --git a/multimolecule/models/rnamsm/modeling_rnamsm.py b/multimolecule/models/rnamsm/modeling_rnamsm.py index 2e6cec30..7aab5edc 100644 --- a/multimolecule/models/rnamsm/modeling_rnamsm.py +++ b/multimolecule/models/rnamsm/modeling_rnamsm.py @@ -425,36 +425,6 @@ def forward( return hidden_states, attention_probs -class RnaMsmLearnedPositionalEmbedding(nn.Embedding): - """ - This module learns positional embeddings up to a fixed maximum size. - Padding ids are ignored by either offsetting based on padding_idx - or by setting padding_idx to None and ensuring that the appropriate - position ids are passed to the forward function. - """ - - def __init__(self, num_embeddings: int, *args, **kwargs): - num_embeddings += 2 - super().__init__(num_embeddings, *args, **kwargs) - self.max_positions = num_embeddings - - def forward(self, input_ids: Tensor, attention_mask: Optional[Tensor] = None): - """Input is expected to be of size [bsz x seqlen].""" - if attention_mask is None: - attention_mask = input_ids.ne(self.padding_idx).int() - # This is a bug in the original implementation - positions = (torch.cumsum(attention_mask, dim=1, dtype=attention_mask.dtype) * attention_mask).long() + 1 - return F.embedding( - positions, - self.weight, - self.padding_idx, - self.max_norm, - self.norm_type, - self.scale_grad_by_freq, - self.sparse, - ) - - class RowSelfAttention(nn.Module): """Compute self-attention over rows of a 2D input.""" @@ -951,6 +921,81 @@ def forward(self, hidden_states): return hidden_states +class RnaMsmEmbeddings(nn.Module): + def __init__(self, config: RnaMsmConfig): + super().__init__() + self.max_position_embeddings = config.max_position_embeddings + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = RnaMsmLearnedPositionalEmbedding( + self.max_position_embeddings, config.hidden_size, padding_idx=config.pad_token_id + ) + if config.embed_positions_msa: + self.msa_embeddings = nn.Parameter( + 0.01 * torch.randn(1, self.max_position_embeddings, 1, 1), requires_grad=True + ) + else: + self.register_parameter("msa_embeddings", None) # type: ignore + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=1e-12) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, input_ids: Tensor, attention_mask: Optional[Tensor] = None): + assert input_ids.ndim == 3 + batch_size, num_alignments, seqlen = input_ids.size() + words_embeddings = self.word_embeddings(input_ids.long()) + # words_embeddings = self.word_embeddings(tokens) + position_embeddings = self.position_embeddings(input_ids.view(batch_size * num_alignments, seqlen)).view( + words_embeddings.size() + ) + msa_embeddings = 0 + if self.msa_embeddings is not None: + if input_ids.size(1) > self.max_position_embeddings: + raise RuntimeError( + "Using model with MSA position embedding trained on maximum MSA " + f"depth of {self.max_position_embeddings}, but received {position_embeddings.size(1)} alignments." + ) + msa_embeddings += self.msa_embeddings[:, :num_alignments] + + embeddings = words_embeddings + position_embeddings + msa_embeddings + embeddings = self.layer_norm(embeddings) + + embeddings = self.dropout(embeddings) + + if attention_mask is not None: + embeddings = embeddings * (1 - attention_mask.unsqueeze(-1).type_as(embeddings)) + + return embeddings + + +class RnaMsmLearnedPositionalEmbedding(nn.Embedding): + """ + This module learns positional embeddings up to a fixed maximum size. + Padding ids are ignored by either offsetting based on padding_idx + or by setting padding_idx to None and ensuring that the appropriate + position ids are passed to the forward function. + """ + + def __init__(self, num_embeddings: int, *args, **kwargs): + num_embeddings += 2 + super().__init__(num_embeddings, *args, **kwargs) + self.max_positions = num_embeddings + + def forward(self, input_ids: Tensor, attention_mask: Optional[Tensor] = None): + """Input is expected to be of size [bsz x seqlen].""" + if attention_mask is None: + attention_mask = input_ids.ne(self.padding_idx).int() + # This is a bug in the original implementation + positions = (torch.cumsum(attention_mask, dim=1, dtype=attention_mask.dtype) * attention_mask).long() + 1 + return F.embedding( + positions, + self.weight, + self.padding_idx, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.sparse, + ) + + class RnaMsmPooler(nn.Module): def __init__(self, config): super().__init__() @@ -1024,51 +1069,6 @@ def forward(self, input_ids, row_attentions): return self.activation(self.regression(row_attentions).squeeze(3)) -class RnaMsmEmbeddings(nn.Module): - def __init__(self, config: RnaMsmConfig): - super().__init__() - self.max_position_embeddings = config.max_position_embeddings - self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) - self.position_embeddings = RnaMsmLearnedPositionalEmbedding( - self.max_position_embeddings, config.hidden_size, padding_idx=config.pad_token_id - ) - if config.embed_positions_msa: - self.msa_embeddings = nn.Parameter( - 0.01 * torch.randn(1, self.max_position_embeddings, 1, 1), requires_grad=True - ) - else: - self.register_parameter("msa_embeddings", None) # type: ignore - self.layer_norm = nn.LayerNorm(config.hidden_size, eps=1e-12) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - - def forward(self, input_ids: Tensor, attention_mask: Optional[Tensor] = None): - assert input_ids.ndim == 3 - batch_size, num_alignments, seqlen = input_ids.size() - words_embeddings = self.word_embeddings(input_ids.long()) - # words_embeddings = self.word_embeddings(tokens) - position_embeddings = self.position_embeddings(input_ids.view(batch_size * num_alignments, seqlen)).view( - words_embeddings.size() - ) - msa_embeddings = 0 - if self.msa_embeddings is not None: - if input_ids.size(1) > self.max_position_embeddings: - raise RuntimeError( - "Using model with MSA position embedding trained on maximum MSA " - f"depth of {self.max_position_embeddings}, but received {position_embeddings.size(1)} alignments." - ) - msa_embeddings += self.msa_embeddings[:, :num_alignments] - - embeddings = words_embeddings + position_embeddings + msa_embeddings - embeddings = self.layer_norm(embeddings) - - embeddings = self.dropout(embeddings) - - if attention_mask is not None: - embeddings = embeddings * (1 - attention_mask.unsqueeze(-1).type_as(embeddings)) - - return embeddings - - @dataclass class RnaMsmMaskedLMOutput(ModelOutput): """