diff --git a/multimolecule/models/modeling_utils.py b/multimolecule/models/modeling_utils.py index 35fbcd2b..99380631 100644 --- a/multimolecule/models/modeling_utils.py +++ b/multimolecule/models/modeling_utils.py @@ -1,5 +1,6 @@ from __future__ import annotations +from functools import partial from typing import Optional, Tuple import torch @@ -10,6 +11,8 @@ from .configuration_utils import HeadConfig, PretrainedConfig +TokenHeads = ConfigRegistry(key="tokenizer_type") + class ContactPredictionHead(nn.Module): """ @@ -43,11 +46,11 @@ def forward( if attention_mask is None: if input_ids is None: raise ValueError( - "Either attention_mask or input_ids must be provided for contact prediction head to work." + "Either attention_mask or input_ids must be provided for ContactPredictionHead to work." ) if self.pad_token_id is None: raise ValueError( - "pad_token_id must be provided when attention_mask is not passed to contact prediction head." + "pad_token_id must be provided when attention_mask is not passed to ContactPredictionHead." ) attention_mask = input_ids.ne(self.pad_token_id) # In the original model, attentions for padding tokens are completely zeroed out. @@ -155,6 +158,7 @@ def forward(self, outputs: ModelOutput | Tuple[Tensor, ...]) -> Tensor: # pylin return output +@TokenHeads.register("single", default=True) class TokenClassificationHead(ClassificationHead): """Head for token-level tasks.""" @@ -163,6 +167,39 @@ def forward(self, outputs: ModelOutput | Tuple[Tensor, ...]) -> Tensor: # pylin return output +@TokenHeads.register("kmer", default=True) +class TokenKMerHead(ClassificationHead): + """Head for token-level tasks.""" + + def __init__(self, config: PretrainedConfig): + super().__init__(config) + self.nmers = config.nmers + self.bos_token_id = config.bos_token_id + self.eos_token_id = config.eos_token_id + self.pad_token_id = config.pad_token_id + self.unfold_kmer_embeddings = partial( + unfold_kmer_embeddings, nmers=self.nmers, bos_token_id=self.bos_token_id, eos_token_id=self.eos_token_id + ) + + def forward( # pylint: disable=arguments-renamed + self, + outputs: ModelOutput | Tuple[Tensor, ...], + attention_mask: Optional[Tensor] = None, + input_ids: Optional[Tensor] = None, + ) -> Tensor: + if attention_mask is None: + if input_ids is None: + raise ValueError("Either attention_mask or input_ids must be provided for TokenKMerHead to work.") + if self.pad_token_id is None: + raise ValueError("pad_token_id must be provided when attention_mask is not passed to TokenKMerHead.") + attention_mask = input_ids.ne(self.pad_token_id) + + output = outputs[0] + output = self.unfold_kmer_embeddings(output, attention_mask) + output = super().forward(output) + return output + + PredictionHeadTransform = ConfigRegistry(key="transform") @@ -203,6 +240,112 @@ def __init__(self, config: HeadConfig): # pylint: disable=unused-argument super().__init__() +def unfold_kmer_embeddings( + embeddings: Tensor, + attention_mask: Tensor, + nmers: int, + bos_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, +) -> Tensor: + r""" + Unfold k-mer embeddings to token embeddings. + + For k-mer input, each embedding column represents k tokens. + This should be fine for sequence level tasks, but sacrifices the resolution for token level tasks. + This function unfolds the k-mer embeddings to token embeddings by sliding averaging the k-mer embeddings. + + For example: + + input tokens = `ACGU` + + 2-mer embeddings = `[, AC, CG, GU, ]`. + + token embeddings = `[, AC, (AC + CG) / 2, (CG + GU) / 2, GU, ]`. + + Args: + embeddings: The k-mer embeddings. + attention_mask: The attention mask. + nmers: The number of tokens in each k-mer. + bos_token_id: The id of the beginning of sequence token. + If not None, the first valid token will not be included in sliding averaging. + eos_token_id: The id of the end of sequence token. + If not None, the last valid token will not be included in sliding averaging. + + Returns: + The token embeddings. + + Examples: + >>> from danling import NestedTensor + >>> embeddings = NestedTensor(torch.arange(3).repeat(2, 1).T, torch.arange(4).repeat(2, 1).T) + 1 + >>> output = unfold_kmer_embeddings(embeddings.tensor.float(), embeddings.mask, 2, True, True) + >>> output[0, :, 0].tolist() + [1.0, 2.0, 2.0, 3.0, 0.0] + >>> output[1, :, 0].tolist() + [1.0, 2.0, 2.5, 3.0, 4.0] + >>> embeddings = NestedTensor(torch.arange(5).repeat(2, 1).T, torch.arange(7).repeat(2, 1).T) + 1 + >>> output = unfold_kmer_embeddings(embeddings.tensor.float(), embeddings.mask, 3, True, True) + >>> output[0, :, 0].tolist() + [1.0, 2.0, 2.5, 3.0, 3.5, 4.0, 5.0, 0.0, 0.0] + >>> output[1, :, 0].tolist() + [1.0, 2.0, 2.5, 3.0, 4.0, 5.0, 5.5, 6.0, 7.0] + >>> embeddings = NestedTensor(torch.arange(7).repeat(2, 1).T, torch.arange(11).repeat(2, 1).T) + 1 + >>> output = unfold_kmer_embeddings(embeddings.tensor.float(), embeddings.mask, 6, True, True) + >>> output[0, :, 0].tolist() + [1.0, 2.0, 2.5, 3.0, 3.5, 4.0, 4.0, 4.5, 5.0, 5.5, 6.0, 7.0, 0.0, 0.0, 0.0, 0.0] + >>> output[1, :, 0].tolist() + [1.0, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.5, 6.5, 7.5, 8.0, 8.5, 9.0, 9.5, 10.0, 11.0] + >>> embeddings = NestedTensor(torch.arange(5).repeat(2, 1).T, torch.arange(7).repeat(2, 1).T) + 1 + >>> output = unfold_kmer_embeddings(embeddings.tensor.float(), embeddings.mask, 4, True, True) + >>> output[0, :, 0].tolist() + [1.0, 2.0, 2.5, 3.0, 3.0, 3.5, 4.0, 5.0, 0.0, 0.0] + >>> output[1, :, 0].tolist() + [1.0, 2.0, 2.5, 3.0, 3.5, 4.5, 5.0, 5.5, 6.0, 7.0] + >>> embeddings = NestedTensor(torch.arange(6).repeat(2, 1).T, torch.arange(7).repeat(2, 1).T) + 1 + >>> output = unfold_kmer_embeddings(embeddings.tensor.float(), embeddings.mask, 5, True, True) + >>> output[0, :, 0].tolist() + [1.0, 2.0, 2.5, 3.0, 3.5, 3.5, 4.0, 4.5, 5.0, 6.0, 0.0] + >>> output[1, :, 0].tolist() + [1.0, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0, 7.0] + >>> embeddings = NestedTensor(torch.arange(5).repeat(2, 1).T, torch.arange(7).repeat(2, 1).T) + 1 + >>> output = unfold_kmer_embeddings(embeddings.tensor.float(), embeddings.mask, 6) + >>> output[0, :, 0].tolist() + [1.0, 1.5, 2.0, 2.5, 3.0, 3.0, 3.5, 4.0, 4.5, 5.0, 0.0, 0.0] + >>> output[1, :, 0].tolist() + [1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.5, 5.0, 5.5, 6.0, 6.5, 7.0] + """ + + batch_size, seq_length, hidden_size = embeddings.size() + last_valid_indices = attention_mask.sum(dim=-1) + output = torch.zeros(batch_size, seq_length + nmers - 1, hidden_size, device=embeddings.device) + for index, (tensor, seq_len) in enumerate(zip(embeddings, last_valid_indices)): + embedding = tensor[:seq_len] + if bos_token_id is not None: + embedding = embedding[1:] + if eos_token_id is not None: + embedding = embedding[:-1] + if len(embedding) > nmers: + begin = torch.stack([embedding[:i].mean(0) for i in range(1, nmers)]) + medium = embedding.unfold(0, nmers, 1).mean(-1) + end = torch.stack([embedding[-i:].mean(0) for i in range(nmers - 1, 0, -1)]) + embedding = torch.cat([begin, medium, end]) + elif len(embedding) > 2: + begin = torch.stack([embedding[:i].mean(0) for i in range(1, len(embedding))]) + end = torch.stack([embedding[-i:].mean(0) for i in range(nmers, 0, -1)]) + embedding = torch.cat([begin, end]) + elif len(embedding) == 2: + embedding = torch.stack([embedding[0], embedding.mean(0), embedding[1]]) + elif len(embedding) == 1: + embedding = embedding.repeat(nmers, 1) + else: + raise ValueError("Sequence length is less than nmers.") + if bos_token_id is not None: + embedding = torch.cat([tensor[0][None, :], embedding]) + if eos_token_id is not None: + embedding = torch.cat([embedding, tensor[seq_len - 1][None, :]]) + output[index, : seq_len + nmers - 1] = embedding + return output + + def rotate_half(x): x1, x2 = x.chunk(2, dim=-1) return torch.cat((-x2, x1), dim=-1) diff --git a/multimolecule/models/utrbert/modeling_utrbert.py b/multimolecule/models/utrbert/modeling_utrbert.py index f476c43e..d9044b41 100644 --- a/multimolecule/models/utrbert/modeling_utrbert.py +++ b/multimolecule/models/utrbert/modeling_utrbert.py @@ -19,7 +19,7 @@ from transformers.pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from transformers.utils import logging -from ..modeling_utils import MaskedLMHead, SequenceClassificationHead, TokenClassificationHead +from ..modeling_utils import MaskedLMHead, SequenceClassificationHead, TokenKMerHead from .configuration_utrbert import UtrBertConfig logger = logging.get_logger(__name__) @@ -219,7 +219,7 @@ class UtrBertForMaskedLM(UtrBertPreTrainedModel): """ Examples: >>> from multimolecule import UtrBertConfig, UtrBertForMaskedLM, RnaTokenizer - >>> tokenizer = RnaTokenizer(nmers=3, strameline=True) + >>> tokenizer = RnaTokenizer(nmers=2, strameline=True) >>> config = UtrBertConfig(vocab_size=tokenizer.vocab_size) >>> model = UtrBertForMaskedLM(config) >>> input = tokenizer("ACGUN", return_tensors="pt") @@ -320,7 +320,7 @@ class UtrBertForPretraining(UtrBertPreTrainedModel): """ Examples: >>> from multimolecule import UtrBertConfig, UtrBertForPretraining, RnaTokenizer - >>> tokenizer = RnaTokenizer(nmers=4, strameline=True) + >>> tokenizer = RnaTokenizer(nmers=3, strameline=True) >>> config = UtrBertConfig(vocab_size=tokenizer.vocab_size) >>> model = UtrBertForPretraining(config) >>> input = tokenizer("ACGUN", return_tensors="pt") @@ -394,7 +394,7 @@ class UtrBertForSequenceClassification(UtrBertPreTrainedModel): """ Examples: >>> from multimolecule import UtrBertConfig, UtrBertForSequenceClassification, RnaTokenizer - >>> tokenizer = RnaTokenizer(nmers=5, strameline=True) + >>> tokenizer = RnaTokenizer(nmers=4, strameline=True) >>> config = UtrBertConfig(vocab_size=tokenizer.vocab_size) >>> model = UtrBertForSequenceClassification(config) >>> input = tokenizer("ACGUN", return_tensors="pt") @@ -480,7 +480,7 @@ class UtrBertForTokenClassification(UtrBertPreTrainedModel): """ Examples: >>> from multimolecule import UtrBertConfig, UtrBertForTokenClassification, RnaTokenizer - >>> tokenizer = RnaTokenizer(nmers=6, strameline=True) + >>> tokenizer = RnaTokenizer(nmers=2, strameline=True) >>> config = UtrBertConfig(vocab_size=tokenizer.vocab_size) >>> model = UtrBertForTokenClassification(config) >>> input = tokenizer("ACGUN", return_tensors="pt") @@ -491,7 +491,7 @@ def __init__(self, config: UtrBertConfig): super().__init__(config) self.num_labels = config.num_labels self.utrbert = UtrBertModel(config, add_pooling_layer=False) - self.token_head = TokenClassificationHead(config) + self.token_head = TokenKMerHead(config) self.head_config = self.token_head.config # Initialize weights and apply final processing