diff --git a/multimolecule/models/modeling_utils.py b/multimolecule/models/modeling_utils.py index 99380631..5ba06bb9 100644 --- a/multimolecule/models/modeling_utils.py +++ b/multimolecule/models/modeling_utils.py @@ -12,6 +12,7 @@ from .configuration_utils import HeadConfig, PretrainedConfig TokenHeads = ConfigRegistry(key="tokenizer_type") +NucleotideHeads = ConfigRegistry(key="tokenizer_type") class ContactPredictionHead(nn.Module): @@ -71,9 +72,9 @@ def forward( eos_mask = input_ids.ne(self.eos_token_id).to(attentions) input_ids = input_ids[..., 1:] else: - last_valid_indices = attention_mask.sum(dim=-1) - 1 + last_valid_indices = attention_mask.sum(dim=-1) seq_length = attention_mask.size(-1) - eos_mask = torch.arange(seq_length, device=attentions.device).unsqueeze(0) != last_valid_indices + eos_mask = torch.arange(seq_length, device=attentions.device).unsqueeze(0) == last_valid_indices eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2) attentions *= eos_mask[:, None, None, :, :] attentions = attentions[..., :-1, :-1] @@ -200,6 +201,109 @@ def forward( # pylint: disable=arguments-renamed return output +@NucleotideHeads.register("single", default=True) +class NucleotideClassificationHead(ClassificationHead): + """Head for nucleotide-level tasks.""" + + def __init__(self, config: PretrainedConfig): + super().__init__(config) + self.bos_token_id = config.bos_token_id + self.eos_token_id = config.eos_token_id + self.pad_token_id = config.pad_token_id + + def forward( # pylint: disable=arguments-renamed + self, + outputs: ModelOutput | Tuple[Tensor, ...], + attention_mask: Optional[Tensor] = None, + input_ids: Optional[Tensor] = None, + ) -> Tensor: + if attention_mask is None: + if input_ids is None: + raise ValueError( + "Either attention_mask or input_ids must be provided for NucleotideClassificationHead to work." + ) + if self.pad_token_id is None: + raise ValueError( + "pad_token_id must be provided when attention_mask is not passed to NucleotideClassificationHead." + ) + attention_mask = input_ids.ne(self.pad_token_id) + + output = outputs[0] + # remove cls token embeddings + if self.bos_token_id is not None: + output = output[..., 1:, :] + attention_mask = attention_mask[..., 1:] + if input_ids is not None: + input_ids = input_ids[..., 1:] + # remove eos token embeddings + if self.eos_token_id is not None: + if input_ids is not None: + eos_mask = input_ids.ne(self.eos_token_id).to(output) + input_ids = input_ids[..., 1:] + else: + last_valid_indices = attention_mask.sum(dim=-1) + seq_length = attention_mask.size(-1) + eos_mask = torch.arange(seq_length, device=output.device) == last_valid_indices.unsqueeze(1) + output *= eos_mask[:, :, None] + output = output[..., :-1, :] + attention_mask = attention_mask[..., 1:] + + output = super().forward(output) + return output + + +@NucleotideHeads.register("kmer", default=True) +class NucleotideKMerHead(ClassificationHead): + """Head for nucleotide-level tasks.""" + + def __init__(self, config: PretrainedConfig): + super().__init__(config) + self.nmers = config.nmers + self.bos_token_id = None # Nucleotide-level head removes token. + self.eos_token_id = None # Nucleotide-level head removes token. + self.pad_token_id = config.pad_token_id + self.unfold_kmer_embeddings = partial(unfold_kmer_embeddings, nmers=self.nmers) + + def forward( # pylint: disable=arguments-renamed + self, + outputs: ModelOutput | Tuple[Tensor, ...], + attention_mask: Optional[Tensor] = None, + input_ids: Optional[Tensor] = None, + ) -> Tensor: + if attention_mask is None: + if input_ids is None: + raise ValueError("Either attention_mask or input_ids must be provided for NucleotideKMerHead to work.") + if self.pad_token_id is None: + raise ValueError( + "pad_token_id must be provided when attention_mask is not passed to NucleotideKMerHead." + ) + attention_mask = input_ids.ne(self.pad_token_id) + + output = outputs[0] + # remove cls token embeddings + if self.bos_token_id is not None: + output = output[..., 1:, :] + attention_mask = attention_mask[..., 1:] + if input_ids is not None: + input_ids = input_ids[..., 1:] + # remove eos token embeddings + if self.eos_token_id is not None: + if input_ids is not None: + eos_mask = input_ids.ne(self.eos_token_id).to(output) + input_ids = input_ids[..., 1:] + else: + last_valid_indices = attention_mask.sum(dim=-1) + seq_length = attention_mask.size(-1) + eos_mask = torch.arange(seq_length, device=output.device) == last_valid_indices.unsqueeze(1) + output *= eos_mask[:, :, None] + output = output[..., :-1, :] + attention_mask = attention_mask[..., 1:] + + output = self.unfold_kmer_embeddings(output, attention_mask) + output = super().forward(output) + return output + + PredictionHeadTransform = ConfigRegistry(key="transform") @@ -276,42 +380,36 @@ def unfold_kmer_embeddings( Examples: >>> from danling import NestedTensor - >>> embeddings = NestedTensor(torch.arange(3).repeat(2, 1).T, torch.arange(4).repeat(2, 1).T) + 1 - >>> output = unfold_kmer_embeddings(embeddings.tensor.float(), embeddings.mask, 2, True, True) - >>> output[0, :, 0].tolist() - [1.0, 2.0, 2.0, 3.0, 0.0] - >>> output[1, :, 0].tolist() - [1.0, 2.0, 2.5, 3.0, 4.0] - >>> embeddings = NestedTensor(torch.arange(5).repeat(2, 1).T, torch.arange(7).repeat(2, 1).T) + 1 + >>> embeddings = NestedTensor(torch.arange(3).repeat(2, 1).T, torch.arange(5).repeat(2, 1).T) + 1 >>> output = unfold_kmer_embeddings(embeddings.tensor.float(), embeddings.mask, 3, True, True) >>> output[0, :, 0].tolist() - [1.0, 2.0, 2.5, 3.0, 3.5, 4.0, 5.0, 0.0, 0.0] + [1.0, 2.0, 2.0, 2.0, 3.0, 0.0, 0.0] >>> output[1, :, 0].tolist() - [1.0, 2.0, 2.5, 3.0, 4.0, 5.0, 5.5, 6.0, 7.0] - >>> embeddings = NestedTensor(torch.arange(7).repeat(2, 1).T, torch.arange(11).repeat(2, 1).T) + 1 - >>> output = unfold_kmer_embeddings(embeddings.tensor.float(), embeddings.mask, 6, True, True) - >>> output[0, :, 0].tolist() - [1.0, 2.0, 2.5, 3.0, 3.5, 4.0, 4.0, 4.5, 5.0, 5.5, 6.0, 7.0, 0.0, 0.0, 0.0, 0.0] - >>> output[1, :, 0].tolist() - [1.0, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.5, 6.5, 7.5, 8.0, 8.5, 9.0, 9.5, 10.0, 11.0] + [1.0, 2.0, 2.5, 3.0, 3.5, 4.0, 5.0] >>> embeddings = NestedTensor(torch.arange(5).repeat(2, 1).T, torch.arange(7).repeat(2, 1).T) + 1 >>> output = unfold_kmer_embeddings(embeddings.tensor.float(), embeddings.mask, 4, True, True) >>> output[0, :, 0].tolist() [1.0, 2.0, 2.5, 3.0, 3.0, 3.5, 4.0, 5.0, 0.0, 0.0] >>> output[1, :, 0].tolist() [1.0, 2.0, 2.5, 3.0, 3.5, 4.5, 5.0, 5.5, 6.0, 7.0] - >>> embeddings = NestedTensor(torch.arange(6).repeat(2, 1).T, torch.arange(7).repeat(2, 1).T) + 1 + >>> embeddings = NestedTensor(torch.arange(7).repeat(2, 1).T, torch.arange(11).repeat(2, 1).T) + 1 >>> output = unfold_kmer_embeddings(embeddings.tensor.float(), embeddings.mask, 5, True, True) >>> output[0, :, 0].tolist() - [1.0, 2.0, 2.5, 3.0, 3.5, 3.5, 4.0, 4.5, 5.0, 6.0, 0.0] + [1.0, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0, 7.0, 0.0, 0.0, 0.0, 0.0] >>> output[1, :, 0].tolist() - [1.0, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0, 7.0] - >>> embeddings = NestedTensor(torch.arange(5).repeat(2, 1).T, torch.arange(7).repeat(2, 1).T) + 1 + [1.0, 2.0, 2.5, 3.0, 3.5, 4.0, 5.0, 6.0, 7.0, 8.0, 8.5, 9.0, 9.5, 10.0, 11.0] + >>> embeddings = NestedTensor(torch.arange(3).repeat(2, 1).T, torch.arange(4).repeat(2, 1).T) + 1 + >>> output = unfold_kmer_embeddings(embeddings.tensor.float(), embeddings.mask, 6, True, True) + >>> output[0, :, 0].tolist() + [1.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 3.0, 0.0] + >>> output[1, :, 0].tolist() + [1.0, 2.0, 2.5, 2.5, 2.5, 2.5, 2.5, 3.0, 4.0] + >>> embeddings = NestedTensor(torch.arange(1).repeat(2, 1).T, torch.arange(2).repeat(2, 1).T) + 1 >>> output = unfold_kmer_embeddings(embeddings.tensor.float(), embeddings.mask, 6) >>> output[0, :, 0].tolist() - [1.0, 1.5, 2.0, 2.5, 3.0, 3.0, 3.5, 4.0, 4.5, 5.0, 0.0, 0.0] + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0] >>> output[1, :, 0].tolist() - [1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.5, 5.0, 5.5, 6.0, 6.5, 7.0] + [1.0, 1.5, 1.5, 1.5, 1.5, 1.5, 2.0] """ batch_size, seq_length, hidden_size = embeddings.size() @@ -333,7 +431,8 @@ def unfold_kmer_embeddings( end = torch.stack([embedding[-i:].mean(0) for i in range(nmers, 0, -1)]) embedding = torch.cat([begin, end]) elif len(embedding) == 2: - embedding = torch.stack([embedding[0], embedding.mean(0), embedding[1]]) + medium = embedding.mean(0).repeat(nmers - 1, 1) + embedding = torch.cat([embedding[0][None, :], medium, embedding[1][None, :]]) elif len(embedding) == 1: embedding = embedding.repeat(nmers, 1) else: diff --git a/multimolecule/models/utrbert/modeling_utrbert.py b/multimolecule/models/utrbert/modeling_utrbert.py index d9044b41..ed35c245 100644 --- a/multimolecule/models/utrbert/modeling_utrbert.py +++ b/multimolecule/models/utrbert/modeling_utrbert.py @@ -481,7 +481,7 @@ class UtrBertForTokenClassification(UtrBertPreTrainedModel): Examples: >>> from multimolecule import UtrBertConfig, UtrBertForTokenClassification, RnaTokenizer >>> tokenizer = RnaTokenizer(nmers=2, strameline=True) - >>> config = UtrBertConfig(vocab_size=tokenizer.vocab_size) + >>> config = UtrBertConfig(vocab_size=tokenizer.vocab_size, nmers=2) >>> model = UtrBertForTokenClassification(config) >>> input = tokenizer("ACGUN", return_tensors="pt") >>> output = model(**input) @@ -525,7 +525,7 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, ) - logits = self.token_head(outputs) + logits = self.token_head(outputs, attention_mask, input_ids) loss = None if labels is not None: diff --git a/pyproject.toml b/pyproject.toml index 23e56241..205f50e9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ dynamic = [ ] dependencies = [ "chanfig>=0.0.99", + "danling", "transformers", ] [project.urls]