Skip to content

Commit

Permalink
Handle selector exception (#73)
Browse files Browse the repository at this point in the history
* fix(tgi): handle invalid generation config error and return to server

If there is an invalid generation config, the selector raises an error.
This is caught by the prefill method, that skips the slot generation,
so the error is handled by the router.
I had not been able to reproduce the problem with a simple HTTP request
to TGI, but it seems it's possible to do it with the HTML form
interface, so it's better to handle this, even if it's unlikely to
happen.

* fix(tgi): handle another exception in prefill

Returning an empty batch is better than crashing.
  • Loading branch information
tengomucho authored Jul 12, 2024
1 parent eb1d7c9 commit 50ed7bd
Showing 1 changed file with 20 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -449,10 +449,14 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
empty_slots = slots[Slot.State.EMPTY]
model_batch_size = self.model.config.batch_size
if model_batch_size is not None and model_batch_size < len(active_slots) + len(batch.requests):
raise ValueError(
# If raising an error here wouldn't crash the server, we could raise a ValueError
error = ValueError(
f"Cannot prefill {len(batch.requests)} new request(s)."
f" Maximum batch size supported is: {model_batch_size}."
)
# but since it's not possible, we just log the error and return an empty generation
logger.error(error)
return [], None
for slot in empty_slots:
self.slots.remove(slot)
# Assign each request to an empty slot
Expand Down Expand Up @@ -505,13 +509,21 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
input_ids[i, -truncation:] = tokenized_inputs.input_ids[i, -truncation:]
slot_input_ids = input_ids[i : i + 1, :]
# Padded input ids are also required to set logits processors and stopping criterias
selector = TokenSelector.create(
slot_input_ids,
slot.generation_config,
self.model,
self.model.config.sequence_length,
seed=slot.seed,
)
try:
selector = TokenSelector.create(
slot_input_ids,
slot.generation_config,
self.model,
self.model.config.sequence_length,
seed=slot.seed,
)
except ValueError as e:
# This is very unlikely, but it seems it could be possible if router does not check values beforehand.
# In that case, we just skip the slot, and mark it as empty. This should prevent returning this to the
# router.
logger.error(f"Invalid generation parameters for slot {slot.id}. Skipping it. Error: {e}")
slot.clear()
continue
slot_input_ids = slot_input_ids.squeeze(dim=0).type(torch.int64)
attention_mask[i, -truncation:] = tokenized_inputs.attention_mask[i, -truncation:]
if self._supports_static_cache:
Expand Down

0 comments on commit 50ed7bd

Please sign in to comment.