From 1d60ff6c0415ff7cc5a718e36ff7ce30529907e3 Mon Sep 17 00:00:00 2001 From: Pooya Davoodi Date: Sun, 20 Oct 2024 07:28:31 +0900 Subject: [PATCH 01/18] Add support for embedding model parasail-ai/GritLM-7B-vllm This model is a fork of GritLM/GritLM-7B. The main change in the fork wrt the original repo is the name of the architecture to make vLLM adoption easier. Signed-off-by: Pooya Davoodi --- docs/source/models/supported_models.rst | 5 + .../models/embedding/language/test_gritlm.py | 144 +++++++++++ vllm/core/scheduler.py | 38 ++- vllm/inputs/data.py | 11 + vllm/model_executor/models/gritlm.py | 223 ++++++++++++++++++ vllm/model_executor/models/registry.py | 1 + 6 files changed, 413 insertions(+), 9 deletions(-) create mode 100644 tests/models/embedding/language/test_gritlm.py create mode 100644 vllm/model_executor/models/gritlm.py diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 9f3b6f59068e2..71b91f51d1362 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -350,6 +350,11 @@ Text Embedding - :code:`BAAI/bge-multilingual-gemma2`, etc. - - ✅︎ + * - :code:`GritLM` + - GritLM + - :code:`parasail-ai/GritLM-7B-vllm`. + - + - * - :code:`LlamaModel`, :code:`LlamaForCausalLM`, :code:`MistralModel`, etc. - Llama-based - :code:`intfloat/e5-mistral-7b-instruct`, etc. diff --git a/tests/models/embedding/language/test_gritlm.py b/tests/models/embedding/language/test_gritlm.py new file mode 100644 index 0000000000000..89c02e1a7951e --- /dev/null +++ b/tests/models/embedding/language/test_gritlm.py @@ -0,0 +1,144 @@ +import math +import os +from typing import List + +import openai +import pytest +import pytest_asyncio +from scipy.spatial.distance import cosine + +import vllm + +from ....utils import RemoteOpenAIServer + +MODEL_NAME = "parasail-ai/GritLM-7B-vllm" + +# GritLM implementation is only supported by XFormers backend. +os.environ["VLLM_ATTENTION_BACKEND"] = "XFORMERS" + + +@pytest.fixture(scope="module") +def server(): + args = [ + "--task", + "embedding", + ] + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def client(server): + async with server.get_async_client() as async_client: + yield async_client + + +def run_llm_encode(llm: vllm.LLM, queries: List[str], instruction: str, + use_instruction_arg: bool) -> List[float]: + pooling_params = vllm.PoolingParams( + additional_data={"instruction_seq": instruction + }) if use_instruction_arg else None + outputs = llm.encode( + [instruction + q for q in queries], + pooling_params=pooling_params, + ) + return [output.outputs.embedding for output in outputs] + + +async def run_client_embeddings(client: vllm.LLM, queries: List[str], + instruction: str, + use_instruction_arg: bool) -> List[float]: + additional_data = { + "instruction_seq": instruction + } if use_instruction_arg else None + outputs = await client.embeddings.create( + model=MODEL_NAME, + input=[instruction + q for q in queries], + extra_body={"additional_data": additional_data}, + ) + return [data.embedding for data in outputs.data] + + +def gritlm_instruction(instruction): + return ("<|user|>\n" + instruction + + "\n<|embed|>\n" if instruction else "<|embed|>\n") + + +def get_test_data(): + """ + Grabbed this test data and the expected values from + README.md in https://github.com/ContextualAI/gritlm + """ + q_instruction = gritlm_instruction( + "Given a scientific paper title, retrieve the paper's abstract") + queries = [ + "Bitcoin: A Peer-to-Peer Electronic Cash System", + "Generative Representational Instruction Tuning", + ] + + d_instruction = gritlm_instruction("") + documents = [ + # ruff: noqa: E501 + "A purely peer-to-peer version of electronic cash would allow online payments to be sent directly from one party to another without going through a financial institution. Digital signatures provide part of the solution, but the main benefits are lost if a trusted third party is still required to prevent double-spending. We propose a solution to the double-spending problem using a peer-to-peer network. The network timestamps transactions by hashing them into an ongoing chain of hash-based proof-of-work, forming a record that cannot be changed without redoing the proof-of-work. The longest chain not only serves as proof of the sequence of events witnessed, but proof that it came from the largest pool of CPU power. As long as a majority of CPU power is controlled by nodes that are not cooperating to attack the network, they'll generate the longest chain and outpace attackers. The network itself requires minimal structure. Messages are broadcast on a best effort basis, and nodes can leave and rejoin the network at will, accepting the longest proof-of-work chain as proof of what happened while they were gone.", + "All text-based language problems can be reduced to either generation or embedding. Current models only perform well at one or the other. We introduce generative representational instruction tuning (GRIT) whereby a large language model is trained to handle both generative and embedding tasks by distinguishing between them through instructions. Compared to other open models, our resulting GritLM 7B sets a new state of the art on the Massive Text Embedding Benchmark (MTEB) and outperforms all models up to its size on a range of generative tasks. By scaling up further, GritLM 8X7B outperforms all open generative language models that we tried while still being among the best embedding models. Notably, we find that GRIT matches training on only generative or embedding data, thus we can unify both at no performance loss. Among other benefits, the unification via GRIT speeds up Retrieval-Augmented Generation (RAG) by > 60% for long documents, by no longer requiring separate retrieval and generation models. Models, code, etc. are freely available at https://github.com/ContextualAI/gritlm.", + ] + + return queries, q_instruction, documents, d_instruction + + +def validate_output(q_rep: List[float], d_rep: List[float]): + cosine_sim_q0_d0 = 1 - cosine(q_rep[0], d_rep[0]) + assert math.isclose(cosine_sim_q0_d0, 0.609, abs_tol=0.001) + + cosine_sim_q0_d1 = 1 - cosine(q_rep[0], d_rep[1]) + assert math.isclose(cosine_sim_q0_d1, 0.101, abs_tol=0.001) + + cosine_sim_q1_d0 = 1 - cosine(q_rep[1], d_rep[0]) + assert math.isclose(cosine_sim_q1_d0, 0.120, abs_tol=0.001) + + cosine_sim_q1_d1 = 1 - cosine(q_rep[1], d_rep[1]) + assert math.isclose(cosine_sim_q1_d1, 0.532, abs_tol=0.001) + + +@pytest.mark.parametrize("use_instruction_arg", [True, False]) +def test_gritlm_offline(use_instruction_arg: bool): + queries, q_instruction, documents, d_instruction = get_test_data() + + llm = vllm.LLM(MODEL_NAME, task="embedding") + + d_rep = run_llm_encode( + llm, + documents, + d_instruction, + use_instruction_arg=use_instruction_arg, + ) + q_rep = run_llm_encode( + llm, + queries, + q_instruction, + use_instruction_arg=use_instruction_arg, + ) + + validate_output(q_rep, d_rep) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("use_instruction_arg", [True, False]) +async def test_gritlm_api_server(client: openai.AsyncOpenAI, + use_instruction_arg: bool): + queries, q_instruction, documents, d_instruction = get_test_data() + + d_rep = await run_client_embeddings( + client, + documents, + d_instruction, + use_instruction_arg=use_instruction_arg, + ) + q_rep = await run_client_embeddings( + client, + queries, + q_instruction, + use_instruction_arg=use_instruction_arg, + ) + + validate_output(q_rep, d_rep) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index d23009dae01ee..1a3d34187ec35 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -12,6 +12,7 @@ from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import (Sequence, SequenceData, SequenceGroup, SequenceGroupMetadata, SequenceGroupMetadataDelta, @@ -523,7 +524,7 @@ def _schedule_running( chunked number of tokens are scheduled if `budget.num_batched_tokens` has not enough capacity to schedule all tokens. - + Returns: SchedulerRunningOutputs. """ @@ -841,10 +842,10 @@ def _schedule_priority_preemption( self._get_num_new_uncached_and_cached_tokens( seq_group, SequenceStatus.WAITING, False, budget)) - #Only preempt if priority inversion exists + # Only preempt if priority inversion exists while running_queue and self._get_priority( running_queue[-1]) > self._get_priority(seq_group): - #Only preempt if waiting sequence cannot be allocated + # Only preempt if waiting sequence cannot be allocated can_allocate = self.block_manager.can_allocate(seq_group) if (num_new_tokens_uncached > 0 and can_allocate == AllocStatus.OK @@ -854,7 +855,7 @@ def _schedule_priority_preemption( )): break - #Adjust budget to remove the victim sequence group + # Adjust budget to remove the victim sequence group vseq_group = running_queue.pop() num_running_tokens_uncached, _ = ( self._get_num_new_uncached_and_cached_tokens( @@ -865,11 +866,11 @@ def _schedule_priority_preemption( budget.subtract_num_seqs(vseq_group.request_id, num_running_seqs) - #Preempt out the victim sequence group + # Preempt out the victim sequence group self._preempt(vseq_group, blocks_to_swap_out) waiting_queue.appendleft(vseq_group) force_preemption_count += 1 - #Put the sequence back into the waiting queue + # Put the sequence back into the waiting queue waiting_queue.appendleft(seq_group) waiting_queue = deque(sorted(waiting_queue, key=self._get_priority)) @@ -1036,7 +1037,7 @@ def _schedule_prefills( def _schedule_default(self) -> SchedulerOutputs: """Schedule queued requests. - + The current policy is designed to optimize the throughput. First, it batches as many prefill requests as possible. And it schedules decodes. If there's a pressure on GPU memory, decode requests can @@ -1141,7 +1142,7 @@ def _schedule_default(self) -> SchedulerOutputs: def _schedule_chunked_prefill(self) -> SchedulerOutputs: """Schedule queued requests. - + Chunked prefill allows to chunk prefill requests, batch them together with decode requests. This policy 1. schedule as many decoding requests as possible. 2. schedule chunked prefill requests that are not @@ -1350,6 +1351,25 @@ def schedule( seqs[0].data.get_len()): do_sample = False + pooling_params = seq_group.pooling_params + + # Store instruction_seq in pooling_params. + instruction_seq = seq.inputs.inputs.get("instruction_seq") + if instruction_seq is not None: + if pooling_params is None: + pooling_params = PoolingParams() + pooling_params.additional_data = { + "instruction_seq": instruction_seq + } + elif pooling_params.additional_data is None: + pooling_params.additional_data = { + "instruction_seq": instruction_seq + } + else: + pooling_params.additional_data[ + "instruction_seq"] = seq.inputs.inputs.get( + "instruction_seq") + # It assumes the scheduled_seq_groups is ordered by # prefill < decoding. if is_first_prefill or not self.scheduler_config.send_delta_data: @@ -1360,7 +1380,7 @@ def schedule( sampling_params=seq_group.sampling_params, block_tables=block_tables, do_sample=do_sample, - pooling_params=seq_group.pooling_params, + pooling_params=pooling_params, token_chunk_size=token_chunk_size, lora_request=seq_group.lora_request, computed_block_nums=common_computed_block_nums, diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index fb7dbbebd7b90..141aaf27307ae 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -163,6 +163,14 @@ class TokenInputs(TypedDict): to pass the mm_processor_kwargs to each of them. """ + instruction_seq: NotRequired[Optional[str]] + """ + The instruction sequence that is usually prepended to the original prompt + when passing to the model. Certain models need to extract this instruction + sequence from the prompt in order to adjust certain operations of the + model such as the attention mask. + """ + def token_inputs( prompt_token_ids: List[int], @@ -171,6 +179,7 @@ def token_inputs( multi_modal_data: Optional["MultiModalDataDict"] = None, multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None, mm_processor_kwargs: Optional[Dict[str, Any]] = None, + instruction_seq: Optional[str] = None, ) -> TokenInputs: """Construct :class:`TokenInputs` from optional values.""" inputs = TokenInputs(type="token", prompt_token_ids=prompt_token_ids) @@ -185,6 +194,8 @@ def token_inputs( inputs["multi_modal_placeholders"] = multi_modal_placeholders if mm_processor_kwargs is not None: inputs["mm_processor_kwargs"] = mm_processor_kwargs + if instruction_seq is not None: + inputs["instruction_seq"] = instruction_seq return inputs diff --git a/vllm/model_executor/models/gritlm.py b/vllm/model_executor/models/gritlm.py new file mode 100644 index 0000000000000..97d99c9057241 --- /dev/null +++ b/vllm/model_executor/models/gritlm.py @@ -0,0 +1,223 @@ +import re +from typing import List, Optional, Union + +import torch +from torch import nn +from xformers.ops.fmha.attn_bias import BlockDiagonalMask + +from vllm.attention import AttentionMetadata +from vllm.attention.backends.xformers import XFormersImpl +from vllm.config import ModelConfig, VllmConfig +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, + token_inputs) +from vllm.logger import init_logger +from vllm.model_executor.models.llama import LlamaForCausalLM +from vllm.model_executor.pooling_metadata import (PoolingMetadata, + PoolingTensors) +from vllm.multimodal.utils import cached_get_tokenizer +from vllm.pooling_params import PoolingParams +from vllm.sequence import (EmbeddingSequenceGroupOutput, IntermediateTensors, + PoolerOutput) + +logger = init_logger(__name__) + + +class GritLMPooler(nn.Module): + + def __init__( + self, + model_config: ModelConfig, + ): + super().__init__() + + self.model_config = model_config + + def _get_instruction_lens( + self, device: torch.device, + pooling_metadata: PoolingMetadata) -> torch.Tensor: + """ + Compute the number of tokens of each instruction using the tokenizer. + """ + self.tokenizer = cached_get_tokenizer( + self.model_config.tokenizer, + tokenizer_mode=self.model_config.tokenizer_mode, + tokenizer_revision=self.model_config.tokenizer_revision, + trust_remote_code=self.model_config.trust_remote_code, + truncation_side="left", + ) + + def query_instruction_missing(pooling_params: PoolingParams) -> bool: + return (pooling_params is None + or pooling_params.additional_data is None + or "instruction_seq" not in pooling_params.additional_data) + + for seq_group in pooling_metadata.seq_groups: + if query_instruction_missing(seq_group[1]): + logger.warning( + "Query instruction not found in prompt," + "thus using empty string instead. GritLM requires " + "query instruction in prompt.") + + instruction_lens = torch.tensor( + [ + len( + self.tokenizer( + ("" if query_instruction_missing(seq_group[1]) else + seq_group[1].additional_data["instruction_seq"]), + padding=False, + truncation=True, + add_special_tokens=True, + )["input_ids"]) + for seq_group in pooling_metadata.seq_groups + ], + device=device, + ) + + return instruction_lens + + def forward( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> PoolerOutput: + """ + Pool the hidden states by summing the embeddings of + non-instruction tokens. + """ + instruction_lens = self._get_instruction_lens( + device=hidden_states.device, pooling_metadata=pooling_metadata) + + prompt_lens = PoolingTensors.from_pooling_metadata( + pooling_metadata, hidden_states.device).prompt_lens + + mask = torch.zeros_like(hidden_states, dtype=torch.bool) + + start_idx = 0 + for prompt_len, instruction_len in zip(prompt_lens, instruction_lens): + end_idx = start_idx + prompt_len + mask[start_idx + instruction_len:end_idx] = True + start_idx = end_idx + + masked_hidden_states = hidden_states.masked_fill(~mask, 0.0) + + sum_embeddings = torch.zeros(len(prompt_lens), + hidden_states.size(1), + device=hidden_states.device) + + start_idx = 0 + for i, prompt_len in enumerate(prompt_lens): + end_idx = start_idx + prompt_len + sum_embeddings[i] = masked_hidden_states[start_idx:end_idx].sum( + dim=0) + start_idx = end_idx + + num_non_instruction_tokens = prompt_lens - instruction_lens + mean_embeddings = sum_embeddings / num_non_instruction_tokens.unsqueeze( + 1) + + pooled_data = nn.functional.normalize(mean_embeddings, p=2, dim=1) + + pooled_outputs = [ + EmbeddingSequenceGroupOutput(data.tolist()) for data in pooled_data + ] + + return PoolerOutput(outputs=pooled_outputs) + + +def input_processor_for_gritlm(ctx: InputContext, inputs: DecoderOnlyInputs): + """ + Extracts query instruction from prompt and adds it to token inputs. + """ + model_config = ctx.model_config + tokenizer = cached_get_tokenizer(model_config.tokenizer) + + prompt = inputs.get("prompt", None) + instruction = "" + + if prompt is None and "prompt_token_ids" in inputs: + prompt = tokenizer.decode(inputs["prompt_token_ids"]) + + if prompt is not None: + match_instruction = re.match(r"( )?(<\|user\|>\n.*\n<\|embed\|>\n)", + prompt) + match_empty_instruction = re.match(r"( )?(<\|embed\|>\n)", prompt) + + if match_instruction and match_instruction.group(2): + instruction = match_instruction.group(2) + elif match_empty_instruction: + instruction = match_empty_instruction.group(2) + else: + logger.warning("Query instruction not found in prompt," + "thus using empty string instead. GritLM requires " + "query instruction in prompt.") + + return token_inputs( + prompt_token_ids=inputs["prompt_token_ids"], + prompt=prompt, + instruction_seq=instruction, + ) + + +@INPUT_REGISTRY.register_input_processor(input_processor_for_gritlm) +class GritLM(LlamaForCausalLM): + """This class implements the embedding model for parasail-ai/GritLM-7B-vllm. + + The class inherits from LlamaForCausalLM and provides a custom pooling + layer. + + The task "embedding" must be specified in the server arguments. + + The main difference between the pooling layer in GritLM and the one in + LlamaForCausalLM is that GritLM ignores the query instruction in the prompt + when pooling the hidden states. + + Instructions can be passed to the model in two ways: + 1. By prepending the instruction to the prompt. The instruction should be + in the format "<|user|>\n\n<|embed|>\n". + 2. By passing the instruction as additional data in the pooling parameters + (e.g. extra_body of client.embeddings.create). + """ + + def __init__( + self, + vllm_config: VllmConfig, + prefix: str = "", + **kwargs, + ) -> None: + super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs) + + self._pooler = GritLMPooler(model_config=vllm_config.model_config) + + assert isinstance( + self.model.layers[0].self_attn.attn.impl, + XFormersImpl), "GritLM is only supported by XFormers backend, " + "which can be forced by VLLM_ATTENTION_BACKEND=XFORMERS" + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + **kwargs, + ) -> Union[torch.Tensor, IntermediateTensors]: + # Change attention to non-causal. + assert attn_metadata.prefill_metadata.attn_bias is None + attn_metadata.prefill_metadata.attn_bias = [ + BlockDiagonalMask.from_seqlens(attn_metadata.seq_lens) + ] + + return super().forward( + input_ids=input_ids, + positions=positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + **kwargs, + ) + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + return self._pooler(hidden_states, pooling_metadata) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 2b7b69e8c3a95..1a0aec7b3188e 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -110,6 +110,7 @@ "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"), "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"), "GlmForCausalLM": ("glm", "GlmForCausalLM"), + "GritLM": ("gritlm", "GritLM"), "LlamaModel": ("llama", "LlamaForCausalLM"), **{ # Multiple models share the same architecture, so we include them all From 48f7947559c6a058ab526bfc8cebee50e44d709a Mon Sep 17 00:00:00 2001 From: Pooya Davoodi Date: Fri, 6 Dec 2024 02:17:17 +0900 Subject: [PATCH 02/18] Replace input processor by pattern matching on prompt_token_ids inside the pooler Signed-off-by: Pooya Davoodi --- .../models/embedding/language/test_gritlm.py | 51 ++--- vllm/model_executor/models/gritlm.py | 203 ++++++++++-------- 2 files changed, 143 insertions(+), 111 deletions(-) diff --git a/tests/models/embedding/language/test_gritlm.py b/tests/models/embedding/language/test_gritlm.py index 89c02e1a7951e..a10db2bd07775 100644 --- a/tests/models/embedding/language/test_gritlm.py +++ b/tests/models/embedding/language/test_gritlm.py @@ -1,5 +1,6 @@ import math import os +from array import array from typing import List import openai @@ -8,6 +9,7 @@ from scipy.spatial.distance import cosine import vllm +from vllm.model_executor.models.gritlm import GritLMPooler from ....utils import RemoteOpenAIServer @@ -17,6 +19,25 @@ os.environ["VLLM_ATTENTION_BACKEND"] = "XFORMERS" +def _arr(arr): + """ + Convert a list of integers to an array of integers. + """ + return array("i", arr) + + +def test_find_list(): + arr = _arr([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) + + assert GritLMPooler._find_list(arr, _arr([3, 4, 5]), start_idx=0) == 3 + assert GritLMPooler._find_list(arr, _arr([3, 4, 5]), start_idx=1) == 3 + assert GritLMPooler._find_list(arr, _arr([3, 4, 5]), start_idx=5) == -1 + assert GritLMPooler._find_list(arr, _arr([3, 5]), start_idx=0) == -1 + + with pytest.raises(ValueError): + GritLMPooler._find_list(arr, _arr([3, 4, 5]), start_idx=-1) + + @pytest.fixture(scope="module") def server(): args = [ @@ -33,28 +54,17 @@ async def client(server): yield async_client -def run_llm_encode(llm: vllm.LLM, queries: List[str], instruction: str, - use_instruction_arg: bool) -> List[float]: - pooling_params = vllm.PoolingParams( - additional_data={"instruction_seq": instruction - }) if use_instruction_arg else None - outputs = llm.encode( - [instruction + q for q in queries], - pooling_params=pooling_params, - ) +def run_llm_encode(llm: vllm.LLM, queries: List[str], + instruction: str) -> List[float]: + outputs = llm.encode([instruction + q for q in queries], ) return [output.outputs.embedding for output in outputs] async def run_client_embeddings(client: vllm.LLM, queries: List[str], - instruction: str, - use_instruction_arg: bool) -> List[float]: - additional_data = { - "instruction_seq": instruction - } if use_instruction_arg else None + instruction: str) -> List[float]: outputs = await client.embeddings.create( model=MODEL_NAME, input=[instruction + q for q in queries], - extra_body={"additional_data": additional_data}, ) return [data.embedding for data in outputs.data] @@ -100,8 +110,7 @@ def validate_output(q_rep: List[float], d_rep: List[float]): assert math.isclose(cosine_sim_q1_d1, 0.532, abs_tol=0.001) -@pytest.mark.parametrize("use_instruction_arg", [True, False]) -def test_gritlm_offline(use_instruction_arg: bool): +def test_gritlm_offline(): queries, q_instruction, documents, d_instruction = get_test_data() llm = vllm.LLM(MODEL_NAME, task="embedding") @@ -110,35 +119,29 @@ def test_gritlm_offline(use_instruction_arg: bool): llm, documents, d_instruction, - use_instruction_arg=use_instruction_arg, ) q_rep = run_llm_encode( llm, queries, q_instruction, - use_instruction_arg=use_instruction_arg, ) validate_output(q_rep, d_rep) @pytest.mark.asyncio -@pytest.mark.parametrize("use_instruction_arg", [True, False]) -async def test_gritlm_api_server(client: openai.AsyncOpenAI, - use_instruction_arg: bool): +async def test_gritlm_api_server(client: openai.AsyncOpenAI): queries, q_instruction, documents, d_instruction = get_test_data() d_rep = await run_client_embeddings( client, documents, d_instruction, - use_instruction_arg=use_instruction_arg, ) q_rep = await run_client_embeddings( client, queries, q_instruction, - use_instruction_arg=use_instruction_arg, ) validate_output(q_rep, d_rep) diff --git a/vllm/model_executor/models/gritlm.py b/vllm/model_executor/models/gritlm.py index 97d99c9057241..a5b5ca215434a 100644 --- a/vllm/model_executor/models/gritlm.py +++ b/vllm/model_executor/models/gritlm.py @@ -1,4 +1,4 @@ -import re +from array import array from typing import List, Optional, Union import torch @@ -8,14 +8,11 @@ from vllm.attention import AttentionMetadata from vllm.attention.backends.xformers import XFormersImpl from vllm.config import ModelConfig, VllmConfig -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, - token_inputs) from vllm.logger import init_logger from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.pooling_metadata import (PoolingMetadata, PoolingTensors) from vllm.multimodal.utils import cached_get_tokenizer -from vllm.pooling_params import PoolingParams from vllm.sequence import (EmbeddingSequenceGroupOutput, IntermediateTensors, PoolerOutput) @@ -24,56 +21,111 @@ class GritLMPooler(nn.Module): - def __init__( - self, - model_config: ModelConfig, - ): + def __init__(self, model_config: ModelConfig): super().__init__() self.model_config = model_config - def _get_instruction_lens( - self, device: torch.device, - pooling_metadata: PoolingMetadata) -> torch.Tensor: - """ - Compute the number of tokens of each instruction using the tokenizer. - """ - self.tokenizer = cached_get_tokenizer( + tokenizer = cached_get_tokenizer( self.model_config.tokenizer, tokenizer_mode=self.model_config.tokenizer_mode, tokenizer_revision=self.model_config.tokenizer_revision, trust_remote_code=self.model_config.trust_remote_code, - truncation_side="left", ) - def query_instruction_missing(pooling_params: PoolingParams) -> bool: - return (pooling_params is None - or pooling_params.additional_data is None - or "instruction_seq" not in pooling_params.additional_data) + # Collect the tokens needed for pattern matching. + self.token_ids = { + tok: tokenizer.convert_tokens_to_ids([tok])[0] + for tok in ["", "▁<", "<", "|", "embed", ">", "<0x0A>", "user"] + } - for seq_group in pooling_metadata.seq_groups: - if query_instruction_missing(seq_group[1]): - logger.warning( - "Query instruction not found in prompt," - "thus using empty string instead. GritLM requires " - "query instruction in prompt.") + @staticmethod + def _find_list(arr: array, target: array, start_idx: int) -> int: + """ + Find the first starting index where the search_list appears + as a consecutive subsequence in main_list. - instruction_lens = torch.tensor( - [ - len( - self.tokenizer( - ("" if query_instruction_missing(seq_group[1]) else - seq_group[1].additional_data["instruction_seq"]), - padding=False, - truncation=True, - add_special_tokens=True, - )["input_ids"]) - for seq_group in pooling_metadata.seq_groups - ], - device=device, - ) + Args: + arr: The array to search within + target: The consecutive subsequence to find + start_idx: The starting index to search from + + Returns: + int: The index of the first occurrence of target in arr. + """ + if start_idx < 0: + raise ValueError("start_idx must be non-negative") + + found_index = -1 + + # Handle edge cases + if not target or not arr: + return found_index + + # Length of lists + arr_len = len(arr) + target_len = len(target) + + # Iterate through possible starting positions + for i in range(start_idx, arr_len - target_len + 1): + # Check if the subsequence matches + if arr[i:i + target_len] == target: + found_index = i + break + + return found_index + + def _get_instruction_len(self, prompt_token_ids: array) -> bool: + """ + Get the length of the instruction in the prompt. + + We do a pattern matching to find the instruction in the prompt, + and then return the length of the instruction. + + The pattern matching is done using integers instead of strings + because the prompt is given as a list of token IDs. + """ + + def tokens_to_ids(tokens: list[str]) -> List[int]: + return array("i", [self.token_ids[token] for token in tokens]) + + instruction_len = 0 + + found_bos_token = prompt_token_ids[0] == self.token_ids[""] - return instruction_lens + # Return no instruction in case of missing BOS token. + if not found_bos_token: + logger.warning("BOS token not found in prompt," + "thus using empty string for instruction." + "GritLM requires BOS token in prompt.") + return instruction_len + + # Find the user pattern in the prompt. + user_token_ids = tokens_to_ids(["▁<", "|", "user", "|", ">", "<0x0A>"]) + found_user_pattern = (__class__._find_list(prompt_token_ids, + user_token_ids, + start_idx=1) == 1) + + # Find the embed pattern in the prompt. + if found_user_pattern: + embed_token_ids = tokens_to_ids( + ["<0x0A>", "<", "|", "embed", "|", ">", "<0x0A>"]) + else: + embed_token_ids = tokens_to_ids( + ["▁<", "|", "embed", "|", ">", "<0x0A>"]) + found_embed_pattern_idx = __class__._find_list(prompt_token_ids, + embed_token_ids, + start_idx=1) + + if found_embed_pattern_idx != -1: + instruction_len = found_embed_pattern_idx + len(embed_token_ids) + else: + logger.warning("Query instruction not found in prompt," + "thus using BOS token as instruction instead." + "GritLM requires query instruction in prompt.") + instruction_len = 1 + + return instruction_len def forward( self, @@ -84,8 +136,18 @@ def forward( Pool the hidden states by summing the embeddings of non-instruction tokens. """ - instruction_lens = self._get_instruction_lens( - device=hidden_states.device, pooling_metadata=pooling_metadata) + prompts_token_ids = [ + token_ids.prompt_token_ids_array + for _, token_ids in pooling_metadata.seq_data.items() + ] + + instruction_lens = torch.tensor( + [ + self._get_instruction_len(prompt_token_ids) + for prompt_token_ids in prompts_token_ids + ], + device=hidden_states.device, + ) prompt_lens = PoolingTensors.from_pooling_metadata( pooling_metadata, hidden_states.device).prompt_lens @@ -124,41 +186,6 @@ def forward( return PoolerOutput(outputs=pooled_outputs) -def input_processor_for_gritlm(ctx: InputContext, inputs: DecoderOnlyInputs): - """ - Extracts query instruction from prompt and adds it to token inputs. - """ - model_config = ctx.model_config - tokenizer = cached_get_tokenizer(model_config.tokenizer) - - prompt = inputs.get("prompt", None) - instruction = "" - - if prompt is None and "prompt_token_ids" in inputs: - prompt = tokenizer.decode(inputs["prompt_token_ids"]) - - if prompt is not None: - match_instruction = re.match(r"( )?(<\|user\|>\n.*\n<\|embed\|>\n)", - prompt) - match_empty_instruction = re.match(r"( )?(<\|embed\|>\n)", prompt) - - if match_instruction and match_instruction.group(2): - instruction = match_instruction.group(2) - elif match_empty_instruction: - instruction = match_empty_instruction.group(2) - else: - logger.warning("Query instruction not found in prompt," - "thus using empty string instead. GritLM requires " - "query instruction in prompt.") - - return token_inputs( - prompt_token_ids=inputs["prompt_token_ids"], - prompt=prompt, - instruction_seq=instruction, - ) - - -@INPUT_REGISTRY.register_input_processor(input_processor_for_gritlm) class GritLM(LlamaForCausalLM): """This class implements the embedding model for parasail-ai/GritLM-7B-vllm. @@ -171,11 +198,9 @@ class GritLM(LlamaForCausalLM): LlamaForCausalLM is that GritLM ignores the query instruction in the prompt when pooling the hidden states. - Instructions can be passed to the model in two ways: - 1. By prepending the instruction to the prompt. The instruction should be - in the format "<|user|>\n\n<|embed|>\n". - 2. By passing the instruction as additional data in the pooling parameters - (e.g. extra_body of client.embeddings.create). + Prompt must be in the following format: + - With instruction: "<|user|>\nINSTRUCTION\n<|embed|>\nPROMPT". + - Without instruction: "<|embed|>\nPROMPT". """ def __init__( @@ -186,12 +211,16 @@ def __init__( ) -> None: super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs) - self._pooler = GritLMPooler(model_config=vllm_config.model_config) + if vllm_config.model_config.task != "embedding": + raise ValueError(f"Task must be 'embedding' for GritLM, but got " + f"'{vllm_config.model_config.task}'") + + self._pooler = GritLMPooler(vllm_config.model_config) assert isinstance( - self.model.layers[0].self_attn.attn.impl, - XFormersImpl), "GritLM is only supported by XFormers backend, " - "which can be forced by VLLM_ATTENTION_BACKEND=XFORMERS" + self.model.layers[0].self_attn.attn.impl, XFormersImpl), ( + "GritLM is only supported by XFormers backend, " + "which can be forced by VLLM_ATTENTION_BACKEND=XFORMERS") def forward( self, From 1aec4d37acdc6cfdde6c61a25ba2dce9c1563c87 Mon Sep 17 00:00:00 2001 From: Pooya Davoodi Date: Fri, 6 Dec 2024 07:25:42 +0900 Subject: [PATCH 03/18] Revert changes in data.py and scheduler.py Signed-off-by: Pooya Davoodi --- vllm/core/scheduler.py | 38 +++++++++----------------------------- vllm/inputs/data.py | 11 ----------- 2 files changed, 9 insertions(+), 40 deletions(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 1a3d34187ec35..d23009dae01ee 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -12,7 +12,6 @@ from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import (Sequence, SequenceData, SequenceGroup, SequenceGroupMetadata, SequenceGroupMetadataDelta, @@ -524,7 +523,7 @@ def _schedule_running( chunked number of tokens are scheduled if `budget.num_batched_tokens` has not enough capacity to schedule all tokens. - + Returns: SchedulerRunningOutputs. """ @@ -842,10 +841,10 @@ def _schedule_priority_preemption( self._get_num_new_uncached_and_cached_tokens( seq_group, SequenceStatus.WAITING, False, budget)) - # Only preempt if priority inversion exists + #Only preempt if priority inversion exists while running_queue and self._get_priority( running_queue[-1]) > self._get_priority(seq_group): - # Only preempt if waiting sequence cannot be allocated + #Only preempt if waiting sequence cannot be allocated can_allocate = self.block_manager.can_allocate(seq_group) if (num_new_tokens_uncached > 0 and can_allocate == AllocStatus.OK @@ -855,7 +854,7 @@ def _schedule_priority_preemption( )): break - # Adjust budget to remove the victim sequence group + #Adjust budget to remove the victim sequence group vseq_group = running_queue.pop() num_running_tokens_uncached, _ = ( self._get_num_new_uncached_and_cached_tokens( @@ -866,11 +865,11 @@ def _schedule_priority_preemption( budget.subtract_num_seqs(vseq_group.request_id, num_running_seqs) - # Preempt out the victim sequence group + #Preempt out the victim sequence group self._preempt(vseq_group, blocks_to_swap_out) waiting_queue.appendleft(vseq_group) force_preemption_count += 1 - # Put the sequence back into the waiting queue + #Put the sequence back into the waiting queue waiting_queue.appendleft(seq_group) waiting_queue = deque(sorted(waiting_queue, key=self._get_priority)) @@ -1037,7 +1036,7 @@ def _schedule_prefills( def _schedule_default(self) -> SchedulerOutputs: """Schedule queued requests. - + The current policy is designed to optimize the throughput. First, it batches as many prefill requests as possible. And it schedules decodes. If there's a pressure on GPU memory, decode requests can @@ -1142,7 +1141,7 @@ def _schedule_default(self) -> SchedulerOutputs: def _schedule_chunked_prefill(self) -> SchedulerOutputs: """Schedule queued requests. - + Chunked prefill allows to chunk prefill requests, batch them together with decode requests. This policy 1. schedule as many decoding requests as possible. 2. schedule chunked prefill requests that are not @@ -1351,25 +1350,6 @@ def schedule( seqs[0].data.get_len()): do_sample = False - pooling_params = seq_group.pooling_params - - # Store instruction_seq in pooling_params. - instruction_seq = seq.inputs.inputs.get("instruction_seq") - if instruction_seq is not None: - if pooling_params is None: - pooling_params = PoolingParams() - pooling_params.additional_data = { - "instruction_seq": instruction_seq - } - elif pooling_params.additional_data is None: - pooling_params.additional_data = { - "instruction_seq": instruction_seq - } - else: - pooling_params.additional_data[ - "instruction_seq"] = seq.inputs.inputs.get( - "instruction_seq") - # It assumes the scheduled_seq_groups is ordered by # prefill < decoding. if is_first_prefill or not self.scheduler_config.send_delta_data: @@ -1380,7 +1360,7 @@ def schedule( sampling_params=seq_group.sampling_params, block_tables=block_tables, do_sample=do_sample, - pooling_params=pooling_params, + pooling_params=seq_group.pooling_params, token_chunk_size=token_chunk_size, lora_request=seq_group.lora_request, computed_block_nums=common_computed_block_nums, diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index aebfffe049df3..85aaaa776907f 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -170,14 +170,6 @@ class TokenInputs(TypedDict): to pass the mm_processor_kwargs to each of them. """ - instruction_seq: NotRequired[Optional[str]] - """ - The instruction sequence that is usually prepended to the original prompt - when passing to the model. Certain models need to extract this instruction - sequence from the prompt in order to adjust certain operations of the - model such as the attention mask. - """ - def token_inputs( prompt_token_ids: List[int], @@ -187,7 +179,6 @@ def token_inputs( multi_modal_inputs: Optional["MultiModalKwargs"] = None, multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None, mm_processor_kwargs: Optional[Dict[str, Any]] = None, - instruction_seq: Optional[str] = None, ) -> TokenInputs: """Construct :class:`TokenInputs` from optional values.""" inputs = TokenInputs(type="token", prompt_token_ids=prompt_token_ids) @@ -204,8 +195,6 @@ def token_inputs( inputs["multi_modal_placeholders"] = multi_modal_placeholders if mm_processor_kwargs is not None: inputs["mm_processor_kwargs"] = mm_processor_kwargs - if instruction_seq is not None: - inputs["instruction_seq"] = instruction_seq return inputs From 4941376af80c0f03e108c7e4f4137a149aaf7b05 Mon Sep 17 00:00:00 2001 From: Pooya Davoodi Date: Fri, 6 Dec 2024 08:22:57 +0900 Subject: [PATCH 04/18] Improve _find_list Signed-off-by: Pooya Davoodi --- .../models/embedding/language/test_gritlm.py | 12 +++--- vllm/model_executor/models/gritlm.py | 38 +++++++------------ 2 files changed, 20 insertions(+), 30 deletions(-) diff --git a/tests/models/embedding/language/test_gritlm.py b/tests/models/embedding/language/test_gritlm.py index a10db2bd07775..22574069b4c27 100644 --- a/tests/models/embedding/language/test_gritlm.py +++ b/tests/models/embedding/language/test_gritlm.py @@ -26,16 +26,16 @@ def _arr(arr): return array("i", arr) -def test_find_list(): +def test_find_array(): arr = _arr([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) - assert GritLMPooler._find_list(arr, _arr([3, 4, 5]), start_idx=0) == 3 - assert GritLMPooler._find_list(arr, _arr([3, 4, 5]), start_idx=1) == 3 - assert GritLMPooler._find_list(arr, _arr([3, 4, 5]), start_idx=5) == -1 - assert GritLMPooler._find_list(arr, _arr([3, 5]), start_idx=0) == -1 + assert GritLMPooler._find_array(arr, _arr([3, 4, 5]), start_idx=0) == 3 + assert GritLMPooler._find_array(arr, _arr([3, 4, 5]), start_idx=1) == 3 + assert GritLMPooler._find_array(arr, _arr([3, 4, 5]), start_idx=5) == -1 + assert GritLMPooler._find_array(arr, _arr([3, 5]), start_idx=0) == -1 with pytest.raises(ValueError): - GritLMPooler._find_list(arr, _arr([3, 4, 5]), start_idx=-1) + GritLMPooler._find_array(arr, _arr([3, 4, 5]), start_idx=-1) @pytest.fixture(scope="module") diff --git a/vllm/model_executor/models/gritlm.py b/vllm/model_executor/models/gritlm.py index a5b5ca215434a..16a728295fbaa 100644 --- a/vllm/model_executor/models/gritlm.py +++ b/vllm/model_executor/models/gritlm.py @@ -40,10 +40,9 @@ def __init__(self, model_config: ModelConfig): } @staticmethod - def _find_list(arr: array, target: array, start_idx: int) -> int: + def _find_array(arr: array, target: array, start_idx: int) -> int: """ - Find the first starting index where the search_list appears - as a consecutive subsequence in main_list. + Find the first occurrence of target in arr starting from start_idx. Args: arr: The array to search within @@ -55,25 +54,14 @@ def _find_list(arr: array, target: array, start_idx: int) -> int: """ if start_idx < 0: raise ValueError("start_idx must be non-negative") - - found_index = -1 - - # Handle edge cases if not target or not arr: - return found_index + raise ValueError("Empty arr or target not allowed") - # Length of lists - arr_len = len(arr) target_len = len(target) - - # Iterate through possible starting positions - for i in range(start_idx, arr_len - target_len + 1): - # Check if the subsequence matches + for i in range(start_idx, len(arr) - target_len + 1): if arr[i:i + target_len] == target: - found_index = i - break - - return found_index + return i + return -1 def _get_instruction_len(self, prompt_token_ids: array) -> bool: """ @@ -102,20 +90,22 @@ def tokens_to_ids(tokens: list[str]) -> List[int]: # Find the user pattern in the prompt. user_token_ids = tokens_to_ids(["▁<", "|", "user", "|", ">", "<0x0A>"]) - found_user_pattern = (__class__._find_list(prompt_token_ids, - user_token_ids, - start_idx=1) == 1) + found_user_pattern = (__class__._find_array(prompt_token_ids, + user_token_ids, + start_idx=1) == 1) # Find the embed pattern in the prompt. if found_user_pattern: + # If user pattern is found, that means there should be + # a newline token before the embed pattern. embed_token_ids = tokens_to_ids( ["<0x0A>", "<", "|", "embed", "|", ">", "<0x0A>"]) else: embed_token_ids = tokens_to_ids( ["▁<", "|", "embed", "|", ">", "<0x0A>"]) - found_embed_pattern_idx = __class__._find_list(prompt_token_ids, - embed_token_ids, - start_idx=1) + found_embed_pattern_idx = __class__._find_array(prompt_token_ids, + embed_token_ids, + start_idx=1) if found_embed_pattern_idx != -1: instruction_len = found_embed_pattern_idx + len(embed_token_ids) From 5f32d7cdc750279a290ca5b52a9580c7a013ed68 Mon Sep 17 00:00:00 2001 From: Pooya Davoodi Date: Sat, 7 Dec 2024 05:24:39 +0900 Subject: [PATCH 05/18] Fix BOS check, move patterns to constructor Signed-off-by: Pooya Davoodi --- .../models/embedding/language/test_gritlm.py | 14 +++-- vllm/model_executor/models/gritlm.py | 51 ++++++++++--------- 2 files changed, 35 insertions(+), 30 deletions(-) diff --git a/tests/models/embedding/language/test_gritlm.py b/tests/models/embedding/language/test_gritlm.py index 22574069b4c27..b5ae5ee992765 100644 --- a/tests/models/embedding/language/test_gritlm.py +++ b/tests/models/embedding/language/test_gritlm.py @@ -27,15 +27,19 @@ def _arr(arr): def test_find_array(): + # Create an LLM object to get the model config. + llm = vllm.LLM(MODEL_NAME, task="embedding") + pooler = GritLMPooler(model_config=llm.llm_engine.model_config) + arr = _arr([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) - assert GritLMPooler._find_array(arr, _arr([3, 4, 5]), start_idx=0) == 3 - assert GritLMPooler._find_array(arr, _arr([3, 4, 5]), start_idx=1) == 3 - assert GritLMPooler._find_array(arr, _arr([3, 4, 5]), start_idx=5) == -1 - assert GritLMPooler._find_array(arr, _arr([3, 5]), start_idx=0) == -1 + assert pooler._find_array(arr, _arr([3, 4, 5]), start_idx=0) == 3 + assert pooler._find_array(arr, _arr([3, 4, 5]), start_idx=1) == 3 + assert pooler._find_array(arr, _arr([3, 4, 5]), start_idx=5) == -1 + assert pooler._find_array(arr, _arr([3, 5]), start_idx=0) == -1 with pytest.raises(ValueError): - GritLMPooler._find_array(arr, _arr([3, 4, 5]), start_idx=-1) + pooler._find_array(arr, _arr([3, 4, 5]), start_idx=-1) @pytest.fixture(scope="module") diff --git a/vllm/model_executor/models/gritlm.py b/vllm/model_executor/models/gritlm.py index 16a728295fbaa..55bfad2abea03 100644 --- a/vllm/model_executor/models/gritlm.py +++ b/vllm/model_executor/models/gritlm.py @@ -34,13 +34,25 @@ def __init__(self, model_config: ModelConfig): ) # Collect the tokens needed for pattern matching. + # "▁<" is different from "_<". The former uses "▁" to indicate that + # the next token is the start of a word. + # "<0x0A>" is the newline token (i.e. "\n")." self.token_ids = { tok: tokenizer.convert_tokens_to_ids([tok])[0] for tok in ["", "▁<", "<", "|", "embed", ">", "<0x0A>", "user"] } - @staticmethod - def _find_array(arr: array, target: array, start_idx: int) -> int: + def tokens_to_ids(tokens: list[str]) -> array: + return array("i", [self.token_ids[token] for token in tokens]) + + self.user_pattern_ids = tokens_to_ids( + ["▁<", "|", "user", "|", ">", "<0x0A>"]) + self.embed_newline_pattern_ids = tokens_to_ids( + ["<0x0A>", "<", "|", "embed", "|", ">", "<0x0A>"]) + self.embed_pattern_ids = tokens_to_ids( + ["▁<", "|", "embed", "|", ">", "<0x0A>"]) + + def _find_array(self, arr: array, target: array, start_idx: int) -> int: """ Find the first occurrence of target in arr starting from start_idx. @@ -74,41 +86,30 @@ def _get_instruction_len(self, prompt_token_ids: array) -> bool: because the prompt is given as a list of token IDs. """ - def tokens_to_ids(tokens: list[str]) -> List[int]: - return array("i", [self.token_ids[token] for token in tokens]) - instruction_len = 0 - found_bos_token = prompt_token_ids[0] == self.token_ids[""] - # Return no instruction in case of missing BOS token. - if not found_bos_token: + if prompt_token_ids[0] != self.token_ids[""]: logger.warning("BOS token not found in prompt," "thus using empty string for instruction." "GritLM requires BOS token in prompt.") return instruction_len - # Find the user pattern in the prompt. - user_token_ids = tokens_to_ids(["▁<", "|", "user", "|", ">", "<0x0A>"]) - found_user_pattern = (__class__._find_array(prompt_token_ids, - user_token_ids, - start_idx=1) == 1) + # If user pattern is found in the prompt, that means there should be + # a newline token before the embed pattern. + embed_pattern_ids = self.embed_pattern_ids + if self._find_array(prompt_token_ids, + self.user_pattern_ids, + start_idx=1) == 1: + embed_pattern_ids = self.embed_newline_pattern_ids # Find the embed pattern in the prompt. - if found_user_pattern: - # If user pattern is found, that means there should be - # a newline token before the embed pattern. - embed_token_ids = tokens_to_ids( - ["<0x0A>", "<", "|", "embed", "|", ">", "<0x0A>"]) - else: - embed_token_ids = tokens_to_ids( - ["▁<", "|", "embed", "|", ">", "<0x0A>"]) - found_embed_pattern_idx = __class__._find_array(prompt_token_ids, - embed_token_ids, - start_idx=1) + found_embed_pattern_idx = self._find_array(prompt_token_ids, + embed_pattern_ids, + start_idx=1) if found_embed_pattern_idx != -1: - instruction_len = found_embed_pattern_idx + len(embed_token_ids) + instruction_len = found_embed_pattern_idx + len(embed_pattern_ids) else: logger.warning("Query instruction not found in prompt," "thus using BOS token as instruction instead." From 070483ebb49c36d06725f85652aeb452a85d1f17 Mon Sep 17 00:00:00 2001 From: Pooya Davoodi Date: Tue, 10 Dec 2024 04:44:50 +0900 Subject: [PATCH 06/18] Add PP and LORA support to docs Signed-off-by: Pooya Davoodi --- docs/source/models/supported_models.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 71b91f51d1362..dbc659093ea10 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -353,8 +353,8 @@ Text Embedding * - :code:`GritLM` - GritLM - :code:`parasail-ai/GritLM-7B-vllm`. - - - - + - ✅︎ + - ✅︎ * - :code:`LlamaModel`, :code:`LlamaForCausalLM`, :code:`MistralModel`, etc. - Llama-based - :code:`intfloat/e5-mistral-7b-instruct`, etc. From e84c2a45f399be1f6c2bf559f47dee88f761e57f Mon Sep 17 00:00:00 2001 From: Pooya Davoodi Date: Tue, 10 Dec 2024 05:20:56 +0900 Subject: [PATCH 07/18] Improve xformers check In case of PP, some layers don't have self_attn attribute Signed-off-by: Pooya Davoodi --- vllm/model_executor/models/gritlm.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/gritlm.py b/vllm/model_executor/models/gritlm.py index 55bfad2abea03..a57442c8782cf 100644 --- a/vllm/model_executor/models/gritlm.py +++ b/vllm/model_executor/models/gritlm.py @@ -208,10 +208,12 @@ def __init__( self._pooler = GritLMPooler(vllm_config.model_config) - assert isinstance( - self.model.layers[0].self_attn.attn.impl, XFormersImpl), ( - "GritLM is only supported by XFormers backend, " - "which can be forced by VLLM_ATTENTION_BACKEND=XFORMERS") + for layer in self.model.layers: + if hasattr(layer, "self_attn"): + assert isinstance(layer.self_attn.attn.impl, XFormersImpl), ( + "GritLM is only supported by XFormers backend, " + "which can be forced by VLLM_ATTENTION_BACKEND=XFORMERS" + ) def forward( self, From 7a0652716a16c1b5033cb0294f18e51fe2e41556 Mon Sep 17 00:00:00 2001 From: Pooya Davoodi Date: Tue, 10 Dec 2024 05:25:43 +0900 Subject: [PATCH 08/18] Fix format Signed-off-by: Pooya Davoodi --- vllm/model_executor/models/gritlm.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/model_executor/models/gritlm.py b/vllm/model_executor/models/gritlm.py index a57442c8782cf..28b550efb7049 100644 --- a/vllm/model_executor/models/gritlm.py +++ b/vllm/model_executor/models/gritlm.py @@ -212,8 +212,7 @@ def __init__( if hasattr(layer, "self_attn"): assert isinstance(layer.self_attn.attn.impl, XFormersImpl), ( "GritLM is only supported by XFormers backend, " - "which can be forced by VLLM_ATTENTION_BACKEND=XFORMERS" - ) + "which can be forced by VLLM_ATTENTION_BACKEND=XFORMERS") def forward( self, From 92557866cc018e22519b9e942cf3b238771908bb Mon Sep 17 00:00:00 2001 From: Pooya Davoodi Date: Tue, 10 Dec 2024 13:52:23 +0900 Subject: [PATCH 09/18] Add GritLM to tests/models/registry.py Signed-off-by: Pooya Davoodi --- tests/models/registry.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/models/registry.py b/tests/models/registry.py index a89518820045f..6a8b1742ceae3 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -135,6 +135,7 @@ class _HfExamplesInfo: # [Text-only] "BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"), "Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"), + "GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"), "LlamaModel": _HfExamplesInfo("llama", is_available_online=False), "MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"), "Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"), From dddf3852269eb21bfcba7a0e8a47a04375e263ae Mon Sep 17 00:00:00 2001 From: Pooya Davoodi Date: Wed, 11 Dec 2024 04:03:30 +0900 Subject: [PATCH 10/18] Support generate task Signed-off-by: Pooya Davoodi --- .../models/embedding/language/test_gritlm.py | 93 ++++++++++++++++--- vllm/model_executor/models/gritlm.py | 16 ++-- 2 files changed, 87 insertions(+), 22 deletions(-) diff --git a/tests/models/embedding/language/test_gritlm.py b/tests/models/embedding/language/test_gritlm.py index b5ae5ee992765..eb0360e341b37 100644 --- a/tests/models/embedding/language/test_gritlm.py +++ b/tests/models/embedding/language/test_gritlm.py @@ -9,6 +9,7 @@ from scipy.spatial.distance import cosine import vllm +import vllm.config from vllm.model_executor.models.gritlm import GritLMPooler from ....utils import RemoteOpenAIServer @@ -43,18 +44,28 @@ def test_find_array(): @pytest.fixture(scope="module") -def server(): - args = [ - "--task", - "embedding", - ] +def server_embedding(): + args = ["--task", "embedding"] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: yield remote_server +@pytest.fixture(scope="module") +def server_generate(): + args = ["--task", "generate"] + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def client_embedding(server_embedding: RemoteOpenAIServer): + async with server_embedding.get_async_client() as async_client: + yield async_client + + @pytest_asyncio.fixture -async def client(server): - async with server.get_async_client() as async_client: +async def client_generate(server_generate: RemoteOpenAIServer): + async with server_generate.get_async_client() as async_client: yield async_client @@ -100,7 +111,7 @@ def get_test_data(): return queries, q_instruction, documents, d_instruction -def validate_output(q_rep: List[float], d_rep: List[float]): +def validate_embed_output(q_rep: List[float], d_rep: List[float]): cosine_sim_q0_d0 = 1 - cosine(q_rep[0], d_rep[0]) assert math.isclose(cosine_sim_q0_d0, 0.609, abs_tol=0.001) @@ -114,7 +125,7 @@ def validate_output(q_rep: List[float], d_rep: List[float]): assert math.isclose(cosine_sim_q1_d1, 0.532, abs_tol=0.001) -def test_gritlm_offline(): +def test_gritlm_offline_embedding(): queries, q_instruction, documents, d_instruction = get_test_data() llm = vllm.LLM(MODEL_NAME, task="embedding") @@ -130,22 +141,76 @@ def test_gritlm_offline(): q_instruction, ) - validate_output(q_rep, d_rep) + validate_embed_output(q_rep, d_rep) @pytest.mark.asyncio -async def test_gritlm_api_server(client: openai.AsyncOpenAI): +async def test_gritlm_api_server_embedding( + client_embedding: openai.AsyncOpenAI): queries, q_instruction, documents, d_instruction = get_test_data() d_rep = await run_client_embeddings( - client, + client_embedding, documents, d_instruction, ) q_rep = await run_client_embeddings( - client, + client_embedding, queries, q_instruction, ) - validate_output(q_rep, d_rep) + validate_embed_output(q_rep, d_rep) + + +def validate_gen_output(output: str): + expected_output = """Oh, Mt. Fuji, mountain grand, +A sight to see, a climb to command, +At midnight, in the dark of night, +I climbed your slopes, with all my might. + +The stars above, they shone so bright, +A beacon in the darkness, guiding light, +The wind did blow, with a gentle sigh, +As I climbed higher, with a steady eye. + +The path was steep, the climb was tough, +But I pressed on, with a steadfast rough, +For the summit, I longed to see, +The view from the top, a sight to be. + +At last, I reached the peak, and stood, +With awe and wonder, I gazed aloud, +The world below, a sight to see, +A view that's worth the climb, you'll agree. + +Mt. Fuji, mountain grand, +A sight to see, a climb to command, +At midnight, in the dark of night, +I climbed your slopes, with all my might.""" + + assert output == expected_output + + +def test_gritlm_offline_gen(): + input = "<|user|>\nPlease write me a poem about my recent hike of Mt. Fuji at midnight in the style of Shakespeare.\n<|assistant|>\n" + + llm = vllm.LLM(MODEL_NAME) + sampling_params = vllm.SamplingParams(temperature=0.0, max_tokens=256) + outputs = llm.generate(input, sampling_params=sampling_params) + + validate_gen_output(outputs[0].outputs[0].text) + + +@pytest.mark.asyncio +async def test_gritlm_api_server_gen(client_generate: openai.AsyncOpenAI): + input = "<|user|>\nPlease write me a poem about my recent hike of Mt. Fuji at midnight in the style of Shakespeare.\n<|assistant|>\n" + + outputs = await client_generate.completions.create( + model=MODEL_NAME, + prompt=input, + max_tokens=256, + temperature=0.0, + ) + + validate_gen_output(outputs.choices[0].text) diff --git a/vllm/model_executor/models/gritlm.py b/vllm/model_executor/models/gritlm.py index 28b550efb7049..f9ebdb00f99ee 100644 --- a/vllm/model_executor/models/gritlm.py +++ b/vllm/model_executor/models/gritlm.py @@ -202,9 +202,7 @@ def __init__( ) -> None: super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs) - if vllm_config.model_config.task != "embedding": - raise ValueError(f"Task must be 'embedding' for GritLM, but got " - f"'{vllm_config.model_config.task}'") + self.task = vllm_config.model_config.task self._pooler = GritLMPooler(vllm_config.model_config) @@ -222,11 +220,13 @@ def forward( attn_metadata: AttentionMetadata, **kwargs, ) -> Union[torch.Tensor, IntermediateTensors]: - # Change attention to non-causal. - assert attn_metadata.prefill_metadata.attn_bias is None - attn_metadata.prefill_metadata.attn_bias = [ - BlockDiagonalMask.from_seqlens(attn_metadata.seq_lens) - ] + + # Change attention to non-causal for embedding task. + if self.task == "embedding": + assert attn_metadata.prefill_metadata.attn_bias is None + attn_metadata.prefill_metadata.attn_bias = [ + BlockDiagonalMask.from_seqlens(attn_metadata.seq_lens) + ] return super().forward( input_ids=input_ids, From f7fdd77c6f5d87e0ef895bb4f76f00b7dc3adb50 Mon Sep 17 00:00:00 2001 From: Pooya Davoodi Date: Wed, 11 Dec 2024 04:48:01 +0900 Subject: [PATCH 11/18] Skip tests if xformers is not available Signed-off-by: Pooya Davoodi --- tests/models/embedding/language/test_gritlm.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/models/embedding/language/test_gritlm.py b/tests/models/embedding/language/test_gritlm.py index eb0360e341b37..eed04d4b84bc5 100644 --- a/tests/models/embedding/language/test_gritlm.py +++ b/tests/models/embedding/language/test_gritlm.py @@ -1,3 +1,4 @@ +import importlib.util import math import os from array import array @@ -14,11 +15,13 @@ from ....utils import RemoteOpenAIServer -MODEL_NAME = "parasail-ai/GritLM-7B-vllm" - # GritLM implementation is only supported by XFormers backend. +pytest.mark.skipif(not importlib.util.find_spec("xformers"), + reason="GritLM requires XFormers") os.environ["VLLM_ATTENTION_BACKEND"] = "XFORMERS" +MODEL_NAME = "parasail-ai/GritLM-7B-vllm" + def _arr(arr): """ From 8d2c4d3f5fccf14242dfdab26f8036bccc583bba Mon Sep 17 00:00:00 2001 From: Pooya Davoodi Date: Wed, 11 Dec 2024 05:01:53 +0900 Subject: [PATCH 12/18] Add GritLM to _TEXT_GENERATION_MODELS Signed-off-by: Pooya Davoodi --- vllm/model_executor/models/registry.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index e8114717a50b7..64f2f5c5a646e 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -56,6 +56,7 @@ "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"), "GraniteForCausalLM": ("granite", "GraniteForCausalLM"), "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"), + "GritLM": ("gritlm", "GritLM"), "InternLMForCausalLM": ("llama", "LlamaForCausalLM"), "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"), "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"), From 9579e372014bcbab74c4deda422909e30303b9d3 Mon Sep 17 00:00:00 2001 From: Pooya Davoodi Date: Wed, 11 Dec 2024 05:02:04 +0900 Subject: [PATCH 13/18] Reduce context length in tests And make test prompt shorter Signed-off-by: Pooya Davoodi --- .../models/embedding/language/test_gritlm.py | 48 ++++--------------- 1 file changed, 10 insertions(+), 38 deletions(-) diff --git a/tests/models/embedding/language/test_gritlm.py b/tests/models/embedding/language/test_gritlm.py index eed04d4b84bc5..23b5f31e328e7 100644 --- a/tests/models/embedding/language/test_gritlm.py +++ b/tests/models/embedding/language/test_gritlm.py @@ -22,6 +22,7 @@ MODEL_NAME = "parasail-ai/GritLM-7B-vllm" +MAX_MODEL_LEN = 4000 def _arr(arr): """ @@ -32,7 +33,7 @@ def _arr(arr): def test_find_array(): # Create an LLM object to get the model config. - llm = vllm.LLM(MODEL_NAME, task="embedding") + llm = vllm.LLM(MODEL_NAME, task="embedding", max_model_len=MAX_MODEL_LEN) pooler = GritLMPooler(model_config=llm.llm_engine.model_config) arr = _arr([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) @@ -48,14 +49,14 @@ def test_find_array(): @pytest.fixture(scope="module") def server_embedding(): - args = ["--task", "embedding"] + args = ["--task", "embedding", "--max_model_len", str(MAX_MODEL_LEN)] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: yield remote_server @pytest.fixture(scope="module") def server_generate(): - args = ["--task", "generate"] + args = ["--task", "generate", "--max_model_len", str(MAX_MODEL_LEN)] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: yield remote_server @@ -131,7 +132,7 @@ def validate_embed_output(q_rep: List[float], d_rep: List[float]): def test_gritlm_offline_embedding(): queries, q_instruction, documents, d_instruction = get_test_data() - llm = vllm.LLM(MODEL_NAME, task="embedding") + llm = vllm.LLM(MODEL_NAME, task="embedding", max_model_len=MAX_MODEL_LEN) d_rep = run_llm_encode( llm, @@ -166,48 +167,19 @@ async def test_gritlm_api_server_embedding( validate_embed_output(q_rep, d_rep) -def validate_gen_output(output: str): - expected_output = """Oh, Mt. Fuji, mountain grand, -A sight to see, a climb to command, -At midnight, in the dark of night, -I climbed your slopes, with all my might. - -The stars above, they shone so bright, -A beacon in the darkness, guiding light, -The wind did blow, with a gentle sigh, -As I climbed higher, with a steady eye. - -The path was steep, the climb was tough, -But I pressed on, with a steadfast rough, -For the summit, I longed to see, -The view from the top, a sight to be. - -At last, I reached the peak, and stood, -With awe and wonder, I gazed aloud, -The world below, a sight to see, -A view that's worth the climb, you'll agree. - -Mt. Fuji, mountain grand, -A sight to see, a climb to command, -At midnight, in the dark of night, -I climbed your slopes, with all my might.""" - - assert output == expected_output - - def test_gritlm_offline_gen(): - input = "<|user|>\nPlease write me a poem about my recent hike of Mt. Fuji at midnight in the style of Shakespeare.\n<|assistant|>\n" + input = "<|user|>\nWhat is the capital of France?\n<|assistant|>\n" - llm = vllm.LLM(MODEL_NAME) + llm = vllm.LLM(MODEL_NAME, max_model_len=MAX_MODEL_LEN) sampling_params = vllm.SamplingParams(temperature=0.0, max_tokens=256) outputs = llm.generate(input, sampling_params=sampling_params) - validate_gen_output(outputs[0].outputs[0].text) + assert outputs[0].outputs[0].text == "The capital of France is Paris." @pytest.mark.asyncio async def test_gritlm_api_server_gen(client_generate: openai.AsyncOpenAI): - input = "<|user|>\nPlease write me a poem about my recent hike of Mt. Fuji at midnight in the style of Shakespeare.\n<|assistant|>\n" + input = "<|user|>\nWhat is the capital of France?\n<|assistant|>\n" outputs = await client_generate.completions.create( model=MODEL_NAME, @@ -216,4 +188,4 @@ async def test_gritlm_api_server_gen(client_generate: openai.AsyncOpenAI): temperature=0.0, ) - validate_gen_output(outputs.choices[0].text) + assert outputs.choices[0].text == "The capital of France is Paris." From 9a5cb50d87fbf5893fdb5dff7792232071532bd5 Mon Sep 17 00:00:00 2001 From: Pooya Davoodi Date: Wed, 11 Dec 2024 05:18:21 +0900 Subject: [PATCH 14/18] Fix format Signed-off-by: Pooya Davoodi --- tests/models/embedding/language/test_gritlm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/models/embedding/language/test_gritlm.py b/tests/models/embedding/language/test_gritlm.py index 23b5f31e328e7..e49754bc44c20 100644 --- a/tests/models/embedding/language/test_gritlm.py +++ b/tests/models/embedding/language/test_gritlm.py @@ -24,6 +24,7 @@ MAX_MODEL_LEN = 4000 + def _arr(arr): """ Convert a list of integers to an array of integers. From 9c9deeaa08aedc19cbbf0bbc30ce0e7364a5055d Mon Sep 17 00:00:00 2001 From: Pooya Davoodi Date: Wed, 11 Dec 2024 08:30:25 +0900 Subject: [PATCH 15/18] Relax xformers req for generation Signed-off-by: Pooya Davoodi --- vllm/model_executor/models/gritlm.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/models/gritlm.py b/vllm/model_executor/models/gritlm.py index f9ebdb00f99ee..13b7a908fff00 100644 --- a/vllm/model_executor/models/gritlm.py +++ b/vllm/model_executor/models/gritlm.py @@ -183,15 +183,16 @@ class GritLM(LlamaForCausalLM): The class inherits from LlamaForCausalLM and provides a custom pooling layer. - The task "embedding" must be specified in the server arguments. - The main difference between the pooling layer in GritLM and the one in LlamaForCausalLM is that GritLM ignores the query instruction in the prompt when pooling the hidden states. - Prompt must be in the following format: + Embedding prompts should be in the following format: - With instruction: "<|user|>\nINSTRUCTION\n<|embed|>\nPROMPT". - Without instruction: "<|embed|>\nPROMPT". + + Generation prompts should be in the following format: + - "<|user|>\nPROMPT\n<|assistant|>\n" """ def __init__( @@ -206,11 +207,13 @@ def __init__( self._pooler = GritLMPooler(vllm_config.model_config) - for layer in self.model.layers: - if hasattr(layer, "self_attn"): - assert isinstance(layer.self_attn.attn.impl, XFormersImpl), ( - "GritLM is only supported by XFormers backend, " - "which can be forced by VLLM_ATTENTION_BACKEND=XFORMERS") + if self.task == "embedding": + for layer in self.model.layers: + if hasattr(layer, "self_attn"): + assert isinstance(layer.self_attn.attn.impl, XFormersImpl), ( + "GritLM embedding is only supported by XFormers backend, " + "which can be forced by VLLM_ATTENTION_BACKEND=XFORMERS" + ) def forward( self, From a273596b6ea126434520cb06c4c512dd04021dcd Mon Sep 17 00:00:00 2001 From: Pooya Davoodi Date: Wed, 11 Dec 2024 08:30:48 +0900 Subject: [PATCH 16/18] Move import below skip Signed-off-by: Pooya Davoodi --- tests/models/embedding/language/test_gritlm.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/models/embedding/language/test_gritlm.py b/tests/models/embedding/language/test_gritlm.py index e49754bc44c20..06b132dffe166 100644 --- a/tests/models/embedding/language/test_gritlm.py +++ b/tests/models/embedding/language/test_gritlm.py @@ -11,14 +11,16 @@ import vllm import vllm.config + +# GritLM embedding implementation is only supported by XFormers backend. +pytest.mark.skipif( + not importlib.util.find_spec("xformers"), reason="GritLM requires XFormers" +) +os.environ["VLLM_ATTENTION_BACKEND"] = "XFORMERS" from vllm.model_executor.models.gritlm import GritLMPooler from ....utils import RemoteOpenAIServer -# GritLM implementation is only supported by XFormers backend. -pytest.mark.skipif(not importlib.util.find_spec("xformers"), - reason="GritLM requires XFormers") -os.environ["VLLM_ATTENTION_BACKEND"] = "XFORMERS" MODEL_NAME = "parasail-ai/GritLM-7B-vllm" From 703c09c96d3e5f18a77755d7996f17c21f46a1e8 Mon Sep 17 00:00:00 2001 From: Pooya Davoodi Date: Wed, 11 Dec 2024 08:34:44 +0900 Subject: [PATCH 17/18] Fix format Signed-off-by: Pooya Davoodi --- tests/models/embedding/language/test_gritlm.py | 13 ++++++------- vllm/model_executor/models/gritlm.py | 12 +++++------- 2 files changed, 11 insertions(+), 14 deletions(-) diff --git a/tests/models/embedding/language/test_gritlm.py b/tests/models/embedding/language/test_gritlm.py index 06b132dffe166..b6bc828b46426 100644 --- a/tests/models/embedding/language/test_gritlm.py +++ b/tests/models/embedding/language/test_gritlm.py @@ -12,15 +12,12 @@ import vllm import vllm.config -# GritLM embedding implementation is only supported by XFormers backend. -pytest.mark.skipif( - not importlib.util.find_spec("xformers"), reason="GritLM requires XFormers" -) -os.environ["VLLM_ATTENTION_BACKEND"] = "XFORMERS" -from vllm.model_executor.models.gritlm import GritLMPooler - from ....utils import RemoteOpenAIServer +# GritLM embedding implementation is only supported by XFormers backend. +pytest.mark.skipif(not importlib.util.find_spec("xformers"), + reason="GritLM requires XFormers") +os.environ["VLLM_ATTENTION_BACKEND"] = "XFORMERS" MODEL_NAME = "parasail-ai/GritLM-7B-vllm" @@ -35,6 +32,8 @@ def _arr(arr): def test_find_array(): + from vllm.model_executor.models.gritlm import GritLMPooler + # Create an LLM object to get the model config. llm = vllm.LLM(MODEL_NAME, task="embedding", max_model_len=MAX_MODEL_LEN) pooler = GritLMPooler(model_config=llm.llm_engine.model_config) diff --git a/vllm/model_executor/models/gritlm.py b/vllm/model_executor/models/gritlm.py index 13b7a908fff00..ec01a07c16a62 100644 --- a/vllm/model_executor/models/gritlm.py +++ b/vllm/model_executor/models/gritlm.py @@ -207,13 +207,11 @@ def __init__( self._pooler = GritLMPooler(vllm_config.model_config) - if self.task == "embedding": - for layer in self.model.layers: - if hasattr(layer, "self_attn"): - assert isinstance(layer.self_attn.attn.impl, XFormersImpl), ( - "GritLM embedding is only supported by XFormers backend, " - "which can be forced by VLLM_ATTENTION_BACKEND=XFORMERS" - ) + for layer in self.model.layers: + if self.task == "embedding" and hasattr(layer, "self_attn"): + assert isinstance(layer.self_attn.attn.impl, XFormersImpl), ( + "GritLM embedding is only supported by XFormers backend, " + "which can be forced by VLLM_ATTENTION_BACKEND=XFORMERS") def forward( self, From 66664457b87d08ddc04287fe785e1cff1ff08cec Mon Sep 17 00:00:00 2001 From: Pooya Davoodi Date: Wed, 11 Dec 2024 08:37:42 +0900 Subject: [PATCH 18/18] Add gritlm to generation doc Signed-off-by: Pooya Davoodi --- docs/source/models/supported_models.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 4071bb5e0fdda..ec5509e20187b 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -174,6 +174,11 @@ Text Generation - :code:`ibm-granite/granite-3.0-1b-a400m-base`, :code:`ibm-granite/granite-3.0-3b-a800m-instruct`, :code:`ibm/PowerMoE-3b`, etc. - ✅︎ - ✅︎ + * - :code:`GritLM` + - GritLM + - :code:`parasail-ai/GritLM-7B-vllm`. + - ✅︎ + - ✅︎ * - :code:`InternLMForCausalLM` - InternLM - :code:`internlm/internlm-7b`, :code:`internlm/internlm-chat-7b`, etc.