Skip to content

Commit

Permalink
Fix stopping criteria for batch size > 1
Browse files Browse the repository at this point in the history
  • Loading branch information
skavulya committed Jun 13, 2024
1 parent a7db404 commit 52edadd
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
8 changes: 4 additions & 4 deletions optimum/habana/transformers/generation/stopping_criteria.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def gaudi_MaxLengthCriteria_call(
token_idx = kwargs.get("token_idx", None)
if token_idx is not None:
assert not kwargs["needs_tensor_output"]
is_done = token_idx >= self.max_length
return token_idx >= self.max_length
else:
cur_len = input_ids.shape[-1]
is_done = cur_len >= self.max_length
Expand All @@ -49,7 +49,7 @@ def gaudi_MaxLengthCriteria_call(
f"maximum length ({self.max_position_embeddings}). Depending on the model, you may observe "
"exceptions, performance degradation, or nothing at all."
)
return create_return_const_tensor(input_ids, is_done)
return create_return_const_tensor(input_ids, is_done)


def gaudi_MaxNewTokensCriteria_call(
Expand All @@ -58,10 +58,10 @@ def gaudi_MaxNewTokensCriteria_call(
token_idx = kwargs.get("token_idx", None)
if token_idx is not None:
assert not kwargs["needs_tensor_output"]
is_done = token_idx >= self.max_length
return token_idx >= self.max_length
else:
is_done = input_ids.shape[-1] >= self.max_length
return create_return_const_tensor(input_ids, is_done)
return create_return_const_tensor(input_ids, is_done)


def gaudi_MaxTimeCriteria_call(
Expand Down
5 changes: 4 additions & 1 deletion optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,10 @@ def get_final_stopping_criteria(x):
if isinstance(x, bool):
return x
elif torch.is_tensor(x):
return all(x)
if x.dim() > 0:
return all(x)
else:
return x
else:
raise TypeError(f"The stopping criteria should be either a boolean or a torch.tensor but got {type(x)}.")

Expand Down

0 comments on commit 52edadd

Please sign in to comment.