Skip to content

Commit

Permalink
Fix empty output when temp is too low (vllm-project#2937)
Browse files Browse the repository at this point in the history
Co-authored-by: Cyrus Leung <[email protected]>
  • Loading branch information
CatherineSue and DarkLight1337 authored Aug 14, 2024
1 parent 199adbb commit c134a46
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 1 deletion.
3 changes: 2 additions & 1 deletion vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,9 @@ def forward(
sampling_tensors.frequency_penalties,
sampling_tensors.repetition_penalties)

# Apply temperature scaling.
# Use float32 to apply temperature scaling.
# Use in-place division to avoid creating a new tensor.
logits = logits.to(torch.float)
logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1))

if do_top_p_top_k:
Expand Down
7 changes: 7 additions & 0 deletions vllm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
logger = init_logger(__name__)

_SAMPLING_EPS = 1e-5
_MAX_TEMP = 1e-2


class SamplingType(IntEnum):
Expand Down Expand Up @@ -145,6 +146,12 @@ def __init__(
self.presence_penalty = presence_penalty
self.frequency_penalty = frequency_penalty
self.repetition_penalty = repetition_penalty
if 0 < temperature < _MAX_TEMP:
logger.warning(
"temperature %s is less than %s, which may cause numerical "
"errors nan or inf in tensors. We have maxed it out to %s.",
temperature, _MAX_TEMP, _MAX_TEMP)
temperature = max(temperature, _MAX_TEMP)
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
Expand Down

0 comments on commit c134a46

Please sign in to comment.