Skip to content

Commit

Permalink
[V1] VLM - preprocessor hashing
Browse files Browse the repository at this point in the history
Signed-off-by: Roger Wang <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
Co-authored-by: Roger Wang <[email protected]>
Signed-off-by: Alexander Matveev <[email protected]>
  • Loading branch information
3 people committed Dec 11, 2024
1 parent cad5c0a commit adda9d4
Show file tree
Hide file tree
Showing 9 changed files with 330 additions and 48 deletions.
126 changes: 109 additions & 17 deletions examples/offline_inference_vision_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
For most models, the prompt format should follow corresponding examples
on HuggingFace model repository.
"""
import random

from transformers import AutoTokenizer

from vllm import LLM, SamplingParams
Expand All @@ -23,7 +25,9 @@ def run_llava(question: str, modality: str):

prompt = f"USER: <image>\n{question}\nASSISTANT:"

llm = LLM(model="llava-hf/llava-1.5-7b-hf", max_model_len=4096)
llm = LLM(model="llava-hf/llava-1.5-7b-hf",
max_model_len=4096,
mm_cache_preprocessor=args.mm_cache_preprocessor)
stop_token_ids = None
return llm, prompt, stop_token_ids

Expand All @@ -33,7 +37,9 @@ def run_llava_next(question: str, modality: str):
assert modality == "image"

prompt = f"[INST] <image>\n{question} [/INST]"
llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf", max_model_len=8192)
llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf",
max_model_len=8192,
mm_cache_preprocessor=args.mm_cache_preprocessor)
stop_token_ids = None
return llm, prompt, stop_token_ids

Expand All @@ -44,7 +50,9 @@ def run_llava_next_video(question: str, modality: str):
assert modality == "video"

prompt = f"USER: <video>\n{question} ASSISTANT:"
llm = LLM(model="llava-hf/LLaVA-NeXT-Video-7B-hf", max_model_len=8192)
llm = LLM(model="llava-hf/LLaVA-NeXT-Video-7B-hf",
max_model_len=8192,
mm_cache_preprocessor=args.mm_cache_preprocessor)
stop_token_ids = None
return llm, prompt, stop_token_ids

Expand All @@ -61,7 +69,8 @@ def run_llava_onevision(question: str, modality: str):
<|im_start|>assistant\n"

llm = LLM(model="llava-hf/llava-onevision-qwen2-7b-ov-hf",
max_model_len=16384)
max_model_len=16384,
mm_cache_preprocessor=args.mm_cache_preprocessor)
stop_token_ids = None
return llm, prompt, stop_token_ids

Expand All @@ -71,7 +80,10 @@ def run_fuyu(question: str, modality: str):
assert modality == "image"

prompt = f"{question}\n"
llm = LLM(model="adept/fuyu-8b", max_model_len=2048, max_num_seqs=2)
llm = LLM(model="adept/fuyu-8b",
max_model_len=2048,
max_num_seqs=2,
mm_cache_preprocessor=args.mm_cache_preprocessor)
stop_token_ids = None
return llm, prompt, stop_token_ids

Expand Down Expand Up @@ -107,6 +119,7 @@ def run_phi3v(question: str, modality: str):
max_num_seqs=2,
# Note - mm_processor_kwargs can also be passed to generate/chat calls
mm_processor_kwargs={"num_crops": 16},
mm_cache_preprocessor=args.mm_cache_preprocessor,
)
stop_token_ids = None
return llm, prompt, stop_token_ids
Expand All @@ -118,7 +131,8 @@ def run_paligemma(question: str, modality: str):

# PaliGemma has special prompt format for VQA
prompt = "caption en"
llm = LLM(model="google/paligemma-3b-mix-224")
llm = LLM(model="google/paligemma-3b-mix-224",
mm_cache_preprocessor=args.mm_cache_preprocessor)
stop_token_ids = None
return llm, prompt, stop_token_ids

Expand All @@ -128,7 +142,9 @@ def run_chameleon(question: str, modality: str):
assert modality == "image"

prompt = f"{question}<image>"
llm = LLM(model="facebook/chameleon-7b", max_model_len=4096)
llm = LLM(model="facebook/chameleon-7b",
max_model_len=4096,
mm_cache_preprocessor=args.mm_cache_preprocessor)
stop_token_ids = None
return llm, prompt, stop_token_ids

Expand All @@ -154,6 +170,7 @@ def run_minicpmv(question: str, modality: str):
max_model_len=4096,
max_num_seqs=2,
trust_remote_code=True,
mm_cache_preprocessor=args.mm_cache_preprocessor,
)
# NOTE The stop_token_ids are different for various versions of MiniCPM-V
# 2.0
Expand Down Expand Up @@ -186,6 +203,7 @@ def run_h2ovl(question: str, modality: str):
model=model_name,
trust_remote_code=True,
max_model_len=8192,
mm_cache_preprocessor=args.mm_cache_preprocessor,
)

tokenizer = AutoTokenizer.from_pretrained(model_name,
Expand All @@ -211,6 +229,7 @@ def run_internvl(question: str, modality: str):
model=model_name,
trust_remote_code=True,
max_model_len=4096,
mm_cache_preprocessor=args.mm_cache_preprocessor,
)

tokenizer = AutoTokenizer.from_pretrained(model_name,
Expand Down Expand Up @@ -241,6 +260,7 @@ def run_nvlm_d(question: str, modality: str):
trust_remote_code=True,
max_model_len=4096,
tensor_parallel_size=4,
mm_cache_preprocessor=args.mm_cache_preprocessor,
)

tokenizer = AutoTokenizer.from_pretrained(model_name,
Expand All @@ -260,7 +280,8 @@ def run_blip2(question: str, modality: str):
# BLIP-2 prompt format is inaccurate on HuggingFace model repository.
# See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa
prompt = f"Question: {question} Answer:"
llm = LLM(model="Salesforce/blip2-opt-2.7b")
llm = LLM(model="Salesforce/blip2-opt-2.7b",
mm_cache_preprocessor=args.mm_cache_preprocessor)
stop_token_ids = None
return llm, prompt, stop_token_ids

Expand All @@ -274,6 +295,7 @@ def run_qwen_vl(question: str, modality: str):
trust_remote_code=True,
max_model_len=1024,
max_num_seqs=2,
mm_cache_preprocessor=args.mm_cache_preprocessor,
)

prompt = f"{question}Picture 1: <img></img>\n"
Expand All @@ -296,6 +318,7 @@ def run_qwen2_vl(question: str, modality: str):
"min_pixels": 28 * 28,
"max_pixels": 1280 * 28 * 28,
},
mm_cache_preprocessor=args.mm_cache_preprocessor,
)

prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
Expand All @@ -315,6 +338,7 @@ def run_pixtral_hf(question: str, modality: str):
llm = LLM(
model=model_name,
max_model_len=8192,
mm_cache_preprocessor=args.mm_cache_preprocessor,
)

prompt = f"<s>[INST]{question}\n[IMG][/INST]"
Expand All @@ -338,6 +362,7 @@ def run_mllama(question: str, modality: str):
max_model_len=4096,
max_num_seqs=16,
enforce_eager=True,
mm_cache_preprocessor=args.mm_cache_preprocessor,
)

prompt = f"<|image|><|begin_of_text|>{question}"
Expand All @@ -355,6 +380,7 @@ def run_molmo(question, modality):
model=model_name,
trust_remote_code=True,
dtype="bfloat16",
mm_cache_preprocessor=args.mm_cache_preprocessor,
)

prompt = question
Expand All @@ -371,7 +397,8 @@ def run_glm4v(question: str, modality: str):
max_model_len=2048,
max_num_seqs=2,
trust_remote_code=True,
enforce_eager=True)
enforce_eager=True,
mm_cache_preprocessor=args.mm_cache_preprocessor)
prompt = question
stop_token_ids = [151329, 151336, 151338]
return llm, prompt, stop_token_ids
Expand All @@ -394,6 +421,7 @@ def run_idefics3(question: str, modality: str):
"longest_edge": 3 * 364
},
},
mm_cache_preprocessor=args.mm_cache_preprocessor,
)
prompt = (
f"<|begin_of_text|>User:<image>{question}<end_of_utterance>\nAssistant:"
Expand All @@ -410,7 +438,8 @@ def run_aria(question: str, modality: str):
llm = LLM(model=model_name,
tokenizer_mode="slow",
trust_remote_code=True,
dtype="bfloat16")
dtype="bfloat16",
mm_cache_preprocessor=args.mm_cache_preprocessor)

prompt = (f"<|im_start|>user\n<fim_prefix><|img|><fim_suffix>\n{question}"
"<|im_end|>\n<|im_start|>assistant\n")
Expand All @@ -430,6 +459,7 @@ def run_mantis(question: str, modality: str):
model="TIGER-Lab/Mantis-8B-siglip-llama3",
max_model_len=4096,
hf_overrides={"architectures": ["MantisForConditionalGeneration"]},
mm_cache_preprocessor=args.mm_cache_preprocessor,
)
stop_token_ids = [128009]
return llm, prompt, stop_token_ids
Expand Down Expand Up @@ -494,6 +524,35 @@ def get_multi_modal_input(args):
raise ValueError(msg)


def apply_image_repeat(image_repeat_prob, num_prompts, data, prompt, modality):
"""Repeats images with provided probability of "image_repeat_prob".
Used to simulate hit/miss for the MM preprocessor cache.
"""
assert (image_repeat_prob <= 1.0 and image_repeat_prob >= 0)
no_yes = [0, 1]
probs = [1.0 - image_repeat_prob, image_repeat_prob]

inputs = []
cur_image = data
for i in range(num_prompts):
if image_repeat_prob is not None:
res = random.choices(no_yes, probs)[0]
if res == 0:
# No repeat => Modify one pixel
cur_image = cur_image.copy()
new_val = (i // 256 // 256, i // 256, i % 256)
cur_image.putpixel((0, 0), new_val)

inputs.append({
"prompt": prompt,
"multi_modal_data": {
modality: cur_image
}
})

return inputs


def main(args):
model = args.model_type
if model not in model_example_map:
Expand Down Expand Up @@ -524,14 +583,29 @@ def main(args):

else:
# Batch inference
inputs = [{
"prompt": prompt,
"multi_modal_data": {
modality: data
},
} for _ in range(args.num_prompts)]
if args.image_repeat_prob is not None:
# Repeat images with specified probability of "image_repeat_prob"
inputs = apply_image_repeat(args.image_repeat_prob,
args.num_prompts, data, prompt,
modality)
else:
# Use the same image for all prompts
inputs = [{
"prompt": prompt,
"multi_modal_data": {
modality: data
},
} for _ in range(args.num_prompts)]

if args.time_generate:
import time
start_time = time.time()
outputs = llm.generate(inputs, sampling_params=sampling_params)
elapsed_time = time.time() - start_time
print("-- generate time = {}".format(elapsed_time))

outputs = llm.generate(inputs, sampling_params=sampling_params)
else:
outputs = llm.generate(inputs, sampling_params=sampling_params)

for o in outputs:
generated_text = o.outputs[0].text
Expand Down Expand Up @@ -561,5 +635,23 @@ def main(args):
type=int,
default=16,
help='Number of frames to extract from the video.')

parser.add_argument(
'--image-repeat-prob',
type=float,
default=None,
help='Simulates the hit-ratio for multi-modal preprocessor cache'
' (if enabled)')

parser.add_argument(
'--mm-cache-preprocessor',
action='store_true',
help='If True, enable caching of multi-modal preprocessor/mapper.')

parser.add_argument(
'--time-generate',
action='store_true',
help='If True, then print the total generate() call time')

args = parser.parse_args()
main(args)
1 change: 1 addition & 0 deletions requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ sentencepiece # Required for LLaMA tokenizer.
numpy < 2.0.0
requests >= 2.26.0
tqdm
blake3
py-cpuinfo
transformers >= 4.45.2 # Required for Llama 3.2 and Qwen2-VL.
tokenizers >= 0.19.1 # Required for Llama 3.
Expand Down
10 changes: 8 additions & 2 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,9 @@ class ModelConfig:
HuggingFace config.
mm_processor_kwargs: Arguments to be forwarded to the model's processor
for multi-modal data, e.g., image processor.
mm_cache_preprocessor: If true, then enables caching of the multi-modal
preprocessor/mapper. Otherwise, the mapper executes each time, and
for better performance consider enabling frontend process.
override_neuron_config: Initialize non default neuron config or
override default neuron config that are specific to Neuron devices,
this argument will be used to configure the neuron config that
Expand Down Expand Up @@ -185,6 +188,7 @@ def __init__(
config_format: ConfigFormat = ConfigFormat.AUTO,
hf_overrides: Optional[HfOverrides] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
mm_cache_preprocessor: bool = False,
override_neuron_config: Optional[Dict[str, Any]] = None,
override_pooler_config: Optional["PoolerConfig"] = None) -> None:
self.model = model
Expand Down Expand Up @@ -251,6 +255,7 @@ def __init__(
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
self.use_async_output_proc = use_async_output_proc
self.mm_processor_kwargs = mm_processor_kwargs
self.mm_cache_preprocessor = mm_cache_preprocessor

# Set enforce_eager to False if the value is unset.
if self.enforce_eager is None:
Expand Down Expand Up @@ -2684,9 +2689,10 @@ def __str__(self):
f"enable_prefix_caching={self.cache_config.enable_prefix_caching}, "
f"chunked_prefill_enabled={self.scheduler_config.chunked_prefill_enabled}, " # noqa
f"use_async_output_proc={self.model_config.use_async_output_proc}, "
f"mm_cache_preprocessor={self.model_config.mm_cache_preprocessor!r}, " # noqa
f"mm_processor_kwargs={self.model_config.mm_processor_kwargs}, "
f"pooler_config={self.model_config.pooler_config!r},"
f" compilation_config={self.compilation_config!r}")
f"pooler_config={self.model_config.pooler_config!r}, "
f"compilation_config={self.compilation_config!r}")


_current_vllm_config: Optional[VllmConfig] = None
Expand Down
8 changes: 8 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ class EngineArgs:
tokenizer_pool_extra_config: Optional[Dict[str, Any]] = None
limit_mm_per_prompt: Optional[Mapping[str, int]] = None
mm_processor_kwargs: Optional[Dict[str, Any]] = None
mm_cache_preprocessor: bool = False
enable_lora: bool = False
enable_lora_bias: bool = False
max_loras: int = 1
Expand Down Expand Up @@ -593,6 +594,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
type=json.loads,
help=('Overrides for the multimodal input mapping/processing, '
'e.g., image processor. For example: {"num_crops": 4}.'))
parser.add_argument(
'--mm-cache-preprocessor',
action='store_true',
help='If true, then enables caching of the multi-modal '
'preprocessor/mapper. Otherwise, the mapper executes each time'
', and for better performance consider enabling frontend process.')

# LoRA related configs
parser.add_argument('--enable-lora',
Expand Down Expand Up @@ -965,6 +972,7 @@ def create_model_config(self) -> ModelConfig:
use_async_output_proc=not self.disable_async_output_proc,
config_format=self.config_format,
mm_processor_kwargs=self.mm_processor_kwargs,
mm_cache_preprocessor=self.mm_cache_preprocessor,
override_neuron_config=self.override_neuron_config,
override_pooler_config=self.override_pooler_config,
)
Expand Down
3 changes: 2 additions & 1 deletion vllm/v1/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ class EngineCoreRequest:
# always be tokenized?
prompt: Optional[str]
prompt_token_ids: List[int]
mm_inputs: Optional[List[MultiModalKwargs]]
mm_inputs: Optional[List[Optional[MultiModalKwargs]]]
mm_hashes: Optional[List[Optional[str]]]
mm_placeholders: Optional[MultiModalPlaceholderDict]
sampling_params: SamplingParams
eos_token_id: Optional[int]
Expand Down
Loading

0 comments on commit adda9d4

Please sign in to comment.