Skip to content

Commit

Permalink
Clean high-level imports
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Nov 28, 2024
1 parent 7cdaeac commit 11ae206
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
4 changes: 2 additions & 2 deletions outlines/models/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
import inspect
from typing import TYPE_CHECKING, Iterator, List, Optional, Tuple, Union

from datasets.fingerprint import Hasher

from outlines.generate.api import GenerationParameters, SamplingParameters
from outlines.models.tokenizer import Tokenizer

Expand Down Expand Up @@ -116,6 +114,8 @@ def __eq__(self, other):
return NotImplemented

def __hash__(self):
from datasets.fingerprint import Hasher

return hash(Hasher.hash(self.tokenizer))

def __getstate__(self):
Expand Down
7 changes: 4 additions & 3 deletions outlines/models/vllm.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import dataclasses
from typing import TYPE_CHECKING, List, Optional, Union

from transformers import SPIECE_UNDERLINE, PreTrainedTokenizerBase

from outlines.generate.api import GenerationParameters, SamplingParameters

if TYPE_CHECKING:
from transformers import PreTrainedTokenizerBase
from vllm import LLM
from vllm.sampling_params import SamplingParams

Expand Down Expand Up @@ -188,7 +187,7 @@ def vllm(model_name: str, **vllm_model_params):
return VLLM(model)


def adapt_tokenizer(tokenizer: PreTrainedTokenizerBase) -> PreTrainedTokenizerBase:
def adapt_tokenizer(tokenizer: "PreTrainedTokenizerBase") -> "PreTrainedTokenizerBase":
"""Adapt a tokenizer to use to compile the FSM.
The API of Outlines tokenizers is slightly different to that of `transformers`. In
Expand All @@ -205,6 +204,8 @@ def adapt_tokenizer(tokenizer: PreTrainedTokenizerBase) -> PreTrainedTokenizerBa
PreTrainedTokenizerBase
The adapted tokenizer.
"""
from transformers import SPIECE_UNDERLINE

tokenizer.vocabulary = tokenizer.get_vocab()
tokenizer.special_tokens = set(tokenizer.all_special_tokens)

Expand Down

0 comments on commit 11ae206

Please sign in to comment.