Skip to content

Commit

Permalink
Fix IndexError caused by invalid token IDs in CFGGuide
Browse files Browse the repository at this point in the history
  • Loading branch information
RohitRathore1 committed Nov 7, 2024
1 parent 5f39ded commit 668ea42
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 35 deletions.
40 changes: 22 additions & 18 deletions outlines/fsm/guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,10 @@ def __init__(self, cfg_string: str, tokenizer):

self.cfg_string = cfg_string
self.tokenizer = tokenizer

# Set eos_token_id if available
self.eos_token_id = self.tokenizer.eos_token_id

self.parser = PartialLark(
cfg_string,
parser="lalr",
Expand Down Expand Up @@ -149,14 +152,20 @@ def get_next_instruction(self, state: CFGState) -> Instruction:
"""

if state.parser_state is None:
return Write(torch.tensor([self.eos_token_id]))
if self.eos_token_id is not None:
return Write(torch.tensor([self.eos_token_id]))
else:
return None # No instruction if eos_token_id is not set

valid_tokens = list(
self.iter_valid_token_ids(state, self.tokenizer.vocabulary.values())
self.iter_valid_token_ids(state, list(self.tokenizer.vocabulary.values()))
)
if len(valid_tokens) == 1:
if not valid_tokens:
return None # No valid tokens to generate
elif len(valid_tokens) == 1:
return Write(torch.tensor(valid_tokens))
return Generate(torch.tensor(valid_tokens))
else:
return Generate(torch.tensor(valid_tokens))

def iter_valid_token_ids(
self, state: CFGState, candidate_token_ids: list
Expand All @@ -177,11 +186,12 @@ def iter_valid_token_ids(
Valid token ids.
"""
if state.parser_state is None:
yield self.eos_token_id
if self.eos_token_id is not None:
yield self.eos_token_id
return

for token_id in candidate_token_ids:
if token_id == self.eos_token_id:
if token_id == self.eos_token_id and self.eos_token_id is not None:
if self.can_terminate_state(state):
yield token_id
else:
Expand Down Expand Up @@ -234,20 +244,14 @@ def _get_parser_state_token_applied(
"""
parser_state = copy.copy(state.parser_state) # prevent side effects

# normalize
if state.prev_token is None:
new_token_str = self.tokenizer.decode([token_id])[0]
else:
prev_token_str = self.tokenizer.decode([[state.prev_token]])[0]
combined_token_str = self.tokenizer.decode([[state.prev_token, token_id]])[
0
]
new_token_str = combined_token_str[len(prev_token_str) :]

if new_token_str == "":
# Decode the token
token_str = self.tokenizer.decode([token_id])
if not token_str:
raise ValueError("empty next token")

# update parser with new token
new_token_str = token_str[0] # Assuming decode returns a list

# Update parser with new token
parser_state.lexer.state.text += new_token_str
self.parser.parse_from_state(parser_state, is_end=False)

Expand Down
52 changes: 35 additions & 17 deletions outlines/processors/structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,17 @@ def process_logits(
if self._seq_start_idx is None:
self._seq_start_idx = len(input_ids[0])

sequence_states: List[int] = [] # vector of states corresponding to `input_ids`
sequence_states: List[Any] = [] # vector of states corresponding to `input_ids`

for seq_ids in input_ids:
gen_ids = seq_ids[self._seq_start_idx :]
curr_state_key = hash(tuple(gen_ids.tolist()))

if curr_state_key not in self._guide_states:
prev_state = self._guide_states[hash(tuple(gen_ids[:-1].tolist()))]
prev_state_key = hash(tuple(gen_ids[:-1].tolist()))
prev_state = self._guide_states.get(
prev_state_key, self.guide.initial_state
)
curr_state = self.guide.get_next_state(prev_state, gen_ids[-1].item())
self._guide_states[curr_state_key] = curr_state

Expand All @@ -107,19 +110,26 @@ def process_logits(
allowed_tokens_batch = []
batch_indices = []
for i, guide_state in enumerate(sequence_states):
allowed_tokens = self.guide.get_next_instruction(guide_state).tokens.to(
mask.device, non_blocking=True
)
instruction = self.guide.get_next_instruction(guide_state)
if instruction is None:
continue # Skip if no instruction is available
allowed_tokens = instruction.tokens
if allowed_tokens is None:
continue # Skip if no tokens are allowed
allowed_tokens = allowed_tokens.to(mask.device, non_blocking=True)

# Filter out invalid token IDs
allowed_tokens = allowed_tokens[allowed_tokens < logits.size(1)]
allowed_tokens_batch.append(allowed_tokens)
batch_indices.append(
torch.full_like(allowed_tokens, i)
) # Store batch index for each allowed token
batch_indices.append(torch.full_like(allowed_tokens, i))

allowed_tokens_concat = torch.cat(allowed_tokens_batch)
batch_indices_concat = torch.cat(batch_indices)
if allowed_tokens_batch:
allowed_tokens_concat = torch.cat(allowed_tokens_batch)
batch_indices_concat = torch.cat(batch_indices)

mask[batch_indices_concat, allowed_tokens_concat] = False
logits.masked_fill_(mask, float("-inf"))
mask[batch_indices_concat, allowed_tokens_concat] = False

logits = logits.masked_fill(mask, float("-inf"))

return logits

Expand Down Expand Up @@ -221,26 +231,34 @@ def process_logits(
if self._seq_start_idx is None:
self._seq_start_idx = len(input_ids[0])

sequence_states: List = [] # vector of states corresponding to `input_ids`
sequence_states: List[Any] = [] # vector of states corresponding to `input_ids`

for seq_ids in input_ids:
gen_ids = seq_ids[self._seq_start_idx :]
curr_state_key = hash(tuple(gen_ids.tolist()))

if curr_state_key not in self._guide_states:
prev_state = self._guide_states[hash(tuple(gen_ids[:-1].tolist()))]
prev_state_key = hash(tuple(gen_ids[:-1].tolist()))
prev_state = self._guide_states.get(
prev_state_key, self.guide.initial_state
)
curr_state = self.guide.get_next_state(prev_state, gen_ids[-1].item())
self._guide_states[curr_state_key] = curr_state

sequence_states.append(self._guide_states[curr_state_key])

mask = torch.full_like(logits, -math.inf)
for i, guide_state in enumerate(sequence_states):
first_legal_token = next(
valid_tokens = list(
self.guide.iter_valid_token_ids(
guide_state, torch.argsort(logits[i], descending=True)
guide_state, torch.arange(logits.size(1), device=logits.device)
)
)
mask[i, [first_legal_token]] = logits[i, [first_legal_token]]
if valid_tokens:
# Keep only valid tokens
mask[i, valid_tokens] = logits[i, valid_tokens]
else:
# No valid tokens; generation should stop
mask[i] = logits[i]

return mask

0 comments on commit 668ea42

Please sign in to comment.