Skip to content

Commit

Permalink
cleanup job queue on failure
Browse files Browse the repository at this point in the history
  • Loading branch information
lapp0 committed Oct 5, 2024
1 parent 26753d2 commit 58b8d91
Showing 1 changed file with 19 additions and 11 deletions.
30 changes: 19 additions & 11 deletions outlines/models/exllamav2.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import dataclasses
from typing import TYPE_CHECKING, Iterator, List, Optional, Tuple, TypedDict, Union

import torch
from typing_extensions import Unpack

from outlines.generate.api import GenerationParameters, SamplingParameters
Expand Down Expand Up @@ -29,7 +30,9 @@ def convert_token_to_string(self, token):
return token

def decode(self, token_ids: "torch.LongTensor") -> List[str]:
return self.exl2_tokenizer.batch_decode(token_ids, decode_special_tokens=False)
return self.exl2_tokenizer.decode(
torch.tensor([token_ids]), decode_special_tokens=False
)


class ExLlamaV2Model:
Expand Down Expand Up @@ -219,16 +222,21 @@ def stream(
next_text = [""] * batch_size

def token_generator() -> Iterator[str]:
while self.generator.num_remaining_jobs():
results = self.generator.iterate()
for r in results:
idx = order[r["serial"]]
if r["stage"] == "streaming":
text = r.get("text", "")
next_text[idx] = text
if r["eos"]:
next_text[idx] = ""
yield self.reformat_output(next_text, sampling_parameters)
try:
while self.generator.num_remaining_jobs():
results = self.generator.iterate()
for r in results:
idx = order[r["serial"]]
if r["stage"] == "streaming":
text = r.get("text", "")
next_text[idx] = text
if r["eos"]:
next_text[idx] = ""
yield self.reformat_output(next_text, sampling_parameters)
except Exception as e:
for job in self.generator.pending_jobs:
job.cancel()
raise e
return

return token_generator()
Expand Down

0 comments on commit 58b8d91

Please sign in to comment.