diff --git a/outlines/integrations/llamacpp.py b/outlines/integrations/llamacpp.py index 74498726d..8e18c33e7 100644 --- a/outlines/integrations/llamacpp.py +++ b/outlines/integrations/llamacpp.py @@ -26,7 +26,7 @@ """ import math -from typing import TYPE_CHECKING, Dict, Optional, Set, Type, Union +from typing import TYPE_CHECKING, Optional, Type, Union import numpy as np import torch @@ -36,47 +36,12 @@ from outlines.fsm.guide import CFGGuide, Guide, RegexGuide from outlines.fsm.json_schema import build_regex_from_schema from outlines.integrations.utils import convert_json_schema_to_str +from outlines.models.llamacpp import LlamaCppTokenizer if TYPE_CHECKING: from llama_cpp import Llama -class LlamaCppTokenizer: - def __init__(self, model: "Llama"): - self.eos_token_id = model.token_eos() - self.eos_token = model.tokenizer().decode([self.eos_token_id]) - self.pad_token_id = self.eos_token_id - self.special_tokens: Set[int] = set() - - self.vocabulary: Dict[str, int] = dict() - - tokenizer = model.tokenizer() - - self.decode = tokenizer.decode - - # TODO: Remove when https://github.com/ggerganov/llama.cpp/pull/5613 is resolved - try: - self.vocabulary = model.tokenizer_.hf_tokenizer.get_vocab() - except AttributeError: - # ### - for t in range(model.n_vocab()): - token_piece = model.tokenizer().decode([t]) - self.vocabulary[token_piece] = t - - def convert_token_to_string(self, token: str) -> str: - return token - - def __getstate__(self): - """Allow tokenizer to be used as hash key by excluding self.decode""" - return ( - self.vocabulary.items(), - self.eos_token_id, - self.eos_token, - self.pad_token_id, - sorted(self.special_tokens), - ) - - class LogitsProcessor: """Bias LlamaCpp generation using a finite state machine. diff --git a/outlines/models/llamacpp.py b/outlines/models/llamacpp.py index 5920f08d6..b85bd529d 100644 --- a/outlines/models/llamacpp.py +++ b/outlines/models/llamacpp.py @@ -1,15 +1,95 @@ import dataclasses +import pickle import warnings -from typing import TYPE_CHECKING, Iterator, List, Optional, TypedDict, Union +from typing import ( + TYPE_CHECKING, + Dict, + Iterator, + List, + Optional, + Set, + Tuple, + TypedDict, + Union, +) from typing_extensions import Unpack from outlines.generate.api import GenerationParameters, SamplingParameters +from outlines.models.tokenizer import Tokenizer if TYPE_CHECKING: from llama_cpp import Llama, LogitsProcessorList +class LlamaCppTokenizer(Tokenizer): + def __init__(self, model: "Llama"): + self.eos_token_id = model.token_eos() + self.eos_token = model.tokenizer().decode([self.eos_token_id]) + self.pad_token_id = self.eos_token_id + self.special_tokens: Set[int] = set() + + self.vocabulary: Dict[str, int] = dict() + + self.tokenizer = model.tokenizer() + + # TODO: Remove when https://github.com/ggerganov/llama.cpp/pull/5613 is resolved + try: + self.vocabulary = model.tokenizer_.hf_tokenizer.get_vocab() + except AttributeError: + # ### + for t in range(model.n_vocab()): + token_piece = model.tokenizer().decode([t]) + self.vocabulary[token_piece] = t + + self._hash = None + + def decode(self, token_ids: List[int]) -> List[str]: + decoded_bytes = self.tokenizer.detokenize(token_ids) + return [decoded_bytes.decode("utf-8", errors="ignore")] + + def encode( + self, prompt: Union[str, List[str]], add_bos: bool = True, special: bool = True + ) -> Tuple[List[int], List[int]]: + if isinstance(prompt, list): + raise NotImplementedError( + "llama-cpp-python tokenizer doesn't support batch tokenization" + ) + token_ids = self.tokenizer.tokenize( + prompt.encode("utf-8", errors="ignore"), add_bos=add_bos, special=special + ) + # generate attention mask, missing from llama-cpp-python + attention_mask = [ + 1 if token_id != self.pad_token_id else 0 for token_id in token_ids + ] + return token_ids, attention_mask + + def convert_token_to_string(self, token: str) -> str: + return token + + def __eq__(self, other): + return hash(self) == hash(other) + + def __hash__(self): + # cache object hash + if self._hash is None: + self._hash = hash(pickle.dumps(self)) + return self._hash + + def __getstate__(self): + """Create a stable representation for outlines.caching""" + return ( + self.vocabulary.items(), + self.eos_token_id, + self.eos_token, + self.pad_token_id, + sorted(self.special_tokens), + ) + + def __setstate__(self, state): + raise NotImplementedError("Cannot load a pickled llamacpp tokenizer") + + class LlamaCppParams(TypedDict, total=False): suffix: Optional[str] temperature: float diff --git a/tests/generate/test_integration_llamacpp.py b/tests/generate/test_integration_llamacpp.py index 531bf8fb9..89d41835e 100644 --- a/tests/generate/test_integration_llamacpp.py +++ b/tests/generate/test_integration_llamacpp.py @@ -281,9 +281,45 @@ def test_llama_cpp_pre_tokenizer_remains_broken(): generate.choice(model, ["skirt", "dress", "pen", "jacket"]) -def test_create_states_mapping_llamacpp_tokenizer_regression(model): - """Minimal reproducer for #922, error passing llamacpp tokenizer to create_states_mapping""" +def test_RegexGuide_caching(temp_cache_dir): + import outlines.caching from outlines.fsm.guide import create_states_mapping - from outlines.integrations.llamacpp import LlamaCppTokenizer - create_states_mapping("a", LlamaCppTokenizer(model.model)) + assert outlines.caching._caching_enabled + + regex = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)" + prompt = "What is the IP address of the Google DNS servers? " + + cache = outlines.caching.get_cache() + + # Returns (hits, misses) + _ = cache.stats(enable=True) + assert cache.statistics + + assert create_states_mapping.__memory__ is cache + + model = models.transformers( + "hf-internal-testing/tiny-random-XLMRobertaXLForCausalLM" + ) + generator = generate.regex(model, regex, sampler=samplers.greedy()) + assert cache.stats() == (0, 1) + + model_2 = models.transformers("hf-internal-testing/tiny-random-GPTJForCausalLM") + generator_2 = generate.regex(model_2, regex, sampler=samplers.greedy()) + assert cache.stats() == (0, 2) + + # These two different models and tokenizers should not have the same state + # mapping results + assert generator.fsm.states_to_token_maps != generator_2.fsm.states_to_token_maps + + generator_3 = generate.regex(model_2, regex, sampler=samplers.greedy()) + assert cache.stats() == (1, 2) + assert generator_2.fsm.states_to_token_maps == generator_3.fsm.states_to_token_maps + + # Just for fun... + structured = generator(prompt, max_tokens=30) + structured_2 = generator_2(prompt, max_tokens=30) + + assert re.fullmatch(regex, structured) + assert re.fullmatch(regex, structured_2) + assert structured != structured_2