From 954dd061cb4402a32dffd492df37a1412dd50fb0 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 29 Nov 2024 08:55:40 +0000 Subject: [PATCH 01/24] Replace embedding models with generic adapter Signed-off-by: DarkLight1337 --- docs/source/models/supported_models.rst | 6 +- .../embedding/language/test_embedding.py | 5 + vllm/inputs/registry.py | 16 ++-- vllm/model_executor/model_loader/loader.py | 5 +- vllm/model_executor/model_loader/utils.py | 7 +- vllm/model_executor/models/adapters.py | 94 +++++++++++++++++++ vllm/model_executor/models/gemma2.py | 58 +----------- vllm/model_executor/models/llama.py | 1 + vllm/model_executor/models/llava_next.py | 19 +--- vllm/model_executor/models/phi3v.py | 19 +--- vllm/model_executor/models/qwen2.py | 28 +++--- vllm/model_executor/models/qwen2_vl.py | 18 +--- vllm/model_executor/models/registry.py | 14 ++- vllm/multimodal/base.py | 6 +- vllm/multimodal/registry.py | 5 +- vllm/utils.py | 19 +++- 16 files changed, 175 insertions(+), 145 deletions(-) create mode 100644 vllm/model_executor/models/adapters.py diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 7b7a83f20871b..f4cab81b3d20b 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -357,7 +357,7 @@ Text Embedding - ✅︎ * - :code:`Qwen2Model`, :code:`Qwen2ForCausalLM` - Qwen2-based - - :code:`ssmits/Qwen2-7B-Instruct-embed-base`, :code:`Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc. + - :code:`ssmits/Qwen2-7B-Instruct-embed-base` (see note), :code:`Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc. - ✅︎ - ✅︎ * - :code:`RobertaModel`, :code:`RobertaForMaskedLM` @@ -378,6 +378,10 @@ Text Embedding .. tip:: You can override the model's pooling method by passing :code:`--override-pooler-config`. +.. note:: + :code:`ssmits/Qwen2-7B-Instruct-embed-base` has an improperly defined Sentence Transformers config. + You should manually set mean pooling by passing :code:`--override-pooler-config '{"pooling_type": "MEAN"}'`. + .. note:: Unlike base Qwen2, :code:`Alibaba-NLP/gte-Qwen2-7B-instruct` uses bi-directional attention. You can set :code:`--hf-overrides '{"is_causal": false}'` to change the attention mask accordingly. diff --git a/tests/models/embedding/language/test_embedding.py b/tests/models/embedding/language/test_embedding.py index 36b1e5887981c..5ef8540265d14 100644 --- a/tests/models/embedding/language/test_embedding.py +++ b/tests/models/embedding/language/test_embedding.py @@ -4,6 +4,8 @@ """ import pytest +from vllm.config import PoolerConfig + from ..utils import check_embeddings_close @@ -33,6 +35,9 @@ def test_models( dtype: str, ) -> None: vllm_extra_kwargs = {} + if model == "ssmits/Qwen2-7B-Instruct-embed-base": + vllm_extra_kwargs["override_pooler_config"] = \ + PoolerConfig(pooling_type="MEAN") if model == "Alibaba-NLP/gte-Qwen2-7B-instruct": vllm_extra_kwargs["hf_overrides"] = {"is_causal": False} diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 68b4756331e6d..874290bc94ed1 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -11,8 +11,8 @@ from vllm.logger import init_logger from vllm.transformers_utils.processor import cached_get_processor from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.utils import (get_allowed_kwarg_only_overrides, print_warning_once, - resolve_mm_processor_kwargs) +from vllm.utils import (ClassRegistry, get_allowed_kwarg_only_overrides, + print_warning_once, resolve_mm_processor_kwargs) from .data import ProcessorInputs, SingletonInputs from .parse import is_encoder_decoder_inputs @@ -136,12 +136,12 @@ class InputRegistry: """ def __init__(self) -> None: - self._dummy_factories_by_model_type: Dict[Type[nn.Module], - DummyDataFactory] = {} - self._dummy_encoder_factories_by_model_type: Dict[ - Type[nn.Module], DummyDataFactory] = {} - self._input_processors_by_model_type: Dict[Type[nn.Module], - InputProcessor] = {} + self._dummy_factories_by_model_type = \ + ClassRegistry[nn.Module,DummyDataFactory]() + self._dummy_encoder_factories_by_model_type = \ + ClassRegistry[nn.Module, DummyDataFactory]() + self._input_processors_by_model_type = \ + ClassRegistry[nn.Module, InputProcessor]() def _default_dummy_data_factory( self, diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 37c2d789030b6..6c0b6a3c7ccab 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -9,6 +9,7 @@ import json import math import os +import warnings from abc import ABC, abstractmethod from contextlib import contextmanager from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast @@ -107,12 +108,14 @@ def _initialize_model(vllm_config: VllmConfig, prefix: str = "") -> nn.Module: # new-style model class with set_current_vllm_config(vllm_config): return model_class(vllm_config=vllm_config, prefix=prefix) + msg = ("vLLM model class should accept `vllm_config` and `prefix` as " "input arguments. Possibly you have an old-style model class" " registered from out of tree and it is used for new vLLM version. " "Check https://docs.vllm.ai/en/latest/design/arch_overview.html " "for the design and update the model class accordingly.") - logger.warning(msg) + warnings.warn(msg, DeprecationWarning, stacklevel=2) + logger.warning( "Trying to guess the arguments for old-style model class %s", model_class, diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index b95c0b7cd0612..1975f1e53e506 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -7,6 +7,7 @@ from vllm.config import ModelConfig from vllm.model_executor.models import ModelRegistry +from vllm.model_executor.models.adapters import for_embedding @contextlib.contextmanager @@ -32,7 +33,11 @@ def get_model_architecture( and "MixtralForCausalLM" in architectures): architectures = ["QuantMixtralForCausalLM"] - return ModelRegistry.resolve_model_cls(architectures) + model_cls, arch = ModelRegistry.resolve_model_cls(architectures) + if model_config.task == "embedding": + model_cls = for_embedding(model_cls) + + return model_cls, arch def get_architecture_class_name(model_config: ModelConfig) -> str: diff --git a/vllm/model_executor/models/adapters.py b/vllm/model_executor/models/adapters.py new file mode 100644 index 0000000000000..b529dcb5dd3b8 --- /dev/null +++ b/vllm/model_executor/models/adapters.py @@ -0,0 +1,94 @@ +from collections.abc import Iterable +from typing import Any, TypeVar + +import torch +import torch.nn as nn + +from .interfaces_base import VllmModelForEmbedding, is_embedding_model + +_T = TypeVar("_T", bound=type[nn.Module]) + + +def for_embedding(cls: _T) -> _T: + """Subclass an existing vLLM model to support embeddings.""" + # Avoid modifying existing embedding models + if is_embedding_model(cls): + return cls + + # Lazy import + from vllm.config import VllmConfig + from vllm.model_executor.layers.pooler import (Pooler, PoolerOutput, + PoolingType) + from vllm.model_executor.pooling_metadata import PoolingMetadata + + from .utils import AutoWeightsLoader, WeightsMapper + + class ModelForEmbedding(cls, VllmModelForEmbedding): + def __init__( + self, + *, + vllm_config: "VllmConfig", + prefix: str = "", + **kwargs: Any, + ) -> None: + super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs) + + # These are not used in embedding models + if hasattr(self, "lm_head"): + del self.lm_head + if hasattr(self, "logits_processor"): + del self.logits_processor + + pooler_config = vllm_config.model_config.pooler_config + assert pooler_config is not None + + # If the model already defines a pooler instance, don't overwrite it + if not getattr(self, "_pooler", None): + pooler = Pooler.from_config_with_defaults( + pooler_config, + pooling_type=PoolingType.LAST, + normalize=True, + softmax=False, + ) + assert pooler is not None + self._pooler = pooler + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> PoolerOutput: + return self._pooler(hidden_states, pooling_metadata) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + # We have deleted this attribute, so don't load it + weights = ((name, data) for name, data in weights + if not name.startswith("lm_head.")) + + + # If `*ForCausalLM` defines `load_weights` on the inner model + # and there are no other inner modules with parameters, + # we support loading from both `*Model` and `*ForCausalLM` + if (hasattr(self, "model") and hasattr(self.model, "load_weights") + and all( + name == "model" or all(False for _ in child.parameters()) + for name, child in self.named_children() + )): + mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) + weights = mapper.apply(weights) + + self.model.load_weights(weights) + # For most other models + elif hasattr(cls, "load_weights"): + cls.load_weights(self, weights) # type: ignore + # Fallback + else: + loader = AutoWeightsLoader(self) + loader.load_weights(weights) + + ModelForEmbedding.__name__ = cls.__name__ \ + .removesuffix("ForCausalLM") \ + .removesuffix("ForConditionalGeneration") + "ForEmbedding" + + return ModelForEmbedding # type: ignore + \ No newline at end of file diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index d35fcb012e166..4664aa53ea092 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -30,19 +30,17 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, PoolerOutput +from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, WeightsMapper, extract_layer_index, +from .utils import (AutoWeightsLoader, extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -455,55 +453,3 @@ def load_weights(self, weights: Iterable[Tuple[str, if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) - - -class Gemma2EmbeddingModel(nn.Module, SupportsPP): - """ - A model that uses Gemma2 with additional embedding functionalities. - - This class encapsulates the Gemma2Model and provides an interface for - embedding operations and customized pooling functions. - - Attributes: - model: An instance of Gemma2Model used for forward operations. - _pooler: An instance of Pooler used for pooling operations. - """ - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - - self.model = Gemma2Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) - self._pooler = Pooler.from_config_with_defaults( - vllm_config.model_config.pooler_config, - pooling_type=PoolingType.LAST, - normalize=True, - softmax=False) - self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) - - def forward( - self, - input_ids: Optional[torch.Tensor], - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - return self.model(input_ids, positions, kv_caches, attn_metadata, - intermediate_tensors, inputs_embeds) - - def pooler( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> Optional[PoolerOutput]: - return self._pooler(hidden_states, pooling_metadata) - - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) - weights = hf_to_vllm_mapper.apply(weights) - weights = ((name, data) for name, data in weights - if not name.startswith("lm_head.")) - self.model.load_weights(weights) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index fe94bb352961b..4daaf5ff3d37e 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -627,6 +627,7 @@ def permute(w: torch.Tensor, n_heads: int): return name, loaded_weight +# TODO: Remove this once reward modeling is separated from LlamaForCausalLM class LlamaEmbeddingModel(nn.Module, SupportsLoRA, SupportsPP): """ A model that uses Llama with additional embedding functionalities. diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index e113f5862830d..42c190811eba4 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -14,13 +14,11 @@ from vllm.config import VllmConfig from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, InputContext) -from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler -from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import NestedTensors -from vllm.sequence import IntermediateTensors, PoolerOutput +from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of from .clip import (CLIPVisionModel, dummy_image_for_clip, @@ -286,7 +284,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - pooler_config = vllm_config.model_config.pooler_config multimodal_config = vllm_config.model_config.multimodal_config vision_feature_layer = config.vision_feature_layer @@ -325,13 +322,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: vllm_config=vllm_config, prefix=maybe_prefix(prefix, "language_model")) - # The same model class supports both language generation and embedding - # because the architecture name is the same - self._pooler = Pooler.from_config_with_defaults( - pooler_config, - pooling_type=PoolingType.LAST, - normalize=True, - softmax=False) self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) @@ -678,13 +668,6 @@ def sample( ) -> Optional[SamplerOutput]: return self.language_model.sample(logits, sampling_metadata) - def pooler( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> Optional[PoolerOutput]: - return self._pooler(hidden_states, pooling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 4cb874a13e0c1..a725590914533 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -29,19 +29,17 @@ from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, InputContext, token_inputs) from vllm.logger import init_logger -from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.models.llama import LlamaForCausalLM -from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import NestedTensors, PlaceholderRange from vllm.multimodal.utils import cached_get_tokenizer, repeat_and_pad_token -from vllm.sequence import IntermediateTensors, PoolerOutput +from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of from .clip import dummy_image_for_clip, dummy_seq_data_for_clip @@ -536,7 +534,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - pooler_config = vllm_config.model_config.pooler_config multimodal_config = vllm_config.model_config.multimodal_config self.config = config self.multimodal_config = multimodal_config @@ -561,13 +558,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.language_model = LlamaForCausalLM(vllm_config=vllm_config, prefix="") - # The same model class supports both language generation and embedding - # because the architecture name is the same - self._pooler = Pooler.from_config_with_defaults( - pooler_config, - pooling_type=PoolingType.LAST, - normalize=True, - softmax=False) self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) @@ -739,13 +729,6 @@ def sample( ) -> Optional[SamplerOutput]: return self.language_model.sample(logits, sampling_metadata) - def pooler( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> Optional[PoolerOutput]: - return self._pooler(hidden_states, pooling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: hf_to_vllm_mapper = WeightsMapper( diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 87943e53d861c..7d4cc4b69e614 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -31,6 +31,7 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -55,6 +56,8 @@ make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) +logger = init_logger(__name__) + class Qwen2MLP(nn.Module): @@ -433,7 +436,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config - pooler_config = vllm_config.model_config.pooler_config self.config = config self.lora_config = lora_config @@ -454,14 +456,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = get_sampler() - # The same model class supports both language generation and embedding - # because the architecture name is the same - self._pooler = Pooler.from_config_with_defaults( - pooler_config, - pooling_type=PoolingType.LAST, - normalize=True, - softmax=False) - self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -499,13 +493,6 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def pooler( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> Optional[PoolerOutput]: - return self._pooler(hidden_states, pooling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader( @@ -553,6 +540,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.model = Qwen2Model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) + # TODO: Replace this model class with for_embedding(Qwen2ForCausalLM), + # after changing the default pooling method + if pooler_config.pooling_type is None: + logger.warning( + "This embedding model will default to last-token pooling in " + "an upcoming version. To avoid breaking changes, you should " + "pass `--override-pooler-config '{\"pooling_type\": \"MEAN\"}'`" + " explicitly.") + self._pooler = Pooler.from_config_with_defaults( pooler_config, pooling_type=PoolingType.MEAN, diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 7956a98b21569..27175dbae7483 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -50,7 +50,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.gptq_marlin import ( @@ -59,14 +58,13 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.qwen2 import Qwen2Model -from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.image import cached_get_image_processor from vllm.multimodal.inputs import (MultiModalData, MultiModalDataDict, MultiModalKwargs, NestedTensors) from vllm.multimodal.utils import cached_get_tokenizer from vllm.platforms import _Backend -from vllm.sequence import IntermediateTensors, PoolerOutput, SequenceData +from vllm.sequence import IntermediateTensors, SequenceData from vllm.transformers_utils.config import uses_mrope from vllm.transformers_utils.processor import cached_get_processor @@ -1070,7 +1068,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - pooler_config = vllm_config.model_config.pooler_config multimodal_config = vllm_config.model_config.multimodal_config assert not cache_config.enable_prefix_caching, \ "Qwen2-VL currently does not support prefix caching" @@ -1102,11 +1099,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = get_sampler() - self._pooler = Pooler.from_config_with_defaults( - pooler_config, - pooling_type=PoolingType.LAST, - normalize=True, - softmax=False) + self.make_empty_intermediate_tensors = ( make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) @@ -1361,13 +1354,6 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def pooler( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> Optional[PoolerOutput]: - return self._pooler(hidden_states, pooling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index c400c7d59828c..8b606f0d2844e 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -20,6 +20,7 @@ from vllm.logger import init_logger from vllm.platforms import current_platform +from .adapters import for_embedding from .interfaces import (has_inner_state, is_attention_free, supports_cross_encoding, supports_multimodal, supports_pp) @@ -107,7 +108,7 @@ "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"), "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"), "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"), - "Gemma2Model": ("gemma2", "Gemma2EmbeddingModel"), + "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"), "GlmForCausalLM": ("glm", "GlmForCausalLM"), "LlamaModel": ("llama", "LlamaEmbeddingModel"), **{ @@ -218,9 +219,18 @@ class _ModelInfo: @staticmethod def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo": + is_embedding_model_ = is_embedding_model(model) + if not is_embedding_model_: + try: + for_embedding(model) + except Exception: + pass + else: + is_embedding_model_ = True + return _ModelInfo( is_text_generation_model=is_text_generation_model(model), - is_embedding_model=is_embedding_model(model), + is_embedding_model=is_embedding_model_, supports_cross_encoding=supports_cross_encoding(model), supports_multimodal=supports_multimodal(model), supports_pp=supports_pp(model), diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index 6eec660e42ac4..bbb8fb4bc1cd1 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -7,7 +7,7 @@ from vllm.inputs import InputContext from vllm.logger import init_logger -from vllm.utils import (get_allowed_kwarg_only_overrides, +from vllm.utils import (ClassRegistry, get_allowed_kwarg_only_overrides, resolve_mm_processor_kwargs) if TYPE_CHECKING: @@ -54,8 +54,8 @@ class MultiModalPlugin(ABC): """ def __init__(self) -> None: - self._input_mappers: Dict[Type[nn.Module], MultiModalInputMapper] = {} - self._max_mm_tokens: Dict[Type[nn.Module], MultiModalTokensCalc] = {} + self._input_mappers = ClassRegistry[nn.Module, MultiModalInputMapper]() + self._max_mm_tokens = ClassRegistry[nn.Module, MultiModalTokensCalc]() @abstractmethod def get_data_key(self) -> str: diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index b992442d3b314..b73daee98bd80 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -9,6 +9,7 @@ from vllm.inputs import InputProcessingContext from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import ClassRegistry from .audio import AudioPlugin from .base import MultiModalInputMapper, MultiModalPlugin, MultiModalTokensCalc @@ -62,8 +63,8 @@ def __init__( plugins: Sequence[MultiModalPlugin] = DEFAULT_PLUGINS) -> None: self._plugins = {p.get_data_key(): p for p in plugins} - self._processor_factories: Dict[Type[nn.Module], - MultiModalProcessorFactory] = {} + self._processor_factories = ClassRegistry[nn.Module, + MultiModalProcessorFactory]() # This is used for non-multimodal models self._disabled_limits_per_plugin = {k: 0 for k in self._plugins} diff --git a/vllm/utils.py b/vllm/utils.py index 6f7a6f8c54e47..83fbefd755870 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -20,7 +20,7 @@ import warnings import weakref from asyncio import FIRST_COMPLETED, AbstractEventLoop, Future, Task -from collections import defaultdict +from collections import UserDict, defaultdict from collections.abc import Iterable, Mapping from functools import lru_cache, partial, wraps from platform import uname @@ -1517,13 +1517,13 @@ def value(self): # Adapted from: https://stackoverflow.com/a/47212782/5082708 -class LazyDict(Mapping, Generic[T]): +class LazyDict(Mapping[str, T], Generic[T]): def __init__(self, factory: Dict[str, Callable[[], T]]): self._factory = factory self._dict: Dict[str, T] = {} - def __getitem__(self, key) -> T: + def __getitem__(self, key: str) -> T: if key not in self._dict: if key not in self._factory: raise KeyError(key) @@ -1540,6 +1540,19 @@ def __len__(self): return len(self._factory) +class ClassRegistry(UserDict[type[T], _V]): + + def __getitem__(self, key: type[T]) -> _V: + for cls in key.mro(): + if cls in self.data: + return self.data[cls] + + raise KeyError(key) + + def __contains__(self, key: type[T]) -> bool: + return any(cls in self.data for cls in key.mro()) + + def weak_ref_tensor(tensor: torch.Tensor) -> torch.Tensor: """ Create a weak reference to a tensor. From 769507a119df5870b739c10999cd8931dd8ec165 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 29 Nov 2024 09:03:10 +0000 Subject: [PATCH 02/24] Fix lint errors Signed-off-by: DarkLight1337 --- vllm/model_executor/models/adapters.py | 12 +++++------- vllm/utils.py | 5 ++++- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/models/adapters.py b/vllm/model_executor/models/adapters.py index b529dcb5dd3b8..04de6d2661c43 100644 --- a/vllm/model_executor/models/adapters.py +++ b/vllm/model_executor/models/adapters.py @@ -24,6 +24,7 @@ def for_embedding(cls: _T) -> _T: from .utils import AutoWeightsLoader, WeightsMapper class ModelForEmbedding(cls, VllmModelForEmbedding): + def __init__( self, *, @@ -63,17 +64,15 @@ def pooler( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): # We have deleted this attribute, so don't load it weights = ((name, data) for name, data in weights - if not name.startswith("lm_head.")) - + if not name.startswith("lm_head.")) # If `*ForCausalLM` defines `load_weights` on the inner model # and there are no other inner modules with parameters, # we support loading from both `*Model` and `*ForCausalLM` if (hasattr(self, "model") and hasattr(self.model, "load_weights") - and all( - name == "model" or all(False for _ in child.parameters()) - for name, child in self.named_children() - )): + and all(name == "model" or all(False + for _ in child.parameters()) + for name, child in self.named_children())): mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) weights = mapper.apply(weights) @@ -91,4 +90,3 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): .removesuffix("ForConditionalGeneration") + "ForEmbedding" return ModelForEmbedding # type: ignore - \ No newline at end of file diff --git a/vllm/utils.py b/vllm/utils.py index 83fbefd755870..0165a22582e7b 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1549,7 +1549,10 @@ def __getitem__(self, key: type[T]) -> _V: raise KeyError(key) - def __contains__(self, key: type[T]) -> bool: + def __contains__(self, key: object) -> bool: + if not isinstance(key, type): + return False + return any(cls in self.data for cls in key.mro()) From 1d1c7b30c8eb9b6e37f5147f941db0542d885150 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 29 Nov 2024 10:08:01 +0000 Subject: [PATCH 03/24] Rename Signed-off-by: DarkLight1337 --- vllm/model_executor/model_loader/utils.py | 4 ++-- vllm/model_executor/models/adapters.py | 2 +- vllm/model_executor/models/registry.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 1975f1e53e506..3f15762d18bfd 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -7,7 +7,7 @@ from vllm.config import ModelConfig from vllm.model_executor.models import ModelRegistry -from vllm.model_executor.models.adapters import for_embedding +from vllm.model_executor.models.adapters import as_embedding_model @contextlib.contextmanager @@ -35,7 +35,7 @@ def get_model_architecture( model_cls, arch = ModelRegistry.resolve_model_cls(architectures) if model_config.task == "embedding": - model_cls = for_embedding(model_cls) + model_cls = as_embedding_model(model_cls) return model_cls, arch diff --git a/vllm/model_executor/models/adapters.py b/vllm/model_executor/models/adapters.py index 04de6d2661c43..efcf1c10c0061 100644 --- a/vllm/model_executor/models/adapters.py +++ b/vllm/model_executor/models/adapters.py @@ -9,7 +9,7 @@ _T = TypeVar("_T", bound=type[nn.Module]) -def for_embedding(cls: _T) -> _T: +def as_embedding_model(cls: _T) -> _T: """Subclass an existing vLLM model to support embeddings.""" # Avoid modifying existing embedding models if is_embedding_model(cls): diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 8b606f0d2844e..a935ab991c3fe 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -20,7 +20,7 @@ from vllm.logger import init_logger from vllm.platforms import current_platform -from .adapters import for_embedding +from .adapters import as_embedding_model from .interfaces import (has_inner_state, is_attention_free, supports_cross_encoding, supports_multimodal, supports_pp) @@ -222,7 +222,7 @@ def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo": is_embedding_model_ = is_embedding_model(model) if not is_embedding_model_: try: - for_embedding(model) + as_embedding_model(model) except Exception: pass else: From 666cc19bfae9ae6ced030d7a5d20cc7ef86d70b1 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 29 Nov 2024 10:18:44 +0000 Subject: [PATCH 04/24] Split up the condition Signed-off-by: DarkLight1337 --- vllm/model_executor/models/adapters.py | 28 +++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/models/adapters.py b/vllm/model_executor/models/adapters.py index efcf1c10c0061..bc93808ac5718 100644 --- a/vllm/model_executor/models/adapters.py +++ b/vllm/model_executor/models/adapters.py @@ -9,6 +9,11 @@ _T = TypeVar("_T", bound=type[nn.Module]) +def _is_paramless(module: nn.Module): + # NOTE: all([]) returns True + return all(False for _ in module.parameters()) + + def as_embedding_model(cls: _T) -> _T: """Subclass an existing vLLM model to support embeddings.""" # Avoid modifying existing embedding models @@ -69,16 +74,21 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): # If `*ForCausalLM` defines `load_weights` on the inner model # and there are no other inner modules with parameters, # we support loading from both `*Model` and `*ForCausalLM` - if (hasattr(self, "model") and hasattr(self.model, "load_weights") - and all(name == "model" or all(False - for _ in child.parameters()) - for name, child in self.named_children())): - mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) - weights = mapper.apply(weights) - - self.model.load_weights(weights) + if hasattr(self, "model") and hasattr(self.model, "load_weights"): + # Whether only `self.model` contains parameters + model_is_only_param = all( + name == "model" or _is_paramless(child) + for name, child in self.named_children()) + + if model_is_only_param: + mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) + weights = mapper.apply(weights) + + self.model.load_weights(weights) + return + # For most other models - elif hasattr(cls, "load_weights"): + if hasattr(cls, "load_weights"): cls.load_weights(self, weights) # type: ignore # Fallback else: From f73282ecaf8a81e74eb66c39d6c21c939275ff98 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 29 Nov 2024 13:38:39 +0000 Subject: [PATCH 05/24] Remove unnecessary `LlamaEmbeddingModel` Signed-off-by: DarkLight1337 --- docs/source/models/supported_models.rst | 9 +++ vllm/model_executor/models/llama.py | 98 +------------------------ vllm/model_executor/models/registry.py | 4 +- 3 files changed, 14 insertions(+), 97 deletions(-) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index f4cab81b3d20b..f571b8bf6735e 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -401,12 +401,21 @@ Reward Modeling - Example HF Models - :ref:`LoRA ` - :ref:`PP ` + * - :code:`LlamaForCausalLM` + - Llama-based + - :code:`peiyi9979/math-shepherd-mistral-7b-prm`, etc. + - ✅︎ + - ✅︎ * - :code:`Qwen2ForRewardModel` - Qwen2-based - :code:`Qwen/Qwen2.5-Math-RM-72B`, etc. - ✅︎ - ✅︎ +.. important:: + For process-supervised reward models such as :code:`peiyi9979/math-shepherd-mistral-7b-prm`, the pooling config should be set explicitly, + e.g.: :code:`--override-pooler-config '{"pooling_type": "STEP", "step_tag_id": 123, "returned_token_ids": [456, 789]}'`. + .. note:: As an interim measure, these models are supported in both offline and online inference via Embeddings API. diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 4daaf5ff3d37e..4c69fdf74894d 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -37,7 +37,6 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( get_compressed_tensors_cache_scale) @@ -47,14 +46,12 @@ DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.platforms import current_platform -from vllm.sequence import IntermediateTensors, PoolerOutput +from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper, - is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -497,7 +494,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config - pooler_config = vllm_config.model_config.pooler_config self.config = config self.lora_config = lora_config @@ -530,13 +526,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.sampler = get_sampler() else: self.lm_head = PPMissingLayer() + self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) - self._pooler = Pooler.from_config_with_defaults( - pooler_config, - pooling_type=PoolingType.STEP, - normalize=False, - softmax=False) def _init_model(self, vllm_config: VllmConfig, prefix: str = ""): return LlamaModel(vllm_config=vllm_config, prefix=prefix) @@ -567,14 +559,6 @@ def compute_logits( sampling_metadata) return logits - def pooler( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> Optional[PoolerOutput]: - logits = self.compute_logits(hidden_states, None) - return self._pooler(logits, pooling_metadata) - def sample(self, logits: torch.Tensor, sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]: next_tokens = self.sampler(logits, sampling_metadata) @@ -625,79 +609,3 @@ def permute(w: torch.Tensor, n_heads: int): name = name.replace(item, mapping[item]) return name, loaded_weight - - -# TODO: Remove this once reward modeling is separated from LlamaForCausalLM -class LlamaEmbeddingModel(nn.Module, SupportsLoRA, SupportsPP): - """ - A model that uses Llama with additional embedding functionalities. - - This class encapsulates the LlamaModel and provides an interface for - embedding operations and customized pooling functions. - - Attributes: - model: An instance of LlamaModel used for forward operations. - _pooler: An instance of Pooler used for pooling operations. - """ - packed_modules_mapping = { - "qkv_proj": ["q_proj", "k_proj", "v_proj"], - "gate_up_proj": ["gate_proj", "up_proj"] - } - - # LoRA specific attributes - supported_lora_modules = [ - "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens" - ] - embedding_modules = { - "embed_tokens": "input_embeddings", - } - embedding_padding_modules = [] - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - - pooler_config = vllm_config.model_config.pooler_config - - self.model = LlamaModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) - self._pooler = Pooler.from_config_with_defaults( - pooler_config, - pooling_type=PoolingType.LAST, - normalize=True, - softmax=False) - self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) - - def forward( - self, - input_ids: Optional[torch.Tensor], - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - return self.model(input_ids, positions, kv_caches, attn_metadata, - intermediate_tensors, inputs_embeds) - - def pooler( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> Optional[PoolerOutput]: - return self._pooler(hidden_states, pooling_metadata) - - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) - weights = hf_to_vllm_mapper.apply(weights) - weights = ((name, data) for name, data in weights - if not name.startswith("lm_head.")) - self.model.load_weights(weights) - - def load_kv_cache_scales(self, quantization_param_path: str) -> None: - self.model.load_kv_cache_scales(quantization_param_path) - - # LRUCacheWorkerLoRAManager instantiation requires model config. - @property - def config(self): - return self.model.config diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index a935ab991c3fe..36486150c767a 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -110,13 +110,13 @@ "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"), "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"), "GlmForCausalLM": ("glm", "GlmForCausalLM"), - "LlamaModel": ("llama", "LlamaEmbeddingModel"), + "LlamaModel": ("llama", "LlamaForCausalLM"), **{ # Multiple models share the same architecture, so we include them all k: (mod, arch) for k, (mod, arch) in _TEXT_GENERATION_MODELS.items() if arch == "LlamaForCausalLM" }, - "MistralModel": ("llama", "LlamaEmbeddingModel"), + "MistralModel": ("llama", "LlamaForCausalLM"), "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"), "Qwen2Model": ("qwen2", "Qwen2EmbeddingModel"), "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), From ec0fbf740ce9d8cc9a80ac61b91a8bb8eeda299b Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 29 Nov 2024 14:01:09 +0000 Subject: [PATCH 06/24] Simplify code Signed-off-by: DarkLight1337 --- vllm/model_executor/layers/pooler.py | 4 +--- vllm/model_executor/models/adapters.py | 18 +++++------------- 2 files changed, 6 insertions(+), 16 deletions(-) diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index f9437b4112ceb..e0d42e30ebef3 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -60,9 +60,7 @@ def from_config_with_defaults( softmax: bool, step_tag_id: Optional[int] = None, returned_token_ids: Optional[List[int]] = None, - ) -> Optional["Pooler"]: - if pooler_config is None: - return None + ) -> "Pooler": return cls( pooling_type=PoolingType[pooler_config.pooling_type] if pooler_config.pooling_type is not None else pooling_type, diff --git a/vllm/model_executor/models/adapters.py b/vllm/model_executor/models/adapters.py index bc93808ac5718..389a2ffc28884 100644 --- a/vllm/model_executor/models/adapters.py +++ b/vllm/model_executor/models/adapters.py @@ -9,11 +9,6 @@ _T = TypeVar("_T", bound=type[nn.Module]) -def _is_paramless(module: nn.Module): - # NOTE: all([]) returns True - return all(False for _ in module.parameters()) - - def as_embedding_model(cls: _T) -> _T: """Subclass an existing vLLM model to support embeddings.""" # Avoid modifying existing embedding models @@ -40,24 +35,21 @@ def __init__( super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs) # These are not used in embedding models - if hasattr(self, "lm_head"): - del self.lm_head - if hasattr(self, "logits_processor"): - del self.logits_processor + for attr in ("lm_head", "logits_processor"): + if hasattr(self, attr): + delattr(self, attr) pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None # If the model already defines a pooler instance, don't overwrite it if not getattr(self, "_pooler", None): - pooler = Pooler.from_config_with_defaults( + self._pooler = Pooler.from_config_with_defaults( pooler_config, pooling_type=PoolingType.LAST, normalize=True, softmax=False, ) - assert pooler is not None - self._pooler = pooler def pooler( self, @@ -77,7 +69,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): if hasattr(self, "model") and hasattr(self.model, "load_weights"): # Whether only `self.model` contains parameters model_is_only_param = all( - name == "model" or _is_paramless(child) + name == "model" or next(child.parameters(), None) is None for name, child in self.named_children()) if model_is_only_param: From 8301a0859b0b5310d819f93a61db148db44244f7 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 29 Nov 2024 14:01:38 +0000 Subject: [PATCH 07/24] Fix typo Signed-off-by: DarkLight1337 --- vllm/model_executor/models/registry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 36486150c767a..b14a47a29cbf3 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -126,7 +126,7 @@ # [Multimodal] "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501 "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), - "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration") # noqa: E501, + "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501 } _CROSS_ENCODER_MODELS = { From f52414ee572c925088739627f8f8b5bf4933b464 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 29 Nov 2024 14:25:10 +0000 Subject: [PATCH 08/24] Remove unused parameter Signed-off-by: DarkLight1337 --- tests/conftest.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index d56942d8912af..36f1d477fab59 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -263,7 +263,6 @@ def __init__( dtype: str = "half", *, model_kwargs: Optional[Dict[str, Any]] = None, - is_embedding_model: bool = False, is_sentence_transformer: bool = False, is_cross_encoder: bool = False, skip_tokenizer_init: bool = False, From b0e0f148a86010958bfc2d29a1827aabe4161f0e Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 29 Nov 2024 15:10:58 +0000 Subject: [PATCH 09/24] Fix OOT embedding test Signed-off-by: DarkLight1337 --- .../my_gemma_embedding.py | 43 +++++++++++++++++-- 1 file changed, 39 insertions(+), 4 deletions(-) diff --git a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py index 21958b1640204..acf56ad7d00af 100644 --- a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py +++ b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py @@ -1,13 +1,34 @@ -from typing import List, Optional, Union +from typing import Iterable, List, Optional, Tuple, Union import torch +import torch.nn as nn from vllm.attention import AttentionMetadata -from vllm.model_executor.models.gemma2 import Gemma2EmbeddingModel -from vllm.sequence import IntermediateTensors +from vllm.config import VllmConfig +from vllm.model_executor.layers.pooler import Pooler, PoolingType +from vllm.model_executor.models.gemma2 import Gemma2Model +from vllm.model_executor.models.utils import WeightsMapper, maybe_prefix +from vllm.model_executor.pooling_metadata import PoolingMetadata +from vllm.sequence import IntermediateTensors, PoolerOutput -class MyGemma2Embedding(Gemma2EmbeddingModel): +class MyGemma2Embedding(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + self.model = Gemma2Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + self._pooler = Pooler.from_config_with_defaults( + vllm_config.model_config.pooler_config, + pooling_type=PoolingType.LAST, + normalize=True, + softmax=False, + ) + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) def forward( self, @@ -32,3 +53,17 @@ def forward( # Return all-zero embeddings return torch.zeros_like(hidden_states) + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + return self._pooler(hidden_states, pooling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) + weights = hf_to_vllm_mapper.apply(weights) + weights = ((name, data) for name, data in weights + if not name.startswith("lm_head.")) + self.model.load_weights(weights) From 29869a3e8084ec13fbdca7b8f38331484ff9589a Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 29 Nov 2024 15:16:10 +0000 Subject: [PATCH 10/24] Fix config tests Signed-off-by: DarkLight1337 --- tests/test_config.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/test_config.py b/tests/test_config.py index 3cf90297ce177..0f1ee4fabfb56 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -9,7 +9,7 @@ @pytest.mark.parametrize(("model_id", "expected_task"), [ ("facebook/opt-125m", "generate"), - ("intfloat/e5-mistral-7b-instruct", "embedding"), + ("Qwen/Qwen2.5-Math-RM-72B", "embedding"), ]) def test_auto_task(model_id, expected_task): config = ModelConfig( @@ -26,8 +26,7 @@ def test_auto_task(model_id, expected_task): @pytest.mark.parametrize(("model_id", "bad_task"), [ - ("facebook/opt-125m", "embedding"), - ("intfloat/e5-mistral-7b-instruct", "generate"), + ("Qwen/Qwen2.5-Math-RM-72B", "generate"), ]) def test_incorrect_task(model_id, bad_task): with pytest.raises(ValueError, match=r"does not support the .* task"): From 12a88eec867b5014335c299870bd27897f7c7ade Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 29 Nov 2024 15:57:21 +0000 Subject: [PATCH 11/24] Fix PP tests Signed-off-by: DarkLight1337 --- tests/distributed/test_pipeline_parallel.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 386877e0e0a2c..a21ff3bda5ffc 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -191,9 +191,9 @@ def iter_params(self, model_name: str): EMBEDDING_MODELS = { # type: ignore[var-annotated] # [Text-only] - "intfloat/e5-mistral-7b-instruct": PPTestSettings.fast(), - "BAAI/bge-multilingual-gemma2": PPTestSettings.fast(), - "Qwen/Qwen2.5-Math-RM-72B": PPTestSettings.fast(tp_base=4, trust_remote_code=True), # noqa: E501 + "intfloat/e5-mistral-7b-instruct": PPTestSettings.fast(task="embedding"), + "BAAI/bge-multilingual-gemma2": PPTestSettings.fast(task="embedding"), + "Qwen/Qwen2.5-Math-RM-72B": PPTestSettings.fast(task="embedding", tp_base=4, trust_remote_code=True), # noqa: E501 } MULTIMODAL_MODELS = { From 8c0ed5ced562186c526889ccf515bf38f30ee214 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 29 Nov 2024 15:57:30 +0000 Subject: [PATCH 12/24] Fix model tests Signed-off-by: DarkLight1337 --- vllm/model_executor/models/adapters.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/adapters.py b/vllm/model_executor/models/adapters.py index 389a2ffc28884..09fc586e3c0e4 100644 --- a/vllm/model_executor/models/adapters.py +++ b/vllm/model_executor/models/adapters.py @@ -81,11 +81,11 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): # For most other models if hasattr(cls, "load_weights"): - cls.load_weights(self, weights) # type: ignore + return cls.load_weights(self, weights) # type: ignore # Fallback else: loader = AutoWeightsLoader(self) - loader.load_weights(weights) + return loader.load_weights(weights) ModelForEmbedding.__name__ = cls.__name__ \ .removesuffix("ForCausalLM") \ From fddca2519407abfc436ef89a18fe261f0f79cb54 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 29 Nov 2024 17:32:01 +0000 Subject: [PATCH 13/24] Fix Signed-off-by: DarkLight1337 --- .../vllm_add_dummy_model/my_gemma_embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py index acf56ad7d00af..316a6561adcd2 100644 --- a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py +++ b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py @@ -39,7 +39,7 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = super().forward( + hidden_states = self.model( input_ids, positions, kv_caches, From cd84a3c8047b5a10684e77e992a93bf3282d4516 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 29 Nov 2024 17:42:55 +0000 Subject: [PATCH 14/24] Improve auto task selection Signed-off-by: DarkLight1337 --- vllm/config.py | 14 ++++++++++ vllm/model_executor/models/registry.py | 37 +++++++++++++++++--------- 2 files changed, 38 insertions(+), 13 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index cd24e9ffdf598..51f8cca0a0f53 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -370,6 +370,20 @@ def _resolve_task( selected_task = next(iter(supported_tasks_lst)) if len(supported_tasks) > 1: + suffix_to_preferred_task: List[Tuple[str, _Task]] = [ + ("ForCausalLM", "generate"), + ("ForConditionalGeneration", "generate"), + ("Model", "embedding"), + ("RewardModel", "embedding"), + ("ForSequenceClassification", "embedding"), + ] + _, arch = ModelRegistry.inspect_model_cls(architectures) + + for suffix, pref_task in suffix_to_preferred_task: + if arch.endswith(suffix) and pref_task in supported_tasks: + selected_task = pref_task + break + logger.info( "This model supports multiple tasks: %s. " "Defaulting to '%s'.", supported_tasks, selected_task) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index b14a47a29cbf3..237201681f6de 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -409,13 +409,13 @@ def _normalize_archs( def inspect_model_cls( self, architectures: Union[str, List[str]], - ) -> _ModelInfo: + ) -> Tuple[_ModelInfo, str]: architectures = self._normalize_archs(architectures) for arch in architectures: model_info = self._try_inspect_model_cls(arch) if model_info is not None: - return model_info + return (model_info, arch) return self._raise_for_unsupported(architectures) @@ -436,39 +436,50 @@ def is_text_generation_model( self, architectures: Union[str, List[str]], ) -> bool: - return self.inspect_model_cls(architectures).is_text_generation_model + model_cls, _ = self.inspect_model_cls(architectures) + return model_cls.is_text_generation_model def is_embedding_model( self, architectures: Union[str, List[str]], ) -> bool: - return self.inspect_model_cls(architectures).is_embedding_model + model_cls, _ = self.inspect_model_cls(architectures) + return model_cls.is_embedding_model def is_cross_encoder_model( self, architectures: Union[str, List[str]], ) -> bool: - return self.inspect_model_cls(architectures).supports_cross_encoding + model_cls, _ = self.inspect_model_cls(architectures) + return model_cls.supports_cross_encoding def is_multimodal_model( self, architectures: Union[str, List[str]], ) -> bool: - return self.inspect_model_cls(architectures).supports_multimodal + model_cls, _ = self.inspect_model_cls(architectures) + return model_cls.supports_multimodal def is_pp_supported_model( self, architectures: Union[str, List[str]], ) -> bool: - return self.inspect_model_cls(architectures).supports_pp + model_cls, _ = self.inspect_model_cls(architectures) + return model_cls.supports_pp - def model_has_inner_state(self, architectures: Union[str, - List[str]]) -> bool: - return self.inspect_model_cls(architectures).has_inner_state + def model_has_inner_state( + self, + architectures: Union[str, List[str]], + ) -> bool: + model_cls, _ = self.inspect_model_cls(architectures) + return model_cls.has_inner_state - def is_attention_free_model(self, architectures: Union[str, - List[str]]) -> bool: - return self.inspect_model_cls(architectures).is_attention_free + def is_attention_free_model( + self, + architectures: Union[str, List[str]], + ) -> bool: + model_cls, _ = self.inspect_model_cls(architectures) + return model_cls.is_attention_free ModelRegistry = _ModelRegistry({ From 5dc0f10e0f3736cf322437c1a1456d2e2ca19ac5 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 29 Nov 2024 17:44:12 +0000 Subject: [PATCH 15/24] Revert PP test change to verify auto task selection Signed-off-by: DarkLight1337 --- tests/distributed/test_pipeline_parallel.py | 6 +++--- tests/test_config.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index a21ff3bda5ffc..386877e0e0a2c 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -191,9 +191,9 @@ def iter_params(self, model_name: str): EMBEDDING_MODELS = { # type: ignore[var-annotated] # [Text-only] - "intfloat/e5-mistral-7b-instruct": PPTestSettings.fast(task="embedding"), - "BAAI/bge-multilingual-gemma2": PPTestSettings.fast(task="embedding"), - "Qwen/Qwen2.5-Math-RM-72B": PPTestSettings.fast(task="embedding", tp_base=4, trust_remote_code=True), # noqa: E501 + "intfloat/e5-mistral-7b-instruct": PPTestSettings.fast(), + "BAAI/bge-multilingual-gemma2": PPTestSettings.fast(), + "Qwen/Qwen2.5-Math-RM-72B": PPTestSettings.fast(tp_base=4, trust_remote_code=True), # noqa: E501 } MULTIMODAL_MODELS = { diff --git a/tests/test_config.py b/tests/test_config.py index 0f1ee4fabfb56..45b0b938af215 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -9,7 +9,7 @@ @pytest.mark.parametrize(("model_id", "expected_task"), [ ("facebook/opt-125m", "generate"), - ("Qwen/Qwen2.5-Math-RM-72B", "embedding"), + ("intfloat/e5-mistral-7b-instruct", "embedding"), ]) def test_auto_task(model_id, expected_task): config = ModelConfig( From bc5f9fb7661a60ee4db495b2b7e73e0c709809ca Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 30 Nov 2024 04:08:33 +0000 Subject: [PATCH 16/24] Fix task selection Signed-off-by: DarkLight1337 --- vllm/config.py | 10 ++++++++-- vllm/model_executor/models/registry.py | 2 ++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 51f8cca0a0f53..08e5d98c22894 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -373,16 +373,22 @@ def _resolve_task( suffix_to_preferred_task: List[Tuple[str, _Task]] = [ ("ForCausalLM", "generate"), ("ForConditionalGeneration", "generate"), - ("Model", "embedding"), + ("LMHeadModel", "generate"), + ("EmbeddingModel", "embedding"), ("RewardModel", "embedding"), ("ForSequenceClassification", "embedding"), ] - _, arch = ModelRegistry.inspect_model_cls(architectures) + info, arch = ModelRegistry.inspect_model_cls(architectures) for suffix, pref_task in suffix_to_preferred_task: if arch.endswith(suffix) and pref_task in supported_tasks: selected_task = pref_task break + else: + if (arch.endswith("Model") + and info.architecture.endswith("ForCausalLM") + and "embedding" in supported_tasks): + selected_task = "embedding" logger.info( "This model supports multiple tasks: %s. " diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 237201681f6de..7d2bfce9ba264 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -209,6 +209,7 @@ @dataclass(frozen=True) class _ModelInfo: + architecture: str is_text_generation_model: bool is_embedding_model: bool supports_cross_encoding: bool @@ -229,6 +230,7 @@ def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo": is_embedding_model_ = True return _ModelInfo( + architecture=model.__name__, is_text_generation_model=is_text_generation_model(model), is_embedding_model=is_embedding_model_, supports_cross_encoding=supports_cross_encoding(model), From ebffd705b5e5b95985a8ccd61bd99210087390e7 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 30 Nov 2024 04:13:02 +0000 Subject: [PATCH 17/24] Fix registry test Signed-off-by: DarkLight1337 --- tests/models/test_registry.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/tests/models/test_registry.py b/tests/models/test_registry.py index 289ea66b5ebc5..127d81ce6ea09 100644 --- a/tests/models/test_registry.py +++ b/tests/models/test_registry.py @@ -6,11 +6,8 @@ from vllm.model_executor.models import (is_embedding_model, is_text_generation_model, supports_multimodal) -# yapf conflicts with isort for this block -# yapf: disable -from vllm.model_executor.models.registry import (_CROSS_ENCODER_MODELS, - _EMBEDDING_MODELS, - _MULTIMODAL_MODELS, +from vllm.model_executor.models.adapters import as_embedding_model +from vllm.model_executor.models.registry import (_MULTIMODAL_MODELS, _SPECULATIVE_DECODING_MODELS, _TEXT_GENERATION_MODELS, ModelRegistry) @@ -32,9 +29,9 @@ def test_registry_imports(model_arch): model_arch in _TEXT_GENERATION_MODELS or model_arch in _MULTIMODAL_MODELS) - embedding_models = {**_EMBEDDING_MODELS, **_CROSS_ENCODER_MODELS} - assert is_embedding_model(model_cls) is (model_arch - in embedding_models) + # All vLLM models should be convertible to an embedding model + embed_model = as_embedding_model(model_cls) + assert is_embedding_model(embed_model) assert supports_multimodal(model_cls) is (model_arch in _MULTIMODAL_MODELS) From d8ef4ae99c8c7cfd71950e957a5adad7eb533784 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 30 Nov 2024 04:15:32 +0000 Subject: [PATCH 18/24] Fix test grouping Signed-off-by: DarkLight1337 --- .buildkite/test-pipeline.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index fc23c9cff0d87..46692506f01d4 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -334,7 +334,6 @@ steps: commands: - pytest -v -s models/decoder_only/language -m 'core_model or quant_model' - pytest -v -s models/embedding/language -m core_model - - pytest -v -s models/embedding/vision_language -m core_model - label: Language Models Test (Extended) # 50min optional: true @@ -346,7 +345,6 @@ steps: commands: - pytest -v -s models/decoder_only/language -m 'not core_model and not quant_model' - pytest -v -s models/embedding/language -m 'not core_model' - - pytest -v -s models/embedding/vision_language -m 'not core_model' - label: Multi-Modal Models Test (Standard) # 26min #mirror_hardwares: [amd] @@ -359,6 +357,7 @@ steps: commands: - pytest -v -s models/decoder_only/audio_language -m 'core_model or quant_model' - pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'core_model or quant_model' + - pytest -v -s models/embedding/vision_language -m core_model - pytest -v -s models/encoder_decoder/language -m core_model - pytest -v -s models/encoder_decoder/vision_language -m core_model @@ -376,6 +375,7 @@ steps: # https://github.com/huggingface/transformers/issues/34307 - pytest -v -s models/decoder_only/vision_language/test_phi3v.py - pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'not core_model and not quant_model' + - pytest -v -s models/embedding/vision_language -m 'not core_model' - pytest -v -s models/encoder_decoder/language -m 'not core_model' - pytest -v -s models/encoder_decoder/vision_language -m 'not core_model' From 3a4d4682ab3b33ef292df0110fa1c06f0352c340 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 30 Nov 2024 04:17:43 +0000 Subject: [PATCH 19/24] Format Signed-off-by: DarkLight1337 --- vllm/inputs/registry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 874290bc94ed1..85ab4355cc2e4 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -137,7 +137,7 @@ class InputRegistry: def __init__(self) -> None: self._dummy_factories_by_model_type = \ - ClassRegistry[nn.Module,DummyDataFactory]() + ClassRegistry[nn.Module, DummyDataFactory]() self._dummy_encoder_factories_by_model_type = \ ClassRegistry[nn.Module, DummyDataFactory]() self._input_processors_by_model_type = \ From d4d7ad7b779db62e5b28b63eb35b02c706d0ef2c Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 30 Nov 2024 06:45:14 +0000 Subject: [PATCH 20/24] Fix task detection Signed-off-by: DarkLight1337 --- vllm/config.py | 5 +++++ vllm/model_executor/models/adapters.py | 4 +++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index 08e5d98c22894..a111d7aca9eed 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -371,8 +371,13 @@ def _resolve_task( if len(supported_tasks) > 1: suffix_to_preferred_task: List[Tuple[str, _Task]] = [ + # Hardcode the models that are exceptions + ("AquilaModel", "generate"), + ("ChatGLMModel", "generate"), + # Other models follow this pattern ("ForCausalLM", "generate"), ("ForConditionalGeneration", "generate"), + ("ChatModel", "generate"), ("LMHeadModel", "generate"), ("EmbeddingModel", "embedding"), ("RewardModel", "embedding"), diff --git a/vllm/model_executor/models/adapters.py b/vllm/model_executor/models/adapters.py index 09fc586e3c0e4..1356f9056d06f 100644 --- a/vllm/model_executor/models/adapters.py +++ b/vllm/model_executor/models/adapters.py @@ -89,6 +89,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): ModelForEmbedding.__name__ = cls.__name__ \ .removesuffix("ForCausalLM") \ - .removesuffix("ForConditionalGeneration") + "ForEmbedding" + .removesuffix("ForConditionalGeneration") \ + .removesuffix("ChatModel") \ + .removesuffix("LMHeadModel") + "ForEmbedding" return ModelForEmbedding # type: ignore From 7d8757ea1789d92880003360b7da5ea5140268d2 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 30 Nov 2024 07:09:23 +0000 Subject: [PATCH 21/24] Fix registry tests Signed-off-by: DarkLight1337 --- tests/models/test_registry.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/tests/models/test_registry.py b/tests/models/test_registry.py index 127d81ce6ea09..1886b1f9898ad 100644 --- a/tests/models/test_registry.py +++ b/tests/models/test_registry.py @@ -23,18 +23,18 @@ def test_registry_imports(model_arch): model_cls, _ = ModelRegistry.resolve_model_cls(model_arch) if model_arch in _SPECULATIVE_DECODING_MODELS: - pass # Ignore these models which do not have a unified format - else: - assert is_text_generation_model(model_cls) is ( - model_arch in _TEXT_GENERATION_MODELS - or model_arch in _MULTIMODAL_MODELS) - - # All vLLM models should be convertible to an embedding model - embed_model = as_embedding_model(model_cls) - assert is_embedding_model(embed_model) - - assert supports_multimodal(model_cls) is (model_arch - in _MULTIMODAL_MODELS) + return # Ignore these models which do not have a unified format + + if (model_arch in _TEXT_GENERATION_MODELS + or model_arch in _MULTIMODAL_MODELS): + assert is_text_generation_model(model_cls) + + # All vLLM models should be convertible to an embedding model + embed_model = as_embedding_model(model_cls) + assert is_embedding_model(embed_model) + + if model_arch in _MULTIMODAL_MODELS: + assert supports_multimodal(model_cls) @fork_new_process_for_each_test From f7d8c059427aeeee43479686040dd4949c16b81d Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 30 Nov 2024 10:37:42 +0000 Subject: [PATCH 22/24] Fix weight loading Signed-off-by: DarkLight1337 --- .../my_gemma_embedding.py | 2 +- vllm/model_executor/model_loader/loader.py | 13 +++++++--- vllm/model_executor/model_loader/utils.py | 11 ++++++--- vllm/model_executor/models/adapters.py | 6 +++-- vllm/model_executor/models/blip2.py | 5 ++-- vllm/model_executor/models/internvl.py | 5 ++-- vllm/model_executor/models/llava.py | 5 ++-- vllm/model_executor/models/llava_next.py | 5 ++-- .../model_executor/models/llava_next_video.py | 5 ++-- vllm/model_executor/models/llava_onevision.py | 5 ++-- vllm/model_executor/models/paligemma.py | 5 ++-- vllm/model_executor/models/phi3v.py | 18 +++++++++----- vllm/model_executor/models/pixtral.py | 5 ++-- vllm/model_executor/models/ultravox.py | 5 ++-- vllm/model_executor/models/utils.py | 24 +++++++++++++++---- 15 files changed, 81 insertions(+), 38 deletions(-) diff --git a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py index 316a6561adcd2..d676eacffb056 100644 --- a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py +++ b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py @@ -66,4 +66,4 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weights = hf_to_vllm_mapper.apply(weights) weights = ((name, data) for name, data in weights if not name.startswith("lm_head.")) - self.model.load_weights(weights) + return self.model.load_weights(weights) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 6c0b6a3c7ccab..0e12bc5691538 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -98,10 +98,17 @@ def device_loading_context(module: torch.nn.Module, logger = init_logger(__name__) -def _initialize_model(vllm_config: VllmConfig, prefix: str = "") -> nn.Module: +def _initialize_model( + vllm_config: VllmConfig, + *, + prefix: str = "", + architectures: Optional[list[str]] = None, +) -> nn.Module: """Initialize a model with the given configurations.""" model_config = vllm_config.model_config - model_class, _ = get_model_architecture(model_config) + model_class, _ = get_model_architecture(model_config, + architectures=architectures) + signatures = inspect.signature(model_class.__init__) all_params = [param.name for param in signatures.parameters.values()] if "vllm_config" in all_params and "prefix" in all_params: @@ -359,7 +366,7 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module: weights_to_load = {name for name, _ in model.named_parameters()} loaded_weights = model.load_weights( self._get_all_weights(model_config, model)) - # We only enable strict check for non-quantiized models + # We only enable strict check for non-quantized models # that have loaded weights tracking currently. if model_config.quantization is None and loaded_weights is not None: weights_not_loaded = weights_to_load - loaded_weights diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 3f15762d18bfd..864dd04e79921 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -1,6 +1,6 @@ """Utilities for selecting and loading models.""" import contextlib -from typing import Tuple, Type +from typing import Optional, Tuple, Type import torch from torch import nn @@ -20,8 +20,13 @@ def set_default_torch_dtype(dtype: torch.dtype): def get_model_architecture( - model_config: ModelConfig) -> Tuple[Type[nn.Module], str]: - architectures = getattr(model_config.hf_config, "architectures", []) + model_config: ModelConfig, + *, + architectures: Optional[list[str]] = None, +) -> Tuple[Type[nn.Module], str]: + if architectures is None: + architectures = getattr(model_config.hf_config, "architectures", []) + # Special handling for quantized Mixtral. # FIXME(woosuk): This is a temporary hack. mixtral_supported = [ diff --git a/vllm/model_executor/models/adapters.py b/vllm/model_executor/models/adapters.py index 1356f9056d06f..360433a07c5b8 100644 --- a/vllm/model_executor/models/adapters.py +++ b/vllm/model_executor/models/adapters.py @@ -59,6 +59,8 @@ def pooler( return self._pooler(hidden_states, pooling_metadata) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + # TODO: Support uninitialized params tracking + # We have deleted this attribute, so don't load it weights = ((name, data) for name, data in weights if not name.startswith("lm_head.")) @@ -81,11 +83,11 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): # For most other models if hasattr(cls, "load_weights"): - return cls.load_weights(self, weights) # type: ignore + cls.load_weights(self, weights) # type: ignore # Fallback else: loader = AutoWeightsLoader(self) - return loader.load_weights(weights) + loader.load_weights(weights) ModelForEmbedding.__name__ = cls.__name__ \ .removesuffix("ForCausalLM") \ diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index d2592016aff34..76b8505ee1c2a 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -512,9 +512,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.language_model = init_vllm_registered_model( - config.text_config, vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "language_model")) + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + ) self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index b1c0065afbf30..86aab38032450 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -474,9 +474,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: ) self.language_model = init_vllm_registered_model( - config.text_config, vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "language_model")) + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + ) self.mlp1 = self._init_mlp1(config) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index e7757b3c7d405..7fd4b32774798 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -319,9 +319,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: projector_hidden_act=config.projector_hidden_act) self.language_model = init_vllm_registered_model( - config.text_config, vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "language_model")) + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + ) self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index 42c190811eba4..a39f2f4124d05 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -318,9 +318,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: projector_hidden_act=config.projector_hidden_act) self.language_model = init_vllm_registered_model( - config.text_config, vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "language_model")) + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + ) self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index b130791808924..0de9d8c5ea572 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -275,9 +275,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: text_hidden_size=config.text_config.hidden_size, projector_hidden_act=config.projector_hidden_act) self.language_model = init_vllm_registered_model( - config.text_config, vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "language_model")) + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + ) self.make_empty_intermediate_tensors = ( self.language_model.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index 3166737d61582..0bebc1c745e2b 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -422,9 +422,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: prefix=maybe_prefix(prefix, "vision_tower")) self.multi_modal_projector = LlavaOnevisionMultiModalProjector(config) self.language_model = init_vllm_registered_model( - config.text_config, vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "language_model")) + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + ) self.image_newline = nn.Parameter( torch.empty(config.text_config.hidden_size)) diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index 2e5b6bee784e7..253e689e50a3b 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -151,9 +151,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.quant_config = quant_config config.text_config.architectures = ["GemmaForCausalLM"] self.language_model = init_vllm_registered_model( - config.text_config, vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "language_model")) + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + ) logit_scale = getattr(config, "logit_scale", 1.0) self.language_model.logits_processor.scale *= logit_scale diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index a725590914533..eef23029a2aca 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -34,7 +34,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.models.clip import CLIPVisionModel -from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import NestedTensors, PlaceholderRange @@ -44,7 +43,8 @@ from .clip import dummy_image_for_clip, dummy_seq_data_for_clip from .interfaces import SupportsMultiModal, SupportsPP -from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, maybe_prefix, +from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, + init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) logger = init_logger(__name__) @@ -553,10 +553,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): quant_config, prefix=maybe_prefix(prefix, "model.vision_embed_tokens")) - # The prefix is empty intentionally because default prefix of - # LlamaForCausalLM is "model" - self.language_model = LlamaForCausalLM(vllm_config=vllm_config, - prefix="") + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + # The prefix is empty intentionally because default prefix of + # LlamaForCausalLM is "model" + prefix="", + # We don't directly initialize vLLM's LlamaForCausalLM so we + # can automatically apply embedding wrapper if this model is + # initialized as an embedding model + architectures=["LlamaForCausalLM"], + ) self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 45171c1a04b17..215727cadd954 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -172,9 +172,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # init MistralForCausalLM self.language_model = init_vllm_registered_model( - config.text_config, vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "language_model")) + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + ) self.vision_encoder = VisionTransformer(self.vision_args) self.vision_language_adapter = VisionLanguageAdapter( diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index b61deccde45b7..ea1e5401d42c0 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -360,9 +360,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): )) self.multi_modal_projector = UltravoxProjector(config) self.language_model = init_vllm_registered_model( - config.text_config, vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "language_model")) + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + ) if config.text_model_id is not None: # this prefix is not for initialization, but for loading weights # note the trailing dot diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index a6b40a233439b..7a1e1f9bf2be4 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -173,8 +173,15 @@ def _load_module( module_load_weights = getattr(module, "load_weights", None) if callable(module_load_weights): loaded_params = module_load_weights(weights) - yield from map(lambda x: self._get_qualname(base_prefix, x), - loaded_params) + if loaded_params is None: + logger.warning( + "Unable to collect loaded parameters " + "for module %s", module) + else: + yield from map( + lambda x: self._get_qualname(base_prefix, x), + loaded_params, + ) child_modules = dict(module.named_children()) child_params = dict(module.named_parameters(recurse=False)) @@ -232,17 +239,24 @@ def load_weights( def init_vllm_registered_model( - hf_config: PretrainedConfig, vllm_config: VllmConfig, + *, prefix: str = "", + hf_config: Optional[PretrainedConfig] = None, + architectures: Optional[list[str]] = None, ) -> nn.Module: """ Helper function to initialize an inner model registered to vLLM, based on the arguments passed to the outer vLLM model. """ from vllm.model_executor.model_loader.loader import _initialize_model - vllm_config = vllm_config.with_hf_config(hf_config) - return _initialize_model(vllm_config, prefix) + + if hf_config is not None: + vllm_config = vllm_config.with_hf_config(hf_config) + + return _initialize_model(vllm_config=vllm_config, + prefix=prefix, + architectures=architectures) @overload From 7d3631f983aa84cd98b23fa8883868acd1e2bc06 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 30 Nov 2024 15:16:16 +0000 Subject: [PATCH 23/24] Try adding back prefix Signed-off-by: DarkLight1337 --- vllm/model_executor/models/llama.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 36f1f40b23b16..fb1f1fce7fd03 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -512,7 +512,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config = config self.lora_config = lora_config - self.model = self._init_model(vllm_config=vllm_config, prefix=prefix) + self.model = self._init_model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + if get_pp_group().is_last_rank: self.unpadded_vocab_size = config.vocab_size if lora_config: From 69d8d6cf3c8b892b59292fb93ee1da70895e336a Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 30 Nov 2024 15:46:39 +0000 Subject: [PATCH 24/24] format Signed-off-by: DarkLight1337 --- vllm/model_executor/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index fb1f1fce7fd03..31dfb235ae877 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -514,7 +514,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.model = self._init_model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) - + if get_pp_group().is_last_rank: self.unpadded_vocab_size = config.vocab_size if lora_config: