Skip to content

Commit

Permalink
[temp] remove cache and add profiling script
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Apr 24, 2024
1 parent 4d6ec1f commit d5068d6
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 9 deletions.
2 changes: 0 additions & 2 deletions outlines/fsm/guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from lark import Lark

from outlines import grammars
from outlines.caching import cache
from outlines.fsm.regex import (
create_fsm_index_tokenizer,
make_byte_level_fsm,
Expand Down Expand Up @@ -110,7 +109,6 @@ class RegexGuide(Guide):
initial_state = 0

def __init__(self, regex_string: str, tokenizer):
@cache()
def create_states_mapping(regex_string: str) -> Tuple[dict, set, set]:
"""Create the variables related to the mapping between states and tokens
The parameters of the function are used for caching purpose
Expand Down
10 changes: 3 additions & 7 deletions outlines/fsm/regex.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import re
from collections import namedtuple
from functools import lru_cache
from typing import (
TYPE_CHECKING,
Dict,
Expand Down Expand Up @@ -112,7 +111,7 @@ def fsm_info(self):
nb_unichar_2_type = numba.types.UnicodeCharSeq(2)


@numba.njit(cache=True)
@numba.njit(cache=False)
def create_fsm_info(
py_initial,
py_finals,
Expand Down Expand Up @@ -411,7 +410,7 @@ def make_deterministic_fsm(fsm: FSM) -> Tuple[BetterFSM, Dict[int, int]]:
return new_fsm, old_to_new_states


@numba.njit(nogil=True, cache=True)
@numba.njit(nogil=True, cache=False)
def _walk_fsm(
fsm_transitions: Dict[Tuple[int, int], int],
alphabet_symbol_mapping: Dict[str, int],
Expand Down Expand Up @@ -647,7 +646,7 @@ def get_sub_fsms_from_seq(
)


@numba.njit(cache=True, nogil=True)
@numba.njit(cache=False, nogil=True)
def state_scan_tokens(
fsm_transitions: Dict[Tuple[int, int], int],
alphabet_symbol_mapping: Dict[str, int],
Expand Down Expand Up @@ -723,7 +722,6 @@ def create_fsm_index_end_to_end(


# Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode
@lru_cache()
def gpt2_bytes_to_unicode():
"""
Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
Expand All @@ -750,14 +748,12 @@ def gpt2_bytes_to_unicode():
return dict(zip(bs, cs))


@lru_cache()
def gpt2_unicode_to_bytes():
return {v: k for k, v in gpt2_bytes_to_unicode().items()}


# TODO: Cannot cache typed collections to disk, yet. See
# https://github.com/numba/numba/issues/4698
@lru_cache
def reduced_vocabulary(
tokenizer: "Tokenizer",
) -> Tuple[List[Tuple[Sequence[str], Sequence[int]]], Set[int]]:
Expand Down
54 changes: 54 additions & 0 deletions profile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import interegular
import line_profiler
from transformers import SPIECE_UNDERLINE

from outlines import generate, models
from outlines.fsm.regex import (
_walk_fsm,
create_fsm_index_end_to_end,
create_fsm_index_tokenizer,
make_byte_level_fsm,
make_deterministic_fsm,
reduced_vocabulary,
state_scan_tokens,
)


def run_model():
model = models.vllm(
"TheBloke/Mistral-7B-Instruct-v0.2-GPTQ", quantization="gptq", dtype="half"
)
tokenizer = model.model.get_tokenizer()
tokenizer.vocabulary = tokenizer.get_vocab()
tokenizer.special_tokens = set(tokenizer.all_special_tokens)

def convert_token_to_string(token):
string = tokenizer.convert_tokens_to_string([token])

# A hack to handle missing spaces to HF's Llama tokenizers
if (
type(token) is str
and token.startswith(SPIECE_UNDERLINE)
or token == "<0x20>"
):
return " " + string

return string

tokenizer.convert_token_to_string = convert_token_to_string

regex_string = '\\{[\n ]*"name"[\n ]*:[\n ]*"(?:[^"\\\x00-\x1f\x7f-\x9f]|\\.){,10}"[\n ]*,[\n ]*"age"[\n ]*:[\n ]*(0|[1-9][0-9]*)[\n ]*,[\n ]*"armor"[\n ]*:[\n ]*("leather"|"chainmail"|"plate")[\n ]*,[\n ]*"strength"[\n ]*:[\n ]*(0|[1-9][0-9]*)[\n ]*\\}'
regex_pattern = interegular.parse_pattern(regex_string)
byte_fsm = make_byte_level_fsm(regex_pattern.to_fsm().reduce(), keep_utf8=True)
regex_fsm, _ = make_deterministic_fsm(byte_fsm)
states_to_token_maps, empty_token_ids = create_fsm_index_tokenizer(
regex_fsm, tokenizer
)


profile = line_profiler.LineProfiler()
profile.add_function(create_fsm_index_tokenizer)
profile.add_function(create_fsm_index_end_to_end)
profile.add_function(reduced_vocabulary)
profile(run_model)()
profile.print_stats()

0 comments on commit d5068d6

Please sign in to comment.