-
Notifications
You must be signed in to change notification settings - Fork 7
/
gradio_code_demo.py
59 lines (46 loc) · 1.81 KB
/
gradio_code_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import gradio as gr
import random
import time
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import re
import torch
import os
tokenizer = AutoTokenizer.from_pretrained("OpenBA/OpenBA-Code", trust_remote_code=True)
model = AutoModelForSeq2SeqLM.from_pretrained("OpenBA/OpenBA-Code", trust_remote_code=True).half().cuda()
model.eval()
def case_insensitive_replace(input_str, from_str, to_str):
pattern = re.compile(re.escape(from_str), re.IGNORECASE)
return pattern.sub(to_str, input_str)
def history2input(chat_history, message):
return message
def gpu_respond(message, top_p, temp, chat_history):
input_text = history2input(chat_history, message)
print("input:", input_text)
bot_message = generate(input_text, top_p, temp)
print("message:", bot_message)
print('-' * 30)
chat_history.append((message, bot_message))
return "", chat_history
def generate(input_text, top_p=0.7, temp=0.95):
inputs = tokenizer("<S> " + input_text + " <extra_id_0>", return_tensors='pt')
for k in inputs:
inputs[k] = inputs[k].cuda()
outputs = model.generate(
**inputs,
do_sample=True,
max_new_tokens=1024,
temperature = temp,
top_p = top_p,
)
response = tokenizer.decode(outputs[0][1:], spaces_between_special_tokens=False) + '\n'
return response
if __name__ == "__main__":
with gr.Blocks() as demo:
chatbot = gr.Chatbot()
msg = gr.Textbox()
clear = gr.ClearButton([msg, chatbot])
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, label="Top P")
temp = gr.Slider(minimum=0.01, maximum=1.0, value=0.95, label="Temperature")
msg.submit(gpu_respond, [msg, top_p, temp, chatbot], [msg, chatbot])
demo.queue(concurrency_count=3)
demo.launch(share=True)