Skip to content

Commit

Permalink
add KMerTokenHead
Browse files Browse the repository at this point in the history
Signed-off-by: Zhiyuan Chen <[email protected]>
  • Loading branch information
ZhiyuanChen committed Apr 29, 2024
1 parent ea6079c commit e28b926
Show file tree
Hide file tree
Showing 2 changed files with 151 additions and 8 deletions.
147 changes: 145 additions & 2 deletions multimolecule/models/modeling_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from functools import partial
from typing import Optional, Tuple

import torch
Expand All @@ -10,6 +11,8 @@

from .configuration_utils import HeadConfig, PretrainedConfig

TokenHeads = ConfigRegistry(key="tokenizer_type")


class ContactPredictionHead(nn.Module):
"""
Expand Down Expand Up @@ -43,11 +46,11 @@ def forward(
if attention_mask is None:
if input_ids is None:
raise ValueError(
"Either attention_mask or input_ids must be provided for contact prediction head to work."
"Either attention_mask or input_ids must be provided for ContactPredictionHead to work."
)
if self.pad_token_id is None:
raise ValueError(
"pad_token_id must be provided when attention_mask is not passed to contact prediction head."
"pad_token_id must be provided when attention_mask is not passed to ContactPredictionHead."
)
attention_mask = input_ids.ne(self.pad_token_id)
# In the original model, attentions for padding tokens are completely zeroed out.
Expand Down Expand Up @@ -155,6 +158,7 @@ def forward(self, outputs: ModelOutput | Tuple[Tensor, ...]) -> Tensor: # pylin
return output


@TokenHeads.register("single", default=True)
class TokenClassificationHead(ClassificationHead):
"""Head for token-level tasks."""

Expand All @@ -163,6 +167,39 @@ def forward(self, outputs: ModelOutput | Tuple[Tensor, ...]) -> Tensor: # pylin
return output


@TokenHeads.register("kmer", default=True)
class TokenKMerHead(ClassificationHead):
"""Head for token-level tasks."""

def __init__(self, config: PretrainedConfig):
super().__init__(config)
self.nmers = config.nmers
self.bos_token_id = config.bos_token_id
self.eos_token_id = config.eos_token_id
self.pad_token_id = config.pad_token_id
self.unfold_kmer_embeddings = partial(
unfold_kmer_embeddings, nmers=self.nmers, bos_token_id=self.bos_token_id, eos_token_id=self.eos_token_id
)

def forward( # pylint: disable=arguments-renamed
self,
outputs: ModelOutput | Tuple[Tensor, ...],
attention_mask: Optional[Tensor] = None,
input_ids: Optional[Tensor] = None,
) -> Tensor:
if attention_mask is None:
if input_ids is None:
raise ValueError("Either attention_mask or input_ids must be provided for TokenKMerHead to work.")
if self.pad_token_id is None:
raise ValueError("pad_token_id must be provided when attention_mask is not passed to TokenKMerHead.")
attention_mask = input_ids.ne(self.pad_token_id)

output = outputs[0]
output = self.unfold_kmer_embeddings(output, attention_mask)
output = super().forward(output)
return output


PredictionHeadTransform = ConfigRegistry(key="transform")


Expand Down Expand Up @@ -203,6 +240,112 @@ def __init__(self, config: HeadConfig): # pylint: disable=unused-argument
super().__init__()


def unfold_kmer_embeddings(
embeddings: Tensor,
attention_mask: Tensor,
nmers: int,
bos_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
) -> Tensor:
r"""
Unfold k-mer embeddings to token embeddings.
For k-mer input, each embedding column represents k tokens.
This should be fine for sequence level tasks, but sacrifices the resolution for token level tasks.
This function unfolds the k-mer embeddings to token embeddings by sliding averaging the k-mer embeddings.
For example:
input tokens = `ACGU`
2-mer embeddings = `[<CLS>, AC, CG, GU, <SEP>]`.
token embeddings = `[<CLS>, AC, (AC + CG) / 2, (CG + GU) / 2, GU, <SEP>]`.
Args:
embeddings: The k-mer embeddings.
attention_mask: The attention mask.
nmers: The number of tokens in each k-mer.
bos_token_id: The id of the beginning of sequence token.
If not None, the first valid token will not be included in sliding averaging.
eos_token_id: The id of the end of sequence token.
If not None, the last valid token will not be included in sliding averaging.
Returns:
The token embeddings.
Examples:
>>> from danling import NestedTensor
>>> embeddings = NestedTensor(torch.arange(3).repeat(2, 1).T, torch.arange(4).repeat(2, 1).T) + 1
>>> output = unfold_kmer_embeddings(embeddings.tensor.float(), embeddings.mask, 2, True, True)
>>> output[0, :, 0].tolist()
[1.0, 2.0, 2.0, 3.0, 0.0]
>>> output[1, :, 0].tolist()
[1.0, 2.0, 2.5, 3.0, 4.0]
>>> embeddings = NestedTensor(torch.arange(5).repeat(2, 1).T, torch.arange(7).repeat(2, 1).T) + 1
>>> output = unfold_kmer_embeddings(embeddings.tensor.float(), embeddings.mask, 3, True, True)
>>> output[0, :, 0].tolist()
[1.0, 2.0, 2.5, 3.0, 3.5, 4.0, 5.0, 0.0, 0.0]
>>> output[1, :, 0].tolist()
[1.0, 2.0, 2.5, 3.0, 4.0, 5.0, 5.5, 6.0, 7.0]
>>> embeddings = NestedTensor(torch.arange(7).repeat(2, 1).T, torch.arange(11).repeat(2, 1).T) + 1
>>> output = unfold_kmer_embeddings(embeddings.tensor.float(), embeddings.mask, 6, True, True)
>>> output[0, :, 0].tolist()
[1.0, 2.0, 2.5, 3.0, 3.5, 4.0, 4.0, 4.5, 5.0, 5.5, 6.0, 7.0, 0.0, 0.0, 0.0, 0.0]
>>> output[1, :, 0].tolist()
[1.0, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.5, 6.5, 7.5, 8.0, 8.5, 9.0, 9.5, 10.0, 11.0]
>>> embeddings = NestedTensor(torch.arange(5).repeat(2, 1).T, torch.arange(7).repeat(2, 1).T) + 1
>>> output = unfold_kmer_embeddings(embeddings.tensor.float(), embeddings.mask, 4, True, True)
>>> output[0, :, 0].tolist()
[1.0, 2.0, 2.5, 3.0, 3.0, 3.5, 4.0, 5.0, 0.0, 0.0]
>>> output[1, :, 0].tolist()
[1.0, 2.0, 2.5, 3.0, 3.5, 4.5, 5.0, 5.5, 6.0, 7.0]
>>> embeddings = NestedTensor(torch.arange(6).repeat(2, 1).T, torch.arange(7).repeat(2, 1).T) + 1
>>> output = unfold_kmer_embeddings(embeddings.tensor.float(), embeddings.mask, 5, True, True)
>>> output[0, :, 0].tolist()
[1.0, 2.0, 2.5, 3.0, 3.5, 3.5, 4.0, 4.5, 5.0, 6.0, 0.0]
>>> output[1, :, 0].tolist()
[1.0, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0, 7.0]
>>> embeddings = NestedTensor(torch.arange(5).repeat(2, 1).T, torch.arange(7).repeat(2, 1).T) + 1
>>> output = unfold_kmer_embeddings(embeddings.tensor.float(), embeddings.mask, 6)
>>> output[0, :, 0].tolist()
[1.0, 1.5, 2.0, 2.5, 3.0, 3.0, 3.5, 4.0, 4.5, 5.0, 0.0, 0.0]
>>> output[1, :, 0].tolist()
[1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.5, 5.0, 5.5, 6.0, 6.5, 7.0]
"""

batch_size, seq_length, hidden_size = embeddings.size()
last_valid_indices = attention_mask.sum(dim=-1)
output = torch.zeros(batch_size, seq_length + nmers - 1, hidden_size, device=embeddings.device)
for index, (tensor, seq_len) in enumerate(zip(embeddings, last_valid_indices)):
embedding = tensor[:seq_len]
if bos_token_id is not None:
embedding = embedding[1:]
if eos_token_id is not None:
embedding = embedding[:-1]
if len(embedding) > nmers:
begin = torch.stack([embedding[:i].mean(0) for i in range(1, nmers)])
medium = embedding.unfold(0, nmers, 1).mean(-1)
end = torch.stack([embedding[-i:].mean(0) for i in range(nmers - 1, 0, -1)])
embedding = torch.cat([begin, medium, end])
elif len(embedding) > 2:
begin = torch.stack([embedding[:i].mean(0) for i in range(1, len(embedding))])
end = torch.stack([embedding[-i:].mean(0) for i in range(nmers, 0, -1)])
embedding = torch.cat([begin, end])
elif len(embedding) == 2:
embedding = torch.stack([embedding[0], embedding.mean(0), embedding[1]])
elif len(embedding) == 1:
embedding = embedding.repeat(nmers, 1)
else:
raise ValueError("Sequence length is less than nmers.")
if bos_token_id is not None:
embedding = torch.cat([tensor[0][None, :], embedding])
if eos_token_id is not None:
embedding = torch.cat([embedding, tensor[seq_len - 1][None, :]])
output[index, : seq_len + nmers - 1] = embedding
return output


def rotate_half(x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
Expand Down
12 changes: 6 additions & 6 deletions multimolecule/models/utrbert/modeling_utrbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from transformers.pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from transformers.utils import logging

from ..modeling_utils import MaskedLMHead, SequenceClassificationHead, TokenClassificationHead
from ..modeling_utils import MaskedLMHead, SequenceClassificationHead, TokenKMerHead
from .configuration_utrbert import UtrBertConfig

logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -219,7 +219,7 @@ class UtrBertForMaskedLM(UtrBertPreTrainedModel):
"""
Examples:
>>> from multimolecule import UtrBertConfig, UtrBertForMaskedLM, RnaTokenizer
>>> tokenizer = RnaTokenizer(nmers=3, strameline=True)
>>> tokenizer = RnaTokenizer(nmers=2, strameline=True)
>>> config = UtrBertConfig(vocab_size=tokenizer.vocab_size)
>>> model = UtrBertForMaskedLM(config)
>>> input = tokenizer("ACGUN", return_tensors="pt")
Expand Down Expand Up @@ -320,7 +320,7 @@ class UtrBertForPretraining(UtrBertPreTrainedModel):
"""
Examples:
>>> from multimolecule import UtrBertConfig, UtrBertForPretraining, RnaTokenizer
>>> tokenizer = RnaTokenizer(nmers=4, strameline=True)
>>> tokenizer = RnaTokenizer(nmers=3, strameline=True)
>>> config = UtrBertConfig(vocab_size=tokenizer.vocab_size)
>>> model = UtrBertForPretraining(config)
>>> input = tokenizer("ACGUN", return_tensors="pt")
Expand Down Expand Up @@ -394,7 +394,7 @@ class UtrBertForSequenceClassification(UtrBertPreTrainedModel):
"""
Examples:
>>> from multimolecule import UtrBertConfig, UtrBertForSequenceClassification, RnaTokenizer
>>> tokenizer = RnaTokenizer(nmers=5, strameline=True)
>>> tokenizer = RnaTokenizer(nmers=4, strameline=True)
>>> config = UtrBertConfig(vocab_size=tokenizer.vocab_size)
>>> model = UtrBertForSequenceClassification(config)
>>> input = tokenizer("ACGUN", return_tensors="pt")
Expand Down Expand Up @@ -480,7 +480,7 @@ class UtrBertForTokenClassification(UtrBertPreTrainedModel):
"""
Examples:
>>> from multimolecule import UtrBertConfig, UtrBertForTokenClassification, RnaTokenizer
>>> tokenizer = RnaTokenizer(nmers=6, strameline=True)
>>> tokenizer = RnaTokenizer(nmers=2, strameline=True)
>>> config = UtrBertConfig(vocab_size=tokenizer.vocab_size)
>>> model = UtrBertForTokenClassification(config)
>>> input = tokenizer("ACGUN", return_tensors="pt")
Expand All @@ -491,7 +491,7 @@ def __init__(self, config: UtrBertConfig):
super().__init__(config)
self.num_labels = config.num_labels
self.utrbert = UtrBertModel(config, add_pooling_layer=False)
self.token_head = TokenClassificationHead(config)
self.token_head = TokenKMerHead(config)
self.head_config = self.token_head.config

# Initialize weights and apply final processing
Expand Down

0 comments on commit e28b926

Please sign in to comment.