Skip to content

Commit

Permalink
add head for nucleotide-level tasks
Browse files Browse the repository at this point in the history
Signed-off-by: Zhiyuan Chen <[email protected]>
  • Loading branch information
ZhiyuanChen committed Apr 29, 2024
1 parent e28b926 commit 71ea23c
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 26 deletions.
147 changes: 123 additions & 24 deletions multimolecule/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .configuration_utils import HeadConfig, PretrainedConfig

TokenHeads = ConfigRegistry(key="tokenizer_type")
NucleotideHeads = ConfigRegistry(key="tokenizer_type")


class ContactPredictionHead(nn.Module):
Expand Down Expand Up @@ -71,9 +72,9 @@ def forward(
eos_mask = input_ids.ne(self.eos_token_id).to(attentions)
input_ids = input_ids[..., 1:]
else:
last_valid_indices = attention_mask.sum(dim=-1) - 1
last_valid_indices = attention_mask.sum(dim=-1)
seq_length = attention_mask.size(-1)
eos_mask = torch.arange(seq_length, device=attentions.device).unsqueeze(0) != last_valid_indices
eos_mask = torch.arange(seq_length, device=attentions.device).unsqueeze(0) == last_valid_indices
eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2)
attentions *= eos_mask[:, None, None, :, :]
attentions = attentions[..., :-1, :-1]
Expand Down Expand Up @@ -200,6 +201,109 @@ def forward( # pylint: disable=arguments-renamed
return output


@NucleotideHeads.register("single", default=True)
class NucleotideClassificationHead(ClassificationHead):
"""Head for nucleotide-level tasks."""

def __init__(self, config: PretrainedConfig):
super().__init__(config)
self.bos_token_id = config.bos_token_id
self.eos_token_id = config.eos_token_id
self.pad_token_id = config.pad_token_id

def forward( # pylint: disable=arguments-renamed
self,
outputs: ModelOutput | Tuple[Tensor, ...],
attention_mask: Optional[Tensor] = None,
input_ids: Optional[Tensor] = None,
) -> Tensor:
if attention_mask is None:
if input_ids is None:
raise ValueError(
"Either attention_mask or input_ids must be provided for NucleotideClassificationHead to work."
)
if self.pad_token_id is None:
raise ValueError(
"pad_token_id must be provided when attention_mask is not passed to NucleotideClassificationHead."
)
attention_mask = input_ids.ne(self.pad_token_id)

output = outputs[0]
# remove cls token embeddings
if self.bos_token_id is not None:
output = output[..., 1:, :]
attention_mask = attention_mask[..., 1:]
if input_ids is not None:
input_ids = input_ids[..., 1:]
# remove eos token embeddings
if self.eos_token_id is not None:
if input_ids is not None:
eos_mask = input_ids.ne(self.eos_token_id).to(output)
input_ids = input_ids[..., 1:]
else:
last_valid_indices = attention_mask.sum(dim=-1)
seq_length = attention_mask.size(-1)
eos_mask = torch.arange(seq_length, device=output.device) == last_valid_indices.unsqueeze(1)
output *= eos_mask[:, :, None]
output = output[..., :-1, :]
attention_mask = attention_mask[..., 1:]

output = super().forward(output)
return output


@NucleotideHeads.register("kmer", default=True)
class NucleotideKMerHead(ClassificationHead):
"""Head for nucleotide-level tasks."""

def __init__(self, config: PretrainedConfig):
super().__init__(config)
self.nmers = config.nmers
self.bos_token_id = None # Nucleotide-level head removes <cls> token.
self.eos_token_id = None # Nucleotide-level head removes <eos> token.
self.pad_token_id = config.pad_token_id
self.unfold_kmer_embeddings = partial(unfold_kmer_embeddings, nmers=self.nmers)

def forward( # pylint: disable=arguments-renamed
self,
outputs: ModelOutput | Tuple[Tensor, ...],
attention_mask: Optional[Tensor] = None,
input_ids: Optional[Tensor] = None,
) -> Tensor:
if attention_mask is None:
if input_ids is None:
raise ValueError("Either attention_mask or input_ids must be provided for NucleotideKMerHead to work.")
if self.pad_token_id is None:
raise ValueError(
"pad_token_id must be provided when attention_mask is not passed to NucleotideKMerHead."
)
attention_mask = input_ids.ne(self.pad_token_id)

output = outputs[0]
# remove cls token embeddings
if self.bos_token_id is not None:
output = output[..., 1:, :]
attention_mask = attention_mask[..., 1:]
if input_ids is not None:
input_ids = input_ids[..., 1:]
# remove eos token embeddings
if self.eos_token_id is not None:
if input_ids is not None:
eos_mask = input_ids.ne(self.eos_token_id).to(output)
input_ids = input_ids[..., 1:]
else:
last_valid_indices = attention_mask.sum(dim=-1)
seq_length = attention_mask.size(-1)
eos_mask = torch.arange(seq_length, device=output.device) == last_valid_indices.unsqueeze(1)
output *= eos_mask[:, :, None]
output = output[..., :-1, :]
attention_mask = attention_mask[..., 1:]

output = self.unfold_kmer_embeddings(output, attention_mask)
output = super().forward(output)
return output


PredictionHeadTransform = ConfigRegistry(key="transform")


Expand Down Expand Up @@ -276,42 +380,36 @@ def unfold_kmer_embeddings(
Examples:
>>> from danling import NestedTensor
>>> embeddings = NestedTensor(torch.arange(3).repeat(2, 1).T, torch.arange(4).repeat(2, 1).T) + 1
>>> output = unfold_kmer_embeddings(embeddings.tensor.float(), embeddings.mask, 2, True, True)
>>> output[0, :, 0].tolist()
[1.0, 2.0, 2.0, 3.0, 0.0]
>>> output[1, :, 0].tolist()
[1.0, 2.0, 2.5, 3.0, 4.0]
>>> embeddings = NestedTensor(torch.arange(5).repeat(2, 1).T, torch.arange(7).repeat(2, 1).T) + 1
>>> embeddings = NestedTensor(torch.arange(3).repeat(2, 1).T, torch.arange(5).repeat(2, 1).T) + 1
>>> output = unfold_kmer_embeddings(embeddings.tensor.float(), embeddings.mask, 3, True, True)
>>> output[0, :, 0].tolist()
[1.0, 2.0, 2.5, 3.0, 3.5, 4.0, 5.0, 0.0, 0.0]
[1.0, 2.0, 2.0, 2.0, 3.0, 0.0, 0.0]
>>> output[1, :, 0].tolist()
[1.0, 2.0, 2.5, 3.0, 4.0, 5.0, 5.5, 6.0, 7.0]
>>> embeddings = NestedTensor(torch.arange(7).repeat(2, 1).T, torch.arange(11).repeat(2, 1).T) + 1
>>> output = unfold_kmer_embeddings(embeddings.tensor.float(), embeddings.mask, 6, True, True)
>>> output[0, :, 0].tolist()
[1.0, 2.0, 2.5, 3.0, 3.5, 4.0, 4.0, 4.5, 5.0, 5.5, 6.0, 7.0, 0.0, 0.0, 0.0, 0.0]
>>> output[1, :, 0].tolist()
[1.0, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.5, 6.5, 7.5, 8.0, 8.5, 9.0, 9.5, 10.0, 11.0]
[1.0, 2.0, 2.5, 3.0, 3.5, 4.0, 5.0]
>>> embeddings = NestedTensor(torch.arange(5).repeat(2, 1).T, torch.arange(7).repeat(2, 1).T) + 1
>>> output = unfold_kmer_embeddings(embeddings.tensor.float(), embeddings.mask, 4, True, True)
>>> output[0, :, 0].tolist()
[1.0, 2.0, 2.5, 3.0, 3.0, 3.5, 4.0, 5.0, 0.0, 0.0]
>>> output[1, :, 0].tolist()
[1.0, 2.0, 2.5, 3.0, 3.5, 4.5, 5.0, 5.5, 6.0, 7.0]
>>> embeddings = NestedTensor(torch.arange(6).repeat(2, 1).T, torch.arange(7).repeat(2, 1).T) + 1
>>> embeddings = NestedTensor(torch.arange(7).repeat(2, 1).T, torch.arange(11).repeat(2, 1).T) + 1
>>> output = unfold_kmer_embeddings(embeddings.tensor.float(), embeddings.mask, 5, True, True)
>>> output[0, :, 0].tolist()
[1.0, 2.0, 2.5, 3.0, 3.5, 3.5, 4.0, 4.5, 5.0, 6.0, 0.0]
[1.0, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0, 7.0, 0.0, 0.0, 0.0, 0.0]
>>> output[1, :, 0].tolist()
[1.0, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0, 7.0]
>>> embeddings = NestedTensor(torch.arange(5).repeat(2, 1).T, torch.arange(7).repeat(2, 1).T) + 1
[1.0, 2.0, 2.5, 3.0, 3.5, 4.0, 5.0, 6.0, 7.0, 8.0, 8.5, 9.0, 9.5, 10.0, 11.0]
>>> embeddings = NestedTensor(torch.arange(3).repeat(2, 1).T, torch.arange(4).repeat(2, 1).T) + 1
>>> output = unfold_kmer_embeddings(embeddings.tensor.float(), embeddings.mask, 6, True, True)
>>> output[0, :, 0].tolist()
[1.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 3.0, 0.0]
>>> output[1, :, 0].tolist()
[1.0, 2.0, 2.5, 2.5, 2.5, 2.5, 2.5, 3.0, 4.0]
>>> embeddings = NestedTensor(torch.arange(1).repeat(2, 1).T, torch.arange(2).repeat(2, 1).T) + 1
>>> output = unfold_kmer_embeddings(embeddings.tensor.float(), embeddings.mask, 6)
>>> output[0, :, 0].tolist()
[1.0, 1.5, 2.0, 2.5, 3.0, 3.0, 3.5, 4.0, 4.5, 5.0, 0.0, 0.0]
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0]
>>> output[1, :, 0].tolist()
[1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.5, 5.0, 5.5, 6.0, 6.5, 7.0]
[1.0, 1.5, 1.5, 1.5, 1.5, 1.5, 2.0]
"""

batch_size, seq_length, hidden_size = embeddings.size()
Expand All @@ -333,7 +431,8 @@ def unfold_kmer_embeddings(
end = torch.stack([embedding[-i:].mean(0) for i in range(nmers, 0, -1)])
embedding = torch.cat([begin, end])
elif len(embedding) == 2:
embedding = torch.stack([embedding[0], embedding.mean(0), embedding[1]])
medium = embedding.mean(0).repeat(nmers - 1, 1)
embedding = torch.cat([embedding[0][None, :], medium, embedding[1][None, :]])
elif len(embedding) == 1:
embedding = embedding.repeat(nmers, 1)
else:
Expand Down
4 changes: 2 additions & 2 deletions multimolecule/models/utrbert/modeling_utrbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ class UtrBertForTokenClassification(UtrBertPreTrainedModel):
Examples:
>>> from multimolecule import UtrBertConfig, UtrBertForTokenClassification, RnaTokenizer
>>> tokenizer = RnaTokenizer(nmers=2, strameline=True)
>>> config = UtrBertConfig(vocab_size=tokenizer.vocab_size)
>>> config = UtrBertConfig(vocab_size=tokenizer.vocab_size, nmers=2)
>>> model = UtrBertForTokenClassification(config)
>>> input = tokenizer("ACGUN", return_tensors="pt")
>>> output = model(**input)
Expand Down Expand Up @@ -525,7 +525,7 @@ def forward(
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
logits = self.token_head(outputs)
logits = self.token_head(outputs, attention_mask, input_ids)

loss = None
if labels is not None:
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ dynamic = [
]
dependencies = [
"chanfig>=0.0.99",
"danling",
"transformers",
]
[project.urls]
Expand Down

0 comments on commit 71ea23c

Please sign in to comment.