Skip to content

Commit

Permalink
[V1] VLM - preprocessor hashing
Browse files Browse the repository at this point in the history
Signed-off-by: Roger Wang <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
Co-authored-by: Roger Wang <[email protected]>
  • Loading branch information
3 people committed Dec 9, 2024
1 parent d1c2e15 commit 3554439
Show file tree
Hide file tree
Showing 8 changed files with 220 additions and 32 deletions.
56 changes: 48 additions & 8 deletions examples/offline_inference_vision_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
For most models, the prompt format should follow corresponding examples
on HuggingFace model repository.
"""
import random

from transformers import AutoTokenizer

from vllm import LLM, SamplingParams
Expand All @@ -23,7 +25,11 @@ def run_llava(question: str, modality: str):

prompt = f"USER: <image>\n{question}\nASSISTANT:"

llm = LLM(model="llava-hf/llava-1.5-7b-hf", max_model_len=4096)
llm = LLM(
model="llava-hf/llava-1.5-7b-hf",
max_model_len=4096,
# TODO: Fix this!
mm_cache_preprocessor=args.mm_cache_preprocessor)
stop_token_ids = None
return llm, prompt, stop_token_ids

Expand Down Expand Up @@ -524,14 +530,35 @@ def main(args):

else:
# Batch inference
inputs = [{
"prompt": prompt,
"multi_modal_data": {
modality: data
},
} for _ in range(args.num_prompts)]

if args.image_repeat_ratio is not None:
assert (args.image_repeat_ratio <= 1.0
and args.image_repeat_ratio >= 0)
no_yes = [0, 1]
probs = [1.0 - args.image_repeat_ratio, args.image_repeat_ratio]

inputs = []
cur_image = data
for i in range(args.num_prompts):
if args.image_repeat_ratio is not None:
res = random.choices(no_yes, probs)[0]
if res == 0:
# No repeat => Modify one pixel
cur_image = cur_image.copy()
new_val = (i // 256 // 256, i // 256, i % 256)
cur_image.putpixel((0, 0), new_val)

inputs.append({
"prompt": prompt,
"multi_modal_data": {
modality: cur_image
}
})

import time
start_time = time.time()
outputs = llm.generate(inputs, sampling_params=sampling_params)
elapsed_time = time.time() - start_time
print("-- generate time = {}".format(elapsed_time))

for o in outputs:
generated_text = o.outputs[0].text
Expand Down Expand Up @@ -561,5 +588,18 @@ def main(args):
type=int,
default=16,
help='Number of frames to extract from the video.')

parser.add_argument(
'--image-repeat-ratio',
type=float,
default=None,
help='Simulates the hit-ratio for multi-modal preprocessor cache'
' (if enabled)')

parser.add_argument(
'--mm-cache-preprocessor',
action='store_true',
help='If True, enable caching of multi-modal preprocessor/mapper.')

args = parser.parse_args()
main(args)
9 changes: 7 additions & 2 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ class ModelConfig:
HuggingFace config.
mm_processor_kwargs: Arguments to be forwarded to the model's processor
for multi-modal data, e.g., image processor.
mm_cache_preprocessor: If True, enable caching of multi-modal
preprocessor/mapper.
override_neuron_config: Initialize non default neuron config or
override default neuron config that are specific to Neuron devices,
this argument will be used to configure the neuron config that
Expand Down Expand Up @@ -171,6 +173,7 @@ def __init__(
config_format: ConfigFormat = ConfigFormat.AUTO,
hf_overrides: Optional[HfOverrides] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
mm_cache_preprocessor: bool = False,
override_neuron_config: Optional[Dict[str, Any]] = None,
override_pooler_config: Optional["PoolerConfig"] = None) -> None:
self.model = model
Expand Down Expand Up @@ -237,6 +240,7 @@ def __init__(
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
self.use_async_output_proc = use_async_output_proc
self.mm_processor_kwargs = mm_processor_kwargs
self.mm_cache_preprocessor = mm_cache_preprocessor

# Set enforce_eager to False if the value is unset.
if self.enforce_eager is None:
Expand Down Expand Up @@ -2610,9 +2614,10 @@ def __str__(self):
f"enable_prefix_caching={self.cache_config.enable_prefix_caching}, "
f"chunked_prefill_enabled={self.scheduler_config.chunked_prefill_enabled}, " # noqa
f"use_async_output_proc={self.model_config.use_async_output_proc}, "
f"mm_cache_preprocessor={self.model_config.mm_cache_preprocessor!r}, " # noqa
f"mm_processor_kwargs={self.model_config.mm_processor_kwargs}, "
f"pooler_config={self.model_config.pooler_config!r},"
f" compilation_config={self.compilation_config!r}")
f"pooler_config={self.model_config.pooler_config!r}, "
f"compilation_config={self.compilation_config!r}")


_current_vllm_config: Optional[VllmConfig] = None
Expand Down
6 changes: 6 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ class EngineArgs:
tokenizer_pool_extra_config: Optional[Dict[str, Any]] = None
limit_mm_per_prompt: Optional[Mapping[str, int]] = None
mm_processor_kwargs: Optional[Dict[str, Any]] = None
mm_cache_preprocessor: bool = False
enable_lora: bool = False
enable_lora_bias: bool = False
max_loras: int = 1
Expand Down Expand Up @@ -590,6 +591,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
type=json.loads,
help=('Overrides for the multimodal input mapping/processing, '
'e.g., image processor. For example: {"num_crops": 4}.'))
parser.add_argument(
'--mm-cache-preprocessor',
action='store_true',
help='If True, enable caching of multi-modal preprocessor/mapper.')

# LoRA related configs
parser.add_argument('--enable-lora',
Expand Down Expand Up @@ -962,6 +967,7 @@ def create_model_config(self) -> ModelConfig:
use_async_output_proc=not self.disable_async_output_proc,
config_format=self.config_format,
mm_processor_kwargs=self.mm_processor_kwargs,
mm_cache_preprocessor=self.mm_cache_preprocessor,
override_neuron_config=self.override_neuron_config,
override_pooler_config=self.override_pooler_config,
)
Expand Down
3 changes: 2 additions & 1 deletion vllm/v1/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ class EngineCoreRequest:
# always be tokenized?
prompt: Optional[str]
prompt_token_ids: List[int]
mm_inputs: Optional[List[MultiModalKwargs]]
mm_inputs: Optional[List[Optional[MultiModalKwargs]]]
mm_hashes: Optional[List[Optional[str]]]
mm_placeholders: Optional[MultiModalPlaceholderDict]
sampling_params: SamplingParams
eos_token_id: Optional[int]
Expand Down
14 changes: 10 additions & 4 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
EngineCoreProfile, EngineCoreRequest,
EngineCoreRequestType)
from vllm.v1.engine.mm_input_mapper import MMInputMapper
from vllm.v1.engine.mm_input_mapper import MMInputMapperServer
from vllm.v1.executor.gpu_executor import GPUExecutor
from vllm.v1.request import Request, RequestStatus
from vllm.v1.serial_utils import PickleEncoder
Expand Down Expand Up @@ -55,16 +55,15 @@ def __init__(
vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks

# Set up multimodal input mapper (e.g., convert PIL images to tensors).
self.mm_input_mapper = MMInputMapper(vllm_config.model_config)

# Setup scheduler.
self.scheduler = Scheduler(vllm_config.scheduler_config,
vllm_config.cache_config,
vllm_config.lora_config)

self._last_logging_time = time.time()

self.mm_input_mapper_server = MMInputMapperServer()

def _initialize_kv_caches(self,
cache_config: CacheConfig) -> Tuple[int, int]:
start = time.time()
Expand All @@ -88,7 +87,14 @@ def _initialize_kv_caches(self,

def add_request(self, request: EngineCoreRequest):
"""Add request to the scheduler."""

# Add doc
if request.mm_hashes is not None:
request.mm_inputs = self.mm_input_mapper_server.process_inputs(
request.mm_inputs, request.mm_hashes)

req = Request.from_engine_core_request(request)

self.scheduler.add_request(req)

def abort_requests(self, request_ids: List[str]):
Expand Down
119 changes: 109 additions & 10 deletions vllm/v1/engine/mm_input_mapper.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
from typing import Any, Dict, List, Optional

import PIL
from blake3 import blake3

from vllm.config import ModelConfig
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
MultiModalKwargs, MultiModalRegistry)
from vllm.v1.utils import LRUDictCache

# Both Client and Server must use the same cache size
MM_CACHE_SIZE = 128


class MMInputMapper:
class MMInputMapperClient:

def __init__(
self,
Expand All @@ -18,23 +25,115 @@ def __init__(
model_config)
self.mm_registry.init_mm_limits_per_prompt(model_config)

self.mm_cache = LRUDictCache(MM_CACHE_SIZE)

# Set to None to disable (TODO: Disable!)
self.mm_debug_cache_hit_ratio_steps = 32
self.mm_cache_hits = 0
self.mm_cache_misses = 0

def cache_hit_ratio(self, steps) -> float:
total_steps = self.mm_cache_hits + self.mm_cache_misses

if total_steps > 0 and total_steps % steps == 0:
print("[debug] MMInputMapper: cache_hit_ratio = {}".format(
self.mm_cache_hits / total_steps))

def process_inputs(
self,
mm_data: MultiModalDataDict,
mm_hashes: Optional[List[str]],
mm_processor_kwargs: Optional[Dict[str, Any]],
) -> List[MultiModalKwargs]:
image_inputs = mm_data["image"]
if not isinstance(image_inputs, list):
image_inputs = [image_inputs]

use_hash = mm_hashes is not None
if use_hash:
assert len(image_inputs) == len(mm_hashes) # Sanity

# Process each image input separately so that later we can schedule
# them in a fine-grained manner.
mm_inputs: List[MultiModalKwargs] = []
num_images = len(image_inputs)
for i in range(num_images):
mm_input = self.multi_modal_input_mapper(
{"image": image_inputs[i]},
mm_processor_kwargs=mm_processor_kwargs,
)
mm_inputs.append(mm_input)
return mm_inputs
# Utilize caching (if enabled)
ret_hashes = [] if use_hash else None
ret_inputs: List[MultiModalKwargs] = []
for i in range(len(image_inputs)):
if self.mm_debug_cache_hit_ratio_steps is not None:
self.cache_hit_ratio(self.mm_debug_cache_hit_ratio_steps)

if use_hash:
mm_hash = mm_hashes[i]
mm_input = self.mm_cache.get(mm_hash)
else:
mm_hash = None
mm_input = None

if mm_input is None:
self.mm_cache_misses += 1
mm_input = self.multi_modal_input_mapper(
{"image": [image_inputs[i]]},
mm_processor_kwargs=mm_processor_kwargs,
)

if use_hash:
self.mm_cache.put(mm_hash, mm_input)
else:
self.mm_cache_hits += 1
mm_input = None # Avoids sending mm_input to Server

if use_hash:
ret_hashes.append(mm_hash)
ret_inputs.append(mm_input)

return ret_inputs, ret_hashes


class MMInputMapperServer:

def __init__(self, ):
self.mm_cache = LRUDictCache(MM_CACHE_SIZE)

def process_inputs(
self,
mm_inputs: List[Optional[MultiModalKwargs]],
mm_hashes: List[Optional[str]],
) -> List[MultiModalKwargs]:
assert len(mm_inputs) == len(mm_hashes)

full_mm_inputs = []
for mm_input, mm_hash in zip(mm_inputs, mm_hashes):
if mm_input is None:
mm_input = self.mm_cache.get(mm_hash)
assert mm_input is not None
else:
self.mm_cache.put(mm_hash, mm_input)

full_mm_inputs.append(mm_input)

return full_mm_inputs


class MMHasher:

def __init__(self):
pass

def hash(self, mm_data: MultiModalDataDict) -> List[str]:
image_inputs = mm_data["image"]
if not isinstance(image_inputs, list):
image_inputs = [image_inputs]

ret = []
for image in image_inputs:
assert isinstance(image, PIL.Image.Image)

# Convert image to bytes
bytes = image.tobytes()

# Hash image bytes
hasher = blake3()
hasher.update(bytes)
ret.append(hasher.hexdigest())

return ret
Loading

0 comments on commit 3554439

Please sign in to comment.