Skip to content

Commit

Permalink
Fix BOS check, move patterns to constructor
Browse files Browse the repository at this point in the history
Signed-off-by: Pooya Davoodi <[email protected]>
  • Loading branch information
pooyadavoodi committed Dec 8, 2024
1 parent 4941376 commit 5f32d7c
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 30 deletions.
14 changes: 9 additions & 5 deletions tests/models/embedding/language/test_gritlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,19 @@ def _arr(arr):


def test_find_array():
# Create an LLM object to get the model config.
llm = vllm.LLM(MODEL_NAME, task="embedding")
pooler = GritLMPooler(model_config=llm.llm_engine.model_config)

arr = _arr([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

assert GritLMPooler._find_array(arr, _arr([3, 4, 5]), start_idx=0) == 3
assert GritLMPooler._find_array(arr, _arr([3, 4, 5]), start_idx=1) == 3
assert GritLMPooler._find_array(arr, _arr([3, 4, 5]), start_idx=5) == -1
assert GritLMPooler._find_array(arr, _arr([3, 5]), start_idx=0) == -1
assert pooler._find_array(arr, _arr([3, 4, 5]), start_idx=0) == 3
assert pooler._find_array(arr, _arr([3, 4, 5]), start_idx=1) == 3
assert pooler._find_array(arr, _arr([3, 4, 5]), start_idx=5) == -1
assert pooler._find_array(arr, _arr([3, 5]), start_idx=0) == -1

with pytest.raises(ValueError):
GritLMPooler._find_array(arr, _arr([3, 4, 5]), start_idx=-1)
pooler._find_array(arr, _arr([3, 4, 5]), start_idx=-1)


@pytest.fixture(scope="module")
Expand Down
51 changes: 26 additions & 25 deletions vllm/model_executor/models/gritlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,25 @@ def __init__(self, model_config: ModelConfig):
)

# Collect the tokens needed for pattern matching.
# "▁<" is different from "_<". The former uses "▁" to indicate that
# the next token is the start of a word.
# "<0x0A>" is the newline token (i.e. "\n")."
self.token_ids = {
tok: tokenizer.convert_tokens_to_ids([tok])[0]
for tok in ["<s>", "▁<", "<", "|", "embed", ">", "<0x0A>", "user"]
}

@staticmethod
def _find_array(arr: array, target: array, start_idx: int) -> int:
def tokens_to_ids(tokens: list[str]) -> array:
return array("i", [self.token_ids[token] for token in tokens])

self.user_pattern_ids = tokens_to_ids(
["▁<", "|", "user", "|", ">", "<0x0A>"])
self.embed_newline_pattern_ids = tokens_to_ids(
["<0x0A>", "<", "|", "embed", "|", ">", "<0x0A>"])
self.embed_pattern_ids = tokens_to_ids(
["▁<", "|", "embed", "|", ">", "<0x0A>"])

def _find_array(self, arr: array, target: array, start_idx: int) -> int:
"""
Find the first occurrence of target in arr starting from start_idx.
Expand Down Expand Up @@ -74,41 +86,30 @@ def _get_instruction_len(self, prompt_token_ids: array) -> bool:
because the prompt is given as a list of token IDs.
"""

def tokens_to_ids(tokens: list[str]) -> List[int]:
return array("i", [self.token_ids[token] for token in tokens])

instruction_len = 0

found_bos_token = prompt_token_ids[0] == self.token_ids["<s>"]

# Return no instruction in case of missing BOS token.
if not found_bos_token:
if prompt_token_ids[0] != self.token_ids["<s>"]:
logger.warning("BOS token not found in prompt,"
"thus using empty string for instruction."
"GritLM requires BOS token in prompt.")
return instruction_len

# Find the user pattern in the prompt.
user_token_ids = tokens_to_ids(["▁<", "|", "user", "|", ">", "<0x0A>"])
found_user_pattern = (__class__._find_array(prompt_token_ids,
user_token_ids,
start_idx=1) == 1)
# If user pattern is found in the prompt, that means there should be
# a newline token before the embed pattern.
embed_pattern_ids = self.embed_pattern_ids
if self._find_array(prompt_token_ids,
self.user_pattern_ids,
start_idx=1) == 1:
embed_pattern_ids = self.embed_newline_pattern_ids

# Find the embed pattern in the prompt.
if found_user_pattern:
# If user pattern is found, that means there should be
# a newline token before the embed pattern.
embed_token_ids = tokens_to_ids(
["<0x0A>", "<", "|", "embed", "|", ">", "<0x0A>"])
else:
embed_token_ids = tokens_to_ids(
["▁<", "|", "embed", "|", ">", "<0x0A>"])
found_embed_pattern_idx = __class__._find_array(prompt_token_ids,
embed_token_ids,
start_idx=1)
found_embed_pattern_idx = self._find_array(prompt_token_ids,
embed_pattern_ids,
start_idx=1)

if found_embed_pattern_idx != -1:
instruction_len = found_embed_pattern_idx + len(embed_token_ids)
instruction_len = found_embed_pattern_idx + len(embed_pattern_ids)
else:
logger.warning("Query instruction not found in prompt,"
"thus using BOS token as instruction instead."
Expand Down

0 comments on commit 5f32d7c

Please sign in to comment.