diff --git a/text-generation-inference/server/text_generation_server/generator.py b/text-generation-inference/server/text_generation_server/generator.py index 245ee64a..facd511d 100644 --- a/text-generation-inference/server/text_generation_server/generator.py +++ b/text-generation-inference/server/text_generation_server/generator.py @@ -449,10 +449,14 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: empty_slots = slots[Slot.State.EMPTY] model_batch_size = self.model.config.batch_size if model_batch_size is not None and model_batch_size < len(active_slots) + len(batch.requests): - raise ValueError( + # If raising an error here wouldn't crash the server, we could raise a ValueError + error = ValueError( f"Cannot prefill {len(batch.requests)} new request(s)." f" Maximum batch size supported is: {model_batch_size}." ) + # but since it's not possible, we just log the error and return an empty generation + logger.error(error) + return [], None for slot in empty_slots: self.slots.remove(slot) # Assign each request to an empty slot @@ -505,13 +509,21 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: input_ids[i, -truncation:] = tokenized_inputs.input_ids[i, -truncation:] slot_input_ids = input_ids[i : i + 1, :] # Padded input ids are also required to set logits processors and stopping criterias - selector = TokenSelector.create( - slot_input_ids, - slot.generation_config, - self.model, - self.model.config.sequence_length, - seed=slot.seed, - ) + try: + selector = TokenSelector.create( + slot_input_ids, + slot.generation_config, + self.model, + self.model.config.sequence_length, + seed=slot.seed, + ) + except ValueError as e: + # This is very unlikely, but it seems it could be possible if router does not check values beforehand. + # In that case, we just skip the slot, and mark it as empty. This should prevent returning this to the + # router. + logger.error(f"Invalid generation parameters for slot {slot.id}. Skipping it. Error: {e}") + slot.clear() + continue slot_input_ids = slot_input_ids.squeeze(dim=0).type(torch.int64) attention_mask[i, -truncation:] = tokenized_inputs.attention_mask[i, -truncation:] if self._supports_static_cache: