Skip to content

Commit

Permalink
Update bf16 accuracy fix in sample to keep upcast to float
Browse files Browse the repository at this point in the history
  • Loading branch information
skavulya committed Dec 21, 2024
1 parent 0608684 commit 090527c
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2459,7 +2459,7 @@ def _sample(
if token_idx is not None and outputs.logits.shape[-2] > 1:
# case1 (w/o KV caching): outputs.logits.shape: [batch_size, max_length, vocab_size]
if self.config.is_encoder_decoder:
next_token_logits = outputs.logits[:, token_idx - 1, :]
next_token_logits = outputs.logits[:, token_idx - 1, :].float()
next_token_scores = logits_processor(input_ids[:, :token_idx], next_token_logits)
else:
if model_kwargs.get("num_virtual_tokens", 0) > 0:
Expand All @@ -2468,12 +2468,13 @@ def _sample(
output_idx = torch.tensor(outputs.logits.shape[-2], device=input_ids.device)
else:
output_idx = token_idx + outputs.logits.shape[-2] - input_ids.shape[-1]
next_token_logits = torch.index_select(outputs.logits, -2, output_idx - 1).squeeze(-2)
next_token_logits = torch.index_select(outputs.logits, -2, output_idx - 1).squeeze(-2).float()
else:
next_token_logits = torch.index_select(outputs.logits, -2, token_idx - 1).squeeze(-2)
next_token_logits = torch.index_select(outputs.logits, -2, token_idx - 1).squeeze(-2).float()
next_token_scores = logits_processor(input_ids, next_token_logits)
else:
next_token_logits = outputs.logits[:, -1, :]
# .float() is needed to retain precision for later logits manipulations
next_token_logits = outputs.logits[:, -1, :].float()
if token_idx is not None and self.config.is_encoder_decoder:
# case2 (with KV caching): outputs.logits.shape: [batch_size, 1, vocab_size]
next_token_scores = logits_processor(input_ids[:, :token_idx], next_token_logits)
Expand Down Expand Up @@ -2503,7 +2504,9 @@ def _sample(

# token selection
if do_sample:
probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
# Workaround on HPU for output quality issues with torch.multinomial for lower precision models
# Distribution sampled by torch.multinomial may be affected by next_token_logits upcast to float
probs = torch.nn.functional.softmax(next_token_scores, dim=-1).to(outputs.logits.dtype)
# TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
else:
Expand Down

0 comments on commit 090527c

Please sign in to comment.