Skip to content

Commit

Permalink
tokenizers
Browse files Browse the repository at this point in the history
  • Loading branch information
0x00b1 committed Jul 19, 2024
1 parent 7bed8f1 commit ac5d5a3
Show file tree
Hide file tree
Showing 17 changed files with 1,862 additions and 5 deletions.
13 changes: 8 additions & 5 deletions src/beignet/tokenizers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from ._protein_tokenizer import ProteinTokenizer

__all__ = [
"ProteinTokenizer",
]
from ._hyena_tokenizer import HyenaTokenizer
from ._hyena_tokenizer_transform import HyenaTokenizerTransform
from ._pmlm_tokenizer import PmlmTokenizer, TrainablePmlmTokenizer
from ._pmlm_tokenizer_transform import (
PmlmTokenizerTransform,
PT5TeacherForcingTransform,
PT5TokenizerTransform,
)
Empty file.
80 changes: 80 additions & 0 deletions src/beignet/tokenizers/_cached_bert_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import torch
from cachetools import LRUCache, cached
from torch import LongTensor, Tensor
from transformers import BertTokenizerFast


class CachedBertTokenizerFast(BertTokenizerFast):
"""
This class is a wrapper around the BertTokenizerFast class from the
transformers library. It adds an additional cached encoding method for
faster runtimes on smaller datasets. It also provides attributes
indicating which tokens can be corrupted and sampled by denoising models,
and a convenience method for getting a mask of corruptible tokens from a
sequence of token ids.
"""

def __init__(
self,
vocab_file: str = None,
tokenizer_file: str = None,
do_lower_case: bool = False,
unk_token: str = "[UNK]",
sep_token: str = "[SEP]",
pad_token: str = "[PAD]",
cls_token: str = "[CLS]",
mask_token: str = "[MASK]",
**kwargs,
):
super().__init__(
vocab_file=vocab_file,
tokenizer_file=tokenizer_file,
do_lower_case=do_lower_case,
unk_token=unk_token,
sep_token=sep_token,
pad_token=pad_token,
cls_token=cls_token,
mask_token=mask_token,
**kwargs,
)
self.padding_idx = self.vocab[pad_token]
self.masking_idx = self.vocab[mask_token]

# prevent utility token input corruption
utility_tokens = [
unk_token,
sep_token,
pad_token,
cls_token,
mask_token,
]
self.corruption_vocab_excluded = set(utility_tokens)
self.sampling_vocab_excluded = set(utility_tokens)

@property
def corruption_vocab_included(self):
return set(self.vocab.keys()) - self.corruption_vocab_excluded

@property
def sampling_vocab_included(self):
return set(self.vocab.keys()) - self.sampling_vocab_excluded

@cached(cache=LRUCache(maxsize=int(1e6)))
def cached_encode(self, text: str):
res = self.convert_tokens_to_ids(text.split(" "))
return res

def get_corruptible_mask(self, token_batch: LongTensor) -> Tensor:
"""
Args:
token_batch: a batch of token ids (LongTensor).
Returns:
a boolean mask tensor of corruptible tokens (corrupt if True).
"""
excluded_idxs = (
torch.tensor([self.vocab[tok] for tok in self.corruption_vocab_excluded])
.view(-1, 1, 1)
.to(token_batch)
)
is_corruptible = token_batch.ne(excluded_idxs).prod(dim=0).bool()
return is_corruptible
148 changes: 148 additions & 0 deletions src/beignet/tokenizers/_hyena_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
"""Adapted from https://github.com/huggingface/transformers/tree/v4.23.1/src/transformers/models"""

import importlib.resources
import os
from typing import List, Optional, Union

from transformers.tokenization_utils import PreTrainedTokenizer, Trie
from transformers.tokenization_utils_base import AddedToken
from transformers.utils import logging

logger = logging.get_logger(__name__)

VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}

VOCAB_PATH = (
importlib.resources.files("lobster") / "assets" / "hyena_tokenizer" / "vocab.txt"
)


def load_vocab_file(vocab_file):
with open(vocab_file, "r") as f:
lines = f.read().splitlines()
return [ll.strip() for ll in lines]


class HyenaTokenizer(PreTrainedTokenizer):
"""
Constructs a Hyena tokenizer.
"""

model_input_names = ["input_ids", "attention_mask"]

def __init__(
self,
model_max_length: int,
vocab_file=VOCAB_PATH,
bos_token="<bos>",
sep_token="<sep>",
unk_token="<unk>",
cls_token="<cls>",
pad_token="<pad>",
mask_token="<mask>",
eos_token="<eos>",
**kwargs,
):
self._characters = ("A", "C", "G", "T", "U", "N")
self._model_max_length = model_max_length
self.all_tokens = load_vocab_file(vocab_file)
self._id_to_token = dict(enumerate(self.all_tokens))
super().__init__(**kwargs)

self._token_to_id = {tok: ind for ind, tok in enumerate(self.all_tokens)}
self.unk_token = unk_token
self.bos_token = bos_token
self.cls_token = cls_token
self.pad_token = pad_token
self.sep_token = sep_token
self.mask_token = mask_token
self.eos_token = eos_token
self.unique_no_split_tokens = self.all_tokens
self._create_trie(self.unique_no_split_tokens)

def _convert_id_to_token(self, index: int) -> str:
return self._id_to_token.get(index, self.unk_token)

def _convert_token_to_id(self, token: str) -> int:
return self._token_to_id.get(token, self._token_to_id.get(self.unk_token))

def convert_tokens_to_string(self, tokens):
return "".join(tokens)

def _create_trie(self, unique_no_split_tokens):
trie = Trie()
for token in unique_no_split_tokens:
if (
hasattr(self, "do_lower_case")
and self.do_lower_case
and token not in self.all_special_tokens
):
trie.add(token.lower())
else:
trie.add(token)
self.tokens_trie = trie

# def _tokenize(self, text, **kwargs):
# return text.split()
def _tokenize(self, text: str) -> List[str]:
return list(text)

def get_vocab_size(self, with_added_tokens=False):
return len(self._id_to_token)

def get_vocab(self):
return {token: i for i, token in enumerate(self.all_tokens)}

def token_to_id(self, token: str) -> int:
return self._token_to_id.get(token, self._token_to_id.get(self.unk_token))

def id_to_token(self, index: int) -> str:
return self._id_to_token.get(index, self.unk_token)

def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
sep = [self.sep_token_id]
# cls = [self.cls_token_id]
result = token_ids_0 + sep
if token_ids_1 is not None:
result += token_ids_1 + sep
return result

def get_special_tokens_mask(
self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None,
already_has_special_tokens: bool = False,
) -> List[int]:
if already_has_special_tokens:
return super().get_special_tokens_mask(
token_ids_0=token_ids_0,
token_ids_1=token_ids_1,
already_has_special_tokens=True,
)

result = ([0] * len(token_ids_0)) + [1]
if token_ids_1 is not None:
result += ([0] * len(token_ids_1)) + [1]
return result

def save_vocabulary(self, save_directory, filename_prefix):
vocab_file = os.path.join(
save_directory,
(filename_prefix + "-" if filename_prefix else "") + "vocab.txt",
)
with open(vocab_file, "w") as f:
f.write("\n".join(self.all_tokens))
return (vocab_file,)

@property
def vocab_size(self) -> int:
return self.get_vocab_size(with_added_tokens=False)

def _add_tokens(
self,
new_tokens: Union[List[str], List[AddedToken]],
special_tokens: bool = False,
) -> int:
return super()._add_tokens(new_tokens, special_tokens=True)
128 changes: 128 additions & 0 deletions src/beignet/tokenizers/_hyena_tokenizer_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import importlib.resources
import json
from os import PathLike
from typing import Any, Dict, List, Optional, Union

import torch
from transformers.tokenization_utils_base import (
BatchEncoding,
PaddingStrategy,
TruncationStrategy,
)

from beignet.transforms import Transform

from ._hyena_tokenizer import HyenaTokenizer

filename = (
importlib.resources.files("lobster")
/ "assets"
/ "codon_tables"
/ "codon_table.json"
)
DNA_CODON_DICT = json.load(open(filename, "r"))


class HyenaTokenizerTransform(Transform):
def __init__(
self,
pretrained_model_name_or_path: Union[str, PathLike] = None,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = False,
max_length: Optional[int] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = False,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_length: bool = False,
verbose: bool = True,
tokenizer_dir: Optional[str] = "hyena_tokenizer",
mlm: bool = False,
aa_to_dna: bool = False,
):
super().__init__()

self._pretrained_model_name_or_path = pretrained_model_name_or_path
self._padding = padding
self._truncation = truncation
self._max_length = max_length
self._return_token_type_ids = return_token_type_ids
self._return_attention_mask = return_attention_mask
self._return_overflowing_tokens = return_overflowing_tokens
self._return_special_tokens_mask = return_special_tokens_mask
self._return_offsets_mapping = return_offsets_mapping
self._return_length = return_length
self._verbose = verbose
self._tokenizer_dir = tokenizer_dir
self._mlm = mlm
self._aa_to_dna = aa_to_dna

if self._pretrained_model_name_or_path is not None:
self._auto_tokenizer = HyenaTokenizer.from_pretrained(
self._pretrained_model_name_or_path,
do_lower_case=False,
add_special_tokens=True,
padding_side="left", # since HyenaDNA is causal, we pad on the left
use_fast=True,
)
elif self._tokenizer_dir is not None:
path = importlib.resources.files("lobster") / "assets" / self._tokenizer_dir
self._auto_tokenizer = HyenaTokenizer.from_pretrained(
path,
do_lower_case=False,
add_special_tokens=True,
padding_side="left", # since HyenaDNA is causal, we pad on the left
use_fast=True,
)

def transform(
self,
text: Union[str, List[str], List[int]],
parameters: dict[str, Any],
) -> BatchEncoding:
if self._aa_to_dna:
text = [self.translate_aa_to_dna(seq) for seq in text]

tokenized = self._auto_tokenizer(
text,
padding=self._padding,
truncation=self._truncation,
max_length=self._max_length,
return_tensors="pt",
return_token_type_ids=self._return_token_type_ids,
return_attention_mask=self._return_attention_mask,
return_overflowing_tokens=self._return_overflowing_tokens,
return_special_tokens_mask=self._return_special_tokens_mask,
return_offsets_mapping=self._return_offsets_mapping,
return_length=self._return_length,
verbose=self._verbose,
)

labels = tokenized["input_ids"].clone()
if self._auto_tokenizer.pad_token_id is not None:
labels[labels == self._auto_tokenizer.pad_token_id] = -100 # ignore in loss
tokenized["labels"] = labels

return tokenized

def validate(self, flat_inputs: list[Any]) -> None:
pass

def _check_inputs(self, inputs: List[Any]) -> None:
pass

def _transform(self, input: Any, parameters: Dict[str, Any]) -> Any:
return self.transform(input, parameters)

def translate_aa_to_dna(self, aa_sequence: str) -> str:
# TODO: update DNA frequencies
dna_sequence = "".join(
[
DNA_CODON_DICT[aa][
torch.randint(0, len(DNA_CODON_DICT[aa]), (1,)).item()
]
for aa in aa_sequence
]
)
return dna_sequence
Loading

0 comments on commit ac5d5a3

Please sign in to comment.