Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Replace embedding models with pooling adapter #10769

Merged
merged 25 commits into from
Dec 1, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ Text Embedding
- ✅︎
* - :code:`Qwen2Model`, :code:`Qwen2ForCausalLM`
- Qwen2-based
- :code:`ssmits/Qwen2-7B-Instruct-embed-base`, :code:`Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc.
- :code:`ssmits/Qwen2-7B-Instruct-embed-base` (see note), :code:`Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc.
- ✅︎
- ✅︎
* - :code:`RobertaModel`, :code:`RobertaForMaskedLM`
Expand All @@ -378,6 +378,10 @@ Text Embedding
.. tip::
You can override the model's pooling method by passing :code:`--override-pooler-config`.

.. note::
:code:`ssmits/Qwen2-7B-Instruct-embed-base` has an improperly defined Sentence Transformers config.
You should manually set mean pooling by passing :code:`--override-pooler-config '{"pooling_type": "MEAN"}'`.

.. note::
Unlike base Qwen2, :code:`Alibaba-NLP/gte-Qwen2-7B-instruct` uses bi-directional attention.
You can set :code:`--hf-overrides '{"is_causal": false}'` to change the attention mask accordingly.
Expand Down
5 changes: 5 additions & 0 deletions tests/models/embedding/language/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
"""
import pytest

from vllm.config import PoolerConfig

from ..utils import check_embeddings_close


Expand Down Expand Up @@ -33,6 +35,9 @@ def test_models(
dtype: str,
) -> None:
vllm_extra_kwargs = {}
if model == "ssmits/Qwen2-7B-Instruct-embed-base":
vllm_extra_kwargs["override_pooler_config"] = \
PoolerConfig(pooling_type="MEAN")
if model == "Alibaba-NLP/gte-Qwen2-7B-instruct":
vllm_extra_kwargs["hf_overrides"] = {"is_causal": False}

Expand Down
16 changes: 8 additions & 8 deletions vllm/inputs/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from vllm.logger import init_logger
from vllm.transformers_utils.processor import cached_get_processor
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import (get_allowed_kwarg_only_overrides, print_warning_once,
resolve_mm_processor_kwargs)
from vllm.utils import (ClassRegistry, get_allowed_kwarg_only_overrides,
print_warning_once, resolve_mm_processor_kwargs)

from .data import ProcessorInputs, SingletonInputs
from .parse import is_encoder_decoder_inputs
Expand Down Expand Up @@ -136,12 +136,12 @@ class InputRegistry:
"""

def __init__(self) -> None:
self._dummy_factories_by_model_type: Dict[Type[nn.Module],
DummyDataFactory] = {}
self._dummy_encoder_factories_by_model_type: Dict[
Type[nn.Module], DummyDataFactory] = {}
self._input_processors_by_model_type: Dict[Type[nn.Module],
InputProcessor] = {}
self._dummy_factories_by_model_type = \
ClassRegistry[nn.Module,DummyDataFactory]()
self._dummy_encoder_factories_by_model_type = \
ClassRegistry[nn.Module, DummyDataFactory]()
self._input_processors_by_model_type = \
ClassRegistry[nn.Module, InputProcessor]()

def _default_dummy_data_factory(
self,
Expand Down
5 changes: 4 additions & 1 deletion vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import json
import math
import os
import warnings
from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast
Expand Down Expand Up @@ -107,12 +108,14 @@ def _initialize_model(vllm_config: VllmConfig, prefix: str = "") -> nn.Module:
# new-style model class
with set_current_vllm_config(vllm_config):
return model_class(vllm_config=vllm_config, prefix=prefix)

msg = ("vLLM model class should accept `vllm_config` and `prefix` as "
"input arguments. Possibly you have an old-style model class"
" registered from out of tree and it is used for new vLLM version. "
"Check https://docs.vllm.ai/en/latest/design/arch_overview.html "
"for the design and update the model class accordingly.")
logger.warning(msg)
warnings.warn(msg, DeprecationWarning, stacklevel=2)

logger.warning(
"Trying to guess the arguments for old-style model class %s",
model_class,
Expand Down
7 changes: 6 additions & 1 deletion vllm/model_executor/model_loader/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from vllm.config import ModelConfig
from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.models.adapters import for_embedding


@contextlib.contextmanager
Expand All @@ -32,7 +33,11 @@ def get_model_architecture(
and "MixtralForCausalLM" in architectures):
architectures = ["QuantMixtralForCausalLM"]

return ModelRegistry.resolve_model_cls(architectures)
model_cls, arch = ModelRegistry.resolve_model_cls(architectures)
if model_config.task == "embedding":
model_cls = for_embedding(model_cls)

return model_cls, arch


def get_architecture_class_name(model_config: ModelConfig) -> str:
Expand Down
92 changes: 92 additions & 0 deletions vllm/model_executor/models/adapters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from collections.abc import Iterable
from typing import Any, TypeVar

import torch
import torch.nn as nn

from .interfaces_base import VllmModelForEmbedding, is_embedding_model

_T = TypeVar("_T", bound=type[nn.Module])


def for_embedding(cls: _T) -> _T:
"""Subclass an existing vLLM model to support embeddings."""
Isotr0py marked this conversation as resolved.
Show resolved Hide resolved
# Avoid modifying existing embedding models
if is_embedding_model(cls):
return cls

# Lazy import
from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import (Pooler, PoolerOutput,
PoolingType)
from vllm.model_executor.pooling_metadata import PoolingMetadata

from .utils import AutoWeightsLoader, WeightsMapper

class ModelForEmbedding(cls, VllmModelForEmbedding):

def __init__(
self,
*,
vllm_config: "VllmConfig",
prefix: str = "",
**kwargs: Any,
) -> None:
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)

# These are not used in embedding models
if hasattr(self, "lm_head"):
del self.lm_head
if hasattr(self, "logits_processor"):
del self.logits_processor

pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None

# If the model already defines a pooler instance, don't overwrite it
if not getattr(self, "_pooler", None):
pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.LAST,
normalize=True,
softmax=False,
)
assert pooler is not None
self._pooler = pooler

def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
return self._pooler(hidden_states, pooling_metadata)

def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
# We have deleted this attribute, so don't load it
weights = ((name, data) for name, data in weights
if not name.startswith("lm_head."))

# If `*ForCausalLM` defines `load_weights` on the inner model
# and there are no other inner modules with parameters,
# we support loading from both `*Model` and `*ForCausalLM`
if (hasattr(self, "model") and hasattr(self.model, "load_weights")
and all(name == "model" or all(False
for _ in child.parameters())
for name, child in self.named_children())):
Isotr0py marked this conversation as resolved.
Show resolved Hide resolved
mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
weights = mapper.apply(weights)

self.model.load_weights(weights)
# For most other models
elif hasattr(cls, "load_weights"):
cls.load_weights(self, weights) # type: ignore
# Fallback
else:
loader = AutoWeightsLoader(self)
loader.load_weights(weights)

ModelForEmbedding.__name__ = cls.__name__ \
.removesuffix("ForCausalLM") \
.removesuffix("ForConditionalGeneration") + "ForEmbedding"

return ModelForEmbedding # type: ignore
58 changes: 2 additions & 56 deletions vllm/model_executor/models/gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,17 @@
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.sequence import IntermediateTensors

from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, extract_layer_index,
from .utils import (AutoWeightsLoader, extract_layer_index,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
Expand Down Expand Up @@ -455,55 +453,3 @@ def load_weights(self, weights: Iterable[Tuple[str,
if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights)


class Gemma2EmbeddingModel(nn.Module, SupportsPP):
"""
A model that uses Gemma2 with additional embedding functionalities.

This class encapsulates the Gemma2Model and provides an interface for
embedding operations and customized pooling functions.

Attributes:
model: An instance of Gemma2Model used for forward operations.
_pooler: An instance of Pooler used for pooling operations.
"""

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()

self.model = Gemma2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self._pooler = Pooler.from_config_with_defaults(
vllm_config.model_config.pooler_config,
pooling_type=PoolingType.LAST,
normalize=True,
softmax=False)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)

def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
return self.model(input_ids, positions, kv_caches, attn_metadata,
intermediate_tensors, inputs_embeds)

def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
weights = hf_to_vllm_mapper.apply(weights)
weights = ((name, data) for name, data in weights
if not name.startswith("lm_head."))
self.model.load_weights(weights)
1 change: 1 addition & 0 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,7 @@ def permute(w: torch.Tensor, n_heads: int):
return name, loaded_weight


# TODO: Remove this once reward modeling is separated from LlamaForCausalLM
class LlamaEmbeddingModel(nn.Module, SupportsLoRA, SupportsPP):
"""
A model that uses Llama with additional embedding functionalities.
Expand Down
19 changes: 1 addition & 18 deletions vllm/model_executor/models/llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,11 @@
from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext)
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors
from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of

from .clip import (CLIPVisionModel, dummy_image_for_clip,
Expand Down Expand Up @@ -286,7 +284,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
pooler_config = vllm_config.model_config.pooler_config
multimodal_config = vllm_config.model_config.multimodal_config

vision_feature_layer = config.vision_feature_layer
Expand Down Expand Up @@ -325,13 +322,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model"))

# The same model class supports both language generation and embedding
# because the architecture name is the same
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.LAST,
normalize=True,
softmax=False)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)

Expand Down Expand Up @@ -678,13 +668,6 @@ def sample(
) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata)

def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)

def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self)
Expand Down
Loading