From eadd0bba97e8b2fac99202672948b2bf8ba07f00 Mon Sep 17 00:00:00 2001 From: Xin Yang Date: Tue, 1 Oct 2024 11:31:07 -0700 Subject: [PATCH] [Model] Support Gemma2 embedding model --- .../embedding/language/test_embedding.py | 1 + vllm/model_executor/models/__init__.py | 1 + vllm/model_executor/models/gemma2.py | 6 +- .../model_executor/models/gemma2_embedding.py | 80 +++++++++++++++++++ 4 files changed, 87 insertions(+), 1 deletion(-) create mode 100644 vllm/model_executor/models/gemma2_embedding.py diff --git a/tests/models/embedding/language/test_embedding.py b/tests/models/embedding/language/test_embedding.py index 6556998b68a74..78dd8419dbaba 100644 --- a/tests/models/embedding/language/test_embedding.py +++ b/tests/models/embedding/language/test_embedding.py @@ -8,6 +8,7 @@ MODELS = [ "intfloat/e5-mistral-7b-instruct", + "BAAI/bge-multilingual-gemma2", ] diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 682a2e71a1dbf..23d4646ebe468 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -75,6 +75,7 @@ _EMBEDDING_MODELS = { "MistralModel": ("llama_embedding", "LlamaEmbeddingModel"), "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"), + "Gemma2Model": ("gemma2_embedding", "Gemma2EmbeddingModel"), } _MULTIMODAL_MODELS = { diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index f9d9f9e7567c8..db169525cafe0 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -271,8 +271,12 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.embed_tokens(input_ids) hidden_states *= self.normalizer residual = None diff --git a/vllm/model_executor/models/gemma2_embedding.py b/vllm/model_executor/models/gemma2_embedding.py new file mode 100644 index 0000000000000..6a5f734aa5da6 --- /dev/null +++ b/vllm/model_executor/models/gemma2_embedding.py @@ -0,0 +1,80 @@ +from typing import Iterable, List, Optional, Tuple + +import torch +from torch import nn + +from vllm.attention import AttentionMetadata +from vllm.model_executor.layers.pooler import Pooler, PoolingType +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.gemma2 import Gemma2Model +from vllm.model_executor.pooling_metadata import PoolingMetadata +from vllm.sequence import PoolerOutput + + +class Gemma2EmbeddingModel(nn.Module): + """A model that uses Gemma2 with additional embedding functionalities. + + This class encapsulates the Gemma2Model and provides an interface for + embedding operations and customized pooling functions. + + Attributes: + model: An instance of Gemma2Model used for forward operations. + _pooler: An instance of Pooler used for pooling operations. + """ + + def __init__( + self, + **kwargs, + ) -> None: + super().__init__() + self.model = Gemma2Model(**kwargs) + self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return self.model.forward(input_ids, positions, kv_caches, + attn_metadata, inputs_embeds) + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + return self._pooler(hidden_states, pooling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.model.named_parameters()) + for name, loaded_weight in weights: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight)