Skip to content

Commit

Permalink
[Model] Replace embedding models with pooling adapter (#10769)
Browse files Browse the repository at this point in the history
Signed-off-by: DarkLight1337 <[email protected]>
  • Loading branch information
DarkLight1337 authored Dec 1, 2024
1 parent 7e4bbda commit 1337071
Show file tree
Hide file tree
Showing 32 changed files with 387 additions and 323 deletions.
4 changes: 2 additions & 2 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,6 @@ steps:
commands:
- pytest -v -s models/decoder_only/language -m 'core_model or quant_model'
- pytest -v -s models/embedding/language -m core_model
- pytest -v -s models/embedding/vision_language -m core_model

- label: Language Models Test (Extended) # 50min
optional: true
Expand All @@ -346,7 +345,6 @@ steps:
commands:
- pytest -v -s models/decoder_only/language -m 'not core_model and not quant_model'
- pytest -v -s models/embedding/language -m 'not core_model'
- pytest -v -s models/embedding/vision_language -m 'not core_model'

- label: Multi-Modal Models Test (Standard) # 26min
#mirror_hardwares: [amd]
Expand All @@ -359,6 +357,7 @@ steps:
commands:
- pytest -v -s models/decoder_only/audio_language -m 'core_model or quant_model'
- pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'core_model or quant_model'
- pytest -v -s models/embedding/vision_language -m core_model
- pytest -v -s models/encoder_decoder/language -m core_model
- pytest -v -s models/encoder_decoder/vision_language -m core_model

Expand All @@ -376,6 +375,7 @@ steps:
# https://github.com/huggingface/transformers/issues/34307
- pytest -v -s models/decoder_only/vision_language/test_phi3v.py
- pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'not core_model and not quant_model'
- pytest -v -s models/embedding/vision_language -m 'not core_model'
- pytest -v -s models/encoder_decoder/language -m 'not core_model'
- pytest -v -s models/encoder_decoder/vision_language -m 'not core_model'

Expand Down
15 changes: 14 additions & 1 deletion docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ Text Embedding
- ✅︎
* - :code:`Qwen2Model`, :code:`Qwen2ForCausalLM`
- Qwen2-based
- :code:`ssmits/Qwen2-7B-Instruct-embed-base`, :code:`Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc.
- :code:`ssmits/Qwen2-7B-Instruct-embed-base` (see note), :code:`Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc.
- ✅︎
- ✅︎
* - :code:`RobertaModel`, :code:`RobertaForMaskedLM`
Expand All @@ -378,6 +378,10 @@ Text Embedding
.. tip::
You can override the model's pooling method by passing :code:`--override-pooler-config`.

.. note::
:code:`ssmits/Qwen2-7B-Instruct-embed-base` has an improperly defined Sentence Transformers config.
You should manually set mean pooling by passing :code:`--override-pooler-config '{"pooling_type": "MEAN"}'`.

.. note::
Unlike base Qwen2, :code:`Alibaba-NLP/gte-Qwen2-7B-instruct` uses bi-directional attention.
You can set :code:`--hf-overrides '{"is_causal": false}'` to change the attention mask accordingly.
Expand All @@ -397,12 +401,21 @@ Reward Modeling
- Example HF Models
- :ref:`LoRA <lora>`
- :ref:`PP <distributed_serving>`
* - :code:`LlamaForCausalLM`
- Llama-based
- :code:`peiyi9979/math-shepherd-mistral-7b-prm`, etc.
- ✅︎
- ✅︎
* - :code:`Qwen2ForRewardModel`
- Qwen2-based
- :code:`Qwen/Qwen2.5-Math-RM-72B`, etc.
- ✅︎
- ✅︎

.. important::
For process-supervised reward models such as :code:`peiyi9979/math-shepherd-mistral-7b-prm`, the pooling config should be set explicitly,
e.g.: :code:`--override-pooler-config '{"pooling_type": "STEP", "step_tag_id": 123, "returned_token_ids": [456, 789]}'`.

.. note::
As an interim measure, these models are supported in both offline and online inference via Embeddings API.

Expand Down
1 change: 0 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,6 @@ def __init__(
dtype: str = "half",
*,
model_kwargs: Optional[Dict[str, Any]] = None,
is_embedding_model: bool = False,
is_sentence_transformer: bool = False,
is_cross_encoder: bool = False,
skip_tokenizer_init: bool = False,
Expand Down
5 changes: 5 additions & 0 deletions tests/models/embedding/language/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
"""
import pytest

from vllm.config import PoolerConfig

from ..utils import check_embeddings_close


Expand Down Expand Up @@ -33,6 +35,9 @@ def test_models(
dtype: str,
) -> None:
vllm_extra_kwargs = {}
if model == "ssmits/Qwen2-7B-Instruct-embed-base":
vllm_extra_kwargs["override_pooler_config"] = \
PoolerConfig(pooling_type="MEAN")
if model == "Alibaba-NLP/gte-Qwen2-7B-instruct":
vllm_extra_kwargs["hf_overrides"] = {"is_causal": False}

Expand Down
31 changes: 14 additions & 17 deletions tests/models/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,8 @@
from vllm.model_executor.models import (is_embedding_model,
is_text_generation_model,
supports_multimodal)
# yapf conflicts with isort for this block
# yapf: disable
from vllm.model_executor.models.registry import (_CROSS_ENCODER_MODELS,
_EMBEDDING_MODELS,
_MULTIMODAL_MODELS,
from vllm.model_executor.models.adapters import as_embedding_model
from vllm.model_executor.models.registry import (_MULTIMODAL_MODELS,
_SPECULATIVE_DECODING_MODELS,
_TEXT_GENERATION_MODELS,
ModelRegistry)
Expand All @@ -26,18 +23,18 @@ def test_registry_imports(model_arch):
model_cls, _ = ModelRegistry.resolve_model_cls(model_arch)

if model_arch in _SPECULATIVE_DECODING_MODELS:
pass # Ignore these models which do not have a unified format
else:
assert is_text_generation_model(model_cls) is (
model_arch in _TEXT_GENERATION_MODELS
or model_arch in _MULTIMODAL_MODELS)

embedding_models = {**_EMBEDDING_MODELS, **_CROSS_ENCODER_MODELS}
assert is_embedding_model(model_cls) is (model_arch
in embedding_models)

assert supports_multimodal(model_cls) is (model_arch
in _MULTIMODAL_MODELS)
return # Ignore these models which do not have a unified format

if (model_arch in _TEXT_GENERATION_MODELS
or model_arch in _MULTIMODAL_MODELS):
assert is_text_generation_model(model_cls)

# All vLLM models should be convertible to an embedding model
embed_model = as_embedding_model(model_cls)
assert is_embedding_model(embed_model)

if model_arch in _MULTIMODAL_MODELS:
assert supports_multimodal(model_cls)


@fork_new_process_for_each_test
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,34 @@
from typing import List, Optional, Union
from typing import Iterable, List, Optional, Tuple, Union

import torch
import torch.nn as nn

from vllm.attention import AttentionMetadata
from vllm.model_executor.models.gemma2 import Gemma2EmbeddingModel
from vllm.sequence import IntermediateTensors
from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.models.gemma2 import Gemma2Model
from vllm.model_executor.models.utils import WeightsMapper, maybe_prefix
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput


class MyGemma2Embedding(Gemma2EmbeddingModel):
class MyGemma2Embedding(nn.Module):

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()

self.model = Gemma2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))

self._pooler = Pooler.from_config_with_defaults(
vllm_config.model_config.pooler_config,
pooling_type=PoolingType.LAST,
normalize=True,
softmax=False,
)

self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)

def forward(
self,
Expand All @@ -18,7 +39,7 @@ def forward(
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = super().forward(
hidden_states = self.model(
input_ids,
positions,
kv_caches,
Expand All @@ -32,3 +53,17 @@ def forward(

# Return all-zero embeddings
return torch.zeros_like(hidden_states)

def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
weights = hf_to_vllm_mapper.apply(weights)
weights = ((name, data) for name, data in weights
if not name.startswith("lm_head."))
return self.model.load_weights(weights)
3 changes: 1 addition & 2 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ def test_auto_task(model_id, expected_task):


@pytest.mark.parametrize(("model_id", "bad_task"), [
("facebook/opt-125m", "embedding"),
("intfloat/e5-mistral-7b-instruct", "generate"),
("Qwen/Qwen2.5-Math-RM-72B", "generate"),
])
def test_incorrect_task(model_id, bad_task):
with pytest.raises(ValueError, match=r"does not support the .* task"):
Expand Down
25 changes: 25 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,31 @@ def _resolve_task(
selected_task = next(iter(supported_tasks_lst))

if len(supported_tasks) > 1:
suffix_to_preferred_task: List[Tuple[str, _Task]] = [
# Hardcode the models that are exceptions
("AquilaModel", "generate"),
("ChatGLMModel", "generate"),
# Other models follow this pattern
("ForCausalLM", "generate"),
("ForConditionalGeneration", "generate"),
("ChatModel", "generate"),
("LMHeadModel", "generate"),
("EmbeddingModel", "embedding"),
("RewardModel", "embedding"),
("ForSequenceClassification", "embedding"),
]
info, arch = ModelRegistry.inspect_model_cls(architectures)

for suffix, pref_task in suffix_to_preferred_task:
if arch.endswith(suffix) and pref_task in supported_tasks:
selected_task = pref_task
break
else:
if (arch.endswith("Model")
and info.architecture.endswith("ForCausalLM")
and "embedding" in supported_tasks):
selected_task = "embedding"

logger.info(
"This model supports multiple tasks: %s. "
"Defaulting to '%s'.", supported_tasks, selected_task)
Expand Down
16 changes: 8 additions & 8 deletions vllm/inputs/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from vllm.logger import init_logger
from vllm.transformers_utils.processor import cached_get_processor
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import (get_allowed_kwarg_only_overrides, print_warning_once,
resolve_mm_processor_kwargs)
from vllm.utils import (ClassRegistry, get_allowed_kwarg_only_overrides,
print_warning_once, resolve_mm_processor_kwargs)

from .data import ProcessorInputs, SingletonInputs
from .parse import is_encoder_decoder_inputs
Expand Down Expand Up @@ -136,12 +136,12 @@ class InputRegistry:
"""

def __init__(self) -> None:
self._dummy_factories_by_model_type: Dict[Type[nn.Module],
DummyDataFactory] = {}
self._dummy_encoder_factories_by_model_type: Dict[
Type[nn.Module], DummyDataFactory] = {}
self._input_processors_by_model_type: Dict[Type[nn.Module],
InputProcessor] = {}
self._dummy_factories_by_model_type = \
ClassRegistry[nn.Module, DummyDataFactory]()
self._dummy_encoder_factories_by_model_type = \
ClassRegistry[nn.Module, DummyDataFactory]()
self._input_processors_by_model_type = \
ClassRegistry[nn.Module, InputProcessor]()

def _default_dummy_data_factory(
self,
Expand Down
4 changes: 1 addition & 3 deletions vllm/model_executor/layers/pooler.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,7 @@ def from_config_with_defaults(
softmax: bool,
step_tag_id: Optional[int] = None,
returned_token_ids: Optional[List[int]] = None,
) -> Optional["Pooler"]:
if pooler_config is None:
return None
) -> "Pooler":
return cls(
pooling_type=PoolingType[pooler_config.pooling_type]
if pooler_config.pooling_type is not None else pooling_type,
Expand Down
18 changes: 14 additions & 4 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import json
import math
import os
import warnings
from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast
Expand Down Expand Up @@ -97,22 +98,31 @@ def device_loading_context(module: torch.nn.Module,
logger = init_logger(__name__)


def _initialize_model(vllm_config: VllmConfig, prefix: str = "") -> nn.Module:
def _initialize_model(
vllm_config: VllmConfig,
*,
prefix: str = "",
architectures: Optional[list[str]] = None,
) -> nn.Module:
"""Initialize a model with the given configurations."""
model_config = vllm_config.model_config
model_class, _ = get_model_architecture(model_config)
model_class, _ = get_model_architecture(model_config,
architectures=architectures)

signatures = inspect.signature(model_class.__init__)
all_params = [param.name for param in signatures.parameters.values()]
if "vllm_config" in all_params and "prefix" in all_params:
# new-style model class
with set_current_vllm_config(vllm_config):
return model_class(vllm_config=vllm_config, prefix=prefix)

msg = ("vLLM model class should accept `vllm_config` and `prefix` as "
"input arguments. Possibly you have an old-style model class"
" registered from out of tree and it is used for new vLLM version. "
"Check https://docs.vllm.ai/en/latest/design/arch_overview.html "
"for the design and update the model class accordingly.")
logger.warning(msg)
warnings.warn(msg, DeprecationWarning, stacklevel=2)

logger.warning(
"Trying to guess the arguments for old-style model class %s",
model_class,
Expand Down Expand Up @@ -356,7 +366,7 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module:
weights_to_load = {name for name, _ in model.named_parameters()}
loaded_weights = model.load_weights(
self._get_all_weights(model_config, model))
# We only enable strict check for non-quantiized models
# We only enable strict check for non-quantized models
# that have loaded weights tracking currently.
if model_config.quantization is None and loaded_weights is not None:
weights_not_loaded = weights_to_load - loaded_weights
Expand Down
18 changes: 14 additions & 4 deletions vllm/model_executor/model_loader/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""Utilities for selecting and loading models."""
import contextlib
from typing import Tuple, Type
from typing import Optional, Tuple, Type

import torch
from torch import nn

from vllm.config import ModelConfig
from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.models.adapters import as_embedding_model


@contextlib.contextmanager
Expand All @@ -19,8 +20,13 @@ def set_default_torch_dtype(dtype: torch.dtype):


def get_model_architecture(
model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
architectures = getattr(model_config.hf_config, "architectures", [])
model_config: ModelConfig,
*,
architectures: Optional[list[str]] = None,
) -> Tuple[Type[nn.Module], str]:
if architectures is None:
architectures = getattr(model_config.hf_config, "architectures", [])

# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
mixtral_supported = [
Expand All @@ -32,7 +38,11 @@ def get_model_architecture(
and "MixtralForCausalLM" in architectures):
architectures = ["QuantMixtralForCausalLM"]

return ModelRegistry.resolve_model_cls(architectures)
model_cls, arch = ModelRegistry.resolve_model_cls(architectures)
if model_config.task == "embedding":
model_cls = as_embedding_model(model_cls)

return model_cls, arch


def get_architecture_class_name(model_config: ModelConfig) -> str:
Expand Down
Loading

0 comments on commit 1337071

Please sign in to comment.