Skip to content

Commit

Permalink
Improve _find_list
Browse files Browse the repository at this point in the history
  • Loading branch information
pooyadavoodi committed Dec 5, 2024
1 parent bd04e90 commit 5f07c31
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 30 deletions.
12 changes: 6 additions & 6 deletions tests/models/embedding/language/test_gritlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,16 @@ def _arr(arr):
return array("i", arr)


def test_find_list():
def test_find_array():
arr = _arr([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

assert GritLMPooler._find_list(arr, _arr([3, 4, 5]), start_idx=0) == 3
assert GritLMPooler._find_list(arr, _arr([3, 4, 5]), start_idx=1) == 3
assert GritLMPooler._find_list(arr, _arr([3, 4, 5]), start_idx=5) == -1
assert GritLMPooler._find_list(arr, _arr([3, 5]), start_idx=0) == -1
assert GritLMPooler._find_array(arr, _arr([3, 4, 5]), start_idx=0) == 3
assert GritLMPooler._find_array(arr, _arr([3, 4, 5]), start_idx=1) == 3
assert GritLMPooler._find_array(arr, _arr([3, 4, 5]), start_idx=5) == -1
assert GritLMPooler._find_array(arr, _arr([3, 5]), start_idx=0) == -1

with pytest.raises(ValueError):
GritLMPooler._find_list(arr, _arr([3, 4, 5]), start_idx=-1)
GritLMPooler._find_array(arr, _arr([3, 4, 5]), start_idx=-1)


@pytest.fixture(scope="module")
Expand Down
38 changes: 14 additions & 24 deletions vllm/model_executor/models/gritlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,9 @@ def __init__(self, model_config: ModelConfig):
}

@staticmethod
def _find_list(arr: array, target: array, start_idx: int) -> int:
def _find_array(arr: array, target: array, start_idx: int) -> int:
"""
Find the first starting index where the search_list appears
as a consecutive subsequence in main_list.
Find the first occurrence of target in arr starting from start_idx.
Args:
arr: The array to search within
Expand All @@ -55,25 +54,14 @@ def _find_list(arr: array, target: array, start_idx: int) -> int:
"""
if start_idx < 0:
raise ValueError("start_idx must be non-negative")

found_index = -1

# Handle edge cases
if not target or not arr:
return found_index
raise ValueError("Empty arr or target not allowed")

# Length of lists
arr_len = len(arr)
target_len = len(target)

# Iterate through possible starting positions
for i in range(start_idx, arr_len - target_len + 1):
# Check if the subsequence matches
for i in range(start_idx, len(arr) - target_len + 1):
if arr[i:i + target_len] == target:
found_index = i
break

return found_index
return i
return -1

def _get_instruction_len(self, prompt_token_ids: array) -> bool:
"""
Expand Down Expand Up @@ -102,20 +90,22 @@ def tokens_to_ids(tokens: list[str]) -> List[int]:

# Find the user pattern in the prompt.
user_token_ids = tokens_to_ids(["▁<", "|", "user", "|", ">", "<0x0A>"])
found_user_pattern = (__class__._find_list(prompt_token_ids,
user_token_ids,
start_idx=1) == 1)
found_user_pattern = (__class__._find_array(prompt_token_ids,
user_token_ids,
start_idx=1) == 1)

# Find the embed pattern in the prompt.
if found_user_pattern:
# If user pattern is found, that means there should be
# a newline token before the embed pattern.
embed_token_ids = tokens_to_ids(
["<0x0A>", "<", "|", "embed", "|", ">", "<0x0A>"])
else:
embed_token_ids = tokens_to_ids(
["▁<", "|", "embed", "|", ">", "<0x0A>"])
found_embed_pattern_idx = __class__._find_list(prompt_token_ids,
embed_token_ids,
start_idx=1)
found_embed_pattern_idx = __class__._find_array(prompt_token_ids,
embed_token_ids,
start_idx=1)

if found_embed_pattern_idx != -1:
instruction_len = found_embed_pattern_idx + len(embed_token_ids)
Expand Down

0 comments on commit 5f07c31

Please sign in to comment.