Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster tokenizer #137

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 80 additions & 5 deletions tokenizer/rwkv_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,12 +213,84 @@ def printTokens(self, tokens):
print(f'{repr(s)}{i}', end=' ')
print()

########################################################################################################
# Tokenizer #4 (fast) https://github.com/LoganDark
########################################################################################################

from typing import Generator, Iterable
from ast import literal_eval

# noinspection PyUnresolvedReferences,PyUnboundLocalVariable
FastTokenizerEntry: TypeAlias = Tuple[Optional[int], Dict[int, FastTokenizerEntry]]

class FastTokenizer:
__slots__ = ('tok2val', 'root')

tok2val: Dict[int, bytes]
root: Dict[int, FastTokenizerEntry]

def __init__(self, file_name) -> None:
self.tok2val = {}
self.root = {}

with open(file_name, 'rt', encoding = 'utf-8') as file:
for line in file:
token_str, value_repr = line.rstrip().split(' ', 1)
value_repr, len_str = value_repr.rsplit(' ', 1)
value_str: Union[bytes, str] = literal_eval(value_repr)
value = value_str if isinstance(value_str, bytes) else value_str.encode('utf-8')
assert len(value) == int(len_str)
self.add_token(int(token_str), value)

def add_token(self, token: int, value: bytes) -> None:
self.tok2val[token] = value
pos = self.root
for byte in value[:-1]: pos = pos.setdefault(byte, (None, {}))[1]
pos.setdefault(value[-1], (token, {}))

def next_token(self, src: bytes) -> Optional[int]:
last_token: Optional[int] = None
last = self.root
for i in range(0, len(src)):
if current := last.get(src[i]):
if token := current[0]: last_token = token
last = current[1]
else: break
return last_token

def encode_bytes(self, src: bytes) -> Generator[int, None, None]:
start, stop = 0, len(src)
while start < stop:
last_token: Optional[int] = None
last = self.root

for i in range(start, stop):
if current := last.get(src[i]):
if token := current[0]:
last_token = token
start = i + 1
last = current[1]
else: break

if last_token: yield last_token
else: break

def decode_bytes(self, tokens: Iterable[int]) -> bytes:
return b''.join(map(self.tok2val.__getitem__, tokens))

def encode(self, src: str) -> Generator[int, None, None]:
return self.encode_bytes(src.encode('utf-8'))

def decode(self, tokens: Iterable[int]) -> str:
return self.decode_bytes(tokens).decode('utf-8')

########################################################################################################
# Demo
########################################################################################################

TOKENIZER = RWKV_TOKENIZER('rwkv_vocab_v20230424.txt')
TRIE_TEST = TRIE_TOKENIZER('rwkv_vocab_v20230424.txt')
FAST_TEST = FastTokenizer('rwkv_vocab_v20230424.txt')

src = '''起業家イーロン・マスク氏が創業した宇宙開発企業「スペースX(エックス)」の巨大新型ロケット「スターシップ」が20日朝、初めて打ち上げられたが、爆発した。
打ち上げは米テキサス州の東海岸で行われた。無人の試験で、負傷者はいなかった。
Expand Down Expand Up @@ -248,21 +320,22 @@ def printTokens(self, tokens):

def benchmark(XXX):
min_t = 1e100
for i in range(10):
t_begin = time.time_ns()
tokens = XXX.encode(src)
for i in range(5):
t_begin = time.time_ns() - 1
tokens = list(XXX.encode(src))
min_t = min(time.time_ns() - t_begin, min_t)
print('Encode', round(src_len / min_t * 1e3, 3), 'MB/s')

min_t = 1e100
for i in range(10):
t_begin = time.time_ns()
t_begin = time.time_ns() - 1
sss = XXX.decode(tokens)
min_t = min(time.time_ns() - t_begin, min_t)
print('Decode', round(src_len / min_t * 1e3, 3), 'MB/s')

benchmark(TOKENIZER)
benchmark(TRIE_TEST)
benchmark(FAST_TEST)

########################################################################################################
# Unit Test
Expand Down Expand Up @@ -1003,12 +1076,14 @@ def benchmark(XXX):
Lojban: mi kakne le nu citka le blaci .iku'i le se go'i na xrani mi
Nórdicg: Ljœr ye caudran créneþ ý jor cẃran.
''']

for q in QQQ:
tokens = TOKENIZER.encode(q)
if q != TOKENIZER.decode(tokens):
print('ERROR', q)
if str(tokens) != str(TRIE_TEST.encode(q)):
print('ERROR', q)
if str(tokens) != str(list(FAST_TEST.encode(q))):
print('ERROR', q)

print('All OK\n')