Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transformers use logits processor #31

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

lapp0
Copy link
Owner

@lapp0 lapp0 commented Jun 12, 2024

Fixes dottxt-ai#806

Fixes dottxt-ai#789

Closes dottxt-ai#910

Problem

For outlines.models.transformers, instead of using logits processors which encapsulate automata management, SequenceGenerator directly manages the automata. This different implementation resulted in dottxt-ai#789's bug.

Solution

  • Implement Transformers.generate and Transformers.stream which use HF transformers logits_processor argument with outlines.processors.OutlinesLogitsProcessor
  • Use SequenceGeneratorAdapter for transformers instead of SequenceGenerator

TODO:

  • implement Transformers.generate and Transformers.stream
  • implement SequenceGeneratorAdapter version of outlines.models.transformers
  • unit tests
  • await mlx merge and rebase onto main
  • update transformers integration documentation
  • revert llamacpp and vllm changes, these will be in a separate PR
  • ~~logits processor profiling in benchmarks~
    • will do in logits processor unification PR
  • ping people who've requested this

Bonus

This new structure allows us to easily integrate multi-modal models by subclassing models.Transformer. Additionally, we can make models.mamba a Transformer model and just pass model_class=MambaLMHeadModel.

Multi-modal model example:

from outlines.processors import RegexLogitsProcessor
from outlines.models.transformers import TransformerTokenizer

from PIL import Image
import requests
from transformers import AutoProcessor, LlavaForConditionalGeneration, LogitsProcessorList, AutoTokenizer


model_uri = "llava-hf/llava-1.5-7b-hf"

url = "https://www.ilankelman.org/stopsigns/australia.jpg"
prompt = "USER: <image>\nWhat's the content of the image? ASSISTANT:"
output_pattern = r"This is like, totally an image of .*"


model = LlavaForConditionalGeneration.from_pretrained(model_uri, load_in_4bit=True)
llava_processor = AutoProcessor.from_pretrained(model_uri)
regex_logits_processor = RegexLogitsProcessor(
    output_pattern,
    TransformerTokenizer(AutoTokenizer.from_pretrained(model_uri)),
)

inputs = llava_processor(
    text=prompt,
    images=Image.open(requests.get(url, stream=True).raw),
    return_tensors="pt"
)

# Generate
generate_ids = model.generate(
    **inputs,
    logits_processor=LogitsProcessorList([
        regex_logits_processor
    ]),
    max_new_tokens=30
)

result = llava_processor.batch_decode(
    generate_ids,
    skip_special_tokens=True,
    clean_up_tokenization_spaces=False
)[0]
print(result)
# USER:\nWhat's the content of the image? ASSISTANT:This is like, totally an image of a stop sign on a street.

@lapp0 lapp0 force-pushed the transformers-use-logits-processor branch 24 times, most recently from 3f00ec7 to d9d650c Compare June 12, 2024 22:42
@lapp0 lapp0 force-pushed the transformers-use-logits-processor branch 5 times, most recently from 6ea3047 to b07ac99 Compare June 12, 2024 23:10
@lapp0 lapp0 force-pushed the transformers-use-logits-processor branch 17 times, most recently from 8f9c317 to 6ea583e Compare June 18, 2024 04:46
@lapp0 lapp0 force-pushed the transformers-use-logits-processor branch 2 times, most recently from f5ae15e to b75beeb Compare June 21, 2024 14:59
@lapp0 lapp0 force-pushed the transformers-use-logits-processor branch 7 times, most recently from c24b1fa to 32319df Compare July 2, 2024 20:11
@lapp0 lapp0 force-pushed the transformers-use-logits-processor branch from 32319df to 7d43bbd Compare July 3, 2024 14:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Update the transformers integration RegexPrefixAllowedTokens does not work for batch
1 participant