Skip to content

Commit

Permalink
add head for nucleotide-level tasks
Browse files Browse the repository at this point in the history
Signed-off-by: Zhiyuan Chen <[email protected]>
  • Loading branch information
ZhiyuanChen committed Apr 29, 2024
1 parent 13549c9 commit 4361241
Showing 1 changed file with 48 additions and 0 deletions.
48 changes: 48 additions & 0 deletions multimolecule/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,54 @@ def forward(self, outputs: ModelOutput | Tuple[Tensor, ...]) -> Tensor: # pylin
return output


class NucleotideClassificationHead(ClassificationHead):
"""Head for nucleotide-level tasks."""

def __init__(self, config: PretrainedConfig):
super().__init__(config)
self.bos_token_id = config.bos_token_id
self.eos_token_id = config.eos_token_id
self.pad_token_id = config.pad_token_id

def forward( # pylint: disable=arguments-renamed
self,
outputs: ModelOutput | Tuple[Tensor, ...],
attention_mask: Optional[Tensor] = None,
input_ids: Optional[Tensor] = None,
) -> Tensor:
if attention_mask is None:
if input_ids is None:
raise ValueError(
"Either attention_mask or input_ids must be provided for contact prediction head to work."
)
if self.pad_token_id is None:
raise ValueError(
"pad_token_id must be provided when attention_mask is not passed to contact prediction head."
)
attention_mask = input_ids.ne(self.pad_token_id)
output = outputs[0]
# remove cls token embeddings
if self.bos_token_id is not None:
output = output[..., 1:, :]
attention_mask = attention_mask[..., 1:]
if input_ids is not None:
input_ids = input_ids[..., 1:]
# remove eos token embeddings
if self.eos_token_id is not None:
if input_ids is not None:
eos_mask = input_ids.ne(self.eos_token_id).to(output)
input_ids = input_ids[..., 1:]
else:
last_valid_indices = attention_mask.sum(dim=-1) - 1
seq_length = attention_mask.size(-1)
eos_mask = torch.arange(seq_length, device=output.device) != last_valid_indices.unsqueeze(1)
output *= eos_mask[:, :, None]
output = output[..., :-1, :]
attention_mask = attention_mask[..., 1:]
output = super().forward(output)
return output


PredictionHeadTransform = ConfigRegistry(key="transform")


Expand Down

0 comments on commit 4361241

Please sign in to comment.