-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathmain.py
119 lines (81 loc) · 3.14 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from transformers import GPT2LMHeadModel, PreTrainedTokenizerFast
from flask import Flask, request, jsonify, render_template
import torch
from queue import Queue, Empty
from threading import Thread
import time
import json
app = Flask(__name__)
print("model loading...")
# Model & Tokenizer loading
tokenizer = PreTrainedTokenizerFast.from_pretrained('./KoGPT2',
bos_token='</s>', eos_token='</s>', unk_token='<unk>',
pad_token='<pad>', mask_token='<mask>')
model = GPT2LMHeadModel.from_pretrained('skt/kogpt2-base-v2')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
requests_queue = Queue() # request queue.
BATCH_SIZE = 1 # max request size.
CHECK_INTERVAL = 0.1
print("complete model loading")
def handle_requests_by_batch():
while True:
request_batch = []
while not (len(request_batch) >= BATCH_SIZE):
try:
request_batch.append(requests_queue.get(timeout=CHECK_INTERVAL))
except Empty:
continue
for requests in request_batch:
try:
requests["output"] = make_text(requests['input'][0], requests['input'][1])
except Exception as e:
requests["output"] = e
handler = Thread(target=handle_requests_by_batch).start()
def make_text(text, length):
try:
input_ids = tokenizer.encode(text, return_tensors='pt')
input_ids = input_ids.to(device)
min_length = len(input_ids.tolist()[0])
length = length if length > 0 else 1
length += min_length
gen_ids = model.generate(input_ids,
max_length=128,
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
bos_token_id=tokenizer.bos_token_id,
top_p=0.95,
top_k=50)
result = dict()
for idx, sample_output in enumerate(gen_ids):
result[0] = tokenizer.decode(sample_output.tolist(), skip_special_tokens=True)
return result
except Exception as e:
print('Error occur in script generating!', e)
return jsonify({'Error': e}), 500
@app.route('/generate', methods=['POST'])
def generate():
if requests_queue.qsize() > BATCH_SIZE:
return jsonify({'Error': 'Too Many Requests. Please try again later'}), 429
try:
args = []
text = request.form['text']
length = int(request.form['length'])
args.append(text)
args.append(length)
except Exception as e:
return jsonify({'Error': 'Invalid request'}), 500
req = {'input': args}
requests_queue.put(req)
while 'output' not in req:
time.sleep(CHECK_INTERVAL)
return json.dumps(req['output'], ensure_ascii=False)
@app.route('/healthz', methods=["GET"])
def health_check():
return "Health", 200
@app.route('/')
def main():
return render_template('main.html'), 200
if __name__ == '__main__':
app.run(port=5000, host='0.0.0.0')