-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
17 changed files
with
1,862 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,8 @@ | ||
from ._protein_tokenizer import ProteinTokenizer | ||
|
||
__all__ = [ | ||
"ProteinTokenizer", | ||
] | ||
from ._hyena_tokenizer import HyenaTokenizer | ||
from ._hyena_tokenizer_transform import HyenaTokenizerTransform | ||
from ._pmlm_tokenizer import PmlmTokenizer, TrainablePmlmTokenizer | ||
from ._pmlm_tokenizer_transform import ( | ||
PmlmTokenizerTransform, | ||
PT5TeacherForcingTransform, | ||
PT5TokenizerTransform, | ||
) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import torch | ||
from cachetools import LRUCache, cached | ||
from torch import LongTensor, Tensor | ||
from transformers import BertTokenizerFast | ||
|
||
|
||
class CachedBertTokenizerFast(BertTokenizerFast): | ||
""" | ||
This class is a wrapper around the BertTokenizerFast class from the | ||
transformers library. It adds an additional cached encoding method for | ||
faster runtimes on smaller datasets. It also provides attributes | ||
indicating which tokens can be corrupted and sampled by denoising models, | ||
and a convenience method for getting a mask of corruptible tokens from a | ||
sequence of token ids. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
vocab_file: str = None, | ||
tokenizer_file: str = None, | ||
do_lower_case: bool = False, | ||
unk_token: str = "[UNK]", | ||
sep_token: str = "[SEP]", | ||
pad_token: str = "[PAD]", | ||
cls_token: str = "[CLS]", | ||
mask_token: str = "[MASK]", | ||
**kwargs, | ||
): | ||
super().__init__( | ||
vocab_file=vocab_file, | ||
tokenizer_file=tokenizer_file, | ||
do_lower_case=do_lower_case, | ||
unk_token=unk_token, | ||
sep_token=sep_token, | ||
pad_token=pad_token, | ||
cls_token=cls_token, | ||
mask_token=mask_token, | ||
**kwargs, | ||
) | ||
self.padding_idx = self.vocab[pad_token] | ||
self.masking_idx = self.vocab[mask_token] | ||
|
||
# prevent utility token input corruption | ||
utility_tokens = [ | ||
unk_token, | ||
sep_token, | ||
pad_token, | ||
cls_token, | ||
mask_token, | ||
] | ||
self.corruption_vocab_excluded = set(utility_tokens) | ||
self.sampling_vocab_excluded = set(utility_tokens) | ||
|
||
@property | ||
def corruption_vocab_included(self): | ||
return set(self.vocab.keys()) - self.corruption_vocab_excluded | ||
|
||
@property | ||
def sampling_vocab_included(self): | ||
return set(self.vocab.keys()) - self.sampling_vocab_excluded | ||
|
||
@cached(cache=LRUCache(maxsize=int(1e6))) | ||
def cached_encode(self, text: str): | ||
res = self.convert_tokens_to_ids(text.split(" ")) | ||
return res | ||
|
||
def get_corruptible_mask(self, token_batch: LongTensor) -> Tensor: | ||
""" | ||
Args: | ||
token_batch: a batch of token ids (LongTensor). | ||
Returns: | ||
a boolean mask tensor of corruptible tokens (corrupt if True). | ||
""" | ||
excluded_idxs = ( | ||
torch.tensor([self.vocab[tok] for tok in self.corruption_vocab_excluded]) | ||
.view(-1, 1, 1) | ||
.to(token_batch) | ||
) | ||
is_corruptible = token_batch.ne(excluded_idxs).prod(dim=0).bool() | ||
return is_corruptible |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
"""Adapted from https://github.com/huggingface/transformers/tree/v4.23.1/src/transformers/models""" | ||
|
||
import importlib.resources | ||
import os | ||
from typing import List, Optional, Union | ||
|
||
from transformers.tokenization_utils import PreTrainedTokenizer, Trie | ||
from transformers.tokenization_utils_base import AddedToken | ||
from transformers.utils import logging | ||
|
||
logger = logging.get_logger(__name__) | ||
|
||
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} | ||
|
||
VOCAB_PATH = ( | ||
importlib.resources.files("lobster") / "assets" / "hyena_tokenizer" / "vocab.txt" | ||
) | ||
|
||
|
||
def load_vocab_file(vocab_file): | ||
with open(vocab_file, "r") as f: | ||
lines = f.read().splitlines() | ||
return [ll.strip() for ll in lines] | ||
|
||
|
||
class HyenaTokenizer(PreTrainedTokenizer): | ||
""" | ||
Constructs a Hyena tokenizer. | ||
""" | ||
|
||
model_input_names = ["input_ids", "attention_mask"] | ||
|
||
def __init__( | ||
self, | ||
model_max_length: int, | ||
vocab_file=VOCAB_PATH, | ||
bos_token="<bos>", | ||
sep_token="<sep>", | ||
unk_token="<unk>", | ||
cls_token="<cls>", | ||
pad_token="<pad>", | ||
mask_token="<mask>", | ||
eos_token="<eos>", | ||
**kwargs, | ||
): | ||
self._characters = ("A", "C", "G", "T", "U", "N") | ||
self._model_max_length = model_max_length | ||
self.all_tokens = load_vocab_file(vocab_file) | ||
self._id_to_token = dict(enumerate(self.all_tokens)) | ||
super().__init__(**kwargs) | ||
|
||
self._token_to_id = {tok: ind for ind, tok in enumerate(self.all_tokens)} | ||
self.unk_token = unk_token | ||
self.bos_token = bos_token | ||
self.cls_token = cls_token | ||
self.pad_token = pad_token | ||
self.sep_token = sep_token | ||
self.mask_token = mask_token | ||
self.eos_token = eos_token | ||
self.unique_no_split_tokens = self.all_tokens | ||
self._create_trie(self.unique_no_split_tokens) | ||
|
||
def _convert_id_to_token(self, index: int) -> str: | ||
return self._id_to_token.get(index, self.unk_token) | ||
|
||
def _convert_token_to_id(self, token: str) -> int: | ||
return self._token_to_id.get(token, self._token_to_id.get(self.unk_token)) | ||
|
||
def convert_tokens_to_string(self, tokens): | ||
return "".join(tokens) | ||
|
||
def _create_trie(self, unique_no_split_tokens): | ||
trie = Trie() | ||
for token in unique_no_split_tokens: | ||
if ( | ||
hasattr(self, "do_lower_case") | ||
and self.do_lower_case | ||
and token not in self.all_special_tokens | ||
): | ||
trie.add(token.lower()) | ||
else: | ||
trie.add(token) | ||
self.tokens_trie = trie | ||
|
||
# def _tokenize(self, text, **kwargs): | ||
# return text.split() | ||
def _tokenize(self, text: str) -> List[str]: | ||
return list(text) | ||
|
||
def get_vocab_size(self, with_added_tokens=False): | ||
return len(self._id_to_token) | ||
|
||
def get_vocab(self): | ||
return {token: i for i, token in enumerate(self.all_tokens)} | ||
|
||
def token_to_id(self, token: str) -> int: | ||
return self._token_to_id.get(token, self._token_to_id.get(self.unk_token)) | ||
|
||
def id_to_token(self, index: int) -> str: | ||
return self._id_to_token.get(index, self.unk_token) | ||
|
||
def build_inputs_with_special_tokens( | ||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None | ||
) -> List[int]: | ||
sep = [self.sep_token_id] | ||
# cls = [self.cls_token_id] | ||
result = token_ids_0 + sep | ||
if token_ids_1 is not None: | ||
result += token_ids_1 + sep | ||
return result | ||
|
||
def get_special_tokens_mask( | ||
self, | ||
token_ids_0: List[int], | ||
token_ids_1: Optional[List[int]] = None, | ||
already_has_special_tokens: bool = False, | ||
) -> List[int]: | ||
if already_has_special_tokens: | ||
return super().get_special_tokens_mask( | ||
token_ids_0=token_ids_0, | ||
token_ids_1=token_ids_1, | ||
already_has_special_tokens=True, | ||
) | ||
|
||
result = ([0] * len(token_ids_0)) + [1] | ||
if token_ids_1 is not None: | ||
result += ([0] * len(token_ids_1)) + [1] | ||
return result | ||
|
||
def save_vocabulary(self, save_directory, filename_prefix): | ||
vocab_file = os.path.join( | ||
save_directory, | ||
(filename_prefix + "-" if filename_prefix else "") + "vocab.txt", | ||
) | ||
with open(vocab_file, "w") as f: | ||
f.write("\n".join(self.all_tokens)) | ||
return (vocab_file,) | ||
|
||
@property | ||
def vocab_size(self) -> int: | ||
return self.get_vocab_size(with_added_tokens=False) | ||
|
||
def _add_tokens( | ||
self, | ||
new_tokens: Union[List[str], List[AddedToken]], | ||
special_tokens: bool = False, | ||
) -> int: | ||
return super()._add_tokens(new_tokens, special_tokens=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
import importlib.resources | ||
import json | ||
from os import PathLike | ||
from typing import Any, Dict, List, Optional, Union | ||
|
||
import torch | ||
from transformers.tokenization_utils_base import ( | ||
BatchEncoding, | ||
PaddingStrategy, | ||
TruncationStrategy, | ||
) | ||
|
||
from beignet.transforms import Transform | ||
|
||
from ._hyena_tokenizer import HyenaTokenizer | ||
|
||
filename = ( | ||
importlib.resources.files("lobster") | ||
/ "assets" | ||
/ "codon_tables" | ||
/ "codon_table.json" | ||
) | ||
DNA_CODON_DICT = json.load(open(filename, "r")) | ||
|
||
|
||
class HyenaTokenizerTransform(Transform): | ||
def __init__( | ||
self, | ||
pretrained_model_name_or_path: Union[str, PathLike] = None, | ||
padding: Union[bool, str, PaddingStrategy] = False, | ||
truncation: Union[bool, str, TruncationStrategy] = False, | ||
max_length: Optional[int] = None, | ||
return_token_type_ids: Optional[bool] = None, | ||
return_attention_mask: Optional[bool] = False, | ||
return_overflowing_tokens: bool = False, | ||
return_special_tokens_mask: bool = False, | ||
return_offsets_mapping: bool = False, | ||
return_length: bool = False, | ||
verbose: bool = True, | ||
tokenizer_dir: Optional[str] = "hyena_tokenizer", | ||
mlm: bool = False, | ||
aa_to_dna: bool = False, | ||
): | ||
super().__init__() | ||
|
||
self._pretrained_model_name_or_path = pretrained_model_name_or_path | ||
self._padding = padding | ||
self._truncation = truncation | ||
self._max_length = max_length | ||
self._return_token_type_ids = return_token_type_ids | ||
self._return_attention_mask = return_attention_mask | ||
self._return_overflowing_tokens = return_overflowing_tokens | ||
self._return_special_tokens_mask = return_special_tokens_mask | ||
self._return_offsets_mapping = return_offsets_mapping | ||
self._return_length = return_length | ||
self._verbose = verbose | ||
self._tokenizer_dir = tokenizer_dir | ||
self._mlm = mlm | ||
self._aa_to_dna = aa_to_dna | ||
|
||
if self._pretrained_model_name_or_path is not None: | ||
self._auto_tokenizer = HyenaTokenizer.from_pretrained( | ||
self._pretrained_model_name_or_path, | ||
do_lower_case=False, | ||
add_special_tokens=True, | ||
padding_side="left", # since HyenaDNA is causal, we pad on the left | ||
use_fast=True, | ||
) | ||
elif self._tokenizer_dir is not None: | ||
path = importlib.resources.files("lobster") / "assets" / self._tokenizer_dir | ||
self._auto_tokenizer = HyenaTokenizer.from_pretrained( | ||
path, | ||
do_lower_case=False, | ||
add_special_tokens=True, | ||
padding_side="left", # since HyenaDNA is causal, we pad on the left | ||
use_fast=True, | ||
) | ||
|
||
def transform( | ||
self, | ||
text: Union[str, List[str], List[int]], | ||
parameters: dict[str, Any], | ||
) -> BatchEncoding: | ||
if self._aa_to_dna: | ||
text = [self.translate_aa_to_dna(seq) for seq in text] | ||
|
||
tokenized = self._auto_tokenizer( | ||
text, | ||
padding=self._padding, | ||
truncation=self._truncation, | ||
max_length=self._max_length, | ||
return_tensors="pt", | ||
return_token_type_ids=self._return_token_type_ids, | ||
return_attention_mask=self._return_attention_mask, | ||
return_overflowing_tokens=self._return_overflowing_tokens, | ||
return_special_tokens_mask=self._return_special_tokens_mask, | ||
return_offsets_mapping=self._return_offsets_mapping, | ||
return_length=self._return_length, | ||
verbose=self._verbose, | ||
) | ||
|
||
labels = tokenized["input_ids"].clone() | ||
if self._auto_tokenizer.pad_token_id is not None: | ||
labels[labels == self._auto_tokenizer.pad_token_id] = -100 # ignore in loss | ||
tokenized["labels"] = labels | ||
|
||
return tokenized | ||
|
||
def validate(self, flat_inputs: list[Any]) -> None: | ||
pass | ||
|
||
def _check_inputs(self, inputs: List[Any]) -> None: | ||
pass | ||
|
||
def _transform(self, input: Any, parameters: Dict[str, Any]) -> Any: | ||
return self.transform(input, parameters) | ||
|
||
def translate_aa_to_dna(self, aa_sequence: str) -> str: | ||
# TODO: update DNA frequencies | ||
dna_sequence = "".join( | ||
[ | ||
DNA_CODON_DICT[aa][ | ||
torch.randint(0, len(DNA_CODON_DICT[aa]), (1,)).item() | ||
] | ||
for aa in aa_sequence | ||
] | ||
) | ||
return dna_sequence |
Oops, something went wrong.