Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(transformers): various enhancements to the transformers backend #2468

Merged
merged 1 commit into from
Jun 3, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 36 additions & 23 deletions backend/python/transformers/backend.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@

XPU=os.environ.get("XPU", "0") == "1"
if XPU:
from transformers import AutoTokenizer, AutoModel, set_seed, TextIteratorStreamer
from transformers import AutoTokenizer, AutoModel, set_seed, TextIteratorStreamer, StoppingCriteriaList, StopStringCriteria
else:
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, set_seed, BitsAndBytesConfig, TextIteratorStreamer
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, set_seed, BitsAndBytesConfig, TextIteratorStreamer, StoppingCriteriaList, StopStringCriteria


_ONE_DAY_IN_SECONDS = 60 * 60 * 24
Expand Down Expand Up @@ -246,28 +246,28 @@ def Embedding(self, request, context):

# Pool to get sentence embeddings; i.e. generate one 1024 vector for the entire sentence
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
# print("Calculated embeddings for: " + request.Embeddings, file=sys.stderr)
# print("Embeddings:", sentence_embeddings, file=sys.stderr)
return backend_pb2.EmbeddingResult(embeddings=sentence_embeddings[0])

async def _predict(self, request, context, streaming=False):
set_seed(request.Seed)
if request.TopP == 0:
request.TopP = 0.9
if request.TopP < 0 or request.TopP > 1:
request.TopP = 1

if request.TopK == 0:
request.TopK = 40
if request.TopK <= 0:
request.TopK = 50

if request.Temperature > 0 :
sample=True
else:
sample=False
request.TopP == None
request.TopK == None
request.Temperature == None

prompt = request.Prompt
if not request.Prompt and request.UseTokenizerTemplate and request.Messages:
prompt = self.tokenizer.apply_chat_template(request.Messages, tokenize=False, add_generation_prompt=True)

eos_token_id = self.tokenizer.eos_token_id
if request.StopPrompts:
eos_token_id = []
for word in request.StopPrompts:
eos_token_id.append(self.tokenizer.convert_tokens_to_ids(word))

inputs = self.tokenizer(prompt, return_tensors="pt")

if request.Tokens > 0:
Expand All @@ -281,6 +281,14 @@ async def _predict(self, request, context, streaming=False):
inputs = inputs.to("xpu")
streaming = False

criteria=[]
if request.StopPrompts:
criteria = StoppingCriteriaList(
[
StopStringCriteria(tokenizer=self.tokenizer, stop_strings=request.StopPrompts),
]
)

if streaming:
streamer=TextIteratorStreamer(self.tokenizer,
skip_prompt=True,
Expand All @@ -290,11 +298,14 @@ async def _predict(self, request, context, streaming=False):
temperature=request.Temperature,
top_p=request.TopP,
top_k=request.TopK,
do_sample=True,
do_sample=sample,
attention_mask=inputs["attention_mask"],
eos_token_id=eos_token_id,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.eos_token_id,
streamer=streamer)
streamer=streamer,
stopping_criteria=criteria,
use_cache=True,
)
thread=Thread(target=self.model.generate, kwargs=config)
thread.start()
generated_text = ""
Expand All @@ -311,18 +322,20 @@ async def _predict(self, request, context, streaming=False):
temperature=request.Temperature,
top_p=request.TopP,
top_k=request.TopK,
do_sample=True,
do_sample=sample,
pad_token=self.tokenizer.eos_token_id)
else:
outputs = self.model.generate(inputs["input_ids"],
outputs = self.model.generate(**inputs,
max_new_tokens=max_tokens,
temperature=request.Temperature,
top_p=request.TopP,
top_k=request.TopK,
do_sample=True,
attention_mask=inputs["attention_mask"],
eos_token_id=eos_token_id,
pad_token_id=self.tokenizer.eos_token_id)
do_sample=sample,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.eos_token_id,
stopping_criteria=criteria,
use_cache=True,
)
generated_text = self.tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[1]:], skip_special_tokens=True)[0]

if streaming:
Expand Down