Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi Beam Text Streamer #35436

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/internal/generation_utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,8 @@ A [`Constraint`] can be used to force the generation to include specific tokens

[[autodoc]] AsyncTextIteratorStreamer

[[autodoc]] MultiBeamTextStreamer

## Caches

[[autodoc]] Cache
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@
"AsyncTextIteratorStreamer",
"CompileConfig",
"GenerationConfig",
"MultiBeamTextStreamer",
"TextIteratorStreamer",
"TextStreamer",
"WatermarkingConfig",
Expand Down Expand Up @@ -5070,6 +5071,7 @@
AsyncTextIteratorStreamer,
CompileConfig,
GenerationConfig,
MultiBeamTextStreamer,
TextIteratorStreamer,
TextStreamer,
WatermarkingConfig,
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/generation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
"SynthIDTextWatermarkingConfig",
"WatermarkingConfig",
],
"streamers": ["AsyncTextIteratorStreamer", "TextIteratorStreamer", "TextStreamer"],
"streamers": ["AsyncTextIteratorStreamer", "MultiBeamTextStreamer", "TextIteratorStreamer", "TextStreamer"],
}

try:
Expand Down Expand Up @@ -199,7 +199,7 @@
SynthIDTextWatermarkingConfig,
WatermarkingConfig,
)
from .streamers import AsyncTextIteratorStreamer, TextIteratorStreamer, TextStreamer
from .streamers import AsyncTextIteratorStreamer, MultiBeamTextStreamer, TextIteratorStreamer, TextStreamer

try:
if not is_torch_available():
Expand Down
6 changes: 6 additions & 0 deletions src/transformers/generation/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from ..utils import add_start_docstrings
from .beam_constraints import Constraint, ConstraintListState
from .streamers import MultiBeamBaseStreamer


PROCESS_INPUTS_DOCSTRING = r"""
Expand Down Expand Up @@ -223,6 +224,7 @@ def process(
beam_indices: Optional[torch.LongTensor] = None,
group_index: Optional[int] = 0,
decoder_prompt_len: Optional[int] = 0,
streamer: Optional["MultiBeamBaseStreamer"] = None,
) -> Dict[str, torch.Tensor]:
# add up to the length which the next_scores is calculated on (including decoder prompt)
cur_len = input_ids.shape[-1] + 1
Expand Down Expand Up @@ -287,6 +289,10 @@ def process(
beam_indices=beam_index,
generated_len=cur_len - decoder_prompt_len,
)

if streamer is not None:
streamer.beam_finished(batch_beam_idx.item())

else:
# add next predicted token since it is not eos_token
next_beam_scores[batch_idx, beam_idx] = next_score
Expand Down
277 changes: 275 additions & 2 deletions src/transformers/generation/streamers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
from queue import Queue
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Callable, Dict, List, Optional


if TYPE_CHECKING:
Expand All @@ -36,6 +35,31 @@ def end(self):
raise NotImplementedError()


class MultiBeamBaseStreamer(BaseStreamer):
"""
Base class from which all multi-beam streamers should inherit.
Extends the BaseStreamer class with functionality specific to handling multiple beams.
"""

def __init__(self, num_beams: int):
super().__init__()
if not isinstance(num_beams, int) or num_beams <= 0:
raise ValueError(f"num_beams must be a positive integer, got {num_beams}")
self.num_beams = num_beams
self.current_beam = 0

def beam_finished(self, beam_idx: int):
"""
Called when a specific beam has finished generating.
Must be implemented by the derived class.

Args:
beam_idx (`int`):
Index of the beam that finished generating.
"""
raise NotImplementedError()


class TextStreamer(BaseStreamer):
"""
Simple text streamer that prints the token(s) to stdout as soon as entire words are formed.
Expand Down Expand Up @@ -228,6 +252,255 @@ def __next__(self):
return value


class MultiBeamTextStreamer(MultiBeamBaseStreamer):
"""
A streamer that handles beam search generation, allowing real-time tracking and processing of multiple beam outputs.
This is useful for applications that need to monitor or display multiple candidate sequences during beam search
generation, such as interactive applications showing alternative generations in real-time.

<Tip warning={true}>

The API for the streamer classes is still under development and may change in the future.

</Tip>

Parameters:
tokenizer (`AutoTokenizer`):
The tokenizer used to decode the tokens.
num_beams (`int`):
The number of beams to handle during generation.
on_beam_update (`Callable[[int, str], None]`):
A callback function that gets called whenever a beam's text is updated.
The function receives two arguments:
- beam_idx (`int`): The index of the updated beam
- text (`str`): The current complete text for this beam
on_beam_finished (`Callable[[str], None]`, *optional*):
A callback function that gets called when a beam reaches the EOS token.
The function receives one argument:
- text (`str`): The final text of the finished beam
skip_prompt (`bool`, *optional*, defaults to `True`):
Whether to skip the prompt tokens in the generation output.
decode_kwargs (`dict`, *optional*):
Additional keyword arguments to pass to the tokenizer's `decode` method.

Examples:

```python
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, MultiBeamTextStreamer

>>> # Create a dictionary to store beam outputs
>>> beam_outputs = {}

>>> # Define callback functions that store outputs in the dictionary
>>> def on_beam_update(beam_idx: int, text: str):
... beam_outputs[f"beam_{beam_idx}"] = text

>>> def on_beam_finished(text: str):
... beam_outputs["completed"] = text

>>> # Initialize model, tokenizer and streamer
>>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
>>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
>>> inputs = tokenizer(["An increasing sequence: one,"], return_tensors="pt")

>>> # Create streamer with 2 beams
>>> streamer = MultiBeamTextStreamer(
... tokenizer=tokenizer,
... num_beams=2,
... on_beam_update=on_beam_update,
... on_beam_finished=on_beam_finished
... )

>>> # Generate with beam search
>>> _ = model.generate(
... **inputs,
... streamer=streamer,
... num_beams=2,
... max_new_tokens=10
... )

>>> # Access the final outputs from the dictionary
>>> print(beam_outputs)
{
'beam_0': 'An increasing sequence: one, two, three, four,',
'beam_1': 'An increasing sequence: one, two, three, five,',
'completed': 'An increasing sequence: one, two, three, four,'
}
```

The streamer maintains internal state for each beam and provides real-time updates through the callback functions.
It handles beam switching during beam search and ensures proper tracking of beam histories. The streamer is particularly
useful for:

- Interactive applications showing multiple generation alternatives
- Debugging beam search behavior
- Creating UIs that display beam search progress
- Analyzing beam search decision patterns

Note that this streamer requires more memory than single-sequence streamers as it needs to maintain state for all beams.
For applications that only need the final best sequence, consider using `TextStreamer` instead.
"""

def __init__(
self,
tokenizer: "AutoTokenizer",
num_beams: int,
on_beam_update: Callable[[int, str], None],
on_beam_finished: Callable[[str], None] = None,
skip_prompt: bool = True,
**decode_kwargs,
):
super().__init__(num_beams)
self.tokenizer = tokenizer
self.num_beams = num_beams
self.skip_prompt = skip_prompt
self.decode_kwargs = decode_kwargs
self.on_beam_update = on_beam_update
self.on_beam_finished = on_beam_finished

# Initialize storage for each beam
self.beam_tokens: Dict[int, List[int]] = {i: [] for i in range(num_beams)}
self.beam_texts: Dict[int, str] = {i: "" for i in range(num_beams)}
self.beam_print_lens: Dict[int, int] = {i: 0 for i in range(num_beams)}

# Track beam states at each position
self.beam_history: Dict[int, Dict[int, List[int]]] = {} # position -> beam_idx -> tokens
self.current_position = 0

# Track current state
self.next_tokens_are_prompt = True

# Store finished beams
self.finished_beams: List[str] = []

def _switch_beam_content(self, position: int, previous_beam_idx: int, new_beam_idx: int):
"""
Internal helper to handle beam content switching with position tracking.
"""
if new_beam_idx >= self.num_beams:
raise ValueError(f"Beam index {new_beam_idx} is out of range (num_beams={self.num_beams})")

if previous_beam_idx != new_beam_idx:
# Get the correct historical state for the previous beam at this position
if position > 0 and position in self.beam_history:
source_tokens = self.beam_history[position][previous_beam_idx].copy()
else:
source_tokens = self.beam_tokens[previous_beam_idx].copy()

# Update tokens for the new beam
self.beam_tokens[new_beam_idx] = source_tokens

# Update text and calculate new state
text = self.tokenizer.decode(source_tokens, **self.decode_kwargs)
self.beam_texts[new_beam_idx] = text
self.beam_print_lens[new_beam_idx] = len(text)

# Notify handler of the beam update
self.on_beam_update(new_beam_idx, text)

def put(self, values, beam_indices=None):
"""
Handle new tokens for all beams at once.
Args:
values: List or array-like of shape (num_beams, 1) containing the next token for each beam
beam_indices: Optional list/array/tensor containing the previous beam indices for each current beam
"""
# Convert values to list if it's a tensor or array
if hasattr(values, "tolist"):
values = values.tolist()

# Validate input shape
if len(values) == 1 and isinstance(values[0], list) and len(values[0]) > 1:
values = [[token] for token in values[0]]
else:
if not isinstance(values, list) or not all(isinstance(row, list) and len(row) == 1 for row in values):
raise ValueError("Expected values to be a list of lists, each inner list having length 1")

if len(values) > self.num_beams:
raise ValueError(
f"Number of beams in values ({len(values)}) exceeds initialized num_beams ({self.num_beams})"
)

# Handle beam_indices
if beam_indices is None:
# Create a simple list of indices from 0 to num_beams-1
beam_indices = list(range(len(values)))
else:
# Convert beam_indices to list if it's a tensor or array
if hasattr(beam_indices, "tolist"):
beam_indices = beam_indices.tolist()
elif not isinstance(beam_indices, list):
beam_indices = list(beam_indices)

if self.skip_prompt and self.next_tokens_are_prompt:
self.next_tokens_are_prompt = False
return

# Save current state before modifications
current_state = {beam_idx: self.beam_tokens[beam_idx].copy() for beam_idx in range(self.num_beams)}
self.beam_history[self.current_position] = current_state

# Handle beam switching
for i in range(len(beam_indices)):
self._switch_beam_content(self.current_position, beam_indices[i], i)

# Iterate through each beam
for beam_idx in range(len(values)):
# Get token for current beam
value = values[beam_idx]

# Add new tokens to current beam
self.beam_tokens[beam_idx].extend(value)

# Decode the entire sequence for current beam
text = self.tokenizer.decode(self.beam_tokens[beam_idx], **self.decode_kwargs)

# Update beam text and calculate printable portion
self.beam_texts[beam_idx] = text
self.beam_print_lens[beam_idx] = len(text)

# Notify handler of the beam update with new text
self.on_beam_update(beam_idx, text)

self.current_position += 1

def beam_finished(self, beam_idx: int):
"""Mark a beam as finished and notify the handler."""
if beam_idx in self.beam_texts:
self.finished_beams.append(self.beam_texts[beam_idx])

# Notify handler that the beam is finished
if self.on_beam_finished:
self.on_beam_finished(self.finished_beams[-1])

def end(self):
"""Finish streaming and handle any remaining beams."""
try:
# Clean up all beam-related storage
self.beam_tokens.clear()
self.beam_texts.clear()
self.beam_print_lens.clear()
self.finished_beams.clear()

# Clean up position tracking
self.beam_history.clear()
self.current_position = 0

# Reset state variables
self.next_tokens_are_prompt = True

# Reinitialize storage for potential reuse
self.beam_tokens = {i: [] for i in range(self.num_beams)}
self.beam_texts = {i: "" for i in range(self.num_beams)}
self.beam_print_lens = {i: 0 for i in range(self.num_beams)}
self.finished_beams = {}
self.beam_history = {}

except Exception as e:
print(f"Error during cleanup: {str(e)}")
raise


class AsyncTextIteratorStreamer(TextStreamer):
"""
Streamer that stores print-ready text in a queue, to be used by a downstream application as an async iterator.
Expand Down
Loading