diff --git a/outlines/fsm/guide.py b/outlines/fsm/guide.py index cf5be8d50..fa824449a 100644 --- a/outlines/fsm/guide.py +++ b/outlines/fsm/guide.py @@ -5,7 +5,6 @@ from lark import Lark from outlines import grammars -from outlines.caching import cache from outlines.fsm.regex import ( create_fsm_index_tokenizer, make_byte_level_fsm, @@ -110,7 +109,6 @@ class RegexGuide(Guide): initial_state = 0 def __init__(self, regex_string: str, tokenizer): - @cache() def create_states_mapping(regex_string: str) -> Tuple[dict, set, set]: """Create the variables related to the mapping between states and tokens The parameters of the function are used for caching purpose @@ -120,9 +118,7 @@ def create_states_mapping(regex_string: str) -> Tuple[dict, set, set]: regex_pattern.to_fsm().reduce(), keep_utf8=True ) regex_fsm, _ = make_deterministic_fsm(byte_fsm) - states_to_token_maps, empty_token_ids = create_fsm_index_tokenizer( - regex_fsm, tokenizer - ) + states_to_token_maps, _ = create_fsm_index_tokenizer(regex_fsm, tokenizer) # We make sure that it is possible to generate strings in the language # of the regular expression with the tokens present in the model's @@ -135,11 +131,10 @@ def create_states_mapping(regex_string: str) -> Tuple[dict, set, set]: "The vocabulary does not allow us to build a sequence that matches the input regex" ) - return states_to_token_maps, empty_token_ids, regex_fsm.finals + return states_to_token_maps, regex_fsm.finals ( self.states_to_token_maps, - self.empty_token_ids, fsm_finals, ) = create_states_mapping(regex_string) self.eos_token_id = tokenizer.eos_token_id @@ -220,9 +215,7 @@ def create_states_mapping_from_interegular_fsm( """ byte_fsm = make_byte_level_fsm(fsm.reduce(), keep_utf8=True) regex_fsm, _ = make_deterministic_fsm(byte_fsm) - states_to_token_maps, empty_token_ids = create_fsm_index_tokenizer( - regex_fsm, tokenizer - ) + states_to_token_maps, _ = create_fsm_index_tokenizer(regex_fsm, tokenizer) # We make sure that it is possible to generate strings in the language # of the regular expression with the tokens present in the model's @@ -235,11 +228,10 @@ def create_states_mapping_from_interegular_fsm( "The vocabulary does not allow us to build a sequence that matches the input regex" ) - return states_to_token_maps, empty_token_ids + return states_to_token_maps ( from_interegular_instance.states_to_token_maps, - from_interegular_instance.empty_token_ids, ) = create_states_mapping_from_interegular_fsm(interegular_fsm) from_interegular_instance.eos_token_id = tokenizer.eos_token_id return from_interegular_instance diff --git a/outlines/fsm/regex.py b/outlines/fsm/regex.py index 0941bbb9f..5242aea92 100644 --- a/outlines/fsm/regex.py +++ b/outlines/fsm/regex.py @@ -1,6 +1,5 @@ import re from collections import namedtuple -from functools import lru_cache from typing import ( TYPE_CHECKING, Dict, @@ -112,7 +111,7 @@ def fsm_info(self): nb_unichar_2_type = numba.types.UnicodeCharSeq(2) -@numba.njit(cache=True) +@numba.njit(cache=False) def create_fsm_info( py_initial, py_finals, @@ -411,7 +410,7 @@ def make_deterministic_fsm(fsm: FSM) -> Tuple[BetterFSM, Dict[int, int]]: return new_fsm, old_to_new_states -@numba.njit(nogil=True, cache=True) +@numba.njit(nogil=True, cache=False) def _walk_fsm( fsm_transitions: Dict[Tuple[int, int], int], alphabet_symbol_mapping: Dict[str, int], @@ -647,7 +646,7 @@ def get_sub_fsms_from_seq( ) -@numba.njit(cache=True, nogil=True) +@numba.njit(cache=False, nogil=True) def state_scan_tokens( fsm_transitions: Dict[Tuple[int, int], int], alphabet_symbol_mapping: Dict[str, int], @@ -723,7 +722,6 @@ def create_fsm_index_end_to_end( # Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode -@lru_cache() def gpt2_bytes_to_unicode(): """ Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control @@ -750,14 +748,12 @@ def gpt2_bytes_to_unicode(): return dict(zip(bs, cs)) -@lru_cache() def gpt2_unicode_to_bytes(): return {v: k for k, v in gpt2_bytes_to_unicode().items()} # TODO: Cannot cache typed collections to disk, yet. See # https://github.com/numba/numba/issues/4698 -@lru_cache def reduced_vocabulary( tokenizer: "Tokenizer", ) -> Tuple[List[Tuple[Sequence[str], Sequence[int]]], Set[int]]: diff --git a/outlines/models/vllm.py b/outlines/models/vllm.py index 378a35d91..038c80e60 100644 --- a/outlines/models/vllm.py +++ b/outlines/models/vllm.py @@ -1,7 +1,9 @@ +import concurrent.futures import dataclasses from typing import TYPE_CHECKING, List, Optional, Union from outlines.generate.api import GenerationParameters, SamplingParameters +from outlines.integrations.utils import adapt_tokenizer if TYPE_CHECKING: from vllm import LLM @@ -18,7 +20,7 @@ class VLLM: """ - def __init__(self, model: "LLM"): + def __init__(self, model: "LLM", tokenizer): self.model = model self.lora_request = None @@ -73,6 +75,9 @@ def generate( # are specified by the user when calling the generator. if max_tokens is not None: sampling_params.max_tokens = max_tokens + else: + sampling_params.max_tokens = None + if stop_at is not None: if isinstance(stop_at, str): stop_at = [stop_at] @@ -140,7 +145,40 @@ def load_lora(self, adapter_path: Optional[str]): self.lora_request = LoRARequest(adapter_path, 1, adapter_path) -def vllm(model_name: str, **vllm_model_params): +def load_model(model_name, **vllm_model_params): + """Load the model in GPU memory.""" + from vllm import LLM + + model = LLM(model_name, **vllm_model_params) + return model + + +def load_and_convert_tokenizer(model_name): + """Convert vocabulary types and JIT-compile function.""" + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained(model_name) + tokenizer = convert_tokenizer(tokenizer) + jit_compile_numba(tokenizer.numba_vocabulary) + + return tokenizer + + +def convert_tokenizer(tokenizer): + from outlines.fsm.regex import reduced_vocabulary + + tokenizer = adapt_tokenizer(tokenizer) + vocabulary, _ = reduced_vocabulary(tokenizer) + tokenizer.vocabulary_numba = vocabulary + + return tokenizer + + +def jit_compile_numba(vocabulary): + pass + + +def vllm(model_or_model_name: Union[str, "LLM"], **vllm_model_params): """Load a vLLM model. Arguments @@ -154,6 +192,16 @@ def vllm(model_name: str, **vllm_model_params): """ from vllm import LLM - model = LLM(model_name, **vllm_model_params) + if isinstance(model_or_model_name, LLM): + model = model_or_model_name + tokenizer = convert_tokenizer(model.tokenizer) + else: + with concurrent.futures.ThreadPoolExecutor() as executor: + model_name = model_or_model_name + future_model = executor.submit(load_model, model_name, **vllm_model_params) + future_tokenizer = executor.submit(load_and_convert_tokenizer, model_name) + + tokenizer = future_tokenizer.result() + model = future_model.result() - return VLLM(model) + return VLLM(model, tokenizer) diff --git a/profile.py b/profile.py new file mode 100644 index 000000000..cd2824b11 --- /dev/null +++ b/profile.py @@ -0,0 +1,54 @@ +import interegular +import line_profiler +from transformers import SPIECE_UNDERLINE + +from outlines import generate, models +from outlines.fsm.regex import ( + _walk_fsm, + create_fsm_index_end_to_end, + create_fsm_index_tokenizer, + make_byte_level_fsm, + make_deterministic_fsm, + reduced_vocabulary, + state_scan_tokens, +) + + +def run_model(): + model = models.vllm( + "TheBloke/Mistral-7B-Instruct-v0.2-GPTQ", quantization="gptq", dtype="half" + ) + tokenizer = model.model.get_tokenizer() + tokenizer.vocabulary = tokenizer.get_vocab() + tokenizer.special_tokens = set(tokenizer.all_special_tokens) + + def convert_token_to_string(token): + string = tokenizer.convert_tokens_to_string([token]) + + # A hack to handle missing spaces to HF's Llama tokenizers + if ( + type(token) is str + and token.startswith(SPIECE_UNDERLINE) + or token == "<0x20>" + ): + return " " + string + + return string + + tokenizer.convert_token_to_string = convert_token_to_string + + regex_string = '\\{[\n ]*"name"[\n ]*:[\n ]*"(?:[^"\\\x00-\x1f\x7f-\x9f]|\\.){,10}"[\n ]*,[\n ]*"age"[\n ]*:[\n ]*(0|[1-9][0-9]*)[\n ]*,[\n ]*"armor"[\n ]*:[\n ]*("leather"|"chainmail"|"plate")[\n ]*,[\n ]*"strength"[\n ]*:[\n ]*(0|[1-9][0-9]*)[\n ]*\\}' + regex_pattern = interegular.parse_pattern(regex_string) + byte_fsm = make_byte_level_fsm(regex_pattern.to_fsm().reduce(), keep_utf8=True) + regex_fsm, _ = make_deterministic_fsm(byte_fsm) + states_to_token_maps, empty_token_ids = create_fsm_index_tokenizer( + regex_fsm, tokenizer + ) + + +profile = line_profiler.LineProfiler() +profile.add_function(create_fsm_index_tokenizer) +profile.add_function(create_fsm_index_end_to_end) +profile.add_function(reduced_vocabulary) +profile(run_model)() +profile.print_stats()