Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert vocabulary types and load model concurrently #832

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 4 additions & 12 deletions outlines/fsm/guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from lark import Lark

from outlines import grammars
from outlines.caching import cache
from outlines.fsm.regex import (
create_fsm_index_tokenizer,
make_byte_level_fsm,
Expand Down Expand Up @@ -110,7 +109,6 @@ class RegexGuide(Guide):
initial_state = 0

def __init__(self, regex_string: str, tokenizer):
@cache()
def create_states_mapping(regex_string: str) -> Tuple[dict, set, set]:
"""Create the variables related to the mapping between states and tokens
The parameters of the function are used for caching purpose
Expand All @@ -120,9 +118,7 @@ def create_states_mapping(regex_string: str) -> Tuple[dict, set, set]:
regex_pattern.to_fsm().reduce(), keep_utf8=True
)
regex_fsm, _ = make_deterministic_fsm(byte_fsm)
states_to_token_maps, empty_token_ids = create_fsm_index_tokenizer(
regex_fsm, tokenizer
)
states_to_token_maps, _ = create_fsm_index_tokenizer(regex_fsm, tokenizer)

# We make sure that it is possible to generate strings in the language
# of the regular expression with the tokens present in the model's
Expand All @@ -135,11 +131,10 @@ def create_states_mapping(regex_string: str) -> Tuple[dict, set, set]:
"The vocabulary does not allow us to build a sequence that matches the input regex"
)

return states_to_token_maps, empty_token_ids, regex_fsm.finals
return states_to_token_maps, regex_fsm.finals

(
self.states_to_token_maps,
self.empty_token_ids,
fsm_finals,
) = create_states_mapping(regex_string)
self.eos_token_id = tokenizer.eos_token_id
Expand Down Expand Up @@ -220,9 +215,7 @@ def create_states_mapping_from_interegular_fsm(
"""
byte_fsm = make_byte_level_fsm(fsm.reduce(), keep_utf8=True)
regex_fsm, _ = make_deterministic_fsm(byte_fsm)
states_to_token_maps, empty_token_ids = create_fsm_index_tokenizer(
regex_fsm, tokenizer
)
states_to_token_maps, _ = create_fsm_index_tokenizer(regex_fsm, tokenizer)

# We make sure that it is possible to generate strings in the language
# of the regular expression with the tokens present in the model's
Expand All @@ -235,11 +228,10 @@ def create_states_mapping_from_interegular_fsm(
"The vocabulary does not allow us to build a sequence that matches the input regex"
)

return states_to_token_maps, empty_token_ids
return states_to_token_maps

(
from_interegular_instance.states_to_token_maps,
from_interegular_instance.empty_token_ids,
) = create_states_mapping_from_interegular_fsm(interegular_fsm)
from_interegular_instance.eos_token_id = tokenizer.eos_token_id
return from_interegular_instance
Expand Down
10 changes: 3 additions & 7 deletions outlines/fsm/regex.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import re
from collections import namedtuple
from functools import lru_cache
from typing import (
TYPE_CHECKING,
Dict,
Expand Down Expand Up @@ -112,7 +111,7 @@ def fsm_info(self):
nb_unichar_2_type = numba.types.UnicodeCharSeq(2)


@numba.njit(cache=True)
@numba.njit(cache=False)
def create_fsm_info(
py_initial,
py_finals,
Expand Down Expand Up @@ -411,7 +410,7 @@ def make_deterministic_fsm(fsm: FSM) -> Tuple[BetterFSM, Dict[int, int]]:
return new_fsm, old_to_new_states


@numba.njit(nogil=True, cache=True)
@numba.njit(nogil=True, cache=False)
def _walk_fsm(
fsm_transitions: Dict[Tuple[int, int], int],
alphabet_symbol_mapping: Dict[str, int],
Expand Down Expand Up @@ -647,7 +646,7 @@ def get_sub_fsms_from_seq(
)


@numba.njit(cache=True, nogil=True)
@numba.njit(cache=False, nogil=True)
def state_scan_tokens(
fsm_transitions: Dict[Tuple[int, int], int],
alphabet_symbol_mapping: Dict[str, int],
Expand Down Expand Up @@ -723,7 +722,6 @@ def create_fsm_index_end_to_end(


# Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode
@lru_cache()
def gpt2_bytes_to_unicode():
"""
Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
Expand All @@ -750,14 +748,12 @@ def gpt2_bytes_to_unicode():
return dict(zip(bs, cs))


@lru_cache()
def gpt2_unicode_to_bytes():
return {v: k for k, v in gpt2_bytes_to_unicode().items()}


# TODO: Cannot cache typed collections to disk, yet. See
# https://github.com/numba/numba/issues/4698
@lru_cache
def reduced_vocabulary(
tokenizer: "Tokenizer",
) -> Tuple[List[Tuple[Sequence[str], Sequence[int]]], Set[int]]:
Expand Down
56 changes: 52 additions & 4 deletions outlines/models/vllm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import concurrent.futures
import dataclasses
from typing import TYPE_CHECKING, List, Optional, Union

from outlines.generate.api import GenerationParameters, SamplingParameters
from outlines.integrations.utils import adapt_tokenizer

if TYPE_CHECKING:
from vllm import LLM
Expand All @@ -18,7 +20,7 @@ class VLLM:

"""

def __init__(self, model: "LLM"):
def __init__(self, model: "LLM", tokenizer):
self.model = model
self.lora_request = None

Expand Down Expand Up @@ -73,6 +75,9 @@ def generate(
# are specified by the user when calling the generator.
if max_tokens is not None:
sampling_params.max_tokens = max_tokens
else:
sampling_params.max_tokens = None

if stop_at is not None:
if isinstance(stop_at, str):
stop_at = [stop_at]
Expand Down Expand Up @@ -140,7 +145,40 @@ def load_lora(self, adapter_path: Optional[str]):
self.lora_request = LoRARequest(adapter_path, 1, adapter_path)


def vllm(model_name: str, **vllm_model_params):
def load_model(model_name, **vllm_model_params):
"""Load the model in GPU memory."""
from vllm import LLM

model = LLM(model_name, **vllm_model_params)
return model


def load_and_convert_tokenizer(model_name):
"""Convert vocabulary types and JIT-compile function."""
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer = convert_tokenizer(tokenizer)
jit_compile_numba(tokenizer.numba_vocabulary)

return tokenizer


def convert_tokenizer(tokenizer):
from outlines.fsm.regex import reduced_vocabulary

tokenizer = adapt_tokenizer(tokenizer)
vocabulary, _ = reduced_vocabulary(tokenizer)
tokenizer.vocabulary_numba = vocabulary

return tokenizer


def jit_compile_numba(vocabulary):
pass


def vllm(model_or_model_name: Union[str, "LLM"], **vllm_model_params):
"""Load a vLLM model.

Arguments
Expand All @@ -154,6 +192,16 @@ def vllm(model_name: str, **vllm_model_params):
"""
from vllm import LLM

model = LLM(model_name, **vllm_model_params)
if isinstance(model_or_model_name, LLM):
model = model_or_model_name
tokenizer = convert_tokenizer(model.tokenizer)
else:
with concurrent.futures.ThreadPoolExecutor() as executor:
model_name = model_or_model_name
future_model = executor.submit(load_model, model_name, **vllm_model_params)
future_tokenizer = executor.submit(load_and_convert_tokenizer, model_name)

tokenizer = future_tokenizer.result()
model = future_model.result()

return VLLM(model)
return VLLM(model, tokenizer)
54 changes: 54 additions & 0 deletions profile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import interegular
import line_profiler
from transformers import SPIECE_UNDERLINE

from outlines import generate, models
from outlines.fsm.regex import (
_walk_fsm,
create_fsm_index_end_to_end,
create_fsm_index_tokenizer,
make_byte_level_fsm,
make_deterministic_fsm,
reduced_vocabulary,
state_scan_tokens,
)


def run_model():
model = models.vllm(
"TheBloke/Mistral-7B-Instruct-v0.2-GPTQ", quantization="gptq", dtype="half"
)
tokenizer = model.model.get_tokenizer()
tokenizer.vocabulary = tokenizer.get_vocab()
tokenizer.special_tokens = set(tokenizer.all_special_tokens)

def convert_token_to_string(token):
string = tokenizer.convert_tokens_to_string([token])

# A hack to handle missing spaces to HF's Llama tokenizers
if (
type(token) is str
and token.startswith(SPIECE_UNDERLINE)
or token == "<0x20>"
):
return " " + string

return string

tokenizer.convert_token_to_string = convert_token_to_string

regex_string = '\\{[\n ]*"name"[\n ]*:[\n ]*"(?:[^"\\\x00-\x1f\x7f-\x9f]|\\.){,10}"[\n ]*,[\n ]*"age"[\n ]*:[\n ]*(0|[1-9][0-9]*)[\n ]*,[\n ]*"armor"[\n ]*:[\n ]*("leather"|"chainmail"|"plate")[\n ]*,[\n ]*"strength"[\n ]*:[\n ]*(0|[1-9][0-9]*)[\n ]*\\}'
regex_pattern = interegular.parse_pattern(regex_string)
byte_fsm = make_byte_level_fsm(regex_pattern.to_fsm().reduce(), keep_utf8=True)
regex_fsm, _ = make_deterministic_fsm(byte_fsm)
states_to_token_maps, empty_token_ids = create_fsm_index_tokenizer(
regex_fsm, tokenizer
)


profile = line_profiler.LineProfiler()
profile.add_function(create_fsm_index_tokenizer)
profile.add_function(create_fsm_index_end_to_end)
profile.add_function(reduced_vocabulary)
profile(run_model)()
profile.print_stats()
Loading