-
Notifications
You must be signed in to change notification settings - Fork 203
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add hook sample of transformes v4.45
- Loading branch information
Showing
2 changed files
with
227 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,225 @@ | ||
# -*- coding: utf-8 -*- | ||
# Copyright (C) 2023-2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# flake8: noqa | ||
import time | ||
import torch | ||
import logging as log | ||
from torch import nn | ||
from typing import Optional, Tuple, Union | ||
from transformers.generation.stopping_criteria import ( | ||
StoppingCriteriaList, | ||
validate_stopping_criteria, | ||
) | ||
from transformers.generation.logits_process import LogitsProcessorList | ||
from transformers.generation.streamers import BaseStreamer | ||
from transformers.utils import ModelOutput | ||
from transformers.generation.configuration_utils import GenerationConfig | ||
import llm_bench_utils.hook_greedy_search as hook_greedy | ||
|
||
|
||
logger = log.getLogger(__name__) | ||
|
||
|
||
class GenerateDecoderOnlyOutput(ModelOutput): | ||
sequences: torch.LongTensor = None | ||
scores: Optional[Tuple[torch.FloatTensor]] = None | ||
logits: Optional[Tuple[torch.FloatTensor]] = None | ||
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None | ||
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None | ||
past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None | ||
|
||
|
||
class GenerateEncoderDecoderOutput(ModelOutput): | ||
sequences: torch.LongTensor = None | ||
scores: Optional[Tuple[torch.FloatTensor]] = None | ||
logits: Optional[Tuple[torch.FloatTensor]] = None | ||
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None | ||
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None | ||
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None | ||
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None | ||
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None | ||
past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None | ||
|
||
|
||
GenerateNonBeamOutput = Union[GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput] | ||
|
||
|
||
# Transformers version: v4.45.0 | ||
# Copied from https://github.com/huggingface/transformers/blob/v4.45.0/src/transformers/generation/utils.py#L2925 | ||
# Add the function of collecting latency | ||
def new_sample( | ||
self, | ||
input_ids: torch.LongTensor, | ||
logits_processor: LogitsProcessorList, | ||
stopping_criteria: StoppingCriteriaList, | ||
generation_config: GenerationConfig, | ||
synced_gpus: bool, | ||
streamer: Optional["BaseStreamer"], | ||
**model_kwargs, | ||
) -> Union[GenerateNonBeamOutput, torch.LongTensor]: | ||
r""" | ||
Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and | ||
can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. | ||
Parameters: | ||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): | ||
The sequence used as a prompt for the generation. | ||
logits_processor (`LogitsProcessorList`): | ||
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] | ||
used to modify the prediction scores of the language modeling head applied at each generation step. | ||
stopping_criteria (`StoppingCriteriaList`): | ||
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] | ||
used to tell if the generation loop should stop. | ||
generation_config ([`~generation.GenerationConfig`]): | ||
The generation configuration to be used as parametrization of the decoding method. | ||
synced_gpus (`bool`): | ||
Whether to continue running the while loop until max_length (needed for ZeRO stage 3) | ||
streamer (`BaseStreamer`, *optional*): | ||
Streamer object that will be used to stream the generated sequences. Generated tokens are passed | ||
through `streamer.put(token_ids)` and the streamer is responsible for any further processing. | ||
model_kwargs: | ||
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is | ||
an encoder-decoder model the kwargs should include `encoder_outputs`. | ||
Return: | ||
[`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or `torch.LongTensor`: | ||
A `torch.LongTensor` containing the generated tokens (default behaviour) or a | ||
[`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and | ||
`return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if | ||
`model.config.is_encoder_decoder=True`. | ||
""" | ||
# init values | ||
pad_token_id = generation_config._pad_token_tensor | ||
output_attentions = generation_config.output_attentions | ||
output_hidden_states = generation_config.output_hidden_states | ||
output_scores = generation_config.output_scores | ||
output_logits = generation_config.output_logits | ||
return_dict_in_generate = generation_config.return_dict_in_generate | ||
max_length = generation_config.max_length | ||
has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) | ||
do_sample = generation_config.do_sample | ||
|
||
# init attention / hidden states / scores tuples | ||
scores = () if (return_dict_in_generate and output_scores) else None | ||
raw_logits = () if (return_dict_in_generate and output_logits) else None | ||
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None | ||
cross_attentions = () if (return_dict_in_generate and output_attentions) else None | ||
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None | ||
|
||
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states | ||
if return_dict_in_generate and self.config.is_encoder_decoder: | ||
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None | ||
encoder_hidden_states = ( | ||
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None | ||
) | ||
|
||
# keep track of which sequences are already finished | ||
batch_size, cur_len = input_ids.shape | ||
this_peer_finished = False | ||
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) | ||
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) | ||
|
||
while self._has_unfinished_sequences( | ||
this_peer_finished, synced_gpus, device=input_ids.device, cur_len=cur_len, max_length=max_length | ||
): | ||
tic = time.perf_counter() | ||
# prepare model inputs | ||
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) | ||
|
||
# prepare variable output controls (note: some models won't accept all output controls) | ||
model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) | ||
model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) | ||
|
||
# forward pass to get next token | ||
tic_infer = time.perf_counter() | ||
outputs = self(**model_inputs, return_dict=True) | ||
hook_greedy.tm_infer_list.append(time.perf_counter() - tic_infer) | ||
|
||
if synced_gpus and this_peer_finished: | ||
continue # don't waste resources running the code we don't need | ||
|
||
# Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration | ||
# (the clone itself is always small) | ||
next_token_logits = outputs.logits.clone()[:, -1, :].float() | ||
|
||
# pre-process distribution | ||
next_token_scores = logits_processor(input_ids, next_token_logits) | ||
|
||
# Store scores, attentions and hidden_states when required | ||
if return_dict_in_generate: | ||
if output_scores: | ||
scores += (next_token_scores,) | ||
if output_logits: | ||
raw_logits += (next_token_logits,) | ||
if output_attentions: | ||
decoder_attentions += ( | ||
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) | ||
) | ||
if self.config.is_encoder_decoder: | ||
cross_attentions += (outputs.cross_attentions,) | ||
|
||
if output_hidden_states: | ||
decoder_hidden_states += ( | ||
(outputs.decoder_hidden_states,) | ||
if self.config.is_encoder_decoder | ||
else (outputs.hidden_states,) | ||
) | ||
|
||
# token selection | ||
if do_sample: | ||
probs = nn.functional.softmax(next_token_scores, dim=-1) | ||
# TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution | ||
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) | ||
else: | ||
next_tokens = torch.argmax(next_token_scores, dim=-1) | ||
|
||
# finished sentences should have their next token be a padding token | ||
if has_eos_stopping_criteria: | ||
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) | ||
|
||
# update generated ids, model inputs, and length for next step | ||
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) | ||
if streamer is not None: | ||
streamer.put(next_tokens.cpu()) | ||
model_kwargs = self._update_model_kwargs_for_generation( | ||
outputs, | ||
model_kwargs, | ||
is_encoder_decoder=self.config.is_encoder_decoder, | ||
) | ||
|
||
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) | ||
this_peer_finished = unfinished_sequences.max() == 0 | ||
cur_len += 1 | ||
hook_greedy.tm_list.append(time.perf_counter() - tic) | ||
# This is needed to properly delete outputs.logits which may be very large for first iteration | ||
# Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration | ||
del outputs | ||
|
||
if streamer is not None: | ||
streamer.end() | ||
|
||
if return_dict_in_generate: | ||
if self.config.is_encoder_decoder: | ||
return GenerateEncoderDecoderOutput( | ||
sequences=input_ids, | ||
scores=scores, | ||
logits=raw_logits, | ||
encoder_attentions=encoder_attentions, | ||
encoder_hidden_states=encoder_hidden_states, | ||
decoder_attentions=decoder_attentions, | ||
cross_attentions=cross_attentions, | ||
decoder_hidden_states=decoder_hidden_states, | ||
past_key_values=model_kwargs.get("past_key_values"), | ||
) | ||
else: | ||
return GenerateDecoderOnlyOutput( | ||
sequences=input_ids, | ||
scores=scores, | ||
logits=raw_logits, | ||
attentions=decoder_attentions, | ||
hidden_states=decoder_hidden_states, | ||
past_key_values=model_kwargs.get("past_key_values"), | ||
) | ||
else: | ||
return input_ids |