Skip to content

Commit

Permalink
Convert vocabulary types and load model concurrently
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Apr 24, 2024
1 parent d5068d6 commit 6deef74
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 14 deletions.
14 changes: 4 additions & 10 deletions outlines/fsm/guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,7 @@ def create_states_mapping(regex_string: str) -> Tuple[dict, set, set]:
regex_pattern.to_fsm().reduce(), keep_utf8=True
)
regex_fsm, _ = make_deterministic_fsm(byte_fsm)
states_to_token_maps, empty_token_ids = create_fsm_index_tokenizer(
regex_fsm, tokenizer
)
states_to_token_maps, _ = create_fsm_index_tokenizer(regex_fsm, tokenizer)

# We make sure that it is possible to generate strings in the language
# of the regular expression with the tokens present in the model's
Expand All @@ -133,11 +131,10 @@ def create_states_mapping(regex_string: str) -> Tuple[dict, set, set]:
"The vocabulary does not allow us to build a sequence that matches the input regex"
)

return states_to_token_maps, empty_token_ids, regex_fsm.finals
return states_to_token_maps, regex_fsm.finals

(
self.states_to_token_maps,
self.empty_token_ids,
fsm_finals,
) = create_states_mapping(regex_string)
self.eos_token_id = tokenizer.eos_token_id
Expand Down Expand Up @@ -218,9 +215,7 @@ def create_states_mapping_from_interegular_fsm(
"""
byte_fsm = make_byte_level_fsm(fsm.reduce(), keep_utf8=True)
regex_fsm, _ = make_deterministic_fsm(byte_fsm)
states_to_token_maps, empty_token_ids = create_fsm_index_tokenizer(
regex_fsm, tokenizer
)
states_to_token_maps, _ = create_fsm_index_tokenizer(regex_fsm, tokenizer)

# We make sure that it is possible to generate strings in the language
# of the regular expression with the tokens present in the model's
Expand All @@ -233,11 +228,10 @@ def create_states_mapping_from_interegular_fsm(
"The vocabulary does not allow us to build a sequence that matches the input regex"
)

return states_to_token_maps, empty_token_ids
return states_to_token_maps

(
from_interegular_instance.states_to_token_maps,
from_interegular_instance.empty_token_ids,
) = create_states_mapping_from_interegular_fsm(interegular_fsm)
from_interegular_instance.eos_token_id = tokenizer.eos_token_id
return from_interegular_instance
Expand Down
56 changes: 52 additions & 4 deletions outlines/models/vllm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import concurrent.futures
import dataclasses
from typing import TYPE_CHECKING, List, Optional, Union

from outlines.generate.api import GenerationParameters, SamplingParameters
from outlines.integrations.utils import adapt_tokenizer

if TYPE_CHECKING:
from vllm import LLM
Expand All @@ -18,7 +20,7 @@ class VLLM:
"""

def __init__(self, model: "LLM"):
def __init__(self, model: "LLM", tokenizer):
self.model = model
self.lora_request = None

Expand Down Expand Up @@ -73,6 +75,9 @@ def generate(
# are specified by the user when calling the generator.
if max_tokens is not None:
sampling_params.max_tokens = max_tokens
else:
sampling_params.max_tokens = None

if stop_at is not None:
if isinstance(stop_at, str):
stop_at = [stop_at]
Expand Down Expand Up @@ -140,7 +145,40 @@ def load_lora(self, adapter_path: Optional[str]):
self.lora_request = LoRARequest(adapter_path, 1, adapter_path)


def vllm(model_name: str, **vllm_model_params):
def load_model(model_name, **vllm_model_params):
"""Load the model in GPU memory."""
from vllm import LLM

model = LLM(model_name, **vllm_model_params)
return model


def load_and_convert_tokenizer(model_name):
"""Convert vocabulary types and JIT-compile function."""
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer = convert_tokenizer(tokenizer)
jit_compile_numba(tokenizer.numba_vocabulary)

return tokenizer


def convert_tokenizer(tokenizer):
from outlines.fsm.regex import reduced_vocabulary

tokenizer = adapt_tokenizer(tokenizer)
vocabulary, _ = reduced_vocabulary(tokenizer)
tokenizer.vocabulary_numba = vocabulary

return tokenizer


def jit_compile_numba(vocabulary):
pass


def vllm(model_or_model_name: Union[str, "LLM"], **vllm_model_params):
"""Load a vLLM model.
Arguments
Expand All @@ -154,6 +192,16 @@ def vllm(model_name: str, **vllm_model_params):
"""
from vllm import LLM

model = LLM(model_name, **vllm_model_params)
if isinstance(model_or_model_name, LLM):
model = model_or_model_name
tokenizer = convert_tokenizer(model.tokenizer)
else:
with concurrent.futures.ThreadPoolExecutor() as executor:
model_name = model_or_model_name
future_model = executor.submit(load_model, model_name, **vllm_model_params)
future_tokenizer = executor.submit(load_and_convert_tokenizer, model_name)

tokenizer = future_tokenizer.result()
model = future_model.result()

return VLLM(model)
return VLLM(model, tokenizer)

0 comments on commit 6deef74

Please sign in to comment.