Skip to content

Commit

Permalink
make LlamaCppTokenizer an outlines Tokenizer
Browse files Browse the repository at this point in the history
  • Loading branch information
lapp0 committed May 30, 2024
1 parent cb16b16 commit 88d8eaa
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 42 deletions.
39 changes: 2 additions & 37 deletions outlines/integrations/llamacpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
"""

import math
from typing import TYPE_CHECKING, Dict, Optional, Set, Type, Union
from typing import TYPE_CHECKING, Optional, Type, Union

import numpy as np
import torch
Expand All @@ -36,47 +36,12 @@
from outlines.fsm.guide import CFGGuide, Guide, RegexGuide
from outlines.fsm.json_schema import build_regex_from_schema
from outlines.integrations.utils import convert_json_schema_to_str
from outlines.models.llamacpp import LlamaCppTokenizer

if TYPE_CHECKING:
from llama_cpp import Llama


class LlamaCppTokenizer:
def __init__(self, model: "Llama"):
self.eos_token_id = model.token_eos()
self.eos_token = model.tokenizer().decode([self.eos_token_id])
self.pad_token_id = self.eos_token_id
self.special_tokens: Set[int] = set()

self.vocabulary: Dict[str, int] = dict()

tokenizer = model.tokenizer()

self.decode = tokenizer.decode

# TODO: Remove when https://github.com/ggerganov/llama.cpp/pull/5613 is resolved
try:
self.vocabulary = model.tokenizer_.hf_tokenizer.get_vocab()
except AttributeError:
# ###
for t in range(model.n_vocab()):
token_piece = model.tokenizer().decode([t])
self.vocabulary[token_piece] = t

def convert_token_to_string(self, token: str) -> str:
return token

def __getstate__(self):
"""Allow tokenizer to be used as hash key by excluding self.decode"""
return (
self.vocabulary.items(),
self.eos_token_id,
self.eos_token,
self.pad_token_id,
sorted(self.special_tokens),
)


class LogitsProcessor:
"""Bias LlamaCpp generation using a finite state machine.
Expand Down
82 changes: 81 additions & 1 deletion outlines/models/llamacpp.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,95 @@
import dataclasses
import pickle
import warnings
from typing import TYPE_CHECKING, Iterator, List, Optional, TypedDict, Union
from typing import (
TYPE_CHECKING,
Dict,
Iterator,
List,
Optional,
Set,
Tuple,
TypedDict,
Union,
)

from typing_extensions import Unpack

from outlines.generate.api import GenerationParameters, SamplingParameters
from outlines.models.tokenizer import Tokenizer

if TYPE_CHECKING:
from llama_cpp import Llama, LogitsProcessorList


class LlamaCppTokenizer(Tokenizer):
def __init__(self, model: "Llama"):
self.eos_token_id = model.token_eos()
self.eos_token = model.tokenizer().decode([self.eos_token_id])
self.pad_token_id = self.eos_token_id
self.special_tokens: Set[int] = set()

self.vocabulary: Dict[str, int] = dict()

self.tokenizer = model.tokenizer()

# TODO: Remove when https://github.com/ggerganov/llama.cpp/pull/5613 is resolved
try:
self.vocabulary = model.tokenizer_.hf_tokenizer.get_vocab()
except AttributeError:
# ###
for t in range(model.n_vocab()):
token_piece = model.tokenizer().decode([t])
self.vocabulary[token_piece] = t

self._hash = None

def decode(self, token_ids: List[int]) -> List[str]:
decoded_bytes = self.tokenizer.detokenize(token_ids)
return [decoded_bytes.decode("utf-8", errors="ignore")]

def encode(
self, prompt: Union[str, List[str]], add_bos: bool = True, special: bool = True
) -> Tuple[List[int], List[int]]:
if isinstance(prompt, list):
raise NotImplementedError(
"llama-cpp-python tokenizer doesn't support batch tokenization"
)
token_ids = self.tokenizer.tokenize(
prompt.encode("utf-8", errors="ignore"), add_bos=add_bos, special=special
)
# generate attention mask, missing from llama-cpp-python
attention_mask = [
1 if token_id != self.pad_token_id else 0 for token_id in token_ids
]
return token_ids, attention_mask

def convert_token_to_string(self, token: str) -> str:
return token

def __eq__(self, other):
return hash(self) == hash(other)

def __hash__(self):
# cache object hash
if self._hash is None:
self._hash = hash(pickle.dumps(self))
return self._hash

def __getstate__(self):
"""Create a stable representation for outlines.caching"""
return (
self.vocabulary.items(),
self.eos_token_id,
self.eos_token,
self.pad_token_id,
sorted(self.special_tokens),
)

def __setstate__(self, state):
raise NotImplementedError("Cannot load a pickled llamacpp tokenizer")


class LlamaCppParams(TypedDict, total=False):
suffix: Optional[str]
temperature: float
Expand Down
44 changes: 40 additions & 4 deletions tests/generate/test_integration_llamacpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,9 +281,45 @@ def test_llama_cpp_pre_tokenizer_remains_broken():
generate.choice(model, ["skirt", "dress", "pen", "jacket"])


def test_create_states_mapping_llamacpp_tokenizer_regression(model):
"""Minimal reproducer for #922, error passing llamacpp tokenizer to create_states_mapping"""
def test_RegexGuide_caching(temp_cache_dir):
import outlines.caching
from outlines.fsm.guide import create_states_mapping
from outlines.integrations.llamacpp import LlamaCppTokenizer

create_states_mapping("a", LlamaCppTokenizer(model.model))
assert outlines.caching._caching_enabled

regex = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
prompt = "What is the IP address of the Google DNS servers? "

cache = outlines.caching.get_cache()

# Returns (hits, misses)
_ = cache.stats(enable=True)
assert cache.statistics

assert create_states_mapping.__memory__ is cache

model = models.transformers(
"hf-internal-testing/tiny-random-XLMRobertaXLForCausalLM"
)
generator = generate.regex(model, regex, sampler=samplers.greedy())
assert cache.stats() == (0, 1)

model_2 = models.transformers("hf-internal-testing/tiny-random-GPTJForCausalLM")
generator_2 = generate.regex(model_2, regex, sampler=samplers.greedy())
assert cache.stats() == (0, 2)

# These two different models and tokenizers should not have the same state
# mapping results
assert generator.fsm.states_to_token_maps != generator_2.fsm.states_to_token_maps

generator_3 = generate.regex(model_2, regex, sampler=samplers.greedy())
assert cache.stats() == (1, 2)
assert generator_2.fsm.states_to_token_maps == generator_3.fsm.states_to_token_maps

# Just for fun...
structured = generator(prompt, max_tokens=30)
structured_2 = generator_2(prompt, max_tokens=30)

assert re.fullmatch(regex, structured)
assert re.fullmatch(regex, structured_2)
assert structured != structured_2

0 comments on commit 88d8eaa

Please sign in to comment.