Skip to content

Commit

Permalink
Update the stopping criteria for input embeds to improve performance
Browse files Browse the repository at this point in the history
  • Loading branch information
skavulya committed Jun 17, 2024
1 parent f55d6df commit 977f40e
Showing 1 changed file with 43 additions and 21 deletions.
64 changes: 43 additions & 21 deletions optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1720,9 +1720,11 @@ def _greedy_search(

# keep track of which sequences are already finished
batch_size, cur_len = input_ids.shape
has_inputs_embeds = "inputs_embeds" in model_kwargs
if has_inputs_embeds:
inputs_embeds_offset = 0
if "inputs_embeds" in model_kwargs:
cur_len = model_kwargs["inputs_embeds"].shape[1]
inputs_embeds_offset = input_ids.shape[1] - cur_len

this_peer_finished = False
if not ignore_eos:
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
Expand Down Expand Up @@ -1857,14 +1859,17 @@ def _greedy_search(
else:
model_kwargs["cache_idx"] = model_kwargs["kv_cache_len"]
cur_len = cur_len + 1
stop_tkn_idx = cur_len + inputs_embeds_offset

if ignore_eos:
this_peer_finished = get_final_stopping_criteria(stopping_criteria(
input_ids, scores, token_idx=token_idx if has_inputs_embeds else cur_len, ignore_eos=ignore_eos, eos_token_id=eos_token_id
))
this_peer_finished = get_final_stopping_criteria(
stopping_criteria(
input_ids, scores, token_idx=stop_tkn_idx, ignore_eos=ignore_eos, eos_token_id=eos_token_id
)
)
else:
unfinished_sequences = unfinished_sequences & ~stopping_criteria(
input_ids, scores, token_idx=token_idx if has_inputs_embeds else cur_len, ignore_eos=ignore_eos, eos_token_id=eos_token_id
input_ids, scores, token_idx=stop_tkn_idx, ignore_eos=ignore_eos, eos_token_id=eos_token_id
)
this_peer_finished = unfinished_sequences.max() == 0

Expand Down Expand Up @@ -2143,9 +2148,11 @@ def _sample(
# keep track of which sequences are already finished
# TODO: no ignore_eos check here since there is a compilation error, will add ignore_eos here if fixed
batch_size, cur_len = input_ids.shape
has_inputs_embeds = "inputs_embeds" in model_kwargs
if has_inputs_embeds:
inputs_embeds_offset = 0
if "inputs_embeds" in model_kwargs:
cur_len = model_kwargs["inputs_embeds"].shape[1]
inputs_embeds_offset = input_ids.shape[1] - cur_len

this_peer_finished = False
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
Expand Down Expand Up @@ -2278,13 +2285,17 @@ def _sample(
else:
model_kwargs["cache_idx"] = model_kwargs["kv_cache_len"]

stop_tkn_idx = cur_len + inputs_embeds_offset

if ignore_eos:
this_peer_finished = get_final_stopping_criteria(stopping_criteria(
input_ids, scores, token_idx=token_idx if has_inputs_embeds else cur_len, ignore_eos=ignore_eos, eos_token_id=eos_token_id
))
this_peer_finished = get_final_stopping_criteria(
stopping_criteria(
input_ids, scores, token_idx=stop_tkn_idx, ignore_eos=ignore_eos, eos_token_id=eos_token_id
)
)
else:
unfinished_sequences = unfinished_sequences & ~stopping_criteria(
input_ids, scores, token_idx=token_idx if has_inputs_embeds else cur_len, ignore_eos=ignore_eos, eos_token_id=eos_token_id
input_ids, scores, token_idx=stop_tkn_idx, ignore_eos=ignore_eos, eos_token_id=eos_token_id
)
this_peer_finished = unfinished_sequences.max() == 0

Expand Down Expand Up @@ -2547,9 +2558,11 @@ def _beam_search(
num_beams = beam_scorer.num_beams

batch_beam_size, cur_len = input_ids.shape
has_inputs_embeds = "inputs_embeds" in model_kwargs
if has_inputs_embeds:
inputs_embeds_offset = 0
if "inputs_embeds" in model_kwargs:
cur_len = model_kwargs["inputs_embeds"].shape[1]
inputs_embeds_offset = input_ids.shape[1] - cur_len

token_idx = model_kwargs.get("token_idx", None)
if token_idx is not None:
# Update cur_len in case of static shapes
Expand Down Expand Up @@ -2881,6 +2894,7 @@ def expand_if_needed(tensor, new_size, value, dim=-1):

# increase cur_len
cur_len = cur_len + 1
stop_tkn_idx = cur_len + inputs_embeds_offset

hb_profer.step()
if self.generation_config.static_shapes:
Expand All @@ -2893,7 +2907,7 @@ def expand_if_needed(tensor, new_size, value, dim=-1):
and num_eos_tokens >= num_beams_tensor
):
break
elif get_final_stopping_criteria(stopping_criteria(input_ids, scores, token_idx=token_idx if has_inputs_embeds else cur_len)):
elif get_final_stopping_criteria(stopping_criteria(input_ids, scores, token_idx=stop_tkn_idx)):
break
elif get_final_stopping_criteria(stopping_criteria(input_ids, scores)) or (
beam_scorer.is_done and not lazy_mode
Expand Down Expand Up @@ -3481,9 +3495,11 @@ def _constrained_beam_search(
num_beams = constrained_beam_scorer.num_beams

batch_beam_size, cur_len = input_ids.shape
has_inputs_embeds = "inputs_embeds" in model_kwargs
if has_inputs_embeds:
inputs_embeds_offset = 0
if "inputs_embeds" in model_kwargs:
cur_len = model_kwargs["inputs_embeds"].shape[1]
inputs_embeds_offset = input_ids.shape[1] - cur_len

token_idx = model_kwargs.get("token_idx", None)
if token_idx is not None:
# Update cur_len in case of static shapes
Expand Down Expand Up @@ -3640,11 +3656,12 @@ def _constrained_beam_search(

# increase cur_len
cur_len = cur_len + 1
stop_tkn_idx = cur_len + inputs_embeds_offset

hb_profer.step()

if constrained_beam_scorer.is_done or get_final_stopping_criteria(
stopping_criteria(input_ids, scores, token_idx=token_idx if has_inputs_embeds else cur_len)
stopping_criteria(input_ids, scores, token_idx=stop_tkn_idx)
):
this_peer_finished = True

Expand Down Expand Up @@ -3898,8 +3915,11 @@ def _assisted_decoding(

# keep track of which sequences are already finished
batch_size, cur_len = input_ids.shape
inputs_embeds_offset = 0
if "inputs_embeds" in model_kwargs:
cur_len = model_kwargs["inputs_embeds"].shape[1]
inputs_embeds_offset = input_ids.shape[1] - cur_len

if not ignore_eos:
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
Expand Down Expand Up @@ -4082,9 +4102,11 @@ def _assisted_decoding(
)

if ignore_eos:
this_peer_finished = get_final_stopping_criteria(stopping_criteria(
input_ids, scores, token_idx=None, ignore_eos=ignore_eos, eos_token_id=eos_token_id
))
this_peer_finished = get_final_stopping_criteria(
stopping_criteria(
input_ids, scores, token_idx=None, ignore_eos=ignore_eos, eos_token_id=eos_token_id
)
)
else:
unfinished_sequences = unfinished_sequences & ~stopping_criteria(
input_ids, scores, token_idx=None, ignore_eos=ignore_eos, eos_token_id=eos_token_id
Expand Down

0 comments on commit 977f40e

Please sign in to comment.