Skip to content

Commit

Permalink
Introduce outlines.models.transformers_multimodal
Browse files Browse the repository at this point in the history
  • Loading branch information
lapp0 committed Jul 19, 2024
1 parent f6a6c29 commit 9cc0775
Show file tree
Hide file tree
Showing 10 changed files with 605 additions and 79 deletions.
93 changes: 93 additions & 0 deletions docs/reference/models/multimodal.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Transformers MultiModal

Outlines allows seamless use of [multimodal models](https://huggingface.co/learn/computer-vision-course/en/unit4/multimodal-models/tasks-models-part1).

Tasks supported include
- image + text
- video + text -> text
- TODO: look into other models which can be used with no code changes


## Example: Using [Llava-Next](https://huggingface.co/docs/transformers/en/model_doc/llava_next) Vision Models

Install dependencies
`pip install torchvision pillow flash-attn`

Create the model
```python
import outlines

model = outlines.models.transformers_multimodal(
"llava-hf/llava-v1.6-mistral-7b-hf",
device="cuda",
)
```

Create convenience function to load a `PIL.Image` from URL
```
from PIL import Image
from io import BytesIO
from urllib.request import urlopen
def img_from_url(url):
img_byte_stream = BytesIO(urlopen(url).read())
return Image.open(img_byte_stream).convert("RGB")
```

### Describing an image

```python
description_generator = outlines.generate.text(model)
description_generator(
"<image> detailed description:",
[img_from_url("https://upload.wikimedia.org/wikipedia/commons/2/25/Siam_lilacpoint.jpg")]
)
```

> This is a color photograph featuring a Siamese cat with striking blue eyes. The cat has a creamy coat and a light eye color, which is typical for the Siamese breed. Its features include elongated ears, a long, thin tail, and a striking coat pattern. The cat is sitting in an indoor setting, possibly on a cat tower or a similar raised platform, which is covered with a beige fabric, providing a comfortable and soft surface for the cat to rest or perch. The surface of the wall behind the cat appears to be a light-colored stucco or plaster.
### Classifying an Image

```python
pattern = "Mercury|Venus|Earth|Mars|Saturn|Jupiter|Neptune|Uranus|Pluto"
planet_generator = outlines.generate.regex(model, pattern)

planet_generator(
"What planet is this: <image>",
[img_from_url("https://upload.wikimedia.org/wikipedia/commons/e/e3/Saturn_from_Cassini_Orbiter_%282004-10-06%29.jpg")]
)
```

> Saturn

### Extracting Structured Image data

```python
from pydantic import BaseModel
from typing import List, Optional

def img_from_url(url)

class ImageData(BaseModel):
caption: str
tags_list: List[str]
object_list: List[str]
is_photo: bool

image_data_generator = outlines.generate.json(model, ImageData)

image_data_generator(
"<image> detailed JSON metadata:",
[img_from_url("https://upload.wikimedia.org/wikipedia/commons/9/98/Aldrin_Apollo_11_original.jpg")]
)
```

> `ImageData(caption='An astronaut on the moon', tags_list=['moon', 'space', 'nasa', 'americanflag'], object_list=['moon', 'moon_surface', 'space_suit', 'americanflag'], is_photo=True)`

## Resources

### Chosing a model
- https://mmbench.opencompass.org.cn/leaderboard
- https://huggingface.co/spaces/WildVision/vision-arena
109 changes: 100 additions & 9 deletions outlines/generate/api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import datetime
from dataclasses import dataclass
from typing import TYPE_CHECKING, Iterator, List, Optional, Union
from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Union

from outlines.generate.generator import sequence_generator
from outlines.samplers import BeamSearchSampler, GreedySampler, MultinomialSampler
Expand Down Expand Up @@ -479,6 +479,13 @@ def format_sequence(self, sequence: str) -> FormattedOutput:
"""
return sequence

def _format(self, sequences):
"""Apply formatting to every string in a completion."""
if isinstance(sequences, list):
return [self._format(sequence) for sequence in sequences]
else:
return self.format_sequence(sequences)

def __call__(
self,
prompts: Union[str, List[str]],
Expand All @@ -489,13 +496,6 @@ def __call__(
):
"""Generate text from a prompt of list of prompts."""

def format(sequences):
"""Apply formatting to every string in a completion."""
if isinstance(sequences, list):
return [format(sequence) for sequence in sequences]
else:
return self.format_sequence(sequences)

generation_params = self.prepare_generation_parameters(
max_tokens, stop_at, seed
)
Expand All @@ -508,7 +508,7 @@ def format(sequences):
**model_specific_params,
)

return format(completions)
return self._format(completions)

def stream(
self,
Expand All @@ -529,3 +529,94 @@ def stream(
self.sampling_params,
**model_specific_params,
)


class MultiModalSequenceGeneratorAdapter(SequenceGeneratorAdapter):
def __call__( # type: ignore
self,
prompts: Union[str, List[str]],
media: Union[str, Any],
max_tokens: Optional[int] = None,
stop_at: Optional[Union[str, List[str]]] = None,
seed: Optional[int] = None,
**model_specific_params,
):
"""
Generate text from a prompt of list of prompts.
Media: A URI to construct media or media object itself. Used as AutoProcessor argument.
"""
prompts, media = self._prepare_prompts_and_media(prompts, media)

generation_params = self.prepare_generation_parameters(
max_tokens, stop_at, seed
)

completions = self.model.generate(
prompts,
media,
generation_params,
self.logits_processor,
self.sampling_params,
**model_specific_params,
)

return self._format(completions)

def stream( # type: ignore
self,
prompts: Union[str, List[str]],
media: List[Union[str, Any, List[Union[str, Any]]]],
max_tokens: Optional[int] = None,
stop_at: Optional[Union[str, List[str]]] = None,
seed: Optional[int] = None,
**model_specific_params,
):
"""Return a text generator from a prompt or a list of prompts."""
prompts, media = self._prepare_prompts_and_media(prompts, media)
generation_params = self.prepare_generation_parameters(
max_tokens, stop_at, seed
)
return self.model.stream(
prompts,
media,
generation_params,
self.logits_processor,
self.sampling_params,
**model_specific_params,
)

@classmethod
def _prepare_prompts_and_media(
cls,
prompts: Union[str, List[str]],
media: Union[str, Any, List[Union[str, Any]]],
) -> Union[Any, List[Any]]:
"""
Prepare media as PIL.Image and ensure for every prompt str there is one List[PIL.Image]
"""

def valid_types(prompts, media):
from PIL import Image # type: ignore

if isinstance(prompts, list):
if not isinstance(media, list):
return False
for subprompt, submedia in zip(prompts, media):
if not isinstance(subprompt, str) or not all(
isinstance(m, Image.Image) for m in submedia
):
return False
elif isinstance(prompts, str):
if not all(isinstance(m, Image.Image) for m in media):
return False
return True

if not valid_types(prompts, media):
raise TypeError(
"Expected (prompts, media) to be of type "
"(str, List[Image])), or (List[str], List[List[Image]]) "
f"instead got prompts={prompts}, media={media}"
)

return prompts, media
17 changes: 15 additions & 2 deletions outlines/generate/fsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@
import interegular

from outlines.fsm.guide import RegexGuide
from outlines.generate.api import SequenceGenerator, SequenceGeneratorAdapter
from outlines.models import MLXLM, LlamaCpp, Transformers
from outlines.generate.api import (
MultiModalSequenceGeneratorAdapter,
SequenceGenerator,
SequenceGeneratorAdapter,
)
from outlines.models import MLXLM, LlamaCpp, Transformers, TransformersMultiModal
from outlines.samplers import Sampler, multinomial


Expand All @@ -29,3 +33,12 @@ def fsm_unified(
fsm = RegexGuide.from_interegular_fsm(fsm, model.tokenizer)
logits_processor = FSMLogitsProcessor(tokenizer=model.tokenizer, fsm=fsm)
return SequenceGeneratorAdapter(model, logits_processor, sampler)


@fsm.register(TransformersMultiModal)
def fsm_multimodal(model, fsm: interegular.fsm.FSM, sampler: Sampler = multinomial()):
from outlines.processors import FSMLogitsProcessor

fsm = RegexGuide.from_interegular_fsm(fsm, model.tokenizer)
logits_processor = FSMLogitsProcessor(tokenizer=model.tokenizer, fsm=fsm)
return MultiModalSequenceGeneratorAdapter(model, logits_processor, sampler)
31 changes: 25 additions & 6 deletions outlines/generate/regex.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
from functools import singledispatch

from outlines.fsm.guide import RegexGuide
from outlines.generate.api import SequenceGenerator, SequenceGeneratorAdapter
from outlines.models import OpenAI
from outlines.models.llamacpp import LlamaCpp
from outlines.models.mlxlm import MLXLM
from outlines.models.transformers import Transformers
from outlines.models.vllm import VLLM
from outlines.generate.api import (
MultiModalSequenceGeneratorAdapter,
SequenceGenerator,
SequenceGeneratorAdapter,
)
from outlines.models import (
MLXLM,
VLLM,
LlamaCpp,
OpenAI,
Transformers,
TransformersMultiModal,
)
from outlines.samplers import Sampler, multinomial


Expand Down Expand Up @@ -53,6 +60,18 @@ def regex_unified(
return SequenceGeneratorAdapter(model, logits_processor, sampler)


@regex.register(TransformersMultiModal)
def regex_multimodal(
model,
regex_str: str,
sampler: Sampler = multinomial(),
):
from outlines.processors import RegexLogitsProcessor

logits_processor = RegexLogitsProcessor(regex_str, tokenizer=model.tokenizer)
return MultiModalSequenceGeneratorAdapter(model, logits_processor, sampler)


@regex.register(VLLM)
def regex_vllm(
model: VLLM,
Expand Down
20 changes: 18 additions & 2 deletions outlines/generate/text.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,19 @@
from functools import singledispatch

from outlines.fsm.guide import StopAtEOSGuide
from outlines.generate.api import SequenceGenerator, SequenceGeneratorAdapter
from outlines.models import MLXLM, VLLM, LlamaCpp, OpenAI, Transformers
from outlines.generate.api import (
MultiModalSequenceGeneratorAdapter,
SequenceGenerator,
SequenceGeneratorAdapter,
)
from outlines.models import (
MLXLM,
VLLM,
LlamaCpp,
OpenAI,
Transformers,
TransformersMultiModal,
)
from outlines.samplers import Sampler, multinomial


Expand Down Expand Up @@ -43,6 +54,11 @@ def text_unified(model, sampler: Sampler = multinomial()):
return SequenceGeneratorAdapter(model, None, sampler)


@text.register(TransformersMultiModal)
def text_multimodal(model, sampler: Sampler = multinomial()):
return MultiModalSequenceGeneratorAdapter(model, None, sampler)


@text.register(VLLM)
def text_vllm(model: VLLM, sampler: Sampler = multinomial()):
return SequenceGeneratorAdapter(model, None, sampler)
Expand Down
2 changes: 2 additions & 0 deletions outlines/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
codebase.
"""

from typing import Union

from .exllamav2 import ExLlamaV2Model, exl2
from .llamacpp import LlamaCpp, llamacpp
from .mlxlm import MLXLM, mlxlm
from .openai import OpenAI, azure_openai, openai
from .transformers import Transformers, TransformerTokenizer, mamba, transformers
from .transformers_multimodal import TransformersMultiModal, transformers_multimodal
from .vllm import VLLM, vllm

LogitsGenerator = Union[Transformers, LlamaCpp, ExLlamaV2Model, MLXLM, VLLM]
Loading

0 comments on commit 9cc0775

Please sign in to comment.