From 5f07c319ab944ac527c266a03502edd1b6dfcd56 Mon Sep 17 00:00:00 2001 From: Pooya Davoodi Date: Fri, 6 Dec 2024 08:22:57 +0900 Subject: [PATCH] Improve _find_list --- .../models/embedding/language/test_gritlm.py | 12 +++--- vllm/model_executor/models/gritlm.py | 38 +++++++------------ 2 files changed, 20 insertions(+), 30 deletions(-) diff --git a/tests/models/embedding/language/test_gritlm.py b/tests/models/embedding/language/test_gritlm.py index a10db2bd07775..22574069b4c27 100644 --- a/tests/models/embedding/language/test_gritlm.py +++ b/tests/models/embedding/language/test_gritlm.py @@ -26,16 +26,16 @@ def _arr(arr): return array("i", arr) -def test_find_list(): +def test_find_array(): arr = _arr([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) - assert GritLMPooler._find_list(arr, _arr([3, 4, 5]), start_idx=0) == 3 - assert GritLMPooler._find_list(arr, _arr([3, 4, 5]), start_idx=1) == 3 - assert GritLMPooler._find_list(arr, _arr([3, 4, 5]), start_idx=5) == -1 - assert GritLMPooler._find_list(arr, _arr([3, 5]), start_idx=0) == -1 + assert GritLMPooler._find_array(arr, _arr([3, 4, 5]), start_idx=0) == 3 + assert GritLMPooler._find_array(arr, _arr([3, 4, 5]), start_idx=1) == 3 + assert GritLMPooler._find_array(arr, _arr([3, 4, 5]), start_idx=5) == -1 + assert GritLMPooler._find_array(arr, _arr([3, 5]), start_idx=0) == -1 with pytest.raises(ValueError): - GritLMPooler._find_list(arr, _arr([3, 4, 5]), start_idx=-1) + GritLMPooler._find_array(arr, _arr([3, 4, 5]), start_idx=-1) @pytest.fixture(scope="module") diff --git a/vllm/model_executor/models/gritlm.py b/vllm/model_executor/models/gritlm.py index a5b5ca215434a..16a728295fbaa 100644 --- a/vllm/model_executor/models/gritlm.py +++ b/vllm/model_executor/models/gritlm.py @@ -40,10 +40,9 @@ def __init__(self, model_config: ModelConfig): } @staticmethod - def _find_list(arr: array, target: array, start_idx: int) -> int: + def _find_array(arr: array, target: array, start_idx: int) -> int: """ - Find the first starting index where the search_list appears - as a consecutive subsequence in main_list. + Find the first occurrence of target in arr starting from start_idx. Args: arr: The array to search within @@ -55,25 +54,14 @@ def _find_list(arr: array, target: array, start_idx: int) -> int: """ if start_idx < 0: raise ValueError("start_idx must be non-negative") - - found_index = -1 - - # Handle edge cases if not target or not arr: - return found_index + raise ValueError("Empty arr or target not allowed") - # Length of lists - arr_len = len(arr) target_len = len(target) - - # Iterate through possible starting positions - for i in range(start_idx, arr_len - target_len + 1): - # Check if the subsequence matches + for i in range(start_idx, len(arr) - target_len + 1): if arr[i:i + target_len] == target: - found_index = i - break - - return found_index + return i + return -1 def _get_instruction_len(self, prompt_token_ids: array) -> bool: """ @@ -102,20 +90,22 @@ def tokens_to_ids(tokens: list[str]) -> List[int]: # Find the user pattern in the prompt. user_token_ids = tokens_to_ids(["▁<", "|", "user", "|", ">", "<0x0A>"]) - found_user_pattern = (__class__._find_list(prompt_token_ids, - user_token_ids, - start_idx=1) == 1) + found_user_pattern = (__class__._find_array(prompt_token_ids, + user_token_ids, + start_idx=1) == 1) # Find the embed pattern in the prompt. if found_user_pattern: + # If user pattern is found, that means there should be + # a newline token before the embed pattern. embed_token_ids = tokens_to_ids( ["<0x0A>", "<", "|", "embed", "|", ">", "<0x0A>"]) else: embed_token_ids = tokens_to_ids( ["▁<", "|", "embed", "|", ">", "<0x0A>"]) - found_embed_pattern_idx = __class__._find_list(prompt_token_ids, - embed_token_ids, - start_idx=1) + found_embed_pattern_idx = __class__._find_array(prompt_token_ids, + embed_token_ids, + start_idx=1) if found_embed_pattern_idx != -1: instruction_len = found_embed_pattern_idx + len(embed_token_ids)