diff --git a/outlines/fsm/guide.py b/outlines/fsm/guide.py index 51b6e707e..fa824449a 100644 --- a/outlines/fsm/guide.py +++ b/outlines/fsm/guide.py @@ -118,9 +118,7 @@ def create_states_mapping(regex_string: str) -> Tuple[dict, set, set]: regex_pattern.to_fsm().reduce(), keep_utf8=True ) regex_fsm, _ = make_deterministic_fsm(byte_fsm) - states_to_token_maps, empty_token_ids = create_fsm_index_tokenizer( - regex_fsm, tokenizer - ) + states_to_token_maps, _ = create_fsm_index_tokenizer(regex_fsm, tokenizer) # We make sure that it is possible to generate strings in the language # of the regular expression with the tokens present in the model's @@ -133,11 +131,10 @@ def create_states_mapping(regex_string: str) -> Tuple[dict, set, set]: "The vocabulary does not allow us to build a sequence that matches the input regex" ) - return states_to_token_maps, empty_token_ids, regex_fsm.finals + return states_to_token_maps, regex_fsm.finals ( self.states_to_token_maps, - self.empty_token_ids, fsm_finals, ) = create_states_mapping(regex_string) self.eos_token_id = tokenizer.eos_token_id @@ -218,9 +215,7 @@ def create_states_mapping_from_interegular_fsm( """ byte_fsm = make_byte_level_fsm(fsm.reduce(), keep_utf8=True) regex_fsm, _ = make_deterministic_fsm(byte_fsm) - states_to_token_maps, empty_token_ids = create_fsm_index_tokenizer( - regex_fsm, tokenizer - ) + states_to_token_maps, _ = create_fsm_index_tokenizer(regex_fsm, tokenizer) # We make sure that it is possible to generate strings in the language # of the regular expression with the tokens present in the model's @@ -233,11 +228,10 @@ def create_states_mapping_from_interegular_fsm( "The vocabulary does not allow us to build a sequence that matches the input regex" ) - return states_to_token_maps, empty_token_ids + return states_to_token_maps ( from_interegular_instance.states_to_token_maps, - from_interegular_instance.empty_token_ids, ) = create_states_mapping_from_interegular_fsm(interegular_fsm) from_interegular_instance.eos_token_id = tokenizer.eos_token_id return from_interegular_instance diff --git a/outlines/models/vllm.py b/outlines/models/vllm.py index 378a35d91..038c80e60 100644 --- a/outlines/models/vllm.py +++ b/outlines/models/vllm.py @@ -1,7 +1,9 @@ +import concurrent.futures import dataclasses from typing import TYPE_CHECKING, List, Optional, Union from outlines.generate.api import GenerationParameters, SamplingParameters +from outlines.integrations.utils import adapt_tokenizer if TYPE_CHECKING: from vllm import LLM @@ -18,7 +20,7 @@ class VLLM: """ - def __init__(self, model: "LLM"): + def __init__(self, model: "LLM", tokenizer): self.model = model self.lora_request = None @@ -73,6 +75,9 @@ def generate( # are specified by the user when calling the generator. if max_tokens is not None: sampling_params.max_tokens = max_tokens + else: + sampling_params.max_tokens = None + if stop_at is not None: if isinstance(stop_at, str): stop_at = [stop_at] @@ -140,7 +145,40 @@ def load_lora(self, adapter_path: Optional[str]): self.lora_request = LoRARequest(adapter_path, 1, adapter_path) -def vllm(model_name: str, **vllm_model_params): +def load_model(model_name, **vllm_model_params): + """Load the model in GPU memory.""" + from vllm import LLM + + model = LLM(model_name, **vllm_model_params) + return model + + +def load_and_convert_tokenizer(model_name): + """Convert vocabulary types and JIT-compile function.""" + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained(model_name) + tokenizer = convert_tokenizer(tokenizer) + jit_compile_numba(tokenizer.numba_vocabulary) + + return tokenizer + + +def convert_tokenizer(tokenizer): + from outlines.fsm.regex import reduced_vocabulary + + tokenizer = adapt_tokenizer(tokenizer) + vocabulary, _ = reduced_vocabulary(tokenizer) + tokenizer.vocabulary_numba = vocabulary + + return tokenizer + + +def jit_compile_numba(vocabulary): + pass + + +def vllm(model_or_model_name: Union[str, "LLM"], **vllm_model_params): """Load a vLLM model. Arguments @@ -154,6 +192,16 @@ def vllm(model_name: str, **vllm_model_params): """ from vllm import LLM - model = LLM(model_name, **vllm_model_params) + if isinstance(model_or_model_name, LLM): + model = model_or_model_name + tokenizer = convert_tokenizer(model.tokenizer) + else: + with concurrent.futures.ThreadPoolExecutor() as executor: + model_name = model_or_model_name + future_model = executor.submit(load_model, model_name, **vllm_model_params) + future_tokenizer = executor.submit(load_and_convert_tokenizer, model_name) + + tokenizer = future_tokenizer.result() + model = future_model.result() - return VLLM(model) + return VLLM(model, tokenizer)