diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 4e5b10967e3bb..ec5509e20187b 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -174,6 +174,11 @@ Text Generation - :code:`ibm-granite/granite-3.0-1b-a400m-base`, :code:`ibm-granite/granite-3.0-3b-a800m-instruct`, :code:`ibm/PowerMoE-3b`, etc. - ✅︎ - ✅︎ + * - :code:`GritLM` + - GritLM + - :code:`parasail-ai/GritLM-7B-vllm`. + - ✅︎ + - ✅︎ * - :code:`InternLMForCausalLM` - InternLM - :code:`internlm/internlm-7b`, :code:`internlm/internlm-chat-7b`, etc. @@ -350,6 +355,11 @@ Text Embedding - :code:`BAAI/bge-multilingual-gemma2`, etc. - - ✅︎ + * - :code:`GritLM` + - GritLM + - :code:`parasail-ai/GritLM-7B-vllm`. + - ✅︎ + - ✅︎ * - :code:`LlamaModel`, :code:`LlamaForCausalLM`, :code:`MistralModel`, etc. - Llama-based - :code:`intfloat/e5-mistral-7b-instruct`, etc. diff --git a/tests/models/embedding/language/test_gritlm.py b/tests/models/embedding/language/test_gritlm.py new file mode 100644 index 0000000000000..b947265be9e9d --- /dev/null +++ b/tests/models/embedding/language/test_gritlm.py @@ -0,0 +1,200 @@ +import importlib.util +import math +from array import array +from typing import List + +import openai +import pytest +import pytest_asyncio +from scipy.spatial.distance import cosine + +import vllm +import vllm.config + +from ....utils import RemoteOpenAIServer + +# GritLM embedding implementation is only supported by XFormers backend. +pytest.mark.skipif(not importlib.util.find_spec("xformers"), + reason="GritLM requires XFormers") + +MODEL_NAME = "parasail-ai/GritLM-7B-vllm" +MAX_MODEL_LEN = 4000 + + +def _arr(arr): + """ + Convert a list of integers to an array of integers. + """ + return array("i", arr) + + +def test_find_array(monkeypatch): + # GritLM embedding implementation is only supported by XFormers backend. + monkeypatch.setenv("VLLM_ATTENTION_BACKEND", "XFORMERS") + + from vllm.model_executor.models.gritlm import GritLMPooler + + # Create an LLM object to get the model config. + llm = vllm.LLM(MODEL_NAME, task="embedding", max_model_len=MAX_MODEL_LEN) + pooler = GritLMPooler(model_config=llm.llm_engine.model_config) + + arr = _arr([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) + + assert pooler._find_array(arr, _arr([3, 4, 5]), start_idx=0) == 3 + assert pooler._find_array(arr, _arr([3, 4, 5]), start_idx=1) == 3 + assert pooler._find_array(arr, _arr([3, 4, 5]), start_idx=5) == -1 + assert pooler._find_array(arr, _arr([3, 5]), start_idx=0) == -1 + + with pytest.raises(ValueError): + pooler._find_array(arr, _arr([3, 4, 5]), start_idx=-1) + + +@pytest.fixture(scope="module") +def server_embedding(): + # GritLM embedding implementation is only supported by XFormers backend. + with pytest.MonkeyPatch.context() as mp: + mp.setenv("VLLM_ATTENTION_BACKEND", "XFORMERS") + + args = ["--task", "embedding", "--max_model_len", str(MAX_MODEL_LEN)] + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest.fixture(scope="module") +def server_generate(): + args = ["--task", "generate", "--max_model_len", str(MAX_MODEL_LEN)] + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def client_embedding(server_embedding: RemoteOpenAIServer): + async with server_embedding.get_async_client() as async_client: + yield async_client + + +@pytest_asyncio.fixture +async def client_generate(server_generate: RemoteOpenAIServer): + async with server_generate.get_async_client() as async_client: + yield async_client + + +def run_llm_encode(llm: vllm.LLM, queries: List[str], + instruction: str) -> List[float]: + outputs = llm.encode([instruction + q for q in queries], ) + return [output.outputs.embedding for output in outputs] + + +async def run_client_embeddings(client: vllm.LLM, queries: List[str], + instruction: str) -> List[float]: + outputs = await client.embeddings.create( + model=MODEL_NAME, + input=[instruction + q for q in queries], + ) + return [data.embedding for data in outputs.data] + + +def gritlm_instruction(instruction): + return ("<|user|>\n" + instruction + + "\n<|embed|>\n" if instruction else "<|embed|>\n") + + +def get_test_data(): + """ + Grabbed this test data and the expected values from + README.md in https://github.com/ContextualAI/gritlm + """ + q_instruction = gritlm_instruction( + "Given a scientific paper title, retrieve the paper's abstract") + queries = [ + "Bitcoin: A Peer-to-Peer Electronic Cash System", + "Generative Representational Instruction Tuning", + ] + + d_instruction = gritlm_instruction("") + documents = [ + # ruff: noqa: E501 + "A purely peer-to-peer version of electronic cash would allow online payments to be sent directly from one party to another without going through a financial institution. Digital signatures provide part of the solution, but the main benefits are lost if a trusted third party is still required to prevent double-spending. We propose a solution to the double-spending problem using a peer-to-peer network. The network timestamps transactions by hashing them into an ongoing chain of hash-based proof-of-work, forming a record that cannot be changed without redoing the proof-of-work. The longest chain not only serves as proof of the sequence of events witnessed, but proof that it came from the largest pool of CPU power. As long as a majority of CPU power is controlled by nodes that are not cooperating to attack the network, they'll generate the longest chain and outpace attackers. The network itself requires minimal structure. Messages are broadcast on a best effort basis, and nodes can leave and rejoin the network at will, accepting the longest proof-of-work chain as proof of what happened while they were gone.", + "All text-based language problems can be reduced to either generation or embedding. Current models only perform well at one or the other. We introduce generative representational instruction tuning (GRIT) whereby a large language model is trained to handle both generative and embedding tasks by distinguishing between them through instructions. Compared to other open models, our resulting GritLM 7B sets a new state of the art on the Massive Text Embedding Benchmark (MTEB) and outperforms all models up to its size on a range of generative tasks. By scaling up further, GritLM 8X7B outperforms all open generative language models that we tried while still being among the best embedding models. Notably, we find that GRIT matches training on only generative or embedding data, thus we can unify both at no performance loss. Among other benefits, the unification via GRIT speeds up Retrieval-Augmented Generation (RAG) by > 60% for long documents, by no longer requiring separate retrieval and generation models. Models, code, etc. are freely available at https://github.com/ContextualAI/gritlm.", + ] + + return queries, q_instruction, documents, d_instruction + + +def validate_embed_output(q_rep: List[float], d_rep: List[float]): + cosine_sim_q0_d0 = 1 - cosine(q_rep[0], d_rep[0]) + assert math.isclose(cosine_sim_q0_d0, 0.609, abs_tol=0.001) + + cosine_sim_q0_d1 = 1 - cosine(q_rep[0], d_rep[1]) + assert math.isclose(cosine_sim_q0_d1, 0.101, abs_tol=0.001) + + cosine_sim_q1_d0 = 1 - cosine(q_rep[1], d_rep[0]) + assert math.isclose(cosine_sim_q1_d0, 0.120, abs_tol=0.001) + + cosine_sim_q1_d1 = 1 - cosine(q_rep[1], d_rep[1]) + assert math.isclose(cosine_sim_q1_d1, 0.532, abs_tol=0.001) + + +def test_gritlm_offline_embedding(monkeypatch): + # GritLM embedding implementation is only supported by XFormers backend. + monkeypatch.setenv("VLLM_ATTENTION_BACKEND", "XFORMERS") + + queries, q_instruction, documents, d_instruction = get_test_data() + + llm = vllm.LLM(MODEL_NAME, task="embedding", max_model_len=MAX_MODEL_LEN) + + d_rep = run_llm_encode( + llm, + documents, + d_instruction, + ) + q_rep = run_llm_encode( + llm, + queries, + q_instruction, + ) + + validate_embed_output(q_rep, d_rep) + + +@pytest.mark.asyncio +async def test_gritlm_api_server_embedding( + client_embedding: openai.AsyncOpenAI): + queries, q_instruction, documents, d_instruction = get_test_data() + + d_rep = await run_client_embeddings( + client_embedding, + documents, + d_instruction, + ) + q_rep = await run_client_embeddings( + client_embedding, + queries, + q_instruction, + ) + + validate_embed_output(q_rep, d_rep) + + +def test_gritlm_offline_gen(): + input = "<|user|>\nWhat is the capital of France?\n<|assistant|>\n" + + llm = vllm.LLM(MODEL_NAME, max_model_len=MAX_MODEL_LEN) + sampling_params = vllm.SamplingParams(temperature=0.0, max_tokens=256) + outputs = llm.generate(input, sampling_params=sampling_params) + + assert outputs[0].outputs[0].text == "The capital of France is Paris." + + +@pytest.mark.asyncio +async def test_gritlm_api_server_gen(client_generate: openai.AsyncOpenAI): + input = "<|user|>\nWhat is the capital of France?\n<|assistant|>\n" + + outputs = await client_generate.completions.create( + model=MODEL_NAME, + prompt=input, + max_tokens=256, + temperature=0.0, + ) + + assert outputs.choices[0].text == "The capital of France is Paris." diff --git a/tests/models/registry.py b/tests/models/registry.py index a89518820045f..6a8b1742ceae3 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -135,6 +135,7 @@ class _HfExamplesInfo: # [Text-only] "BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"), "Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"), + "GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"), "LlamaModel": _HfExamplesInfo("llama", is_available_online=False), "MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"), "Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"), diff --git a/vllm/model_executor/models/gritlm.py b/vllm/model_executor/models/gritlm.py new file mode 100644 index 0000000000000..ec01a07c16a62 --- /dev/null +++ b/vllm/model_executor/models/gritlm.py @@ -0,0 +1,245 @@ +from array import array +from typing import List, Optional, Union + +import torch +from torch import nn +from xformers.ops.fmha.attn_bias import BlockDiagonalMask + +from vllm.attention import AttentionMetadata +from vllm.attention.backends.xformers import XFormersImpl +from vllm.config import ModelConfig, VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.models.llama import LlamaForCausalLM +from vllm.model_executor.pooling_metadata import (PoolingMetadata, + PoolingTensors) +from vllm.multimodal.utils import cached_get_tokenizer +from vllm.sequence import (EmbeddingSequenceGroupOutput, IntermediateTensors, + PoolerOutput) + +logger = init_logger(__name__) + + +class GritLMPooler(nn.Module): + + def __init__(self, model_config: ModelConfig): + super().__init__() + + self.model_config = model_config + + tokenizer = cached_get_tokenizer( + self.model_config.tokenizer, + tokenizer_mode=self.model_config.tokenizer_mode, + tokenizer_revision=self.model_config.tokenizer_revision, + trust_remote_code=self.model_config.trust_remote_code, + ) + + # Collect the tokens needed for pattern matching. + # "▁<" is different from "_<". The former uses "▁" to indicate that + # the next token is the start of a word. + # "<0x0A>" is the newline token (i.e. "\n")." + self.token_ids = { + tok: tokenizer.convert_tokens_to_ids([tok])[0] + for tok in ["", "▁<", "<", "|", "embed", ">", "<0x0A>", "user"] + } + + def tokens_to_ids(tokens: list[str]) -> array: + return array("i", [self.token_ids[token] for token in tokens]) + + self.user_pattern_ids = tokens_to_ids( + ["▁<", "|", "user", "|", ">", "<0x0A>"]) + self.embed_newline_pattern_ids = tokens_to_ids( + ["<0x0A>", "<", "|", "embed", "|", ">", "<0x0A>"]) + self.embed_pattern_ids = tokens_to_ids( + ["▁<", "|", "embed", "|", ">", "<0x0A>"]) + + def _find_array(self, arr: array, target: array, start_idx: int) -> int: + """ + Find the first occurrence of target in arr starting from start_idx. + + Args: + arr: The array to search within + target: The consecutive subsequence to find + start_idx: The starting index to search from + + Returns: + int: The index of the first occurrence of target in arr. + """ + if start_idx < 0: + raise ValueError("start_idx must be non-negative") + if not target or not arr: + raise ValueError("Empty arr or target not allowed") + + target_len = len(target) + for i in range(start_idx, len(arr) - target_len + 1): + if arr[i:i + target_len] == target: + return i + return -1 + + def _get_instruction_len(self, prompt_token_ids: array) -> bool: + """ + Get the length of the instruction in the prompt. + + We do a pattern matching to find the instruction in the prompt, + and then return the length of the instruction. + + The pattern matching is done using integers instead of strings + because the prompt is given as a list of token IDs. + """ + + instruction_len = 0 + + # Return no instruction in case of missing BOS token. + if prompt_token_ids[0] != self.token_ids[""]: + logger.warning("BOS token not found in prompt," + "thus using empty string for instruction." + "GritLM requires BOS token in prompt.") + return instruction_len + + # If user pattern is found in the prompt, that means there should be + # a newline token before the embed pattern. + embed_pattern_ids = self.embed_pattern_ids + if self._find_array(prompt_token_ids, + self.user_pattern_ids, + start_idx=1) == 1: + embed_pattern_ids = self.embed_newline_pattern_ids + + # Find the embed pattern in the prompt. + found_embed_pattern_idx = self._find_array(prompt_token_ids, + embed_pattern_ids, + start_idx=1) + + if found_embed_pattern_idx != -1: + instruction_len = found_embed_pattern_idx + len(embed_pattern_ids) + else: + logger.warning("Query instruction not found in prompt," + "thus using BOS token as instruction instead." + "GritLM requires query instruction in prompt.") + instruction_len = 1 + + return instruction_len + + def forward( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> PoolerOutput: + """ + Pool the hidden states by summing the embeddings of + non-instruction tokens. + """ + prompts_token_ids = [ + token_ids.prompt_token_ids_array + for _, token_ids in pooling_metadata.seq_data.items() + ] + + instruction_lens = torch.tensor( + [ + self._get_instruction_len(prompt_token_ids) + for prompt_token_ids in prompts_token_ids + ], + device=hidden_states.device, + ) + + prompt_lens = PoolingTensors.from_pooling_metadata( + pooling_metadata, hidden_states.device).prompt_lens + + mask = torch.zeros_like(hidden_states, dtype=torch.bool) + + start_idx = 0 + for prompt_len, instruction_len in zip(prompt_lens, instruction_lens): + end_idx = start_idx + prompt_len + mask[start_idx + instruction_len:end_idx] = True + start_idx = end_idx + + masked_hidden_states = hidden_states.masked_fill(~mask, 0.0) + + sum_embeddings = torch.zeros(len(prompt_lens), + hidden_states.size(1), + device=hidden_states.device) + + start_idx = 0 + for i, prompt_len in enumerate(prompt_lens): + end_idx = start_idx + prompt_len + sum_embeddings[i] = masked_hidden_states[start_idx:end_idx].sum( + dim=0) + start_idx = end_idx + + num_non_instruction_tokens = prompt_lens - instruction_lens + mean_embeddings = sum_embeddings / num_non_instruction_tokens.unsqueeze( + 1) + + pooled_data = nn.functional.normalize(mean_embeddings, p=2, dim=1) + + pooled_outputs = [ + EmbeddingSequenceGroupOutput(data.tolist()) for data in pooled_data + ] + + return PoolerOutput(outputs=pooled_outputs) + + +class GritLM(LlamaForCausalLM): + """This class implements the embedding model for parasail-ai/GritLM-7B-vllm. + + The class inherits from LlamaForCausalLM and provides a custom pooling + layer. + + The main difference between the pooling layer in GritLM and the one in + LlamaForCausalLM is that GritLM ignores the query instruction in the prompt + when pooling the hidden states. + + Embedding prompts should be in the following format: + - With instruction: "<|user|>\nINSTRUCTION\n<|embed|>\nPROMPT". + - Without instruction: "<|embed|>\nPROMPT". + + Generation prompts should be in the following format: + - "<|user|>\nPROMPT\n<|assistant|>\n" + """ + + def __init__( + self, + vllm_config: VllmConfig, + prefix: str = "", + **kwargs, + ) -> None: + super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs) + + self.task = vllm_config.model_config.task + + self._pooler = GritLMPooler(vllm_config.model_config) + + for layer in self.model.layers: + if self.task == "embedding" and hasattr(layer, "self_attn"): + assert isinstance(layer.self_attn.attn.impl, XFormersImpl), ( + "GritLM embedding is only supported by XFormers backend, " + "which can be forced by VLLM_ATTENTION_BACKEND=XFORMERS") + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + **kwargs, + ) -> Union[torch.Tensor, IntermediateTensors]: + + # Change attention to non-causal for embedding task. + if self.task == "embedding": + assert attn_metadata.prefill_metadata.attn_bias is None + attn_metadata.prefill_metadata.attn_bias = [ + BlockDiagonalMask.from_seqlens(attn_metadata.seq_lens) + ] + + return super().forward( + input_ids=input_ids, + positions=positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + **kwargs, + ) + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + return self._pooler(hidden_states, pooling_metadata) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index e69596aa915b5..64f2f5c5a646e 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -56,6 +56,7 @@ "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"), "GraniteForCausalLM": ("granite", "GraniteForCausalLM"), "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"), + "GritLM": ("gritlm", "GritLM"), "InternLMForCausalLM": ("llama", "LlamaForCausalLM"), "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"), "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"), @@ -110,6 +111,7 @@ "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"), "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"), "GlmForCausalLM": ("glm", "GlmForCausalLM"), + "GritLM": ("gritlm", "GritLM"), "LlamaModel": ("llama", "LlamaForCausalLM"), **{ # Multiple models share the same architecture, so we include them all