From 5f32d7cdc750279a290ca5b52a9580c7a013ed68 Mon Sep 17 00:00:00 2001 From: Pooya Davoodi Date: Sat, 7 Dec 2024 05:24:39 +0900 Subject: [PATCH] Fix BOS check, move patterns to constructor Signed-off-by: Pooya Davoodi --- .../models/embedding/language/test_gritlm.py | 14 +++-- vllm/model_executor/models/gritlm.py | 51 ++++++++++--------- 2 files changed, 35 insertions(+), 30 deletions(-) diff --git a/tests/models/embedding/language/test_gritlm.py b/tests/models/embedding/language/test_gritlm.py index 22574069b4c27..b5ae5ee992765 100644 --- a/tests/models/embedding/language/test_gritlm.py +++ b/tests/models/embedding/language/test_gritlm.py @@ -27,15 +27,19 @@ def _arr(arr): def test_find_array(): + # Create an LLM object to get the model config. + llm = vllm.LLM(MODEL_NAME, task="embedding") + pooler = GritLMPooler(model_config=llm.llm_engine.model_config) + arr = _arr([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) - assert GritLMPooler._find_array(arr, _arr([3, 4, 5]), start_idx=0) == 3 - assert GritLMPooler._find_array(arr, _arr([3, 4, 5]), start_idx=1) == 3 - assert GritLMPooler._find_array(arr, _arr([3, 4, 5]), start_idx=5) == -1 - assert GritLMPooler._find_array(arr, _arr([3, 5]), start_idx=0) == -1 + assert pooler._find_array(arr, _arr([3, 4, 5]), start_idx=0) == 3 + assert pooler._find_array(arr, _arr([3, 4, 5]), start_idx=1) == 3 + assert pooler._find_array(arr, _arr([3, 4, 5]), start_idx=5) == -1 + assert pooler._find_array(arr, _arr([3, 5]), start_idx=0) == -1 with pytest.raises(ValueError): - GritLMPooler._find_array(arr, _arr([3, 4, 5]), start_idx=-1) + pooler._find_array(arr, _arr([3, 4, 5]), start_idx=-1) @pytest.fixture(scope="module") diff --git a/vllm/model_executor/models/gritlm.py b/vllm/model_executor/models/gritlm.py index 16a728295fbaa..55bfad2abea03 100644 --- a/vllm/model_executor/models/gritlm.py +++ b/vllm/model_executor/models/gritlm.py @@ -34,13 +34,25 @@ def __init__(self, model_config: ModelConfig): ) # Collect the tokens needed for pattern matching. + # "▁<" is different from "_<". The former uses "▁" to indicate that + # the next token is the start of a word. + # "<0x0A>" is the newline token (i.e. "\n")." self.token_ids = { tok: tokenizer.convert_tokens_to_ids([tok])[0] for tok in ["", "▁<", "<", "|", "embed", ">", "<0x0A>", "user"] } - @staticmethod - def _find_array(arr: array, target: array, start_idx: int) -> int: + def tokens_to_ids(tokens: list[str]) -> array: + return array("i", [self.token_ids[token] for token in tokens]) + + self.user_pattern_ids = tokens_to_ids( + ["▁<", "|", "user", "|", ">", "<0x0A>"]) + self.embed_newline_pattern_ids = tokens_to_ids( + ["<0x0A>", "<", "|", "embed", "|", ">", "<0x0A>"]) + self.embed_pattern_ids = tokens_to_ids( + ["▁<", "|", "embed", "|", ">", "<0x0A>"]) + + def _find_array(self, arr: array, target: array, start_idx: int) -> int: """ Find the first occurrence of target in arr starting from start_idx. @@ -74,41 +86,30 @@ def _get_instruction_len(self, prompt_token_ids: array) -> bool: because the prompt is given as a list of token IDs. """ - def tokens_to_ids(tokens: list[str]) -> List[int]: - return array("i", [self.token_ids[token] for token in tokens]) - instruction_len = 0 - found_bos_token = prompt_token_ids[0] == self.token_ids[""] - # Return no instruction in case of missing BOS token. - if not found_bos_token: + if prompt_token_ids[0] != self.token_ids[""]: logger.warning("BOS token not found in prompt," "thus using empty string for instruction." "GritLM requires BOS token in prompt.") return instruction_len - # Find the user pattern in the prompt. - user_token_ids = tokens_to_ids(["▁<", "|", "user", "|", ">", "<0x0A>"]) - found_user_pattern = (__class__._find_array(prompt_token_ids, - user_token_ids, - start_idx=1) == 1) + # If user pattern is found in the prompt, that means there should be + # a newline token before the embed pattern. + embed_pattern_ids = self.embed_pattern_ids + if self._find_array(prompt_token_ids, + self.user_pattern_ids, + start_idx=1) == 1: + embed_pattern_ids = self.embed_newline_pattern_ids # Find the embed pattern in the prompt. - if found_user_pattern: - # If user pattern is found, that means there should be - # a newline token before the embed pattern. - embed_token_ids = tokens_to_ids( - ["<0x0A>", "<", "|", "embed", "|", ">", "<0x0A>"]) - else: - embed_token_ids = tokens_to_ids( - ["▁<", "|", "embed", "|", ">", "<0x0A>"]) - found_embed_pattern_idx = __class__._find_array(prompt_token_ids, - embed_token_ids, - start_idx=1) + found_embed_pattern_idx = self._find_array(prompt_token_ids, + embed_pattern_ids, + start_idx=1) if found_embed_pattern_idx != -1: - instruction_len = found_embed_pattern_idx + len(embed_token_ids) + instruction_len = found_embed_pattern_idx + len(embed_pattern_ids) else: logger.warning("Query instruction not found in prompt," "thus using BOS token as instruction instead."