Skip to content

Commit

Permalink
Add test for invalid token id in allowed tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
RohitRathore1 committed Nov 28, 2024
1 parent 668ea42 commit fc41bf7
Showing 1 changed file with 42 additions and 0 deletions.
42 changes: 42 additions & 0 deletions tests/fsm/test_cfg_guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,3 +455,45 @@ def test_cfg_grammar_sample(request, sample_name, tokenizer_name, cleanup_lark_i
state = cfg_guide.get_next_state(state, token_id)
final_instruction = cfg_guide.get_next_instruction(state)
assert tokenizer.eos_token_id in final_instruction.tokens


# Add the first test: Unit test with mock tokenizer
def test_invalid_eos_token_id_handling():
from outlines.fsm.guide import CFGGuide

# Mock tokenizer with limited vocabulary and invalid eos_token_id
class MockTokenizer:
vocabulary = {"a": 0, "b": 1}
token_to_id = vocabulary
id_to_token = {v: k for k, v in vocabulary.items()}
special_tokens = {}
eos_token_id = len(vocabulary) # Invalid eos_token_id

def decode(self, token_ids):
return [self.id_to_token.get(token_id, "") for token_id in token_ids]

# Define a simple CFG
cfg_string = r"""
?start: "a" "b"
"""

# Initialize the guide with the mock tokenizer
tokenizer = MockTokenizer()
guide = CFGGuide(cfg_string, tokenizer)

# Build the initial state
state = guide.initial_state
instruction = guide.get_next_instruction(state)
valid_tokens = instruction.tokens

# Check that valid_tokens do not contain invalid token IDs
invalid_tokens = [
token_id for token_id in valid_tokens if token_id >= len(tokenizer.vocabulary)
]
assert not invalid_tokens, f"Found invalid token IDs: {invalid_tokens}"

try:
next_token_id = valid_tokens[0] # Take the first valid token
next_state = guide.get_next_state(state, next_token_id) # noqa: F841
except IndexError as e:
pytest.fail(f"IndexError encountered: {e}")

0 comments on commit fc41bf7

Please sign in to comment.