Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove the need to copy all tokens during basic generation #852

Merged
merged 1 commit into from
May 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions outlines/fsm/fsm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import TYPE_CHECKING, List, NewType
from typing import TYPE_CHECKING, Iterable, NewType, Optional

from outlines.fsm.guide import CFGGuide, RegexGuide, StopAtEOSGuide

Expand All @@ -20,7 +20,7 @@ def __init__(self, tokenizer: "Tokenizer"):
)
super().__init__(tokenizer)

def allowed_token_ids(self, state: FSMState) -> List[int]:
def allowed_token_ids(self, state: FSMState) -> Optional[Iterable[int]]:
next_instruction = self.get_next_instruction(state)
return next_instruction.tokens

Expand All @@ -39,7 +39,7 @@ def __init__(self, regex_string: str, tokenizer):
)
super().__init__(regex_string, tokenizer)

def allowed_token_ids(self, state: FSMState) -> List[int]:
def allowed_token_ids(self, state: FSMState) -> Optional[Iterable[int]]:
next_instruction = self.get_next_instruction(state)
return next_instruction.tokens

Expand All @@ -58,7 +58,7 @@ def __init__(self, cfg_string: str, tokenizer):
)
super().__init__(cfg_string, tokenizer)

def allowed_token_ids(self, state: FSMState) -> List[int]:
def allowed_token_ids(self, state: FSMState) -> Optional[Iterable[int]]:
return self.get_next_instruction(state).tokens

def next_state(self, state: FSMState, token_id: int) -> FSMState:
Expand Down
15 changes: 11 additions & 4 deletions outlines/fsm/guide.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Protocol, Tuple, Union
from typing import TYPE_CHECKING, List, Optional, Protocol, Tuple, Union

import interegular
from lark import Lark
Expand Down Expand Up @@ -38,10 +38,11 @@ class Generate:
Attributes
----------
tokens
The tokens that lead to a valid completion if generated.
The tokens that lead to a valid completion if generated. A value
of ``None`` indicates that all tokens are allowed.
"""

tokens: List[int]
tokens: Optional[List[int]]


Instruction = Union[Write, Generate]
Expand Down Expand Up @@ -89,7 +90,7 @@ def __init__(self, tokenizer: "Tokenizer"):
def get_next_instruction(self, state: int) -> Instruction:
if self.is_final_state(state):
return Write([self.eos_token_id])
return Generate(list(self.vocabulary))
return Generate(None)

def get_next_state(self, state: int, token_id: int) -> int:
if token_id == self.eos_token_id or state == self.final_state:
Expand Down Expand Up @@ -330,6 +331,9 @@ def get_next_instruction(self, state: int) -> Instruction:
proposer = self.regex_fsm

instruction = proposer.get_next_instruction(state)

assert instruction.tokens is not None

if isinstance(instruction, Write):
proposal += instruction.tokens
else:
Expand Down Expand Up @@ -365,6 +369,9 @@ def get_next_instruction(self, state: int) -> Instruction:
self.reset_state = True

instruction = self.regex_fsm.get_next_instruction(self.start_state)

assert instruction.tokens is not None

if isinstance(instruction, Write):
proposal += instruction.tokens
else:
Expand Down
11 changes: 8 additions & 3 deletions outlines/generate/generator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import dataclasses
import math
from typing import TYPE_CHECKING, Callable, Iterator, List, Optional, Tuple
from typing import TYPE_CHECKING, Callable, Iterable, Iterator, List, Optional, Tuple

if TYPE_CHECKING:
import torch
Expand Down Expand Up @@ -134,7 +134,9 @@ def get_next_fsm_states(
]


def get_allowed_tokens(fsms: List["Guide"], fsm_states: List[int]) -> "torch.Tensor":
def get_allowed_tokens(
fsms: List["Guide"], fsm_states: List[int]
) -> List[Optional[Iterable[int]]]:
"""Get the new instructions for each sequence from the finite-state machine.

Parameters
Expand Down Expand Up @@ -302,5 +304,8 @@ def bias_logits(logits: "torch.Tensor", allowed_token_ids: List) -> "torch.Tenso

biased_logits = torch.full_like(logits, -math.inf, device=logits.device)
for i, ids in enumerate(allowed_token_ids):
biased_logits[i, ids] = logits[i, ids]
if ids is not None:
biased_logits[i, ids] = logits[i, ids]
else:
biased_logits[i] = logits[i]
return biased_logits
4 changes: 2 additions & 2 deletions outlines/integrations/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
"""

from collections import defaultdict
from typing import DefaultDict, List, Optional, Type, Union
from typing import DefaultDict, Iterable, Optional, Type, Union

import torch
from pydantic import BaseModel
Expand Down Expand Up @@ -84,7 +84,7 @@ def __init__(
# apply the FSM to the generated tokens.
self._prefix = [-1]

def __call__(self, batch_id: int, sent: torch.Tensor) -> List[int]:
def __call__(self, batch_id: int, sent: torch.Tensor) -> Optional[Iterable[int]]:
"""Use the FSM to bias the logits before sampling the next token.

Parameters
Expand Down
2 changes: 1 addition & 1 deletion tests/fsm/test_fsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class MockTokenizer:
with pytest.warns(UserWarning):
fsm = StopAtEosFSM(MockTokenizer())

assert fsm.allowed_token_ids(fsm.start_state) == [1, 2]
assert fsm.allowed_token_ids(fsm.start_state) is None
assert fsm.allowed_token_ids(fsm.final_state) == [2]
assert fsm.next_state(fsm.start_state, 2) == fsm.final_state
assert fsm.next_state(fsm.start_state, 1) == fsm.start_state
Expand Down
2 changes: 1 addition & 1 deletion tests/fsm/test_guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class MockTokenizer:

instruction = fsm.get_next_instruction(fsm.start_state)
assert isinstance(instruction, Generate)
assert instruction.tokens == [1, 2]
assert instruction.tokens is None

instruction = fsm.get_next_instruction(fsm.final_state)
assert isinstance(instruction, Write)
Expand Down
Loading