Skip to content

Commit

Permalink
cleate OutlinesExLlamaV2Tokenizer rather than monkey patching
Browse files Browse the repository at this point in the history
  • Loading branch information
lapp0 committed Oct 5, 2024
1 parent e12cc7e commit c957265
Showing 1 changed file with 24 additions and 11 deletions.
35 changes: 24 additions & 11 deletions outlines/models/exllamav2.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import dataclasses
from typing import TYPE_CHECKING, Iterator, List, Optional, Tuple, TypedDict, Union

import torch
from typing_extensions import Unpack

from outlines.generate.api import GenerationParameters, SamplingParameters

if TYPE_CHECKING:
from exllamav2 import ExLlamaV2Tokenizer
import torch.LongTensor
from exllamav2.generator import ExLlamaV2DynamicGenerator, ExLlamaV2Sampler


Expand All @@ -18,13 +19,30 @@ class ExllamaV2Params(TypedDict, total=False):
max_new_tokens: List[int]


class OutlinesExLlamaV2Tokenizer:
def __init__(self, tokenizer):
self.exl2_tokenizer = tokenizer
self.vocabulary = self.exl2_tokenizer.get_piece_to_id_dict()
self.special_tokens = set(self.exl2_tokenizer.extended_piece_to_id)
self.eos_token_id = self.exl2_tokenizer.eos_token_id

def convert_token_to_string(self, token):
return token

def decode(self, token_ids: "torch.LongTensor") -> List[str]:
return self.exl2_tokenizer.decode(
[torch.tensor(tok_seq) for tok_seq in token_ids],
decode_special_tokens=False,
)


class ExLlamaV2Model:
"""Represents a `exl2` model."""

def __init__(
self,
generator: "ExLlamaV2DynamicGenerator",
tokenizer: "ExLlamaV2Tokenizer",
tokenizer: "OutlinesExLlamaV2Tokenizer",
max_seq_len: int,
):
self.generator = generator
Expand Down Expand Up @@ -220,13 +238,6 @@ def token_generator() -> Iterator[str]:
return token_generator()


def patch_tokenizer(tokenizer):
tokenizer.vocabulary = tokenizer.get_piece_to_id_dict()
tokenizer.special_tokens = set(tokenizer.extended_piece_to_id)
tokenizer.convert_token_to_string = lambda t: t
return tokenizer


def exl2(
model_path: str,
draft_model_path: Optional[str] = None,
Expand Down Expand Up @@ -305,7 +316,6 @@ def exl2(

print("Loading tokenizer...")
tokenizer = ExLlamaV2Tokenizer(config)
tokenizer = patch_tokenizer(tokenizer)
max_batch_size = 4 if paged else 1

draft_model = None
Expand Down Expand Up @@ -336,4 +346,7 @@ def exl2(
paged=paged,
)
max_seq_len = cache.max_seq_len
return ExLlamaV2Model(generator, tokenizer, max_seq_len)

outlines_tokenizer = OutlinesExLlamaV2Tokenizer(tokenizer)
outlines_exl2_model = ExLlamaV2Model(generator, outlines_tokenizer, max_seq_len)
return outlines_exl2_model

0 comments on commit c957265

Please sign in to comment.