Skip to content

Commit

Permalink
automatically download exl2 model in tests
Browse files Browse the repository at this point in the history
fix exl bug: sometimes piece_to_id not populated, but get_piece_to_id() still works

fix exl bug: sometimes piece_to_id not populated, but get_piece_to_id() still works

enable exl2 in generate.cfg

cleate OutlinesExLlamaV2Tokenizer rather than monkey patching
  • Loading branch information
lapp0 committed Oct 5, 2024
1 parent 6530d73 commit 6a7eb90
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 21 deletions.
9 changes: 1 addition & 8 deletions outlines/generate/cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
SequenceGeneratorAdapter,
VisionSequenceGeneratorAdapter,
)
from outlines.models import ExLlamaV2Model, LlamaCpp, OpenAI, TransformersVision
from outlines.models import LlamaCpp, OpenAI, TransformersVision
from outlines.samplers import Sampler, multinomial


Expand Down Expand Up @@ -41,13 +41,6 @@ def cfg_vision(model, cfg_str: str, sampler: Sampler = multinomial()):
return VisionSequenceGeneratorAdapter(model, logits_processor, sampler)


@cfg.register(ExLlamaV2Model)
def cfg_exllamav2(model, cfg_str: str, sampler: Sampler = multinomial()):
raise NotImplementedError(
"Not yet available, track progress in https://github.com/dottxt-ai/outlines/pull/1010"
)


@cfg.register(LlamaCpp)
def cfg_llamacpp(model, cfg_str: str, sampler: Sampler = multinomial()):
raise NotImplementedError("Not yet available due to bug in llama_cpp tokenizer")
Expand Down
39 changes: 27 additions & 12 deletions outlines/models/exllamav2.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import dataclasses
from typing import TYPE_CHECKING, Iterator, List, Optional, Tuple, TypedDict, Union

import torch
from typing_extensions import Unpack

from outlines.generate.api import GenerationParameters, SamplingParameters

if TYPE_CHECKING:
from exllamav2 import ExLlamaV2Tokenizer
import torch.LongTensor
from exllamav2.generator import ExLlamaV2DynamicGenerator, ExLlamaV2Sampler


Expand All @@ -18,13 +19,33 @@ class ExllamaV2Params(TypedDict, total=False):
max_new_tokens: List[int]


class OutlinesExLlamaV2Tokenizer:
def __init__(self, tokenizer):
self.exl2_tokenizer = tokenizer
self.vocabulary = self.exl2_tokenizer.get_piece_to_id_dict()
self.special_tokens = set(self.exl2_tokenizer.extended_piece_to_id)
self.eos_token_id = self.exl2_tokenizer.eos_token_id

def convert_token_to_string(self, token):
return token

def decode(self, token_ids: "torch.LongTensor") -> List[str]:
decoded = self.exl2_tokenizer.decode(
torch.tensor(token_ids),
decode_special_tokens=False,
)
if isinstance(decoded, str):
return [decoded]
return decoded


class ExLlamaV2Model:
"""Represents a `exl2` model."""

def __init__(
self,
generator: "ExLlamaV2DynamicGenerator",
tokenizer: "ExLlamaV2Tokenizer",
tokenizer: "OutlinesExLlamaV2Tokenizer",
max_seq_len: int,
):
self.generator = generator
Expand Down Expand Up @@ -220,14 +241,6 @@ def token_generator() -> Iterator[str]:
return token_generator()


# Taken from https://github.com/lapp0/exllamav2/pull/1/files#diff-26f303de07c10aad998e33d3df52581643673a598162cc4b35ef051f52d7c60b
def patch_tokenizer(tokenizer):
tokenizer.vocabulary = tokenizer.piece_to_id
tokenizer.special_tokens = set(tokenizer.extended_piece_to_id)
tokenizer.convert_token_to_string = lambda t: t
return tokenizer


def exl2(
model_path: str,
draft_model_path: Optional[str] = None,
Expand Down Expand Up @@ -306,7 +319,6 @@ def exl2(

print("Loading tokenizer...")
tokenizer = ExLlamaV2Tokenizer(config)
tokenizer = patch_tokenizer(tokenizer)
max_batch_size = 4 if paged else 1

draft_model = None
Expand Down Expand Up @@ -337,4 +349,7 @@ def exl2(
paged=paged,
)
max_seq_len = cache.max_seq_len
return ExLlamaV2Model(generator, tokenizer, max_seq_len)

outlines_tokenizer = OutlinesExLlamaV2Tokenizer(tokenizer)
outlines_exl2_model = ExLlamaV2Model(generator, outlines_tokenizer, max_seq_len)
return outlines_exl2_model
10 changes: 9 additions & 1 deletion tests/generate/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,16 @@ def model_llamacpp(tmp_path_factory):

@pytest.fixture(scope="session")
def model_exllamav2(tmp_path_factory):
from huggingface_hub import snapshot_download

tmp_dir = tmp_path_factory.mktemp("model_download")
model_path = snapshot_download(
repo_id="blockblockblock/TinyLlama-1.1B-Chat-v1.0-bpw4.6-exl2",
cache_dir=tmp_dir,
)

return models.exl2(
model_path="blockblockblock/TinyLlama-1.1B-Chat-v1.0-bpw4-exl2",
model_path=model_path,
cache_q4=True,
paged=False,
)
Expand Down

0 comments on commit 6a7eb90

Please sign in to comment.